Skip to content

Commit

Permalink
Synchronize updates; fix AdamW lr_t (keras)
Browse files Browse the repository at this point in the history
**BUGFIXES**:
 - Last weight in network would be updated with `t_cur` one update ahead, desynchronizing it from all other weights
 - `AdamW` in `keras` (optimizers.py, optimizers_225.py) weight updates were _not_ mediated by `eta_t`, so cosine annealing had no effect. Pardon the mishap

**FEATURES**:
 - Added `lr_t` to tf.keras optimizers to track "actual" learning rate externally; use `K.eval(model.optimizer.lr_t)` to get "actual" learning rate for given `t_cur` and `iterations`
 - Added `lr_t` vs. iterations plot to README, and source code in `example.py`

**MISC**:
 - Added `test_updates` to ensure all weights update synchronously, and that `eta_t` first applies on weights as-is and _then_ updates according to `t_cur`
 - Fixes #47
  • Loading branch information
OverLordGoldDragon committed Jun 4, 2020
1 parent 630e9ae commit 8dc42e0
Show file tree
Hide file tree
Showing 9 changed files with 195 additions and 84 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ for epoch in range(3):
K.set_value(model.optimizer.t_cur, -1) # WARM RESTART: reset cosine annealing argument
print("EPOCH {} COMPLETED\n".format(epoch + 1))
```
<img src="https://user-images.githubusercontent.com/16495490/65729113-2063d400-e08b-11e9-8b6a-3a2ea1c62fdd.png" width="450">
<img src="https://user-images.githubusercontent.com/16495490/83707138-51d56c00-a62a-11ea-9eba-60284490992b.png" width="470">

(Full example + plot code: [example.py](https://github.com/OverLordGoldDragon/keras-adamw/blob/master/example.py))
(Full example + plot code, and explanation of `lr_t` vs. `lr`: [example.py](https://github.com/OverLordGoldDragon/keras-adamw/blob/master/example.py))

## Use guidelines
### Weight decay
Expand Down
33 changes: 26 additions & 7 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from keras import backend as K
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

from keras_adamw import AdamW
from keras_adamw.utils import K_eval


#%%############################################################################
ipt = Input(shape=(120,4))
x = LSTM(60, activation='relu', name='lstm_1',
kernel_regularizer=l1(1e-4), recurrent_regularizer=l2(2e-4))(ipt)
Expand All @@ -21,22 +22,40 @@
use_cosine_annealing=True, total_iterations=24)
model.compile(optimizer, loss='binary_crossentropy')

#%%############################################################################
eta_history = []
lr_history = []
for epoch in range(3):
for iteration in range(24):
x = np.random.rand(10, 120, 4) # dummy data
y = np.random.randint(0, 2, (10, 1)) # dummy labels
loss = model.train_on_batch(x, y)
eta_history.append(K_eval(model.optimizer.eta_t, K))
lr_history.append(K_eval(model.optimizer.lr_t, K))
print("Iter {} loss: {}".format(iteration + 1, "%.3f" % loss))
if iteration == (24 - 2):
K.set_value(model.optimizer.t_cur, -1) # WARM RESTART
print("EPOCH {} COMPLETED\n".format(epoch + 1))

plt.plot(eta_history, linewidth=2)
plt.xlim(0, len(eta_history))
plt.ylim(0, 1.05)
plt.ylabel('eta_t', weight='bold', fontsize=15)
plt.xlabel('Train iterations', weight='bold', fontsize=15)
plt.gcf().set_size_inches(10, 5)
# learning rate at iteration `t` (lr_t) is subject to scaling depending on
# optimizer; Adam and Nadam use betas (1 & 2), SGD w/ momentum uses beta.
# -W optimizers additionally scale by eta_t. The lr_t plots reflect the
# ultimate learning rate as a result of all the scalings.

#%%############################################################################
_, ax = plt.subplots(figsize=(10, 5))
ax.plot(eta_history, linewidth=2)
ax.set_xlim(0, len(eta_history))
ax.set_ylim(0, 1.05)
ax.set_ylabel('eta_t', weight='bold', fontsize=15)
ax.set_xlabel('Train iterations', weight='bold', fontsize=15)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%0.0e'))

_, ax = plt.subplots(figsize=(10, 5))
ax.plot(lr_history, linewidth=2)
ax.set_xlim(0, len(lr_history))
ax.set_ylim(0, 1.05 * np.max(lr_history))
ax.set_ylabel('lr_t', weight='bold', fontsize=15)
ax.set_xlabel('Train iterations', weight='bold', fontsize=15)
ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%0.0e'))
plt.show()
5 changes: 3 additions & 2 deletions keras_adamw/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,11 @@ def get_updates(self, loss, params):
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
if self.amsgrad:
vhat_t = K.maximum(vhat, v_t)
p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
p_t = p - self.eta_t * lr_t * m_t / (
K.sqrt(vhat_t) + self.epsilon)
self.updates.append(K.update(vhat, vhat_t))
else:
p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)
p_t = p - self.eta_t * lr_t * m_t / (K.sqrt(v_t) + self.epsilon)

self.updates.append(K.update(m, m_t))
self.updates.append(K.update(v, v_t))
Expand Down
14 changes: 8 additions & 6 deletions keras_adamw/optimizers_225.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,16 @@ def get_updates(self, loss, params):

m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)
self.updates.append(K.update(m, m_t))
self.updates.append(K.update(v, v_t))

if self.amsgrad:
vhat_t = K.maximum(vhat, v_t)
p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon)
p_t = p - self.eta_t * lr_t * m_t / (
K.sqrt(vhat_t) + self.epsilon)
self.updates.append(K.update(vhat, vhat_t))
else:
p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon)

self.updates.append(K.update(m, m_t))
self.updates.append(K.update(v, v_t))
p_t = p - self.eta_t * lr_t * m_t / (K.sqrt(v_t) + self.epsilon)

# Weight decays
if p.name in self.weight_decays.keys():
Expand Down Expand Up @@ -306,8 +307,9 @@ def get_updates(self, loss, params):

self.updates.append(K.update(m, m_t))
self.updates.append(K.update(v, v_t))

p_t = p - self.eta_t * lr_t * m_t_bar / (
K.sqrt(v_t_prime) + self.epsilon)
K.sqrt(v_t_prime) + self.epsilon)

# Weight decays
if p.name in self.weight_decays.keys():
Expand Down
71 changes: 44 additions & 27 deletions keras_adamw/optimizers_225tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tensorflow.python.util.tf_export import keras_export
from tensorflow.keras import backend as K
from .utils import _init_weight_decays, _apply_weight_decays, _check_args
from .utils import _update_t_cur_eta_t_apply_lr_mult
from .utils import _update_t_cur_eta_t_v2, _apply_lr_multiplier


@keras_export('keras.optimizers.AdamW')
Expand Down Expand Up @@ -144,11 +144,10 @@ def _resource_apply_dense(self, grad, var):
epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)

lr_t = lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)

# Learning rate multipliers
# Cosine annealing
(iteration_done, t_cur_update, eta_t_update
) = _update_t_cur_eta_t_apply_lr_mult(self, lr_t, var)
if self.lr_multipliers is not None:
lr_t = _apply_lr_multiplier(self, lr_t, var)


m_t = state_ops.assign(m,
beta_1_t * m + (1.0 - beta_1_t) * grad,
Expand All @@ -171,9 +170,13 @@ def _resource_apply_dense(self, grad, var):
if var.name in self.weight_decays.keys():
var_t = _apply_weight_decays(self, var, var_t)

var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

# Cosine annealing
(iteration_done, t_cur_update, eta_t_update
) = _update_t_cur_eta_t_v2(self, lr_t, var)
if iteration_done and not self._init_notified:
self._init_notified = True
var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

updates = [var_update, m_t, v_t]
if iteration_done:
Expand All @@ -197,11 +200,9 @@ def _resource_apply_sparse(self, grad, var, indices):
epsilon_t = ops.convert_to_tensor(self.epsilon, var_dtype)

lr_t = lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power)

# Learning rate multipliers
# Cosine annealing
(iteration_done, t_cur_update, eta_t_update
) = _update_t_cur_eta_t_apply_lr_mult(self, lr_t, var)
if self.lr_multipliers is not None:
lr_t = _apply_lr_multiplier(self, lr_t, var)

m_scaled_g_values = grad * (1 - beta_1_t)
m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking)
Expand Down Expand Up @@ -231,9 +232,13 @@ def _resource_apply_sparse(self, grad, var, indices):
if var.name in self.weight_decays.keys():
var_t = _apply_weight_decays(self, var, var_t)

var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

# Cosine annealing
(iteration_done, t_cur_update, eta_t_update
) = _update_t_cur_eta_t_v2(self, lr_t, var)
if iteration_done and not self._init_notified:
self._init_notified = True
var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

updates = [var_update, m_t, v_t]
if iteration_done:
Expand Down Expand Up @@ -419,9 +424,8 @@ def _resource_apply_dense(self, grad, var):
decay_base = math_ops.cast(0.96, var_dtype)

# Learning rate multipliers
# Cosine annealing
(iteration_done, t_cur_update, eta_t_update
) = _update_t_cur_eta_t_apply_lr_mult(self, lr_t, var)
if self.lr_multipliers is not None:
lr_t = _apply_lr_multiplier(self, lr_t, var)

# Due to the recommendations in [2], i.e. warming momentum schedule
momentum_cache_t = beta_1_t * (1. - 0.5 * (
Expand Down Expand Up @@ -454,9 +458,13 @@ def _resource_apply_dense(self, grad, var):
if var.name in self.weight_decays.keys():
var_t = _apply_weight_decays(self, var, var_t)

var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

# Cosine annealing
(iteration_done, t_cur_update, eta_t_update
) = _update_t_cur_eta_t_v2(self, lr_t, var)
if iteration_done and not self._init_notified:
self._init_notified = True
var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

updates = [var_update, m_t, v_t]
if iteration_done:
Expand All @@ -478,9 +486,8 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
decay_base = math_ops.cast(0.96, var_dtype)

# Learning rate multipliers
# Cosine annealing
(iteration_done, t_cur_update, eta_t_update
) = _update_t_cur_eta_t_apply_lr_mult(self, lr_t, var)
if self.lr_multipliers is not None:
lr_t = _apply_lr_multiplier(self, lr_t, var)

momentum_cache_t = beta_1_t * (1. - 0.5 * (
math_ops.pow(decay_base, self._initial_decay * local_step)))
Expand Down Expand Up @@ -523,9 +530,13 @@ def _resource_apply_sparse(self, grad, var, indices, apply_state=None):
if var.name in self.weight_decays.keys():
var_t = _apply_weight_decays(self, var, var_t)

var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

# Cosine annealing
(iteration_done, t_cur_update, eta_t_update
) = _update_t_cur_eta_t_v2(self, lr_t, var)
if iteration_done and not self._init_notified:
self._init_notified = True
var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

updates = [var_update, m_t_bar, v_t]
if iteration_done:
Expand Down Expand Up @@ -667,9 +678,8 @@ def _resource_apply_dense(self, grad, var):
lr_t = self._decayed_lr(var_dtype)

# Learning rate multipliers
# Cosine annealing
(iteration_done, t_cur_update, eta_t_update
) = _update_t_cur_eta_t_apply_lr_mult(self, lr_t, var)
if self.lr_multipliers is not None:
lr_t = _apply_lr_multiplier(self, lr_t, var)

if self._momentum:
momentum = array_ops.identity(self._get_hyper('momentum', var_dtype))
Expand All @@ -690,9 +700,13 @@ def _resource_apply_dense(self, grad, var):
if var.name in self.weight_decays.keys():
var_t = _apply_weight_decays(self, var, var_t)

var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

# Cosine annealing
(iteration_done, t_cur_update, eta_t_update
) = _update_t_cur_eta_t_v2(self, lr_t, var)
if iteration_done and not self._init_notified:
self._init_notified = True
var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

updates = [var_update]
if self._momentum:
Expand All @@ -708,9 +722,8 @@ def _resource_apply_sparse(self, grad, var, indices):
lr_t = self._decayed_lr(var_dtype)

# Learning rate multipliers
# Cosine annealing
(iteration_done, t_cur_update, eta_t_update
) = _update_t_cur_eta_t_apply_lr_mult(self, lr_t, var)
if self.lr_multipliers is not None:
lr_t = _apply_lr_multiplier(self, lr_t, var)

if self._momentum:
momentum = array_ops.identity(self._get_hyper('momentum', var_dtype))
Expand All @@ -731,9 +744,13 @@ def _resource_apply_sparse(self, grad, var, indices):
if var.name in self.weight_decays.keys():
var_t = _apply_weight_decays(self, var, var_t)

var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

# Cosine annealing
(iteration_done, t_cur_update, eta_t_update
) = _update_t_cur_eta_t_v2(self, lr_t, var)
if iteration_done and not self._init_notified:
self._init_notified = True
var_update = state_ops.assign(var, var_t, use_locking=self._use_locking)

updates = [var_update]
if self._momentum:
Expand Down
Loading

0 comments on commit 8dc42e0

Please sign in to comment.