-
Notifications
You must be signed in to change notification settings - Fork 36
/
limiter.go
137 lines (121 loc) · 3.01 KB
/
limiter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package limiter
import (
"sync"
"time"
"github.com/RicheyJang/PaimengBot/basic/limiter/rate"
log "github.com/sirupsen/logrus"
)
// PluginLimiter 插件级限流器,可以区分用户地管理插件的CD限流
type PluginLimiter struct {
Key string // 插件Key,仅做log所用
cd time.Duration
burst int
limiters sync.Map
cdMux sync.RWMutex
}
// NewPluginLimiter 新建PluginLimiter用于单个插件的限流
func NewPluginLimiter(cd time.Duration, burst int) *PluginLimiter {
res := &PluginLimiter{
cd: cd,
burst: burst,
}
res.ResetCD(cd)
go res.gc()
return res
}
// GetCD 获取当前CD
func (pl *PluginLimiter) GetCD() time.Duration {
pl.cdMux.RLock()
defer pl.cdMux.RUnlock()
return pl.cd
}
// ResetCD 重置PluginLimiter的CD时间长度
func (pl *PluginLimiter) ResetCD(cd time.Duration) {
// 重置CD 会在下次gc之后生效
pl.cdMux.Lock()
pl.cd = cd
pl.cdMux.Unlock()
// 重置所有已有Limiter的CD
pl.limiters.Range(func(key, value interface{}) bool {
l, ok := value.(*subLimiter)
if !ok {
return true
}
l.ttl = cd * 3
l.limiter.SetLimitAt(l.lastGet, rate.Every(cd)) // 重设CD并防止重新获取Token
return true
})
}
// Allow 判断指定用户(key)能否拿到令牌
func (pl *PluginLimiter) Allow(key int64) (bool, time.Duration) {
// 获取subLimiter
l := pl.getSubLimiter(key)
if l == nil {
return false, pl.cd
}
// 检查rate
return l.allow()
}
// ---- 内部方法 ----
// 子Limiter,指定了某个特定用户
type subLimiter struct {
limiter *rate.Limiter
lastGet time.Time //上一次获取token的时间
ttl time.Duration
}
// 回收过期subLimiter
func (pl *PluginLimiter) gc() {
for {
// 等待
pl.cdMux.RLock()
defaultTTL := time.Minute
if pl.cd*3 > defaultTTL {
defaultTTL = pl.cd * 3
}
pl.cdMux.RUnlock()
time.Sleep(defaultTTL)
// 回收
pl.limiters.Range(func(key, value interface{}) bool {
l, ok := value.(*subLimiter)
if !ok {
pl.limiters.Delete(key)
return true
}
if l.lastGet.Add(l.ttl).Before(time.Now()) { // 超时,删除
log.Infof("删除超时的<%v>插件[%v]子限流器", pl.Key, key)
pl.limiters.Delete(key)
return true
}
return true
})
}
}
// 根据key(用户ID)获取subLimiter
func (pl *PluginLimiter) getSubLimiter(key int64) *subLimiter {
value, ok := pl.limiters.Load(key) // 从Map中获取subLimiter
if !ok { // 不存在或已超时
pl.cdMux.RLock()
l := newSubLimiter(pl.cd, pl.burst) // 新建subLimiter
pl.cdMux.RUnlock()
pl.limiters.Store(key, l) // 存储subLimiter
return l
}
l, ok := value.(*subLimiter)
if ok {
return l
}
return nil
}
// 创建新的subLimiter
func newSubLimiter(cd time.Duration, burst int) *subLimiter {
return &subLimiter{
limiter: rate.NewLimiter(rate.Every(cd), burst),
lastGet: time.Now(),
ttl: cd * 3, // 3倍CD作为subLimiter的过期间隔
}
}
// 判断rate
func (l *subLimiter) allow() (bool, time.Duration) {
l.lastGet = time.Now()
return l.limiter.AllowAndLeft()
}