diff --git a/module/http_middleware.go b/module/http_middleware.go index 484ab696..2aa3d6de 100644 --- a/module/http_middleware.go +++ b/module/http_middleware.go @@ -3,8 +3,10 @@ package module import ( "context" "fmt" + "math" "net" "net/http" + "strconv" "strings" "sync" "time" @@ -167,7 +169,14 @@ func (m *RateLimitMiddleware) Process(next http.Handler) http.Handler { // Check if request can proceed if c.tokens < 1 { m.mu.Unlock() - w.Header().Set("Retry-After", "60") + // Compute how many seconds until 1 token refills, based on the + // fractional per-minute rate (ratePerMinute tokens/minute). + retryAfter := "60" + if m.ratePerMinute > 0 { + secondsUntilToken := 60.0 / m.ratePerMinute + retryAfter = strconv.Itoa(int(math.Ceil(secondsUntilToken))) + } + w.Header().Set("Retry-After", retryAfter) http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) return } diff --git a/module/http_middleware_test.go b/module/http_middleware_test.go index 11866d3f..60e3f72d 100644 --- a/module/http_middleware_test.go +++ b/module/http_middleware_test.go @@ -460,8 +460,8 @@ func TestRateLimitMiddleware_RetryAfterHeader(t *testing.T) { if rec2.Code != http.StatusTooManyRequests { t.Errorf("expected 429, got %d", rec2.Code) } - if rec2.Header().Get("Retry-After") != "60" { - t.Errorf("expected Retry-After header '60', got %q", rec2.Header().Get("Retry-After")) + if rec2.Header().Get("Retry-After") != "1" { + t.Errorf("expected Retry-After header '1', got %q", rec2.Header().Get("Retry-After")) } } @@ -640,3 +640,53 @@ func TestNewRateLimitMiddlewareWithHourlyRate_RatePerMinute(t *testing.T) { t.Errorf("expected ratePerMinute=1.0, got %f", m.ratePerMinute) } } + +func TestNewRateLimitMiddlewareWithHourlyRate_FractionalRefill(t *testing.T) { + // 3600 requests/hour -> ratePerMinute = 60.0, timePerToken = 1 second. + // Using a high hourly rate keeps the sleep short while still exercising + // the fractional refill path. + m := NewRateLimitMiddlewareWithHourlyRate("rl-hour-fractional", 3600, 1) + + handler := m.Process(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + // First request should be allowed (uses the single burst token). + req1 := httptest.NewRequest("GET", "/fractional", nil) + rec1 := httptest.NewRecorder() + handler.ServeHTTP(rec1, req1) + if rec1.Code != http.StatusOK { + t.Fatalf("first request: expected 200, got %d", rec1.Code) + } + + // Second immediate request must be rate-limited (burst exhausted, no refill yet). + req2 := httptest.NewRequest("GET", "/fractional", nil) + rec2 := httptest.NewRecorder() + handler.ServeHTTP(rec2, req2) + if rec2.Code != http.StatusTooManyRequests { + t.Fatalf("second request: expected 429, got %d", rec2.Code) + } + + // Wait slightly longer than the time needed to refill one token. + if m.ratePerMinute <= 0 { + t.Fatalf("ratePerMinute must be positive, got %f", m.ratePerMinute) + } + timePerToken := time.Duration(float64(time.Minute) / m.ratePerMinute) + time.Sleep(timePerToken + 100*time.Millisecond) + + // After waiting, exactly one additional request should be allowed. + req3 := httptest.NewRequest("GET", "/fractional", nil) + rec3 := httptest.NewRecorder() + handler.ServeHTTP(rec3, req3) + if rec3.Code != http.StatusOK { + t.Fatalf("third request after refill: expected 200, got %d", rec3.Code) + } + + // An immediately following request must still be rate-limited. + req4 := httptest.NewRequest("GET", "/fractional", nil) + rec4 := httptest.NewRecorder() + handler.ServeHTTP(rec4, req4) + if rec4.Code != http.StatusTooManyRequests { + t.Fatalf("fourth request after refill: expected 429, got %d", rec4.Code) + } +} diff --git a/plugins/http/modules.go b/plugins/http/modules.go index c72c7b61..03223ea3 100644 --- a/plugins/http/modules.go +++ b/plugins/http/modules.go @@ -135,24 +135,36 @@ func loggingMiddlewareFactory(name string, cfg map[string]any) modular.Module { func rateLimitMiddlewareFactory(name string, cfg map[string]any) modular.Module { burstSize := 10 if bs, ok := cfg["burstSize"].(int); ok { - burstSize = bs + if bs > 0 { + burstSize = bs + } } else if bs, ok := cfg["burstSize"].(float64); ok { - burstSize = int(bs) + if intBS := int(bs); intBS > 0 { + burstSize = intBS + } } // requestsPerHour takes precedence over requestsPerMinute for low-frequency // endpoints (e.g. registration) where fractional per-minute rates are needed. if rph, ok := cfg["requestsPerHour"].(int); ok { - return module.NewRateLimitMiddlewareWithHourlyRate(name, rph, burstSize) + if rph > 0 { + return module.NewRateLimitMiddlewareWithHourlyRate(name, rph, burstSize) + } } else if rph, ok := cfg["requestsPerHour"].(float64); ok { - return module.NewRateLimitMiddlewareWithHourlyRate(name, int(rph), burstSize) + if intRPH := int(rph); intRPH > 0 { + return module.NewRateLimitMiddlewareWithHourlyRate(name, intRPH, burstSize) + } } requestsPerMinute := 60 if rpm, ok := cfg["requestsPerMinute"].(int); ok { - requestsPerMinute = rpm + if rpm > 0 { + requestsPerMinute = rpm + } } else if rpm, ok := cfg["requestsPerMinute"].(float64); ok { - requestsPerMinute = int(rpm) + if intRPM := int(rpm); intRPM > 0 { + requestsPerMinute = intRPM + } } return module.NewRateLimitMiddleware(name, requestsPerMinute, burstSize) } diff --git a/plugins/http/plugin_test.go b/plugins/http/plugin_test.go index acea60ef..3b5d220b 100644 --- a/plugins/http/plugin_test.go +++ b/plugins/http/plugin_test.go @@ -356,6 +356,41 @@ func TestRateLimitMiddlewareFactory_RequestsPerHour(t *testing.T) { } } +func TestRateLimitMiddlewareFactory_InvalidValues(t *testing.T) { + factories := moduleFactories() + factory, ok := factories["http.middleware.ratelimit"] + if !ok { + t.Fatal("no factory for http.middleware.ratelimit") + } + + // Zero requestsPerHour must fall through to requestsPerMinute path (not crash). + modZeroRPH := factory("rl-zero-rph", map[string]any{ + "requestsPerHour": 0, + "requestsPerMinute": 30, + "burstSize": 5, + }) + if modZeroRPH == nil { + t.Fatal("factory returned nil for zero requestsPerHour config") + } + + // Negative requestsPerMinute must use default (60). + modNegRPM := factory("rl-neg-rpm", map[string]any{ + "requestsPerMinute": -10, + }) + if modNegRPM == nil { + t.Fatal("factory returned nil for negative requestsPerMinute config") + } + + // Zero burstSize must keep default (10). + modZeroBurst := factory("rl-zero-burst", map[string]any{ + "requestsPerMinute": 60, + "burstSize": 0, + }) + if modZeroBurst == nil { + t.Fatal("factory returned nil for zero burstSize config") + } +} + func TestPluginLoaderIntegration(t *testing.T) { p := New() diff --git a/plugins/http/schemas.go b/plugins/http/schemas.go index ae4b0d0b..41a13231 100644 --- a/plugins/http/schemas.go +++ b/plugins/http/schemas.go @@ -161,7 +161,7 @@ func rateLimitMiddlewareSchema() *schema.ModuleSchema { ConfigFields: []schema.ConfigFieldDef{ {Key: "requestsPerMinute", Label: "Requests Per Minute", Type: schema.FieldTypeNumber, DefaultValue: 60, Description: "Maximum number of requests per minute per client (mutually exclusive with requestsPerHour)"}, {Key: "requestsPerHour", Label: "Requests Per Hour", Type: schema.FieldTypeNumber, DefaultValue: 0, Description: "Maximum number of requests per hour per client; takes precedence over requestsPerMinute when set"}, - {Key: "burstSize", Label: "Burst Size", Type: schema.FieldTypeNumber, DefaultValue: 10, Description: "Maximum burst of requests allowed above the rate limit"}, + {Key: "burstSize", Label: "Burst Size", Type: schema.FieldTypeNumber, DefaultValue: 10, Description: "Maximum number of tokens in the bucket; determines how many requests can burst when the bucket is full"}, }, DefaultConfig: map[string]any{"requestsPerMinute": 60, "burstSize": 10}, }