Skip to content
This repository has been archived by the owner on Jun 24, 2021. It is now read-only.

Commit

Permalink
Fix slow weights in tf.keras
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Sep 11, 2019
1 parent 85b6fb2 commit 622f3c4
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 29 deletions.
71 changes: 46 additions & 25 deletions keras_lookahead/optimizers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .backend import keras
from .backend import keras, TF_KERAS
from .backend import backend as K

__all__ = ['Lookahead']
Expand Down Expand Up @@ -34,31 +34,52 @@ def lr(self):
def lr(self, lr):
self.optimizer.lr = lr

@property
def iterations(self):
return self.optimizer.iterations

def get_updates(self, loss, params):
slow_params = {p.name: K.variable(K.get_value(p), name='sp_{}'.format(i)) for i, p in enumerate(params)}
sync_cond = K.equal((self.optimizer.iterations + 1) % self.sync_period, 0)
original_update = getattr(K, 'update')
setattr(K, 'update', lambda x, new_x: (x, new_x))
self.updates = self.optimizer.get_updates(loss, params)
setattr(K, 'update', original_update)
slow_updates = []
for i, update in enumerate(self.updates):
if isinstance(update, tuple):
if update[0].name not in slow_params:
self.updates[i] = K.update(update[0], update[1])
else:
slow_param = slow_params[update[0].name]
slow_param_t = slow_param + self.slow_step * (update[1] - slow_param)
slow_updates.append(K.update(slow_param, K.switch(
sync_cond,
slow_param_t,
slow_param,
)))
self.updates[i] = K.update(update[0], K.switch(
sync_cond,
slow_param_t,
update[1],
))
sync_cond = K.equal((self.iterations + 1) % self.sync_period, 0)
if TF_KERAS:
slow_params = [K.variable(K.get_value(p), name='sp_{}'.format(i)) for i, p in enumerate(params)]
self.updates = self.optimizer.get_updates(loss, params)
slow_updates = []
for p, sp in zip(params, slow_params):
sp_t = sp + self.slow_step * (p - sp)
slow_updates.append(K.update(sp, K.switch(
sync_cond,
sp_t,
sp,
)))
slow_updates.append(K.update_add(p, K.switch(
sync_cond,
sp_t - p,
K.zeros_like(p),
)))
else:
slow_params = {p.name: K.variable(K.get_value(p), name='sp_{}'.format(i)) for i, p in enumerate(params)}
original_update = getattr(K, 'update')
setattr(K, 'update', lambda x, new_x: (x, new_x))
self.updates = self.optimizer.get_updates(loss, params)
setattr(K, 'update', original_update)
slow_updates = []
for i, update in enumerate(self.updates):
if isinstance(update, tuple):
if update[0].name not in slow_params:
self.updates[i] = K.update(update[0], update[1])
else:
slow_param = slow_params[update[0].name]
slow_param_t = slow_param + self.slow_step * (update[1] - slow_param)
slow_updates.append(K.update(slow_param, K.switch(
sync_cond,
slow_param_t,
slow_param,
)))
self.updates[i] = K.update(update[0], K.switch(
sync_cond,
slow_param_t,
update[1],
))
self.updates += slow_updates
return self.updates

Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@

setup(
name='keras-lookahead',
version='0.2.0',
version='0.3.0',
packages=find_packages(),
url='https://github.com/CyberZHG/keras-lookahead',
license='MIT',
author='CyberZHG',
author_email='CyberZHG@users.noreply.github.com',
description='lookahead implemented in Keras',
description='Lookahead mechanism for optimizers in Keras',
long_description=long_description,
long_description_content_type='text/markdown',
install_requires=install_requires,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@ def test_ranger(self):

def test_half(self):
weight = np.random.standard_normal((5, 1))
x, y, _ = self._init_data(data_size=320)
x, y, _ = self._init_data(data_size=3200)

model = self._init_model('adam', w=weight)
model.fit(x, y, batch_size=32)
original = model.get_weights()[0]

model = self._init_model(Lookahead('adam', sync_period=10, slow_step=0.5), w=weight)
model = self._init_model(Lookahead('adam', sync_period=100, slow_step=0.5), w=weight)
model.fit(x, y, batch_size=32)
step_back = model.get_weights()[0]

Expand Down

0 comments on commit 622f3c4

Please sign in to comment.