Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions autorouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -961,6 +961,11 @@ func normalizeProviderRequest(raw map[string]any, providerName string) {
return
}

if providerName == "anthropic" {
normalizeAnthropicRequest(raw)
return
}

if providerName != "deepseek" {
return
}
Expand Down Expand Up @@ -993,6 +998,49 @@ func normalizeProviderRequest(raw map[string]any, providerName string) {
}
}

const defaultAnthropicMaxTokens = 1024

func normalizeAnthropicRequest(raw map[string]any) {
if hasPositiveNumber(raw["max_tokens"]) {
return
}
raw["max_tokens"] = defaultAnthropicMaxTokens
}

func hasPositiveNumber(value any) bool {
switch v := value.(type) {
case int:
return v > 0
case int8:
return v > 0
case int16:
return v > 0
case int32:
return v > 0
case int64:
return v > 0
case uint:
return v > 0
case uint8:
return v > 0
case uint16:
return v > 0
case uint32:
return v > 0
case uint64:
return v > 0
case float32:
return v > 0
case float64:
return v > 0
case json.Number:
f, err := v.Float64()
return err == nil && f > 0
default:
return false
}
}

func normalizeGoogleAIRequest(raw map[string]any) {
delete(raw, "stream")

Expand Down
95 changes: 95 additions & 0 deletions autorouter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1411,6 +1411,101 @@ func TestAutoRouter_AnthropicStreamingNoStreamOptions(t *testing.T) {
}
}

func TestAutoRouter_AnthropicDefaultMaxTokens(t *testing.T) {
var receivedBody map[string]any
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &receivedBody)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"msg_test","type":"message","model":"claude-3-opus","content":[{"type":"text","text":"Hello"}],"usage":{"input_tokens":8,"output_tokens":1}}`))
}))
defer upstream.Close()

provider := &mockProvider{
name: "anthropic",
parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) {
data, _ := io.ReadAll(body)
var raw map[string]any
_ = json.Unmarshal(data, &raw)
maxTokens, _ := raw["max_tokens"].(float64)
return BodyMetadata{Model: "claude-3-opus", MaxTokens: int(maxTokens)}, data, nil
},
enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil },
resolveFn: func(meta BodyMetadata) (*url.URL, error) {
return url.Parse(upstream.URL)
},
extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) {
body, _ := io.ReadAll(resp.Body)
return ResponseMetadata{ID: "msg_test"}, body, nil
},
}

router := NewAutoRouter(
WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "anthropic" })),
)
router.RegisterProvider(provider)

req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"claude-3-opus","messages":[{"role":"user","content":"Hello"}]}`)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

router.ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Fatalf("StatusCode = %d, want 200", w.Code)
}
if got := receivedBody["max_tokens"]; got != float64(defaultAnthropicMaxTokens) {
t.Fatalf("max_tokens = %v, want %d", got, defaultAnthropicMaxTokens)
}
}

func TestAutoRouter_AnthropicPreservesMaxTokens(t *testing.T) {
var receivedBody map[string]any
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &receivedBody)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"id":"msg_test","type":"message","model":"claude-3-opus","content":[{"type":"text","text":"Hello"}],"usage":{"input_tokens":8,"output_tokens":1}}`))
}))
defer upstream.Close()

provider := &mockProvider{
name: "anthropic",
parseFn: func(body io.ReadCloser) (BodyMetadata, []byte, error) {
data, _ := io.ReadAll(body)
return BodyMetadata{Model: "claude-3-opus", MaxTokens: 64}, data, nil
},
enrichFn: func(req *http.Request, meta BodyMetadata, body []byte) error { return nil },
resolveFn: func(meta BodyMetadata) (*url.URL, error) {
return url.Parse(upstream.URL)
},
extractFn: func(resp *http.Response) (ResponseMetadata, []byte, error) {
body, _ := io.ReadAll(resp.Body)
return ResponseMetadata{ID: "msg_test"}, body, nil
},
}

router := NewAutoRouter(
WithAutoRouterDetector(ProviderDetectorFunc(func(hint ProviderHint) string { return "anthropic" })),
)
router.RegisterProvider(provider)

req := httptest.NewRequestWithContext(context.Background(), "POST", "/", bytes.NewReader([]byte(`{"model":"claude-3-opus","max_tokens":64,"messages":[{"role":"user","content":"Hello"}]}`)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()

router.ServeHTTP(w, req)

if w.Code != http.StatusOK {
t.Fatalf("StatusCode = %d, want 200", w.Code)
}
if got := receivedBody["max_tokens"]; got != float64(64) {
t.Fatalf("max_tokens = %v, want 64", got)
}
}

func TestAutoRouter_StreamingWritesGatewayMetadataEvent(t *testing.T) {
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
Expand Down
Loading