Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion module/http_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package module
import (
"context"
"fmt"
"math"
"net"
"net/http"
"strconv"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -167,7 +169,14 @@ func (m *RateLimitMiddleware) Process(next http.Handler) http.Handler {
// Check if request can proceed
if c.tokens < 1 {
m.mu.Unlock()
w.Header().Set("Retry-After", "60")
// Compute how many seconds until 1 token refills, based on the
// fractional per-minute rate (ratePerMinute tokens/minute).
retryAfter := "60"
if m.ratePerMinute > 0 {
secondsUntilToken := 60.0 / m.ratePerMinute
retryAfter = strconv.Itoa(int(math.Ceil(secondsUntilToken)))
}
w.Header().Set("Retry-After", retryAfter)
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
Expand Down
54 changes: 52 additions & 2 deletions module/http_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -460,8 +460,8 @@ func TestRateLimitMiddleware_RetryAfterHeader(t *testing.T) {
if rec2.Code != http.StatusTooManyRequests {
t.Errorf("expected 429, got %d", rec2.Code)
}
if rec2.Header().Get("Retry-After") != "60" {
t.Errorf("expected Retry-After header '60', got %q", rec2.Header().Get("Retry-After"))
if rec2.Header().Get("Retry-After") != "1" {
t.Errorf("expected Retry-After header '1', got %q", rec2.Header().Get("Retry-After"))
}
}

Expand Down Expand Up @@ -640,3 +640,53 @@ func TestNewRateLimitMiddlewareWithHourlyRate_RatePerMinute(t *testing.T) {
t.Errorf("expected ratePerMinute=1.0, got %f", m.ratePerMinute)
}
}

func TestNewRateLimitMiddlewareWithHourlyRate_FractionalRefill(t *testing.T) {
// 3600 requests/hour -> ratePerMinute = 60.0, timePerToken = 1 second.
// Using a high hourly rate keeps the sleep short while still exercising
// the fractional refill path.
m := NewRateLimitMiddlewareWithHourlyRate("rl-hour-fractional", 3600, 1)

handler := m.Process(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))

// First request should be allowed (uses the single burst token).
req1 := httptest.NewRequest("GET", "/fractional", nil)
rec1 := httptest.NewRecorder()
handler.ServeHTTP(rec1, req1)
if rec1.Code != http.StatusOK {
t.Fatalf("first request: expected 200, got %d", rec1.Code)
}

// Second immediate request must be rate-limited (burst exhausted, no refill yet).
req2 := httptest.NewRequest("GET", "/fractional", nil)
rec2 := httptest.NewRecorder()
handler.ServeHTTP(rec2, req2)
if rec2.Code != http.StatusTooManyRequests {
t.Fatalf("second request: expected 429, got %d", rec2.Code)
}

// Wait slightly longer than the time needed to refill one token.
if m.ratePerMinute <= 0 {
t.Fatalf("ratePerMinute must be positive, got %f", m.ratePerMinute)
}
timePerToken := time.Duration(float64(time.Minute) / m.ratePerMinute)
time.Sleep(timePerToken + 100*time.Millisecond)

// After waiting, exactly one additional request should be allowed.
req3 := httptest.NewRequest("GET", "/fractional", nil)
rec3 := httptest.NewRecorder()
handler.ServeHTTP(rec3, req3)
if rec3.Code != http.StatusOK {
t.Fatalf("third request after refill: expected 200, got %d", rec3.Code)
}

// An immediately following request must still be rate-limited.
req4 := httptest.NewRequest("GET", "/fractional", nil)
rec4 := httptest.NewRecorder()
handler.ServeHTTP(rec4, req4)
if rec4.Code != http.StatusTooManyRequests {
t.Fatalf("fourth request after refill: expected 429, got %d", rec4.Code)
}
}
24 changes: 18 additions & 6 deletions plugins/http/modules.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,24 +135,36 @@ func loggingMiddlewareFactory(name string, cfg map[string]any) modular.Module {
func rateLimitMiddlewareFactory(name string, cfg map[string]any) modular.Module {
burstSize := 10
if bs, ok := cfg["burstSize"].(int); ok {
burstSize = bs
if bs > 0 {
burstSize = bs
}
} else if bs, ok := cfg["burstSize"].(float64); ok {
burstSize = int(bs)
if intBS := int(bs); intBS > 0 {
burstSize = intBS
}
}

// requestsPerHour takes precedence over requestsPerMinute for low-frequency
// endpoints (e.g. registration) where fractional per-minute rates are needed.
if rph, ok := cfg["requestsPerHour"].(int); ok {
return module.NewRateLimitMiddlewareWithHourlyRate(name, rph, burstSize)
if rph > 0 {
return module.NewRateLimitMiddlewareWithHourlyRate(name, rph, burstSize)
}
} else if rph, ok := cfg["requestsPerHour"].(float64); ok {
return module.NewRateLimitMiddlewareWithHourlyRate(name, int(rph), burstSize)
if intRPH := int(rph); intRPH > 0 {
return module.NewRateLimitMiddlewareWithHourlyRate(name, intRPH, burstSize)
}
}

requestsPerMinute := 60
if rpm, ok := cfg["requestsPerMinute"].(int); ok {
requestsPerMinute = rpm
if rpm > 0 {
requestsPerMinute = rpm
}
} else if rpm, ok := cfg["requestsPerMinute"].(float64); ok {
requestsPerMinute = int(rpm)
if intRPM := int(rpm); intRPM > 0 {
requestsPerMinute = intRPM
}
}
return module.NewRateLimitMiddleware(name, requestsPerMinute, burstSize)
}
Expand Down
35 changes: 35 additions & 0 deletions plugins/http/plugin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,41 @@ func TestRateLimitMiddlewareFactory_RequestsPerHour(t *testing.T) {
}
}

func TestRateLimitMiddlewareFactory_InvalidValues(t *testing.T) {
factories := moduleFactories()
factory, ok := factories["http.middleware.ratelimit"]
if !ok {
t.Fatal("no factory for http.middleware.ratelimit")
}

// Zero requestsPerHour must fall through to requestsPerMinute path (not crash).
modZeroRPH := factory("rl-zero-rph", map[string]any{
"requestsPerHour": 0,
"requestsPerMinute": 30,
"burstSize": 5,
})
if modZeroRPH == nil {
t.Fatal("factory returned nil for zero requestsPerHour config")
}

// Negative requestsPerMinute must use default (60).
modNegRPM := factory("rl-neg-rpm", map[string]any{
"requestsPerMinute": -10,
})
if modNegRPM == nil {
t.Fatal("factory returned nil for negative requestsPerMinute config")
}

// Zero burstSize must keep default (10).
modZeroBurst := factory("rl-zero-burst", map[string]any{
"requestsPerMinute": 60,
"burstSize": 0,
})
if modZeroBurst == nil {
t.Fatal("factory returned nil for zero burstSize config")
}
}

func TestPluginLoaderIntegration(t *testing.T) {
p := New()

Expand Down
2 changes: 1 addition & 1 deletion plugins/http/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ func rateLimitMiddlewareSchema() *schema.ModuleSchema {
ConfigFields: []schema.ConfigFieldDef{
{Key: "requestsPerMinute", Label: "Requests Per Minute", Type: schema.FieldTypeNumber, DefaultValue: 60, Description: "Maximum number of requests per minute per client (mutually exclusive with requestsPerHour)"},
{Key: "requestsPerHour", Label: "Requests Per Hour", Type: schema.FieldTypeNumber, DefaultValue: 0, Description: "Maximum number of requests per hour per client; takes precedence over requestsPerMinute when set"},
{Key: "burstSize", Label: "Burst Size", Type: schema.FieldTypeNumber, DefaultValue: 10, Description: "Maximum burst of requests allowed above the rate limit"},
{Key: "burstSize", Label: "Burst Size", Type: schema.FieldTypeNumber, DefaultValue: 10, Description: "Maximum number of tokens in the bucket; determines how many requests can burst when the bucket is full"},
},
DefaultConfig: map[string]any{"requestsPerMinute": 60, "burstSize": 10},
}
Expand Down