-
Notifications
You must be signed in to change notification settings - Fork 0
feat(gateway): enforce agent decision budgets #10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,86 @@ | ||
| package auth | ||
|
|
||
| import ( | ||
| "context" | ||
| "crypto/hmac" | ||
| "crypto/sha256" | ||
| "encoding/base64" | ||
| "net/http/httptest" | ||
| "strings" | ||
| "testing" | ||
| "time" | ||
| ) | ||
|
|
||
| func TestHS256VerifierAcceptsValidSupabaseToken(t *testing.T) { | ||
| verifier := hs256Verifier{ | ||
| secret: []byte("test-secret"), | ||
| now: func() time.Time { return time.Unix(100, 0).UTC() }, | ||
| } | ||
| token := signTestJWT(t, "test-secret", `{"alg":"HS256","typ":"JWT"}`, `{"sub":"player-1","role":"authenticated","email":"p@example.test","exp":200}`) | ||
|
|
||
| identity, err := verifier.Verify(context.Background(), token) | ||
| if err != nil { | ||
| t.Fatalf("expected token to verify: %v", err) | ||
| } | ||
| if identity.PlayerID != "player-1" { | ||
| t.Fatalf("expected player id from sub, got %q", identity.PlayerID) | ||
| } | ||
| if identity.Role != "authenticated" { | ||
| t.Fatalf("expected role from token, got %q", identity.Role) | ||
| } | ||
| } | ||
|
|
||
| func TestHS256VerifierRejectsInvalidSignature(t *testing.T) { | ||
| verifier := hs256Verifier{ | ||
| secret: []byte("test-secret"), | ||
| now: func() time.Time { return time.Unix(100, 0).UTC() }, | ||
| } | ||
| token := signTestJWT(t, "other-secret", `{"alg":"HS256","typ":"JWT"}`, `{"sub":"player-1","exp":200}`) | ||
|
|
||
| if _, err := verifier.Verify(context.Background(), token); err != ErrInvalidJWT { | ||
| t.Fatalf("expected ErrInvalidJWT, got %v", err) | ||
| } | ||
| } | ||
|
|
||
| func TestHS256VerifierRejectsExpiredToken(t *testing.T) { | ||
| verifier := hs256Verifier{ | ||
| secret: []byte("test-secret"), | ||
| now: func() time.Time { return time.Unix(300, 0).UTC() }, | ||
| } | ||
| token := signTestJWT(t, "test-secret", `{"alg":"HS256","typ":"JWT"}`, `{"sub":"player-1","exp":200}`) | ||
|
|
||
| if _, err := verifier.Verify(context.Background(), token); err != ErrInvalidJWT { | ||
| t.Fatalf("expected ErrInvalidJWT, got %v", err) | ||
| } | ||
| } | ||
|
|
||
| func TestFromRequestRejectsOversizedBearerToken(t *testing.T) { | ||
| req := httptest.NewRequest("POST", "/v1/agent/decide", nil) | ||
| req.Header.Set("Authorization", "Bearer "+strings.Repeat("a", maxBearerTokenBytes+1)) | ||
|
|
||
| if _, err := FromRequest(req); err != ErrInvalidJWT { | ||
| t.Fatalf("expected ErrInvalidJWT, got %v", err) | ||
| } | ||
| } | ||
|
|
||
| func TestHS256VerifierRejectsOversizedJWTPart(t *testing.T) { | ||
| verifier := hs256Verifier{ | ||
| secret: []byte("test-secret"), | ||
| now: func() time.Time { return time.Unix(100, 0).UTC() }, | ||
| } | ||
| token := strings.Repeat("a", maxJWTPartBytes+1) + ".payload.signature" | ||
|
|
||
| if _, err := verifier.Verify(context.Background(), token); err != ErrInvalidJWT { | ||
| t.Fatalf("expected ErrInvalidJWT, got %v", err) | ||
| } | ||
| } | ||
|
|
||
| func signTestJWT(t *testing.T, secret string, header string, claims string) string { | ||
| t.Helper() | ||
| encodedHeader := base64.RawURLEncoding.EncodeToString([]byte(header)) | ||
| encodedClaims := base64.RawURLEncoding.EncodeToString([]byte(claims)) | ||
| unsigned := encodedHeader + "." + encodedClaims | ||
| mac := hmac.New(sha256.New, []byte(secret)) | ||
| mac.Write([]byte(unsigned)) | ||
| return unsigned + "." + base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| package server | ||
|
|
||
| import ( | ||
| "strings" | ||
| "sync" | ||
| "time" | ||
|
|
||
| "github.com/DOS/Second-Spawn/backend/gateway/internal/config" | ||
| ) | ||
|
|
||
| type agentDecisionLimitResult struct { | ||
| Error string `json:"error"` | ||
| Reason string `json:"reason"` | ||
| PlayerID string `json:"player_id"` | ||
| RetryAfterSeconds int64 `json:"retry_after_seconds,omitempty"` | ||
| TokenEstimate int `json:"token_estimate,omitempty"` | ||
| TokenBudgetPerDay int `json:"token_budget_per_day,omitempty"` | ||
| TokenBudgetUsedToday int `json:"token_budget_used_today,omitempty"` | ||
| TokenBudgetRemaining int `json:"token_budget_remaining,omitempty"` | ||
| } | ||
|
|
||
| type agentDecisionLimiter struct { | ||
| mu sync.Mutex | ||
| cfg *config.Config | ||
| now func() time.Time | ||
| lastPruned time.Time | ||
| // TODO(#13): Move limiter state to Redis or another shared store before | ||
| // running more than one gateway instance. | ||
| players map[string]*agentDecisionLimitState | ||
| } | ||
|
Comment on lines
+22
to
+30
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The current limiter implementation is entirely in-memory. While this works for a single-instance deployment, it will not correctly enforce budgets or rate limits across multiple gateway instances. Given that the configuration already includes a |
||
|
|
||
| type agentDecisionLimitState struct { | ||
| minuteStart time.Time | ||
| minuteCount int | ||
| day string | ||
| tokensUsed int | ||
| lastSeen time.Time | ||
| } | ||
|
|
||
| const agentDecisionLimitStateTTL = 25 * time.Hour | ||
|
|
||
| func newAgentDecisionLimiter(cfg *config.Config, now func() time.Time) *agentDecisionLimiter { | ||
| if cfg == nil { | ||
| cfg = &config.Config{} | ||
| } | ||
| if now == nil { | ||
| now = func() time.Time { return time.Now().UTC() } | ||
| } | ||
| return &agentDecisionLimiter{ | ||
| cfg: cfg, | ||
| now: now, | ||
| players: map[string]*agentDecisionLimitState{}, | ||
| } | ||
| } | ||
|
|
||
| func (l *agentDecisionLimiter) Allow(playerID string, tokenEstimate int) (bool, agentDecisionLimitResult) { | ||
| if l == nil || l.cfg == nil || !l.enabled() { | ||
| return true, agentDecisionLimitResult{} | ||
| } | ||
|
|
||
| playerID = normalizeLimitPlayerID(playerID) | ||
| tokenEstimate = max(tokenEstimate, 1) | ||
| now := l.now().UTC() | ||
| minuteStart := now.Truncate(time.Minute) | ||
| day := now.Format("2006-01-02") | ||
|
|
||
| l.mu.Lock() | ||
| defer l.mu.Unlock() | ||
|
|
||
| l.pruneExpiredIfDue(now) | ||
| state := l.playerState(playerID, minuteStart, day) | ||
| state.resetWindows(minuteStart, day) | ||
| state.lastSeen = now | ||
| if result, blocked := state.rateLimitResult(playerID, l.cfg.LLMRateLimitPerPlayerPerMin, now); blocked { | ||
| return false, result | ||
| } | ||
| if result, blocked := state.tokenBudgetResult(playerID, tokenEstimate, l.cfg.LLMTokenBudgetPerPlayerDay); blocked { | ||
| return false, result | ||
| } | ||
|
|
||
| state.minuteCount++ | ||
| state.tokensUsed += tokenEstimate | ||
| return true, agentDecisionLimitResult{} | ||
| } | ||
|
|
||
| func (l *agentDecisionLimiter) enabled() bool { | ||
| return l.cfg.LLMRateLimitPerPlayerPerMin > 0 || l.cfg.LLMTokenBudgetPerPlayerDay > 0 | ||
| } | ||
|
|
||
| func (l *agentDecisionLimiter) playerState(playerID string, minuteStart time.Time, day string) *agentDecisionLimitState { | ||
| state := l.players[playerID] | ||
| if state != nil { | ||
| return state | ||
| } | ||
| state = &agentDecisionLimitState{minuteStart: minuteStart, day: day} | ||
| l.players[playerID] = state | ||
| return state | ||
| } | ||
|
|
||
| func (l *agentDecisionLimiter) pruneExpiredIfDue(now time.Time) { | ||
| if !l.lastPruned.IsZero() && now.Sub(l.lastPruned) < time.Minute { | ||
| return | ||
| } | ||
| l.lastPruned = now | ||
|
|
||
| cutoff := now.Add(-agentDecisionLimitStateTTL) | ||
| for playerID, state := range l.players { | ||
| if !state.lastSeen.IsZero() && state.lastSeen.Before(cutoff) { | ||
| delete(l.players, playerID) | ||
| } | ||
| } | ||
| } | ||
|
Comment on lines
+100
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Comment on lines
+100
to
+112
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
| func (s *agentDecisionLimitState) resetWindows(minuteStart time.Time, day string) { | ||
| if !s.minuteStart.Equal(minuteStart) { | ||
| s.minuteStart = minuteStart | ||
| s.minuteCount = 0 | ||
| } | ||
| if s.day != day { | ||
| s.day = day | ||
| s.tokensUsed = 0 | ||
| } | ||
| } | ||
|
|
||
| func (s *agentDecisionLimitState) rateLimitResult(playerID string, rateLimit int, now time.Time) (agentDecisionLimitResult, bool) { | ||
| if rateLimit <= 0 || s.minuteCount < rateLimit { | ||
| return agentDecisionLimitResult{}, false | ||
| } | ||
| retryAfter := s.minuteStart.Add(time.Minute).Sub(now) | ||
| if retryAfter < time.Second { | ||
| retryAfter = time.Second | ||
| } | ||
| return agentDecisionLimitResult{ | ||
| Error: "agent decision rate limit exceeded", | ||
| Reason: "rate_limit_exceeded", | ||
| PlayerID: playerID, | ||
| RetryAfterSeconds: int64(retryAfter.Seconds()), | ||
| }, true | ||
| } | ||
|
|
||
| func (s *agentDecisionLimitState) tokenBudgetResult(playerID string, tokenEstimate int, tokenBudget int) (agentDecisionLimitResult, bool) { | ||
| if tokenBudget <= 0 || s.tokensUsed+tokenEstimate <= tokenBudget { | ||
| return agentDecisionLimitResult{}, false | ||
| } | ||
| return agentDecisionLimitResult{ | ||
| Error: "agent decision token budget exhausted", | ||
| Reason: "token_budget_exhausted", | ||
| PlayerID: playerID, | ||
| TokenEstimate: tokenEstimate, | ||
| TokenBudgetPerDay: tokenBudget, | ||
| TokenBudgetUsedToday: s.tokensUsed, | ||
| TokenBudgetRemaining: max(tokenBudget-s.tokensUsed, 0), | ||
| }, true | ||
| } | ||
|
|
||
| func normalizeLimitPlayerID(playerID string) string { | ||
| playerID = strings.TrimSpace(playerID) | ||
| if playerID == "" { | ||
| return "unknown" | ||
| } | ||
| return playerID | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The JWT header is decoded and its algorithm is checked before the signature is verified. While this is necessary to determine the verification method, the
decodeJWTPartfunction usesjson.Unmarshalon unauthenticated input. A maliciously crafted JWT with an extremely large header could lead to excessive memory consumption or a Denial of Service (DoS) before signature verification occurs. Consider adding a maximum length check for the JWT string or its parts before processing.