Skip to content
Merged
9 changes: 9 additions & 0 deletions cmd/gomodel/docs/docs.go
Original file line number Diff line number Diff line change
Expand Up @@ -1779,6 +1779,9 @@ const docTemplate = `{
"model": {
"type": "string"
},
"provider": {
"type": "string"
},
"reasoning": {
"$ref": "#/definitions/core.Reasoning"
},
Expand Down Expand Up @@ -1879,6 +1882,9 @@ const docTemplate = `{
"input": {},
"model": {
"type": "string"
},
"provider": {
"type": "string"
}
}
},
Expand Down Expand Up @@ -2327,6 +2333,9 @@ const docTemplate = `{
"model": {
"type": "string"
},
"provider": {
"type": "string"
},
"reasoning": {
"$ref": "#/definitions/core.Reasoning"
},
Expand Down
9 changes: 9 additions & 0 deletions cmd/gomodel/docs/swagger.json
Original file line number Diff line number Diff line change
Expand Up @@ -1775,6 +1775,9 @@
"model": {
"type": "string"
},
"provider": {
"type": "string"
},
Comment on lines +1778 to +1780
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🏁 Script executed:

# Check if docs.go is in the changed files for this PR
git diff --name-only | grep -E "cmd/gomodel/docs/docs.go" && echo "docs.go found in changes" || echo "docs.go NOT found in changes"

# If docs.go exists and is changed, check for provider field
if [ -f cmd/gomodel/docs/docs.go ]; then
  echo "=== Checking provider field in docs.go ==="
  grep -n '"provider"' cmd/gomodel/docs/docs.go | head -20
fi

# Also verify the swagger.json has the provider fields mentioned in the review
echo "=== Verifying swagger.json provider fields ==="
grep -n '"provider"' cmd/gomodel/docs/swagger.json | head -10

Repository: ENTERPILOT/GOModel

Length of output: 1420


Commit cmd/gomodel/docs/docs.go alongside swagger.json per coding guidelines.

The coding guidelines explicitly require committing both cmd/gomodel/docs/swagger.json and cmd/gomodel/docs/docs.go together when making API changes. While swagger.json is included in this PR with the provider field additions, docs.go is not in the changed files. Regenerate and commit docs.go with this PR to maintain consistency between the generated Swagger files.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cmd/gomodel/docs/swagger.json` around lines 1778 - 1780, The swagger.json
change added the provider field but docs.go was not regenerated/committed; run
the docs generator used in the project to regenerate cmd/gomodel/docs/docs.go so
it reflects the updated swagger.json (ensuring the new "provider" schema is
included), then add and commit the regenerated docs.go alongside
cmd/gomodel/docs/swagger.json in this PR so both generated artifacts stay in
sync.

"reasoning": {
"$ref": "#/definitions/core.Reasoning"
},
Expand Down Expand Up @@ -1875,6 +1878,9 @@
"input": {},
"model": {
"type": "string"
},
"provider": {
"type": "string"
}
}
},
Expand Down Expand Up @@ -2323,6 +2329,9 @@
"model": {
"type": "string"
},
"provider": {
"type": "string"
},
"reasoning": {
"$ref": "#/definitions/core.Reasoning"
},
Expand Down
60 changes: 60 additions & 0 deletions internal/core/model_selector.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package core

import (
"fmt"
"strings"
)

// ModelSelector is a normalized model routing selector.
// Model is always the raw upstream model ID (without provider prefix).
type ModelSelector struct {
Model string
Provider string
}

// QualifiedModel returns "provider/model" when Provider is set, or only model otherwise.
func (s ModelSelector) QualifiedModel() string {
if s.Provider == "" {
return s.Model
}
return s.Provider + "/" + s.Model
}

// ParseModelSelector normalizes model/provider routing input.
//
// Accepted forms:
// - model only: "gpt-4o"
// - model with prefix: "openai/gpt-4o"
// - explicit provider field: provider="openai", model="gpt-4o"
//
// If provider is present in both places, values must match.
func ParseModelSelector(model, provider string) (ModelSelector, error) {
model = strings.TrimSpace(model)
provider = strings.TrimSpace(provider)

if model == "" {
return ModelSelector{}, fmt.Errorf("model is required")
}

parts := strings.SplitN(model, "/", 2)
if len(parts) == 2 {
prefix := strings.TrimSpace(parts[0])
rest := strings.TrimSpace(parts[1])
if prefix != "" && rest != "" {
if provider != "" && provider != prefix {
return ModelSelector{}, fmt.Errorf("provider field %q conflicts with model prefix %q", provider, prefix)
}
provider = prefix
model = rest
}
}

if model == "" {
return ModelSelector{}, fmt.Errorf("model is required")
}

return ModelSelector{
Model: model,
Provider: provider,
}, nil
}
73 changes: 73 additions & 0 deletions internal/core/model_selector_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package core

import "testing"

func TestParseModelSelector(t *testing.T) {
tests := []struct {
name string
model string
provider string
wantModel string
wantProvider string
wantQualified string
wantErr bool
}{
{
name: "plain model",
model: "gpt-4o",
wantModel: "gpt-4o",
wantProvider: "",
wantQualified: "gpt-4o",
},
{
name: "prefixed model",
model: "openai/gpt-4o",
wantModel: "gpt-4o",
wantProvider: "openai",
wantQualified: "openai/gpt-4o",
},
{
name: "provider field",
model: "gpt-4o",
provider: "openai",
wantModel: "gpt-4o",
wantProvider: "openai",
wantQualified: "openai/gpt-4o",
},
{
name: "provider conflict",
model: "openai/gpt-4o",
provider: "anthropic",
wantErr: true,
},
{
name: "missing model",
model: "",
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
selector, err := ParseModelSelector(tt.model, tt.provider)
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if selector.Model != tt.wantModel {
t.Fatalf("Model = %q, want %q", selector.Model, tt.wantModel)
}
if selector.Provider != tt.wantProvider {
t.Fatalf("Provider = %q, want %q", selector.Provider, tt.wantProvider)
}
if selector.QualifiedModel() != tt.wantQualified {
t.Fatalf("QualifiedModel = %q, want %q", selector.QualifiedModel(), tt.wantQualified)
}
})
}
}
4 changes: 3 additions & 1 deletion internal/core/responses.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package core
// This is the OpenAI-compatible /v1/responses endpoint.
type ResponsesRequest struct {
Model string `json:"model"`
Provider string `json:"provider,omitempty"`
Input interface{} `json:"input" swaggertype:"string" example:"Tell me a joke"` // string or []ResponsesInputItem — see docs for array form
Instructions string `json:"instructions,omitempty"`
Tools []map[string]any `json:"tools,omitempty"`
Expand All @@ -20,6 +21,7 @@ type ResponsesRequest struct {
func (r *ResponsesRequest) WithStreaming() *ResponsesRequest {
return &ResponsesRequest{
Model: r.Model,
Provider: r.Provider,
Input: r.Input,
Instructions: r.Instructions,
Tools: r.Tools,
Expand Down Expand Up @@ -81,7 +83,7 @@ type ResponsesUsage struct {
TotalTokens int `json:"total_tokens"`
PromptTokensDetails *PromptTokensDetails `json:"prompt_tokens_details,omitempty"`
CompletionTokensDetails *CompletionTokensDetails `json:"completion_tokens_details,omitempty"`
RawUsage map[string]any `json:"raw_usage,omitempty"`
RawUsage map[string]any `json:"raw_usage,omitempty"`
}

// ResponsesError represents an error in the response.
Expand Down
3 changes: 3 additions & 0 deletions internal/core/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type ChatRequest struct {
Temperature *float64 `json:"temperature,omitempty"`
MaxTokens *int `json:"max_tokens,omitempty"`
Model string `json:"model"`
Provider string `json:"provider,omitempty"`
Messages []Message `json:"messages"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
Expand All @@ -36,6 +37,7 @@ func (r *ChatRequest) WithStreaming() *ChatRequest {
Temperature: r.Temperature,
MaxTokens: r.MaxTokens,
Model: r.Model,
Provider: r.Provider,
Messages: r.Messages,
Stream: true,
StreamOptions: r.StreamOptions,
Expand Down Expand Up @@ -232,6 +234,7 @@ type ModelsResponse struct {
// EmbeddingRequest represents the incoming embeddings request (OpenAI-compatible).
type EmbeddingRequest struct {
Model string `json:"model"`
Provider string `json:"provider,omitempty"`
Input any `json:"input"`
EncodingFormat string `json:"encoding_format,omitempty"`
Dimensions *int `json:"dimensions,omitempty"`
Expand Down
1 change: 1 addition & 0 deletions internal/providers/responses_adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ type ChatProvider interface {
func ConvertResponsesRequestToChat(req *core.ResponsesRequest) *core.ChatRequest {
chatReq := &core.ChatRequest{
Model: req.Model,
Provider: req.Provider,
Messages: make([]core.Message, 0),
Temperature: req.Temperature,
Stream: req.Stream,
Expand Down
83 changes: 50 additions & 33 deletions internal/providers/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,23 @@ func (r *Router) checkReady() error {
return nil
}

// resolveProvider validates readiness, parses the model selector, and finds the target provider.
func (r *Router) resolveProvider(model, provider string) (core.Provider, core.ModelSelector, error) {
if err := r.checkReady(); err != nil {
return nil, core.ModelSelector{}, err
}
selector, err := core.ParseModelSelector(model, provider)
if err != nil {
return nil, core.ModelSelector{}, core.NewInvalidRequestError(err.Error(), err)
}
lookupModel := selector.QualifiedModel()
p := r.lookup.GetProvider(lookupModel)
if p == nil {
return nil, core.ModelSelector{}, fmt.Errorf("no provider found for model: %s", lookupModel)
}
return p, selector, nil
}

// Supports returns true if any provider supports the given model.
// Returns false if the lookup has no models loaded.
func (r *Router) Supports(model string) bool {
Expand All @@ -52,31 +69,31 @@ func (r *Router) Supports(model string) bool {
// ChatCompletion routes the request to the appropriate provider.
// Returns ErrRegistryNotInitialized if the lookup has no models loaded.
func (r *Router) ChatCompletion(ctx context.Context, req *core.ChatRequest) (*core.ChatResponse, error) {
if err := r.checkReady(); err != nil {
provider, selector, err := r.resolveProvider(req.Model, req.Provider)
if err != nil {
return nil, err
}
provider := r.lookup.GetProvider(req.Model)
if provider == nil {
return nil, fmt.Errorf("no provider found for model: %s", req.Model)
}
resp, err := provider.ChatCompletion(ctx, req)
forwardReq := *req
forwardReq.Model = selector.Model
forwardReq.Provider = ""
resp, err := provider.ChatCompletion(ctx, &forwardReq)
if err == nil && resp != nil {
resp.Provider = r.GetProviderType(req.Model)
resp.Provider = r.GetProviderType(selector.QualifiedModel())
}
return resp, err
}

// StreamChatCompletion routes the streaming request to the appropriate provider.
// Returns ErrRegistryNotInitialized if the lookup has no models loaded.
func (r *Router) StreamChatCompletion(ctx context.Context, req *core.ChatRequest) (io.ReadCloser, error) {
if err := r.checkReady(); err != nil {
provider, selector, err := r.resolveProvider(req.Model, req.Provider)
if err != nil {
return nil, err
}
provider := r.lookup.GetProvider(req.Model)
if provider == nil {
return nil, fmt.Errorf("no provider found for model: %s", req.Model)
}
return provider.StreamChatCompletion(ctx, req)
forwardReq := *req
forwardReq.Model = selector.Model
forwardReq.Provider = ""
return provider.StreamChatCompletion(ctx, &forwardReq)
}

// ListModels returns all models from the lookup.
Expand All @@ -95,45 +112,45 @@ func (r *Router) ListModels(_ context.Context) (*core.ModelsResponse, error) {
// Responses routes the Responses API request to the appropriate provider.
// Returns ErrRegistryNotInitialized if the lookup has no models loaded.
func (r *Router) Responses(ctx context.Context, req *core.ResponsesRequest) (*core.ResponsesResponse, error) {
if err := r.checkReady(); err != nil {
provider, selector, err := r.resolveProvider(req.Model, req.Provider)
if err != nil {
return nil, err
}
provider := r.lookup.GetProvider(req.Model)
if provider == nil {
return nil, fmt.Errorf("no provider found for model: %s", req.Model)
}
resp, err := provider.Responses(ctx, req)
forwardReq := *req
forwardReq.Model = selector.Model
forwardReq.Provider = ""
resp, err := provider.Responses(ctx, &forwardReq)
if err == nil && resp != nil {
resp.Provider = r.GetProviderType(req.Model)
resp.Provider = r.GetProviderType(selector.QualifiedModel())
}
return resp, err
}

// StreamResponses routes the streaming Responses API request to the appropriate provider.
// Returns ErrRegistryNotInitialized if the lookup has no models loaded.
func (r *Router) StreamResponses(ctx context.Context, req *core.ResponsesRequest) (io.ReadCloser, error) {
if err := r.checkReady(); err != nil {
provider, selector, err := r.resolveProvider(req.Model, req.Provider)
if err != nil {
return nil, err
}
provider := r.lookup.GetProvider(req.Model)
if provider == nil {
return nil, fmt.Errorf("no provider found for model: %s", req.Model)
}
return provider.StreamResponses(ctx, req)
forwardReq := *req
forwardReq.Model = selector.Model
forwardReq.Provider = ""
return provider.StreamResponses(ctx, &forwardReq)
}

// Embeddings routes the embeddings request to the appropriate provider.
func (r *Router) Embeddings(ctx context.Context, req *core.EmbeddingRequest) (*core.EmbeddingResponse, error) {
if err := r.checkReady(); err != nil {
provider, selector, err := r.resolveProvider(req.Model, req.Provider)
if err != nil {
return nil, err
}
provider := r.lookup.GetProvider(req.Model)
if provider == nil {
return nil, fmt.Errorf("no provider found for model: %s", req.Model)
}
resp, err := provider.Embeddings(ctx, req)
forwardReq := *req
forwardReq.Model = selector.Model
forwardReq.Provider = ""
resp, err := provider.Embeddings(ctx, &forwardReq)
if err == nil && resp != nil {
resp.Provider = r.GetProviderType(req.Model)
resp.Provider = r.GetProviderType(selector.QualifiedModel())
}
return resp, err
}
Expand Down
Loading