/
callback.py
85 lines (63 loc) · 2.07 KB
/
callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from ...utils.decorator import Decorator
class callback(Decorator):
"""callback for trainer
Examples:
>>> @callback
... def savemodel(model):
... model.save("path_to_file")
...
... trainer.train(model, callback = savemodel)
"""
def __init__(self, *args, **kwargs):
if hasattr(self, 'wrapped'):
# use `wrapped` func as core func
super().__init__(getattr(self, 'wrapped'))
# setup configuration
self.setup(*args, **kwargs)
return
# init normal decorator
super().__init__(*args, **kwargs)
def setup_func(self, func):
import inspect
self._params = inspect.signature(func).parameters
return func
def wrapper(self, **kwargs):
params = {k: v for k ,v in kwargs.items() if k in self._params.keys()}
return self.call(**params)
class checkpoint(callback):
"""
Args:
dir (string): dir name for saving checkpoint
every (int): every epoch for saving
format (string): checkpoint file format
"""
dirpath = "model_checkpoints"
every = 1
filename = "{name}-{epoch}.pt"
def wrapper(self, **kwargs):
model = kwargs.get("model")
epoch = kwargs.get("epoch")
name = type(model).__name__
from pathlib import Path
dirpath = Path(self.dirpath)
dirpath.mkdir(parents = True, exist_ok = True)
filename = self.filename.format(
name = name,
epoch = epoch,
)
path = dirpath / filename
if epoch % self.every == 0:
super().wrapper(
path = path,
**kwargs
)
class savemodel(checkpoint):
"""
Args:
dir (string): dir name for saving checkpoint
every (int): every epoch for saving
format (string): checkpoint file format, default is `{name}-{epoch}.pt`
"""
def wrapped(self, model, path):
import torch
torch.save(model.state_dict(), path)