-
Notifications
You must be signed in to change notification settings - Fork 1
Add rate limiting, production server hardening, and CD workflow #20
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| name: Deploy to Fly.io | ||
|
|
||
| on: | ||
| push: | ||
| branches: [main] | ||
|
|
||
| jobs: | ||
| deploy: | ||
| runs-on: ubuntu-latest | ||
| concurrency: deploy-production | ||
| steps: | ||
| - uses: actions/checkout@v4 | ||
| - uses: superfly/flyctl-actions/setup-flyctl@master | ||
| - run: flyctl deploy --remote-only | ||
| env: | ||
| FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }} | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| // Package ratelimit provides per-API-key rate limiting middleware. | ||
| package ratelimit | ||
|
|
||
| import ( | ||
| "net/http" | ||
| "strings" | ||
| "sync" | ||
|
|
||
| "golang.org/x/time/rate" | ||
| ) | ||
|
|
||
| // Limiter holds per-key token-bucket limiters. | ||
| type Limiter struct { | ||
| mu sync.Mutex | ||
| entries map[string]*rate.Limiter | ||
| r rate.Limit | ||
| b int | ||
| } | ||
|
|
||
| // New creates a Limiter allowing r tokens per second with a burst of b. | ||
| func New(r rate.Limit, b int) *Limiter { | ||
| return &Limiter{ | ||
| entries: make(map[string]*rate.Limiter), | ||
| r: r, | ||
| b: b, | ||
| } | ||
| } | ||
|
|
||
| // Middleware returns HTTP middleware that rate-limits by API key. | ||
| // Keys are read from "Authorization: Bearer <key>" or "X-API-Key: <key>", | ||
| // matching the auth middleware extraction logic. Requests with no key | ||
| // are passed through — the auth middleware upstream handles rejection. | ||
| func (l *Limiter) Middleware(next http.Handler) http.Handler { | ||
| return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| key := extractKey(r) | ||
| if key != "" && !l.get(key).Allow() { | ||
| http.Error(w, "rate limit exceeded", http.StatusTooManyRequests) | ||
| return | ||
| } | ||
| next.ServeHTTP(w, r) | ||
| }) | ||
| } | ||
|
|
||
| func (l *Limiter) get(key string) *rate.Limiter { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The actual rate limited never removes a client when added, which will be an infinite memory leak. We should eventually remove clients after some time threshold or use a library like https://github.com/didip/tollbooth which handles the removal. |
||
| l.mu.Lock() | ||
| defer l.mu.Unlock() | ||
| lim, ok := l.entries[key] | ||
| if !ok { | ||
| lim = rate.NewLimiter(l.r, l.b) | ||
| l.entries[key] = lim | ||
| } | ||
| return lim | ||
| } | ||
|
|
||
| func extractKey(r *http.Request) string { | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This function appears to be a duplicate of the one in middleware.go
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably use the authentication layer. The auth should then stash the verified key in r.Context() and let ratelimit read it from there. |
||
| if h := r.Header.Get("Authorization"); h != "" { | ||
| if rest, ok := strings.CutPrefix(h, "Bearer "); ok { | ||
| return strings.TrimSpace(rest) | ||
| } | ||
| } | ||
| return strings.TrimSpace(r.Header.Get("X-API-Key")) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,88 @@ | ||
| package ratelimit | ||
|
|
||
| import ( | ||
| "net/http" | ||
| "net/http/httptest" | ||
| "testing" | ||
|
|
||
| "github.com/stretchr/testify/require" | ||
| "golang.org/x/time/rate" | ||
| ) | ||
|
|
||
| func okHandler() http.Handler { | ||
| return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { | ||
| w.WriteHeader(http.StatusOK) | ||
| }) | ||
| } | ||
|
|
||
| func TestRateLimit_AllowsWithinBurst(t *testing.T) { | ||
| l := New(rate.Limit(1), 5) | ||
| h := l.Middleware(okHandler()) | ||
|
|
||
| for i := 0; i < 5; i++ { | ||
| req := httptest.NewRequest(http.MethodGet, "/data", nil) | ||
| req.Header.Set("X-API-Key", "testkey") | ||
| rec := httptest.NewRecorder() | ||
| h.ServeHTTP(rec, req) | ||
| require.Equal(t, http.StatusOK, rec.Code, "request %d should be allowed", i+1) | ||
| } | ||
| } | ||
|
|
||
| func TestRateLimit_BlocksAfterBurst(t *testing.T) { | ||
| l := New(rate.Limit(1), 3) | ||
| h := l.Middleware(okHandler()) | ||
|
|
||
| for i := 0; i < 3; i++ { | ||
| req := httptest.NewRequest(http.MethodGet, "/data", nil) | ||
| req.Header.Set("X-API-Key", "testkey") | ||
| rec := httptest.NewRecorder() | ||
| h.ServeHTTP(rec, req) | ||
| require.Equal(t, http.StatusOK, rec.Code) | ||
| } | ||
|
|
||
| req := httptest.NewRequest(http.MethodGet, "/data", nil) | ||
| req.Header.Set("X-API-Key", "testkey") | ||
| rec := httptest.NewRecorder() | ||
| h.ServeHTTP(rec, req) | ||
| require.Equal(t, http.StatusTooManyRequests, rec.Code) | ||
| } | ||
|
|
||
| func TestRateLimit_IndependentPerKey(t *testing.T) { | ||
| l := New(rate.Limit(1), 1) | ||
| h := l.Middleware(okHandler()) | ||
|
|
||
| for _, key := range []string{"key-a", "key-b", "key-c"} { | ||
| req := httptest.NewRequest(http.MethodGet, "/data", nil) | ||
| req.Header.Set("X-API-Key", key) | ||
| rec := httptest.NewRecorder() | ||
| h.ServeHTTP(rec, req) | ||
| require.Equal(t, http.StatusOK, rec.Code, "first request for %s should pass", key) | ||
| } | ||
| } | ||
|
|
||
| func TestRateLimit_BearerToken(t *testing.T) { | ||
| l := New(rate.Limit(1), 1) | ||
| h := l.Middleware(okHandler()) | ||
|
|
||
| req := httptest.NewRequest(http.MethodGet, "/data", nil) | ||
| req.Header.Set("Authorization", "Bearer mytoken") | ||
| rec := httptest.NewRecorder() | ||
| h.ServeHTTP(rec, req) | ||
| require.Equal(t, http.StatusOK, rec.Code) | ||
|
|
||
| req2 := httptest.NewRequest(http.MethodGet, "/data", nil) | ||
| req2.Header.Set("Authorization", "Bearer mytoken") | ||
| rec2 := httptest.NewRecorder() | ||
| h.ServeHTTP(rec2, req2) | ||
| require.Equal(t, http.StatusTooManyRequests, rec2.Code) | ||
| } | ||
|
|
||
| func TestRateLimit_NoKey_PassesThrough(t *testing.T) { | ||
| l := New(rate.Limit(1), 1) | ||
| h := l.Middleware(okHandler()) | ||
|
|
||
| req := httptest.NewRequest(http.MethodGet, "/data", nil) | ||
| rec := httptest.NewRecorder() | ||
| h.ServeHTTP(rec, req) | ||
| require.Equal(t, http.StatusOK, rec.Code) | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,17 +1,23 @@ | ||
| package main | ||
|
|
||
| import ( | ||
| "context" | ||
| "database/sql" | ||
| "fmt" | ||
| "log" | ||
| "net/http" | ||
| "os" | ||
| "os/signal" | ||
| "syscall" | ||
| "time" | ||
|
|
||
| _ "modernc.org/sqlite" | ||
|
|
||
| "github.com/Ribbit-Network/api/internal/auth" | ||
| "github.com/Ribbit-Network/api/internal/data" | ||
| "github.com/Ribbit-Network/api/internal/ratelimit" | ||
| "github.com/joho/godotenv" | ||
| "golang.org/x/time/rate" | ||
| ) | ||
|
|
||
| func main() { | ||
|
|
@@ -31,20 +37,45 @@ func runServer() { | |
| if err != nil { | ||
| log.Fatal(err) | ||
| } | ||
|
|
||
| requireKey := auth.Require(store) | ||
| // 60 requests/minute per key with a burst of 30. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The comment here and the burst limit in the code below disagree. The comment says burst 30, but the code below is burst 60. |
||
| limiter := ratelimit.New(rate.Every(time.Second), 60) | ||
|
|
||
| http.HandleFunc("/", handle) | ||
| http.Handle("/data", requireKey(http.HandlerFunc(data.Handle))) | ||
| mux := http.NewServeMux() | ||
| mux.HandleFunc("/", handleRoot) | ||
| mux.HandleFunc("/healthz", handleHealthz) | ||
| mux.Handle("/data", requireKey(limiter.Middleware(http.HandlerFunc(data.Handle)))) | ||
|
|
||
| port := os.Getenv("PORT") | ||
| if port == "" { | ||
| port = "8080" | ||
| } | ||
| addr := fmt.Sprintf(":%s", port) | ||
|
|
||
| log.Println("API running at http://localhost" + addr) | ||
| if err := http.ListenAndServe(addr, nil); err != nil { | ||
| log.Fatal(err) | ||
| srv := &http.Server{ | ||
| Addr: fmt.Sprintf(":%s", port), | ||
| Handler: corsMiddleware(mux), | ||
| ReadTimeout: 30 * time.Second, | ||
| WriteTimeout: 30 * time.Second, | ||
| IdleTimeout: 120 * time.Second, | ||
| } | ||
|
|
||
| go func() { | ||
| log.Println("API running at http://localhost" + srv.Addr) | ||
| if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { | ||
| log.Fatal(err) | ||
| } | ||
| }() | ||
|
|
||
| quit := make(chan os.Signal, 1) | ||
| signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM) | ||
| <-quit | ||
|
|
||
| log.Println("shutting down...") | ||
| ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) | ||
| defer cancel() | ||
| if err := srv.Shutdown(ctx); err != nil { | ||
| log.Fatalf("shutdown: %v", err) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -60,6 +91,24 @@ func openKeyStore() (*auth.Store, error) { | |
| return auth.NewStore(db) | ||
| } | ||
|
|
||
| func handle(w http.ResponseWriter, _ *http.Request) { | ||
| func handleRoot(w http.ResponseWriter, _ *http.Request) { | ||
| _, _ = fmt.Fprintln(w, "🐸") | ||
| } | ||
|
|
||
| func handleHealthz(w http.ResponseWriter, _ *http.Request) { | ||
| w.WriteHeader(http.StatusOK) | ||
| _, _ = fmt.Fprintln(w, "ok") | ||
| } | ||
|
|
||
| func corsMiddleware(next http.Handler) http.Handler { | ||
| return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
| w.Header().Set("Access-Control-Allow-Origin", "*") | ||
| w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS") | ||
| w.Header().Set("Access-Control-Allow-Headers", "Authorization, X-API-Key, Content-Type") | ||
| if r.Method == http.MethodOptions { | ||
| w.WriteHeader(http.StatusNoContent) | ||
| return | ||
| } | ||
| next.ServeHTTP(w, r) | ||
| }) | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This deploys no matter what, even if the tests fail etc. Either needs: the CI job or run go test ./... before flyctl deploy