/
memoization.py
276 lines (198 loc) · 9.36 KB
/
memoization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
from __future__ import annotations
import hashlib
from functools import lru_cache, singledispatch
import logging
import pickle
from parsl.dataflow.taskrecord import TaskRecord
from typing import Dict, Any, List, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from parsl import DataFlowKernel # import loop at runtime - needed for typechecking - TODO turn into "if typing:"
from concurrent.futures import Future
import types
logger = logging.getLogger(__name__)
@singledispatch
def id_for_memo(obj: object, output_ref: bool = False) -> bytes:
"""This should return a byte sequence which identifies the supplied
value for memoization purposes: for any two calls of id_for_memo,
the byte sequence should be the same when the "same" value is supplied,
and different otherwise.
"same" is in quotes about because sameness is not as straightforward as
serialising out the content.
For example, for two dicts x, y:
x = {"a":3, "b":4}
y = {"b":4, "a":3}
then: x == y, but their serialization is not equal, and some other
functions on x and y are not equal: list(x.keys()) != list(y.keys())
id_for_memo is invoked with output_ref=True when the parameter is an
output reference (a value in the outputs=[] parameter of an app
invocation).
Memo hashing might be different for such parameters: for example, a
user might choose to hash input File content so that changing the
content of an input file invalidates memoization. This does not make
sense to do for output files: there is no meaningful content stored
where an output filename points at memoization time.
"""
logger.error("id_for_memo attempted on unknown type {}".format(type(obj)))
raise ValueError("unknown type for memoization: {}".format(type(obj)))
@id_for_memo.register(str)
@id_for_memo.register(int)
@id_for_memo.register(float)
@id_for_memo.register(type(None))
def id_for_memo_pickle(obj: object, output_ref: bool = False) -> bytes:
return pickle.dumps(obj)
@id_for_memo.register(list)
def id_for_memo_list(denormalized_list: list, output_ref: bool = False) -> bytes:
if type(denormalized_list) is not list:
raise ValueError("id_for_memo_list cannot work on subclasses of list")
normalized_list = []
for e in denormalized_list:
normalized_list.append(id_for_memo(e, output_ref=output_ref))
return pickle.dumps(normalized_list)
@id_for_memo.register(tuple)
def id_for_memo_tuple(denormalized_tuple: tuple, output_ref: bool = False) -> bytes:
if type(denormalized_tuple) is not tuple:
raise ValueError("id_for_memo_tuple cannot work on subclasses of tuple")
normalized_list = []
for e in denormalized_tuple:
normalized_list.append(id_for_memo(e, output_ref=output_ref))
return pickle.dumps(normalized_list)
@id_for_memo.register(dict)
def id_for_memo_dict(denormalized_dict: dict, output_ref: bool = False) -> bytes:
"""This normalises the keys and values of the supplied dictionary.
When output_ref=True, the values are normalised as output refs, but
the keys are not.
"""
if type(denormalized_dict) is not dict:
raise ValueError("id_for_memo_dict cannot work on subclasses of dict")
keys = sorted(denormalized_dict)
normalized_list = []
for k in keys:
normalized_list.append(id_for_memo(k))
normalized_list.append(id_for_memo(denormalized_dict[k], output_ref=output_ref))
return pickle.dumps(normalized_list)
# the LRU cache decorator must be applied closer to the id_for_memo_function call
# that the .register() call, so that the cache-decorated version is registered.
@id_for_memo.register(types.FunctionType)
@lru_cache()
def id_for_memo_function(f: types.FunctionType, output_ref: bool = False) -> bytes:
"""This will checkpoint a function based only on its name and module name.
This means that changing source code (other than the function name) will
not cause a checkpoint invalidation.
"""
return pickle.dumps(["types.FunctionType", f.__name__, f.__module__])
class Memoizer:
"""Memoizer is responsible for ensuring that identical work is not repeated.
When a task is repeated, i.e., the same function is called with the same exact arguments, the
result from a previous execution is reused. `wiki <https://en.wikipedia.org/wiki/Memoization>`_
The memoizer implementation here does not collapse duplicate calls
at call time, but works **only** when the result of a previous
call is available at the time the duplicate call is made.
For instance::
No advantage from Memoization helps
memoization here: here:
TaskA TaskB
| TaskA |
| | TaskA done (TaskB)
| | | (TaskB)
done | |
done |
done
The memoizer creates a lookup table by hashing the function name
and its inputs, and storing the results of the function.
When a task is ready for launch, i.e., all of its arguments
have resolved, we add its hash to the task datastructure.
"""
def __init__(self, dfk: DataFlowKernel, memoize: bool = True, checkpoint: Dict[str, Future[Any]] = {}):
"""Initialize the memoizer.
Args:
- dfk (DFK obj): The DFK object
KWargs:
- memoize (Bool): enable memoization or not.
- checkpoint (Dict): A checkpoint loaded as a dict.
"""
self.dfk = dfk
self.memoize = memoize
if self.memoize:
logger.info("App caching initialized")
self.memo_lookup_table = checkpoint
else:
logger.info("App caching disabled for all apps")
self.memo_lookup_table = {}
def make_hash(self, task: TaskRecord) -> str:
"""Create a hash of the task inputs.
Args:
- task (dict) : Task dictionary from dfk.tasks
Returns:
- hash (str) : A unique hash string
"""
t: List[bytes] = []
# if kwargs contains an outputs parameter, that parameter is removed
# and normalised differently - with output_ref set to True.
# kwargs listed in ignore_for_cache will also be removed
filtered_kw = task['kwargs'].copy()
ignore_list = task['ignore_for_cache']
logger.debug("Ignoring these kwargs for checkpointing: %s", ignore_list)
for k in ignore_list:
logger.debug("Ignoring kwarg %s", k)
del filtered_kw[k]
if 'outputs' in task['kwargs']:
outputs = task['kwargs']['outputs']
del filtered_kw['outputs']
t.append(id_for_memo(outputs, output_ref=True))
t.extend(map(id_for_memo, (filtered_kw, task['func'], task['args'])))
x = b''.join(t)
return hashlib.md5(x).hexdigest()
def check_memo(self, task: TaskRecord) -> Optional[Future[Any]]:
"""Create a hash of the task and its inputs and check the lookup table for this hash.
If present, the results are returned.
Args:
- task(task) : task from the dfk.tasks table
Returns:
- Result of the function if present in table, wrapped in a Future
This call will also set task['hashsum'] to the unique hashsum for the func+inputs.
"""
task_id = task['id']
if not self.memoize or not task['memoize']:
task['hashsum'] = None
logger.debug("Task {} will not be memoized".format(task_id))
return None
hashsum = self.make_hash(task)
logger.debug("Task {} has memoization hash {}".format(task_id, hashsum))
result = None
if hashsum in self.memo_lookup_table:
result = self.memo_lookup_table[hashsum]
logger.info("Task %s using result from cache", task_id)
else:
logger.info("Task %s had no result in cache", task_id)
task['hashsum'] = hashsum
assert isinstance(result, Future) or result is None
return result
def hash_lookup(self, hashsum: str) -> Future[Any]:
"""Lookup a hash in the memoization table.
Args:
- hashsum (str): The same hashes used to uniquely identify apps+inputs
Returns:
- Lookup result
Raises:
- KeyError: if hash not in table
"""
return self.memo_lookup_table[hashsum]
def update_memo(self, task: TaskRecord, r: Future[Any]) -> None:
"""Updates the memoization lookup table with the result from a task.
Args:
- task (dict) : A task dict from dfk.tasks
- r (Result future): Result future
"""
# TODO: could use typeguard
assert isinstance(r, Future)
task_id = task['id']
if not self.memoize or not task['memoize'] or 'hashsum' not in task:
return
if not isinstance(task['hashsum'], str):
logger.error("Attempting to update app cache entry but hashsum is not a string key")
return
if task['hashsum'] in self.memo_lookup_table:
logger.info(f"Replacing app cache entry {task['hashsum']} with result from task {task_id}")
else:
logger.debug(f"Storing app cache entry {task['hashsum']} with result from task {task_id}")
self.memo_lookup_table[task['hashsum']] = r