-
Notifications
You must be signed in to change notification settings - Fork 434
/
model.py
409 lines (344 loc) 路 13.2 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
from __future__ import annotations
import dataclasses as dc
import io
import typing as t
from contextlib import contextmanager
import torch
from torch.utils import data
from tqdm import tqdm
from superduperdb.backends.query_dataset import QueryDataset
from superduperdb.base.datalayer import Datalayer
from superduperdb.components.component import ensure_initialized
from superduperdb.components.datatype import (
DataType,
dill_serializer,
)
from superduperdb.components.model import (
CallableInputs,
Model,
Signature,
_DeviceManaged,
_Fittable,
)
from superduperdb.ext.torch.utils import device_of, eval, to_device
from superduperdb.misc.annotations import merge_docstrings
if t.TYPE_CHECKING:
from superduperdb.jobs.job import Job
def torchmodel(class_obj):
"""A decorator to convert a `torch.nn.Module` into a `TorchModel`.
Decorate a `torch.nn.Module` so that when it is invoked,
the result is a `TorchModel`.
:param class_obj: Class to decorate
"""
def factory(
identifier: str,
*args,
preprocess: t.Optional[t.Callable] = None,
postprocess: t.Optional[t.Callable] = None,
collate_fn: t.Optional[t.Callable] = None,
optimizer_state: t.Optional[t.Any] = None,
forward_method: str = '__call__',
train_forward_method: str = '__call__',
loader_kwargs: t.Dict = dc.field(default_factory=lambda: {}),
preprocess_signature: Signature = 'singleton',
forward_signature: Signature = 'singleton',
postprocess_signature: Signature = 'singleton',
**kwargs,
):
return TorchModel(
identifier=identifier,
object=class_obj(*args, **kwargs),
preprocess=preprocess,
postprocess=postprocess,
collate_fn=collate_fn,
optimizer_state=optimizer_state,
forward_method=forward_method,
train_forward_method=train_forward_method,
loader_kwargs=loader_kwargs,
preprocess_signature=preprocess_signature,
forward_signature=forward_signature,
postprocess_signature=postprocess_signature,
)
return factory
class BasicDataset(data.Dataset):
"""
Basic database iterating over a list of documents and applying a transformation.
:param items: items, typically documents
:param transform: function, typically a preprocess function
:param signature: signature of the transform function
"""
def __init__(self, items, transform, signature):
super().__init__()
self.items = items
self.transform = transform
self.signature = signature
def __len__(self):
return len(self.items)
def __getitem__(self, item):
out = self.items[item]
if self.transform is not None:
from superduperdb.components.model import Model
args, kwargs = Model.handle_input_type(out, self.signature)
return self.transform(*args, **kwargs)
return out
@merge_docstrings
@dc.dataclass(kw_only=True)
class TorchModel(Model, _Fittable, _DeviceManaged):
"""Torch model. This class is a wrapper around a PyTorch model.
:param object: Torch model, e.g. `torch.nn.Module`
:param preprocess: Preprocess function, the function to apply to the input
:param preprocess_signature: The signature of the preprocess function
:param postprocess: The postprocess function, the function to apply to the output
:param postprocess_signature: The signature of the postprocess function
:param forward_method: The forward method, the method to call on the model
:param forward_signature: The signature of the forward method
:param train_forward_method: Train forward method, the method to call on the model
:param train_forward_signature: The signature of the train forward method
:param train_preprocess: Train preprocess function,
the function to apply to the input
:param train_preprocess_signature: The signature of the train preprocess function
:param collate_fn: The collate function for the dataloader
:param optimizer_state: The optimizer state
:param loader_kwargs: The kwargs for the dataloader
:param trainer: `Trainer` object to train the model
:param preferred_devices: The order of devices to use
:param device: The device to be used
"""
_artifacts: t.ClassVar[t.Sequence[t.Tuple[str, DataType]]] = (
('object', dill_serializer),
)
object: torch.nn.Module
preprocess: t.Optional[t.Callable] = None
preprocess_signature: Signature = 'singleton'
postprocess: t.Optional[t.Callable] = None
postprocess_signature: Signature = 'singleton'
forward_method: str = '__call__'
forward_signature: Signature = 'singleton'
train_forward_method: str = '__call__'
train_forward_signature: Signature = 'singleton'
train_preprocess: t.Optional[t.Callable] = None
train_preprocess_signature: Signature = 'singleton'
collate_fn: t.Optional[t.Callable] = None
optimizer_state: t.Optional[t.Any] = None
loader_kwargs: t.Dict = dc.field(default_factory=lambda: {})
def __post_init__(self, db, artifacts):
super().__post_init__(db, artifacts=artifacts)
if self.optimizer_state is not None:
self.optimizer.load_state_dict(self.optimizer_state)
self._validation_set_cache = {}
@property
def signature(self):
"""Get the signature of the model."""
if self.preprocess:
return self.preprocess_signature
return self.forward_signature
@signature.setter
def signature(self, signature):
"""Set the signature of the model.
:param signature: Signature
"""
if self.preprocess:
self.preprocess_signature = signature
else:
self.forward_signature = signature
def schedule_jobs(
self,
db: Datalayer,
dependencies: t.Sequence['Job'] = (),
) -> t.Sequence[t.Any]:
"""Schedule jobs for the model.
:param db: Datalayer
:param dependencies: Dependencies
"""
jobs = _Fittable.schedule_jobs(self, db, dependencies=dependencies)
return jobs
@property
def inputs(self) -> CallableInputs:
"""Get the inputs callable for the model."""
return CallableInputs(
self.object.forward if not self.preprocess else self.preprocess, {}
)
def to(self, device):
"""Move the model to a device.
:param device: Device
"""
self.object.to(device)
def save(self, db: Datalayer):
"""Save the model to the database.
:param db: Datalayer
"""
with self.saving():
db.replace(object=self, upsert=True)
@contextmanager
def evaluating(self):
"""Context manager for evaluating the model.
This context manager ensures that the model is in evaluation mode
"""
yield eval(self)
def train(self):
"""Set the model to training mode."""
return self.object.train()
def eval(self):
"""Set the model to evaluation mode."""
return self.object.eval()
def parameters(self):
"""Get the model parameters."""
return self.object.parameters()
def state_dict(self):
"""Get the model state dict."""
return self.object.state_dict()
@contextmanager
def saving(self):
"""Context manager for saving the model.
This context manager ensures that the model is in evaluation mode
"""
was_training = self.object.training
try:
self.object.eval()
yield
finally:
if was_training:
self.object.train()
def __getstate__(self):
state = self.__dict__.copy()
if isinstance(self.object, torch.jit.ScriptModule) or isinstance(
self.object, torch.jit.ScriptFunction
):
f = io.BytesIO()
torch.jit.save(self.object, f)
state['_object_bytes'] = f.getvalue()
return state
def __setstate__(self, state):
keys = state.keys()
for k in keys:
if k != '_object_bytes':
self.__dict__[k] = state[k]
else:
state.__dict__['object'] = torch.jit.load(
io.BytesIO(state.pop('object_bytes'))
)
@ensure_initialized
def predict_one(self, *args, **kwargs):
"""Predict on a single input.
:param args: Input arguments
:param kwargs: Input keyword arguments
"""
with torch.no_grad(), eval(self.object):
if self.preprocess is not None:
out = self.preprocess(*args, **kwargs)
args, kwargs = self.handle_input_type(out, self.signature)
args, kwargs = to_device((args, kwargs), self.device)
args, kwargs = create_batch((args, kwargs))
method = getattr(self.object, self.forward_method)
output = method(*args, **kwargs)
output = to_device(output, 'cpu')
args = unpack_batch(output)[0]
if self.postprocess is not None:
args = self.postprocess(args)
return args
@ensure_initialized
def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List:
"""Predict on a dataset.
:param dataset: Dataset
"""
with torch.no_grad(), eval(self.object):
inputs = BasicDataset(
items=dataset,
transform=self.preprocess,
signature=self.preprocess_signature,
)
loader = torch.utils.data.DataLoader(
inputs, **self.loader_kwargs, collate_fn=self.collate_fn
)
out = []
for batch in tqdm(loader, total=len(loader)):
batch = to_device(batch, device_of(self.object))
args, kwargs = self.handle_input_type(batch, self.signature)
method = getattr(self.object, self.forward_method)
tmp = method(*args, **kwargs, **self.predict_kwargs)
tmp = to_device(tmp, 'cpu')
tmp = unpack_batch(tmp)
if self.postprocess:
tmp = [
self.handle_input_type(x, self.postprocess_signature)
for x in tmp
]
tmp = [self.postprocess(*x[0], **x[1]) for x in tmp]
out.extend(tmp)
return out
def train_forward(self, X, y=None):
"""The forward method for training.
:param X: Input
:param y: Target
"""
X = X.to(self.device)
if y is not None:
y = y.to(self.device)
method = getattr(self.object, self.train_forward_method)
if hasattr(self.object, 'train_forward'):
if y is None:
return method(X)
else:
return method(X, y=y)
else:
if y is None:
return (method(X),)
else:
return [method(X), y]
def unpack_batch(args):
"""Unpack a batch into lines of tensor output.
:param args: a batch of model outputs
>>> unpack_batch(torch.randn(1, 10))[0].shape
torch.Size([10])
>>> out = unpack_batch([torch.randn(2, 10), torch.randn(2, 3, 5)])
>>> type(out)
<class 'list'>
>>> len(out)
2
>>> out = unpack_batch({'a': torch.randn(2, 10), 'b': torch.randn(2, 3, 5)})
>>> [type(x) for x in out]
[<class 'dict'>, <class 'dict'>]
>>> out[0]['a'].shape
torch.Size([10])
>>> out[0]['b'].shape
torch.Size([3, 5])
>>> out = unpack_batch({'a': {'b': torch.randn(2, 10)}})
>>> out[0]['a']['b'].shape
torch.Size([10])
>>> out[1]['a']['b'].shape
torch.Size([10])
"""
if isinstance(args, torch.Tensor):
return [args[i] for i in range(args.shape[0])]
if isinstance(args, list) or isinstance(args, tuple):
tmp = [unpack_batch(x) for x in args]
batch_size = len(tmp[0])
return [[x[i] for x in tmp] for i in range(batch_size)]
if isinstance(args, dict):
tmp = {k: unpack_batch(v) for k, v in args.items()}
batch_size = len(next(iter(tmp.values())))
return [{k: v[i] for k, v in tmp.items()} for i in range(batch_size)]
raise NotImplementedError
def create_batch(args):
"""Create a singleton batch in a manner similar to the PyTorch dataloader.
:param args: single data point for batching
>>> create_batch(3.).shape
torch.Size([1])
>>> x, y = create_batch([torch.randn(5), torch.randn(3, 7)])
>>> x.shape
torch.Size([1, 5])
>>> y.shape
torch.Size([1, 3, 7])
>>> d = create_batch(({'a': torch.randn(4)}))
>>> d['a'].shape
torch.Size([1, 4])
"""
if isinstance(args, (tuple, list)):
return tuple([create_batch(x) for x in args])
if isinstance(args, dict):
return {k: create_batch(args[k]) for k in args}
if isinstance(args, torch.Tensor):
return args.unsqueeze(0)
if isinstance(args, (float, int)):
return torch.tensor([args])
raise TypeError('Only tensors and tuples of tensors recursively supported...')