-
Notifications
You must be signed in to change notification settings - Fork 2
/
ipreqlimit.go
132 lines (120 loc) · 3.78 KB
/
ipreqlimit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
// Copyright 2023, DASH-Industry Forum. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE.md file.
package app
import (
"encoding/json"
"fmt"
"log/slog"
"net"
"net/http"
"os"
"sync"
"time"
)
// IPRequestLimiter limits the number of requests per interval
type IPRequestLimiter struct {
MaxNrRequests int `json:"maxNrRequests"`
Interval time.Duration `json:"interval"`
ResetTime time.Time `json:"resetTime"`
Counters map[string]int `json:"counters"`
logFile string `json:"-"`
mux sync.Mutex `json:"-"`
}
// NewIPRequestLimiter returns a new IPRequestLimiter with maxNrRequests per interval starting now.
// If logFile is not empty, the IPRequestLimiter is dumped to the logFile at the end of each interval.
func NewIPRequestLimiter(maxNrRequests int, interval time.Duration, start time.Time, logFile string) *IPRequestLimiter {
return &IPRequestLimiter{
MaxNrRequests: maxNrRequests,
Interval: interval,
ResetTime: start,
Counters: make(map[string]int),
logFile: logFile,
mux: sync.Mutex{},
}
}
// NewLimiterMiddleware returns a middleware that limits the number of requests per IP address per interval
// An HTTP response 429 Too Many Requests is generated if there are too many requests
// An HTTP header named hdrName is return the number of requests and the maximum number of requests per interval
func NewLimiterMiddleware(hdrName string, reqLimiter *IPRequestLimiter) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ip, err := ipFromRequest(r)
if err != nil {
_, _ = w.Write([]byte("could not read client IP"))
w.WriteHeader(http.StatusBadRequest)
return
}
now := time.Now()
count, ok := reqLimiter.Inc(now, ip)
if !ok {
if hdrName != "" {
w.Header().Set(hdrName, fmt.Sprintf("%d (max %d)", count, reqLimiter.MaxNrRequests))
}
w.WriteHeader(http.StatusTooManyRequests)
return
}
if hdrName != "" {
w.Header().Set(hdrName, fmt.Sprintf("%d (max %d)", count, reqLimiter.MaxNrRequests))
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}
// Inc increments the number of requests and returns number and ok value
func (il *IPRequestLimiter) Inc(now time.Time, ip string) (int, bool) {
il.mux.Lock()
defer il.mux.Unlock()
if now.Sub(il.ResetTime) > il.Interval {
if il.logFile != "" {
il.dump()
}
il.Counters = make(map[string]int)
il.ResetTime = now
}
il.Counters[ip]++
val := il.Counters[ip]
return val, val <= il.MaxNrRequests
}
// Count returns the counter value for an IP address
func (il *IPRequestLimiter) Count(ip string) int {
il.mux.Lock()
defer il.mux.Unlock()
return il.Counters[ip]
}
// EndTime returns next reset time.
func (il *IPRequestLimiter) EndTime() time.Time {
return il.ResetTime.Add(il.Interval)
}
func (il *IPRequestLimiter) dump() {
payload, err := json.Marshal(il)
if err != nil {
slog.Error("could not marshal IPRequestLimiter", "error", err.Error())
return
}
f, err := os.OpenFile(il.logFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0600)
if err != nil {
slog.Error("could not open IPRequestLimiter log file", "error", err.Error())
}
defer f.Close()
_, err = f.Write(payload)
if err != nil {
slog.Error("could not write to IPRequestLimiter log file", "error", err.Error())
}
}
func ipFromRequest(req *http.Request) (string, error) {
forwardIP := req.Header.Get("X-Forwarded-For")
if forwardIP != "" {
return forwardIP, nil
}
ip, _, err := net.SplitHostPort(req.RemoteAddr)
if err != nil {
return "", err
}
userIP := net.ParseIP(ip)
if userIP == nil {
return "", fmt.Errorf("no IP found")
}
return userIP.String(), nil
}