Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
initial commit lamb optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudhacharya authored and Rohit Kumar Srivastava committed Nov 14, 2019
1 parent 7e21bda commit 4f5ebff
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 2 deletions.
44 changes: 42 additions & 2 deletions python/mxnet/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
multi_mp_sgd_mom_update, preloaded_multi_sgd_update,
preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update,
preloaded_multi_mp_sgd_mom_update)
preloaded_multi_mp_sgd_mom_update, lamb_update)
from ..ndarray import sparse
from ..random import normal
from ..util import is_np_array

__all__ = [
'AdaDelta', 'AdaGrad', 'Adam', 'Adamax', 'DCASGD', 'FTML', 'Ftrl', 'LARS', 'LBSGD',
'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum',
'NAG', 'NDabs', 'Nadam', 'Optimizer', 'RMSProp', 'SGD', 'SGLD', 'Signum', 'LAMB',
'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register'
]

Expand Down Expand Up @@ -1244,6 +1244,46 @@ def update(self, index, weight, grad, state):
kwargs = {}
sgd_update(weight, grad, out=weight, lr=lr, wd=wd, **kwargs)


@register
class LAMB(Optimizer):
"""LAMB Optimizer.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
lower_bound=1e-3, upper_bound=10.0, bias_correction=False, **kwargs):
super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.bias_correction = bias_correction

def create_state(self, index, weight):
stype = weight.stype
dtype = weight.dtype
return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype),
zeros(weight.shape, weight.context, dtype=dtype, stype=stype))

def update(self, index, weight,grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]

kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'lower_bound': self.lower_bound, 'upper_bound': self.upper_bound,
'bias_correction': self.bias_correction, 't': t,
'rescale_grad': self.rescale_grad}
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient

mean, var = state
lamb_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs)


# pylint: enable=line-too-long
@register
class DCASGD(Optimizer):
Expand Down
120 changes: 120 additions & 0 deletions src/operator/optimizer_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1563,6 +1563,126 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
}
}

struct LAMBParam : public dmlc::Parameter<LAMBParam> {
float lr;
float beta1;
float beta2;
float epsilon;
float lower_bound;
float upper_bound;
float t;
bool bias_correction;
float wd;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(LAMBParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(beta1)
.set_default(0.9f)
.describe("The decay rate for the 1st moment estimates.");
DMLC_DECLARE_FIELD(beta2)
.set_default(0.999f)
.describe("The decay rate for the 2nd moment estimates.");
DMLC_DECLARE_FIELD(epsilon)
.set_default(1e-6f)
.describe("A small constant for numerical stability.");
DMLC_DECLARE_FIELD(lower_bound)
.set_default(1e-3f)
.describe("Lower limit of norm of weight.");
DMLC_DECLARE_FIELD(upper_bound)
.set_default(10.0f)
.describe("Upper limit of norm of weight.");
DMLC_DECLARE_FIELD(t)
.describe("Index update count.");
DMLC_DECLARE_FIELD(bias_correction)
.set_default(false)
.describe("Whether to use bias correction.");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
DMLC_DECLARE_FIELD(rescale_grad)
.set_default(1.0f)
.describe("Rescale gradient to grad = rescale_grad*grad.");
DMLC_DECLARE_FIELD(clip_gradient)
.set_default(-1.0f)
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
}
};

struct LAMBUpdateKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
const DType clip_gradient, const DType rescale_grad,
const DType beta1, const DType beta2,
DType lr, const DType wd,
const DType epsilon, const DType lower_bound,
const DType upper_bound, const DType t,
bool bias_correction, const OpReqType req) {
using namespace mshadow_op;

DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[i] * wd;
if (clip_gradient >= 0.f) {
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
}

mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled;

DType r1 = square_root::Map(square::Map(weight_data[i]));

r1 = minimum::Map(maximum::Map(r1, lower_bound), upper_bound);
DType g = mean_data[i] / square_root::Map(var_data[i] + epsilon) + wd * weight_data[i];

if (bias_correction) {
DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t));
DType var_hat = var_data[i] / (1 - power::Map(beta2, t));
g = mean_hat / square_root::Map(var_hat + epsilon) + wd * weight_data[i];
}
DType r2 = square_root::Map(square::Map(g));
if (r1 == 0.0f || r2 == 0.0f) {
lr = lr * 1.0f;
} else {
lr = lr * r1 / r2;
}

KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g);
}
};

template<typename xpu>
inline void LAMBUpdate(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const LAMBParam& param = nnvm::get<LAMBParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

Kernel<LAMBUpdateKernel, xpu>::Launch(s, weight.shape_.Size(),
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.epsilon), static_cast<DType>(param.lower_bound),
static_cast<DType>(param.upper_bound), static_cast<DType>(param.t),
static_cast<bool>(param.bias_correction), req[0]);
});
}


// This RMSProp code follows the version in
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
// by Alex Graves, 2013.
Expand Down
16 changes: 16 additions & 0 deletions src/operator/optimizer_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ DMLC_REGISTER_PARAMETER(FtrlParam);
DMLC_REGISTER_PARAMETER(SignSGDParam);
DMLC_REGISTER_PARAMETER(SignumParam);
DMLC_REGISTER_PARAMETER(AdagradParam);
DMLC_REGISTER_PARAMETER(LAMBParam);

NNVM_REGISTER_OP(signsgd_update)
.describe(R"code(Update function for SignSGD optimizer.
Expand Down Expand Up @@ -921,5 +922,20 @@ Note that non-zero values for the weight decay option are not supported.
.add_argument("history", "NDArray-or-Symbol", "History")
.add_arguments(AdagradParam::__FIELDS__());

NNVM_REGISTER_OP(lamb_update)
.describe(R"code(Update function for lamb optimizer.
)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LAMBParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4,1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4,1>)
.set_attr<FCompute>("FCompute<cpu>", LAMBUpdate<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_arguments(LAMBParam::__FIELDS__());

} // namespace op
} // namespace mxnet
3 changes: 3 additions & 0 deletions src/operator/optimizer_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -277,5 +277,8 @@ NNVM_REGISTER_OP(ftrl_update)
NNVM_REGISTER_OP(_sparse_adagrad_update)
.set_attr<FComputeEx>("FComputeEx<gpu>", AdagradUpdateEx<gpu>);

NNVM_REGISTER_OP(lamb_update)
.set_attr<FCompute>("FCompute<gpu>", LambUpdate<gpu>);

} // namespace op
} // namespace mxnet
68 changes: 68 additions & 0 deletions tests/python/unittest/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,74 @@ def test_nag():
continue
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, dtype, rtol=1e-3, atol=1e-4)


# LAMB optimizer
class PyLAMB(mx.optimizer.Optimizer):
"""
Python reference implementation of LAMB optimizer.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
lower_bound=1e-3, upper_bound=10.0, bias_correction=False, **kwargs):
super(PyLAMB, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.bias_correction = bias_correction

def create_state(self, index, weight):
stype = weight.stype
return (mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype),
mx.nd.zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype))

def update(self, index, weight, grad, state):
self._update_count(index)
lr = self._get_lr(index)
wd = self._get_wd(index)
t = self._index_update_count[index]

grad *= self.rescale_grad
if self.clip_gradient is not None:
grad = clip(grad, -self.clip_gradient, self.clip_gradient)

mean, var = state
mean[:] = self.beta1 * mean + (1. - self.beta1) * grad
var[:] = self.beta2 * var + (1. - self.beta2) * mx.nd.square(grad)

r1 = weight.norm()
if not self.bias_correction:
r1 = mx.nd.minimum(mx.nd.maximum(r1, self.lower_bound), self.upper_bound)
g = mean / (mx.nd.sqrt(var) + self.epsilon) + wd * weight

else:
mean_hat = mean / (1. - mx.nd.power(self.beta1, t))
var_hat = var / (1. - mx.nd.power(self.beta2, t))
g = mean_hat / mx.nd.sqrt(var_hat + self.epsilon) + wd * weight

r2 = g.norm()

# calculate lamb_trust_ratio
r = 1. if r1 == 0. or r2 == 0. else r1 / r2
lr *= r

# update weight
weight[:] -= lr * g

@with_seed()
def test_lamb():
opt1 = PyLAMB
opt2 = mx.optimizer.LAMB
shape = (3, 4, 5)
cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
bc_options = [{}, {'bias_correction': False}, {'bias_correction': True}]
for params in itertools.product(cg_options, rg_options, wd_options, bc_options):
kwarg = {k: v for param in params for k, v in param.items()}
compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32)


#SGLD
class PySGLD(mx.optimizer.Optimizer):
"""python reference implementation of SGLD"""
Expand Down

0 comments on commit 4f5ebff

Please sign in to comment.