This repository has been archived by the owner on Jun 27, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 4
/
optimizer_v1.py
120 lines (97 loc) · 5.19 KB
/
optimizer_v1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from .backend import keras
from .backend import backend as K
__all__ = ['AdamAccumulated']
def identity(x):
return x
symbolic = identity
if hasattr(K, 'symbolic'):
symbolic = K.symbolic
class AdamAccumulated(keras.optimizers.Optimizer):
"""Adam optimizer with gradient accumulation.
Default parameters follow those provided in the original paper.
# Arguments
accumulation_steps: int > 0. Update gradient in every accumulation steps.
learning_rate: float >= 0. Learning rate.
beta_1: float, 0 < beta < 1. Generally close to 1.
beta_2: float, 0 < beta < 1. Generally close to 1.
epsilon: float >= 0. Fuzz factor. If `None`, defaults to `K.epsilon()`.
decay: float >= 0. Learning rate decay over each update.
amsgrad: boolean. Whether to apply the AMSGrad variant of this
algorithm from the paper "On the Convergence of Adam and Beyond".
# References
- [Adam - A Method for Stochastic Optimization](https://arxiv.org/abs/1412.6980v8)
- [On the Convergence of Adam and Beyond](https://openreview.net/forum?id=ryQu7f-RZ)
"""
def __init__(self, accumulation_steps, learning_rate=0.001, beta_1=0.9, beta_2=0.999,
epsilon=None, decay=0., amsgrad=False, **kwargs):
learning_rate = kwargs.pop('lr', learning_rate)
super(AdamAccumulated, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.accumulation_steps = K.variable(accumulation_steps, dtype='int64', name='accumulation_steps')
self.learning_rate = K.variable(learning_rate, name='learning_rate')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.decay = K.variable(decay, name='decay')
if epsilon is None:
epsilon = K.epsilon()
self.epsilon = epsilon
self.initial_decay = decay
self.amsgrad = amsgrad
@symbolic
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
update_cond = K.equal((self.iterations + 1) % self.accumulation_steps, 0)
sub_step = self.iterations % self.accumulation_steps + 1
t = K.cast(self.iterations // self.accumulation_steps, K.floatx()) + 1
lr = self.lr
if self.initial_decay > 0:
lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, K.dtype(self.decay))))
lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / (1. - K.pow(self.beta_1, t)))
lr_t = K.switch(update_cond, lr_t, 0.0)
ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p), name='m_' + str(i)) for (i, p) in enumerate(params)]
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p), name='v_' + str(i)) for (i, p) in enumerate(params)]
if self.amsgrad:
vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p), name='vhat_' + str(i)) for (i, p) in enumerate(params)]
else:
vhats = [K.zeros(1, name='vhat_' + str(i)) for i in range(len(params))]
self.weights = [self.iterations] + ms + vs + vhats
acc_grads = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
for grad, acc_grad in zip(grads, acc_grads):
ave_grad = grad / K.cast(self.accumulation_steps, K.floatx())
self.updates.append(K.update(
acc_grad,
K.switch(
K.equal(sub_step, 1),
ave_grad,
acc_grad + (ave_grad - acc_grad) / K.cast(sub_step, K.floatx())
),
))
grads = [K.switch(update_cond, grad, K.zeros_like(grad)) for grad in acc_grads]
for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
m_t = K.switch(update_cond, (self.beta_1 * m) + (1. - self.beta_1) * g, m)
v_t = K.switch(update_cond, (self.beta_2 * v) + (1. - self.beta_2) * K.square(g), v)
if self.amsgrad:
vhat_t = K.switch(update_cond, K.maximum(vhat, v_t), vhat)
p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
self.updates.append(K.update(vhat, vhat_t))
else:
p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
self.updates.append(K.update(m, m_t))
self.updates.append(K.update(v, v_t))
new_p = p_t
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
self.updates.append(K.update(p, new_p))
return self.updates
def get_config(self):
config = {'accumulation_steps': int(K.get_value(self.accumulation_steps)),
'learning_rate': float(K.get_value(self.learning_rate)),
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'decay': float(K.get_value(self.decay)),
'epsilon': self.epsilon,
'amsgrad': self.amsgrad}
base_config = super(AdamAccumulated, self).get_config()
return dict(list(base_config.items()) + list(config.items()))