Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions catalog/opencodego/opencodego_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ func TestProtocolForModel(t *testing.T) {
}

func TestUsesMessagesAPI_HeuristicFallback(t *testing.T) {
t.Parallel()
// Reset map so we test the heuristic fallback.
ResetProtocolMap()
t.Cleanup(ResetProtocolMap)
tests := []struct {
model string
want bool
Expand All @@ -75,8 +75,8 @@ func TestUsesMessagesAPI_HeuristicFallback(t *testing.T) {
}

func TestUsesMessagesAPI_DynamicMapOverrides(t *testing.T) {
t.Parallel()
ResetProtocolMap()
t.Cleanup(ResetProtocolMap)
// Simulate live fetch returning protocol data.
UpdateProtocolMap([]struct{ ID, Protocol string }{
{"kimi-k2.6", "openai"},
Expand All @@ -100,13 +100,11 @@ func TestUsesMessagesAPI_DynamicMapOverrides(t *testing.T) {
if UsesMessagesAPI("totally-new-model") {
t.Error("totally-new-model should default to openai (heuristic fallback)")
}

ResetProtocolMap()
}

func TestProtocolMapSnapshot(t *testing.T) {
t.Parallel()
ResetProtocolMap()
t.Cleanup(ResetProtocolMap)
UpdateProtocolMap([]struct{ ID, Protocol string }{
{"kimi-k2.6", "openai"},
{"minimax-m3", "anthropic"},
Expand All @@ -118,7 +116,6 @@ func TestProtocolMapSnapshot(t *testing.T) {
if snap["minimax-m3"] != "anthropic" {
t.Errorf("snapshot minimax-m3 = %q, want anthropic", snap["minimax-m3"])
}
ResetProtocolMap()
}

func TestUsageTracker_RecordAndSpend(t *testing.T) {
Expand Down
76 changes: 76 additions & 0 deletions catalog/registry/derive.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,82 @@ func DisplayName(providerID string) string {
return providerID
}

// ChatProviderPreferenceOrder returns provider ids ordered by chat/runtime preference.
func ChatProviderPreferenceOrder() []string {
specs := DefaultRegistry.All()
sort.Slice(specs, func(i, j int) bool {
left := specs[i].ChatPreference
right := specs[j].ChatPreference
if left == 0 {
left = specs[i].SortOrder + 10_000
}
if right == 0 {
right = specs[j].SortOrder + 10_000
}
if left != right {
return left < right
}
return specs[i].ProviderID < specs[j].ProviderID
})
out := make([]string, 0, len(specs))
for _, spec := range specs {
if spec.ProviderID != "" {
out = append(out, spec.ProviderID)
}
}
return out
}

// RuntimeProfileKey returns the config runtime-profile key for a provider.
func RuntimeProfileKey(providerID string) string {
if spec, ok := SpecByProviderID(providerID); ok {
return strings.TrimSpace(spec.RuntimeProfileKey)
}
return ""
}

// DirectFallbackProviderIDs returns direct-provider fallback ids for providerID.
func DirectFallbackProviderIDs(providerID string) []string {
spec, ok := SpecByProviderID(providerID)
if !ok || len(spec.DirectFallbacks) == 0 {
return nil
}
out := make([]string, 0, len(spec.DirectFallbacks))
for _, id := range spec.DirectFallbacks {
if trimmed := strings.TrimSpace(id); trimmed != "" {
out = append(out, trimmed)
}
}
return out
}

// CredentialAliases returns compatibility env var names for providerID.
func CredentialAliases(providerID string) []string {
spec, ok := SpecByProviderID(providerID)
if !ok || len(spec.CredentialAliases) == 0 {
return nil
}
out := make([]string, 0, len(spec.CredentialAliases))
for _, env := range spec.CredentialAliases {
if trimmed := strings.TrimSpace(env); trimmed != "" {
out = append(out, trimmed)
}
}
return out
}

// CredentialEnvPreparedProviders returns providers that need config-derived env before discovery.
func CredentialEnvPreparedProviders() []string {
var out []string
for _, spec := range DefaultRegistry.All() {
if spec.PrepareCredentialEnv && spec.ProviderID != "" {
out = append(out, spec.ProviderID)
}
}
sort.Strings(out)
return out
}

// CredentialRegistry derives credential rows from provider specs.
func CredentialRegistry() []CredentialSpec {
specs := DefaultRegistry.All()
Expand Down
38 changes: 38 additions & 0 deletions catalog/registry/provider_spec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,44 @@ func TestOpenCodeGo_HasProbeBaseURL(t *testing.T) {
}
}

func TestProviderRuntimePolicy_Metadata(t *testing.T) {
t.Parallel()

order := registry.ChatProviderPreferenceOrder()
if len(order) < 3 {
t.Fatalf("runtime preference order too short: %v", order)
}
if order[0] != "openai" || order[1] != "anthropic" || order[2] != "openrouter" {
t.Fatalf("unexpected runtime preference prefix: %v", order[:3])
}

if got := registry.DirectFallbackProviderIDs("openai"); len(got) != 1 || got[0] != "anthropic" {
t.Fatalf("openai direct fallbacks = %v, want [anthropic]", got)
}
if got := registry.DirectFallbackProviderIDs("anthropic"); len(got) != 1 || got[0] != "openai" {
t.Fatalf("anthropic direct fallbacks = %v, want [openai]", got)
}

if got := registry.CredentialAliases("anthropic"); len(got) != 1 || got[0] != "CLAUDE_API_KEY" {
t.Fatalf("anthropic credential aliases = %v", got)
}

prepared := registry.CredentialEnvPreparedProviders()
wantPrepared := map[string]bool{
"xiaomi_mimo_token_plan": true,
"zai_coding": true,
"zai_payg": true,
}
if len(prepared) != len(wantPrepared) {
t.Fatalf("prepared providers = %v", prepared)
}
for _, providerID := range prepared {
if !wantPrepared[providerID] {
t.Fatalf("unexpected prepared provider %q in %v", providerID, prepared)
}
}
}

func TestProviderSpecs_TableDriven(t *testing.T) {
t.Parallel()
tests := []struct {
Expand Down
Loading
Loading