forked from Lightning-AI/pytorch-lightning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathearly_stopping.py
216 lines (172 loc) · 7.77 KB
/
early_stopping.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
r"""
Early Stopping
==============
Monitor a validation metric and stop training when it stops improving.
"""
from copy import deepcopy
import os
import numpy as np
import torch
import torch.distributed as dist
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import rank_zero_warn
torch_inf = torch.tensor(np.Inf)
try:
import torch_xla
import torch_xla.core.xla_model as xm
except ImportError:
XLA_AVAILABLE = False
else:
XLA_AVAILABLE = True
class EarlyStopping(Callback):
r"""
Args:
monitor: quantity to be monitored. Default: ``'val_loss'``.
min_delta: minimum change in the monitored quantity
to qualify as an improvement, i.e. an absolute
change of less than `min_delta`, will count as no
improvement. Default: ``0.0``.
patience: number of validation epochs with no improvement
after which training will be stopped. Default: ``3``.
verbose: verbosity mode. Default: ``False``.
mode: one of {auto, min, max}. In `min` mode,
training will stop when the quantity
monitored has stopped decreasing; in `max`
mode it will stop when the quantity
monitored has stopped increasing; in `auto`
mode, the direction is automatically inferred
from the name of the monitored quantity. Default: ``'auto'``.
strict: whether to crash the training if `monitor` is
not found in the validation metrics. Default: ``True``.
Example::
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import EarlyStopping
>>> early_stopping = EarlyStopping('val_loss')
>>> trainer = Trainer(early_stop_callback=early_stopping)
"""
mode_dict = {
'min': torch.lt,
'max': torch.gt,
}
def __init__(self, monitor: str = 'val_loss', min_delta: float = 0.0, patience: int = 3,
verbose: bool = False, mode: str = 'auto', strict: bool = True):
super().__init__()
self.monitor = monitor
self.patience = patience
self.verbose = verbose
self.strict = strict
self.min_delta = min_delta
self.wait_count = 0
self.stopped_epoch = 0
self.mode = mode
if mode not in self.mode_dict:
if self.verbose > 0:
log.info(f'EarlyStopping mode {mode} is unknown, fallback to auto mode.')
self.mode = 'auto'
if self.mode == 'auto':
if self.monitor == 'acc':
self.mode = 'max'
else:
self.mode = 'min'
if self.verbose > 0:
log.info(f'EarlyStopping mode set to {self.mode} for monitoring {self.monitor}.')
self.min_delta *= 1 if self.monitor_op == torch.gt else -1
self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
def _validate_condition_metric(self, logs):
"""
Checks that the condition metric for early stopping is good
Args:
logs: callback metrics from validation output
Return:
True if specified metric is available
"""
monitor_val = logs.get(self.monitor)
error_msg = (f'Early stopping conditioned on metric `{self.monitor}`'
f' which is not available. Either add `{self.monitor}` to the return of '
f' validation_epoch end or modify your EarlyStopping callback to use any of the '
f'following: `{"`, `".join(list(logs.keys()))}`')
if monitor_val is None:
if self.strict:
raise RuntimeError(error_msg)
if self.verbose > 0:
rank_zero_warn(error_msg, RuntimeWarning)
return False
return True
@property
def monitor_op(self):
return self.mode_dict[self.mode]
def state_dict(self):
return {
'wait_count': self.wait_count,
'stopped_epoch': self.stopped_epoch,
'best_score': self.best_score,
'patience': self.patience
}
def load_state_dict(self, state_dict):
state_dict = deepcopy(state_dict)
self.wait_count = state_dict['wait_count']
self.stopped_epoch = state_dict['stopped_epoch']
self.best_score = state_dict['best_score']
self.patience = state_dict['patience']
def on_validation_end(self, trainer, pl_module):
self._run_early_stopping_check(trainer, pl_module)
def on_validation_epoch_end(self, trainer, pl_module):
val_es_key = 'val_early_stop_on'
if trainer.callback_metrics.get(val_es_key) is not None:
self.monitor = val_es_key
# disable strict checking when using structured results
if val_es_key in trainer.callback_metrics:
self.strict = False
self._validate_condition_metric(trainer.callback_metrics)
def on_train_epoch_end(self, trainer, pl_module):
# disable early stopping in train loop when there's a val loop
if self.monitor == 'val_early_stop_on':
return
# early stopping can also work in the train loop when there is no val loop and when using structured results
should_check_early_stop = False
train_es_key = 'early_stop_on'
if trainer.callback_metrics.get(train_es_key, None) is not None:
self.monitor = train_es_key
should_check_early_stop = True
if should_check_early_stop:
self._run_early_stopping_check(trainer, pl_module)
def _run_early_stopping_check(self, trainer, pl_module):
logs = trainer.callback_metrics
if not self._validate_condition_metric(logs):
return # short circuit if metric not present
current = logs.get(self.monitor)
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(current)
if not isinstance(current, torch.Tensor):
current = torch.tensor(current, device=pl_module.device)
if trainer.use_tpu and XLA_AVAILABLE:
current = current.cpu()
if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
self.wait_count = 0
else:
self.wait_count += 1
should_stop = self.wait_count >= self.patience
if bool(should_stop):
self.stopped_epoch = trainer.current_epoch
trainer.should_stop = True
# stop every ddp process if any world process decides to stop
self._stop_distributed_training(trainer, pl_module)
def _stop_distributed_training(self, trainer, pl_module):
# in ddp make sure all processes stop when one is flagged
if trainer.use_ddp or trainer.use_ddp2:
stop = torch.tensor(int(trainer.should_stop), device=pl_module.device)
dist.all_reduce(stop, op=dist.reduce_op.SUM)
dist.barrier()
trainer.should_stop = stop == trainer.world_size
if trainer.use_tpu:
stop = torch.tensor(int(trainer.should_stop), device=pl_module.device, dtype=torch.int32)
stop = xm.mesh_reduce("stop_signal", stop, torch.cat)
torch_xla.core.xla_model.rendezvous("pl.EarlyStoppingCallback.stop_distributed_training_check")
trainer.should_stop = int(stop.item()) == trainer.world_size
def on_train_end(self, trainer, pl_module):
if self.stopped_epoch > 0 and self.verbose > 0:
rank_zero_warn('Displayed epoch numbers by `EarlyStopping` start from "1" until v0.6.x,'
' but will start from "0" in v0.8.0.', DeprecationWarning)
log.info(f'Epoch {self.stopped_epoch + 1:05d}: early stopping triggered.')