Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fixing base lamb optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Nov 14, 2019
1 parent 4f5ebff commit 31d7955
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 70 deletions.
22 changes: 15 additions & 7 deletions python/mxnet/optimizer/optimizer.py
Expand Up @@ -34,7 +34,7 @@
multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update,
multi_mp_sgd_mom_update, preloaded_multi_sgd_update,
preloaded_multi_sgd_mom_update, preloaded_multi_mp_sgd_update,
preloaded_multi_mp_sgd_mom_update, lamb_update)
preloaded_multi_mp_sgd_mom_update, lamb_update_phase1, lamb_update_phase2)
from ..ndarray import sparse
from ..random import normal
from ..util import is_np_array
Expand Down Expand Up @@ -1250,7 +1250,7 @@ class LAMB(Optimizer):
"""LAMB Optimizer.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
lower_bound=1e-3, upper_bound=10.0, bias_correction=False, **kwargs):
lower_bound=None, upper_bound=None, bias_correction=False, **kwargs):
super(LAMB, self).__init__(learning_rate=learning_rate, **kwargs)
self.beta1 = beta1
self.beta2 = beta2
Expand All @@ -1259,13 +1259,14 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-6,
self.upper_bound = upper_bound
self.bias_correction = bias_correction


def create_state(self, index, weight):
stype = weight.stype
dtype = weight.dtype
return (zeros(weight.shape, weight.context, dtype=dtype, stype=stype),
zeros(weight.shape, weight.context, dtype=dtype, stype=stype))

def update(self, index, weight,grad, state):
def update(self, index, weight, grad, state):
assert(isinstance(weight, NDArray))
assert(isinstance(grad, NDArray))
self._update_count(index)
Expand All @@ -1274,14 +1275,21 @@ def update(self, index, weight,grad, state):
t = self._index_update_count[index]

kwargs = {'beta1': self.beta1, 'beta2': self.beta2, 'epsilon': self.epsilon,
'lower_bound': self.lower_bound, 'upper_bound': self.upper_bound,
'bias_correction': self.bias_correction, 't': t,
'rescale_grad': self.rescale_grad}
mean, var = state
if self.clip_gradient:
kwargs['clip_gradient'] = self.clip_gradient

mean, var = state
lamb_update(weight, grad, mean, var, out=weight, lr=lr, wd=wd, **kwargs)
g = lamb_update_phase1(weight, grad, mean, var, wd=wd, **kwargs)

kwargs = {}
if self.lower_bound:
kwargs['lower_bound'] = self.lower_bound
if self.upper_bound:
kwargs['upper_bound'] = self.upper_bound
r_1 = weight.norm()
r_2 = g.norm()
lamb_update_phase2(weight, g, r_1, r_2, lr=lr, out=weight, **kwargs)


# pylint: enable=line-too-long
Expand Down
152 changes: 109 additions & 43 deletions src/operator/optimizer_op-inl.h
Expand Up @@ -1563,21 +1563,16 @@ inline void AdamUpdateEx(const nnvm::NodeAttrs& attrs,
}
}

struct LAMBParam : public dmlc::Parameter<LAMBParam> {
float lr;
struct LambUpdatePhaseOneParam : public dmlc::Parameter<LambUpdatePhaseOneParam> {
float beta1;
float beta2;
float epsilon;
float lower_bound;
float upper_bound;
float t;
bool bias_correction;
float wd;
float rescale_grad;
float clip_gradient;
DMLC_DECLARE_PARAMETER(LAMBParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_PARAMETER(LambUpdatePhaseOneParam) {
DMLC_DECLARE_FIELD(beta1)
.set_default(0.9f)
.describe("The decay rate for the 1st moment estimates.");
Expand All @@ -1587,19 +1582,12 @@ struct LAMBParam : public dmlc::Parameter<LAMBParam> {
DMLC_DECLARE_FIELD(epsilon)
.set_default(1e-6f)
.describe("A small constant for numerical stability.");
DMLC_DECLARE_FIELD(lower_bound)
.set_default(1e-3f)
.describe("Lower limit of norm of weight.");
DMLC_DECLARE_FIELD(upper_bound)
.set_default(10.0f)
.describe("Upper limit of norm of weight.");
DMLC_DECLARE_FIELD(t)
.describe("Index update count.");
DMLC_DECLARE_FIELD(bias_correction)
.set_default(false)
.describe("Whether to use bias correction.");
DMLC_DECLARE_FIELD(wd)
.set_default(0.0f)
.describe("Weight decay augments the objective function with a "
"regularization term that penalizes large weights. "
"The penalty scales with the square of the magnitude of each weight.");
Expand All @@ -1614,74 +1602,152 @@ struct LAMBParam : public dmlc::Parameter<LAMBParam> {
}
};

struct LAMBUpdateKernel {
struct LambUpdatePhaseTwoParam : public dmlc::Parameter<LambUpdatePhaseTwoParam> {
float lr;
float lower_bound;
float upper_bound;
DMLC_DECLARE_PARAMETER(LambUpdatePhaseTwoParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
DMLC_DECLARE_FIELD(lower_bound)
.set_default(-1.0f)
.describe("Lower limit of norm of weight. If lower_bound <= 0, Lower limit is not set");
DMLC_DECLARE_FIELD(upper_bound)
.set_default(-1.0f)
.describe("Upper limit of norm of weight. If upper_bound <= 0, Upper limit is not set");
}
};

struct LambUpdatePhaseOneKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
DType* mean_data, DType* var_data, const DType* weight_data, const DType* grad_data,
const DType clip_gradient, const DType rescale_grad,
const DType beta1, const DType beta2,
DType lr, const DType wd,
const DType epsilon, const DType lower_bound,
const DType upper_bound, const DType t,
const DType beta1, const DType beta2, const DType wd,
const DType epsilon, const DType t,
bool bias_correction, const OpReqType req) {
using namespace mshadow_op;

DType grad_rescaled = grad_data[i] * rescale_grad + weight_data[i] * wd;
DType grad_rescaled = grad_data[i] * rescale_grad;
if (clip_gradient >= 0.f) {
grad_rescaled = clip::Map(grad_rescaled, clip_gradient);
}

mean_data[i] = beta1 * mean_data[i] + (1.f - beta1) * grad_rescaled;
var_data[i] = beta2 * var_data[i] + (1.f - beta2) * grad_rescaled * grad_rescaled;

DType r1 = square_root::Map(square::Map(weight_data[i]));

r1 = minimum::Map(maximum::Map(r1, lower_bound), upper_bound);
DType g = mean_data[i] / square_root::Map(var_data[i] + epsilon) + wd * weight_data[i];
DType g = mean_data[i] / (square_root::Map(var_data[i]) + epsilon) + wd * weight_data[i];

if (bias_correction) {
DType mean_hat = mean_data[i] / (1. - power::Map(beta1, t));
DType var_hat = var_data[i] / (1 - power::Map(beta2, t));
g = mean_hat / square_root::Map(var_hat + epsilon) + wd * weight_data[i];
}
DType r2 = square_root::Map(square::Map(g));
if (r1 == 0.0f || r2 == 0.0f) {
lr = lr * 1.0f;
} else {
lr = lr * r1 / r2;
g = mean_hat / (square_root::Map(var_hat) + epsilon) + wd * weight_data[i];
}

KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g);
KERNEL_ASSIGN(out_data[i], req, g);
}
};

template<typename xpu>
inline void LAMBUpdate(const nnvm::NodeAttrs& attrs,
inline void LambUpdatePhaseOne(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const LAMBParam& param = nnvm::get<LAMBParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
using namespace mxnet_op;
const LambUpdatePhaseOneParam& param = nnvm::get<LambUpdatePhaseOneParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> grad = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> mean = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> var = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

Kernel<LAMBUpdateKernel, xpu>::Launch(s, weight.shape_.Size(),
Kernel<LambUpdatePhaseOneKernel, xpu>::Launch(s, weight.shape_.Size(),
out.dptr_, mean.dptr_, var.dptr_, weight.dptr_, grad.dptr_,
static_cast<DType>(param.clip_gradient), static_cast<DType>(param.rescale_grad),
static_cast<DType>(param.beta1), static_cast<DType>(param.beta2),
static_cast<DType>(param.lr), static_cast<DType>(param.wd),
static_cast<DType>(param.epsilon), static_cast<DType>(param.lower_bound),
static_cast<DType>(param.upper_bound), static_cast<DType>(param.t),
static_cast<bool>(param.bias_correction), req[0]);
});
static_cast<DType>(param.wd), static_cast<DType>(param.epsilon),
static_cast<DType>(param.t), static_cast<bool>(param.bias_correction), req[0]);
});
}

inline bool LambUpdatePhaseTwoShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector* in_attrs,
mxnet::ShapeVector* out_attrs) {
CHECK_EQ(in_attrs->size(), 4U);
CHECK_EQ(out_attrs->size(), 1U);

mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);

mxnet::TShape& weight_shape = in_attrs->at(0);
mxnet::TShape& g_shape = in_attrs->at(1);
CHECK_EQ(weight_shape.ndim(), g_shape.ndim())
<< "total no. of dimensions for weights and g must match";
for (int i=0; i < weight_shape.ndim(); ++i) {
CHECK_EQ(weight_shape[i], g_shape[i])
<< "weight and g dimension size mismatch at " << i << "-th index";
}
mxnet::TShape& r1_shape = in_attrs->at(2);
mxnet::TShape& r2_shape = in_attrs->at(3);
CHECK_EQ(r1_shape[0], 1U) << "r1 shape incorrect";
CHECK_EQ(r2_shape[0], 1U) << "r2 shape incorrect";
for (int i=0; i < expected_out.ndim(); ++i) {
expected_out[i] = weight_shape[i];
}

SHAPE_ASSIGN_CHECK(*out_attrs, 0, expected_out);
return shape_is_known(expected_out);
}

struct LambUpdatePhaseTwoKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data,
const DType* weight_data, const DType* g,
const DType* r1, const DType* r2,
DType lr, const DType lower_bound,
const DType upper_bound, const OpReqType req) {
using namespace mshadow_op;

DType new_r1 = r1[0];
if (lower_bound >= 0) {
new_r1 = maximum::Map(new_r1, lower_bound);
}
if (upper_bound >= 0) {
new_r1 = minimum::Map(new_r1, upper_bound);
}
if (new_r1 == 0.0f || r2[0] == 0.0f) {
lr = lr * 1.0f;
} else {
lr = lr * new_r1 / r2[0];
}

KERNEL_ASSIGN(out_data[i], req, weight_data[i] - lr * g[i]);
}
};

template<typename xpu>
inline void LambUpdatePhaseTwo(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
const LambUpdatePhaseTwoParam& param = nnvm::get<LambUpdatePhaseTwoParam>(attrs.parsed);
Stream<xpu>* s = ctx.get_stream<xpu>();
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
Tensor<xpu, 2, DType> weight = inputs[0].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> g = inputs[1].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> r1 = inputs[2].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> r2 = inputs[3].FlatTo2D<xpu, DType>(s);
Tensor<xpu, 2, DType> out = outputs[0].FlatTo2D<xpu, DType>(s);

Kernel<LambUpdatePhaseTwoKernel, xpu>::Launch(s, weight.shape_.Size(),
out.dptr_, weight.dptr_, g.dptr_, r1.dptr_, r2.dptr_,
static_cast<DType>(param.lr), static_cast<DType>(param.lower_bound),
static_cast<DType>(param.upper_bound), req[0]);
});
}

// This RMSProp code follows the version in
// http://arxiv.org/pdf/1308.0850v5.pdf Eq(38) - Eq(45)
Expand Down
34 changes: 27 additions & 7 deletions src/operator/optimizer_op.cc
Expand Up @@ -43,7 +43,8 @@ DMLC_REGISTER_PARAMETER(FtrlParam);
DMLC_REGISTER_PARAMETER(SignSGDParam);
DMLC_REGISTER_PARAMETER(SignumParam);
DMLC_REGISTER_PARAMETER(AdagradParam);
DMLC_REGISTER_PARAMETER(LAMBParam);
DMLC_REGISTER_PARAMETER(LambUpdatePhaseOneParam);
DMLC_REGISTER_PARAMETER(LambUpdatePhaseTwoParam);

NNVM_REGISTER_OP(signsgd_update)
.describe(R"code(Update function for SignSGD optimizer.
Expand Down Expand Up @@ -922,20 +923,39 @@ Note that non-zero values for the weight decay option are not supported.
.add_argument("history", "NDArray-or-Symbol", "History")
.add_arguments(AdagradParam::__FIELDS__());

NNVM_REGISTER_OP(lamb_update)
NNVM_REGISTER_OP(lamb_update_phase1)
.describe(R"code(Update function for lamb optimizer.
)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LAMBParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4,1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4,1>)
.set_attr<FCompute>("FCompute<cpu>", LAMBUpdate<cpu>)
.set_attr_parser(ParamParser<LambUpdatePhaseOneParam>)
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<4, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseOne<cpu>)
.set_attr<nnvm::FMutateInputs>("FMutateInputs",
[](const nnvm::NodeAttrs& attrs) {
return std::vector<uint32_t>{2, 3};
})
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("grad", "NDArray-or-Symbol", "Gradient")
.add_argument("mean", "NDArray-or-Symbol", "Moving mean")
.add_argument("var", "NDArray-or-Symbol", "Moving variance")
.add_arguments(LAMBParam::__FIELDS__());
.add_arguments(LambUpdatePhaseOneParam::__FIELDS__());

NNVM_REGISTER_OP(lamb_update_phase2)
.describe(R"code(Update function for lamb optimizer.
)code" ADD_FILELINE)
.set_num_inputs(4)
.set_num_outputs(1)
.set_attr_parser(ParamParser<LambUpdatePhaseTwoParam>)
.set_attr<mxnet::FInferShape>("FInferShape", LambUpdatePhaseTwoShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<4, 1>)
.set_attr<FCompute>("FCompute<cpu>", LambUpdatePhaseTwo<cpu>)
.add_argument("weight", "NDArray-or-Symbol", "Weight")
.add_argument("g", "NDArray-or-Symbol", "Output of lamb_update_phase 1")
.add_argument("r1", "NDArray-or-Symbol", "r1")
.add_argument("r2", "NDArray-or-Symbol", "r2")
.add_arguments(LambUpdatePhaseTwoParam::__FIELDS__());

} // namespace op
} // namespace mxnet
8 changes: 6 additions & 2 deletions src/operator/optimizer_op.cu
Expand Up @@ -277,8 +277,12 @@ NNVM_REGISTER_OP(ftrl_update)
NNVM_REGISTER_OP(_sparse_adagrad_update)
.set_attr<FComputeEx>("FComputeEx<gpu>", AdagradUpdateEx<gpu>);

NNVM_REGISTER_OP(lamb_update)
.set_attr<FCompute>("FCompute<gpu>", LambUpdate<gpu>);
NNVM_REGISTER_OP(lamb_update_phase1)
.set_attr<FCompute>("FCompute<gpu>", LambUpdatePhaseOne<gpu>);

NNVM_REGISTER_OP(lamb_update_phase2)
.set_attr<FCompute>("FCompute<gpu>", LambUpdatePhaseTwo<gpu>);


} // namespace op
} // namespace mxnet

0 comments on commit 31d7955

Please sign in to comment.