forked from Lightning-AI/pytorch-lightning
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlr_finder.py
executable file
·513 lines (398 loc) · 17.6 KB
/
lr_finder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
"""
Trainer Learning Rate Finder
"""
import os
import importlib
from abc import ABC, abstractmethod
from typing import Optional, Sequence, Tuple, List, Union
import numpy as np
import torch
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
# check if ipywidgets is installed before importing tqdm.auto
# to ensure it won't fail and a progress bar is displayed
if importlib.util.find_spec('ipywidgets') is not None:
from tqdm.auto import tqdm
else:
from tqdm import tqdm
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.loggers.base import DummyLogger
from pytorch_lightning import _logger as log
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities import rank_zero_warn
class TrainerLRFinderMixin(ABC):
# this is just a summary on variables used in this abstract class,
# the proper values/initialisation should be done in child class
default_root_dir: str
progress_bar_callback: ...
global_step: int
total_batch_idx: int
on_gpu: bool
@abstractmethod
def save_checkpoint(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def restore(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def init_optimizers(self, *args) -> Tuple[List, List, List]:
"""Warning: this is just empty shell for code implemented in other class."""
@abstractmethod
def fit(self, *args):
"""Warning: this is just empty shell for code implemented in other class."""
def _run_lr_finder_internally(self, model: LightningModule):
""" Call lr finder internally during Trainer.fit() """
lr_finder = self.lr_find(model)
lr = lr_finder.suggestion()
# TODO: log lr.results to self.logger
if isinstance(self.auto_lr_find, str):
# Try to find requested field, may be nested
if _nested_hasattr(model, self.auto_lr_find):
_nested_setattr(model, self.auto_lr_find, lr)
else:
raise MisconfigurationException(
f'`auto_lr_find` was set to {self.auto_lr_find}, however'
' could not find this as a field in `model.hparams`.')
else:
if hasattr(model, 'lr'):
model.lr = lr
elif hasattr(model, 'learning_rate'):
model.learning_rate = lr
else:
raise MisconfigurationException(
'When auto_lr_find is set to True, expects that hparams'
' either has field `lr` or `learning_rate` that can overridden')
log.info(f'Learning rate set to {lr}')
def lr_find(
self,
model: LightningModule,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Optional[Union[DataLoader, List[DataLoader]]] = None,
min_lr: float = 1e-8,
max_lr: float = 1,
num_training: int = 100,
mode: str = 'exponential',
early_stop_threshold: float = 4.0,
):
r"""
lr_find enables the user to do a range test of good initial learning rates,
to reduce the amount of guesswork in picking a good starting learning rate.
Args:
model: Model to do range testing for
train_dataloader: A PyTorch
DataLoader with training samples. If the model has
a predefined train_dataloader method this will be skipped.
min_lr: minimum learning rate to investigate
max_lr: maximum learning rate to investigate
num_training: number of learning rates to test
mode: search strategy, either 'linear' or 'exponential'. If set to
'linear' the learning rate will be searched by linearly increasing
after each batch. If set to 'exponential', will increase learning
rate exponentially.
early_stop_threshold: threshold for stopping the search. If the
loss at any point is larger than early_stop_threshold*best_loss
then the search is stopped. To disable, set to None.
Example::
# Setup model and trainer
model = MyModelClass(hparams)
trainer = pl.Trainer()
# Run lr finder
lr_finder = trainer.lr_find(model, ...)
# Inspect results
fig = lr_finder.plot(); fig.show()
suggested_lr = lr_finder.suggestion()
# Overwrite lr and create new model
hparams.lr = suggested_lr
model = MyModelClass(hparams)
# Ready to train with new learning rate
trainer.fit(model)
"""
save_path = os.path.join(self.default_root_dir, 'lr_find_temp.ckpt')
self.__lr_finder_dump_params(model)
# Prevent going into infinite loop
self.auto_lr_find = False
# Initialize lr finder object (stores results)
lr_finder = _LRFinder(mode, min_lr, max_lr, num_training)
# Use special lr logger callback
self.callbacks = [_LRCallback(num_training,
early_stop_threshold,
progress_bar_refresh_rate=1)]
# No logging
self.logger = DummyLogger()
# Max step set to number of iterations
self.max_steps = num_training
# Disable standard progress bar for fit
if self.progress_bar_callback:
self.progress_bar_callback.disable()
# Disable standard checkpoint & early stopping
self.checkpoint_callback = False
self.early_stop_callback = None
# Required for saving the model
self.optimizers, self.schedulers = [], [],
self.model = model
# Dump model checkpoint
self.save_checkpoint(str(save_path))
# Configure optimizer and scheduler
optimizers, _, _ = self.init_optimizers(model)
if len(optimizers) != 1:
raise MisconfigurationException(
f'`model.configure_optimizers()` returned {len(optimizers)}, but'
' learning rate finder only works with single optimizer')
model.configure_optimizers = lr_finder._get_new_optimizer(optimizers[0])
# Fit, lr & loss logged in callback
self.fit(model,
train_dataloader=train_dataloader,
val_dataloaders=val_dataloaders)
# Prompt if we stopped early
if self.global_step != num_training:
log.info('LR finder stopped early due to diverging loss.')
# Transfer results from callback to lr finder object
lr_finder.results.update({'lr': self.callbacks[0].lrs,
'loss': self.callbacks[0].losses})
lr_finder._total_batch_idx = self.total_batch_idx # for debug purpose
# Reset model state
self.restore(str(save_path), on_gpu=self.on_gpu)
os.remove(save_path)
# Finish by resetting variables so trainer is ready to fit model
self.__lr_finder_restore_params(model)
if self.progress_bar_callback:
self.progress_bar_callback.enable()
return lr_finder
def __lr_finder_dump_params(self, model):
# Prevent going into infinite loop
self.__dumped_params = {
'auto_lr_find': self.auto_lr_find,
'callbacks': self.callbacks,
'logger': self.logger,
'max_steps': self.max_steps,
'checkpoint_callback': self.checkpoint_callback,
'early_stop_callback': self.early_stop_callback,
'configure_optimizers': model.configure_optimizers,
}
def __lr_finder_restore_params(self, model):
self.auto_lr_find = self.__dumped_params['auto_lr_find']
self.logger = self.__dumped_params['logger']
self.callbacks = self.__dumped_params['callbacks']
self.max_steps = self.__dumped_params['max_steps']
self.checkpoint_callback = self.__dumped_params['checkpoint_callback']
self.early_stop_callback = self.__dumped_params['early_stop_callback']
model.configure_optimizers = self.__dumped_params['configure_optimizers']
del self.__dumped_params
class _LRFinder(object):
""" LR finder object. This object stores the results of Trainer.lr_find().
Args:
mode: either `linear` or `exponential`, how to increase lr after each step
lr_min: lr to start search from
lr_max: lr to stop search
num_training: number of steps to take between lr_min and lr_max
Example::
# Run lr finder
lr_finder = trainer.lr_find(model)
# Results stored in
lr_finder.results
# Plot using
lr_finder.plot()
# Get suggestion
lr = lr_finder.suggestion()
"""
def __init__(self, mode: str, lr_min: float, lr_max: float, num_training: int):
assert mode in ('linear', 'exponential'), \
'mode should be either `linear` or `exponential`'
self.mode = mode
self.lr_min = lr_min
self.lr_max = lr_max
self.num_training = num_training
self.results = {}
self._total_batch_idx = 0 # for debug purpose
def _get_new_optimizer(self, optimizer: torch.optim.Optimizer):
""" Construct a new `configure_optimizers()` method, that has a optimizer
with initial lr set to lr_min and a scheduler that will either
linearly or exponentially increase the lr to lr_max in num_training steps.
Args:
optimizer: instance of `torch.optim.Optimizer`
"""
new_lrs = [self.lr_min] * len(optimizer.param_groups)
for param_group, new_lr in zip(optimizer.param_groups, new_lrs):
param_group["lr"] = new_lr
param_group["initial_lr"] = new_lr
args = (optimizer, self.lr_max, self.num_training)
scheduler = _LinearLR(*args) if self.mode == 'linear' else _ExponentialLR(*args)
def configure_optimizers():
return [optimizer], [{'scheduler': scheduler,
'interval': 'step'}]
return configure_optimizers
def plot(self, suggest: bool = False, show: bool = False):
""" Plot results from lr_find run
Args:
suggest: if True, will mark suggested lr to use with a red point
show: if True, will show figure
"""
import matplotlib.pyplot as plt
lrs = self.results["lr"]
losses = self.results["loss"]
fig, ax = plt.subplots()
# Plot loss as a function of the learning rate
ax.plot(lrs, losses)
if self.mode == 'exponential':
ax.set_xscale("log")
ax.set_xlabel("Learning rate")
ax.set_ylabel("Loss")
if suggest:
_ = self.suggestion()
if self._optimal_idx:
ax.plot(lrs[self._optimal_idx], losses[self._optimal_idx],
markersize=10, marker='o', color='red')
if show:
plt.show()
return fig
def suggestion(self, skip_begin: int = 10, skip_end: int = 1):
""" This will propose a suggestion for choice of initial learning rate
as the point with the steepest negative gradient.
Returns:
lr: suggested initial learning rate to use
skip_begin: how many samples to skip in the beginning. Prevent too naive estimates
skip_end: how many samples to skip in the end. Prevent too optimistic estimates
"""
try:
loss = np.array(self.results["loss"][skip_begin:-skip_end])
loss = loss[np.isfinite(loss)]
min_grad = np.gradient(loss).argmin()
self._optimal_idx = min_grad + skip_begin
return self.results["lr"][self._optimal_idx]
except Exception:
log.exception('Failed to compute suggesting for `lr`. There might not be enough points.')
self._optimal_idx = None
class _LRCallback(Callback):
""" Special callback used by the learning rate finder. This callbacks log
the learning rate before each batch and log the corresponding loss after
each batch.
Args:
num_training: number of iterations done by the learning rate finder
early_stop_threshold: threshold for stopping the search. If the
loss at any point is larger than ``early_stop_threshold*best_loss``
then the search is stopped. To disable, set to ``None``.
progress_bar_refresh_rate: rate to refresh the progress bar for
the learning rate finder
beta: smoothing value, the loss being logged is a running average of
loss values logged until now. ``beta`` controls the forget rate i.e.
if ``beta=0`` all past information is ignored.
"""
def __init__(self, num_training: int,
early_stop_threshold: float = 4.0,
progress_bar_refresh_rate: int = 0,
beta: float = 0.98):
self.num_training = num_training
self.early_stop_threshold = early_stop_threshold
self.beta = beta
self.losses = []
self.lrs = []
self.avg_loss = 0.0
self.best_loss = 0.0
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.progress_bar = None
def on_batch_start(self, trainer, pl_module):
""" Called before each training batch, logs the lr that will be used """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return
if self.progress_bar_refresh_rate and self.progress_bar is None:
self.progress_bar = tqdm(desc='Finding best initial lr', total=self.num_training)
self.lrs.append(trainer.lr_schedulers[0]['scheduler'].lr[0])
def on_batch_end(self, trainer, pl_module):
""" Called when the training batch ends, logs the calculated loss """
if (trainer.batch_idx + 1) % trainer.accumulate_grad_batches != 0:
return
if self.progress_bar:
self.progress_bar.update()
current_loss = trainer.running_loss.last().item()
current_step = trainer.global_step + 1 # remove the +1 in 1.0
# Avg loss (loss with momentum) + smoothing
self.avg_loss = self.beta * self.avg_loss + (1 - self.beta) * current_loss
smoothed_loss = self.avg_loss / (1 - self.beta**current_step)
# Check if we diverging
if self.early_stop_threshold is not None:
if current_step > 1 and smoothed_loss > self.early_stop_threshold * self.best_loss:
trainer.max_steps = current_step # stop signal
if self.progress_bar:
self.progress_bar.close()
# Save best loss for diverging checking
if smoothed_loss < self.best_loss or current_step == 1:
self.best_loss = smoothed_loss
self.losses.append(smoothed_loss)
class _LinearLR(_LRScheduler):
"""Linearly increases the learning rate between two boundaries
over a number of iterations.
Arguments:
optimizer: wrapped optimizer.
end_lr: the final learning rate.
num_iter: the number of iterations over which the test occurs.
last_epoch: the index of last epoch. Default: -1.
"""
last_epoch: int
base_lrs: Sequence
def __init__(self,
optimizer: torch.optim.Optimizer,
end_lr: float,
num_iter: int,
last_epoch: int = -1):
self.end_lr = end_lr
self.num_iter = num_iter
super(_LinearLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter
if self.last_epoch > 0:
val = [base_lr + r * (self.end_lr - base_lr) for base_lr in self.base_lrs]
else:
val = [base_lr for base_lr in self.base_lrs]
self._lr = val
return val
@property
def lr(self):
return self._lr
class _ExponentialLR(_LRScheduler):
"""Exponentially increases the learning rate between two boundaries
over a number of iterations.
Arguments:
optimizer: wrapped optimizer.
end_lr: the final learning rate.
num_iter: the number of iterations over which the test occurs.
last_epoch: the index of last epoch. Default: -1.
"""
last_epoch: int
base_lrs: Sequence
def __init__(self,
optimizer: torch.optim.Optimizer,
end_lr: float,
num_iter: int,
last_epoch: int = -1):
self.end_lr = end_lr
self.num_iter = num_iter
super(_ExponentialLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
curr_iter = self.last_epoch + 1
r = curr_iter / self.num_iter
if self.last_epoch > 0:
val = [base_lr * (self.end_lr / base_lr) ** r for base_lr in self.base_lrs]
else:
val = [base_lr for base_lr in self.base_lrs]
self._lr = val
return val
@property
def lr(self):
return self._lr
def _nested_hasattr(obj, path):
parts = path.split(".")
for part in parts:
if hasattr(obj, part):
obj = getattr(obj, part)
else:
return False
else:
return True
def _nested_setattr(obj, path, val):
parts = path.split(".")
for part in parts[:-1]:
if hasattr(obj, part):
obj = getattr(obj, part)
setattr(obj, parts[-1], val)