From 31d79559fe27af30f79628792ecabe41864a03d3 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Wed, 13 Nov 2019 19:56:38 +0000 Subject: [PATCH] fixing base lamb optimizer --- python/mxnet/optimizer/optimizer.py | 22 ++-- src/operator/optimizer_op-inl.h | 152 +++++++++++++++++------- src/operator/optimizer_op.cc | 34 ++++-- src/operator/optimizer_op.cu | 8 +- tests/python/unittest/test_optimizer.py | 27 +++-- 5 files changed, 173 insertions(+), 70 deletions(-) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index d076540cb662..19efe4d9df1a 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -34,7 +34,7 @@ multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update, multi_mp_sgd_mom_update, preloaded_multi_sgd_update, preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update, - preloaded_multi_mp_sgd_mom_update, lamb_update) + preloaded_multi_mp_sgd_mom_update, lamb_update_phase1, lamb_update_phase2) from ..ndarray import sparse from ..random import normal from ..util import is_np_array @@ -1250,7 +1250,7 @@ class LAMB(Optimizer): """LAMB Optimizer. """ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6, - lower_bound=1e-3, upper_bound=10.0, bias_correction=False, **kwargs): + lower_bound=None, upper_bound=None, bias_correction=False, **kwargs): super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs) self.beta1 = beta1 self.beta2 = beta2 @@ -1259,13 +1259,14 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6, self.upper_bound = upper_bound self.bias_correction = bias_correction + def create_state(self, index, weight): stype = weight.stype dtype = weight.dtype return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype), zeros(weight.shape, weight.context, dtype=dtype, stype=stype)) - def update(self, index, weight,grad, state): + def update(self, index, weight, grad, state): assert(isinstance(weight, NDArray)) assert(isinstance(grad, NDArray)) self._update_count(index) @@ -1274,14 +1275,21 @@ def update(self, index, weight,grad, state): t = self._index_update_count[index] kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon, - 'lower_bound': self.lower_bound, 'upper_bound': self.upper_bound, 'bias_correction': self.bias_correction, 't': t, 'rescale_grad': self.rescale_grad} + mean, var = state if self.clip_gradient: kwargs['clip_gradient'] = self.clip_gradient - - mean, var = state - lamb_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs) + g = lamb_update_phase1(weight, grad, mean, var, wd=wd, **kwargs) + + kwargs = {} + if self.lower_bound: + kwargs['lower_bound'] = self.lower_bound + if self.upper_bound: + kwargs['upper_bound'] = self.upper_bound + r_1 = weight.norm() + r_2 = g.norm() + lamb_update_phase2(weight, g, r_1, r_2, lr=lr, out=weight, **kwargs) # pylint: enable=line-too-long diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 6cf05618f536..7c511f4788ed 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -1563,21 +1563,16 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs, } } -struct LAMBParam : public dmlc::Parameter { - float lr; +struct LambUpdatePhaseOneParam : public dmlc::Parameter { float beta1; float beta2; float epsilon; - float lower_bound; - float upper_bound; float t; bool bias_correction; float wd; float rescale_grad; float clip_gradient; - DMLC_DECLARE_PARAMETER(LAMBParam) { - DMLC_DECLARE_FIELD(lr) - .describe("Learning rate"); + DMLC_DECLARE_PARAMETER(LambUpdatePhaseOneParam) { DMLC_DECLARE_FIELD(beta1) .set_default(0.9f) .describe("The decay rate for the 1st moment estimates."); @@ -1587,19 +1582,12 @@ struct LAMBParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(epsilon) .set_default(1e-6f) .describe("A small constant for numerical stability."); - DMLC_DECLARE_FIELD(lower_bound) - .set_default(1e-3f) - .describe("Lower limit of norm of weight."); - DMLC_DECLARE_FIELD(upper_bound) - .set_default(10.0f) - .describe("Upper limit of norm of weight."); DMLC_DECLARE_FIELD(t) .describe("Index update count."); DMLC_DECLARE_FIELD(bias_correction) .set_default(false) .describe("Whether to use bias correction."); DMLC_DECLARE_FIELD(wd) - .set_default(0.0f) .describe("Weight decay augments the objective function with a " "regularization term that penalizes large weights. " "The penalty scales with the square of the magnitude of each weight."); @@ -1614,19 +1602,33 @@ struct LAMBParam : public dmlc::Parameter { } }; -struct LAMBUpdateKernel { +struct LambUpdatePhaseTwoParam : public dmlc::Parameter { + float lr; + float lower_bound; + float upper_bound; + DMLC_DECLARE_PARAMETER(LambUpdatePhaseTwoParam) { + DMLC_DECLARE_FIELD(lr) + .describe("Learning rate"); + DMLC_DECLARE_FIELD(lower_bound) + .set_default(-1.0f) + .describe("Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set"); + DMLC_DECLARE_FIELD(upper_bound) + .set_default(-1.0f) + .describe("Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set"); + } +}; + +struct LambUpdatePhaseOneKernel { template MSHADOW_XINLINE static void Map(int i, DType* out_data, DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data, const DType clip_gradient, const DType rescale_grad, - const DType beta1, const DType beta2, - DType lr, const DType wd, - const DType epsilon, const DType lower_bound, - const DType upper_bound, const DType t, + const DType beta1, const DType beta2, const DType wd, + const DType epsilon, const DType t, bool bias_correction, const OpReqType req) { using namespace mshadow_op; - DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[i] * wd; + DType grad_rescaled = grad_data[i] * rescale_grad; if (clip_gradient >= 0.f) { grad_rescaled = clip::Map(grad_rescaled, clip_gradient); } @@ -1634,36 +1636,26 @@ struct LAMBUpdateKernel { mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled; var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled; - DType r1 = square_root::Map(square::Map(weight_data[i])); - - r1 = minimum::Map(maximum::Map(r1, lower_bound), upper_bound); - DType g = mean_data[i] / square_root::Map(var_data[i] + epsilon) + wd * weight_data[i]; + DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i]; if (bias_correction) { DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t)); DType var_hat = var_data[i] / (1 - power::Map(beta2, t)); - g = mean_hat / square_root::Map(var_hat + epsilon) + wd * weight_data[i]; - } - DType r2 = square_root::Map(square::Map(g)); - if (r1 == 0.0f || r2 == 0.0f) { - lr = lr * 1.0f; - } else { - lr = lr * r1 / r2; + g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i]; } - - KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g); + KERNEL_ASSIGN(out_data[i], req, g); } }; template -inline void LAMBUpdate(const nnvm::NodeAttrs& attrs, +inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { - using namespace mxnet_op; - const LAMBParam& param = nnvm::get(attrs.parsed); - Stream* s = ctx.get_stream(); + using namespace mxnet_op; + const LambUpdatePhaseOneParam& param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { Tensor weight = inputs[0].FlatTo2D(s); Tensor grad = inputs[1].FlatTo2D(s); @@ -1671,17 +1663,91 @@ inline void LAMBUpdate(const nnvm::NodeAttrs& attrs, Tensor var = inputs[3].FlatTo2D(s); Tensor out = outputs[0].FlatTo2D(s); - Kernel::Launch(s, weight.shape_.Size(), + Kernel::Launch(s, weight.shape_.Size(), out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_, static_cast(param.clip_gradient), static_cast(param.rescale_grad), static_cast(param.beta1), static_cast(param.beta2), - static_cast(param.lr), static_cast(param.wd), - static_cast(param.epsilon), static_cast(param.lower_bound), - static_cast(param.upper_bound), static_cast(param.t), - static_cast(param.bias_correction), req[0]); - }); + static_cast(param.wd), static_cast(param.epsilon), + static_cast(param.t), static_cast(param.bias_correction), req[0]); + }); +} + +inline bool LambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 4U); + CHECK_EQ(out_attrs->size(), 1U); + + mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1); + + mxnet::TShape& weight_shape = in_attrs->at(0); + mxnet::TShape& g_shape = in_attrs->at(1); + CHECK_EQ(weight_shape.ndim(), g_shape.ndim()) + << "total no. of dimensions for weights and g must match"; + for (int i=0; i < weight_shape.ndim(); ++i) { + CHECK_EQ(weight_shape[i], g_shape[i]) + << "weight and g dimension size mismatch at " << i << "-th index"; + } + mxnet::TShape& r1_shape = in_attrs->at(2); + mxnet::TShape& r2_shape = in_attrs->at(3); + CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect"; + CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect"; + for (int i=0; i < expected_out.ndim(); ++i) { + expected_out[i] = weight_shape[i]; + } + + SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out); + return shape_is_known(expected_out); } +struct LambUpdatePhaseTwoKernel { + template + MSHADOW_XINLINE static void Map(int i, DType* out_data, + const DType* weight_data, const DType* g, + const DType* r1, const DType* r2, + DType lr, const DType lower_bound, + const DType upper_bound, const OpReqType req) { + using namespace mshadow_op; + + DType new_r1 = r1[0]; + if (lower_bound >= 0) { + new_r1 = maximum::Map(new_r1, lower_bound); + } + if (upper_bound >= 0) { + new_r1 = minimum::Map(new_r1, upper_bound); + } + if (new_r1 == 0.0f || r2[0] == 0.0f) { + lr = lr * 1.0f; + } else { + lr = lr * new_r1 / r2[0]; + } + + KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g[i]); + } +}; + +template +inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + const LambUpdatePhaseTwoParam& param = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { + Tensor weight = inputs[0].FlatTo2D(s); + Tensor g = inputs[1].FlatTo2D(s); + Tensor r1 = inputs[2].FlatTo2D(s); + Tensor r2 = inputs[3].FlatTo2D(s); + Tensor out = outputs[0].FlatTo2D(s); + + Kernel::Launch(s, weight.shape_.Size(), + out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_, + static_cast(param.lr), static_cast(param.lower_bound), + static_cast(param.upper_bound), req[0]); + }); +} // This RMSProp code follows the version in // http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45) diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 259fa0397fc7..ff248861788a 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -43,7 +43,8 @@ DMLC_REGISTER_PARAMETER(FtrlParam); DMLC_REGISTER_PARAMETER(SignSGDParam); DMLC_REGISTER_PARAMETER(SignumParam); DMLC_REGISTER_PARAMETER(AdagradParam); -DMLC_REGISTER_PARAMETER(LAMBParam); +DMLC_REGISTER_PARAMETER(LambUpdatePhaseOneParam); +DMLC_REGISTER_PARAMETER(LambUpdatePhaseTwoParam); NNVM_REGISTER_OP(signsgd_update) .describe(R"code(Update function for SignSGD optimizer. @@ -922,20 +923,39 @@ Note that non-zero values for the weight decay option are not supported. .add_argument("history", "NDArray-or-Symbol", "History") .add_arguments(AdagradParam::__FIELDS__()); -NNVM_REGISTER_OP(lamb_update) +NNVM_REGISTER_OP(lamb_update_phase1) .describe(R"code(Update function for lamb optimizer. )code" ADD_FILELINE) .set_num_inputs(4) .set_num_outputs(1) -.set_attr_parser(ParamParser) -.set_attr("FInferShape", ElemwiseShape<4,1>) -.set_attr("FInferType", ElemwiseType<4,1>) -.set_attr("FCompute", LAMBUpdate) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<4, 1>) +.set_attr("FInferType", ElemwiseType<4, 1>) +.set_attr("FCompute", LambUpdatePhaseOne) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + return std::vector{2, 3}; + }) .add_argument("weight", "NDArray-or-Symbol", "Weight") .add_argument("grad", "NDArray-or-Symbol", "Gradient") .add_argument("mean", "NDArray-or-Symbol", "Moving mean") .add_argument("var", "NDArray-or-Symbol", "Moving variance") -.add_arguments(LAMBParam::__FIELDS__()); +.add_arguments(LambUpdatePhaseOneParam::__FIELDS__()); + +NNVM_REGISTER_OP(lamb_update_phase2) +.describe(R"code(Update function for lamb optimizer. +)code" ADD_FILELINE) +.set_num_inputs(4) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", LambUpdatePhaseTwoShape) +.set_attr("FInferType", ElemwiseType<4, 1>) +.set_attr("FCompute", LambUpdatePhaseTwo) +.add_argument("weight", "NDArray-or-Symbol", "Weight") +.add_argument("g", "NDArray-or-Symbol", "Output of lamb_update_phase 1") +.add_argument("r1", "NDArray-or-Symbol", "r1") +.add_argument("r2", "NDArray-or-Symbol", "r2") +.add_arguments(LambUpdatePhaseTwoParam::__FIELDS__()); } // namespace op } // namespace mxnet diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index d9a1e3414178..a602b649b63d 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -277,8 +277,12 @@ NNVM_REGISTER_OP(ftrl_update) NNVM_REGISTER_OP(_sparse_adagrad_update) .set_attr("FComputeEx", AdagradUpdateEx); -NNVM_REGISTER_OP(lamb_update) -.set_attr("FCompute", LambUpdate); +NNVM_REGISTER_OP(lamb_update_phase1) +.set_attr("FCompute", LambUpdatePhaseOne); + +NNVM_REGISTER_OP(lamb_update_phase2) +.set_attr("FCompute", LambUpdatePhaseTwo); + } // namespace op } // namespace mxnet diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py index fce27a4719f0..892b595c4b04 100644 --- a/tests/python/unittest/test_optimizer.py +++ b/tests/python/unittest/test_optimizer.py @@ -432,7 +432,7 @@ class PyLAMB(mx.optimizer.Optimizer): Python reference implementation of LAMB optimizer. """ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6, - lower_bound=1e-3, upper_bound=10.0, bias_correction=False, **kwargs): + lower_bound=None, upper_bound=None, bias_correction=False, **kwargs): super(PyLAMB, self).__init__(learning_rate=learning_rate, **kwargs) self.beta1 = beta1 self.beta2 = beta2 @@ -454,31 +454,34 @@ def update(self, index, weight, grad, state): grad *= self.rescale_grad if self.clip_gradient is not None: - grad = clip(grad, -self.clip_gradient, self.clip_gradient) + grad = mx.nd.clip(grad, -self.clip_gradient, self.clip_gradient) mean, var = state mean[:] = self.beta1 * mean + (1. - self.beta1) * grad var[:] = self.beta2 * var + (1. - self.beta2) * mx.nd.square(grad) + mean_hat = mean + var_hat = var r1 = weight.norm() - if not self.bias_correction: - r1 = mx.nd.minimum(mx.nd.maximum(r1, self.lower_bound), self.upper_bound) - g = mean / (mx.nd.sqrt(var) + self.epsilon) + wd * weight - - else: + if self.lower_bound: + r1 = mx.nd.maximum(r1, self.lower_bound) + if self.upper_bound: + r1 = mx.nd.minimum(r1, self.upper_bound) + if self.bias_correction: mean_hat = mean / (1. - mx.nd.power(self.beta1, t)) var_hat = var / (1. - mx.nd.power(self.beta2, t)) - g = mean_hat / mx.nd.sqrt(var_hat + self.epsilon) + wd * weight + g = mean_hat / (mx.nd.sqrt(var_hat) + self.epsilon) + wd * weight r2 = g.norm() - # calculate lamb_trust_ratio r = 1. if r1 == 0. or r2 == 0. else r1 / r2 lr *= r - # update weight weight[:] -= lr * g + def update_multi_precision(self, index, weight, grad, state): + self.update(index, weight, grad, state) + @with_seed() def test_lamb(): opt1 = PyLAMB @@ -488,7 +491,9 @@ def test_lamb(): rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}] wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}] bc_options = [{}, {'bias_correction': False}, {'bias_correction': True}] - for params in itertools.product(cg_options, rg_options, wd_options, bc_options): + lb_options = [{}, {'lower_bound': None}, {'lower_bound': 1e-3}] + ub_options = [{}, {'upper_bound': None}, {'upper_bound': 10}] + for params in itertools.product(cg_options, rg_options, wd_options, bc_options, lb_options, ub_options): kwarg = {k: v for param in params for k, v in param.items()} compare_optimizer(opt1(**kwarg), opt2(**kwarg), shape, np.float32)