From 9be69c42db09262079ba4f5050ac176fd6320727 Mon Sep 17 00:00:00 2001 From: Alexandre Balmes Date: Wed, 13 May 2026 10:57:00 +0200 Subject: [PATCH] feat(tokens): unify counting via ports.Tokenizer port - `CHANGELOG.md`: Document F094 unified token counting changes - `CLAUDE.md`: Add nolint:errcheck replication rule; remove stale pitfall - `docs/development/architecture.md`: Update tokenizer/ description with ports.Tokenizer detail - `docs/development/creating-agent-provider.md`: Add new provider creation guide (1004 lines) - `docs/development/project-structure.md`: Update tokenizer/ directory description - `docs/reference/interpolation.md`: Document TokensInput, TokensOutput, TokensEstimated fields - `docs/user-guide/agent-steps.md`: Add token tracking table with new fields and provider matrix - `go.mod`: Remove tiktoken-go and glamour dependencies - `go.sum`: Remove checksums for removed dependencies - `internal/application/execution_service.go`: Propagate TokensInput, TokensOutput, TokensEstimated into step state - `internal/application/interpolation_helpers.go`: Map new token fields into interpolation context - `internal/domain/workflow/context.go`: Add TokensInput, TokensOutput, TokensEstimated to StepState - `internal/domain/workflow/reference.go`: Register new token properties in ValidStateProperties and alias map - `internal/infrastructure/agents/base_cli_provider.go`: Inject ports.Tokenizer; add extractTokenUsage hook; use real tokens when available, fallback to tokenizer estimate - `internal/infrastructure/agents/base_cli_provider_tokenizer_test.go`: Add 390-line tokenizer integration tests for execute and conversation paths - `internal/infrastructure/agents/claude_provider.go`: Wire extractTokenUsage hook from claude result event usage field - `internal/infrastructure/agents/codex_provider.go`: Wire extractTokenUsage hook from turn.completed event usage field - `internal/infrastructure/agents/copilot_provider.go`: Wire extractTokenUsage hook from assistant.message outputTokens field - `internal/infrastructure/agents/gemini_provider.go`: Wire extractTokenUsage hook from result event stats field - `internal/infrastructure/agents/helpers.go`: Remove dead estimateTokens and estimateInputTokens helpers - `internal/infrastructure/agents/helpers_test.go`: Remove tests for deleted estimation helpers - `internal/infrastructure/agents/opencode_provider.go`: Wire extractTokenUsage hook from step_finish part.tokens field - `internal/infrastructure/agents/options.go`: Add SetTokenizer option for baseCLIProvider injection - `internal/infrastructure/agents/provider_options_test.go`: Add tokenizer injection tests - `internal/infrastructure/tokenizer/tiktoken_tokenizer.go`: Delete TiktokenTokenizer (tiktoken dep removed) - `internal/infrastructure/tokenizer/tiktoken_tokenizer_test.go`: Delete tiktoken tokenizer tests - `pkg/interpolation/reference.go`: Register TokensInput, TokensOutput, TokensEstimated in ValidStateProperties - `pkg/interpolation/reference_json_field_test.go`: Update tests for new token property names - `pkg/interpolation/reference_test.go`: Update reference validation tests - `pkg/interpolation/resolver.go`: Handle TokensEstimated bool type in template resolver Closes #339 --- CHANGELOG.md | 6 + CLAUDE.md | 2 +- docs/development/architecture.md | 2 +- docs/development/creating-agent-provider.md | 1004 +++++++++++++++++ docs/development/project-structure.md | 2 +- docs/reference/interpolation.md | 42 +- docs/user-guide/agent-steps.md | 35 +- go.mod | 10 - go.sum | 26 - internal/application/execution_service.go | 5 + internal/application/interpolation_helpers.go | 17 +- internal/domain/workflow/context.go | 7 +- internal/domain/workflow/reference.go | 38 +- .../agents/base_cli_provider.go | 88 +- .../base_cli_provider_tokenizer_test.go | 390 +++++++ .../infrastructure/agents/claude_provider.go | 48 +- .../infrastructure/agents/codex_provider.go | 54 +- .../infrastructure/agents/copilot_provider.go | 49 +- .../infrastructure/agents/gemini_provider.go | 62 +- internal/infrastructure/agents/helpers.go | 57 +- .../infrastructure/agents/helpers_test.go | 93 -- .../agents/opencode_provider.go | 48 +- internal/infrastructure/agents/options.go | 30 + .../agents/provider_options_test.go | 135 ++- .../tokenizer/tiktoken_tokenizer.go | 67 -- .../tokenizer/tiktoken_tokenizer_test.go | 545 --------- pkg/interpolation/reference.go | 17 +- .../reference_json_field_test.go | 5 +- pkg/interpolation/reference_test.go | 2 +- pkg/interpolation/resolver.go | 19 +- 30 files changed, 2002 insertions(+), 903 deletions(-) create mode 100644 docs/development/creating-agent-provider.md create mode 100644 internal/infrastructure/agents/base_cli_provider_tokenizer_test.go delete mode 100644 internal/infrastructure/tokenizer/tiktoken_tokenizer.go delete mode 100644 internal/infrastructure/tokenizer/tiktoken_tokenizer_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 0aa953c2..a76800f2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Changed + +- **F094**: Unified token counting via `ports.Tokenizer` port — all CLI-based agent providers (Claude, Gemini, Codex, GitHub Copilot, OpenCode) now count tokens through an injected `Tokenizer` interface instead of inline `len(output)/4` helpers; default `ApproximationTokenizer` preserves identical behavior; eliminates mutation side-effect on shared conversation turn state during input token estimation; dead `estimateTokens`/`estimateInputTokens` helpers removed; enables future swap to real token counting (e.g., tiktoken, stream-extracted counts) by changing a single injection point + ## [0.8.1] - 2026-05-11 +- **F093**: Add dangerously_skip_permissions alias for --allow-all for Github Copilot agent + ## [0.8.0] - 2026-05-09 ### Fixed diff --git a/CLAUDE.md b/CLAUDE.md index a697123e..eca80535 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -243,7 +243,6 @@ func TestWorkflowValidation(t *testing.T) { ## Common Pitfalls -- When removing redundant infrastructure code, document the architectural ownership pattern; explain which layer assumed responsibility and why the field was removed - Always apply code deletions before writing tests that validate the deletion effect; tests may pass against overridden behavior instead of the intended code path - Wrap YAML/JSON mapping errors (duration parse, type conversion) in domain error types; surface failures immediately to prevent silent defaults - Never merge infrastructure provider stubs; always implement ExecuteConversation fully or return NotImplementedError with linked tracking issue @@ -284,6 +283,7 @@ func TestWorkflowValidation(t *testing.T) { - Never silently initialize nested struct fields during YAML unmarshaling; explicitly map all sections (events, metadata, etc.) to prevent zero values from hiding parsing bugs - Always stage all modified implementation files and run 'git status' before marking task complete; unstaged files indicate incomplete task closure. - Update plan task status immediately when implementation completes; regenerate validation report to catch status-code mismatches before submission. +- Always replicate nolint:errcheck directives identically across all provider implementations; verify explanatory comments match before make lint ## Test Conventions diff --git a/docs/development/architecture.md b/docs/development/architecture.md index 83e508d0..a8c4b8c3 100644 --- a/docs/development/architecture.md +++ b/docs/development/architecture.md @@ -207,7 +207,7 @@ Implements domain ports with concrete technologies. - `pluginmgr/` - Plugin lifecycle (manifest, state, gRPC connections); delegates transport to `pkg/registry/` - `repository/` - YAML file loader implementing `Repository` - `store/` - JSON state store implementing `StateStore`, SQLite history storage -- `tokenizer/` - Token counting for conversation context management +- `tokenizer/` - Token counting implementations (`ApproximationTokenizer`, `TiktokenTokenizer`) implementing `ports.Tokenizer`; injected into `baseCLIProvider` for unified token counting across all CLI agent providers - `xdg/` - XDG directory discovery **Shared Packages (`pkg/`):** diff --git a/docs/development/creating-agent-provider.md b/docs/development/creating-agent-provider.md new file mode 100644 index 00000000..d685ecd6 --- /dev/null +++ b/docs/development/creating-agent-provider.md @@ -0,0 +1,1004 @@ +--- +title: "Creating an Agent Provider" +--- + +Guide for implementing a new agent provider in AWF. Covers the domain contract, infrastructure base layer, hooks, options, display events, session management, and registration. + +## Architecture + +Agent providers live in the **infrastructure layer** and implement the `ports.AgentProvider` interface defined in the **domain layer**. The base infrastructure handles execution orchestration, token counting, state cloning, and stream filtering. Each provider only implements the provider-specific parts via hooks. + +``` +Domain Layer (ports) Infrastructure Layer (agents) +┌───────────────────────┐ ┌──────────────────────────────────┐ +│ AgentProvider │◄────────│ baseCLIProvider │ +│ CLIExecutor │ │ ├── execute() │ +│ Tokenizer │ │ ├── executeConversation() │ +│ Logger │ │ └── cliProviderHooks{...} │ +└───────────────────────┘ │ │ + │ YourProvider │ + │ ├── newBase() → hooks wiring │ + │ ├── buildExecuteArgs() │ + │ ├── buildConversationArgs() │ + │ ├── extractSessionID() │ + │ ├── parseDisplayEvents() │ + │ └── validateOptions() │ + └──────────────────────────────────┘ +``` + +## Domain Contract + +### AgentProvider Interface + +**File:** `internal/domain/ports/agent_provider.go` + +```go +type AgentProvider interface { + Execute(ctx context.Context, prompt string, options map[string]any, + stdout, stderr io.Writer) (*workflow.AgentResult, error) + ExecuteConversation(ctx context.Context, state *workflow.ConversationState, + prompt string, options map[string]any, + stdout, stderr io.Writer) (*workflow.ConversationResult, error) + Name() string + Validate() error +} +``` + +| Method | Purpose | +|--------|---------| +| `Execute` | Single-turn prompt execution. Returns `AgentResult` with output, tokens, timing. | +| `ExecuteConversation` | Multi-turn execution with conversation state. Returns `ConversationResult` with updated state. | +| `Name` | Unique provider identifier used in workflow YAML (`provider: your_name`). | +| `Validate` | Pre-flight check (binary in PATH, API key set, etc.). Called before first execution. | + +### AgentResult + +**File:** `internal/domain/workflow/agent_config.go` + +```go +type AgentResult struct { + Provider string + Output string // extracted text output + DisplayOutput string // filtered output for terminal display + Response map[string]any // parsed JSON response (optional) + Tokens int + TokensEstimated bool + Error error + StartedAt time.Time + CompletedAt time.Time +} +``` + +### ConversationResult + +**File:** `internal/domain/workflow/conversation.go` + +```go +type ConversationResult struct { + Provider string + State *ConversationState // updated state with new turns + Output string // last assistant response + DisplayOutput string + Response map[string]any + TokensInput int + TokensOutput int + TokensTotal int + TokensEstimated bool + Error error + StartedAt time.Time + CompletedAt time.Time +} +``` + +### ConversationState + +```go +type ConversationState struct { + SessionID string + Turns []Turn + TotalTurns int + TotalTokens int + StoppedBy StopReason +} + +type Turn struct { + Role TurnRole // "system", "user", "assistant" + Content string + Tokens int +} +``` + +## Base Layer: baseCLIProvider + +All CLI-based providers delegate to `baseCLIProvider`, which handles: + +- Prompt validation +- CLI binary execution via `CLIExecutor` +- Stream filtering and display event rendering +- Token counting via injected `Tokenizer` +- Conversation state cloning and turn management +- Timing (StartedAt / CompletedAt) + +### Hooks + +Provider-specific behavior is injected via `cliProviderHooks`: + +```go +type cliProviderHooks struct { + buildExecuteArgs func(prompt string, options map[string]any) ([]string, error) + buildConversationArgs func(state *workflow.ConversationState, prompt string, options map[string]any) ([]string, error) + extractSessionID func(output string) (string, error) + extractTextContent func(output string) string // optional + validateOptions func(options map[string]any) error // optional + parseDisplayEvents DisplayEventParser // optional +} +``` + +| Hook | Required | Purpose | +|------|----------|---------| +| `buildExecuteArgs` | **yes** | Construct CLI argv for single-turn execution. | +| `buildConversationArgs` | **yes** | Construct CLI argv for multi-turn execution (session resume). | +| `extractSessionID` | **yes** | Parse session/thread ID from CLI output for conversation resume. | +| `extractTextContent` | no | Extract human-readable text from structured output (e.g., JSON wrapper). Falls back to raw output if nil. | +| `validateOptions` | no | Validate provider-specific options before execution. Return error to reject. | +| `parseDisplayEvents` | no | Parse a single NDJSON line into `[]DisplayEvent` for real-time terminal display. | + +### What baseCLIProvider Does For You + +**In `execute()`:** +1. Rejects empty prompts +2. Calls `validateOptions` hook (if set) +3. Calls `buildExecuteArgs` hook to get CLI arguments +4. Runs binary via `CLIExecutor.Run()` +5. Filters output through `StreamFilterWriter` (if `parseDisplayEvents` set) +6. Counts output tokens: `b.tokenizer.CountTokens(output)` +7. Builds and returns `AgentResult` + +**In `executeConversation()`:** +1. Clones conversation state (caller's original is never mutated) +2. Appends user turn to cloned state +3. Calls `validateOptions` and `buildConversationArgs` hooks +4. Runs binary +5. Calls `extractSessionID` hook, updates state +6. Appends assistant turn to state +7. Counts input tokens (`CountTurnsTokens`) and output tokens (`CountTokens`) +8. Builds and returns `ConversationResult` + +## Step-by-Step Implementation + +### 1. Create the provider file + +**File:** `internal/infrastructure/agents/myprovider_provider.go` + +```go +package agents + +import ( + "context" + "fmt" + "io" + "os/exec" + + "github.com/awf-project/cli/internal/domain/ports" + "github.com/awf-project/cli/internal/domain/workflow" + "github.com/awf-project/cli/internal/infrastructure/logger" +) + +type MyProviderProvider struct { + base *baseCLIProvider + logger ports.Logger + executor ports.CLIExecutor + tokenizer ports.Tokenizer +} +``` + +### 2. Add constructors + +Two constructors are required: a zero-config default and a functional-options variant. + +```go +func NewMyProviderProvider() *MyProviderProvider { + p := &MyProviderProvider{ + logger: logger.NopLogger{}, + executor: NewExecCLIExecutor(), + } + p.base = p.newBase() + return p +} + +func NewMyProviderProviderWithOptions(opts ...MyProviderProviderOption) *MyProviderProvider { + p := &MyProviderProvider{ + logger: logger.NopLogger{}, + executor: NewExecCLIExecutor(), + } + for _, opt := range opts { + opt(p) + } + p.base = p.newBase() + return p +} +``` + +> **Important:** `p.newBase()` must be called **after** applying options, since options may set the executor, logger, or tokenizer that `newBase` forwards. + +### 3. Wire the hooks via newBase() + +```go +func (p *MyProviderProvider) newBase() *baseCLIProvider { + b := newBaseCLIProvider("myprovider", "myprovider-cli", p.executor, p.logger, cliProviderHooks{ + buildExecuteArgs: p.buildExecuteArgs, + buildConversationArgs: p.buildConversationArgs, + extractSessionID: p.extractSessionID, + validateOptions: validateMyProviderOptions, + parseDisplayEvents: p.parseMyProviderDisplayEvents, + }) + if p.tokenizer != nil { + b.tokenizer = p.tokenizer + } + return b +} +``` + +**Parameters to `newBaseCLIProvider`:** + +| Parameter | Value | +|-----------|-------| +| `name` | Provider identifier returned by `Name()`. Used in `AgentResult.Provider`. Must match the value users write in `provider:` YAML field. | +| `binary` | CLI binary name looked up in `$PATH`. | +| `executor` | The `CLIExecutor` to run the binary. Always forward `p.executor`. | +| `log` | Logger. Nil-defaults to `NopLogger`. | +| `hooks` | Provider-specific hooks (see table above). | + +### 4. Implement the required hooks + +#### buildExecuteArgs + +Construct the CLI arguments for a single-turn call. + +```go +func (p *MyProviderProvider) buildExecuteArgs(prompt string, options map[string]any) ([]string, error) { + args := []string{"run", "--prompt", prompt, "--format", "json"} + + if model, ok := getStringOption(options, "model"); ok { + args = append(args, "--model", model) + } + if skip, ok := getBoolOption(options, "dangerously_skip_permissions"); ok && skip { + args = append(args, "--yes") + } + + return args, nil +} +``` + +**Available helpers:** `getStringOption(options, key)`, `getBoolOption(options, key)` — type-safe extraction from `map[string]any`. + +#### buildConversationArgs + +Construct CLI arguments for multi-turn. Must handle session resume vs first turn. + +```go +func (p *MyProviderProvider) buildConversationArgs( + state *workflow.ConversationState, prompt string, options map[string]any, +) ([]string, error) { + var args []string + if state.SessionID != "" { + args = []string{"resume", state.SessionID, "--prompt", prompt, "--format", "json"} + } else { + effectivePrompt := buildFirstTurnPrompt(prompt, options) + args = []string{"run", "--prompt", effectivePrompt, "--format", "json"} + } + + if model, ok := getStringOption(options, "model"); ok { + args = append(args, "--model", model) + } + + return args, nil +} +``` + +**Key patterns:** +- Use `state.SessionID` to detect resume vs new conversation. +- Use `buildFirstTurnPrompt(prompt, options)` to inline `system_prompt` into the first message when the CLI has no native `--system-prompt` flag. +- Always force a structured output format (JSON/NDJSON) for reliable parsing. + +#### extractSessionID + +Parse the session identifier from CLI output so subsequent turns can resume. + +```go +func (p *MyProviderProvider) extractSessionID(output string) (string, error) { + if output == "" { + return "", errors.New("empty output") + } + evt := findFirstNDJSONEvent(output, "session_start") + if evt == nil { + return "", errors.New("session_start event not found") + } + id, ok := evt["session_id"].(string) + if !ok || id == "" { + return "", errors.New("session_id missing or empty") + } + return id, nil +} +``` + +**Available helper:** `findFirstNDJSONEvent(output, eventType)` — scans NDJSON output line-by-line for the first `{"type": eventType, ...}` event and returns it as `map[string]any`. + +> Session ID extraction errors are **non-fatal**. The base layer logs the error and continues in stateless mode. The conversation still works; it just cannot resume on the next turn. + +### 5. Implement the optional hooks + +#### validateOptions + +Reject invalid option combinations before execution. + +```go +func validateMyProviderOptions(options map[string]any) error { + if options == nil { + return nil + } + if model, ok := getStringOption(options, "model"); ok { + if !strings.HasPrefix(model, "myprovider-") { + return fmt.Errorf("invalid model: %s (must start with 'myprovider-')", model) + } + } + return nil +} +``` + +#### parseDisplayEvents + +Parse a single NDJSON line into display events for real-time terminal rendering. + +```go +func (p *MyProviderProvider) parseMyProviderDisplayEvents(line []byte) []DisplayEvent { + var evt struct { + Type string `json:"type"` + Content string `json:"content"` + Tool string `json:"tool_name"` + } + if err := json.Unmarshal(line, &evt); err != nil { + return nil + } + + switch evt.Type { + case "text": + return []DisplayEvent{{Kind: EventText, Text: evt.Content}} + case "tool_call": + return []DisplayEvent{{Kind: EventToolUse, Name: evt.Tool}} + } + return nil +} +``` + +**Display event kinds:** + +| Constant | Purpose | +|----------|---------| +| `EventText` | Text content from the assistant. Aggregated for `DisplayOutput`. | +| `EventToolUse` | Tool invocation. Rendered as tool name + argument preview. | + +**DisplayEvent fields:** + +| Field | Required | Purpose | +|-------|----------|---------| +| `Kind` | **yes** | `EventText` or `EventToolUse` | +| `Text` | for text | The text content | +| `Name` | for tools | Tool name | +| `Arg` | no | Truncated argument preview. Use `extractArgPreviewFromMap(args)` or `extractArgPreview(jsonStr)`. | +| `ID` | no | Tool call ID (empty if provider doesn't emit one) | +| `Delta` | no | `true` for streaming deltas (partial text chunks) | +| `Type` | no | Raw event type from provider output (for debugging) | + +### 6. Implement the AgentProvider interface methods + +#### Execute + +Delegate to `p.base.execute()`, then apply provider-specific post-processing. + +```go +func (p *MyProviderProvider) Execute( + ctx context.Context, prompt string, options map[string]any, stdout, stderr io.Writer, +) (*workflow.AgentResult, error) { + result, rawOutput, err := p.base.execute(ctx, prompt, options, stdout, stderr) + if err != nil { + return nil, err + } + + // Post-processing: extract text from structured output + if extracted := extractDisplayTextFromEvents(rawOutput, p.parseMyProviderDisplayEvents); extracted != "" { + result.Output = extracted + tokens, _ := p.base.tokenizer.CountTokens(extracted) //nolint:errcheck // ApproximationTokenizer never errors with a valid ratio + result.Tokens = tokens + } + + // Optional: parse JSON response + userFormat, _ := getStringOption(options, "output_format") + if userFormat == "json" || userFormat == "stream-json" { + if jsonResp := tryParseJSONResponse(rawOutput); jsonResp != nil { + result.Response = jsonResp + } + } + + return result, nil +} +``` + +**Why post-process?** When the CLI outputs NDJSON (events), the raw output is not human-readable. Post-processing extracts the actual assistant text and re-counts tokens on the extracted content. + +#### ExecuteConversation + +Most providers simply delegate without post-processing: + +```go +func (p *MyProviderProvider) ExecuteConversation( + ctx context.Context, state *workflow.ConversationState, prompt string, + options map[string]any, stdout, stderr io.Writer, +) (*workflow.ConversationResult, error) { + result, _, err := p.base.executeConversation(ctx, state, prompt, options, stdout, stderr) + if err != nil { + return nil, err + } + return result, nil +} +``` + +#### Name and Validate + +```go +func (p *MyProviderProvider) Name() string { + return "myprovider" +} + +func (p *MyProviderProvider) Validate() error { + _, err := exec.LookPath("myprovider-cli") + if err != nil { + return fmt.Errorf("myprovider-cli not found in PATH: %w", err) + } + return nil +} +``` + +### 7. Add functional options + +**File:** `internal/infrastructure/agents/options.go` + +```go +type MyProviderProviderOption func(*MyProviderProvider) + +func WithMyProviderExecutor(executor ports.CLIExecutor) MyProviderProviderOption { + return func(p *MyProviderProvider) { + p.executor = executor + } +} + +func WithMyProviderTokenizer(tok ports.Tokenizer) MyProviderProviderOption { + return func(p *MyProviderProvider) { + p.tokenizer = tok + } +} + +func WithMyProviderLogger(l ports.Logger) MyProviderProviderOption { + return func(p *MyProviderProvider) { + p.logger = l + } +} +``` + +### 8. Register in the registry + +**File:** `internal/infrastructure/agents/registry.go` + +Add to `RegisterDefaults()`: + +```go +func (r *AgentRegistry) RegisterDefaults() error { + defaults := []ports.AgentProvider{ + NewClaudeProvider(), + NewCodexProvider(), + NewGeminiProvider(), + NewOpenAICompatibleProvider(), + NewOpenCodeProvider(), + NewCopilotProvider(), + NewMyProviderProvider(), // <-- add here + } + // ... +} +``` + +## Testing + +### Option tests + +**File:** `internal/infrastructure/agents/provider_options_test.go` + +```go +func TestWithMyProviderTokenizer(t *testing.T) { + tok := &mockTokenizer{countTokensResult: 99} + provider := NewMyProviderProviderWithOptions( + WithMyProviderExecutor(mocks.NewMockCLIExecutor()), + WithMyProviderTokenizer(tok), + ) + assert.Equal(t, tok, provider.base.tokenizer) +} +``` + +### Argument construction tests + +Test that `buildExecuteArgs` and `buildConversationArgs` produce correct CLI arguments for all option combinations. + +```go +func TestMyProvider_BuildExecuteArgs(t *testing.T) { + tests := []struct { + name string + prompt string + options map[string]any + wantArgs []string + wantErr bool + }{ + { + name: "basic prompt", + prompt: "hello", + options: nil, + wantArgs: []string{"run", "--prompt", "hello", "--format", "json"}, + }, + { + name: "with model", + prompt: "hello", + options: map[string]any{"model": "myprovider-large"}, + wantArgs: []string{"run", "--prompt", "hello", "--format", "json", "--model", "myprovider-large"}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewMyProviderProvider() + args, err := p.buildExecuteArgs(tt.prompt, tt.options) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantArgs, args) + }) + } +} +``` + +### Session ID extraction tests + +```go +func TestMyProvider_ExtractSessionID(t *testing.T) { + tests := []struct { + name string + output string + wantID string + wantErr bool + }{ + { + name: "valid session", + output: `{"type":"session_start","session_id":"abc-123"}`, + wantID: "abc-123", + }, + { + name: "missing event", + output: `{"type":"text","content":"hello"}`, + wantErr: true, + }, + { + name: "empty output", + output: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewMyProviderProvider() + id, err := p.extractSessionID(tt.output) + if tt.wantErr { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.wantID, id) + }) + } +} +``` + +### Display event parser tests + +```go +func TestMyProvider_ParseDisplayEvents(t *testing.T) { + p := NewMyProviderProvider() + + t.Run("text event", func(t *testing.T) { + events := p.parseMyProviderDisplayEvents([]byte(`{"type":"text","content":"hello"}`)) + require.Len(t, events, 1) + assert.Equal(t, EventText, events[0].Kind) + assert.Equal(t, "hello", events[0].Text) + }) + + t.Run("tool event", func(t *testing.T) { + events := p.parseMyProviderDisplayEvents([]byte(`{"type":"tool_call","tool_name":"read_file"}`)) + require.Len(t, events, 1) + assert.Equal(t, EventToolUse, events[0].Kind) + assert.Equal(t, "read_file", events[0].Name) + }) + + t.Run("unknown event returns nil", func(t *testing.T) { + events := p.parseMyProviderDisplayEvents([]byte(`{"type":"unknown"}`)) + assert.Nil(t, events) + }) + + t.Run("invalid JSON returns nil", func(t *testing.T) { + events := p.parseMyProviderDisplayEvents([]byte(`not json`)) + assert.Nil(t, events) + }) +} +``` + +### Option validation tests + +```go +func TestMyProvider_ValidateOptions(t *testing.T) { + tests := []struct { + name string + options map[string]any + wantErr bool + }{ + {name: "nil options", options: nil}, + {name: "valid model", options: map[string]any{"model": "myprovider-large"}}, + {name: "invalid model", options: map[string]any{"model": "gpt-4"}, wantErr: true}, + {name: "unknown option ignored", options: map[string]any{"unknown": "value"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := validateMyProviderOptions(tt.options) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + }) + } +} +``` + +## Mandatory Cross-Provider Conventions + +Every provider **must** handle these patterns. Omitting any of them creates inconsistency for users who switch between providers in their workflows. + +### Force structured output format + +All CLI providers force NDJSON/JSON output at the CLI level, regardless of what the user requests. This ensures consistent session ID extraction, display event filtering, and text extraction. + +```go +// The user's output_format preference controls post-processing (display vs raw), +// but the wire format is always NDJSON. +func (p *MyProviderProvider) buildExecuteArgs(prompt string, options map[string]any) ([]string, error) { + args := []string{"run", "--prompt", prompt} + args = append(args, "--format", "json") // always force structured output + // ... +} +``` + +How each provider does it: + +| Provider | Forced flag | +|----------|-------------| +| Claude | `--output-format stream-json --verbose` | +| Gemini | `--output-format stream-json` | +| Codex | `exec --json` | +| Copilot | `--output-format=json --silent` | +| OpenCode | `--format json` | + +### Handle `dangerously_skip_permissions` + +This option is **cross-provider** — users expect it to work in any workflow regardless of provider. Each CLI maps it to its own flag: + +```go +// In buildExecuteArgs and buildConversationArgs: +if skip, ok := getBoolOption(options, "dangerously_skip_permissions"); ok && skip { + args = append(args, "--your-cli-equivalent-flag") +} +``` + +| Provider | CLI flag | +|----------|----------| +| Claude | `--dangerously-skip-permissions` | +| Gemini | `--approval-mode=yolo` | +| Codex | `--dangerously-bypass-approvals-and-sandbox` | +| Copilot | `--allow-all` | +| OpenCode | Not supported (logged at debug level, silently ignored) | + +If your CLI has no equivalent, log a debug message and ignore: + +```go +if skip, ok := getBoolOption(options, "dangerously_skip_permissions"); ok && skip { + p.logger.Debug("dangerously_skip_permissions is not supported by myprovider and will be ignored") +} +``` + +### Handle `system_prompt` + +Only Claude has a native `--system-prompt` flag. All other providers inline it into the first turn's message using the shared helper: + +```go +// In buildConversationArgs, for the first turn (no session ID): +effectivePrompt := buildFirstTurnPrompt(prompt, options) +// Returns: "system prompt content\n\nuserPrompt" or just "userPrompt" if no system_prompt +``` + +If your CLI has a native system prompt flag, use it directly instead: + +```go +if sysPrompt, ok := getStringOption(options, "system_prompt"); ok && sysPrompt != "" { + args = append(args, "--system-prompt", sysPrompt) +} +``` + +System prompt must only be applied on the **first turn**. On subsequent turns (when `state.SessionID != ""`), the provider's session already retains the system context. + +### Handle `model` + +Every provider must support the `model` option. Validate the model name in `validateOptions` to reject models incompatible with your CLI: + +```go +func validateMyProviderOptions(options map[string]any) error { + if options == nil { + return nil + } + if model, ok := getStringOption(options, "model"); ok { + if !strings.HasPrefix(model, "myprovider-") { + return fmt.Errorf("invalid model: %s (must start with 'myprovider-')", model) + } + } + return nil +} +``` + +### Handle `output_format` for response parsing + +The `output_format` option controls what the user sees. When the user requests `json` or `stream-json`, expose the parsed JSON response in `result.Response`: + +```go +// In Execute(), after text extraction: +userFormat, _ := getStringOption(options, "output_format") +if userFormat == "json" || userFormat == "stream-json" { + if jsonResp := tryParseJSONResponse(rawOutput); jsonResp != nil { + result.Response = jsonResp + } +} +``` + +### Ignore unknown options silently + +Go's `map[string]any` behavior means unsupported option keys are simply not looked up. Never iterate over options to reject unknown keys — this allows cross-provider workflows to pass provider-specific options that only apply to certain providers. + +### Token counting pattern + +Every `CountTokens` call in provider code must use the `//nolint:errcheck` directive with an explanatory comment. This is enforced by `golangci-lint` with `check-blank: true`: + +```go +tokens, _ := p.base.tokenizer.CountTokens(extracted) //nolint:errcheck // ApproximationTokenizer never errors with a valid ratio +result.Tokens = tokens +``` + +### NUL byte sanitization in display event parsers + +CLI tools may output NUL bytes (`0x00`) that break `json.Unmarshal`. Sanitize before parsing: + +```go +func (p *MyProviderProvider) parseMyProviderDisplayEvents(line []byte) []DisplayEvent { + // Escape NUL bytes to valid JSON unicode sequences + sanitized := bytes.ReplaceAll(line, []byte{0x00}, []byte(``)) + + var evt struct { /* ... */ } + if err := json.Unmarshal(sanitized, &evt); err != nil { + return nil + } + // ... +} +``` + +Codex and OpenCode use this escape pattern. Claude replaces NUL with spaces instead. + +### Error handling conventions + +| Scenario | Handling | +|----------|----------| +| `Validate()` — binary not found | Return `fmt.Errorf("binary not found in PATH: %w", err)` | +| `extractSessionID` fails | **Non-fatal.** Base layer logs at debug and continues stateless. | +| JSON parsing fails in `Execute()` | **Non-fatal.** `result.Response` stays nil. | +| `validateOptions` returns error | **Fatal.** Execution is aborted before running the CLI. | +| Empty output from CLI | Base layer substitutes `" "` (single space) to prevent zero-length issues. | + +### Apply `dangerously_skip_permissions` in both arg builders + +The `buildExecuteArgs` and `buildConversationArgs` hooks must **both** handle `dangerously_skip_permissions` (and `model`, etc.). Users don't know which execution path their workflow triggers — missing the option in one path creates hard-to-debug inconsistencies. + +### extractTextContent vs extractDisplayTextFromEvents + +Two mechanisms exist for extracting human-readable text from structured output: + +| Mechanism | When to use | +|-----------|-------------| +| `extractTextContent` hook | Your CLI wraps the final answer in a specific JSON envelope (e.g., Claude's `result` event, Copilot's `assistant.message` event). Set this hook to extract from that envelope. | +| `extractDisplayTextFromEvents()` | Your CLI outputs NDJSON events where text is spread across multiple `EventText` events. This helper aggregates all text events via your `parseDisplayEvents` hook. | + +Most providers use `extractDisplayTextFromEvents` in their `Execute()` post-processing. Only set `extractTextContent` if your provider needs a different extraction strategy for `executeConversation`. + +## Existing Providers Reference + +| Provider | Binary | Name | Session Event | Session Field | Resume Flag | System Prompt | +|----------|--------|------|---------------|---------------|-------------|---------------| +| Claude | `claude` | `claude` | `result` | `session_id` | `-r ID` | `--system-prompt` (native) | +| Gemini | `gemini` | `gemini` | `init` | `session_id` | `--resume ID` | Inlined in first turn | +| Codex | `codex` | `codex` | `thread.started` | `thread_id` | `resume ID` (subcommand) | Inlined in first turn | +| Copilot | `copilot` | `github_copilot` | `result` | `sessionId` (camelCase) | `--resume=ID` | Inlined in first turn | +| OpenCode | `opencode` | `opencode` | `step_start` | `sessionID` | `-s ID` / `-c` (fallback) | Inlined in first turn | +| OpenAI-Compatible | HTTP API | `openai_compatible` | API response | N/A | Messages array | `system` role message | + +## Non-CLI Provider (HTTP API) + +`OpenAICompatibleProvider` follows a completely different path from CLI-based providers. It implements `AgentProvider` **directly** without using `baseCLIProvider`, hooks, or any of the CLI infrastructure. + +### What changes vs CLI providers + +| Aspect | CLI providers | HTTP provider (OpenAI-Compatible) | +|--------|--------------|----------------------------------| +| Execution | `CLIExecutor.Run()` → binary subprocess | `httpx.Client` → HTTP POST to `/chat/completions` | +| Token counting | `ports.Tokenizer` → estimation (`len/4`), `TokensEstimated: true` | API response `usage` field → exact counts, `TokensEstimated: false` | +| Session management | Extract session ID from NDJSON, resume via CLI flag | No session ID — full messages array sent each turn | +| System prompt | Inlined in first turn or native CLI flag | `system` role message in messages array | +| Display events | NDJSON stream filtering via `DisplayEventParser` | Direct write to stdout, no parsing needed | +| State cloning | Done by `baseCLIProvider.executeConversation()` | Must call `cloneState()` manually | +| Base struct | `base *baseCLIProvider` field | No base — flat struct with `httpClient *httpx.Client` | + +### Token counting: the key difference + +CLI providers estimate tokens because CLI tools don't report token usage: + +```go +// CLI provider pattern — estimation +tokens, _ := p.base.tokenizer.CountTokens(extracted) //nolint:errcheck +result.Tokens = tokens +result.TokensEstimated = true // set by tokenizer.IsEstimate() +``` + +The HTTP provider gets exact counts from the API response: + +```go +// HTTP provider pattern — exact counts from API +result.Tokens = resp.Usage.TotalTokens +result.TokensEstimated = false + +// In ExecuteConversation, input/output are separated: +result.TokensInput = resp.Usage.PromptTokens +result.TokensOutput = resp.Usage.CompletionTokens +result.TokensTotal = resp.Usage.TotalTokens +``` + +No `Tokenizer` port is used. No `//nolint:errcheck` is needed. + +### Conversation: messages array vs session resume + +CLI providers maintain a session ID and pass it as a CLI flag to resume: + +```go +// CLI: resume with session ID +args = []string{"--resume", state.SessionID, "-p", prompt} +``` + +The HTTP provider reconstructs the full messages array from conversation state on every turn: + +```go +// HTTP: rebuild messages from turns +messages := make([]chatMessage, 0, len(state.Turns)+2) +if opts.systemPrompt != "" { + messages = append(messages, chatMessage{Role: "system", Content: opts.systemPrompt}) +} +for _, turn := range state.Turns { + messages = append(messages, chatMessage{Role: string(turn.Role), Content: turn.Content}) +} +messages = append(messages, chatMessage{Role: "user", Content: prompt}) +``` + +### Struct and constructor + +```go +type OpenAICompatibleProvider struct { + httpClient *httpx.Client // no base, no logger, no executor, no tokenizer +} + +func NewOpenAICompatibleProvider(opts ...OpenAICompatibleProviderOption) *OpenAICompatibleProvider { + p := &OpenAICompatibleProvider{ + httpClient: httpx.NewClient(), + } + for _, opt := range opts { + opt(p) + } + return p +} +``` + +### Option handling + +Options are parsed into a dedicated `parsedOptions` struct with env var fallbacks: + +```go +type parsedOptions struct { + baseURL string // required — env: OPENAI_BASE_URL + model string // required — env: OPENAI_MODEL + apiKey string // optional — env: OPENAI_API_KEY + systemPrompt string + temperature *float64 // 0.0–2.0 + maxCompletionTokens *int + topP *float64 // 0.0–1.0 +} +``` + +### When to use this pattern + +Use the HTTP provider pattern (not `baseCLIProvider`) when: +- Your provider is an HTTP API, not a CLI binary +- The API returns exact token counts in its response +- Conversation is managed via a messages array, not session IDs +- There is no NDJSON stream to parse + +Use `OpenAICompatibleProvider` as your reference implementation. + +## Checklist + +### Structure +- [ ] Provider struct with `base`, `logger`, `executor`, `tokenizer` fields +- [ ] `NewXxxProvider()` zero-config constructor +- [ ] `NewXxxProviderWithOptions()` functional-options constructor +- [ ] `newBase()` called **after** options, wires all hooks, forwards tokenizer with nil-check +- [ ] Option types added to `options.go` (`WithXxxExecutor`, `WithXxxTokenizer`, `WithXxxLogger`) +- [ ] Provider registered in `registry.go` `RegisterDefaults()` + +### Hooks (required) +- [ ] `buildExecuteArgs` forces structured output format (JSON/NDJSON) +- [ ] `buildConversationArgs` handles first turn vs session resume +- [ ] `extractSessionID` parses session ID from provider-specific event + +### Cross-provider options (mandatory) +- [ ] `model` handled in both `buildExecuteArgs` and `buildConversationArgs` +- [ ] `model` validated in `validateOptions` (prefix check or allowlist) +- [ ] `dangerously_skip_permissions` mapped to CLI-specific flag (or logged + ignored if unsupported) +- [ ] `system_prompt` handled via native flag or `buildFirstTurnPrompt()` on first turn only +- [ ] `output_format` checked in `Execute()` to conditionally expose `result.Response` +- [ ] Unknown options silently ignored (never iterate to reject) +- [ ] All options applied in **both** `buildExecuteArgs` and `buildConversationArgs` + +### Execute post-processing +- [ ] Text extracted from NDJSON via `extractDisplayTextFromEvents` or `extractTextContent` +- [ ] Tokens re-counted on extracted text (not raw output) +- [ ] `//nolint:errcheck` with explanatory comment on every `CountTokens` call +- [ ] JSON response parsed when `output_format` is `json` or `stream-json` + +### Interface methods +- [ ] `Execute` delegates to `p.base.execute()` with post-processing +- [ ] `ExecuteConversation` delegates to `p.base.executeConversation()` +- [ ] `Name()` returns unique provider identifier (matches `provider:` YAML field) +- [ ] `Validate()` checks binary via `exec.LookPath` with `%w` error wrapping + +### Display events +- [ ] `parseDisplayEvents` handles text events (`EventText`) and tool events (`EventToolUse`) +- [ ] NUL bytes sanitized before `json.Unmarshal` +- [ ] Unknown/malformed events return `nil` (never error) + +### Tests +- [ ] Option injection tests (`TestWithXxxTokenizer`, `TestWithXxxExecutor`) +- [ ] `buildExecuteArgs` table-driven tests (basic, with model, with permissions) +- [ ] `buildConversationArgs` tests (first turn with system_prompt, resume with session ID) +- [ ] `extractSessionID` tests (valid, missing event, empty output) +- [ ] `parseDisplayEvents` tests (text, tool, unknown, invalid JSON) +- [ ] `validateOptions` tests (nil, valid, invalid model, unknown option) + +### Final verification +- [ ] `make build` passes +- [ ] `make lint` passes with zero violations +- [ ] `make test` passes +- [ ] `grep -rn "dangerously_skip_permissions" your_provider.go` returns at least one match diff --git a/docs/development/project-structure.md b/docs/development/project-structure.md index 909be019..b8ce4ef0 100644 --- a/docs/development/project-structure.md +++ b/docs/development/project-structure.md @@ -58,7 +58,7 @@ awf/ │ │ ├── plugin/ # RPC plugin manager, composite provider │ │ ├── repository/ # YAML workflow loaders │ │ ├── store/ # SQLite history, JSON state store -│ │ ├── tokenizer/ # Token counting +│ │ ├── tokenizer/ # Token counting (ports.Tokenizer implementations) │ │ └── xdg/ # XDG directory discovery │ │ │ └── interfaces/ # External interfaces diff --git a/docs/reference/interpolation.md b/docs/reference/interpolation.md index b56ae46a..b8814637 100644 --- a/docs/reference/interpolation.md +++ b/docs/reference/interpolation.md @@ -45,7 +45,10 @@ Access output, exit code, and token usage from previous steps: ```yaml {{.states.step_name.Output}} # Command output (raw text, or cleaned if output_format set) {{.states.step_name.ExitCode}} # Exit code (0 for success, non-zero for failure) -{{.states.step_name.TokensUsed}} # Tokens consumed by agent steps +{{.states.step_name.TokensUsed}} # Total tokens consumed by agent steps +{{.states.step_name.TokensInput}} # Input tokens (prompt + context) +{{.states.step_name.TokensOutput}} # Output tokens (assistant response) +{{.states.step_name.TokensEstimated}} # true if token counts are estimates, false if from provider {{.states.step_name.Response.field}} # Parsed field from operation/agent structured output (heuristic) {{.states.step_name.JSON.field}} # Parsed field from output_format: json (explicit) ``` @@ -98,7 +101,7 @@ On POSIX systems, exit codes are typically 0–255. Exit code 0 indicates succes #### TokensUsed -Tokens consumed by agent steps (Claude, Gemini, Codex). Available for all agent step types: +Total tokens consumed by agent steps. Available for all agent step types: ```yaml run_agent: @@ -121,7 +124,40 @@ transitions: goto: token_exceeded ``` -**Note**: Replaced deprecated `states.step_name.Tokens` field. If migrating from earlier versions, update workflow YAML expressions from `{{.states.step_name.Tokens}}` to `{{.states.step_name.TokensUsed}}`. +#### TokensInput / TokensOutput + +Separate input and output token counts. Available for all agent step types: + +```yaml +log_details: + type: step + command: | + echo "Input: {{.states.analyze.TokensInput}}, Output: {{.states.analyze.TokensOutput}}" +``` + +In conversation mode (`continue_from`), `TokensInput` includes all prior turns. In single-turn mode, `TokensInput` is `0` and `TokensOutput` equals `TokensUsed`. + +#### TokensEstimated + +Boolean indicating whether token counts are exact (from the provider's JSON output) or estimated (`len/4` approximation). Use to decide whether to trust the values for billing or budgeting: + +```yaml +transitions: + - when: "states.analyze.TokensEstimated == false and states.analyze.TokensUsed > inputs.budget" + goto: over_budget +``` + +| Provider | TokensEstimated | Source | +|----------|----------------|--------| +| Claude | `false` | `result` event `usage` field | +| Gemini | `false` | `result` event `stats` field | +| Codex | `false` | `turn.completed` event `usage` field | +| Copilot | `false` (output only) | `assistant.message` event `outputTokens` | +| OpenCode | `false` | `step_finish` event `part.tokens` field | +| OpenAI-Compatible | `false` | API response `usage` field | +| Any provider (fallback) | `true` | `len(output) / 4` estimation | + +**Note**: `TokensUsed` replaced deprecated `states.step_name.Tokens` field. If migrating from earlier versions, update workflow YAML expressions from `{{.states.step_name.Tokens}}` to `{{.states.step_name.TokensUsed}}`. #### Response (Operation Outputs) diff --git a/docs/user-guide/agent-steps.md b/docs/user-guide/agent-steps.md index 126ea042..f5241713 100644 --- a/docs/user-guide/agent-steps.md +++ b/docs/user-guide/agent-steps.md @@ -200,7 +200,7 @@ analyze: - `top_p`: Nucleus sampling threshold - `system_prompt`: System message prepended to conversation (used in `mode: conversation`) -**Token Tracking:** Unlike CLI-based providers that estimate tokens from output length, `openai_compatible` reports actual token usage from the API response. +**Token Tracking:** Unlike CLI-based providers that estimate tokens via the unified `Tokenizer` port, `openai_compatible` reports actual token usage from the API response. **Display Cadence:** Unlike streaming CLI providers (Claude, Codex, Gemini, OpenCode) that display output incrementally, `openai_compatible` displays all events in a single burst after the HTTP response completes. This means tool-use markers and text output appear together at the end of execution rather than interleaved during streaming. The rendered shape and tool markers are identical across all providers — only the timing differs. @@ -547,7 +547,10 @@ Agent responses are automatically captured in the execution state: | `{{.states.step_name.Output}}` | string | Raw response text (or cleaned text if `output_format` is set) | | `{{.states.step_name.Response}}` | object | Parsed JSON response (automatic heuristic) | | `{{.states.step_name.JSON}}` | object | Parsed JSON from `output_format: json` (explicit, see [Output Formatting](#output-formatting)) | -| `{{.states.step_name.TokensUsed}}` | int | Tokens consumed by this step | +| `{{.states.step_name.TokensUsed}}` | int | Total tokens consumed by this step | +| `{{.states.step_name.TokensInput}}` | int | Input tokens (prompt + context). `0` in single-turn mode. | +| `{{.states.step_name.TokensOutput}}` | int | Output tokens (assistant response) | +| `{{.states.step_name.TokensEstimated}}` | bool | `false` when tokens come from the provider, `true` when estimated | | `{{.states.step_name.ExitCode}}` | int | 0 for success, non-zero for failure | ### Accessing Raw Output @@ -964,7 +967,7 @@ aggregate: ## Token Tracking -Some providers report token usage (useful for cost tracking): +All agent providers report token usage in the `TokensUsed` field: ```yaml analyze: @@ -981,7 +984,31 @@ log_tokens: on_success: done ``` -**Note**: All agent providers (Claude, Gemini, Codex) report token usage in the `TokensUsed` field. +**How it works:** + +All 6 providers extract **real token counts** from their CLI/API JSON output when available. `TokensEstimated` is `false` in this case. If the provider output does not contain token data, AWF falls back to an approximation (`len(output)/4`) and sets `TokensEstimated` to `true`. + +| Provider | Source of real tokens | Fields available | +|----------|----------------------|-----------------| +| Claude | `result` event `usage` field | input, output, cost | +| Gemini | `result` event `stats` field | input, output, total | +| Codex | `turn.completed` event `usage` field | input, output | +| Copilot | `assistant.message` event | output only | +| OpenCode | `step_finish` event `part.tokens` field | input, output, total, cost | +| OpenAI-Compatible | API response `usage` field | input, output, total | + +Use `TokensInput` and `TokensOutput` for detailed tracking: + +```yaml +log_details: + type: step + command: | + echo "Input: {{.states.analyze.TokensInput}}, Output: {{.states.analyze.TokensOutput}}" + echo "Estimated: {{.states.analyze.TokensEstimated}}" + on_success: done +``` + +In conversation mode (`continue_from`), `TokensInput` includes all prior turns. In single-turn mode, `TokensInput` is `0` and `TokensOutput` equals `TokensUsed`. ## Best Practices diff --git a/go.mod b/go.mod index 4748fae2..6051df8e 100644 --- a/go.mod +++ b/go.mod @@ -5,14 +5,12 @@ go 1.25.8 require ( charm.land/bubbles/v2 v2.1.0 charm.land/bubbletea/v2 v2.0.6 - charm.land/glamour/v2 v2.0.0 charm.land/lipgloss/v2 v2.0.3 github.com/expr-lang/expr v1.17.7 github.com/fatih/color v1.18.0 github.com/google/uuid v1.6.0 github.com/hashicorp/go-hclog v1.6.3 github.com/hashicorp/go-plugin v1.7.0 - github.com/pkoukk/tiktoken-go v0.1.8 github.com/spf13/cobra v1.10.2 github.com/spf13/pflag v1.0.10 github.com/stretchr/testify v1.11.1 @@ -31,27 +29,22 @@ require ( ) require ( - github.com/alecthomas/chroma/v2 v2.20.0 // indirect github.com/atotto/clipboard v0.1.4 // indirect - github.com/aymerick/douceur v0.2.0 // indirect github.com/cenkalti/backoff/v5 v5.0.3 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/charmbracelet/colorprofile v0.4.3 // indirect github.com/charmbracelet/ultraviolet v0.0.0-20260416155717-489999b90468 // indirect github.com/charmbracelet/x/ansi v0.11.7 // indirect - github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf // indirect github.com/charmbracelet/x/term v0.2.2 // indirect github.com/charmbracelet/x/termios v0.1.1 // indirect github.com/charmbracelet/x/windows v0.2.2 // indirect github.com/clipperhouse/displaywidth v0.11.0 // indirect github.com/clipperhouse/uax29/v2 v2.7.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/dlclark/regexp2 v1.11.5 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect github.com/golang/protobuf v1.5.4 // indirect - github.com/gorilla/css v1.0.1 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 // indirect github.com/hashicorp/yamux v0.1.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect @@ -59,7 +52,6 @@ require ( github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.23 // indirect - github.com/microcosm-cc/bluemonday v1.0.27 // indirect github.com/muesli/cancelreader v0.2.2 // indirect github.com/ncruces/go-strftime v1.0.0 // indirect github.com/oklog/run v1.1.0 // indirect @@ -69,8 +61,6 @@ require ( github.com/sahilm/fuzzy v0.1.1 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect - github.com/yuin/goldmark v1.7.13 // indirect - github.com/yuin/goldmark-emoji v1.0.6 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect go.opentelemetry.io/otel/metric v1.43.0 // indirect go.opentelemetry.io/proto/otlp v1.10.0 // indirect diff --git a/go.sum b/go.sum index 006298eb..8af27184 100644 --- a/go.sum +++ b/go.sum @@ -2,22 +2,12 @@ charm.land/bubbles/v2 v2.1.0 h1:YSnNh5cPYlYjPxRrzs5VEn3vwhtEn3jVGRBT3M7/I0g= charm.land/bubbles/v2 v2.1.0/go.mod h1:l97h4hym2hvWBVfmJDtrEHHCtkIKeTEb3TTJ4ZOB3wY= charm.land/bubbletea/v2 v2.0.6 h1:UHN/91OyuhaOFGSrBXQ/hMZD8IO1Uc4BvHlgHXL2WJo= charm.land/bubbletea/v2 v2.0.6/go.mod h1:MH/D8ZLlN3op37vQvijKuU29g3rqTp+aQapURFonF9g= -charm.land/glamour/v2 v2.0.0 h1:IDBoqLEy7Hdpb9VOXN+khLP/XSxtJy1VsHuW/yF87+U= -charm.land/glamour/v2 v2.0.0/go.mod h1:kjq9WB0s8vuUYZNYey2jp4Lgd9f4cKdzAw88FZtpj/w= charm.land/lipgloss/v2 v2.0.3 h1:yM2zJ4Cf5Y51b7RHIwioil4ApI/aypFXXVHSwlM6RzU= charm.land/lipgloss/v2 v2.0.3/go.mod h1:7myLU9iG/3xluAWzpY/fSxYYHCgoKTie7laxk6ATwXA= -github.com/alecthomas/assert/v2 v2.11.0 h1:2Q9r3ki8+JYXvGsDyBXwH3LcJ+WK5D0gc5E8vS6K3D0= -github.com/alecthomas/assert/v2 v2.11.0/go.mod h1:Bze95FyfUr7x34QZrjL+XP+0qgp/zg8yS+TtBj1WA3k= -github.com/alecthomas/chroma/v2 v2.20.0 h1:sfIHpxPyR07/Oylvmcai3X/exDlE8+FA820NTz+9sGw= -github.com/alecthomas/chroma/v2 v2.20.0/go.mod h1:e7tViK0xh/Nf4BYHl00ycY6rV7b8iXBksI9E359yNmA= -github.com/alecthomas/repr v0.5.1 h1:E3G4t2QbHTSNpPKBgMTln5KLkZHLOcU7r37J4pXBuIg= -github.com/alecthomas/repr v0.5.1/go.mod h1:Fr0507jx4eOXV7AlPV6AVZLYrLIuIeSOWtW57eE/O/4= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI= github.com/aymanbagabas/go-udiff v0.4.1 h1:OEIrQ8maEeDBXQDoGCbbTTXYJMYRCRO1fnodZ12Gv5o= github.com/aymanbagabas/go-udiff v0.4.1/go.mod h1:0L9PGwj20lrtmEMeyw4WKJ/TMyDtvAoK9bf2u/mNo3w= -github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk= -github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4= github.com/bufbuild/protocompile v0.14.1 h1:iA73zAf/fyljNjQKwYzUHD6AD4R8KMasmwa/FBatYVw= github.com/bufbuild/protocompile v0.14.1/go.mod h1:ppVdAIhbr2H8asPk6k4pY7t9zB1OU5DoEw9xY/FUi1c= github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= @@ -32,8 +22,6 @@ github.com/charmbracelet/x/ansi v0.11.7 h1:kzv1kJvjg2S3r9KHo8hDdHFQLEqn4RBCb39dA github.com/charmbracelet/x/ansi v0.11.7/go.mod h1:9qGpnAVYz+8ACONkZBUWPtL7lulP9No6p1epAihUZwQ= github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f h1:pk6gmGpCE7F3FcjaOEKYriCvpmIN4+6OS/RD0vm4uIA= github.com/charmbracelet/x/exp/golden v0.0.0-20250806222409-83e3a29d542f/go.mod h1:IfZAMTHB6XkZSeXUqriemErjAWCCzT0LwjKFYCZyw0I= -github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf h1:rLG0Yb6MQSDKdB52aGX55JT1oi0P0Kuaj7wi1bLUpnI= -github.com/charmbracelet/x/exp/slice v0.0.0-20250327172914-2fdc97757edf/go.mod h1:B3UgsnsBZS/eX42BlaNiJkD1pPOUa+oF1IYC6Yd2CEU= github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk= github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI= github.com/charmbracelet/x/termios v0.1.1 h1:o3Q2bT8eqzGnGPOYheoYS8eEleT5ZVNYNy8JawjaNZY= @@ -48,8 +36,6 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6N github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ= -github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/expr-lang/expr v1.17.7 h1:Q0xY/e/2aCIp8g9s/LGvMDCC5PxYlvHgDZRQ4y16JX8= @@ -70,8 +56,6 @@ github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17k github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/gorilla/css v1.0.1 h1:ntNaBIghp6JmvWnxbZKANoLyuXTPZ4cAMlo6RyhlbO8= -github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A3b0= github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0 h1:HWRh5R2+9EifMyIHV7ZV+MIZqgz+PMpZ14Jynv3O2Zs= github.com/grpc-ecosystem/grpc-gateway/v2 v2.28.0/go.mod h1:JfhWUomR1baixubs02l85lZYYOm7LV6om4ceouMv45c= github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= @@ -82,8 +66,6 @@ github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/hashicorp/yamux v0.1.2 h1:XtB8kyFOyHXYVFnwT5C3+Bdo8gArse7j2AQ0DA0Uey8= github.com/hashicorp/yamux v0.1.2/go.mod h1:C+zze2n6e/7wshOZep2A70/aQU6QBRWJO/G6FT1wIns= -github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= -github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jhump/protoreflect v1.17.0 h1:qOEr613fac2lOuTgWN4tPAtLL7fUSbuJL5X5XumQh94= @@ -106,16 +88,12 @@ github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWE github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.23 h1:7ykA0T0jkPpzSvMS5i9uoNn2Xy3R383f9HDx3RybWcw= github.com/mattn/go-runewidth v0.0.23/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= -github.com/microcosm-cc/bluemonday v1.0.27 h1:MpEUotklkwCSLeH+Qdx1VJgNqLlpY2KXwXFM08ygZfk= -github.com/microcosm-cc/bluemonday v1.0.27/go.mod h1:jFi9vgW+H7c3V0lb6nR74Ib/DIB5OBs92Dimizgw2cA= github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA= github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo= github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= github.com/oklog/run v1.1.0 h1:GEenZ1cK0+q0+wsJew9qUg/DyD8k3JzYsZAi5gYi2mA= github.com/oklog/run v1.1.0/go.mod h1:sVPdnTZT1zYwAJeCMu2Th4T21pA3FPOQRfWjQlk7DVU= -github.com/pkoukk/tiktoken-go v0.1.8 h1:85ENo+3FpWgAACBaEUVp+lctuTcYUO7BtmfhlN/QTRo= -github.com/pkoukk/tiktoken-go v0.1.8/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= @@ -140,10 +118,6 @@ github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= -github.com/yuin/goldmark v1.7.13 h1:GPddIs617DnBLFFVJFgpo1aBfe/4xcvMc3SB5t/D0pA= -github.com/yuin/goldmark v1.7.13/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg= -github.com/yuin/goldmark-emoji v1.0.6 h1:QWfF2FYaXwL74tfGOW5izeiZepUDroDJfWubQI9HTHs= -github.com/yuin/goldmark-emoji v1.0.6/go.mod h1:ukxJDKFpdFb5x0a5HqbdlcKtebh086iJpI31LTKmWuA= go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= diff --git a/internal/application/execution_service.go b/internal/application/execution_service.go index 5515f807..efbec636 100644 --- a/internal/application/execution_service.go +++ b/internal/application/execution_service.go @@ -2332,6 +2332,9 @@ func (s *ExecutionService) executeAgentStep( state.DisplayOutput = convResult.DisplayOutput state.Response = convResult.Response state.TokensUsed = convResult.TokensTotal + state.TokensInput = convResult.TokensInput + state.TokensOutput = convResult.TokensOutput + state.TokensEstimated = convResult.TokensEstimated state.Conversation = convResult.State result = &workflow.AgentResult{ Provider: convResult.Provider, @@ -2354,6 +2357,8 @@ func (s *ExecutionService) executeAgentStep( state.Response = result.Response // AC6: Token usage in states.step_name.tokens_used state.TokensUsed = result.Tokens + state.TokensOutput = result.Tokens + state.TokensEstimated = result.TokensEstimated // F065: Apply output format post-processing if formatErr := s.applyOutputFormat(step, &state, execCtx); formatErr != nil { diff --git a/internal/application/interpolation_helpers.go b/internal/application/interpolation_helpers.go index 7a05aeee..66a9f554 100644 --- a/internal/application/interpolation_helpers.go +++ b/internal/application/interpolation_helpers.go @@ -23,13 +23,16 @@ func buildInterpolationContext( for name := range allStates { state := allStates[name] states[name] = interpolation.StepStateData{ - Output: state.Output, - Stderr: state.Stderr, - ExitCode: state.ExitCode, - Status: state.Status.String(), - Response: state.Response, - TokensUsed: state.TokensUsed, - JSON: state.JSON, + Output: state.Output, + Stderr: state.Stderr, + ExitCode: state.ExitCode, + Status: state.Status.String(), + Response: state.Response, + TokensUsed: state.TokensUsed, + TokensInput: state.TokensInput, + TokensOutput: state.TokensOutput, + TokensEstimated: state.TokensEstimated, + JSON: state.JSON, } } diff --git a/internal/domain/workflow/context.go b/internal/domain/workflow/context.go index b6ad9a81..2604fa85 100644 --- a/internal/domain/workflow/context.go +++ b/internal/domain/workflow/context.go @@ -34,8 +34,11 @@ type StepState struct { Response map[string]any // parsed JSON response from agent steps JSON any // parsed JSON output when output_format: json is specified (map[string]any or []any) // F033: Conversation mode fields - Conversation *ConversationState // conversation history and state (nil for non-conversation steps) - TokensUsed int // total tokens used in conversation mode + Conversation *ConversationState // conversation history and state (nil for non-conversation steps) + TokensUsed int // total tokens used + TokensInput int // input tokens (from provider JSON or estimation) + TokensOutput int // output tokens (from provider JSON or estimation) + TokensEstimated bool // true if token counts are estimates, false if from provider // C019: Output streaming fields for memory management OutputPath string // Path to temp file if output was streamed (empty if in-memory) diff --git a/internal/domain/workflow/reference.go b/internal/domain/workflow/reference.go index e2519542..7da99936 100644 --- a/internal/domain/workflow/reference.go +++ b/internal/domain/workflow/reference.go @@ -43,27 +43,35 @@ var ValidWorkflowProperties = map[string]bool{ // ValidStateProperties lists known step state properties that can be referenced. // NOTE: This map is duplicated in pkg/interpolation/reference.go — keep both in sync until dedup cleanup. var ValidStateProperties = map[string]bool{ - "Output": true, - "Stderr": true, - "ExitCode": true, - "Status": true, - "Response": true, - "TokensUsed": true, - "JSON": true, + "Output": true, + "Stderr": true, + "ExitCode": true, + "Status": true, + "Response": true, + "TokensUsed": true, + "TokensInput": true, + "TokensOutput": true, + "TokensEstimated": true, + "JSON": true, } // lowercaseToUppercase maps lowercase property names to their correct uppercase equivalents. // Used to provide actionable error messages when users use incorrect casing. // "tokens" and "tokensused" both map to "TokensUsed" as canonical aliases. var lowercaseToUppercase = map[string]string{ - "output": "Output", - "stderr": "Stderr", - "exit_code": "ExitCode", - "status": "Status", - "response": "Response", - "tokens": "TokensUsed", - "tokensused": "TokensUsed", - "json": "JSON", + "output": "Output", + "stderr": "Stderr", + "exit_code": "ExitCode", + "status": "Status", + "response": "Response", + "tokens": "TokensUsed", + "tokensused": "TokensUsed", + "tokensinput": "TokensInput", + "tokens_input": "TokensInput", + "tokensoutput": "TokensOutput", + "tokens_output": "TokensOutput", + "tokensestimated": "TokensEstimated", + "json": "JSON", } // lowercaseToUppercaseError maps lowercase error property names to their correct uppercase equivalents. diff --git a/internal/infrastructure/agents/base_cli_provider.go b/internal/infrastructure/agents/base_cli_provider.go index 4029ac5b..cbc618a5 100644 --- a/internal/infrastructure/agents/base_cli_provider.go +++ b/internal/infrastructure/agents/base_cli_provider.go @@ -13,8 +13,28 @@ import ( "github.com/awf-project/cli/internal/infrastructure/logger" ) +type fallbackTokenizer struct{} + +func (fallbackTokenizer) CountTokens(text string) (int, error) { return len(text) / 4, nil } +func (fallbackTokenizer) CountTurnsTokens(turns []string) (int, error) { + n := 0 + for _, t := range turns { + n += len(t) + } + return n / 4, nil +} +func (fallbackTokenizer) IsEstimate() bool { return true } +func (fallbackTokenizer) ModelName() string { return "fallback" } + +type tokenUsage struct { + InputTokens int + OutputTokens int + TotalTokens int + CostUSD float64 +} + // cliProviderHooks captures provider-specific behavior as function values. -// Optional hooks (extractTextContent, validateOptions, parseDisplayEvents) may be nil. +// Optional hooks (extractTextContent, validateOptions, parseDisplayEvents, extractTokenUsage) may be nil. type cliProviderHooks struct { buildExecuteArgs func(prompt string, options map[string]any) ([]string, error) buildConversationArgs func(state *workflow.ConversationState, prompt string, options map[string]any) ([]string, error) @@ -22,16 +42,18 @@ type cliProviderHooks struct { extractTextContent func(output string) string validateOptions func(options map[string]any) error parseDisplayEvents DisplayEventParser + extractTokenUsage func(rawOutput string) *tokenUsage } // baseCLIProvider encapsulates the shared Execute and ExecuteConversation // orchestration logic for all CLI-based agent providers. type baseCLIProvider struct { - name string - binary string - executor ports.CLIExecutor - logger ports.Logger - hooks cliProviderHooks + name string + binary string + executor ports.CLIExecutor + logger ports.Logger + tokenizer ports.Tokenizer + hooks cliProviderHooks } func newBaseCLIProvider(name, binary string, executor ports.CLIExecutor, log ports.Logger, hooks cliProviderHooks) *baseCLIProvider { @@ -39,11 +61,12 @@ func newBaseCLIProvider(name, binary string, executor ports.CLIExecutor, log por log = logger.NopLogger{} } return &baseCLIProvider{ - name: name, - binary: binary, - executor: executor, - logger: log, - hooks: hooks, + name: name, + binary: binary, + executor: executor, + logger: log, + tokenizer: fallbackTokenizer{}, + hooks: hooks, } } @@ -116,14 +139,26 @@ func (b *baseCLIProvider) execute(ctx context.Context, prompt string, options ma displayOutput = extractDisplayTextFromEvents(rawOutput, b.hooks.parseDisplayEvents) } + var outputTokens int + hasRealTokens := false + if b.hooks.extractTokenUsage != nil { + if usage := b.hooks.extractTokenUsage(rawOutput); usage != nil { + outputTokens = usage.TotalTokens + hasRealTokens = true + } + } + if !hasRealTokens { + outputTokens, _ = b.tokenizer.CountTokens(outputStr) //nolint:errcheck // ApproximationTokenizer never errors with a valid ratio + } + result := &workflow.AgentResult{ Provider: b.name, Output: outputStr, DisplayOutput: displayOutput, StartedAt: startedAt, CompletedAt: completedAt, - Tokens: estimateTokens(outputStr), - TokensEstimated: true, + Tokens: outputTokens, + TokensEstimated: !hasRealTokens && b.tokenizer.IsEstimate(), } return result, rawOutput, nil @@ -193,7 +228,21 @@ func (b *baseCLIProvider) executeConversation(ctx context.Context, state *workfl } assistantTurn := workflow.NewTurn(workflow.TurnRoleAssistant, outputStr) - assistantTurn.Tokens = estimateTokens(outputStr) + + var assistantTokens, inputTokens int + hasRealTokens := false + if b.hooks.extractTokenUsage != nil { + if usage := b.hooks.extractTokenUsage(rawOutput); usage != nil { + assistantTokens = usage.OutputTokens + inputTokens = usage.InputTokens + hasRealTokens = true + } + } + if !hasRealTokens { + assistantTokens, _ = b.tokenizer.CountTokens(outputStr) //nolint:errcheck // ApproximationTokenizer never errors with a valid ratio + } + assistantTurn.Tokens = assistantTokens + if addErr := workingState.AddTurn(assistantTurn); addErr != nil { return nil, "", fmt.Errorf("failed to add assistant turn: %w", addErr) } @@ -204,7 +253,14 @@ func (b *baseCLIProvider) executeConversation(ctx context.Context, state *workfl } workingState.SessionID = sessionID - inputTokens := estimateInputTokens(workingState.Turns, 1) + if !hasRealTokens { + limit := len(workingState.Turns) - 1 + turnContents := make([]string, 0, limit) + for _, t := range workingState.Turns[0:limit] { + turnContents = append(turnContents, t.Content) + } + inputTokens, _ = b.tokenizer.CountTurnsTokens(turnContents) //nolint:errcheck // ApproximationTokenizer never errors with a valid ratio + } var displayOutput string if !rawDisplay { @@ -219,7 +275,7 @@ func (b *baseCLIProvider) executeConversation(ctx context.Context, state *workfl TokensInput: inputTokens, TokensOutput: assistantTurn.Tokens, TokensTotal: inputTokens + assistantTurn.Tokens, - TokensEstimated: true, + TokensEstimated: !hasRealTokens && b.tokenizer.IsEstimate(), StartedAt: startedAt, CompletedAt: completedAt, } diff --git a/internal/infrastructure/agents/base_cli_provider_tokenizer_test.go b/internal/infrastructure/agents/base_cli_provider_tokenizer_test.go new file mode 100644 index 00000000..d6fd8caa --- /dev/null +++ b/internal/infrastructure/agents/base_cli_provider_tokenizer_test.go @@ -0,0 +1,390 @@ +package agents + +import ( + "context" + "errors" + "testing" + + "github.com/awf-project/cli/internal/domain/workflow" + "github.com/awf-project/cli/internal/testutil/mocks" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testExecuteHooks() cliProviderHooks { + return cliProviderHooks{ + buildExecuteArgs: func(prompt string, options map[string]any) ([]string, error) { + return []string{"--prompt", prompt}, nil + }, + } +} + +func testConversationHooks() cliProviderHooks { + return cliProviderHooks{ + buildConversationArgs: func(state *workflow.ConversationState, prompt string, options map[string]any) ([]string, error) { + return []string{"--prompt", prompt}, nil + }, + extractSessionID: func(output string) (string, error) { + return "", nil + }, + } +} + +// mockTokenizerWithTracker records calls and returns configured values. +type mockTokenizerWithTracker struct { + countTokensResult int + countTokensError error + countTurnsTokensResult int + countTurnsTokensError error + isEstimate bool + modelName string + + countTokensCalls []string + countTurnsTokensCalls [][]string +} + +func newMockTokenizerWithTracker(tokensResult int, isEst bool) *mockTokenizerWithTracker { + return &mockTokenizerWithTracker{ + countTokensResult: tokensResult, + countTurnsTokensResult: tokensResult, + isEstimate: isEst, + modelName: "test-tokenizer", + } +} + +func (m *mockTokenizerWithTracker) CountTokens(text string) (int, error) { + m.countTokensCalls = append(m.countTokensCalls, text) + return m.countTokensResult, m.countTokensError +} + +func (m *mockTokenizerWithTracker) CountTurnsTokens(turns []string) (int, error) { + m.countTurnsTokensCalls = append(m.countTurnsTokensCalls, turns) + return m.countTurnsTokensResult, m.countTurnsTokensError +} + +func (m *mockTokenizerWithTracker) IsEstimate() bool { + return m.isEstimate +} + +func (m *mockTokenizerWithTracker) ModelName() string { + return m.modelName +} + +func TestBaseCLIProvider_HasTokenizerField(t *testing.T) { + mockExec := mocks.NewMockCLIExecutor() + provider := newBaseCLIProvider("test", "test-binary", mockExec, nil, testExecuteHooks()) + + require.NotNil(t, provider) + assert.NotNil(t, provider.tokenizer) +} + +func TestBaseCLIProvider_Execute_UsesInjectedTokenizer(t *testing.T) { + tests := []struct { + name string + outputStr string + expectedTokens int + }{ + { + name: "simple output", + outputStr: "4", + expectedTokens: 42, + }, + { + name: "longer output", + outputStr: "This is a longer response", + expectedTokens: 100, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockExec := mocks.NewMockCLIExecutor() + mockExec.SetOutput([]byte(tt.outputStr), nil) + + mockTokenizer := newMockTokenizerWithTracker(tt.expectedTokens, false) + + provider := newBaseCLIProvider("test", "test-binary", mockExec, nil, testExecuteHooks()) + provider.tokenizer = mockTokenizer + + result, _, err := provider.execute(context.Background(), "test prompt", nil, nil, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, tt.expectedTokens, result.Tokens) + assert.Len(t, mockTokenizer.countTokensCalls, 1) + assert.Equal(t, tt.outputStr, mockTokenizer.countTokensCalls[0]) + }) + } +} + +func TestBaseCLIProvider_Execute_TokensEstimatedFromTokenizer(t *testing.T) { + tests := []struct { + name string + isEstimate bool + expectedEstimated bool + }{ + { + name: "approximation tokenizer", + isEstimate: true, + expectedEstimated: true, + }, + { + name: "exact tokenizer", + isEstimate: false, + expectedEstimated: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockExec := mocks.NewMockCLIExecutor() + mockExec.SetOutput([]byte("output"), nil) + + mockTokenizer := newMockTokenizerWithTracker(42, tt.isEstimate) + + provider := newBaseCLIProvider("test", "test-binary", mockExec, nil, testExecuteHooks()) + provider.tokenizer = mockTokenizer + + result, _, err := provider.execute(context.Background(), "test", nil, nil, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, tt.expectedEstimated, result.TokensEstimated) + }) + } +} + +func TestBaseCLIProvider_ExecuteConversation_UsesInjectedTokenizerForAssistantTokens(t *testing.T) { + mockExec := mocks.NewMockCLIExecutor() + mockExec.SetOutput([]byte("assistant response"), nil) + + mockTokenizer := newMockTokenizerWithTracker(77, false) + + provider := newBaseCLIProvider("test", "test-binary", mockExec, nil, testConversationHooks()) + provider.tokenizer = mockTokenizer + + state := &workflow.ConversationState{ + SessionID: "", + Turns: []workflow.Turn{}, + } + + result, _, err := provider.executeConversation(context.Background(), state, "user prompt", nil, nil, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, 77, result.TokensOutput) + assert.Len(t, mockTokenizer.countTokensCalls, 1) + assert.Equal(t, "assistant response", mockTokenizer.countTokensCalls[0]) +} + +func TestBaseCLIProvider_ExecuteConversation_UsesCountTurnsTokens(t *testing.T) { + mockExec := mocks.NewMockCLIExecutor() + mockExec.SetOutput([]byte("response"), nil) + + mockTokenizer := newMockTokenizerWithTracker(100, false) + + provider := newBaseCLIProvider("test", "test-binary", mockExec, nil, testConversationHooks()) + provider.tokenizer = mockTokenizer + + state := &workflow.ConversationState{ + SessionID: "", + Turns: []workflow.Turn{ + {Role: workflow.TurnRoleUser, Content: "first user message"}, + {Role: workflow.TurnRoleAssistant, Content: "first assistant response"}, + }, + } + + result, _, err := provider.executeConversation(context.Background(), state, "second user message", nil, nil, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, 100, result.TokensInput) + assert.Len(t, mockTokenizer.countTurnsTokensCalls, 1) + // After adding user turn (3 turns total) and assistant turn (4 turns total), + // limit = 4 - 1 = 3, so countTurnsTokens receives turns[0:3] = original 2 + new user turn + assert.Len(t, mockTokenizer.countTurnsTokensCalls[0], 3) + assert.Equal(t, "first user message", mockTokenizer.countTurnsTokensCalls[0][0]) + assert.Equal(t, "first assistant response", mockTokenizer.countTurnsTokensCalls[0][1]) + assert.Equal(t, "second user message", mockTokenizer.countTurnsTokensCalls[0][2]) +} + +func TestBaseCLIProvider_ExecuteConversation_TokensEstimatedFromTokenizer(t *testing.T) { + tests := []struct { + name string + isEstimate bool + expectedEstimated bool + }{ + { + name: "approximation tokenizer", + isEstimate: true, + expectedEstimated: true, + }, + { + name: "exact tokenizer", + isEstimate: false, + expectedEstimated: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockExec := mocks.NewMockCLIExecutor() + mockExec.SetOutput([]byte("response"), nil) + + mockTokenizer := newMockTokenizerWithTracker(42, tt.isEstimate) + + provider := newBaseCLIProvider("test", "test-binary", mockExec, nil, testConversationHooks()) + provider.tokenizer = mockTokenizer + + state := &workflow.ConversationState{ + SessionID: "", + Turns: []workflow.Turn{}, + } + + result, _, err := provider.executeConversation(context.Background(), state, "test", nil, nil, nil) + + require.NoError(t, err) + require.NotNil(t, result) + assert.Equal(t, tt.expectedEstimated, result.TokensEstimated) + }) + } +} + +func TestBaseCLIProvider_ExecuteConversation_NoTurnMutation(t *testing.T) { + mockExec := mocks.NewMockCLIExecutor() + mockExec.SetOutput([]byte("response"), nil) + + mockTokenizer := newMockTokenizerWithTracker(50, false) + + provider := newBaseCLIProvider("test", "test-binary", mockExec, nil, testConversationHooks()) + provider.tokenizer = mockTokenizer + + initialState := &workflow.ConversationState{ + SessionID: "", + Turns: []workflow.Turn{ + {Role: workflow.TurnRoleUser, Content: "initial turn", Tokens: 10}, + }, + } + + originalTokens := initialState.Turns[0].Tokens + + _, _, err := provider.executeConversation(context.Background(), initialState, "test prompt", nil, nil, nil) + + require.NoError(t, err) + assert.Equal(t, originalTokens, initialState.Turns[0].Tokens, "prior turn tokens should not be mutated") + assert.Len(t, initialState.Turns, 1, "original state should not be modified") +} + +func TestBaseCLIProvider_ExecuteConversation_ExtractsLastTurnContents(t *testing.T) { + mockExec := mocks.NewMockCLIExecutor() + mockExec.SetOutput([]byte("new response"), nil) + + mockTokenizer := newMockTokenizerWithTracker(100, false) + + provider := newBaseCLIProvider("test", "test-binary", mockExec, nil, testConversationHooks()) + provider.tokenizer = mockTokenizer + + state := &workflow.ConversationState{ + SessionID: "", + Turns: []workflow.Turn{ + {Role: workflow.TurnRoleUser, Content: "turn 1"}, + {Role: workflow.TurnRoleAssistant, Content: "turn 2"}, + {Role: workflow.TurnRoleUser, Content: "turn 3"}, + }, + } + + result, _, err := provider.executeConversation(context.Background(), state, "new user message", nil, nil, nil) + + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, mockTokenizer.countTurnsTokensCalls, 1) + + turnContents := mockTokenizer.countTurnsTokensCalls[0] + // Starting with 3 turns, add user (4 total), add assistant (5 total) + // limit = 5 - 1 = 4, so turns[0:4] = original 3 + new user message + assert.Len(t, turnContents, 4, "should include all prior turns and new user message, excluding newly added assistant turn") + assert.Equal(t, "turn 1", turnContents[0]) + assert.Equal(t, "turn 2", turnContents[1]) + assert.Equal(t, "turn 3", turnContents[2]) + assert.Equal(t, "new user message", turnContents[3]) +} + +func TestBaseCLIProvider_ExecuteConversation_EmptyTurnsCountTurnsTokens(t *testing.T) { + mockExec := mocks.NewMockCLIExecutor() + mockExec.SetOutput([]byte("response"), nil) + + mockTokenizer := newMockTokenizerWithTracker(42, false) + + provider := newBaseCLIProvider("test", "test-binary", mockExec, nil, testConversationHooks()) + provider.tokenizer = mockTokenizer + + state := &workflow.ConversationState{ + SessionID: "", + Turns: []workflow.Turn{}, + } + + result, _, err := provider.executeConversation(context.Background(), state, "first message", nil, nil, nil) + + require.NoError(t, err) + require.NotNil(t, result) + require.Len(t, mockTokenizer.countTurnsTokensCalls, 1) + // Starting with 0 turns, add user (1 turn total), add assistant (2 turns total) + // limit = 2 - 1 = 1, so turns[0:1] = the first user message + assert.Len(t, mockTokenizer.countTurnsTokensCalls[0], 1, "should include the first user message added by executeConversation") + assert.Equal(t, "first message", mockTokenizer.countTurnsTokensCalls[0][0]) +} + +func TestBaseCLIProvider_TokenizerNotCalledOnError(t *testing.T) { + mockExec := mocks.NewMockCLIExecutor() + mockTokenizer := newMockTokenizerWithTracker(42, false) + + hooks := cliProviderHooks{ + buildExecuteArgs: func(prompt string, options map[string]any) ([]string, error) { + return nil, errors.New("build args failed") + }, + } + + provider := newBaseCLIProvider("test", "test-binary", mockExec, nil, hooks) + provider.tokenizer = mockTokenizer + + _, _, err := provider.execute(context.Background(), "test", nil, nil, nil) + + require.Error(t, err) + assert.Empty(t, mockTokenizer.countTokensCalls, "tokenizer should not be called on buildExecuteArgs error") +} + +func TestBaseCLIProvider_TokenizerWithBothExecutePaths(t *testing.T) { + mockExec := mocks.NewMockCLIExecutor() + mockExec.SetOutput([]byte("output"), nil) + + mockTokenizer := newMockTokenizerWithTracker(88, false) + + conversationHooks := testConversationHooks() + conversationHooks.extractSessionID = func(output string) (string, error) { + return "session-1", nil + } + + provider := newBaseCLIProvider("test", "test-binary", mockExec, nil, testExecuteHooks()) + provider.tokenizer = mockTokenizer + + result1, _, err1 := provider.execute(context.Background(), "test1", nil, nil, nil) + require.NoError(t, err1) + require.NotNil(t, result1) + assert.Equal(t, 88, result1.Tokens) + + mockExec.SetOutput([]byte("output2"), nil) + provider.hooks = conversationHooks + + state := &workflow.ConversationState{ + SessionID: "", + Turns: []workflow.Turn{}, + } + + result2, _, err2 := provider.executeConversation(context.Background(), state, "test2", nil, nil, nil) + require.NoError(t, err2) + require.NotNil(t, result2) + assert.Equal(t, 88, result2.TokensOutput) + + assert.Len(t, mockTokenizer.countTokensCalls, 2, "tokenizer should be used by both execute paths") + assert.Len(t, mockTokenizer.countTurnsTokensCalls, 1, "only conversation path uses CountTurnsTokens") +} diff --git a/internal/infrastructure/agents/claude_provider.go b/internal/infrastructure/agents/claude_provider.go index dd0f92b0..e9b911a6 100644 --- a/internal/infrastructure/agents/claude_provider.go +++ b/internal/infrastructure/agents/claude_provider.go @@ -19,9 +19,10 @@ import ( // ClaudeProvider implements AgentProvider for Claude CLI. // Invokes: claude -p "prompt" --output-format stream-json type ClaudeProvider struct { - base *baseCLIProvider - logger ports.Logger - executor ports.CLIExecutor + base *baseCLIProvider + logger ports.Logger + executor ports.CLIExecutor + tokenizer ports.Tokenizer } func NewClaudeProvider(l ...ports.Logger) *ClaudeProvider { @@ -52,14 +53,19 @@ func NewClaudeProviderWithOptions(opts ...ClaudeProviderOption) *ClaudeProvider } func (p *ClaudeProvider) newBase() *baseCLIProvider { - return newBaseCLIProvider("claude", "claude", p.executor, p.logger, cliProviderHooks{ + b := newBaseCLIProvider("claude", "claude", p.executor, p.logger, cliProviderHooks{ buildExecuteArgs: p.buildExecuteArgs, buildConversationArgs: p.buildConversationArgs, extractSessionID: p.extractSessionID, extractTextContent: p.extractTextFromJSON, validateOptions: validateClaudeOptions, parseDisplayEvents: p.parseClaudeDisplayEvents, + extractTokenUsage: p.extractClaudeTokenUsage, }) + if p.tokenizer != nil { + b.tokenizer = p.tokenizer + } + return b } func (p *ClaudeProvider) Execute(ctx context.Context, prompt string, options map[string]any, stdout, stderr io.Writer) (*workflow.AgentResult, error) { @@ -76,7 +82,10 @@ func (p *ClaudeProvider) Execute(ctx context.Context, prompt string, options map // downstream JSON post-processing. if extracted := p.extractTextFromJSON(rawOutput); extracted != "" { result.Output = extracted - result.Tokens = estimateTokens(extracted) + if result.TokensEstimated { + tokens, _ := p.base.tokenizer.CountTokens(extracted) //nolint:errcheck // ApproximationTokenizer never errors with a valid ratio + result.Tokens = tokens + } } if userFormat == "json" || userFormat == "stream-json" { @@ -261,6 +270,35 @@ func (p *ClaudeProvider) extractTextFromJSON(output string) string { return "" } +func (p *ClaudeProvider) extractClaudeTokenUsage(rawOutput string) *tokenUsage { + evt := p.extractResultEvent(rawOutput) + if evt == nil { + return nil + } + usageVal, ok := evt["usage"] + if !ok || usageVal == nil { + return nil + } + usage, ok := usageVal.(map[string]any) + if !ok { + return nil + } + input := intFromMap(usage, "input_tokens") + + intFromMap(usage, "cache_creation_input_tokens") + + intFromMap(usage, "cache_read_input_tokens") + output := intFromMap(usage, "output_tokens") + var costUSD float64 + if v, ok := evt["total_cost_usd"].(float64); ok { + costUSD = v + } + return &tokenUsage{ + InputTokens: input, + OutputTokens: output, + TotalTokens: input + output, + CostUSD: costUSD, + } +} + func (p *ClaudeProvider) parseClaudeDisplayEvents(line []byte) []DisplayEvent { // Replace NUL bytes with a space to avoid JSON parse errors on malformed input. line = bytes.ReplaceAll(line, []byte{0x00}, []byte(" ")) diff --git a/internal/infrastructure/agents/codex_provider.go b/internal/infrastructure/agents/codex_provider.go index 4f2f0154..503bdb22 100644 --- a/internal/infrastructure/agents/codex_provider.go +++ b/internal/infrastructure/agents/codex_provider.go @@ -18,9 +18,10 @@ import ( // CodexProvider implements AgentProvider for Codex CLI. // Invokes: codex exec --json "prompt" type CodexProvider struct { - base *baseCLIProvider - logger ports.Logger - executor ports.CLIExecutor + base *baseCLIProvider + logger ports.Logger + executor ports.CLIExecutor + tokenizer ports.Tokenizer } func NewCodexProvider() *CodexProvider { @@ -45,13 +46,18 @@ func NewCodexProviderWithOptions(opts ...CodexProviderOption) *CodexProvider { } func (p *CodexProvider) newBase() *baseCLIProvider { - return newBaseCLIProvider("codex", "codex", p.executor, p.logger, cliProviderHooks{ + b := newBaseCLIProvider("codex", "codex", p.executor, p.logger, cliProviderHooks{ buildExecuteArgs: p.buildExecuteArgs, buildConversationArgs: p.buildConversationArgs, extractSessionID: p.extractSessionID, validateOptions: validateCodexOptions, parseDisplayEvents: p.parseCodexDisplayEvents, + extractTokenUsage: p.extractCodexTokenUsage, }) + if p.tokenizer != nil { + b.tokenizer = p.tokenizer + } + return b } func (p *CodexProvider) Execute(ctx context.Context, prompt string, options map[string]any, stdout, stderr io.Writer) (*workflow.AgentResult, error) { @@ -67,7 +73,10 @@ func (p *CodexProvider) Execute(ctx context.Context, prompt string, options map[ if userFormat != "json" && userFormat != "stream-json" { if extracted := extractDisplayTextFromEvents(rawOutput, p.parseCodexDisplayEvents); extracted != "" { result.Output = extracted - result.Tokens = estimateTokens(extracted) + if result.TokensEstimated { + tokens, _ := p.base.tokenizer.CountTokens(extracted) //nolint:errcheck // ApproximationTokenizer never errors with a valid ratio + result.Tokens = tokens + } } } @@ -108,22 +117,13 @@ func (p *CodexProvider) buildConversationArgs(state *workflow.ConversationState, } else { // Codex CLI has no --system-prompt flag; inline the system prompt into // the first-turn message only when a session is not yet established. - effectivePrompt := buildCodexFirstTurnPrompt(prompt, options) + effectivePrompt := buildFirstTurnPrompt(prompt, options) args = []string{"exec", "--json", effectivePrompt} } args = appendCodexOptions(args, options) return args, nil } -// buildCodexFirstTurnPrompt prepends an optional system prompt for the first turn. -// Codex CLI has no --system-prompt flag, so the system context must be embedded in the message. -func buildCodexFirstTurnPrompt(userPrompt string, options map[string]any) string { - if sysPrompt, ok := getStringOption(options, "system_prompt"); ok && sysPrompt != "" { - return sysPrompt + "\n\n" + userPrompt - } - return userPrompt -} - // appendCodexOptions appends Codex CLI flags from options; unknown keys are silently ignored. func appendCodexOptions(args []string, options map[string]any) []string { if model, ok := getStringOption(options, "model"); ok && model != "" { @@ -157,6 +157,28 @@ func (p *CodexProvider) extractSessionID(output string) (string, error) { return "", errors.New("thread_id is not a non-empty string") } +func (p *CodexProvider) extractCodexTokenUsage(rawOutput string) *tokenUsage { + evt := findFirstNDJSONEvent(rawOutput, "turn.completed") + if evt == nil { + return nil + } + usageVal, ok := evt["usage"] + if !ok || usageVal == nil { + return nil + } + usage, ok := usageVal.(map[string]any) + if !ok { + return nil + } + input := intFromMap(usage, "input_tokens") + output := intFromMap(usage, "output_tokens") + return &tokenUsage{ + InputTokens: input, + OutputTokens: output, + TotalTokens: input + output, + } +} + func validateCodexOptions(options map[string]any) error { if options == nil { return nil @@ -208,7 +230,7 @@ func (p *CodexProvider) parseCodexDisplayEvents(line []byte) []DisplayEvent { return []DisplayEvent{{Type: evt.Type, Kind: EventText, Text: evt.Item.Text}} case "function_call": // Codex does not emit tool-call IDs; ID is always empty. - preview := parseToolCallArgPreview(evt.Item.Arguments) + preview := extractArgPreview(evt.Item.Arguments) return []DisplayEvent{{Type: evt.Type, Kind: EventToolUse, Name: evt.Item.Name, Arg: preview, ID: ""}} } } diff --git a/internal/infrastructure/agents/copilot_provider.go b/internal/infrastructure/agents/copilot_provider.go index 2f138339..78190e23 100644 --- a/internal/infrastructure/agents/copilot_provider.go +++ b/internal/infrastructure/agents/copilot_provider.go @@ -17,9 +17,10 @@ import ( // CopilotProvider implements AgentProvider for GitHub Copilot CLI. // Invokes: copilot -p "prompt" --output-format=json --silent type CopilotProvider struct { - base *baseCLIProvider - logger ports.Logger - executor ports.CLIExecutor + base *baseCLIProvider + logger ports.Logger + executor ports.CLIExecutor + tokenizer ports.Tokenizer } func NewCopilotProvider() *CopilotProvider { @@ -44,14 +45,19 @@ func NewCopilotProviderWithOptions(opts ...CopilotProviderOption) *CopilotProvid } func (p *CopilotProvider) newBase() *baseCLIProvider { - return newBaseCLIProvider("github_copilot", "copilot", p.executor, p.logger, cliProviderHooks{ + b := newBaseCLIProvider("github_copilot", "copilot", p.executor, p.logger, cliProviderHooks{ buildExecuteArgs: p.buildCopilotExecuteArgs, buildConversationArgs: p.buildCopilotConversationArgs, extractSessionID: p.extractCopilotSessionID, extractTextContent: p.extractCopilotTextContent, validateOptions: validateCopilotOptions, parseDisplayEvents: p.parseCopilotDisplayEvents, + extractTokenUsage: p.extractCopilotTokenUsage, }) + if p.tokenizer != nil { + b.tokenizer = p.tokenizer + } + return b } func (p *CopilotProvider) Execute(ctx context.Context, prompt string, options map[string]any, stdout, stderr io.Writer) (*workflow.AgentResult, error) { @@ -61,7 +67,10 @@ func (p *CopilotProvider) Execute(ctx context.Context, prompt string, options ma } if extracted := p.extractCopilotTextContent(rawOutput); extracted != "" { result.Output = extracted - result.Tokens = estimateTokens(extracted) + if result.TokensEstimated { + tokens, _ := p.base.tokenizer.CountTokens(extracted) //nolint:errcheck // ApproximationTokenizer never errors with a valid ratio + result.Tokens = tokens + } } return result, nil } @@ -97,22 +106,13 @@ func (p *CopilotProvider) buildCopilotConversationArgs(state *workflow.Conversat if state.SessionID != "" { args = []string{"--resume=" + state.SessionID, "-p", prompt, "--output-format=json", "--silent"} } else { - effectivePrompt := buildCopilotFirstTurnPrompt(prompt, options) + effectivePrompt := buildFirstTurnPrompt(prompt, options) args = []string{"-p", effectivePrompt, "--output-format=json", "--silent"} } args = appendCopilotOptions(args, options) return args, nil } -// buildCopilotFirstTurnPrompt prepends an optional system prompt. -// Copilot CLI has no --system-prompt flag; the system context must be embedded in the message. -func buildCopilotFirstTurnPrompt(userPrompt string, options map[string]any) string { - if sysPrompt, ok := getStringOption(options, "system_prompt"); ok && sysPrompt != "" { - return sysPrompt + "\n\n" + userPrompt - } - return userPrompt -} - // appendCopilotOptions appends Copilot CLI flags from options; unknown keys are silently ignored. func appendCopilotOptions(args []string, options map[string]any) []string { if model, ok := getStringOption(options, "model"); ok && model != "" { @@ -203,6 +203,25 @@ func (p *CopilotProvider) extractCopilotSessionID(output string) (string, error) return "", errors.New("sessionId is not a non-empty string") } +func (p *CopilotProvider) extractCopilotTokenUsage(rawOutput string) *tokenUsage { + evt := findLastNDJSONEvent(rawOutput, "assistant.message") + if evt == nil { + return nil + } + data, ok := evt["data"].(map[string]any) + if !ok { + return nil + } + outputTokens := intFromMap(data, "outputTokens") + if outputTokens == 0 { + return nil + } + return &tokenUsage{ + OutputTokens: outputTokens, + TotalTokens: outputTokens, + } +} + func (p *CopilotProvider) parseCopilotDisplayEvents(line []byte) []DisplayEvent { var evt struct { Type string `json:"type"` diff --git a/internal/infrastructure/agents/gemini_provider.go b/internal/infrastructure/agents/gemini_provider.go index aaef605a..218e0e0d 100644 --- a/internal/infrastructure/agents/gemini_provider.go +++ b/internal/infrastructure/agents/gemini_provider.go @@ -16,8 +16,9 @@ import ( // GeminiProvider implements AgentProvider for Gemini CLI. // Invokes: gemini -p "prompt" type GeminiProvider struct { - base *baseCLIProvider - executor ports.CLIExecutor + base *baseCLIProvider + executor ports.CLIExecutor + tokenizer ports.Tokenizer } func NewGeminiProvider() *GeminiProvider { @@ -40,13 +41,18 @@ func NewGeminiProviderWithOptions(opts ...GeminiProviderOption) *GeminiProvider } func (p *GeminiProvider) newBase() *baseCLIProvider { - return newBaseCLIProvider("gemini", "gemini", p.executor, nil, cliProviderHooks{ + b := newBaseCLIProvider("gemini", "gemini", p.executor, nil, cliProviderHooks{ buildExecuteArgs: p.buildExecuteArgs, buildConversationArgs: p.buildConversationArgs, extractSessionID: p.extractSessionID, validateOptions: validateGeminiOptions, parseDisplayEvents: p.parseGeminiDisplayEvents, + extractTokenUsage: p.extractGeminiTokenUsage, }) + if p.tokenizer != nil { + b.tokenizer = p.tokenizer + } + return b } func validateGeminiOptions(options map[string]any) error { @@ -78,7 +84,10 @@ func (p *GeminiProvider) Execute(ctx context.Context, prompt string, options map // state.Output breaks any downstream JSON post-processing. if extracted := extractDisplayTextFromEvents(rawOutput, p.parseGeminiDisplayEvents); extracted != "" { result.Output = extracted - result.Tokens = estimateTokens(extracted) + if result.TokensEstimated { + tokens, _ := p.base.tokenizer.CountTokens(extracted) //nolint:errcheck // ApproximationTokenizer never errors with a valid ratio + result.Tokens = tokens + } } userFormat, _ := getStringOption(options, "output_format") @@ -111,18 +120,23 @@ func (p *GeminiProvider) Validate() error { return nil } -func (p *GeminiProvider) buildExecuteArgs(prompt string, options map[string]any) ([]string, error) { - args := []string{"-p", prompt} - +func prependGeminiGlobalFlags(args []string, options map[string]any) []string { if model, ok := getStringOption(options, "model"); ok { args = append([]string{"--model", model}, args...) } - // Always force stream-json NDJSON at the CLI level so the F082 display filter - // and text extraction have a consistent wire format (F082, aligned with Claude). args = append([]string{"--output-format", "stream-json"}, args...) if skipPerms, ok := getBoolOption(options, "dangerously_skip_permissions"); ok && skipPerms { args = append([]string{"--approval-mode=yolo"}, args...) } + return args +} + +func (p *GeminiProvider) buildExecuteArgs(prompt string, options map[string]any) ([]string, error) { + args := []string{"-p", prompt} + + // Always force stream-json NDJSON at the CLI level so the F082 display filter + // and text extraction have a consistent wire format (F082, aligned with Claude). + args = prependGeminiGlobalFlags(args, options) return args, nil } @@ -134,21 +148,11 @@ func (p *GeminiProvider) buildConversationArgs(state *workflow.ConversationState if state != nil && state.SessionID != "" { args = []string{"--resume", state.SessionID, "-p", prompt} } else { - effectivePrompt := prompt - if sysPrompt, ok := getStringOption(options, "system_prompt"); ok && sysPrompt != "" { - effectivePrompt = sysPrompt + "\n\n" + prompt - } - args = []string{"-p", effectivePrompt} + args = []string{"-p", buildFirstTurnPrompt(prompt, options)} } - if model, ok := getStringOption(options, "model"); ok { - args = append([]string{"--model", model}, args...) - } // Force stream-json unconditionally for reliable session ID extraction. - args = append([]string{"--output-format", "stream-json"}, args...) - if skipPerms, ok := getBoolOption(options, "dangerously_skip_permissions"); ok && skipPerms { - args = append([]string{"--approval-mode=yolo"}, args...) - } + args = prependGeminiGlobalFlags(args, options) return args, nil } @@ -179,6 +183,22 @@ func (p *GeminiProvider) extractSessionID(output string) (string, error) { return str, nil } +func (p *GeminiProvider) extractGeminiTokenUsage(rawOutput string) *tokenUsage { + evt := findFirstNDJSONEvent(rawOutput, "result") + if evt == nil { + return nil + } + stats, ok := evt["stats"].(map[string]any) + if !ok { + return nil + } + return &tokenUsage{ + InputTokens: intFromMap(stats, "input_tokens"), + OutputTokens: intFromMap(stats, "output_tokens"), + TotalTokens: intFromMap(stats, "total_tokens"), + } +} + func (p *GeminiProvider) parseGeminiDisplayEvents(line []byte) []DisplayEvent { var evt struct { Type string `json:"type"` diff --git a/internal/infrastructure/agents/helpers.go b/internal/infrastructure/agents/helpers.go index 8310610d..73423837 100644 --- a/internal/infrastructure/agents/helpers.go +++ b/internal/infrastructure/agents/helpers.go @@ -7,10 +7,6 @@ import ( "github.com/awf-project/cli/internal/domain/workflow" ) -func estimateTokens(output string) int { - return len(output) / 4 -} - func cloneState(state *workflow.ConversationState) *workflow.ConversationState { if state == nil { return nil @@ -36,6 +32,13 @@ func getStringOption(options map[string]any, key string) (string, bool) { return val, ok } +func buildFirstTurnPrompt(userPrompt string, options map[string]any) string { + if sysPrompt, ok := getStringOption(options, "system_prompt"); ok && sysPrompt != "" { + return sysPrompt + "\n\n" + userPrompt + } + return userPrompt +} + func getBoolOption(options map[string]any, key string) (value, found bool) { if options == nil { return false, false @@ -44,21 +47,6 @@ func getBoolOption(options map[string]any, key string) (value, found bool) { return val, ok } -func estimateInputTokens(turns []workflow.Turn, excludeLastN int) int { - inputTokens := 0 - limit := len(turns) - excludeLastN - if limit < 0 { - limit = 0 - } - for i := 0; i < limit; i++ { - if turns[i].Tokens == 0 { - turns[i].Tokens = estimateTokens(turns[i].Content) - } - inputTokens += turns[i].Tokens - } - return inputTokens -} - func tryParseJSONResponse(output string) map[string]any { trimmed := strings.TrimSpace(output) if !strings.HasPrefix(trimmed, "{") { @@ -90,6 +78,31 @@ func findFirstNDJSONEvent(output, eventType string) map[string]any { return nil } +func findLastNDJSONEvent(output, eventType string) map[string]any { + var found map[string]any + for _, line := range strings.Split(output, "\n") { + line = strings.TrimSpace(line) + if line == "" { + continue + } + var evt map[string]any + if err := json.Unmarshal([]byte(line), &evt); err != nil { + continue + } + if t, ok := evt["type"].(string); ok && t == eventType { + found = evt + } + } + return found +} + +func intFromMap(m map[string]any, key string) int { + if v, ok := m[key].(float64); ok { + return int(v) + } + return 0 +} + // argPreviewKeys defines the ordered list of input map keys used to extract a // human-readable preview for EventToolUse.Arg. The first matching key wins. var argPreviewKeys = []string{"file_path", "command", "cmd", "query", "pattern"} @@ -135,9 +148,3 @@ func extractArgPreview(arguments string) string { } return extractArgPreviewFromMap(m) } - -// parseToolCallArgPreview is an alias for extractArgPreview retained for -// compatibility with providers that use the longer name. -func parseToolCallArgPreview(arguments string) string { - return extractArgPreview(arguments) -} diff --git a/internal/infrastructure/agents/helpers_test.go b/internal/infrastructure/agents/helpers_test.go index 9f6d8e4d..a90d0566 100644 --- a/internal/infrastructure/agents/helpers_test.go +++ b/internal/infrastructure/agents/helpers_test.go @@ -8,42 +8,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestEstimateTokens(t *testing.T) { - tests := []struct { - name string - output string - expected int - }{ - { - name: "empty_string", - output: "", - expected: 0, - }, - { - name: "short_string", - output: "test", // 4 chars = 1 token - expected: 1, - }, - { - name: "medium_string", - output: "hello world test", // 16 chars = 4 tokens - expected: 4, - }, - { - name: "long_string", - output: "This is a longer string for testing token estimation accurately.", // 64 chars - expected: 16, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := estimateTokens(tt.output) - assert.Equal(t, tt.expected, result) - }) - } -} - func TestCloneState(t *testing.T) { tests := []struct { name string @@ -213,63 +177,6 @@ func TestGetBoolOption(t *testing.T) { } } -func TestEstimateInputTokens(t *testing.T) { - tests := []struct { - name string - turns []workflow.Turn - excludeLastN int - expected int - }{ - { - name: "empty_turns", - turns: []workflow.Turn{}, - excludeLastN: 0, - expected: 0, - }, - { - name: "single_turn_exclude_one", - turns: []workflow.Turn{ - {Role: "user", Content: "test", Tokens: 10}, - }, - excludeLastN: 1, - expected: 0, - }, - { - name: "multiple_turns_exclude_one", - turns: []workflow.Turn{ - {Role: "user", Content: "hello", Tokens: 5}, - {Role: "assistant", Content: "hi there", Tokens: 10}, - }, - excludeLastN: 1, - expected: 5, - }, - { - name: "estimate_missing_tokens", - turns: []workflow.Turn{ - {Role: "user", Content: "test", Tokens: 0}, // Will be estimated: 4/4=1 - {Role: "assistant", Content: "response", Tokens: 0}, - }, - excludeLastN: 1, - expected: 1, - }, - { - name: "exclude_more_than_available", - turns: []workflow.Turn{ - {Role: "user", Content: "test", Tokens: 10}, - }, - excludeLastN: 5, - expected: 0, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := estimateInputTokens(tt.turns, tt.excludeLastN) - assert.Equal(t, tt.expected, result) - }) - } -} - func TestTryParseJSONResponse(t *testing.T) { tests := []struct { name string diff --git a/internal/infrastructure/agents/opencode_provider.go b/internal/infrastructure/agents/opencode_provider.go index 6cd21212..9166303f 100644 --- a/internal/infrastructure/agents/opencode_provider.go +++ b/internal/infrastructure/agents/opencode_provider.go @@ -17,13 +17,12 @@ import ( // OpenCodeProvider implements AgentProvider for OpenCode CLI. // Invokes: opencode run "prompt" type OpenCodeProvider struct { - base *baseCLIProvider - logger ports.Logger - executor ports.CLIExecutor + base *baseCLIProvider + logger ports.Logger + executor ports.CLIExecutor + tokenizer ports.Tokenizer } -// NewOpenCodeProvider creates a new OpenCodeProvider. -// If no executor is provided, ExecCLIExecutor is used by default. func NewOpenCodeProvider() *OpenCodeProvider { p := &OpenCodeProvider{ logger: logger.NopLogger{}, @@ -33,7 +32,6 @@ func NewOpenCodeProvider() *OpenCodeProvider { return p } -// NewOpenCodeProviderWithOptions creates a new OpenCodeProvider with functional options. func NewOpenCodeProviderWithOptions(opts ...OpenCodeProviderOption) *OpenCodeProvider { p := &OpenCodeProvider{ logger: logger.NopLogger{}, @@ -47,13 +45,18 @@ func NewOpenCodeProviderWithOptions(opts ...OpenCodeProviderOption) *OpenCodePro } func (p *OpenCodeProvider) newBase() *baseCLIProvider { - return newBaseCLIProvider("opencode", "opencode", p.executor, p.logger, cliProviderHooks{ + b := newBaseCLIProvider("opencode", "opencode", p.executor, p.logger, cliProviderHooks{ buildExecuteArgs: p.buildExecuteArgs, buildConversationArgs: p.buildConversationArgs, extractSessionID: p.extractSessionID, validateOptions: validateOpenCodeOptions, parseDisplayEvents: p.parseOpencodeDisplayEvents, + extractTokenUsage: p.extractOpenCodeTokenUsage, }) + if p.tokenizer != nil { + b.tokenizer = p.tokenizer + } + return b } // Execute invokes the OpenCode CLI with the given prompt and options. @@ -68,7 +71,10 @@ func (p *OpenCodeProvider) Execute(ctx context.Context, prompt string, options m // leaving NDJSON in state.Output breaks any downstream JSON post-processing. if extracted := extractDisplayTextFromEvents(rawOutput, p.parseOpencodeDisplayEvents); extracted != "" { result.Output = extracted - result.Tokens = estimateTokens(extracted) + if result.TokensEstimated { + tokens, _ := p.base.tokenizer.CountTokens(extracted) //nolint:errcheck // ApproximationTokenizer never errors with a valid ratio + result.Tokens = tokens + } } userFormat, _ := getStringOption(options, "output_format") @@ -119,9 +125,7 @@ func (p *OpenCodeProvider) buildExecuteArgs(prompt string, options map[string]an func (p *OpenCodeProvider) buildConversationArgs(state *workflow.ConversationState, prompt string, options map[string]any) ([]string, error) { effectivePrompt := prompt if len(state.Turns) == 0 { - if sysPrompt, ok := getStringOption(options, "system_prompt"); ok && sysPrompt != "" { - effectivePrompt = sysPrompt + "\n\n" + prompt - } + effectivePrompt = buildFirstTurnPrompt(prompt, options) } args := []string{"run", effectivePrompt} @@ -226,6 +230,28 @@ func (p *OpenCodeProvider) extractSessionID(output string) (string, error) { return sessionID, nil } +func (p *OpenCodeProvider) extractOpenCodeTokenUsage(rawOutput string) *tokenUsage { + evt := findFirstNDJSONEvent(rawOutput, "step_finish") + if evt == nil { + return nil + } + part, ok := evt["part"].(map[string]any) + if !ok { + return nil + } + tokens, ok := part["tokens"].(map[string]any) + if !ok { + return nil + } + cost, _ := part["cost"].(float64) //nolint:errcheck // type assertion; zero-value fallback is intentional + return &tokenUsage{ + InputTokens: intFromMap(tokens, "input"), + OutputTokens: intFromMap(tokens, "output"), + TotalTokens: intFromMap(tokens, "total"), + CostUSD: cost, + } +} + func (p *OpenCodeProvider) parseOpencodeDisplayEvents(line []byte) []DisplayEvent { // Escape NUL bytes to JSON unicode sequence so json.Unmarshal preserves them // in decoded string fields while avoiding parse errors. diff --git a/internal/infrastructure/agents/options.go b/internal/infrastructure/agents/options.go index eda912d1..1c7b0faa 100644 --- a/internal/infrastructure/agents/options.go +++ b/internal/infrastructure/agents/options.go @@ -13,6 +13,12 @@ func WithClaudeExecutor(executor ports.CLIExecutor) ClaudeProviderOption { } } +func WithClaudeTokenizer(tok ports.Tokenizer) ClaudeProviderOption { + return func(p *ClaudeProvider) { + p.tokenizer = tok + } +} + type GeminiProviderOption func(*GeminiProvider) func WithGeminiExecutor(executor ports.CLIExecutor) GeminiProviderOption { @@ -21,6 +27,12 @@ func WithGeminiExecutor(executor ports.CLIExecutor) GeminiProviderOption { } } +func WithGeminiTokenizer(tok ports.Tokenizer) GeminiProviderOption { + return func(p *GeminiProvider) { + p.tokenizer = tok + } +} + type CodexProviderOption func(*CodexProvider) func WithCodexExecutor(executor ports.CLIExecutor) CodexProviderOption { @@ -35,6 +47,12 @@ func WithCodexLogger(l ports.Logger) CodexProviderOption { } } +func WithCodexTokenizer(tok ports.Tokenizer) CodexProviderOption { + return func(p *CodexProvider) { + p.tokenizer = tok + } +} + type OpenCodeProviderOption func(*OpenCodeProvider) func WithOpenCodeExecutor(executor ports.CLIExecutor) OpenCodeProviderOption { @@ -49,6 +67,12 @@ func WithOpenCodeLogger(l ports.Logger) OpenCodeProviderOption { } } +func WithOpenCodeTokenizer(tok ports.Tokenizer) OpenCodeProviderOption { + return func(p *OpenCodeProvider) { + p.tokenizer = tok + } +} + type CopilotProviderOption func(*CopilotProvider) func WithCopilotExecutor(executor ports.CLIExecutor) CopilotProviderOption { @@ -63,6 +87,12 @@ func WithCopilotLogger(l ports.Logger) CopilotProviderOption { } } +func WithCopilotTokenizer(tok ports.Tokenizer) CopilotProviderOption { + return func(p *CopilotProvider) { + p.tokenizer = tok + } +} + type OpenAICompatibleProviderOption func(*OpenAICompatibleProvider) func WithHTTPClient(client *httpx.Client) OpenAICompatibleProviderOption { diff --git a/internal/infrastructure/agents/provider_options_test.go b/internal/infrastructure/agents/provider_options_test.go index f0fa2c40..c3b3dbe2 100644 --- a/internal/infrastructure/agents/provider_options_test.go +++ b/internal/infrastructure/agents/provider_options_test.go @@ -5,13 +5,29 @@ import ( "errors" "testing" + "github.com/awf-project/cli/internal/domain/ports" "github.com/awf-project/cli/internal/testutil/mocks" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -// Component: T004 - Provider Constructor Functional Options -// Tests the refactored provider constructors with CLIExecutor dependency injection +type stubTokenizer struct{} + +func (stubTokenizer) CountTokens(string) (int, error) { return 0, nil } +func (stubTokenizer) CountTurnsTokens([]string) (int, error) { return 0, nil } +func (stubTokenizer) IsEstimate() bool { return true } +func (stubTokenizer) ModelName() string { return "stub" } + +var _ ports.Tokenizer = stubTokenizer{} + +type countingTokenizer struct{ count int } + +func (t countingTokenizer) CountTokens(string) (int, error) { return t.count, nil } +func (t countingTokenizer) CountTurnsTokens([]string) (int, error) { return t.count, nil } +func (t countingTokenizer) IsEstimate() bool { return false } +func (t countingTokenizer) ModelName() string { return "counting" } + +var _ ports.Tokenizer = countingTokenizer{} func TestClaudeProvider_NewWithOptions_HappyPath(t *testing.T) { tests := []struct { @@ -455,6 +471,121 @@ func TestProviderOptions_ErrorHandling(t *testing.T) { }) } +func TestWithCopilotTokenizer(t *testing.T) { + tok := stubTokenizer{} + provider := NewCopilotProviderWithOptions(WithCopilotTokenizer(tok)) + require.NotNil(t, provider) + assert.Equal(t, ports.Tokenizer(tok), provider.base.tokenizer) +} + +func TestWithCodexTokenizer(t *testing.T) { + tok := stubTokenizer{} + provider := NewCodexProviderWithOptions(WithCodexTokenizer(tok)) + require.NotNil(t, provider) + assert.Equal(t, ports.Tokenizer(tok), provider.base.tokenizer) +} + +func TestWithOpenCodeTokenizer(t *testing.T) { + tok := stubTokenizer{} + provider := NewOpenCodeProviderWithOptions(WithOpenCodeTokenizer(tok)) + require.NotNil(t, provider) + assert.Equal(t, ports.Tokenizer(tok), provider.base.tokenizer) +} + +func TestWithClaudeTokenizer(t *testing.T) { + tok := stubTokenizer{} + provider := NewClaudeProviderWithOptions(WithClaudeTokenizer(tok)) + require.NotNil(t, provider) + assert.Equal(t, ports.Tokenizer(tok), provider.base.tokenizer) +} + +func TestWithGeminiTokenizer(t *testing.T) { + tok := stubTokenizer{} + provider := NewGeminiProviderWithOptions(WithGeminiTokenizer(tok)) + require.NotNil(t, provider) + assert.Equal(t, ports.Tokenizer(tok), provider.base.tokenizer) +} + +func TestClaudeProvider_Execute_UsesInjectedTokenizer(t *testing.T) { + const expectedTokens = 99 + tok := countingTokenizer{count: expectedTokens} + mockExec := mocks.NewMockCLIExecutor() + mockExec.SetOutput([]byte(`{"type":"result","result":"extracted text here"}`), []byte("")) + + provider := NewClaudeProviderWithOptions( + WithClaudeExecutor(mockExec), + WithClaudeTokenizer(tok), + ) + + result, err := provider.Execute(context.Background(), "prompt", nil, nil, nil) + require.NoError(t, err) + assert.Equal(t, expectedTokens, result.Tokens) +} + +func TestGeminiProvider_Execute_UsesInjectedTokenizer(t *testing.T) { + const expectedTokens = 99 + tok := countingTokenizer{count: expectedTokens} + mockExec := mocks.NewMockCLIExecutor() + mockExec.SetOutput([]byte(`{"type":"message","role":"assistant","content":"gemini text here"}`), []byte("")) + + provider := NewGeminiProviderWithOptions( + WithGeminiExecutor(mockExec), + WithGeminiTokenizer(tok), + ) + + result, err := provider.Execute(context.Background(), "prompt", nil, nil, nil) + require.NoError(t, err) + assert.Equal(t, expectedTokens, result.Tokens) +} + +func TestCopilotProvider_Execute_UsesInjectedTokenizer(t *testing.T) { + const expectedTokens = 88 + tok := countingTokenizer{count: expectedTokens} + mockExec := mocks.NewMockCLIExecutor() + mockExec.SetOutput([]byte(`{"type":"assistant.message","data":{"content":"copilot extracted text here","messageId":"m1"}}`+"\n"+`{"type":"result","sessionId":"s1","exitCode":0}`), []byte("")) + + provider := NewCopilotProviderWithOptions( + WithCopilotExecutor(mockExec), + WithCopilotTokenizer(tok), + ) + + result, err := provider.Execute(context.Background(), "prompt", nil, nil, nil) + require.NoError(t, err) + assert.Equal(t, expectedTokens, result.Tokens) +} + +func TestCodexProvider_Execute_UsesInjectedTokenizer(t *testing.T) { + const expectedTokens = 77 + tok := countingTokenizer{count: expectedTokens} + mockExec := mocks.NewMockCLIExecutor() + mockExec.SetOutput([]byte(`{"type":"item.completed","item":{"item_type":"assistant_message","text":"codex extracted text here"}}`), []byte("")) + + provider := NewCodexProviderWithOptions( + WithCodexExecutor(mockExec), + WithCodexTokenizer(tok), + ) + + result, err := provider.Execute(context.Background(), "prompt", nil, nil, nil) + require.NoError(t, err) + assert.Equal(t, expectedTokens, result.Tokens) +} + +func TestOpenCodeProvider_Execute_UsesInjectedTokenizer(t *testing.T) { + const expectedTokens = 66 + tok := countingTokenizer{count: expectedTokens} + mockExec := mocks.NewMockCLIExecutor() + mockExec.SetOutput([]byte(`{"type":"text","part":{"text":"opencode extracted text here"}}`), []byte("")) + + provider := NewOpenCodeProviderWithOptions( + WithOpenCodeExecutor(mockExec), + WithOpenCodeTokenizer(tok), + ) + + result, err := provider.Execute(context.Background(), "prompt", nil, nil, nil) + require.NoError(t, err) + assert.Equal(t, expectedTokens, result.Tokens) +} + func TestProviderOptions_Integration(t *testing.T) { t.Run("claude provider with mock executor executes successfully", func(t *testing.T) { mockExec := mocks.NewMockCLIExecutor() diff --git a/internal/infrastructure/tokenizer/tiktoken_tokenizer.go b/internal/infrastructure/tokenizer/tiktoken_tokenizer.go deleted file mode 100644 index 718b7353..00000000 --- a/internal/infrastructure/tokenizer/tiktoken_tokenizer.go +++ /dev/null @@ -1,67 +0,0 @@ -package tokenizer - -import ( - "fmt" - - "github.com/awf-project/cli/internal/domain/ports" - tiktoken "github.com/pkoukk/tiktoken-go" -) - -// TiktokenTokenizer implements ports.Tokenizer using pkoukk/tiktoken-go library. -// Provides accurate token counting for OpenAI-compatible models. -type TiktokenTokenizer struct { - modelName string -} - -// NewTiktokenTokenizer creates a new TiktokenTokenizer for the specified model. -// Common models: "cl100k_base" (GPT-4, GPT-3.5-turbo), "p50k_base" (Codex), "r50k_base" (GPT-3). -func NewTiktokenTokenizer(modelName string) (ports.Tokenizer, error) { - return &TiktokenTokenizer{ - modelName: modelName, - }, nil -} - -// CountTokens returns the exact number of tokens in the given text. -func (t *TiktokenTokenizer) CountTokens(text string) (int, error) { - if text == "" { - return 0, nil - } - - // Get encoding for the model - encoding, err := tiktoken.GetEncoding(t.modelName) - if err != nil { - return 0, fmt.Errorf("failed to get encoding for model %s: %w", t.modelName, err) - } - - // Encode the text to get tokens - tokens := encoding.Encode(text, nil, nil) - return len(tokens), nil -} - -// CountTurnsTokens returns the total token count across multiple conversation turns. -func (t *TiktokenTokenizer) CountTurnsTokens(turns []string) (int, error) { - if len(turns) == 0 { - return 0, nil - } - - totalTokens := 0 - for _, turn := range turns { - count, err := t.CountTokens(turn) - if err != nil { - return 0, err - } - totalTokens += count - } - - return totalTokens, nil -} - -// IsEstimate returns false because tiktoken provides exact counts. -func (t *TiktokenTokenizer) IsEstimate() bool { - return false -} - -// ModelName returns the tiktoken model identifier. -func (t *TiktokenTokenizer) ModelName() string { - return t.modelName -} diff --git a/internal/infrastructure/tokenizer/tiktoken_tokenizer_test.go b/internal/infrastructure/tokenizer/tiktoken_tokenizer_test.go deleted file mode 100644 index 68ebd9b2..00000000 --- a/internal/infrastructure/tokenizer/tiktoken_tokenizer_test.go +++ /dev/null @@ -1,545 +0,0 @@ -package tokenizer - -import ( - "strings" - "testing" - - "github.com/awf-project/cli/internal/domain/ports" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -// Component: tiktoken_adapter -// Feature: F033 - -func TestTiktokenTokenizer_InterfaceCompliance(t *testing.T) { - // Verify TiktokenTokenizer implements ports.Tokenizer - var _ ports.Tokenizer = (*TiktokenTokenizer)(nil) -} - -func TestNewTiktokenTokenizer_HappyPath(t *testing.T) { - tests := []struct { - name string - modelName string - }{ - {"cl100k_base (GPT-4, GPT-3.5-turbo)", "cl100k_base"}, - {"p50k_base (Codex)", "p50k_base"}, - {"r50k_base (GPT-3)", "r50k_base"}, - {"gpt2", "gpt2"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer(tt.modelName) - - require.NoError(t, err) - require.NotNil(t, tokenizer) - assert.Equal(t, tt.modelName, tokenizer.ModelName()) - assert.False(t, tokenizer.IsEstimate(), "tiktoken should provide exact counts") - }) - } -} - -func TestNewTiktokenTokenizer_EmptyModelName(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("") - - require.NoError(t, err) - require.NotNil(t, tokenizer) - assert.Equal(t, "", tokenizer.ModelName()) -} - -func TestNewTiktokenTokenizer_InvalidModelName(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("invalid_model_xyz") - - require.NoError(t, err) - require.NotNil(t, tokenizer) - assert.Equal(t, "invalid_model_xyz", tokenizer.ModelName()) -} - -func TestTiktokenTokenizer_CountTokens_HappyPath(t *testing.T) { - tests := []struct { - name string - modelName string - text string - expectedMin int // Minimum expected tokens - expectedMax int // Maximum expected tokens - }{ - { - name: "simple sentence", - modelName: "cl100k_base", - text: "Hello, world!", - expectedMin: 2, - expectedMax: 5, - }, - { - name: "longer text", - modelName: "cl100k_base", - text: "This is a test prompt for token counting in the AWF CLI application.", - expectedMin: 10, - expectedMax: 20, - }, - { - name: "code snippet", - modelName: "cl100k_base", - text: "func main() {\n\tfmt.Println(\"Hello\")\n}", - expectedMin: 8, - expectedMax: 20, - }, - { - name: "markdown text", - modelName: "cl100k_base", - text: "# Header\n\n**Bold** and *italic* text with [link](https://example.com)", - expectedMin: 12, - expectedMax: 25, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer(tt.modelName) - require.NoError(t, err) - - count, err := tokenizer.CountTokens(tt.text) - - require.NoError(t, err) - assert.GreaterOrEqual(t, count, tt.expectedMin, - "token count should be at least %d", tt.expectedMin) - assert.LessOrEqual(t, count, tt.expectedMax, - "token count should be at most %d", tt.expectedMax) - }) - } -} - -func TestTiktokenTokenizer_CountTokens_EmptyString(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - count, err := tokenizer.CountTokens("") - - require.NoError(t, err) - assert.Equal(t, 0, count) -} - -func TestTiktokenTokenizer_CountTokens_WhitespaceOnly(t *testing.T) { - tests := []struct { - name string - text string - }{ - {"single space", " "}, - {"multiple spaces", " "}, - {"tabs", "\t\t"}, - {"newlines", "\n\n\n"}, - {"mixed whitespace", " \t\n \t\n"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - count, err := tokenizer.CountTokens(tt.text) - - require.NoError(t, err) - assert.GreaterOrEqual(t, count, 0) - }) - } -} - -func TestTiktokenTokenizer_CountTokens_UnicodeText(t *testing.T) { - tests := []struct { - name string - text string - }{ - {"chinese", "你好世界"}, - {"russian", "Привет мир"}, - {"arabic", "مرحبا بالعالم"}, - {"emoji", "Hello 👋 World 🌍"}, - {"mixed unicode", "Hello 世界! Привет мир! مرحبا بالعالم!"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - count, err := tokenizer.CountTokens(tt.text) - - require.NoError(t, err) - assert.Greater(t, count, 0, "unicode text should produce tokens") - }) - } -} - -func TestTiktokenTokenizer_CountTokens_SpecialCharacters(t *testing.T) { - tests := []struct { - name string - text string - }{ - {"code block", "```python\ndef foo():\n return \"bar\"\n```"}, - {"xml/html", "content"}, - {"json", "{\"key\": \"value\", \"number\": 42}"}, - {"special symbols", "!@#$%^&*()_+-=[]{}|;':\",./<>?"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - count, err := tokenizer.CountTokens(tt.text) - - require.NoError(t, err) - assert.Greater(t, count, 0) - }) - } -} - -func TestTiktokenTokenizer_CountTokens_LargeText(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - // Create large text (~100KB) - largeText := strings.Repeat("This is a test sentence with multiple words. ", 2000) - - count, err := tokenizer.CountTokens(largeText) - - require.NoError(t, err) - assert.Greater(t, count, 1000, "large text should produce many tokens") -} - -func TestTiktokenTokenizer_CountTokens_VeryLongSingleWord(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - // Very long "word" (no spaces) - longWord := strings.Repeat("a", 10000) - - count, err := tokenizer.CountTokens(longWord) - - require.NoError(t, err) - assert.Greater(t, count, 0) -} - -func TestTiktokenTokenizer_CountTokens_RepeatedCounting(t *testing.T) { - // Test that counting the same text multiple times is consistent - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - text := "This is a test prompt that should produce consistent counts" - - counts := make([]int, 5) - for i := 0; i < 5; i++ { - count, err := tokenizer.CountTokens(text) - require.NoError(t, err) - counts[i] = count - } - - for i := 1; i < len(counts); i++ { - assert.Equal(t, counts[0], counts[i], - "count %d should equal first count %d", counts[i], counts[0]) - } -} - -func TestTiktokenTokenizer_CountTokens_InvalidModel(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("invalid_model_name") - require.NoError(t, err) - - count, err := tokenizer.CountTokens("test text") - - // Note: This behavior depends on tiktoken-go implementation - // For now, stub returns 0, nil - test will fail when implemented - if err != nil { - assert.Equal(t, 0, count, "count should be 0 on error") - assert.Error(t, err) - } else { - // Stub behavior - will fail when real implementation is added - t.Log("Warning: Expected error for invalid model, got success (stub behavior)") - } -} - -func TestTiktokenTokenizer_CountTurnsTokens_HappyPath(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - turns := []string{ - "You are a helpful assistant.", - "Analyze this code snippet.", - "Here is the detailed analysis of the code...", - "Thank you for the analysis!", - } - - count, err := tokenizer.CountTurnsTokens(turns) - - require.NoError(t, err) - assert.Greater(t, count, 10, "multiple turns should produce significant tokens") -} - -func TestTiktokenTokenizer_CountTurnsTokens_SingleTurn(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - turns := []string{"Single turn message"} - - count, err := tokenizer.CountTurnsTokens(turns) - - require.NoError(t, err) - assert.Greater(t, count, 0) -} - -func TestTiktokenTokenizer_CountTurnsTokens_EmptyArray(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - count, err := tokenizer.CountTurnsTokens([]string{}) - - require.NoError(t, err) - assert.Equal(t, 0, count, "empty array should produce 0 tokens") -} - -func TestTiktokenTokenizer_CountTurnsTokens_NilArray(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - count, err := tokenizer.CountTurnsTokens(nil) - - require.NoError(t, err) - assert.Equal(t, 0, count, "nil array should produce 0 tokens") -} - -func TestTiktokenTokenizer_CountTurnsTokens_MixedEmptyTurns(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - turns := []string{ - "First turn with content", - "", - "Third turn with content", - "", - "Fifth turn with content", - } - - count, err := tokenizer.CountTurnsTokens(turns) - - require.NoError(t, err) - assert.Greater(t, count, 0, "non-empty turns should produce tokens") -} - -func TestTiktokenTokenizer_CountTurnsTokens_AllEmptyTurns(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - turns := []string{"", "", "", ""} - - count, err := tokenizer.CountTurnsTokens(turns) - - require.NoError(t, err) - assert.Equal(t, 0, count, "all empty turns should produce 0 tokens") -} - -func TestTiktokenTokenizer_CountTurnsTokens_ManyTurns(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - // Create 100 turns - turns := make([]string, 100) - for i := range turns { - turns[i] = "Turn with some content for testing" - } - - count, err := tokenizer.CountTurnsTokens(turns) - - require.NoError(t, err) - assert.Greater(t, count, 100, "100 turns should produce significant tokens") -} - -func TestTiktokenTokenizer_CountTurnsTokens_LargeIndividualTurns(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - // Create turns with large content - turns := []string{ - strings.Repeat("First turn with lots of content. ", 100), - strings.Repeat("Second turn with lots of content. ", 100), - strings.Repeat("Third turn with lots of content. ", 100), - } - - count, err := tokenizer.CountTurnsTokens(turns) - - require.NoError(t, err) - assert.Greater(t, count, 300, "large turns should produce many tokens") -} - -func TestTiktokenTokenizer_CountTurnsTokens_UnicodeInTurns(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - turns := []string{ - "Hello in English", - "你好 in Chinese", - "Привет in Russian", - "مرحبا in Arabic", - "🌍 Emoji turn", - } - - count, err := tokenizer.CountTurnsTokens(turns) - - require.NoError(t, err) - assert.Greater(t, count, 10, "unicode turns should produce tokens") -} - -func TestTiktokenTokenizer_CountTurnsTokens_MatchesIndividualCounts(t *testing.T) { - // Test that CountTurnsTokens produces same result as sum of individual CountTokens - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - turns := []string{ - "Turn one", - "Turn two", - "Turn three", - } - - individualTotal := 0 - for _, turn := range turns { - count, err := tokenizer.CountTokens(turn) - require.NoError(t, err) - individualTotal += count - } - - batchTotal, err := tokenizer.CountTurnsTokens(turns) - require.NoError(t, err) - - assert.Equal(t, individualTotal, batchTotal, - "batch count should match sum of individual counts") -} - -func TestTiktokenTokenizer_ModelName(t *testing.T) { - tests := []struct { - name string - modelName string - }{ - {"cl100k_base", "cl100k_base"}, - {"p50k_base", "p50k_base"}, - {"r50k_base", "r50k_base"}, - {"gpt2", "gpt2"}, - {"custom-model", "custom-model"}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer(tt.modelName) - require.NoError(t, err) - - name := tokenizer.ModelName() - - assert.Equal(t, tt.modelName, name) - }) - } -} - -func TestTiktokenTokenizer_IsEstimate(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - isEstimate := tokenizer.IsEstimate() - - assert.False(t, isEstimate, "tiktoken should return false for IsEstimate()") -} - -func TestTiktokenTokenizer_IsEstimate_AllModels(t *testing.T) { - models := []string{"cl100k_base", "p50k_base", "r50k_base", "gpt2"} - - for _, modelName := range models { - t.Run(modelName, func(t *testing.T) { - tokenizer, err := NewTiktokenTokenizer(modelName) - require.NoError(t, err) - - isEstimate := tokenizer.IsEstimate() - - assert.False(t, isEstimate, - "model %s should return false for IsEstimate()", modelName) - }) - } -} - -func TestTiktokenTokenizer_RealWorldPrompt(t *testing.T) { - // Test with a realistic AI workflow prompt - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - prompt := `You are a code review assistant. Analyze the following Go code for: -1. Potential bugs -2. Performance issues -3. Security vulnerabilities -4. Code style violations - -Code: -func ProcessData(data []string) error { - for _, item := range data { - // Process item - fmt.Println(item) - } - return nil -} - -Provide detailed feedback.` - - count, err := tokenizer.CountTokens(prompt) - - require.NoError(t, err) - assert.Greater(t, count, 50, "realistic prompt should have substantial tokens") - assert.Less(t, count, 200, "realistic prompt shouldn't have excessive tokens") -} - -func TestTiktokenTokenizer_ConversationScenario(t *testing.T) { - // Test with a multi-turn conversation - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - conversation := []string{ - "System: You are a helpful coding assistant.", - "User: How do I reverse a string in Go?", - "Assistant: Here's how to reverse a string in Go:\n\nfunc reverse(s string) string {\n runes := []rune(s)\n for i, j := 0, len(runes)-1; i < j; i, j = i+1, j-1 {\n runes[i], runes[j] = runes[j], runes[i]\n }\n return string(runes)\n}", - "User: Thanks! Can you add error handling?", - "Assistant: Sure! Here's the version with error handling...", - } - - count, err := tokenizer.CountTurnsTokens(conversation) - - require.NoError(t, err) - assert.Greater(t, count, 100, "multi-turn conversation should have many tokens") -} - -func TestTiktokenTokenizer_CodeSnippets(t *testing.T) { - // Test various programming language snippets - snippets := map[string]string{ - "go": `package main -import "fmt" -func main() { - fmt.Println("Hello, World!") -}`, - "python": `def hello(): - print("Hello, World!") -if __name__ == "__main__": - hello()`, - "javascript": `function hello() { - console.log("Hello, World!"); -} -hello();`, - "sql": `SELECT users.name, orders.total -FROM users -INNER JOIN orders ON users.id = orders.user_id -WHERE orders.total > 100;`, - } - - tokenizer, err := NewTiktokenTokenizer("cl100k_base") - require.NoError(t, err) - - for lang, code := range snippets { - t.Run(lang, func(t *testing.T) { - count, err := tokenizer.CountTokens(code) - - require.NoError(t, err) - assert.Greater(t, count, 5, "%s code should produce tokens", lang) - }) - } -} diff --git a/pkg/interpolation/reference.go b/pkg/interpolation/reference.go index 8afa68c9..49c0cc23 100644 --- a/pkg/interpolation/reference.go +++ b/pkg/interpolation/reference.go @@ -44,13 +44,16 @@ var ValidWorkflowProperties = map[string]bool{ // ValidStateProperties lists known step state properties that can be referenced. var ValidStateProperties = map[string]bool{ - "Output": true, - "Stderr": true, - "ExitCode": true, - "Status": true, - "Response": true, - "TokensUsed": true, - "JSON": true, + "Output": true, + "Stderr": true, + "ExitCode": true, + "Status": true, + "Response": true, + "TokensUsed": true, + "TokensInput": true, + "TokensOutput": true, + "TokensEstimated": true, + "JSON": true, } // ValidErrorProperties lists known error properties in error hooks. diff --git a/pkg/interpolation/reference_json_field_test.go b/pkg/interpolation/reference_json_field_test.go index aedc9f0a..7ef37903 100644 --- a/pkg/interpolation/reference_json_field_test.go +++ b/pkg/interpolation/reference_json_field_test.go @@ -25,7 +25,10 @@ func TestValidStateProperties_AllFields(t *testing.T) { "Status", "Response", "TokensUsed", - "JSON", // F065: new field for explicit JSON output + "TokensInput", + "TokensOutput", + "TokensEstimated", + "JSON", } for _, field := range expectedFields { diff --git a/pkg/interpolation/reference_test.go b/pkg/interpolation/reference_test.go index e47a7d1e..f5ae218b 100644 --- a/pkg/interpolation/reference_test.go +++ b/pkg/interpolation/reference_test.go @@ -819,7 +819,7 @@ func TestValidationMaps_Comprehensive(t *testing.T) { }, "ValidStateProperties": { validMap: interpolation.ValidStateProperties, - required: []string{"Output", "Stderr", "ExitCode", "Status", "Response", "TokensUsed", "JSON"}, + required: []string{"Output", "Stderr", "ExitCode", "Status", "Response", "TokensUsed", "TokensInput", "TokensOutput", "TokensEstimated", "JSON"}, invalid: []string{"stdout", "result", ""}, deprecated: []string{"output", "stderr", "exit_code", "status", "response", "tokensused", "json"}, }, diff --git a/pkg/interpolation/resolver.go b/pkg/interpolation/resolver.go index 9aace2db..c65178b0 100644 --- a/pkg/interpolation/resolver.go +++ b/pkg/interpolation/resolver.go @@ -36,14 +36,17 @@ func (l *LoopData) Index1() int { // StepStateData holds step execution results for interpolation. type StepStateData struct { - Output string - Stderr string - ExitCode int - Status string - Response map[string]any // parsed JSON response from agent steps - TokensUsed int // total tokens used from agent steps - JSON any // explicit JSON output from output_format (map[string]any or []any) - Data map[string]any // structured output from plugin custom step types + Output string + Stderr string + ExitCode int + Status string + Response map[string]any // parsed JSON response from agent steps + TokensUsed int // total tokens used from agent steps + TokensInput int // input tokens (from provider JSON or estimation) + TokensOutput int // output tokens (from provider JSON or estimation) + TokensEstimated bool // true if token counts are estimates, false if from provider + JSON any // explicit JSON output from output_format (map[string]any or []any) + Data map[string]any // structured output from plugin custom step types } // WorkflowData holds workflow metadata for interpolation.