From 622f3c4d1b33a2fd5dee68a9fc284424222cf4a3 Mon Sep 17 00:00:00 2001 From: CyberZHG <853842+CyberZHG@users.noreply.github.com> Date: Wed, 11 Sep 2019 22:04:03 +0800 Subject: [PATCH] Fix slow weights in tf.keras --- keras_lookahead/optimizers.py | 71 +++++++++++++++++++++++------------ setup.py | 4 +- tests/test_optimizers.py | 4 +- 3 files changed, 50 insertions(+), 29 deletions(-) diff --git a/keras_lookahead/optimizers.py b/keras_lookahead/optimizers.py index f83a1ee..670b2c0 100644 --- a/keras_lookahead/optimizers.py +++ b/keras_lookahead/optimizers.py @@ -1,4 +1,4 @@ -from .backend import keras +from .backend import keras, TF_KERAS from .backend import backend as K __all__ = ['Lookahead'] @@ -34,31 +34,52 @@ def lr(self): def lr(self, lr): self.optimizer.lr = lr + @property + def iterations(self): + return self.optimizer.iterations + def get_updates(self, loss, params): - slow_params = {p.name: K.variable(K.get_value(p), name='sp_{}'.format(i)) for i, p in enumerate(params)} - sync_cond = K.equal((self.optimizer.iterations + 1) % self.sync_period, 0) - original_update = getattr(K, 'update') - setattr(K, 'update', lambda x, new_x: (x, new_x)) - self.updates = self.optimizer.get_updates(loss, params) - setattr(K, 'update', original_update) - slow_updates = [] - for i, update in enumerate(self.updates): - if isinstance(update, tuple): - if update[0].name not in slow_params: - self.updates[i] = K.update(update[0], update[1]) - else: - slow_param = slow_params[update[0].name] - slow_param_t = slow_param + self.slow_step * (update[1] - slow_param) - slow_updates.append(K.update(slow_param, K.switch( - sync_cond, - slow_param_t, - slow_param, - ))) - self.updates[i] = K.update(update[0], K.switch( - sync_cond, - slow_param_t, - update[1], - )) + sync_cond = K.equal((self.iterations + 1) % self.sync_period, 0) + if TF_KERAS: + slow_params = [K.variable(K.get_value(p), name='sp_{}'.format(i)) for i, p in enumerate(params)] + self.updates = self.optimizer.get_updates(loss, params) + slow_updates = [] + for p, sp in zip(params, slow_params): + sp_t = sp + self.slow_step * (p - sp) + slow_updates.append(K.update(sp, K.switch( + sync_cond, + sp_t, + sp, + ))) + slow_updates.append(K.update_add(p, K.switch( + sync_cond, + sp_t - p, + K.zeros_like(p), + ))) + else: + slow_params = {p.name: K.variable(K.get_value(p), name='sp_{}'.format(i)) for i, p in enumerate(params)} + original_update = getattr(K, 'update') + setattr(K, 'update', lambda x, new_x: (x, new_x)) + self.updates = self.optimizer.get_updates(loss, params) + setattr(K, 'update', original_update) + slow_updates = [] + for i, update in enumerate(self.updates): + if isinstance(update, tuple): + if update[0].name not in slow_params: + self.updates[i] = K.update(update[0], update[1]) + else: + slow_param = slow_params[update[0].name] + slow_param_t = slow_param + self.slow_step * (update[1] - slow_param) + slow_updates.append(K.update(slow_param, K.switch( + sync_cond, + slow_param_t, + slow_param, + ))) + self.updates[i] = K.update(update[0], K.switch( + sync_cond, + slow_param_t, + update[1], + )) self.updates += slow_updates return self.updates diff --git a/setup.py b/setup.py index cf3a789..414f4a5 100644 --- a/setup.py +++ b/setup.py @@ -11,13 +11,13 @@ setup( name='keras-lookahead', - version='0.2.0', + version='0.3.0', packages=find_packages(), url='https://github.com/CyberZHG/keras-lookahead', license='MIT', author='CyberZHG', author_email='CyberZHG@users.noreply.github.com', - description='lookahead implemented in Keras', + description='Lookahead mechanism for optimizers in Keras', long_description=long_description, long_description_content_type='text/markdown', install_requires=install_requires, diff --git a/tests/test_optimizers.py b/tests/test_optimizers.py index 9c46147..ecbe81c 100644 --- a/tests/test_optimizers.py +++ b/tests/test_optimizers.py @@ -62,13 +62,13 @@ def test_ranger(self): def test_half(self): weight = np.random.standard_normal((5, 1)) - x, y, _ = self._init_data(data_size=320) + x, y, _ = self._init_data(data_size=3200) model = self._init_model('adam', w=weight) model.fit(x, y, batch_size=32) original = model.get_weights()[0] - model = self._init_model(Lookahead('adam', sync_period=10, slow_step=0.5), w=weight) + model = self._init_model(Lookahead('adam', sync_period=100, slow_step=0.5), w=weight) model.fit(x, y, batch_size=32) step_back = model.get_weights()[0]