Skip to content

Commit dde304d

Browse files
committed
play.lru
1 parent 19b1463 commit dde304d

File tree

1 file changed

+205
-0
lines changed

1 file changed

+205
-0
lines changed

src/play/lru.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import pytest
2+
from functools import wraps
3+
from typing import Callable
4+
from collections import namedtuple
5+
import time
6+
7+
8+
class Node:
9+
__slots__ = ['pre', 'after', 'key', 'result', 'ts']
10+
11+
def __init__(self, pre, after, key, result, ts):
12+
self.pre = pre
13+
self.after = after
14+
self.key = key
15+
self.result = result
16+
self.ts = ts
17+
18+
def __iter__(self):
19+
return [self.pre, self.after, self.key, self.result, self.ts].__iter__()
20+
21+
22+
# 又可以抛出无法被hash的异常, 如果一个参数是无法被hash,需要考虑是否适合缓存
23+
def _make_key(*args, **kwargs):
24+
key_tuple = tuple(args)
25+
if kwargs:
26+
for x in kwargs.items():
27+
key_tuple += x
28+
return hash(key_tuple)
29+
30+
31+
class Cache:
32+
33+
def __init__(self,
34+
max_size=100,
35+
expire_sec=30,
36+
expire_hook: Callable[[str], None] = None,
37+
hit_hook: Callable[[str], None] = None,
38+
not_hit_hook: Callable[[str], None] = None,
39+
):
40+
self.max_size = max_size
41+
self.expire_sec = expire_sec
42+
43+
self.expire_hook = expire_hook
44+
self.hit_hook = hit_hook
45+
self.not_hit_hook = not_hit_hook
46+
47+
self._cache = {}
48+
self._full = False
49+
self._header = Node(None, None, None, None, 0)
50+
self._header.pre = self._header
51+
self._header.after = self._header
52+
53+
def __call__(self, fn):
54+
@wraps(fn)
55+
def inner(*args, **kwargs):
56+
key = _make_key(*args, **kwargs)
57+
58+
result = self.get(key)
59+
if not result:
60+
result = fn(*args, **kwargs)
61+
self.set(key, result)
62+
return result
63+
return inner
64+
65+
def get(self, key):
66+
if key in self._cache:
67+
pre, after, key, result, ts = self._cache[key]
68+
if time.time() - ts > self.expire_sec:
69+
if self.expire_hook: self.expire_hook(key)
70+
return None
71+
72+
# 将cache放到最新
73+
pre.after = after
74+
after.pre = pre
75+
76+
last = self._header.pre
77+
last.after = self._cache[key]
78+
self._cache[key].pre = last
79+
self._header.pre = self._cache[key]
80+
self._cache[key].ts = time.time()
81+
82+
if self.hit_hook: self.hit_hook(key)
83+
return result
84+
85+
if self.not_hit_hook: self.not_hit_hook(key)
86+
return None
87+
88+
def set(self, key, result):
89+
if self._full:
90+
# 将最老的替换掉
91+
# 直接将新数据放到头节点,并一定头节点指针
92+
oldest_key = self._header.key
93+
self._header.key = key
94+
self._header.result = result
95+
self._header.ts = time.time()
96+
97+
del self._cache[oldest_key]
98+
self._cache[key] = self._header
99+
self._header = self._header.after
100+
else:
101+
if self._header.key is None:
102+
# 从来没有初始化
103+
self._header.key = key
104+
self._header.result = result
105+
self._header.ts = time.time()
106+
self._cache[key] = self._header
107+
else:
108+
# 加到最新
109+
last = self._header.pre
110+
self._cache[key] = Node(last, self._header, key, result, time.time())
111+
last.after = self._cache[key]
112+
self._header.pre = self._cache[key]
113+
self._full = len(self._cache) >= self.max_size
114+
115+
116+
def test_simple_cache():
117+
call_time = []
118+
119+
@Cache()
120+
def work(*args, **kwargs):
121+
call_time.append(1)
122+
return 1
123+
124+
work(1, 2)
125+
work(1, 2)
126+
work(1, a=1)
127+
work(1, a=2)
128+
work(1, b=1)
129+
work(1, a=1)
130+
131+
assert len(call_time) == 4
132+
133+
134+
def test_expire_cache():
135+
call_time = []
136+
137+
@Cache(expire_sec=2)
138+
def work(*args, **kwargs):
139+
call_time.append(1)
140+
return 1
141+
142+
work(1, 2)
143+
work(1, a=1)
144+
work(1, a=2)
145+
work(1, b=1)
146+
work(1, a=1)
147+
148+
time.sleep(3)
149+
work(1, 2)
150+
assert len(call_time) == 5
151+
152+
153+
def test_hook():
154+
not_hit = []
155+
hit = []
156+
expire = []
157+
158+
def not_hit_hook(key): not_hit.append(1)
159+
def hit_hook(key): hit.append(1)
160+
def expire_hook(key): expire.append(1)
161+
162+
@Cache(expire_sec=1, not_hit_hook=not_hit_hook, hit_hook=hit_hook, expire_hook=expire_hook)
163+
def work(*args, **kwargs):
164+
return 1
165+
166+
work(1, 2)
167+
work(1, a=1)
168+
work(1, a=2)
169+
work(1, b=1)
170+
work(1, a=1)
171+
time.sleep(2)
172+
work(1, 2)
173+
174+
assert len(hit) == 1
175+
assert len(expire) == 1
176+
assert len(not_hit) == 4
177+
178+
179+
def test_if_full():
180+
181+
call_time = []
182+
183+
@Cache(expire_sec=2, max_size=3)
184+
def work(*args, **kwargs):
185+
call_time.append(1)
186+
return 1
187+
188+
# lru
189+
work(1)
190+
work(2)
191+
work(3)
192+
work(4)
193+
assert len(call_time) == 4
194+
195+
# cache should has [2,3,4]
196+
work(1)
197+
assert len(call_time) == 5
198+
199+
# cache should hash [3,4, 1]
200+
work(4)
201+
assert len(call_time) == 5
202+
203+
204+
if __name__ == "__main__":
205+
pytest.main(["./lru.py::test_if_full", "-v", "-s"])

0 commit comments

Comments
 (0)