From f9e62fd56d504f410905627ac3f5f87b47565d7a Mon Sep 17 00:00:00 2001 From: "Jakub A. W" Date: Wed, 25 Feb 2026 13:33:09 +0100 Subject: [PATCH 1/3] refactor: fix OCP violation in usage cost mapping Move provider-specific cost mapping data out of usage/cost.go and into each provider's Registration var, wired through the factory at startup. This ensures adding a new provider no longer requires editing the usage package. Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 2 + internal/app/app.go | 4 + internal/core/cost_mapping.go | 25 +++++ internal/providers/anthropic/anthropic.go | 4 + internal/providers/factory.go | 48 +++++++-- internal/providers/factory_test.go | 54 ++++++++++ internal/providers/gemini/gemini.go | 4 + internal/providers/openai/openai.go | 14 +++ internal/providers/xai/xai.go | 7 ++ internal/usage/cost.go | 120 ++++++++-------------- internal/usage/setup_test.go | 45 ++++++++ internal/usage/stream_wrapper.go | 2 +- 12 files changed, 240 insertions(+), 89 deletions(-) create mode 100644 internal/core/cost_mapping.go create mode 100644 internal/usage/setup_test.go diff --git a/CLAUDE.md b/CLAUDE.md index d97a39c1..b8557d44 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -107,6 +107,8 @@ helm/ # Kubernetes Helm charts 1. Create `internal/providers/{name}/` implementing `core.Provider` 2. Export a `Registration` variable: `var Registration = providers.Registration{Type: "{name}", New: New}` + - Optionally add `CostMappings: []core.TokenCostMapping{...}` for provider-specific token cost fields (cached tokens, reasoning tokens, etc.) + - Optionally add `InformationalFields: []string{...}` for known token breakdown fields that don't need separate pricing 3. Register in `cmd/gomodel/main.go` via `factory.Add({name}.Registration)` 4. Add API key env var to `.env.template` and to `knownProviders` in `config/config.go` diff --git a/internal/app/app.go b/internal/app/app.go index f2b65313..b7c72431 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -72,6 +72,10 @@ func New(ctx context.Context, cfg Config) (*App, error) { } app.providers = providerResult + // Register provider cost mappings for usage tracking + costMappings, informationalFields := cfg.Factory.CostRegistry() + usage.RegisterCostMappings(costMappings, informationalFields) + // Initialize audit logging auditResult, err := auditlog.New(ctx, appCfg) if err != nil { diff --git a/internal/core/cost_mapping.go b/internal/core/cost_mapping.go new file mode 100644 index 00000000..ce0be70f --- /dev/null +++ b/internal/core/cost_mapping.go @@ -0,0 +1,25 @@ +package core + +// CostSide indicates whether a token cost contributes to input or output. +type CostSide int + +const ( + CostSideInput CostSide = iota + CostSideOutput +) + +// CostUnit indicates how the pricing field is applied. +type CostUnit int + +const ( + CostUnitPerMtok CostUnit = iota // divide token count by 1M, multiply by rate + CostUnitPerItem // multiply count directly by rate +) + +// TokenCostMapping maps a RawData key to a pricing field and cost side. +type TokenCostMapping struct { + RawDataKey string + PricingField func(p *ModelPricing) *float64 + Side CostSide + Unit CostUnit +} diff --git a/internal/providers/anthropic/anthropic.go b/internal/providers/anthropic/anthropic.go index 0bcc85b5..0ac38215 100644 --- a/internal/providers/anthropic/anthropic.go +++ b/internal/providers/anthropic/anthropic.go @@ -24,6 +24,10 @@ import ( var Registration = providers.Registration{ Type: "anthropic", New: New, + CostMappings: []core.TokenCostMapping{ + {RawDataKey: "cache_read_input_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "cache_creation_input_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CacheWritePerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + }, } const ( diff --git a/internal/providers/factory.go b/internal/providers/factory.go index ff78bc60..6992cf7d 100644 --- a/internal/providers/factory.go +++ b/internal/providers/factory.go @@ -21,21 +21,23 @@ type ProviderConstructor func(apiKey string, opts ProviderOptions) core.Provider // Registration contains metadata for registering a provider with the factory. type Registration struct { - Type string - New ProviderConstructor + Type string + New ProviderConstructor + CostMappings []core.TokenCostMapping // optional: provider-specific token cost mappings + InformationalFields []string // optional: known breakdown fields that need no separate pricing } // ProviderFactory manages provider registration and creation. type ProviderFactory struct { - mu sync.RWMutex - builders map[string]ProviderConstructor - hooks llmclient.Hooks + mu sync.RWMutex + registrations map[string]Registration + hooks llmclient.Hooks } // NewProviderFactory creates a new provider factory instance. func NewProviderFactory() *ProviderFactory { return &ProviderFactory{ - builders: make(map[string]ProviderConstructor), + registrations: make(map[string]Registration), } } @@ -58,13 +60,13 @@ func (f *ProviderFactory) Add(reg Registration) { } f.mu.Lock() defer f.mu.Unlock() - f.builders[reg.Type] = reg.New + f.registrations[reg.Type] = reg } // Create instantiates a provider based on its resolved configuration. func (f *ProviderFactory) Create(cfg ProviderConfig) (core.Provider, error) { f.mu.RLock() - builder, ok := f.builders[cfg.Type] + reg, ok := f.registrations[cfg.Type] hooks := f.hooks f.mu.RUnlock() @@ -77,7 +79,7 @@ func (f *ProviderFactory) Create(cfg ProviderConfig) (core.Provider, error) { Resilience: cfg.Resilience, } - p := builder(cfg.APIKey, opts) + p := reg.New(cfg.APIKey, opts) if cfg.BaseURL != "" { if setter, ok := p.(interface{ SetBaseURL(string) }); ok { @@ -93,9 +95,33 @@ func (f *ProviderFactory) RegisteredTypes() []string { f.mu.RLock() defer f.mu.RUnlock() - types := make([]string, 0, len(f.builders)) - for t := range f.builders { + types := make([]string, 0, len(f.registrations)) + for t := range f.registrations { types = append(types, t) } return types } + +// CostRegistry returns aggregated cost mappings and informational fields from all +// registered providers. The returned map is keyed by provider type. +func (f *ProviderFactory) CostRegistry() (mappings map[string][]core.TokenCostMapping, informationalFields []string) { + f.mu.RLock() + defer f.mu.RUnlock() + + mappings = make(map[string][]core.TokenCostMapping) + seen := make(map[string]struct{}) + + for _, reg := range f.registrations { + if len(reg.CostMappings) > 0 { + mappings[reg.Type] = reg.CostMappings + } + for _, field := range reg.InformationalFields { + if _, ok := seen[field]; !ok { + seen[field] = struct{}{} + informationalFields = append(informationalFields, field) + } + } + } + + return mappings, informationalFields +} diff --git a/internal/providers/factory_test.go b/internal/providers/factory_test.go index 03f9200c..4c0251ce 100644 --- a/internal/providers/factory_test.go +++ b/internal/providers/factory_test.go @@ -3,6 +3,7 @@ package providers import ( "context" "io" + "sort" "testing" "time" @@ -343,3 +344,56 @@ func TestProviderFactory_Create_PassesResilienceConfig(t *testing.T) { t.Errorf("JitterFactor = %f, want 0.5", r.JitterFactor) } } + +func TestProviderFactory_CostRegistry(t *testing.T) { + factory := NewProviderFactory() + + factory.Add(Registration{ + Type: "provider-a", + New: func(_ string, _ ProviderOptions) core.Provider { return &factoryMockProvider{} }, + CostMappings: []core.TokenCostMapping{ + {RawDataKey: "cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + }, + InformationalFields: []string{"prompt_text_tokens"}, + }) + + factory.Add(Registration{ + Type: "provider-b", + New: func(_ string, _ ProviderOptions) core.Provider { return &factoryMockProvider{} }, + CostMappings: []core.TokenCostMapping{ + {RawDataKey: "thought_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, + }, + InformationalFields: []string{"prompt_text_tokens", "prompt_image_tokens"}, + }) + + // Provider with no cost mappings + factory.Add(Registration{ + Type: "provider-c", + New: func(_ string, _ ProviderOptions) core.Provider { return &factoryMockProvider{} }, + }) + + mappings, informational := factory.CostRegistry() + + // Check mappings + if len(mappings) != 2 { + t.Fatalf("expected 2 provider mappings, got %d", len(mappings)) + } + if len(mappings["provider-a"]) != 1 { + t.Errorf("provider-a: expected 1 mapping, got %d", len(mappings["provider-a"])) + } + if len(mappings["provider-b"]) != 1 { + t.Errorf("provider-b: expected 1 mapping, got %d", len(mappings["provider-b"])) + } + if _, ok := mappings["provider-c"]; ok { + t.Error("provider-c should not have mappings") + } + + // Check informational fields are deduplicated + sort.Strings(informational) + if len(informational) != 2 { + t.Fatalf("expected 2 informational fields (deduplicated), got %d: %v", len(informational), informational) + } + if informational[0] != "prompt_image_tokens" || informational[1] != "prompt_text_tokens" { + t.Errorf("unexpected informational fields: %v", informational) + } +} diff --git a/internal/providers/gemini/gemini.go b/internal/providers/gemini/gemini.go index d47e2394..e7a25c21 100644 --- a/internal/providers/gemini/gemini.go +++ b/internal/providers/gemini/gemini.go @@ -19,6 +19,10 @@ import ( var Registration = providers.Registration{ Type: "gemini", New: New, + CostMappings: []core.TokenCostMapping{ + {RawDataKey: "cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "thought_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, + }, } const ( diff --git a/internal/providers/openai/openai.go b/internal/providers/openai/openai.go index 2d43d247..043b24bd 100644 --- a/internal/providers/openai/openai.go +++ b/internal/providers/openai/openai.go @@ -16,6 +16,20 @@ import ( var Registration = providers.Registration{ Type: "openai", New: New, + CostMappings: []core.TokenCostMapping{ + {RawDataKey: "cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "prompt_cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "completion_reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "prompt_audio_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.AudioInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "completion_audio_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.AudioOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, + }, + InformationalFields: []string{ + "prompt_text_tokens", + "prompt_image_tokens", + "completion_accepted_prediction_tokens", + "completion_rejected_prediction_tokens", + }, } const ( diff --git a/internal/providers/xai/xai.go b/internal/providers/xai/xai.go index 6b435c95..cdfd7983 100644 --- a/internal/providers/xai/xai.go +++ b/internal/providers/xai/xai.go @@ -15,6 +15,13 @@ import ( var Registration = providers.Registration{ Type: "xai", New: New, + CostMappings: []core.TokenCostMapping{ + {RawDataKey: "cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "prompt_cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "completion_reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "image_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.InputPerImage }, Side: core.CostSideInput, Unit: core.CostUnitPerItem}, + }, } const ( diff --git a/internal/usage/cost.go b/internal/usage/cost.go index ecb8ab68..ccccb56f 100644 --- a/internal/usage/cost.go +++ b/internal/usage/cost.go @@ -16,79 +16,43 @@ type CostResult struct { Caveat string } -// costSide indicates whether a token cost contributes to input or output. -type costSide int - -const ( - sideInput costSide = iota - sideOutput -) - -// costUnit indicates how the pricing field is applied. -type costUnit int - -const ( - unitPerMtok costUnit = iota // divide token count by 1M, multiply by rate - unitPerItem // multiply count directly by rate -) - -// tokenCostMapping maps a RawData key to a pricing field and cost side. -type tokenCostMapping struct { - rawDataKey string - pricingField func(p *core.ModelPricing) *float64 - side costSide - unit costUnit +// costRegistry holds provider-specific cost mappings and informational fields, +// populated at startup via RegisterCostMappings. +type costRegistry struct { + providerMappings map[string][]core.TokenCostMapping + informationalFields map[string]struct{} + extendedFieldSet map[string]struct{} } -// providerMappings defines the per-provider RawData key to pricing field mappings. -var providerMappings = map[string][]tokenCostMapping{ - "openai": { - {rawDataKey: "cached_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, side: sideInput, unit: unitPerMtok}, - {rawDataKey: "prompt_cached_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, side: sideInput, unit: unitPerMtok}, - {rawDataKey: "reasoning_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, side: sideOutput, unit: unitPerMtok}, - {rawDataKey: "completion_reasoning_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, side: sideOutput, unit: unitPerMtok}, - {rawDataKey: "prompt_audio_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.AudioInputPerMtok }, side: sideInput, unit: unitPerMtok}, - {rawDataKey: "completion_audio_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.AudioOutputPerMtok }, side: sideOutput, unit: unitPerMtok}, - }, - "anthropic": { - {rawDataKey: "cache_read_input_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, side: sideInput, unit: unitPerMtok}, - {rawDataKey: "cache_creation_input_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.CacheWritePerMtok }, side: sideInput, unit: unitPerMtok}, - }, - "gemini": { - {rawDataKey: "cached_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, side: sideInput, unit: unitPerMtok}, - {rawDataKey: "thought_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, side: sideOutput, unit: unitPerMtok}, - }, - "xai": { - {rawDataKey: "cached_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, side: sideInput, unit: unitPerMtok}, - {rawDataKey: "prompt_cached_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, side: sideInput, unit: unitPerMtok}, - {rawDataKey: "reasoning_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, side: sideOutput, unit: unitPerMtok}, - {rawDataKey: "completion_reasoning_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, side: sideOutput, unit: unitPerMtok}, - {rawDataKey: "image_tokens", pricingField: func(p *core.ModelPricing) *float64 { return p.InputPerImage }, side: sideInput, unit: unitPerItem}, - }, +// defaultCostRegistry is the package-level registry used by CalculateGranularCost +// and stream_wrapper.go. Populated by RegisterCostMappings at startup. +var defaultCostRegistry = &costRegistry{ + providerMappings: make(map[string][]core.TokenCostMapping), + informationalFields: make(map[string]struct{}), + extendedFieldSet: make(map[string]struct{}), } -// informationalFields are token fields that are known breakdowns of the base -// input/output counts. They never need separate pricing and should not trigger -// "unmapped token field" caveats. -var informationalFields = map[string]struct{}{ - "prompt_text_tokens": {}, - "prompt_image_tokens": {}, - "completion_accepted_prediction_tokens": {}, - "completion_rejected_prediction_tokens": {}, -} +// RegisterCostMappings populates the cost registry with provider-specific mappings +// and informational fields. Called once at startup after providers are registered. +func RegisterCostMappings(mappings map[string][]core.TokenCostMapping, informational []string) { + reg := &costRegistry{ + providerMappings: mappings, + informationalFields: make(map[string]struct{}, len(informational)), + extendedFieldSet: make(map[string]struct{}), + } -// extendedFieldSet is derived from providerMappings and contains all RawData keys -// that providers may report. Used by stream_wrapper.go to extract extended fields -// from SSE usage data without maintaining a separate hard-coded list. -var extendedFieldSet = func() map[string]struct{} { - set := make(map[string]struct{}) - for _, mappings := range providerMappings { - for _, m := range mappings { - set[m.rawDataKey] = struct{}{} + for _, f := range informational { + reg.informationalFields[f] = struct{}{} + } + + for _, ms := range mappings { + for _, m := range ms { + reg.extendedFieldSet[m.RawDataKey] = struct{}{} } } - return set -}() + + defaultCostRegistry = reg +} // CalculateGranularCost computes input, output, and total costs from token counts, // raw provider-specific data, and pricing information. It accounts for cached tokens, @@ -101,6 +65,8 @@ func CalculateGranularCost(inputTokens, outputTokens int, rawData map[string]any return CostResult{} } + reg := defaultCostRegistry + var inputCost, outputCost float64 var hasInput, hasOutput bool var caveats []string @@ -125,15 +91,15 @@ func CalculateGranularCost(inputTokens, outputTokens int, rawData map[string]any // rawData keys map to the same pricing field (e.g. cached_tokens and prompt_cached_tokens // both map to CachedInputPerMtok). appliedFields := make(map[*float64]bool) - if mappings, ok := providerMappings[providerType]; ok { + if mappings, ok := reg.providerMappings[providerType]; ok { for _, m := range mappings { - count := extractInt(rawData, m.rawDataKey) + count := extractInt(rawData, m.RawDataKey) if count == 0 { continue } - mappedKeys[m.rawDataKey] = true + mappedKeys[m.RawDataKey] = true - rate := m.pricingField(pricing) + rate := m.PricingField(pricing) if rate == nil { continue // Base rate covers this token type; no adjustment needed } @@ -144,18 +110,18 @@ func CalculateGranularCost(inputTokens, outputTokens int, rawData map[string]any appliedFields[rate] = true var cost float64 - switch m.unit { - case unitPerMtok: + switch m.Unit { + case core.CostUnitPerMtok: cost = float64(count) * *rate / 1_000_000 - case unitPerItem: + case core.CostUnitPerItem: cost = float64(count) * *rate } - switch m.side { - case sideInput: + switch m.Side { + case core.CostSideInput: inputCost += cost hasInput = true - case sideOutput: + case core.CostSideOutput: outputCost += cost hasOutput = true } @@ -167,7 +133,7 @@ func CalculateGranularCost(inputTokens, outputTokens int, rawData map[string]any if mappedKeys[key] { continue } - if _, ok := informationalFields[key]; ok { + if _, ok := reg.informationalFields[key]; ok { continue // Known breakdown of base counts, not separately priced } if isTokenField(key) { diff --git a/internal/usage/setup_test.go b/internal/usage/setup_test.go new file mode 100644 index 00000000..dadabff0 --- /dev/null +++ b/internal/usage/setup_test.go @@ -0,0 +1,45 @@ +package usage + +import ( + "os" + "testing" + + "gomodel/internal/core" +) + +func TestMain(m *testing.M) { + // Register cost mappings that were previously hardcoded in cost.go. + // This mirrors the data that providers supply via their Registration vars. + RegisterCostMappings(map[string][]core.TokenCostMapping{ + "openai": { + {RawDataKey: "cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "prompt_cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "completion_reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "prompt_audio_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.AudioInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "completion_audio_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.AudioOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, + }, + "anthropic": { + {RawDataKey: "cache_read_input_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "cache_creation_input_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CacheWritePerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + }, + "gemini": { + {RawDataKey: "cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "thought_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, + }, + "xai": { + {RawDataKey: "cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "prompt_cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "completion_reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, + {RawDataKey: "image_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.InputPerImage }, Side: core.CostSideInput, Unit: core.CostUnitPerItem}, + }, + }, []string{ + "prompt_text_tokens", + "prompt_image_tokens", + "completion_accepted_prediction_tokens", + "completion_rejected_prediction_tokens", + }) + + os.Exit(m.Run()) +} diff --git a/internal/usage/stream_wrapper.go b/internal/usage/stream_wrapper.go index 8d5fb7fa..22403abe 100644 --- a/internal/usage/stream_wrapper.go +++ b/internal/usage/stream_wrapper.go @@ -233,7 +233,7 @@ func (w *StreamUsageWrapper) extractUsageFromJSON(data []byte) *UsageEntry { // Extract extended usage data (provider-specific) using the field set // derived from providerMappings in cost.go (single source of truth). - for field := range extendedFieldSet { + for field := range defaultCostRegistry.extendedFieldSet { if v, ok := usageMap[field].(float64); ok && v > 0 { rawData[field] = int(v) } From a0cb34f7d04518d7c383b758066514b138f84f1d Mon Sep 17 00:00:00 2001 From: "Jakub A. W" Date: Wed, 25 Feb 2026 15:58:50 +0100 Subject: [PATCH 2/3] fix: address review findings on OCP cost mapping refactor - Add CostSideUnknown/CostUnitUnknown zero-value sentinels and field doc comments to TokenCostMapping (core/cost_mapping.go) - Sort informationalFields in CostRegistry for deterministic output - Add default branches to Unit/Side switches to surface unknown enums - Use atomic.Pointer for defaultCostRegistry to prevent data races - Replace hardcoded test data with provider Registration vars - Add cache_read_input_tokens and cache_creation_input_tokens to Anthropic streaming usage (both ChatCompletion and Responses) - Fix Shutdown docstring and startup sequence in CLAUDE.md Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 2 +- internal/app/app.go | 5 +-- internal/core/cost_mapping.go | 19 ++++++---- internal/providers/anthropic/anthropic.go | 18 ++++++++-- internal/providers/factory.go | 2 ++ internal/usage/cost.go | 31 ++++++++++++----- internal/usage/setup_test.go | 42 ++++++----------------- internal/usage/stream_wrapper.go | 2 +- 8 files changed, 70 insertions(+), 51 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index b8557d44..a61c50f9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -62,7 +62,7 @@ Client → Echo Middleware (logger → recover → body limit → audit log → - `internal/observability/metrics.go` — Prometheus metrics via hooks injected at factory level: `gomodel_requests_total`, `gomodel_request_duration_seconds`, `gomodel_requests_in_flight`. - `internal/cache/` — Local file or Redis cache backends for model registry. -**Startup:** Config load (defaults → YAML → env vars) → Register providers with factory → Init providers (cache → async model load → background refresh → router) → Init audit logging → Init usage tracking (shares storage if same backend) → Build guardrails pipeline → Create server → Start listening +**Startup:** Config load (defaults → YAML → env vars) → Register providers with factory → Init providers (cache → async model load → background refresh → router) → Register cost mappings (`RegisterCostMappings`) → Init audit logging → Init usage tracking (shares storage if same backend) → Build guardrails pipeline → Create server → Start listening **Shutdown (in order):** HTTP server (stop accepting) → Providers (stop refresh + close cache) → Usage tracking (flush buffer) → Audit logging (flush buffer) diff --git a/internal/app/app.go b/internal/app/app.go index b7c72431..41f81def 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -216,8 +216,9 @@ func (a *App) Start(addr string) error { // Shutdown gracefully shuts down all components in the correct order. // It ensures proper cleanup of resources: // 1. HTTP server (stop accepting new requests) -// 2. Background refresh goroutine and cache -// 3. Audit logging +// 2. Providers (stop background refresh goroutine and close cache) +// 3. Usage tracking (flush pending entries) +// 4. Audit logging (flush pending logs) // // Safe to call multiple times; subsequent calls are no-ops. func (a *App) Shutdown(ctx context.Context) error { diff --git a/internal/core/cost_mapping.go b/internal/core/cost_mapping.go index ce0be70f..0e35aa82 100644 --- a/internal/core/cost_mapping.go +++ b/internal/core/cost_mapping.go @@ -4,7 +4,8 @@ package core type CostSide int const ( - CostSideInput CostSide = iota + CostSideUnknown CostSide = iota // zero value; must not be used in mappings + CostSideInput CostSideOutput ) @@ -12,14 +13,20 @@ const ( type CostUnit int const ( - CostUnitPerMtok CostUnit = iota // divide token count by 1M, multiply by rate + CostUnitUnknown CostUnit = iota // zero value; must not be used in mappings + CostUnitPerMtok // divide token count by 1M, multiply by rate CostUnitPerItem // multiply count directly by rate ) -// TokenCostMapping maps a RawData key to a pricing field and cost side. +// TokenCostMapping maps a provider-specific RawData key to a pricing field and cost side. type TokenCostMapping struct { - RawDataKey string + // RawDataKey is the key in the usage RawData map (e.g. "cached_tokens"). + RawDataKey string + // PricingField returns a pointer to the relevant rate from ModelPricing, or nil + // if the base rate already covers this token type. PricingField func(p *ModelPricing) *float64 - Side CostSide - Unit CostUnit + // Side indicates whether this cost contributes to input or output. + Side CostSide + // Unit indicates the pricing unit (per million tokens or per item). + Unit CostUnit } diff --git a/internal/providers/anthropic/anthropic.go b/internal/providers/anthropic/anthropic.go index 0ac38215..772614e6 100644 --- a/internal/providers/anthropic/anthropic.go +++ b/internal/providers/anthropic/anthropic.go @@ -502,11 +502,18 @@ func (sc *streamConverter) convertEvent(event *anthropicStreamEvent) string { } // Include usage data if present (OpenAI format) if event.Usage != nil { - chunk["usage"] = map[string]interface{}{ + usageMap := map[string]interface{}{ "prompt_tokens": event.Usage.InputTokens, "completion_tokens": event.Usage.OutputTokens, "total_tokens": event.Usage.InputTokens + event.Usage.OutputTokens, } + if event.Usage.CacheReadInputTokens > 0 { + usageMap["cache_read_input_tokens"] = event.Usage.CacheReadInputTokens + } + if event.Usage.CacheCreationInputTokens > 0 { + usageMap["cache_creation_input_tokens"] = event.Usage.CacheCreationInputTokens + } + chunk["usage"] = usageMap } jsonData, err := json.Marshal(chunk) if err != nil { @@ -788,11 +795,18 @@ func (sc *responsesStreamConverter) Read(p []byte) (n int, err error) { } // Include usage data if captured from message_delta if sc.cachedUsage != nil { - responseData["usage"] = map[string]interface{}{ + usageMap := map[string]interface{}{ "input_tokens": sc.cachedUsage.InputTokens, "output_tokens": sc.cachedUsage.OutputTokens, "total_tokens": sc.cachedUsage.InputTokens + sc.cachedUsage.OutputTokens, } + if sc.cachedUsage.CacheReadInputTokens > 0 { + usageMap["cache_read_input_tokens"] = sc.cachedUsage.CacheReadInputTokens + } + if sc.cachedUsage.CacheCreationInputTokens > 0 { + usageMap["cache_creation_input_tokens"] = sc.cachedUsage.CacheCreationInputTokens + } + responseData["usage"] = usageMap } doneEvent := map[string]interface{}{ "type": "response.completed", diff --git a/internal/providers/factory.go b/internal/providers/factory.go index 6992cf7d..983b5a9a 100644 --- a/internal/providers/factory.go +++ b/internal/providers/factory.go @@ -3,6 +3,7 @@ package providers import ( "fmt" + "sort" "sync" "gomodel/config" @@ -123,5 +124,6 @@ func (f *ProviderFactory) CostRegistry() (mappings map[string][]core.TokenCostMa } } + sort.Strings(informationalFields) return mappings, informationalFields } diff --git a/internal/usage/cost.go b/internal/usage/cost.go index ccccb56f..e0c375c3 100644 --- a/internal/usage/cost.go +++ b/internal/usage/cost.go @@ -4,6 +4,7 @@ import ( "fmt" "sort" "strings" + "sync/atomic" "gomodel/internal/core" ) @@ -24,12 +25,21 @@ type costRegistry struct { extendedFieldSet map[string]struct{} } -// defaultCostRegistry is the package-level registry used by CalculateGranularCost -// and stream_wrapper.go. Populated by RegisterCostMappings at startup. -var defaultCostRegistry = &costRegistry{ - providerMappings: make(map[string][]core.TokenCostMapping), - informationalFields: make(map[string]struct{}), - extendedFieldSet: make(map[string]struct{}), +// costRegistryPtr is the package-level registry used by CalculateGranularCost +// and stream_wrapper.go. Published atomically by RegisterCostMappings. +var costRegistryPtr atomic.Pointer[costRegistry] + +func init() { + costRegistryPtr.Store(&costRegistry{ + providerMappings: make(map[string][]core.TokenCostMapping), + informationalFields: make(map[string]struct{}), + extendedFieldSet: make(map[string]struct{}), + }) +} + +// loadCostRegistry returns the current cost registry. Never returns nil. +func loadCostRegistry() *costRegistry { + return costRegistryPtr.Load() } // RegisterCostMappings populates the cost registry with provider-specific mappings @@ -51,7 +61,7 @@ func RegisterCostMappings(mappings map[string][]core.TokenCostMapping, informati } } - defaultCostRegistry = reg + costRegistryPtr.Store(reg) } // CalculateGranularCost computes input, output, and total costs from token counts, @@ -65,7 +75,7 @@ func CalculateGranularCost(inputTokens, outputTokens int, rawData map[string]any return CostResult{} } - reg := defaultCostRegistry + reg := loadCostRegistry() var inputCost, outputCost float64 var hasInput, hasOutput bool @@ -115,6 +125,9 @@ func CalculateGranularCost(inputTokens, outputTokens int, rawData map[string]any cost = float64(count) * *rate / 1_000_000 case core.CostUnitPerItem: cost = float64(count) * *rate + default: + caveats = append(caveats, fmt.Sprintf("unknown cost unit %d for field %s", int(m.Unit), m.RawDataKey)) + continue } switch m.Side { @@ -124,6 +137,8 @@ func CalculateGranularCost(inputTokens, outputTokens int, rawData map[string]any case core.CostSideOutput: outputCost += cost hasOutput = true + default: + caveats = append(caveats, fmt.Sprintf("unknown cost side %d for field %s", int(m.Side), m.RawDataKey)) } } } diff --git a/internal/usage/setup_test.go b/internal/usage/setup_test.go index dadabff0..81155da3 100644 --- a/internal/usage/setup_test.go +++ b/internal/usage/setup_test.go @@ -5,41 +5,21 @@ import ( "testing" "gomodel/internal/core" + "gomodel/internal/providers/anthropic" + "gomodel/internal/providers/gemini" + "gomodel/internal/providers/openai" + "gomodel/internal/providers/xai" ) func TestMain(m *testing.M) { - // Register cost mappings that were previously hardcoded in cost.go. - // This mirrors the data that providers supply via their Registration vars. + // Register cost mappings from the authoritative provider Registration vars + // so tests use the single source of truth rather than duplicated data. RegisterCostMappings(map[string][]core.TokenCostMapping{ - "openai": { - {RawDataKey: "cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, - {RawDataKey: "prompt_cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, - {RawDataKey: "reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, - {RawDataKey: "completion_reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, - {RawDataKey: "prompt_audio_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.AudioInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, - {RawDataKey: "completion_audio_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.AudioOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, - }, - "anthropic": { - {RawDataKey: "cache_read_input_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, - {RawDataKey: "cache_creation_input_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CacheWritePerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, - }, - "gemini": { - {RawDataKey: "cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, - {RawDataKey: "thought_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, - }, - "xai": { - {RawDataKey: "cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, - {RawDataKey: "prompt_cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok}, - {RawDataKey: "reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, - {RawDataKey: "completion_reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok}, - {RawDataKey: "image_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.InputPerImage }, Side: core.CostSideInput, Unit: core.CostUnitPerItem}, - }, - }, []string{ - "prompt_text_tokens", - "prompt_image_tokens", - "completion_accepted_prediction_tokens", - "completion_rejected_prediction_tokens", - }) + openai.Registration.Type: openai.Registration.CostMappings, + anthropic.Registration.Type: anthropic.Registration.CostMappings, + gemini.Registration.Type: gemini.Registration.CostMappings, + xai.Registration.Type: xai.Registration.CostMappings, + }, openai.Registration.InformationalFields) os.Exit(m.Run()) } diff --git a/internal/usage/stream_wrapper.go b/internal/usage/stream_wrapper.go index 22403abe..a732cf88 100644 --- a/internal/usage/stream_wrapper.go +++ b/internal/usage/stream_wrapper.go @@ -233,7 +233,7 @@ func (w *StreamUsageWrapper) extractUsageFromJSON(data []byte) *UsageEntry { // Extract extended usage data (provider-specific) using the field set // derived from providerMappings in cost.go (single source of truth). - for field := range defaultCostRegistry.extendedFieldSet { + for field := range loadCostRegistry().extendedFieldSet { if v, ok := usageMap[field].(float64); ok && v > 0 { rawData[field] = int(v) } From 8f305eba929eba954f36cae67e18d0447b036940 Mon Sep 17 00:00:00 2001 From: "Jakub A. W" Date: Wed, 25 Feb 2026 16:08:31 +0100 Subject: [PATCH 3/3] refactor: extract appendCacheFields helper and use factory in test setup - Extract duplicated cache field injection in Anthropic streaming into appendCacheFields helper, called from both ChatCompletion and Responses stream converters - Use ProviderFactory.CostRegistry() in usage TestMain to aggregate informational fields the same way production does, instead of hardcoding openai.Registration.InformationalFields Co-Authored-By: Claude Opus 4.6 --- internal/providers/anthropic/anthropic.go | 24 +++++++++++------------ internal/usage/setup_test.go | 20 ++++++++++--------- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/internal/providers/anthropic/anthropic.go b/internal/providers/anthropic/anthropic.go index 772614e6..7dbe8149 100644 --- a/internal/providers/anthropic/anthropic.go +++ b/internal/providers/anthropic/anthropic.go @@ -507,12 +507,7 @@ func (sc *streamConverter) convertEvent(event *anthropicStreamEvent) string { "completion_tokens": event.Usage.OutputTokens, "total_tokens": event.Usage.InputTokens + event.Usage.OutputTokens, } - if event.Usage.CacheReadInputTokens > 0 { - usageMap["cache_read_input_tokens"] = event.Usage.CacheReadInputTokens - } - if event.Usage.CacheCreationInputTokens > 0 { - usageMap["cache_creation_input_tokens"] = event.Usage.CacheCreationInputTokens - } + appendCacheFields(usageMap, event.Usage) chunk["usage"] = usageMap } jsonData, err := json.Marshal(chunk) @@ -679,6 +674,16 @@ func convertAnthropicResponseToResponses(resp *anthropicResponse, model string) } } +// appendCacheFields adds non-zero cache token fields from anthropicUsage to the given map. +func appendCacheFields(m map[string]interface{}, u *anthropicUsage) { + if u.CacheReadInputTokens > 0 { + m["cache_read_input_tokens"] = u.CacheReadInputTokens + } + if u.CacheCreationInputTokens > 0 { + m["cache_creation_input_tokens"] = u.CacheCreationInputTokens + } +} + // buildAnthropicRawUsage extracts cache fields from anthropicUsage into a RawData map. func buildAnthropicRawUsage(u anthropicUsage) map[string]any { raw := make(map[string]any) @@ -800,12 +805,7 @@ func (sc *responsesStreamConverter) Read(p []byte) (n int, err error) { "output_tokens": sc.cachedUsage.OutputTokens, "total_tokens": sc.cachedUsage.InputTokens + sc.cachedUsage.OutputTokens, } - if sc.cachedUsage.CacheReadInputTokens > 0 { - usageMap["cache_read_input_tokens"] = sc.cachedUsage.CacheReadInputTokens - } - if sc.cachedUsage.CacheCreationInputTokens > 0 { - usageMap["cache_creation_input_tokens"] = sc.cachedUsage.CacheCreationInputTokens - } + appendCacheFields(usageMap, sc.cachedUsage) responseData["usage"] = usageMap } doneEvent := map[string]interface{}{ diff --git a/internal/usage/setup_test.go b/internal/usage/setup_test.go index 81155da3..55d7d54d 100644 --- a/internal/usage/setup_test.go +++ b/internal/usage/setup_test.go @@ -4,7 +4,7 @@ import ( "os" "testing" - "gomodel/internal/core" + "gomodel/internal/providers" "gomodel/internal/providers/anthropic" "gomodel/internal/providers/gemini" "gomodel/internal/providers/openai" @@ -12,14 +12,16 @@ import ( ) func TestMain(m *testing.M) { - // Register cost mappings from the authoritative provider Registration vars - // so tests use the single source of truth rather than duplicated data. - RegisterCostMappings(map[string][]core.TokenCostMapping{ - openai.Registration.Type: openai.Registration.CostMappings, - anthropic.Registration.Type: anthropic.Registration.CostMappings, - gemini.Registration.Type: gemini.Registration.CostMappings, - xai.Registration.Type: xai.Registration.CostMappings, - }, openai.Registration.InformationalFields) + // Build cost mappings and informational fields the same way production does: + // register all providers into a factory and use CostRegistry to aggregate. + factory := providers.NewProviderFactory() + factory.Add(openai.Registration) + factory.Add(anthropic.Registration) + factory.Add(gemini.Registration) + factory.Add(xai.Registration) + + costMappings, informationalFields := factory.CostRegistry() + RegisterCostMappings(costMappings, informationalFields) os.Exit(m.Run()) }