diff --git a/cmd/root.go b/cmd/root.go index 3ac48f8d..59bbcd6f 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -164,6 +164,7 @@ func init() { f.IntVar(&rootArgs.config.RateLimitRPS, "rate-limit-rps", defaultRateLimitRPS, "Maximum requests per second per IP for rate limiting") f.IntVar(&rootArgs.config.RateLimitBurst, "rate-limit-burst", defaultRateLimitBurst, "Maximum burst size per IP for rate limiting") f.BoolVar(&rootArgs.config.RateLimitFailClosed, "rate-limit-fail-closed", false, "On rate-limit backend errors, reject with 503 instead of allowing the request") + f.StringSliceVar(&rootArgs.config.TrustedProxies, "trusted-proxies", nil, "Comma-separated CIDRs of trusted reverse proxies. When set, gin uses X-Forwarded-For from these networks. Empty (default) trusts no proxies and uses RemoteAddr.") // JWT flags f.StringVar(&rootArgs.config.JWTType, "jwt-type", "", "Type of JWT to use") @@ -472,6 +473,7 @@ func runRoot(c *cobra.Command, args []string) { // Prepare server deps := &server.Dependencies{ Log: &log, + AppConfig: &rootArgs.config, HTTPProvider: httpProvider, } // Create the server diff --git a/internal/config/config.go b/internal/config/config.go index a43984e2..0d096ccf 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -259,4 +259,10 @@ type Config struct { RateLimitBurst int // RateLimitFailClosed rejects requests when the rate limit backend errors (default: fail-open). RateLimitFailClosed bool + // TrustedProxies is the list of CIDRs allowed to set X-Forwarded-For + // and similar proxy headers. Empty (the default) means no proxies are + // trusted and gin will use RemoteAddr directly. Operators behind a + // reverse proxy MUST set this explicitly or rate limiting and audit + // logs will key on the proxy IP, not the real client IP. + TrustedProxies []string } diff --git a/internal/rate_limit/in_memory.go b/internal/rate_limit/in_memory.go index ca479626..dd5063b5 100644 --- a/internal/rate_limit/in_memory.go +++ b/internal/rate_limit/in_memory.go @@ -3,6 +3,7 @@ package rate_limit import ( "context" "sync" + "sync/atomic" "time" "golang.org/x/time/rate" @@ -13,9 +14,22 @@ import ( // Compile-time interface guard var _ Provider = (*inMemoryProvider)(nil) +// entry stores a per-IP rate limiter together with the last-seen timestamp +// used by the cleanup goroutine. lastSeen is accessed concurrently by +// Allow() and cleanup(), so it MUST be touched only via the atomic helpers +// below — using a plain time.Time field would trip the race detector and +// (more importantly) drop updates non-deterministically. type entry struct { limiter *rate.Limiter - lastSeen time.Time // benign data race: only used for cleanup staleness heuristic + lastSeen atomic.Int64 // unix nanoseconds +} + +func (e *entry) touch() { + e.lastSeen.Store(time.Now().UnixNano()) +} + +func (e *entry) lastSeenTime() time.Time { + return time.Unix(0, e.lastSeen.Load()) } type inMemoryProvider struct { @@ -38,15 +52,11 @@ func newInMemoryProvider(cfg *config.Config, deps *Dependencies) (*inMemoryProvi // Allow checks if a request from the given IP is allowed func (p *inMemoryProvider) Allow(_ context.Context, ip string) (bool, error) { - v, loaded := p.visitors.LoadOrStore(ip, &entry{ - limiter: rate.NewLimiter(p.rps, p.burst), - lastSeen: time.Now(), - }) + newEntry := &entry{limiter: rate.NewLimiter(p.rps, p.burst)} + newEntry.touch() + v, _ := p.visitors.LoadOrStore(ip, newEntry) e := v.(*entry) - e.lastSeen = time.Now() - if loaded { - p.visitors.Store(ip, e) - } + e.touch() return e.limiter.Allow(), nil } @@ -67,7 +77,7 @@ func (p *inMemoryProvider) cleanup(ctx context.Context) { case <-ticker.C: p.visitors.Range(func(key, value any) bool { e := value.(*entry) - if time.Since(e.lastSeen) > 10*time.Minute { + if time.Since(e.lastSeenTime()) > 10*time.Minute { p.visitors.Delete(key) } return true diff --git a/internal/rate_limit/redis.go b/internal/rate_limit/redis.go index f1d27121..55f5cb2a 100644 --- a/internal/rate_limit/redis.go +++ b/internal/rate_limit/redis.go @@ -39,10 +39,15 @@ type redisProvider struct { } func newRedisProvider(cfg *config.Config, deps *Dependencies) (*redisProvider, error) { - // Window = burst / rps, minimum 1 second + // Window in seconds = ceil(burst / rps), minimum 1. The previous integer + // division silently truncated to 0 when burst < rps and produced + // inconsistent enforcement vs the in-memory limiter (which uses the + // same effective window via golang.org/x/time/rate). Use ceiling + // arithmetic so the redis window is at least as long as the rps period. window := 1 if cfg.RateLimitRPS > 0 { - w := int(cfg.RateLimitBurst / cfg.RateLimitRPS) + // ceil(burst / rps) without floats: (a + b - 1) / b + w := (cfg.RateLimitBurst + cfg.RateLimitRPS - 1) / cfg.RateLimitRPS if w > 1 { window = w } @@ -55,13 +60,21 @@ func newRedisProvider(cfg *config.Config, deps *Dependencies) (*redisProvider, e }, nil } -// Allow checks if a request from the given IP is allowed using Redis +// Allow checks if a request from the given IP is allowed using Redis. +// +// Errors are PROPAGATED to the caller (the rate-limit middleware) so that +// the operator-controlled RateLimitFailClosed config can actually take +// effect. Previously this returned (true, nil) on any redis error, which +// meant fail-closed mode was a no-op and a flapping redis silently disabled +// rate limiting entirely. func (p *redisProvider) Allow(ctx context.Context, ip string) (bool, error) { key := "rate_limit:" + ip result, err := p.client.Eval(ctx, rateLimitScript, []string{key}, p.burst, p.window).Int64() if err != nil { - p.log.Error().Err(err).Str("ip", ip).Msg("rate limit redis error, failing open") - return true, nil + p.log.Error().Err(err).Str("ip", ip).Msg("rate limit redis error") + // Default to allowing the request, but PROPAGATE the error so the + // middleware can fail-closed when configured. + return true, err } return result == 1, nil } diff --git a/internal/server/http_routes.go b/internal/server/http_routes.go index ca332f9d..63d4bd11 100644 --- a/internal/server/http_routes.go +++ b/internal/server/http_routes.go @@ -11,6 +11,17 @@ import ( // NewRouter creates new gin router func (s *server) NewRouter() *gin.Engine { router := gin.New() + // Restrict the set of proxies whose forwarded headers are honoured. + // When TrustedProxies is empty/nil, gin trusts NO proxies and falls back + // to RemoteAddr — preventing X-Forwarded-For spoofing for rate limiting, + // audit logs, and CSRF same-origin comparisons. + var trustedProxies []string + if s.Dependencies.AppConfig != nil { + trustedProxies = s.Dependencies.AppConfig.TrustedProxies + } + if err := router.SetTrustedProxies(trustedProxies); err != nil { + s.Dependencies.Log.Warn().Err(err).Msg("failed to apply trusted proxies; falling back to gin defaults") + } router.Use(gin.Recovery()) router.Use(s.Dependencies.HTTPProvider.SecurityHeadersMiddleware()) diff --git a/internal/server/server.go b/internal/server/server.go index 2b146486..a98e0534 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -10,6 +10,7 @@ import ( "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/rs/zerolog" + "github.com/authorizerdev/authorizer/internal/config" "github.com/authorizerdev/authorizer/internal/graphql" "github.com/authorizerdev/authorizer/internal/http_handlers" ) @@ -29,6 +30,7 @@ type Config struct { // Dependencies for a server type Dependencies struct { Log *zerolog.Logger + AppConfig *config.Config GraphQLProvider graphql.Provider HTTPProvider http_handlers.Provider }