From 71ca69e71deccba0992a33f1b3685f276fa620c6 Mon Sep 17 00:00:00 2001 From: CyberZHG <853842+CyberZHG@users.noreply.github.com> Date: Sat, 11 Jul 2020 14:54:18 +0800 Subject: [PATCH] Update to tf.keras --- .travis.yml | 30 +---- keras_adabound/__init__.py | 2 + keras_adabound/backend.py | 46 ------- keras_adabound/optimizers.py | 246 +++++++++++++++++++++++------------ requirements-dev.txt | 1 + requirements.txt | 1 + setup.py | 31 +++-- test.sh | 2 +- tests/test_optimize.py | 78 +++++++++-- tests/test_similar.py | 42 +++--- 10 files changed, 285 insertions(+), 194 deletions(-) delete mode 100644 keras_adabound/backend.py diff --git a/.travis.yml b/.travis.yml index 96a29e7..847fa87 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,19 +1,11 @@ dist: xenial language: python -python: - - 2.7 - - 3.6 +python: "3.6" env: - - KERAS_BACKEND=tensorflow - - KERAS_BACKEND=tensorflow TF_KERAS=1 - - KERAS_BACKEND=theano THEANO_FLAGS=optimizer=fast_compile - # - KERAS_BACKEND=cntk PYTHONWARNINGS=ignore + global: + - COVERALLS_PARALLEL=true install: - - if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then - wget https://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh; - else - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; - fi + - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; - bash miniconda.sh -b -p $HOME/miniconda - export PATH="$HOME/miniconda/bin:$PATH" - conda config --set always_yes yes --set changeps1 no @@ -25,20 +17,10 @@ install: - pip install --upgrade pip - pip install -r requirements.txt - pip install -r requirements-dev.txt - - if [[ $KERAS_BACKEND == "theano" ]]; then pip install theano && conda install mkl mkl-service; fi - - if [[ "$KERAS_BACKEND" == "cntk" ]]; then - set -e && - pip install cntk && - mkdir -p ~/mpi && - pushd ~/mpi && - wget http://cntk.ai/PythonWheel/ForKeras/depends/openmpi_1.10-3.zip && - unzip ./openmpi_1.10-3.zip && - sudo dpkg -i openmpi_1.10-3.deb && - popd; - fi - - if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install adabound; fi - pip install coveralls script: - ./test.sh after_success: coveralls +notifications: + webhooks: https://coveralls.io/webhook diff --git a/keras_adabound/__init__.py b/keras_adabound/__init__.py index acd7379..33d39ba 100644 --- a/keras_adabound/__init__.py +++ b/keras_adabound/__init__.py @@ -1 +1,3 @@ from .optimizers import * + +__version__ = '0.6.0' diff --git a/keras_adabound/backend.py b/keras_adabound/backend.py deleted file mode 100644 index e6b9f58..0000000 --- a/keras_adabound/backend.py +++ /dev/null @@ -1,46 +0,0 @@ -import os - -__all__ = [ - 'keras', 'utils', 'activations', 'applications', 'backend', 'datasets', 'engine', - 'layers', 'preprocessing', 'wrappers', 'callbacks', 'constraints', 'initializers', - 'metrics', 'models', 'losses', 'optimizers', 'regularizers', -] - -if 'TF_KERAS' in os.environ and os.environ['TF_KERAS'] != '0': - from tensorflow.python import keras - from tensorflow.python.keras import utils - from tensorflow.python.keras import activations - from tensorflow.python.keras import applications - from tensorflow.python.keras import backend - from tensorflow.python.keras import datasets - from tensorflow.python.keras import engine - from tensorflow.python.keras import layers - from tensorflow.python.keras import preprocessing - from tensorflow.python.keras import wrappers - from tensorflow.python.keras import callbacks - from tensorflow.python.keras import constraints - from tensorflow.python.keras import initializers - from tensorflow.python.keras import metrics - from tensorflow.python.keras import models - from tensorflow.python.keras import losses - from tensorflow.python.keras import optimizers - from tensorflow.python.keras import regularizers -else: - import keras - from keras import utils - from keras import activations - from keras import applications - from keras import backend - from keras import datasets - from keras import engine - from keras import layers - from keras import preprocessing - from keras import wrappers - from keras import callbacks - from keras import constraints - from keras import initializers - from keras import metrics - from keras import models - from keras import losses - from keras import optimizers - from keras import regularizers diff --git a/keras_adabound/optimizers.py b/keras_adabound/optimizers.py index 4172509..37cff6b 100644 --- a/keras_adabound/optimizers.py +++ b/keras_adabound/optimizers.py @@ -1,14 +1,22 @@ -from .backend import keras -from .backend import backend as K +from typing import Union, Callable, Dict, Optional + +import numpy as np +import tensorflow as tf +from tensorflow import keras +from typeguard import typechecked + + +K = keras.backend +FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64] class AdaBound(keras.optimizers.Optimizer): """AdamBound optimizer. # Arguments - lr: float >= 0. Learning rate. - final_lr: float >= 0. Final (SGD) learning rate. + learning_rate: float >= 0. Learning rate. base_lr: float >= 0. Used for loading the optimizer. Do not set the argument manually. + final_lr: float >= 0. Final (SGD) learning rate. beta_1: float, 0 < beta < 1. Generally close to 1. beta_2: float, 0 < beta < 1. Generally close to 1. gamma: float, 0 < gamma < 1. Convergence speed of the bound functions. @@ -22,91 +30,169 @@ class AdaBound(keras.optimizers.Optimizer): (https://openreview.net/forum?id=Bkg3g2R9FX) """ - def __init__(self, lr=0.001, final_lr=0.1, base_lr=None, - beta_1=0.9, beta_2=0.999, gamma=0.001, - epsilon=None, decay=0., weight_decay=0., amsgrad=False, **kwargs): - super(AdaBound, self).__init__(**kwargs) - with K.name_scope(self.__class__.__name__): - self.iterations = K.variable(0, dtype='int64', name='iterations') - self.lr = K.variable(lr, name='lr') - self.final_lr = K.variable(final_lr, name='final_lr') - self.beta_1 = K.variable(beta_1, name='beta_1') - self.beta_2 = K.variable(beta_2, name='beta_2') - self.gamma = K.variable(gamma, name='gamma') - self.decay = K.variable(decay, name='decay') - self.weight_decay = K.variable(weight_decay, name='weight_decay') - if epsilon is None: - epsilon = K.epsilon() + @typechecked + def __init__( + self, + learning_rate: Union[FloatTensorLike, Callable, Dict] = 0.001, + base_lr: Optional[FloatTensorLike] = None, + final_lr: FloatTensorLike = 0.1, + beta_1: FloatTensorLike = 0.9, + beta_2: FloatTensorLike = 0.999, + gamma: FloatTensorLike = 0.001, + epsilon: FloatTensorLike = 1e-8, + weight_decay: Union[FloatTensorLike, Callable, Dict] = 0.0, + amsgrad: bool = False, + name: str = "AdaBound", + **kwargs + ): + super(AdaBound, self).__init__(name=name, **kwargs) + + if isinstance(learning_rate, Dict): + learning_rate = tf.keras.optimizers.schedules.deserialize(learning_rate) + + if isinstance(weight_decay, Dict): + weight_decay = tf.keras.optimizers.schedules.deserialize(weight_decay) + if base_lr is None: - self.base_lr = lr - else: - self.base_lr = base_lr - self.epsilon = epsilon - self.initial_decay = decay - self.initial_weight_decay = weight_decay + if isinstance(learning_rate, tf.keras.optimizers.schedules.LearningRateSchedule): + base_lr = learning_rate(0) + else: + base_lr = learning_rate + + self._set_hyper("learning_rate", kwargs.get("lr", learning_rate)) + self._set_hyper("base_lr", base_lr) + self._set_hyper("final_lr", final_lr) + self._set_hyper("beta_1", beta_1) + self._set_hyper("beta_2", beta_2) + self._set_hyper("gamma", gamma) + self._set_hyper("decay", self._initial_decay) + self._set_hyper("weight_decay", weight_decay) + self.epsilon = epsilon or tf.keras.backend.epsilon() self.amsgrad = amsgrad + self._has_weight_decay = weight_decay != 0.0 + + def _create_slots(self, var_list): + for var in var_list: + self.add_slot(var, "m") + for var in var_list: + self.add_slot(var, "v") + if self.amsgrad: + for var in var_list: + self.add_slot(var, "vhat") - def get_updates(self, loss, params): - grads = self.get_gradients(loss, params) - self.updates = [K.update_add(self.iterations, 1)] + def _decayed_wd(self, var_dtype): + wd_t = self._get_hyper("weight_decay", var_dtype) + if isinstance(wd_t, tf.keras.optimizers.schedules.LearningRateSchedule): + wd_t = tf.cast(wd_t(self.iterations), var_dtype) + return wd_t - lr = self.lr - if self.initial_decay > 0: - lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, - K.dtype(self.decay)))) + def _resource_apply_dense(self, grad, var): + var_dtype = var.dtype.base_dtype + lr_t = self._decayed_lr(var_dtype) + wd_t = self._decayed_wd(var_dtype) + base_lr = self._get_hyper("base_lr", var_dtype) + final_lr = self._get_hyper("final_lr", var_dtype) + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + beta_1_t = self._get_hyper("beta_1", var_dtype) + beta_2_t = self._get_hyper("beta_2", var_dtype) + gamma = self._get_hyper("gamma", var_dtype) + epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype) + local_step = tf.cast(self.iterations + 1, var_dtype) + beta_1_power = tf.pow(beta_1_t, local_step) + beta_2_power = tf.pow(beta_2_t, local_step) - t = K.cast(self.iterations, K.floatx()) + 1 - lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / - (1. - K.pow(self.beta_1, t))) - final_lr = self.final_lr * lr / self.base_lr - lower_bound = final_lr * (1.0 - 1.0 / (self.gamma * t + 1.0)) - upper_bound = final_lr * (1.0 + 1.0 / (self.gamma * t)) + if self._has_weight_decay: + grad += wd_t * var + + m_t = m.assign(beta_1_t * m + (1.0 - beta_1_t) * grad, use_locking=self._use_locking) + v_t = v.assign(beta_2_t * v + (1.0 - beta_2_t) * tf.square(grad), use_locking=self._use_locking) - ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] - vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] if self.amsgrad: - vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] + vhat = self.get_slot(var, "vhat") + vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking) + denom = tf.sqrt(vhat_t) + epsilon_t else: - vhats = [K.zeros(1) for _ in params] - self.weights = [self.iterations] + ms + vs + vhats - - for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats): - if self.initial_weight_decay > 0.: - # Note that the decayed weights are added to the momentums. - # The mechanism is the same as the official repo. - g += self.weight_decay * p - - m_t = (self.beta_1 * m) + (1. - self.beta_1) * g - v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) - - if self.amsgrad: - vhat_t = K.maximum(vhat, v_t) - step = lr_t / (K.sqrt(vhat_t) + self.epsilon) - self.updates.append(K.update(vhat, vhat_t)) - else: - step = lr_t / (K.sqrt(v_t) + self.epsilon) - p_t = p - K.minimum(K.maximum(step, lower_bound), upper_bound) * m_t - self.updates.append(K.update(m, m_t)) - self.updates.append(K.update(v, v_t)) - new_p = p_t + vhat_t = None + denom = tf.sqrt(v_t) + epsilon_t + + final_lr = final_lr * lr_t / base_lr + lower_bound = final_lr * (1.0 - 1.0 / (gamma * local_step + 1.0)) + upper_bound = final_lr * (1.0 + 1.0 / (gamma * local_step)) + lr_t = lr_t * (tf.sqrt(1.0 - beta_2_power) / (1.0 - beta_1_power)) + lr_t = tf.clip_by_value(lr_t / denom, lower_bound, upper_bound) + var_update = var.assign_sub(lr_t * m_t, use_locking=self._use_locking) + + updates = [var_update, m_t, v_t] + if self.amsgrad: + updates.append(vhat_t) + return tf.group(*updates) + + def _resource_apply_sparse(self, grad, var, indices): + var_dtype = var.dtype.base_dtype + lr_t = self._decayed_lr(var_dtype) + wd_t = self._decayed_wd(var_dtype) + base_lr = self._get_hyper("base_lr", var_dtype) + final_lr = self._get_hyper("final_lr", var_dtype) + m = self.get_slot(var, "m") + v = self.get_slot(var, "v") + beta_1_t = self._get_hyper("beta_1", var_dtype) + beta_2_t = self._get_hyper("beta_2", var_dtype) + gamma = self._get_hyper("gamma", var_dtype) + epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype) + local_step = tf.cast(self.iterations + 1, var_dtype) + beta_1_power = tf.pow(beta_1_t, local_step) + beta_2_power = tf.pow(beta_2_t, local_step) + + if self._has_weight_decay: + grad = grad + wd_t * tf.squeeze(tf.gather(tf.expand_dims(var, axis=0), indices, axis=1), axis=0) - # Apply constraints. - if getattr(p, 'constraint', None) is not None: - new_p = p.constraint(new_p) + m_scaled_g_values = grad * (1 - beta_1_t) + m_t = m.assign(m * beta_1_t, use_locking=self._use_locking) + with tf.control_dependencies([m_t]): + m_t = self._resource_scatter_add(m, indices, m_scaled_g_values) - self.updates.append(K.update(p, new_p)) - return self.updates + v_scaled_g_values = (grad * grad) * (1 - beta_2_t) + v_t = v.assign(v * beta_2_t, use_locking=self._use_locking) + with tf.control_dependencies([v_t]): + v_t = self._resource_scatter_add(v, indices, v_scaled_g_values) + + if self.amsgrad: + vhat = self.get_slot(var, "vhat") + vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking) + denom = tf.sqrt(vhat_t) + epsilon_t + else: + vhat_t = None + denom = tf.sqrt(v_t) + epsilon_t + + final_lr = final_lr * lr_t / base_lr + lower_bound = final_lr * (1.0 - 1.0 / (gamma * local_step + 1.0)) + upper_bound = final_lr * (1.0 + 1.0 / (gamma * local_step)) + lr_t = lr_t * (tf.sqrt(1.0 - beta_2_power) / (1.0 - beta_1_power)) + lr_t = tf.clip_by_value(lr_t / denom, lower_bound, upper_bound) + with tf.control_dependencies([m_t]): + var_update = self._resource_scatter_add( + var, indices, tf.gather(-lr_t * m_t, indices) + ) + + updates = [var_update, m_t, v_t] + if self.amsgrad: + updates.append(vhat_t) + return tf.group(*updates) def get_config(self): - config = {'lr': float(K.get_value(self.lr)), - 'final_lr': float(K.get_value(self.final_lr)), - 'base_lr': self.base_lr, - 'beta_1': float(K.get_value(self.beta_1)), - 'beta_2': float(K.get_value(self.beta_2)), - 'gamma': float(K.get_value(self.gamma)), - 'decay': float(K.get_value(self.decay)), - 'weight_decay': float(K.get_value(self.weight_decay)), - 'epsilon': self.epsilon, - 'amsgrad': self.amsgrad} - base_config = super(AdaBound, self).get_config() - return dict(list(base_config.items()) + list(config.items())) + config = super().get_config() + config.update( + { + "learning_rate": self._serialize_hyperparameter("learning_rate"), + "base_lr": self._serialize_hyperparameter("base_lr"), + "final_lr": self._serialize_hyperparameter("final_lr"), + "beta_1": self._serialize_hyperparameter("beta_1"), + "beta_2": self._serialize_hyperparameter("beta_2"), + "gamma": self._serialize_hyperparameter("gamma"), + "weight_decay": self._serialize_hyperparameter("weight_decay"), + "epsilon": self.epsilon, + "amsgrad": self.amsgrad + } + ) + return config diff --git a/requirements-dev.txt b/requirements-dev.txt index 7b354ae..7239052 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,3 +6,4 @@ nose tensorflow pycodestyle coverage +adabound diff --git a/requirements.txt b/requirements.txt index 2ea7c02..0cf2311 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ numpy Keras +typeguard diff --git a/setup.py b/setup.py index 39dea61..8e244ee 100644 --- a/setup.py +++ b/setup.py @@ -1,30 +1,43 @@ +import os +import re import codecs from setuptools import setup, find_packages +current_path = os.path.abspath(os.path.dirname(__file__)) -with codecs.open('README.md', 'r', 'utf8') as reader: - long_description = reader.read() +def read_file(*parts): + with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: + return reader.read() -with codecs.open('requirements.txt', 'r', 'utf8') as reader: - install_requires = list(map(lambda x: x.strip(), reader.readlines())) + +def get_requirements(*parts): + with codecs.open(os.path.join(current_path, *parts), 'r', 'utf8') as reader: + return list(map(lambda x: x.strip(), reader.readlines())) + + +def find_version(*file_paths): + version_file = read_file(*file_paths) + version_match = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", version_file, re.M) + if version_match: + return version_match.group(1) + raise RuntimeError('Unable to find version string.') setup( name='keras-adabound', - version='0.5.0', + version=find_version('keras_adabound', '__init__.py'), packages=find_packages(), url='https://github.com/CyberZHG/keras-adabound', license='MIT', author='CyberZHG', author_email='CyberZHG@gmail.com', description='AdaBound optimizer in Keras', - long_description=long_description, + long_description=read_file('README.md'), long_description_content_type='text/markdown', - install_requires=install_requires, + install_requires=get_requirements('requirements.txt'), classifiers=( - "Programming Language :: Python :: 2.7", - "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ), diff --git a/test.sh b/test.sh index 2d69a7b..9cb0ec3 100755 --- a/test.sh +++ b/test.sh @@ -1,3 +1,3 @@ #!/usr/bin/env bash pycodestyle --max-line-length=120 keras_adabound tests && \ - nosetests --with-coverage --cover-erase --cover-html --cover-html-dir=htmlcov --cover-package=keras_adabound tests + nosetests --nocapture --with-coverage --cover-erase --cover-html --cover-html-dir=htmlcov --cover-package=keras_adabound tests diff --git a/tests/test_optimize.py b/tests/test_optimize.py index 6f8954d..445ae0c 100644 --- a/tests/test_optimize.py +++ b/tests/test_optimize.py @@ -3,17 +3,19 @@ from unittest import TestCase import numpy as np import tensorflow as tf -from keras_adabound.backend import keras -from keras_adabound.backend import backend as K +from tensorflow import keras from keras_adabound import AdaBound +K = keras.backend + + class TestOptimizers(TestCase): @staticmethod def reset_seed(seed): np.random.seed(seed) - tf.set_random_seed(seed) + tf.random.set_seed(seed) @staticmethod def gen_keras_linear(w, b, amsgrad=False): @@ -26,16 +28,20 @@ def gen_keras_linear(w, b, amsgrad=False): def gen_random_weights(): return np.random.standard_normal((3, 5)), np.random.standard_normal((5,)) - def test_with_constraint(self): + def test_with_scheduler(self): w, b = self.gen_random_weights() model = keras.models.Sequential() model.add(keras.layers.Dense( input_shape=(3,), units=5, - kernel_constraint=keras.constraints.max_norm(1.0), weights=[w, b]), ) - model.compile(optimizer=AdaBound(lr=1e-3, final_lr=0.1, decay=0.5), loss='mse') + decay = tf.keras.optimizers.schedules.ExponentialDecay(0.001, decay_steps=100000, decay_rate=0.96) + decay = tf.keras.optimizers.schedules.serialize(decay) + model.compile(optimizer=AdaBound(learning_rate=decay, + final_lr=0.1, + decay=0.5, + weight_decay=decay), loss='mse') x = np.random.standard_normal((1, 3)) y = np.dot(x, w) + b model.train_on_batch(x, y) @@ -53,8 +59,62 @@ def test_with_plateau(self): x = np.random.standard_normal((10000, 3)) y = np.dot(x, w) + b model.fit(x, y, epochs=100, callbacks=[keras.callbacks.ReduceLROnPlateau(monitor='loss')], verbose=False) - model_path = os.path.join(tempfile.gettempdir(), 'keras_adabound_plateau.h5') - model.save(model_path) - model = keras.models.load_model(model_path, custom_objects={'AdaBound': AdaBound}) + with tempfile.TemporaryDirectory() as temp_path: + model_path = os.path.join(temp_path, 'keras_adabound.h5') + model.save(model_path) + model = keras.models.load_model(model_path, custom_objects={'AdaBound': AdaBound}) self.assertGreater(1e-3, float(K.get_value(model.optimizer.lr))) self.assertEqual(1e-3, model.optimizer.base_lr) + + def _embedding_data(self): + while True: + x = np.random.randint(0, 10, (3, 7)) + y = np.zeros(3) + for i in range(3): + if 5 in x[i]: + y[i] = 1 + yield x, y + + def test_with_embedding(self): + model = keras.models.Sequential() + model.add(keras.layers.Embedding( + input_dim=10, + output_dim=5, + mask_zero=True, + input_shape=(7,)), + ) + model.add(keras.layers.LSTM(units=5)) + model.add(keras.layers.Dense(units=2, activation='softmax')) + model.compile(optimizer=AdaBound(), loss='sparse_categorical_crossentropy') + model.fit(self._embedding_data(), + steps_per_epoch=1000, + validation_data=self._embedding_data(), + validation_steps=10, + epochs=3) + + def test_with_embedding_amsgrad(self): + model = keras.models.Sequential() + model.add(keras.layers.Embedding( + input_dim=10, + mask_zero=True, + output_dim=5, + input_shape=(7,)), + ) + model.add(keras.layers.LSTM(units=5)) + model.add(keras.layers.Dense(units=2, activation='softmax')) + model.compile(optimizer=AdaBound(amsgrad=True, + weight_decay=1e-3), loss='sparse_categorical_crossentropy') + model.fit(self._embedding_data(), + steps_per_epoch=1000, + validation_data=self._embedding_data(), + validation_steps=10, + epochs=2) + with tempfile.TemporaryDirectory() as temp_path: + model_path = os.path.join(temp_path, 'keras_adabound.h5') + model.save(model_path) + model = keras.models.load_model(model_path, custom_objects={'AdaBound': AdaBound}) + model.fit(self._embedding_data(), + steps_per_epoch=1000, + validation_data=self._embedding_data(), + validation_steps=10, + epochs=1) diff --git a/tests/test_similar.py b/tests/test_similar.py index ff57edf..63ba914 100644 --- a/tests/test_similar.py +++ b/tests/test_similar.py @@ -1,16 +1,14 @@ import os -import sys import tempfile from unittest import TestCase import torch import numpy as np import tensorflow as tf -from keras_adabound.backend import keras -from keras_adabound.backend import backend as K +from tensorflow import keras from keras_adabound import AdaBound +from adabound import AdaBound as OfficialAdaBound -if sys.version_info[0] == 3: - from adabound import AdaBound as OfficialAdaBound +K = keras.backend class TestOptimizers(TestCase): @@ -18,7 +16,7 @@ class TestOptimizers(TestCase): @staticmethod def reset_seed(seed): np.random.seed(seed) - tf.set_random_seed(seed) + tf.random.set_seed(seed) torch.manual_seed(seed) @staticmethod @@ -45,8 +43,6 @@ def gen_random_weights(): return np.random.standard_normal((3, 5)), np.random.standard_normal((5,)) def test_same(self): - if sys.version_info[0] < 3: - return self.reset_seed(0xcafe) w, b = self.gen_random_weights() torch_linear = self.gen_torch_linear(w, b) @@ -56,7 +52,7 @@ def test_same(self): keras_linear = keras.models.load_model(model_path, custom_objects={'AdaBound': AdaBound}) w, b = self.gen_random_weights() criterion = torch.nn.MSELoss() - optimizer = OfficialAdaBound(torch_linear.parameters(), lr=1e-3, final_lr=0.1, eps=K.epsilon()) + optimizer = OfficialAdaBound(torch_linear.parameters(), lr=1e-3, final_lr=0.1, eps=1e-8) for i in range(300): x = np.random.standard_normal((1, 3)) y = np.dot(x, w) + b @@ -66,23 +62,21 @@ def test_same(self): torch_loss = loss.tolist() loss.backward() optimizer.step() - keras_loss = keras_linear.train_on_batch(x, y).tolist() + keras_loss = keras_linear.train_on_batch(x, y) # print(i, torch_loss, keras_loss) - self.assertTrue(abs(torch_loss - keras_loss) < 1e-2) + self.assertTrue(abs(torch_loss - keras_loss) < 1e-4) self.assertTrue(np.allclose( torch_linear.weight.detach().numpy().transpose(), keras_linear.get_weights()[0], - atol=1e-2, + atol=1e-4, )) self.assertTrue(np.allclose( torch_linear.bias.detach().numpy(), keras_linear.get_weights()[1], - atol=1e-2, + atol=1e-4, )) def test_same_amsgrad(self): - if sys.version_info[0] < 3: - return self.reset_seed(0xcafe) w, b = self.gen_random_weights() torch_linear = self.gen_torch_linear(w, b) @@ -105,23 +99,21 @@ def test_same_amsgrad(self): torch_loss = loss.tolist() loss.backward() optimizer.step() - keras_loss = keras_linear.train_on_batch(x, y).tolist() + keras_loss = keras_linear.train_on_batch(x, y) # print(i, torch_loss, keras_loss) - self.assertTrue(abs(torch_loss - keras_loss) < 1e-2) + self.assertTrue(abs(torch_loss - keras_loss) < 1e-4) self.assertTrue(np.allclose( torch_linear.weight.detach().numpy().transpose(), keras_linear.get_weights()[0], - atol=1e-2, + atol=1e-4, )) self.assertTrue(np.allclose( torch_linear.bias.detach().numpy(), keras_linear.get_weights()[1], - atol=1e-2, + atol=1e-4, )) def test_same_weight_decay(self): - if sys.version_info[0] < 3: - return self.reset_seed(0xcafe) w, b = self.gen_random_weights() torch_linear = self.gen_torch_linear(w, b) @@ -144,16 +136,16 @@ def test_same_weight_decay(self): torch_loss = loss.tolist() loss.backward() optimizer.step() - keras_loss = keras_linear.train_on_batch(x, y).tolist() + keras_loss = keras_linear.train_on_batch(x, y) # print(i, torch_loss, keras_loss) - self.assertTrue(abs(torch_loss - keras_loss) < 1e-2) + self.assertTrue(abs(torch_loss - keras_loss) < 1e-4) self.assertTrue(np.allclose( torch_linear.weight.detach().numpy().transpose(), keras_linear.get_weights()[0], - atol=1e-2, + atol=1e-4, )) self.assertTrue(np.allclose( torch_linear.bias.detach().numpy(), keras_linear.get_weights()[1], - atol=1e-2, + atol=1e-4, ))