This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
Copy pathcosine.py
91 lines (74 loc) · 2.98 KB
/
cosine.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import logging
from overrides import overrides
import numpy as np
import torch
from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler
logger = logging.getLogger(__name__)
@LearningRateScheduler.register("cosine")
class CosineWithRestarts(LearningRateScheduler):
"""
Cosine annealing with restarts.
This is described in the paper https://arxiv.org/abs/1608.03983. Note that early
stopping should typically be avoided when using this schedule.
Registered as a `LearningRateScheduler` with name "cosine".
# Parameters
optimizer : `torch.optim.Optimizer`
This argument does not get an entry in a configuration file for the object.
t_initial : `int`
The number of iterations (epochs) within the first cycle.
t_mul : `float`, optional (default=`1`)
Determines the number of iterations (epochs) in the i-th decay cycle,
which is the length of the last cycle multiplied by `t_mul`.
eta_min : `float`, optional (default=`0`)
The minimum learning rate.
eta_mul : `float`, optional (default=`1`)
Determines the initial learning rate for the i-th decay cycle, which is the
last initial learning rate multiplied by `m_mul`.
last_epoch : `int`, optional (default=`-1`)
The index of the last epoch. This is used when restarting.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
t_initial: int,
t_mul: float = 1.0,
eta_min: float = 0.0,
eta_mul: float = 1.0,
last_epoch: int = -1,
) -> None:
assert t_initial > 0
assert eta_min >= 0
if t_initial == 1 and t_mul == 1 and eta_mul == 1:
logger.warning(
"Cosine annealing scheduler will have no effect on the learning "
"rate since t_initial = t_mul = eta_mul = 1."
)
self.t_initial = t_initial
self.t_mul = t_mul
self.eta_min = eta_min
self.eta_mul = eta_mul
self._last_restart: int = 0
self._cycle_counter: int = 0
self._cycle_len: int = t_initial
self._n_restarts: int = 0
super().__init__(optimizer, last_epoch)
@overrides
def get_values(self):
"""Get updated learning rate."""
if self.last_epoch == -1:
return self.base_values
step = self.last_epoch + 1
self._cycle_counter = step - self._last_restart
if self._cycle_counter % self._cycle_len == 0:
self._n_restarts += 1
self._cycle_counter = 0
self._last_restart = step
base_lrs = [lr * self.eta_mul ** self._n_restarts for lr in self.base_values]
self._cycle_len = int(self.t_initial * self.t_mul ** self._n_restarts)
lrs = [
self.eta_min
+ ((lr - self.eta_min) / 2)
* (np.cos(np.pi * (self._cycle_counter % self._cycle_len) / self._cycle_len) + 1)
for lr in base_lrs
]
return lrs