Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ Client → Echo Middleware (logger → recover → body limit → audit log →
- `internal/observability/metrics.go` — Prometheus metrics via hooks injected at factory level: `gomodel_requests_total`, `gomodel_request_duration_seconds`, `gomodel_requests_in_flight`.
- `internal/cache/` — Local file or Redis cache backends for model registry.

**Startup:** Config load (defaults → YAML → env vars) → Register providers with factory → Init providers (cache → async model load → background refresh → router) → Init audit logging → Init usage tracking (shares storage if same backend) → Build guardrails pipeline → Create server → Start listening
**Startup:** Config load (defaults → YAML → env vars) → Register providers with factory → Init providers (cache → async model load → background refresh → router) → Register cost mappings (`RegisterCostMappings`) → Init audit logging → Init usage tracking (shares storage if same backend) → Build guardrails pipeline → Create server → Start listening

**Shutdown (in order):** HTTP server (stop accepting) → Providers (stop refresh + close cache) → Usage tracking (flush buffer) → Audit logging (flush buffer)

Expand Down Expand Up @@ -107,6 +107,8 @@ helm/ # Kubernetes Helm charts

1. Create `internal/providers/{name}/` implementing `core.Provider`
2. Export a `Registration` variable: `var Registration = providers.Registration{Type: "{name}", New: New}`
- Optionally add `CostMappings: []core.TokenCostMapping{...}` for provider-specific token cost fields (cached tokens, reasoning tokens, etc.)
- Optionally add `InformationalFields: []string{...}` for known token breakdown fields that don't need separate pricing
3. Register in `cmd/gomodel/main.go` via `factory.Add({name}.Registration)`
4. Add API key env var to `.env.template` and to `knownProviders` in `config/config.go`

Expand Down
9 changes: 7 additions & 2 deletions internal/app/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ func New(ctx context.Context, cfg Config) (*App, error) {
}
app.providers = providerResult

// Register provider cost mappings for usage tracking
costMappings, informationalFields := cfg.Factory.CostRegistry()
usage.RegisterCostMappings(costMappings, informationalFields)

// Initialize audit logging
auditResult, err := auditlog.New(ctx, appCfg)
if err != nil {
Expand Down Expand Up @@ -212,8 +216,9 @@ func (a *App) Start(addr string) error {
// Shutdown gracefully shuts down all components in the correct order.
// It ensures proper cleanup of resources:
// 1. HTTP server (stop accepting new requests)
// 2. Background refresh goroutine and cache
// 3. Audit logging
// 2. Providers (stop background refresh goroutine and close cache)
// 3. Usage tracking (flush pending entries)
// 4. Audit logging (flush pending logs)
//
// Safe to call multiple times; subsequent calls are no-ops.
func (a *App) Shutdown(ctx context.Context) error {
Expand Down
32 changes: 32 additions & 0 deletions internal/core/cost_mapping.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package core

// CostSide indicates whether a token cost contributes to input or output.
type CostSide int

const (
CostSideUnknown CostSide = iota // zero value; must not be used in mappings
CostSideInput
CostSideOutput
)

// CostUnit indicates how the pricing field is applied.
type CostUnit int

const (
CostUnitUnknown CostUnit = iota // zero value; must not be used in mappings
CostUnitPerMtok // divide token count by 1M, multiply by rate
CostUnitPerItem // multiply count directly by rate
Comment on lines +7 to +18
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick | 🔵 Trivial

CostSideUnknown/CostUnitUnknown zero-value sentinels are unvalidated — silently incorrect mappings possible.

The doc comments say these values "must not be used in mappings," but nothing enforces this. A provider that accidentally omits Side or Unit in a TokenCostMapping literal will compile and run without error, producing a mapping with a zero-value enum that the cost-calculation logic likely cannot handle correctly. Consider adding a validation guard in factory.Add() (which already panics on empty Type/nil New) or in RegisterCostMappings:

🛡️ Suggested validation in factory.Add (illustrative)
// Inside factory.Add, after existing guards:
+for i, m := range reg.CostMappings {
+    if m.RawDataKey == "" {
+        panic(fmt.Sprintf("provider %q: CostMappings[%d] has empty RawDataKey", reg.Type, i))
+    }
+    if m.Side == core.CostSideUnknown {
+        panic(fmt.Sprintf("provider %q: CostMappings[%d] has unknown Side", reg.Type, i))
+    }
+    if m.Unit == core.CostUnitUnknown {
+        panic(fmt.Sprintf("provider %q: CostMappings[%d] has unknown Unit", reg.Type, i))
+    }
+}
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@internal/core/cost_mapping.go` around lines 7 - 18, The CostSideUnknown and
CostUnitUnknown zero-value sentinels can slip into TokenCostMapping when
providers omit Side or Unit, so add validation to reject mappings with those
zero-values: in the factory.Add (and/or RegisterCostMappings) code path where
TokenCostMapping entries are registered (the same place that already panics on
empty Type or nil New), check each mapping's Side against CostSideUnknown and
Unit against CostUnitUnknown and panic or return an error if found; ensure the
validation references TokenCostMapping.Side and TokenCostMapping.Unit to make
failures explicit and prevent silent incorrect mappings.

)

// TokenCostMapping maps a provider-specific RawData key to a pricing field and cost side.
type TokenCostMapping struct {
// RawDataKey is the key in the usage RawData map (e.g. "cached_tokens").
RawDataKey string
// PricingField returns a pointer to the relevant rate from ModelPricing, or nil
// if the base rate already covers this token type.
PricingField func(p *ModelPricing) *float64
// Side indicates whether this cost contributes to input or output.
Side CostSide
// Unit indicates the pricing unit (per million tokens or per item).
Unit CostUnit
}
22 changes: 20 additions & 2 deletions internal/providers/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ import (
var Registration = providers.Registration{
Type: "anthropic",
New: New,
CostMappings: []core.TokenCostMapping{
{RawDataKey: "cache_read_input_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok},
{RawDataKey: "cache_creation_input_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CacheWritePerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok},
},
}

const (
Expand Down Expand Up @@ -498,11 +502,13 @@ func (sc *streamConverter) convertEvent(event *anthropicStreamEvent) string {
}
// Include usage data if present (OpenAI format)
if event.Usage != nil {
chunk["usage"] = map[string]interface{}{
usageMap := map[string]interface{}{
"prompt_tokens": event.Usage.InputTokens,
"completion_tokens": event.Usage.OutputTokens,
"total_tokens": event.Usage.InputTokens + event.Usage.OutputTokens,
}
appendCacheFields(usageMap, event.Usage)
chunk["usage"] = usageMap
}
jsonData, err := json.Marshal(chunk)
if err != nil {
Expand Down Expand Up @@ -668,6 +674,16 @@ func convertAnthropicResponseToResponses(resp *anthropicResponse, model string)
}
}

// appendCacheFields adds non-zero cache token fields from anthropicUsage to the given map.
func appendCacheFields(m map[string]interface{}, u *anthropicUsage) {
if u.CacheReadInputTokens > 0 {
m["cache_read_input_tokens"] = u.CacheReadInputTokens
}
if u.CacheCreationInputTokens > 0 {
m["cache_creation_input_tokens"] = u.CacheCreationInputTokens
}
}

// buildAnthropicRawUsage extracts cache fields from anthropicUsage into a RawData map.
func buildAnthropicRawUsage(u anthropicUsage) map[string]any {
raw := make(map[string]any)
Expand Down Expand Up @@ -784,11 +800,13 @@ func (sc *responsesStreamConverter) Read(p []byte) (n int, err error) {
}
// Include usage data if captured from message_delta
if sc.cachedUsage != nil {
responseData["usage"] = map[string]interface{}{
usageMap := map[string]interface{}{
"input_tokens": sc.cachedUsage.InputTokens,
"output_tokens": sc.cachedUsage.OutputTokens,
"total_tokens": sc.cachedUsage.InputTokens + sc.cachedUsage.OutputTokens,
}
appendCacheFields(usageMap, sc.cachedUsage)
responseData["usage"] = usageMap
}
doneEvent := map[string]interface{}{
"type": "response.completed",
Expand Down
50 changes: 39 additions & 11 deletions internal/providers/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package providers

import (
"fmt"
"sort"
"sync"

"gomodel/config"
Expand All @@ -21,21 +22,23 @@ type ProviderConstructor func(apiKey string, opts ProviderOptions) core.Provider

// Registration contains metadata for registering a provider with the factory.
type Registration struct {
Type string
New ProviderConstructor
Type string
New ProviderConstructor
CostMappings []core.TokenCostMapping // optional: provider-specific token cost mappings
InformationalFields []string // optional: known breakdown fields that need no separate pricing
}

// ProviderFactory manages provider registration and creation.
type ProviderFactory struct {
mu sync.RWMutex
builders map[string]ProviderConstructor
hooks llmclient.Hooks
mu sync.RWMutex
registrations map[string]Registration
hooks llmclient.Hooks
}

// NewProviderFactory creates a new provider factory instance.
func NewProviderFactory() *ProviderFactory {
return &ProviderFactory{
builders: make(map[string]ProviderConstructor),
registrations: make(map[string]Registration),
}
}

Expand All @@ -58,13 +61,13 @@ func (f *ProviderFactory) Add(reg Registration) {
}
f.mu.Lock()
defer f.mu.Unlock()
f.builders[reg.Type] = reg.New
f.registrations[reg.Type] = reg
}

// Create instantiates a provider based on its resolved configuration.
func (f *ProviderFactory) Create(cfg ProviderConfig) (core.Provider, error) {
f.mu.RLock()
builder, ok := f.builders[cfg.Type]
reg, ok := f.registrations[cfg.Type]
hooks := f.hooks
f.mu.RUnlock()

Expand All @@ -77,7 +80,7 @@ func (f *ProviderFactory) Create(cfg ProviderConfig) (core.Provider, error) {
Resilience: cfg.Resilience,
}

p := builder(cfg.APIKey, opts)
p := reg.New(cfg.APIKey, opts)

if cfg.BaseURL != "" {
if setter, ok := p.(interface{ SetBaseURL(string) }); ok {
Expand All @@ -93,9 +96,34 @@ func (f *ProviderFactory) RegisteredTypes() []string {
f.mu.RLock()
defer f.mu.RUnlock()

types := make([]string, 0, len(f.builders))
for t := range f.builders {
types := make([]string, 0, len(f.registrations))
for t := range f.registrations {
types = append(types, t)
}
return types
}

// CostRegistry returns aggregated cost mappings and informational fields from all
// registered providers. The returned map is keyed by provider type.
func (f *ProviderFactory) CostRegistry() (mappings map[string][]core.TokenCostMapping, informationalFields []string) {
f.mu.RLock()
defer f.mu.RUnlock()

mappings = make(map[string][]core.TokenCostMapping)
seen := make(map[string]struct{})

for _, reg := range f.registrations {
if len(reg.CostMappings) > 0 {
mappings[reg.Type] = reg.CostMappings
}
for _, field := range reg.InformationalFields {
if _, ok := seen[field]; !ok {
seen[field] = struct{}{}
informationalFields = append(informationalFields, field)
}
}
}

sort.Strings(informationalFields)
return mappings, informationalFields
}
54 changes: 54 additions & 0 deletions internal/providers/factory_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package providers
import (
"context"
"io"
"sort"
"testing"
"time"

Expand Down Expand Up @@ -343,3 +344,56 @@ func TestProviderFactory_Create_PassesResilienceConfig(t *testing.T) {
t.Errorf("JitterFactor = %f, want 0.5", r.JitterFactor)
}
}

func TestProviderFactory_CostRegistry(t *testing.T) {
factory := NewProviderFactory()

factory.Add(Registration{
Type: "provider-a",
New: func(_ string, _ ProviderOptions) core.Provider { return &factoryMockProvider{} },
CostMappings: []core.TokenCostMapping{
{RawDataKey: "cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok},
},
InformationalFields: []string{"prompt_text_tokens"},
})

factory.Add(Registration{
Type: "provider-b",
New: func(_ string, _ ProviderOptions) core.Provider { return &factoryMockProvider{} },
CostMappings: []core.TokenCostMapping{
{RawDataKey: "thought_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok},
},
InformationalFields: []string{"prompt_text_tokens", "prompt_image_tokens"},
})

// Provider with no cost mappings
factory.Add(Registration{
Type: "provider-c",
New: func(_ string, _ ProviderOptions) core.Provider { return &factoryMockProvider{} },
})

mappings, informational := factory.CostRegistry()

// Check mappings
if len(mappings) != 2 {
t.Fatalf("expected 2 provider mappings, got %d", len(mappings))
}
if len(mappings["provider-a"]) != 1 {
t.Errorf("provider-a: expected 1 mapping, got %d", len(mappings["provider-a"]))
}
if len(mappings["provider-b"]) != 1 {
t.Errorf("provider-b: expected 1 mapping, got %d", len(mappings["provider-b"]))
}
if _, ok := mappings["provider-c"]; ok {
t.Error("provider-c should not have mappings")
}

// Check informational fields are deduplicated
sort.Strings(informational)
if len(informational) != 2 {
t.Fatalf("expected 2 informational fields (deduplicated), got %d: %v", len(informational), informational)
}
if informational[0] != "prompt_image_tokens" || informational[1] != "prompt_text_tokens" {
t.Errorf("unexpected informational fields: %v", informational)
}
}
4 changes: 4 additions & 0 deletions internal/providers/gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ import (
var Registration = providers.Registration{
Type: "gemini",
New: New,
CostMappings: []core.TokenCostMapping{
{RawDataKey: "cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok},
{RawDataKey: "thought_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok},
},
}

const (
Expand Down
14 changes: 14 additions & 0 deletions internal/providers/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,20 @@ import (
var Registration = providers.Registration{
Type: "openai",
New: New,
CostMappings: []core.TokenCostMapping{
{RawDataKey: "cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok},
{RawDataKey: "prompt_cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok},
{RawDataKey: "reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok},
{RawDataKey: "completion_reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok},
{RawDataKey: "prompt_audio_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.AudioInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok},
{RawDataKey: "completion_audio_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.AudioOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok},
},
InformationalFields: []string{
"prompt_text_tokens",
"prompt_image_tokens",
"completion_accepted_prediction_tokens",
"completion_rejected_prediction_tokens",
},
}

const (
Expand Down
7 changes: 7 additions & 0 deletions internal/providers/xai/xai.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ import (
var Registration = providers.Registration{
Type: "xai",
New: New,
CostMappings: []core.TokenCostMapping{
{RawDataKey: "cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok},
{RawDataKey: "prompt_cached_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.CachedInputPerMtok }, Side: core.CostSideInput, Unit: core.CostUnitPerMtok},
{RawDataKey: "reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok},
{RawDataKey: "completion_reasoning_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.ReasoningOutputPerMtok }, Side: core.CostSideOutput, Unit: core.CostUnitPerMtok},
{RawDataKey: "image_tokens", PricingField: func(p *core.ModelPricing) *float64 { return p.InputPerImage }, Side: core.CostSideInput, Unit: core.CostUnitPerItem},
},
}

const (
Expand Down
Loading