This repository has been archived by the owner on Dec 16, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
/
moving_average.py
101 lines (79 loc) · 3.61 KB
/
moving_average.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from typing import Iterable, Tuple, Optional
import torch
from allennlp.common.registrable import Registrable
NamedParameter = Tuple[str, torch.Tensor]
class MovingAverage(Registrable):
"""
Tracks a moving average of model parameters.
"""
default_implementation = "exponential"
def __init__(self, parameters: Iterable[NamedParameter]) -> None:
self._parameters = list(parameters)
self._shadows = {name: parameter.data.clone() for name, parameter in self._parameters}
self._backups = {name: parameter.data.clone() for name, parameter in self._parameters}
def apply(self, num_updates: Optional[int] = None):
"""
Update the moving averages based on the latest values of the parameters.
"""
raise NotImplementedError
def assign_average_value(self) -> None:
"""
Replace all the parameter values with the averages.
Save the current parameter values to restore later.
"""
for name, parameter in self._parameters:
self._backups[name].copy_(parameter.data)
parameter.data.copy_(self._shadows[name])
def restore(self) -> None:
"""
Restore the backed-up (non-average) parameter values.
"""
for name, parameter in self._parameters:
parameter.data.copy_(self._backups[name])
@MovingAverage.register("exponential")
class ExponentialMovingAverage(MovingAverage):
"""
Create shadow variables and maintain exponential moving average for model parameters.
Registered as a `MovingAverage` with name "exponential".
# Parameters
parameters : `Iterable[Tuple[str, Parameter]]`, required
The parameters whose averages we'll be tracking.
In a typical AllenNLP configuration file, this argument does not get an entry under the
"moving_average", it gets passed in separately.
decay : `float`, optional (default = `0.9999`)
The decay rate that will be used if `num_updates` is not passed
(and that will be used as an upper bound if `num_updates` is passed).
numerator : `float`, optional (default = `1.0`)
The numerator used to compute the decay rate if `num_updates` is passed.
denominator : `float`, optional (default = `10.0`)
The denominator used to compute the decay rate if `num_updates` is passed.
"""
def __init__(
self,
parameters: Iterable[NamedParameter],
decay: float = 0.9999,
numerator: float = 1.0,
denominator: float = 10.0,
) -> None:
super().__init__(parameters)
self._decay = decay
self._numerator = numerator
self._denominator = denominator
def apply(self, num_updates: Optional[int] = None) -> None:
"""
Apply exponential moving average to `named_parameters` if specified,
or we will apply this to all the trainable parameters of the model.
The optional `num_updates` parameter allows one to tweak the decay rate
dynamically. If passed, the actual decay rate used is:
`min(decay, (numerator + num_updates) / (denominator + num_updates))`
(This logic is based on the Tensorflow exponential moving average
<https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage>)
"""
if num_updates is not None:
decay = min(
self._decay, (self._numerator + num_updates) / (self._denominator + num_updates)
)
else:
decay = self._decay
for name, parameter in self._parameters:
self._shadows[name].mul_(decay).add_((1 - decay) * parameter.data)