1
+ import pytest
2
+ from functools import wraps
3
+ from typing import Callable
4
+ from collections import namedtuple
5
+ import time
6
+
7
+
8
+ class Node :
9
+ __slots__ = ['pre' , 'after' , 'key' , 'result' , 'ts' ]
10
+
11
+ def __init__ (self , pre , after , key , result , ts ):
12
+ self .pre = pre
13
+ self .after = after
14
+ self .key = key
15
+ self .result = result
16
+ self .ts = ts
17
+
18
+ def __iter__ (self ):
19
+ return [self .pre , self .after , self .key , self .result , self .ts ].__iter__ ()
20
+
21
+
22
+ # 又可以抛出无法被hash的异常, 如果一个参数是无法被hash,需要考虑是否适合缓存
23
+ def _make_key (* args , ** kwargs ):
24
+ key_tuple = tuple (args )
25
+ if kwargs :
26
+ for x in kwargs .items ():
27
+ key_tuple += x
28
+ return hash (key_tuple )
29
+
30
+
31
+ class Cache :
32
+
33
+ def __init__ (self ,
34
+ max_size = 100 ,
35
+ expire_sec = 30 ,
36
+ expire_hook : Callable [[str ], None ] = None ,
37
+ hit_hook : Callable [[str ], None ] = None ,
38
+ not_hit_hook : Callable [[str ], None ] = None ,
39
+ ):
40
+ self .max_size = max_size
41
+ self .expire_sec = expire_sec
42
+
43
+ self .expire_hook = expire_hook
44
+ self .hit_hook = hit_hook
45
+ self .not_hit_hook = not_hit_hook
46
+
47
+ self ._cache = {}
48
+ self ._full = False
49
+ self ._header = Node (None , None , None , None , 0 )
50
+ self ._header .pre = self ._header
51
+ self ._header .after = self ._header
52
+
53
+ def __call__ (self , fn ):
54
+ @wraps (fn )
55
+ def inner (* args , ** kwargs ):
56
+ key = _make_key (* args , ** kwargs )
57
+
58
+ result = self .get (key )
59
+ if not result :
60
+ result = fn (* args , ** kwargs )
61
+ self .set (key , result )
62
+ return result
63
+ return inner
64
+
65
+ def get (self , key ):
66
+ if key in self ._cache :
67
+ pre , after , key , result , ts = self ._cache [key ]
68
+ if time .time () - ts > self .expire_sec :
69
+ if self .expire_hook : self .expire_hook (key )
70
+ return None
71
+
72
+ # 将cache放到最新
73
+ pre .after = after
74
+ after .pre = pre
75
+
76
+ last = self ._header .pre
77
+ last .after = self ._cache [key ]
78
+ self ._cache [key ].pre = last
79
+ self ._header .pre = self ._cache [key ]
80
+ self ._cache [key ].ts = time .time ()
81
+
82
+ if self .hit_hook : self .hit_hook (key )
83
+ return result
84
+
85
+ if self .not_hit_hook : self .not_hit_hook (key )
86
+ return None
87
+
88
+ def set (self , key , result ):
89
+ if self ._full :
90
+ # 将最老的替换掉
91
+ # 直接将新数据放到头节点,并一定头节点指针
92
+ oldest_key = self ._header .key
93
+ self ._header .key = key
94
+ self ._header .result = result
95
+ self ._header .ts = time .time ()
96
+
97
+ del self ._cache [oldest_key ]
98
+ self ._cache [key ] = self ._header
99
+ self ._header = self ._header .after
100
+ else :
101
+ if self ._header .key is None :
102
+ # 从来没有初始化
103
+ self ._header .key = key
104
+ self ._header .result = result
105
+ self ._header .ts = time .time ()
106
+ self ._cache [key ] = self ._header
107
+ else :
108
+ # 加到最新
109
+ last = self ._header .pre
110
+ self ._cache [key ] = Node (last , self ._header , key , result , time .time ())
111
+ last .after = self ._cache [key ]
112
+ self ._header .pre = self ._cache [key ]
113
+ self ._full = len (self ._cache ) >= self .max_size
114
+
115
+
116
+ def test_simple_cache ():
117
+ call_time = []
118
+
119
+ @Cache ()
120
+ def work (* args , ** kwargs ):
121
+ call_time .append (1 )
122
+ return 1
123
+
124
+ work (1 , 2 )
125
+ work (1 , 2 )
126
+ work (1 , a = 1 )
127
+ work (1 , a = 2 )
128
+ work (1 , b = 1 )
129
+ work (1 , a = 1 )
130
+
131
+ assert len (call_time ) == 4
132
+
133
+
134
+ def test_expire_cache ():
135
+ call_time = []
136
+
137
+ @Cache (expire_sec = 2 )
138
+ def work (* args , ** kwargs ):
139
+ call_time .append (1 )
140
+ return 1
141
+
142
+ work (1 , 2 )
143
+ work (1 , a = 1 )
144
+ work (1 , a = 2 )
145
+ work (1 , b = 1 )
146
+ work (1 , a = 1 )
147
+
148
+ time .sleep (3 )
149
+ work (1 , 2 )
150
+ assert len (call_time ) == 5
151
+
152
+
153
+ def test_hook ():
154
+ not_hit = []
155
+ hit = []
156
+ expire = []
157
+
158
+ def not_hit_hook (key ): not_hit .append (1 )
159
+ def hit_hook (key ): hit .append (1 )
160
+ def expire_hook (key ): expire .append (1 )
161
+
162
+ @Cache (expire_sec = 1 , not_hit_hook = not_hit_hook , hit_hook = hit_hook , expire_hook = expire_hook )
163
+ def work (* args , ** kwargs ):
164
+ return 1
165
+
166
+ work (1 , 2 )
167
+ work (1 , a = 1 )
168
+ work (1 , a = 2 )
169
+ work (1 , b = 1 )
170
+ work (1 , a = 1 )
171
+ time .sleep (2 )
172
+ work (1 , 2 )
173
+
174
+ assert len (hit ) == 1
175
+ assert len (expire ) == 1
176
+ assert len (not_hit ) == 4
177
+
178
+
179
+ def test_if_full ():
180
+
181
+ call_time = []
182
+
183
+ @Cache (expire_sec = 2 , max_size = 3 )
184
+ def work (* args , ** kwargs ):
185
+ call_time .append (1 )
186
+ return 1
187
+
188
+ # lru
189
+ work (1 )
190
+ work (2 )
191
+ work (3 )
192
+ work (4 )
193
+ assert len (call_time ) == 4
194
+
195
+ # cache should has [2,3,4]
196
+ work (1 )
197
+ assert len (call_time ) == 5
198
+
199
+ # cache should hash [3,4, 1]
200
+ work (4 )
201
+ assert len (call_time ) == 5
202
+
203
+
204
+ if __name__ == "__main__" :
205
+ pytest .main (["./lru.py::test_if_full" , "-v" , "-s" ])
0 commit comments