Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions backend/gateway/internal/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ package auth

import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"errors"
"net/http"
"strings"
"time"
)

// PlayerID is the Supabase user ID extracted from the JWT.
Expand All @@ -36,6 +41,12 @@ var ErrMissingAuth = errors.New("missing Authorization header")
// signature, wrong issuer).
var ErrInvalidJWT = errors.New("invalid JWT")

const (
maxAuthorizationHeaderBytes = 8 * 1024
maxBearerTokenBytes = 6 * 1024
maxJWTPartBytes = 3 * 1024
)

// Verifier validates Supabase JWTs and extracts the player ID.
// The concrete impl uses the Supabase project's JWT secret (HS256) -
// no network call to Supabase per request, the secret is enough to
Expand All @@ -44,13 +55,25 @@ type Verifier interface {
Verify(ctx context.Context, jwt string) (Identity, error)
}

// NewHS256Verifier returns a local Supabase-compatible JWT verifier.
func NewHS256Verifier(secret string) Verifier {
secret = strings.TrimSpace(secret)
if secret == "" {
return nil
}
return hs256Verifier{secret: []byte(secret), now: func() time.Time { return time.Now().UTC() }}
}

// FromRequest extracts the bearer token from an HTTP request.
// Returns ErrMissingAuth if the header is absent or malformed.
func FromRequest(r *http.Request) (string, error) {
authHdr := r.Header.Get("Authorization")
if authHdr == "" {
return "", ErrMissingAuth
}
if len(authHdr) > maxAuthorizationHeaderBytes {
return "", ErrInvalidJWT
}
const prefix = "Bearer "
if !strings.HasPrefix(authHdr, prefix) {
return "", ErrMissingAuth
Expand All @@ -59,5 +82,87 @@ func FromRequest(r *http.Request) (string, error) {
if token == "" {
return "", ErrMissingAuth
}
if len(token) > maxBearerTokenBytes {
return "", ErrInvalidJWT
}
return token, nil
}

type hs256Verifier struct {
secret []byte
now func() time.Time
}

type jwtHeader struct {
Algorithm string `json:"alg"`
Type string `json:"typ"`
}

type supabaseClaims struct {
Subject string `json:"sub"`
Role string `json:"role"`
Email string `json:"email"`
IsAnonymous bool `json:"is_anonymous"`
ExpiresAt int64 `json:"exp"`
}

func (v hs256Verifier) Verify(_ context.Context, jwt string) (Identity, error) {
if len(jwt) > maxBearerTokenBytes {
return Identity{}, ErrInvalidJWT
}
parts := strings.Split(jwt, ".")
if len(parts) != 3 {
return Identity{}, ErrInvalidJWT
}
for _, part := range parts {
if part == "" || len(part) > maxJWTPartBytes {
return Identity{}, ErrInvalidJWT
}
}

var header jwtHeader
if err := decodeJWTPart(parts[0], &header); err != nil {
return Identity{}, ErrInvalidJWT
}
Comment on lines +124 to +126
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The JWT header is decoded and its algorithm is checked before the signature is verified. While this is necessary to determine the verification method, the decodeJWTPart function uses json.Unmarshal on unauthenticated input. A maliciously crafted JWT with an extremely large header could lead to excessive memory consumption or a Denial of Service (DoS) before signature verification occurs. Consider adding a maximum length check for the JWT string or its parts before processing.

if header.Algorithm != "HS256" {
return Identity{}, ErrInvalidJWT
}

mac := hmac.New(sha256.New, v.secret)
mac.Write([]byte(parts[0]))
mac.Write([]byte("."))
mac.Write([]byte(parts[1]))
expected := mac.Sum(nil)

signature, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil || !hmac.Equal(signature, expected) {
return Identity{}, ErrInvalidJWT
}

var claims supabaseClaims
if err := decodeJWTPart(parts[1], &claims); err != nil {
return Identity{}, ErrInvalidJWT
}
if strings.TrimSpace(claims.Subject) == "" {
return Identity{}, ErrInvalidJWT
}
if claims.ExpiresAt > 0 && v.now().Unix() >= claims.ExpiresAt {
return Identity{}, ErrInvalidJWT
}

return Identity{
PlayerID: PlayerID(claims.Subject),
Role: claims.Role,
Email: claims.Email,
IsAnonymous: claims.IsAnonymous,
ExpiresAt: claims.ExpiresAt,
}, nil
}

func decodeJWTPart(part string, target any) error {
payload, err := base64.RawURLEncoding.DecodeString(part)
if err != nil {
return err
}
return json.Unmarshal(payload, target)
}
86 changes: 86 additions & 0 deletions backend/gateway/internal/auth/auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package auth

import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"net/http/httptest"
"strings"
"testing"
"time"
)

func TestHS256VerifierAcceptsValidSupabaseToken(t *testing.T) {
verifier := hs256Verifier{
secret: []byte("test-secret"),
now: func() time.Time { return time.Unix(100, 0).UTC() },
}
token := signTestJWT(t, "test-secret", `{"alg":"HS256","typ":"JWT"}`, `{"sub":"player-1","role":"authenticated","email":"p@example.test","exp":200}`)

identity, err := verifier.Verify(context.Background(), token)
if err != nil {
t.Fatalf("expected token to verify: %v", err)
}
if identity.PlayerID != "player-1" {
t.Fatalf("expected player id from sub, got %q", identity.PlayerID)
}
if identity.Role != "authenticated" {
t.Fatalf("expected role from token, got %q", identity.Role)
}
}

func TestHS256VerifierRejectsInvalidSignature(t *testing.T) {
verifier := hs256Verifier{
secret: []byte("test-secret"),
now: func() time.Time { return time.Unix(100, 0).UTC() },
}
token := signTestJWT(t, "other-secret", `{"alg":"HS256","typ":"JWT"}`, `{"sub":"player-1","exp":200}`)

if _, err := verifier.Verify(context.Background(), token); err != ErrInvalidJWT {
t.Fatalf("expected ErrInvalidJWT, got %v", err)
}
}

func TestHS256VerifierRejectsExpiredToken(t *testing.T) {
verifier := hs256Verifier{
secret: []byte("test-secret"),
now: func() time.Time { return time.Unix(300, 0).UTC() },
}
token := signTestJWT(t, "test-secret", `{"alg":"HS256","typ":"JWT"}`, `{"sub":"player-1","exp":200}`)

if _, err := verifier.Verify(context.Background(), token); err != ErrInvalidJWT {
t.Fatalf("expected ErrInvalidJWT, got %v", err)
}
}

func TestFromRequestRejectsOversizedBearerToken(t *testing.T) {
req := httptest.NewRequest("POST", "/v1/agent/decide", nil)
req.Header.Set("Authorization", "Bearer "+strings.Repeat("a", maxBearerTokenBytes+1))

if _, err := FromRequest(req); err != ErrInvalidJWT {
t.Fatalf("expected ErrInvalidJWT, got %v", err)
}
}

func TestHS256VerifierRejectsOversizedJWTPart(t *testing.T) {
verifier := hs256Verifier{
secret: []byte("test-secret"),
now: func() time.Time { return time.Unix(100, 0).UTC() },
}
token := strings.Repeat("a", maxJWTPartBytes+1) + ".payload.signature"

if _, err := verifier.Verify(context.Background(), token); err != ErrInvalidJWT {
t.Fatalf("expected ErrInvalidJWT, got %v", err)
}
}

func signTestJWT(t *testing.T, secret string, header string, claims string) string {
t.Helper()
encodedHeader := base64.RawURLEncoding.EncodeToString([]byte(header))
encodedClaims := base64.RawURLEncoding.EncodeToString([]byte(claims))
unsigned := encodedHeader + "." + encodedClaims
mac := hmac.New(sha256.New, []byte(secret))
mac.Write([]byte(unsigned))
return unsigned + "." + base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
}
162 changes: 162 additions & 0 deletions backend/gateway/internal/server/agent_decision_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package server

import (
"strings"
"sync"
"time"

"github.com/DOS/Second-Spawn/backend/gateway/internal/config"
)

type agentDecisionLimitResult struct {
Error string `json:"error"`
Reason string `json:"reason"`
PlayerID string `json:"player_id"`
RetryAfterSeconds int64 `json:"retry_after_seconds,omitempty"`
TokenEstimate int `json:"token_estimate,omitempty"`
TokenBudgetPerDay int `json:"token_budget_per_day,omitempty"`
TokenBudgetUsedToday int `json:"token_budget_used_today,omitempty"`
TokenBudgetRemaining int `json:"token_budget_remaining,omitempty"`
}

type agentDecisionLimiter struct {
mu sync.Mutex
cfg *config.Config
now func() time.Time
lastPruned time.Time
// TODO(#13): Move limiter state to Redis or another shared store before
// running more than one gateway instance.
players map[string]*agentDecisionLimitState
}
Comment on lines +22 to +30
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current limiter implementation is entirely in-memory. While this works for a single-instance deployment, it will not correctly enforce budgets or rate limits across multiple gateway instances. Given that the configuration already includes a RedisURL, consider transitioning this state to Redis to support horizontal scaling and persistence across restarts.


type agentDecisionLimitState struct {
minuteStart time.Time
minuteCount int
day string
tokensUsed int
lastSeen time.Time
}

const agentDecisionLimitStateTTL = 25 * time.Hour

func newAgentDecisionLimiter(cfg *config.Config, now func() time.Time) *agentDecisionLimiter {
if cfg == nil {
cfg = &config.Config{}
}
if now == nil {
now = func() time.Time { return time.Now().UTC() }
}
return &agentDecisionLimiter{
cfg: cfg,
now: now,
players: map[string]*agentDecisionLimitState{},
}
}

func (l *agentDecisionLimiter) Allow(playerID string, tokenEstimate int) (bool, agentDecisionLimitResult) {
if l == nil || l.cfg == nil || !l.enabled() {
return true, agentDecisionLimitResult{}
}

playerID = normalizeLimitPlayerID(playerID)
tokenEstimate = max(tokenEstimate, 1)
now := l.now().UTC()
minuteStart := now.Truncate(time.Minute)
day := now.Format("2006-01-02")

l.mu.Lock()
defer l.mu.Unlock()

l.pruneExpiredIfDue(now)
state := l.playerState(playerID, minuteStart, day)
state.resetWindows(minuteStart, day)
state.lastSeen = now
if result, blocked := state.rateLimitResult(playerID, l.cfg.LLMRateLimitPerPlayerPerMin, now); blocked {
return false, result
}
if result, blocked := state.tokenBudgetResult(playerID, tokenEstimate, l.cfg.LLMTokenBudgetPerPlayerDay); blocked {
return false, result
}

state.minuteCount++
state.tokensUsed += tokenEstimate
return true, agentDecisionLimitResult{}
}

func (l *agentDecisionLimiter) enabled() bool {
return l.cfg.LLMRateLimitPerPlayerPerMin > 0 || l.cfg.LLMTokenBudgetPerPlayerDay > 0
}

func (l *agentDecisionLimiter) playerState(playerID string, minuteStart time.Time, day string) *agentDecisionLimitState {
state := l.players[playerID]
if state != nil {
return state
}
state = &agentDecisionLimitState{minuteStart: minuteStart, day: day}
l.players[playerID] = state
return state
}

func (l *agentDecisionLimiter) pruneExpiredIfDue(now time.Time) {
if !l.lastPruned.IsZero() && now.Sub(l.lastPruned) < time.Minute {
return
}
l.lastPruned = now

cutoff := now.Add(-agentDecisionLimitStateTTL)
for playerID, state := range l.players {
if !state.lastSeen.IsZero() && state.lastSeen.Before(cutoff) {
delete(l.players, playerID)
}
}
}
Comment on lines +100 to +112
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pruneExpiredIfDue function iterates over the entire players map while holding a global mutex. In a production environment with a large number of unique players over a 25-hour window, this O(N) operation will cause significant latency spikes for the request that triggers the pruning (once per minute). Consider moving the pruning logic to a background goroutine or using a more efficient data structure for TTL management (e.g., a linked list or a specialized cache library) to avoid blocking the request path.

Comment on lines +100 to +112
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The pruneExpiredIfDue function performs an O(N) scan over the entire players map while holding the global mu mutex. As the number of unique players grows, this will cause significant latency spikes for the request that triggers the pruning (once per minute). Consider performing the pruning in a background goroutine or using a data structure that supports more efficient expiration (like a TTL cache) to avoid blocking the request path.


func (s *agentDecisionLimitState) resetWindows(minuteStart time.Time, day string) {
if !s.minuteStart.Equal(minuteStart) {
s.minuteStart = minuteStart
s.minuteCount = 0
}
if s.day != day {
s.day = day
s.tokensUsed = 0
}
}

func (s *agentDecisionLimitState) rateLimitResult(playerID string, rateLimit int, now time.Time) (agentDecisionLimitResult, bool) {
if rateLimit <= 0 || s.minuteCount < rateLimit {
return agentDecisionLimitResult{}, false
}
retryAfter := s.minuteStart.Add(time.Minute).Sub(now)
if retryAfter < time.Second {
retryAfter = time.Second
}
return agentDecisionLimitResult{
Error: "agent decision rate limit exceeded",
Reason: "rate_limit_exceeded",
PlayerID: playerID,
RetryAfterSeconds: int64(retryAfter.Seconds()),
}, true
}

func (s *agentDecisionLimitState) tokenBudgetResult(playerID string, tokenEstimate int, tokenBudget int) (agentDecisionLimitResult, bool) {
if tokenBudget <= 0 || s.tokensUsed+tokenEstimate <= tokenBudget {
return agentDecisionLimitResult{}, false
}
return agentDecisionLimitResult{
Error: "agent decision token budget exhausted",
Reason: "token_budget_exhausted",
PlayerID: playerID,
TokenEstimate: tokenEstimate,
TokenBudgetPerDay: tokenBudget,
TokenBudgetUsedToday: s.tokensUsed,
TokenBudgetRemaining: max(tokenBudget-s.tokensUsed, 0),
}, true
}

func normalizeLimitPlayerID(playerID string) string {
playerID = strings.TrimSpace(playerID)
if playerID == "" {
return "unknown"
}
return playerID
}
Loading
Loading