This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
callback.py
102 lines (85 loc) · 3.03 KB
/
callback.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from typing import List, Dict, Any, Optional, TYPE_CHECKING
import torch
from allennlp.common import Registrable
from allennlp.data import TensorDict
if TYPE_CHECKING:
from allennlp.training.gradient_descent_trainer import GradientDescentTrainer
class TrainerCallback(Registrable):
"""
A general callback object that handles multiple events.
This class has `on_backward`, `on_batch`, `on_epoch`, and `on_end` methods, corresponding to
each callback type. Each one receives the state of the wrapper object as `self`.
This enables easier state sharing between related callbacks.
Also, this callback type is instantiated with `serialization_dir` and `on_start` is called
with the trainer instance as an argument. This might be handy in case of callback logging
and saving its own files next to the config/checkpoints/logs/etc.
"""
def __init__(self, serialization_dir: str) -> None:
self.serialization_dir = serialization_dir
self.trainer: Optional["GradientDescentTrainer"] = None
def on_start(
self, trainer: "GradientDescentTrainer", is_primary: bool = True, **kwargs
) -> None:
"""
This callback hook is called before the training is started.
"""
self.trainer = trainer
def on_backward(
self,
trainer: "GradientDescentTrainer",
batch_outputs: Dict[str, torch.Tensor],
backward_called: bool,
**kwargs,
) -> bool:
"""
This callback hook performs backpropagation and allows for gradient manipulation.
`backward_called` indicates if `loss.backward` has been called prior to this callback.
`on_backward` should return `True` if and only if `loss.backward` is called in its body.
"""
return False
def on_batch(
self,
trainer: "GradientDescentTrainer",
batch_inputs: List[TensorDict],
batch_outputs: List[Dict[str, Any]],
batch_metrics: Dict[str, Any],
epoch: int,
batch_number: int,
is_training: bool,
is_primary: bool = True,
batch_grad_norm: Optional[float] = None,
**kwargs,
) -> None:
"""
This callback hook is called after the end of each batch.
"""
pass
def on_epoch(
self,
trainer: "GradientDescentTrainer",
metrics: Dict[str, Any],
epoch: int,
is_primary: bool = True,
**kwargs,
) -> None:
"""
This callback hook is called after the end of each epoch.
"""
pass
def on_end(
self,
trainer: "GradientDescentTrainer",
metrics: Dict[str, Any] = None,
epoch: int = None,
is_primary: bool = True,
**kwargs,
) -> None:
"""
This callback hook is called after the final training epoch.
"""
pass
def state_dict(self) -> Dict[str, Any]:
return {}
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
pass
TrainerCallback.register("null")(TrainerCallback)