This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
polynomial_decay.py
125 lines (104 loc) · 4.33 KB
/
polynomial_decay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from overrides import overrides
import torch
from allennlp.common.checks import ConfigurationError
from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler
@LearningRateScheduler.register("polynomial_decay")
class PolynomialDecay(LearningRateScheduler):
"""
Implements polynomial decay Learning rate scheduling. The learning rate is
first linearly increased for the first `warmup_steps` training steps. Then
it is decayed for `total_steps` - `warmup_steps` from the initial learning
rate to `end_learning_rate` using a polynomial of degree `power`.
Formally,
`lr` = (`initial_lr` - `end_learning_rate`) *
((`total_steps` - `steps`)/(`total_steps` - `warmup_steps`)) ** `power`
# Parameters
optimizer : `torch.optim.Optimizer`
This argument does not get an entry in a configuration file for the
object.
num_epochs: `int`
The number of epochs in the experiment. this does *NOT* get an entry in
the config.
num_steps_per_epoch: `int`
The number of steps per epoch. this does *NOT* get an entry in the
config.
warmup_steps : `int`, required
The number of steps to linearly increase the learning rate.
power : `float`, optional (default = `1.0`)
The power of the polynomial used for decaying.
end_learning_rate : `float`, optional (default = `0.0`)
Final learning rate to decay towards.
# Example
Config for using the `PolynomialDecay` Learning Rate Scheduler with
`warmup_steps` set `100`, `power` set to `2`, and `end_learning_rate` set
to `1e-10`.
```json
{
...
"trainer":{
...
"learning_rate_scheduler": {
"type": "polynomial_decay",
"power": 2,
"warmup_steps": 100,
"end_learning_rate": 1e-10
},
...
}
}
```
Note that you do NOT pass a `optimizer`, `num_epochs`, nor
`num_steps_per_epoch` key to the Learning rate scheduler.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
num_epochs: int,
num_steps_per_epoch: int,
power=1.0,
warmup_steps=0,
end_learning_rate=0.0,
last_epoch: int = -1,
):
super().__init__(optimizer, last_epoch)
# Sanity check here.
if num_steps_per_epoch is None:
raise ConfigurationError(
"'num_steps_per_epoch' is required for this LR scheduler.\n\n"
"If you know how many batches per epoch for your training data, you can set this value "
"directly in your config. Otherwise you'll need to use compatible settings with your data loader "
"so that it can report an accurate number of batches per epoch. "
"If you're using the MultiProcessDataLoader, "
"this means you either need to set 'batches_per_epoch' "
"or leave 'max_instances_in_memory' as None (if your entire dataset can fit into memory)."
)
self.power = power
self.warmup_steps = warmup_steps
self.total_steps = num_epochs * num_steps_per_epoch
self.end_learning_rate = end_learning_rate
self.steps = 0
self.step_batch(0)
@overrides
def get_values(self):
if self.warmup_steps > 0 and self.steps < self.warmup_steps:
f = self.steps / self.warmup_steps
return [f * lr for lr in self.base_values]
if self.steps >= self.total_steps:
return [self.end_learning_rate for _ in self.base_values]
current_decay_steps = self.total_steps - self.steps
total_decay_steps = self.total_steps - self.warmup_steps
f = (current_decay_steps / total_decay_steps) ** self.power
return [
f * (lr - self.end_learning_rate) + self.end_learning_rate for lr in self.base_values
]
@overrides
def step(self, metric: float = None) -> None:
pass
@overrides
def step_batch(self, batch_num_total: int = None) -> None:
if batch_num_total is None:
self.steps += 1
else:
self.steps = batch_num_total
for param_group, lr in zip(self.optimizer.param_groups, self.get_values()):
param_group[self.param_group_field] = lr