Skip to content

Commit 904f2ad

Browse files
committed
fucking rate limiter
1 parent e2f022d commit 904f2ad

File tree

1 file changed

+126
-0
lines changed

1 file changed

+126
-0
lines changed

src/play/ratelimiter.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""
2+
BASIC: provide a rate limiter make sure function call time less than *max_call* time in *period* second @@
3+
ADVANCED:
4+
1. what happen if exceed the limit? Just raise exception or sleep ?
5+
2. if function call will cost some second? calculate from the start of function call or the end?
6+
3. if exceeding, start a hook call @@
7+
CODE TASTE:
8+
1. wrapper? @@
9+
2. context?
10+
3. param check @@
11+
PARALLEL SAFE:
12+
1. thread safe?
13+
2. support async await
14+
"""
15+
import time
16+
import threading
17+
18+
from typing import Callable, Optional
19+
from functools import wraps
20+
from collections import deque
21+
from enum import Enum
22+
23+
24+
class EStrategy(Enum):
25+
Raise: int = 0
26+
Sleep: int = 1
27+
28+
29+
class CStrategy(Enum):
30+
Begin: int = 0
31+
End: int = 1
32+
33+
34+
class RateLimiter:
35+
36+
def __init__(self, max_call: int, period: int,
37+
exceed_hook: Callable[[int, int], None] = None,
38+
exceed_strategy: int = EStrategy.Raise,
39+
calc_strategy: int = CStrategy.Begin):
40+
if max_call <= 0:
41+
raise ValueError("max_call must be positive")
42+
43+
if period <= 0:
44+
raise ValueError("period must be positive")
45+
46+
self.max_call = max_call
47+
self.period = period
48+
self.exceed_hook = exceed_hook
49+
self.exceed_strategy = exceed_strategy
50+
self.calc_strategy = calc_strategy
51+
self._calls = deque()
52+
self._lock = threading.Lock()
53+
54+
def _is_exceed(self):
55+
return len(self._calls) >= self.max_call
56+
57+
def _record_latest(self):
58+
self._calls.append(time.time())
59+
60+
def _clean_expired(self):
61+
while len(self._calls) > 0 and time.time() - self._calls[0] > self.period:
62+
self._calls.popleft()
63+
64+
def __enter__(self):
65+
with self._lock:
66+
if self._is_exceed():
67+
if self.exceed_strategy == EStrategy.Sleep:
68+
until = self._calls[0] + self.period
69+
wait = until - time.time()
70+
if self.exceed_hook:
71+
hook_thread = threading.Thread(target=self.exceed_hook, args=(wait, until))
72+
hook_thread.setDaemon(True)
73+
hook_thread.start()
74+
if wait > 0:
75+
time.sleep(wait)
76+
else:
77+
raise Exception("exceed max_call {} limit in {} second", self.max_call, self.period)
78+
79+
if self.calc_strategy == CStrategy.Begin:
80+
self._record_latest()
81+
return self
82+
83+
def __exit__(self, exc_type, exc_val, exc_tb):
84+
with self._lock:
85+
if self.calc_strategy == CStrategy.End:
86+
self._record_latest()
87+
self._clean_expired()
88+
89+
def __call__(self, func: Callable):
90+
@wraps(func)
91+
def inner(*args, **kwargs):
92+
with self:
93+
return func(*args, **kwargs)
94+
return inner
95+
96+
97+
def run_when_exceed(wait: int, until: int):
98+
print("wait {} until {}", wait, until)
99+
100+
101+
def run_when_exceed_timeout(wait: int, until: int):
102+
print("threading id: {}".format(threading.get_ident()))
103+
run_when_exceed(wait, until)
104+
time.sleep(2)
105+
print("threading id: {} I am a long time hook".format(threading.get_ident()))
106+
107+
108+
@RateLimiter(max_call=2, period=1, exceed_strategy=EStrategy.Sleep, exceed_hook=run_when_exceed)
109+
def work(*args, **kwargs):
110+
print(args, kwargs)
111+
112+
113+
def work_timeout(*args, **kwargs):
114+
print(args, kwargs)
115+
time.sleep(1)
116+
117+
118+
if __name__ == "__main__":
119+
120+
for i in range(30):
121+
t = threading.Thread(target=work, args=(1,2), kwargs={"a": 3})
122+
t.setDaemon(True)
123+
t.start()
124+
125+
for i in range(100):
126+
work(1, 2, a=3)

0 commit comments

Comments
 (0)