Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Update to tf.keras
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Jul 11, 2020
1 parent fed9103 commit 71ca69e
Show file tree
Hide file tree
Showing 10 changed files with 285 additions and 194 deletions.
30 changes: 6 additions & 24 deletions .travis.yml
@@ -1,19 +1,11 @@
dist: xenial
language: python
python:
- 2.7
- 3.6
python: "3.6"
env:
- KERAS_BACKEND=tensorflow
- KERAS_BACKEND=tensorflow TF_KERAS=1
- KERAS_BACKEND=theano THEANO_FLAGS=optimizer=fast_compile
# - KERAS_BACKEND=cntk PYTHONWARNINGS=ignore
global:
- COVERALLS_PARALLEL=true
install:
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
wget https://repo.continuum.io/miniconda/Miniconda-latest-Linux-x86_64.sh -O miniconda.sh;
else
wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
fi
- wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh;
- bash miniconda.sh -b -p $HOME/miniconda
- export PATH="$HOME/miniconda/bin:$PATH"
- conda config --set always_yes yes --set changeps1 no
Expand All @@ -25,20 +17,10 @@ install:
- pip install --upgrade pip
- pip install -r requirements.txt
- pip install -r requirements-dev.txt
- if [[ $KERAS_BACKEND == "theano" ]]; then pip install theano && conda install mkl mkl-service; fi
- if [[ "$KERAS_BACKEND" == "cntk" ]]; then
set -e &&
pip install cntk &&
mkdir -p ~/mpi &&
pushd ~/mpi &&
wget http://cntk.ai/PythonWheel/ForKeras/depends/openmpi_1.10-3.zip &&
unzip ./openmpi_1.10-3.zip &&
sudo dpkg -i openmpi_1.10-3.deb &&
popd;
fi
- if [[ $TRAVIS_PYTHON_VERSION == 3.6 ]]; then pip install adabound; fi
- pip install coveralls
script:
- ./test.sh
after_success:
coveralls
notifications:
webhooks: https://coveralls.io/webhook
2 changes: 2 additions & 0 deletions keras_adabound/__init__.py
@@ -1 +1,3 @@
from .optimizers import *

__version__ = '0.6.0'
46 changes: 0 additions & 46 deletions keras_adabound/backend.py

This file was deleted.

246 changes: 166 additions & 80 deletions keras_adabound/optimizers.py
@@ -1,14 +1,22 @@
from .backend import keras
from .backend import backend as K
from typing import Union, Callable, Dict, Optional

import numpy as np
import tensorflow as tf
from tensorflow import keras
from typeguard import typechecked


K = keras.backend
FloatTensorLike = Union[tf.Tensor, float, np.float16, np.float32, np.float64]


class AdaBound(keras.optimizers.Optimizer):
"""AdamBound optimizer.
# Arguments
lr: float >= 0. Learning rate.
final_lr: float >= 0. Final (SGD) learning rate.
learning_rate: float >= 0. Learning rate.
base_lr: float >= 0. Used for loading the optimizer. Do not set the argument manually.
final_lr: float >= 0. Final (SGD) learning rate.
beta_1: float, 0 < beta < 1. Generally close to 1.
beta_2: float, 0 < beta < 1. Generally close to 1.
gamma: float, 0 < gamma < 1. Convergence speed of the bound functions.
Expand All @@ -22,91 +30,169 @@ class AdaBound(keras.optimizers.Optimizer):
(https://openreview.net/forum?id=Bkg3g2R9FX)
"""

def __init__(self, lr=0.001, final_lr=0.1, base_lr=None,
beta_1=0.9, beta_2=0.999, gamma=0.001,
epsilon=None, decay=0., weight_decay=0., amsgrad=False, **kwargs):
super(AdaBound, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.lr = K.variable(lr, name='lr')
self.final_lr = K.variable(final_lr, name='final_lr')
self.beta_1 = K.variable(beta_1, name='beta_1')
self.beta_2 = K.variable(beta_2, name='beta_2')
self.gamma = K.variable(gamma, name='gamma')
self.decay = K.variable(decay, name='decay')
self.weight_decay = K.variable(weight_decay, name='weight_decay')
if epsilon is None:
epsilon = K.epsilon()
@typechecked
def __init__(
self,
learning_rate: Union[FloatTensorLike, Callable, Dict] = 0.001,
base_lr: Optional[FloatTensorLike] = None,
final_lr: FloatTensorLike = 0.1,
beta_1: FloatTensorLike = 0.9,
beta_2: FloatTensorLike = 0.999,
gamma: FloatTensorLike = 0.001,
epsilon: FloatTensorLike = 1e-8,
weight_decay: Union[FloatTensorLike, Callable, Dict] = 0.0,
amsgrad: bool = False,
name: str = "AdaBound",
**kwargs
):
super(AdaBound, self).__init__(name=name, **kwargs)

if isinstance(learning_rate, Dict):
learning_rate = tf.keras.optimizers.schedules.deserialize(learning_rate)

if isinstance(weight_decay, Dict):
weight_decay = tf.keras.optimizers.schedules.deserialize(weight_decay)

if base_lr is None:
self.base_lr = lr
else:
self.base_lr = base_lr
self.epsilon = epsilon
self.initial_decay = decay
self.initial_weight_decay = weight_decay
if isinstance(learning_rate, tf.keras.optimizers.schedules.LearningRateSchedule):
base_lr = learning_rate(0)
else:
base_lr = learning_rate

self._set_hyper("learning_rate", kwargs.get("lr", learning_rate))
self._set_hyper("base_lr", base_lr)
self._set_hyper("final_lr", final_lr)
self._set_hyper("beta_1", beta_1)
self._set_hyper("beta_2", beta_2)
self._set_hyper("gamma", gamma)
self._set_hyper("decay", self._initial_decay)
self._set_hyper("weight_decay", weight_decay)
self.epsilon = epsilon or tf.keras.backend.epsilon()
self.amsgrad = amsgrad
self._has_weight_decay = weight_decay != 0.0

def _create_slots(self, var_list):
for var in var_list:
self.add_slot(var, "m")
for var in var_list:
self.add_slot(var, "v")
if self.amsgrad:
for var in var_list:
self.add_slot(var, "vhat")

def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
self.updates = [K.update_add(self.iterations, 1)]
def _decayed_wd(self, var_dtype):
wd_t = self._get_hyper("weight_decay", var_dtype)
if isinstance(wd_t, tf.keras.optimizers.schedules.LearningRateSchedule):
wd_t = tf.cast(wd_t(self.iterations), var_dtype)
return wd_t

lr = self.lr
if self.initial_decay > 0:
lr = lr * (1. / (1. + self.decay * K.cast(self.iterations,
K.dtype(self.decay))))
def _resource_apply_dense(self, grad, var):
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
wd_t = self._decayed_wd(var_dtype)
base_lr = self._get_hyper("base_lr", var_dtype)
final_lr = self._get_hyper("final_lr", var_dtype)
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
beta_1_t = self._get_hyper("beta_1", var_dtype)
beta_2_t = self._get_hyper("beta_2", var_dtype)
gamma = self._get_hyper("gamma", var_dtype)
epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype)
local_step = tf.cast(self.iterations + 1, var_dtype)
beta_1_power = tf.pow(beta_1_t, local_step)
beta_2_power = tf.pow(beta_2_t, local_step)

t = K.cast(self.iterations, K.floatx()) + 1
lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) /
(1. - K.pow(self.beta_1, t)))
final_lr = self.final_lr * lr / self.base_lr
lower_bound = final_lr * (1.0 - 1.0 / (self.gamma * t + 1.0))
upper_bound = final_lr * (1.0 + 1.0 / (self.gamma * t))
if self._has_weight_decay:
grad += wd_t * var

m_t = m.assign(beta_1_t * m + (1.0 - beta_1_t) * grad, use_locking=self._use_locking)
v_t = v.assign(beta_2_t * v + (1.0 - beta_2_t) * tf.square(grad), use_locking=self._use_locking)

ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
if self.amsgrad:
vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params]
vhat = self.get_slot(var, "vhat")
vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking)
denom = tf.sqrt(vhat_t) + epsilon_t
else:
vhats = [K.zeros(1) for _ in params]
self.weights = [self.iterations] + ms + vs + vhats

for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats):
if self.initial_weight_decay > 0.:
# Note that the decayed weights are added to the momentums.
# The mechanism is the same as the official repo.
g += self.weight_decay * p

m_t = (self.beta_1 * m) + (1. - self.beta_1) * g
v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g)

if self.amsgrad:
vhat_t = K.maximum(vhat, v_t)
step = lr_t / (K.sqrt(vhat_t) + self.epsilon)
self.updates.append(K.update(vhat, vhat_t))
else:
step = lr_t / (K.sqrt(v_t) + self.epsilon)
p_t = p - K.minimum(K.maximum(step, lower_bound), upper_bound) * m_t
self.updates.append(K.update(m, m_t))
self.updates.append(K.update(v, v_t))
new_p = p_t
vhat_t = None
denom = tf.sqrt(v_t) + epsilon_t

final_lr = final_lr * lr_t / base_lr
lower_bound = final_lr * (1.0 - 1.0 / (gamma * local_step + 1.0))
upper_bound = final_lr * (1.0 + 1.0 / (gamma * local_step))
lr_t = lr_t * (tf.sqrt(1.0 - beta_2_power) / (1.0 - beta_1_power))
lr_t = tf.clip_by_value(lr_t / denom, lower_bound, upper_bound)
var_update = var.assign_sub(lr_t * m_t, use_locking=self._use_locking)

updates = [var_update, m_t, v_t]
if self.amsgrad:
updates.append(vhat_t)
return tf.group(*updates)

def _resource_apply_sparse(self, grad, var, indices):
var_dtype = var.dtype.base_dtype
lr_t = self._decayed_lr(var_dtype)
wd_t = self._decayed_wd(var_dtype)
base_lr = self._get_hyper("base_lr", var_dtype)
final_lr = self._get_hyper("final_lr", var_dtype)
m = self.get_slot(var, "m")
v = self.get_slot(var, "v")
beta_1_t = self._get_hyper("beta_1", var_dtype)
beta_2_t = self._get_hyper("beta_2", var_dtype)
gamma = self._get_hyper("gamma", var_dtype)
epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype)
local_step = tf.cast(self.iterations + 1, var_dtype)
beta_1_power = tf.pow(beta_1_t, local_step)
beta_2_power = tf.pow(beta_2_t, local_step)

if self._has_weight_decay:
grad = grad + wd_t * tf.squeeze(tf.gather(tf.expand_dims(var, axis=0), indices, axis=1), axis=0)

# Apply constraints.
if getattr(p, 'constraint', None) is not None:
new_p = p.constraint(new_p)
m_scaled_g_values = grad * (1 - beta_1_t)
m_t = m.assign(m * beta_1_t, use_locking=self._use_locking)
with tf.control_dependencies([m_t]):
m_t = self._resource_scatter_add(m, indices, m_scaled_g_values)

self.updates.append(K.update(p, new_p))
return self.updates
v_scaled_g_values = (grad * grad) * (1 - beta_2_t)
v_t = v.assign(v * beta_2_t, use_locking=self._use_locking)
with tf.control_dependencies([v_t]):
v_t = self._resource_scatter_add(v, indices, v_scaled_g_values)

if self.amsgrad:
vhat = self.get_slot(var, "vhat")
vhat_t = vhat.assign(tf.maximum(vhat, v_t), use_locking=self._use_locking)
denom = tf.sqrt(vhat_t) + epsilon_t
else:
vhat_t = None
denom = tf.sqrt(v_t) + epsilon_t

final_lr = final_lr * lr_t / base_lr
lower_bound = final_lr * (1.0 - 1.0 / (gamma * local_step + 1.0))
upper_bound = final_lr * (1.0 + 1.0 / (gamma * local_step))
lr_t = lr_t * (tf.sqrt(1.0 - beta_2_power) / (1.0 - beta_1_power))
lr_t = tf.clip_by_value(lr_t / denom, lower_bound, upper_bound)
with tf.control_dependencies([m_t]):
var_update = self._resource_scatter_add(
var, indices, tf.gather(-lr_t * m_t, indices)
)

updates = [var_update, m_t, v_t]
if self.amsgrad:
updates.append(vhat_t)
return tf.group(*updates)

def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'final_lr': float(K.get_value(self.final_lr)),
'base_lr': self.base_lr,
'beta_1': float(K.get_value(self.beta_1)),
'beta_2': float(K.get_value(self.beta_2)),
'gamma': float(K.get_value(self.gamma)),
'decay': float(K.get_value(self.decay)),
'weight_decay': float(K.get_value(self.weight_decay)),
'epsilon': self.epsilon,
'amsgrad': self.amsgrad}
base_config = super(AdaBound, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
config = super().get_config()
config.update(
{
"learning_rate": self._serialize_hyperparameter("learning_rate"),
"base_lr": self._serialize_hyperparameter("base_lr"),
"final_lr": self._serialize_hyperparameter("final_lr"),
"beta_1": self._serialize_hyperparameter("beta_1"),
"beta_2": self._serialize_hyperparameter("beta_2"),
"gamma": self._serialize_hyperparameter("gamma"),
"weight_decay": self._serialize_hyperparameter("weight_decay"),
"epsilon": self.epsilon,
"amsgrad": self.amsgrad
}
)
return config
1 change: 1 addition & 0 deletions requirements-dev.txt
Expand Up @@ -6,3 +6,4 @@ nose
tensorflow
pycodestyle
coverage
adabound
1 change: 1 addition & 0 deletions requirements.txt
@@ -1,2 +1,3 @@
numpy
Keras
typeguard

0 comments on commit 71ca69e

Please sign in to comment.