-
Notifications
You must be signed in to change notification settings - Fork 301
/
rate_limiter.go
153 lines (134 loc) · 3.08 KB
/
rate_limiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
package rate_limiter
import (
"net/netip"
"sync"
"time"
"golang.org/x/time/rate"
)
const (
tableShards = 32
gcInterval = time.Minute
)
type Limiter struct {
// Limit and Burst are read-only.
Limit rate.Limit
Burst int
closeOnce sync.Once
closeNotify chan struct{}
tables [tableShards]*tableShard
}
type tableShard struct {
m sync.Mutex
table map[netip.Addr]*limiterEntry
}
type limiterEntry struct {
l *rate.Limiter
lastSeen time.Time
sync.Once
}
// NewRateLimiter creates a new client rate limiter.
// limit and burst should be greater than zero. See rate.Limiter for more
// details.
// Limiter has a internal gc which will run and remove old client entries every 1m.
// If the token refill time (burst/limit) is greater than 1m,
// the actual average qps limit may be higher than expected because the client status
// may be deleted and re-initialized.
func NewRateLimiter(limit rate.Limit, burst int) *Limiter {
l := &Limiter{
Limit: limit,
Burst: burst,
closeNotify: make(chan struct{}),
}
for i := range l.tables {
l.tables[i] = &tableShard{table: make(map[netip.Addr]*limiterEntry)}
}
go l.gcLoop(gcInterval)
return l
}
// maskedUnmappedP must be a masked prefix and contain a unmapped addr.
func (l *Limiter) Allow(unmappedAddr netip.Addr) bool {
now := time.Now()
shard := l.getTableShard(unmappedAddr)
shard.m.Lock()
e, ok := shard.table[unmappedAddr]
if !ok {
e = &limiterEntry{
l: rate.NewLimiter(l.Limit, l.Burst),
lastSeen: now,
}
shard.table[unmappedAddr] = e
}
e.lastSeen = now
shard.m.Unlock()
clientLimiter := e.l
return clientLimiter.AllowN(now, 1)
}
func (l *Limiter) Close() error {
l.closeOnce.Do(func() {
close(l.closeNotify)
})
return nil
}
func (l *Limiter) gcLoop(gcInterval time.Duration) {
ticker := time.NewTicker(gcInterval)
defer ticker.Stop()
for {
select {
case <-l.closeNotify:
return
case now := <-ticker.C:
l.doGc(now, gcInterval)
}
}
}
func (l *Limiter) doGc(now time.Time, gcInterval time.Duration) {
for _, shard := range l.tables {
shard.m.Lock()
for a, e := range shard.table {
if now.Sub(e.lastSeen) > gcInterval {
delete(shard.table, a)
}
}
shard.m.Unlock()
}
}
func (l *Limiter) getTableShard(unmappedAddr netip.Addr) *tableShard {
return l.tables[getTableShardIdx(unmappedAddr)]
}
func (l *Limiter) ForEach(doFunc func(unmappedAddr netip.Addr, r *rate.Limiter) (doBreak bool)) (doBreak bool) {
for _, shard := range l.tables {
shard.m.Lock()
for a, e := range shard.table {
doBreak = doFunc(a, e.l)
if doBreak {
shard.m.Unlock()
return
}
}
shard.m.Unlock()
}
return false
}
// Len returns current number of entries in the Limiter.
func (l *Limiter) Len() int {
n := 0
for _, shard := range l.tables {
shard.m.Lock()
n += len(shard.table)
shard.m.Unlock()
}
return n
}
func getTableShardIdx(unmappedAddr netip.Addr) int {
var i byte
if unmappedAddr.Is4() {
for _, b := range unmappedAddr.As4() {
i ^= b
}
} else {
for _, b := range unmappedAddr.As16() {
i ^= b
}
}
return int(i % tableShards)
}