This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
/
linear_with_warmup.py
71 lines (61 loc) · 2.09 KB
/
linear_with_warmup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import torch
from allennlp.training.learning_rate_schedulers import PolynomialDecay
from allennlp.training.learning_rate_schedulers.learning_rate_scheduler import LearningRateScheduler
@LearningRateScheduler.register("linear_with_warmup")
class LinearWithWarmup(PolynomialDecay):
"""
Implements a learning rate scheduler that increases the learning rate to
`lr` during the first `warmup_steps` steps, and then decreases it to zero
over the rest of the training steps.
In practice, this is a wrapper of [`PolynomialDecay`](
https://docs.allennlp.org/main/api/training/
learning_rate_schedulers/polynomial_decay/)
with `power=1` and `end_learning_rate=0`.
# Parameters
optimizer : `torch.optim.Optimizer`
This argument does not get an entry in a configuration file for the
object.
num_epochs: `int`
The number of epochs in the experiment. this does *NOT* get an entry in
the config.
num_steps_per_epoch: `int`
The number of steps per epoch. this does *NOT* get an entry in the
config.
warmup_steps : `int`, required
The number of steps to linearly increase the learning rate.
# Example
Config for using the `LinearWithWarmup` Learning Rate Scheduler with
`warmup_steps` set `100`.
```json
{
...
"trainer":{
...
"learning_rate_scheduler": {
"type": "linear_with_warmup",
"warmup_steps":100
},
...
}
}
```
Note that you do NOT pass a `optimizer`, `num_epochs`, nor
`num_steps_per_epoch` key to the Learning rate scheduler.
"""
def __init__(
self,
optimizer: torch.optim.Optimizer,
num_epochs: int,
num_steps_per_epoch: int,
warmup_steps: int = 100,
last_epoch: int = -1,
) -> None:
super().__init__(
optimizer,
num_epochs,
num_steps_per_epoch,
power=1.0,
warmup_steps=warmup_steps,
end_learning_rate=0.0,
last_epoch=last_epoch,
)