Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-979] Add fix_beta support in BatchNorm #12625

2 changes: 1 addition & 1 deletion python/mxnet/gluon/nn/basic_layers.py
Expand Up @@ -324,7 +324,7 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
in_channels=0, **kwargs):
super(BatchNorm, self).__init__(**kwargs)
self._kwargs = {'axis': axis, 'eps': epsilon, 'momentum': momentum,
'fix_gamma': not scale, 'use_global_stats': use_global_stats}
'fix_gamma': not scale, 'fix_beta': not center, 'use_global_stats': use_global_stats}
if in_channels != 0:
self.in_channels = in_channels

Expand Down
5 changes: 5 additions & 0 deletions src/operator/nn/batch_norm-inl.h
Expand Up @@ -62,6 +62,7 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
double eps;
float momentum;
bool fix_gamma;
bool fix_beta;
bool use_global_stats;
bool output_mean_var;
int axis;
Expand All @@ -75,6 +76,8 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
.describe("Momentum for moving average");
DMLC_DECLARE_FIELD(fix_gamma).set_default(true)
.describe("Fix gamma while training");
DMLC_DECLARE_FIELD(fix_beta).set_default(false)
.describe("Fix beta while training");
DMLC_DECLARE_FIELD(use_global_stats).set_default(false)
.describe("Whether use global moving statistics instead of local batch-norm. "
"This will force change batch-norm into a scale shift operator.");
Expand All @@ -90,6 +93,7 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
return this->eps == other.eps &&
this->momentum == other.momentum &&
this->fix_gamma == other.fix_gamma &&
this->fix_beta == other.fix_beta &&
this->use_global_stats == other.use_global_stats &&
this->output_mean_var == other.output_mean_var &&
this->axis == other.axis &&
Expand All @@ -107,6 +111,7 @@ struct hash<mxnet::op::BatchNormParam> {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.momentum);
ret = dmlc::HashCombine(ret, val.fix_gamma);
ret = dmlc::HashCombine(ret, val.fix_beta);
ret = dmlc::HashCombine(ret, val.use_global_stats);
ret = dmlc::HashCombine(ret, val.output_mean_var);
ret = dmlc::HashCombine(ret, val.axis);
Expand Down
84 changes: 66 additions & 18 deletions src/operator/nn/batch_norm.cc
Expand Up @@ -155,7 +155,7 @@ void BatchNormForwardImpl(mshadow::Stream<cpu> *,

// compute output
AccReal *w = weights.dptr<AccReal>();
const AccReal *b = bias.dptr<AccReal>();
AccReal *b = bias.dptr<AccReal>();

const AccReal thisMean = mean[channel];
const AccReal thisInvstd = var[channel];
Expand All @@ -164,25 +164,65 @@ void BatchNormForwardImpl(mshadow::Stream<cpu> *,

// note that var is still invstd
if (!param_.fix_gamma) {
if (IsBNWriting(req[batchnorm::kData])) {
ForEachFast(inputData, outputData, channel,
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
DType *out_data) {
*out_data = static_cast<DType>(
((*in_data - thisMean) * thisInvstd) * thisWeight + thisBias);
});
if (!param_.fix_beta) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you really need 4 cases rather than collate 2 cases separately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is basically 2 cases based on fix_gamma. fix_beta cases are collated within it.
This has more readability with flags on the cases. Suggestions?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix gamma and beta only involves whether including (*gamma) and (+bias), but here's more than 50 lines only to accomplish that. It should be fine with the current lambda function with one or two cases, but with 4 cases or maybe 8 cases in the future, it's going to cause a lot of trouble.

// Case 1
// fix_gamma = False
// fix_beta = False
if (IsBNWriting(req[batchnorm::kData])) {
ForEachFast(inputData, outputData, channel,
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using lambda, I think maybe it's better to define a function here since it is used in all 4 cases. A function with proper name will make the code more readable.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, this change would make readability difficult and pass more params unnecessarily across functions.

DType *out_data) {
*out_data = static_cast<DType>(
((*in_data - thisMean) * thisInvstd) * thisWeight + thisBias);
});
}
} else {
// Case 2
// fix_gamma = False
// fix_beta = True
if (IsBNWriting(req[batchnorm::kBeta])) {
b[channel] = AccReal(0);
}
if (IsBNWriting(req[batchnorm::kData])) {
ForEachFast(inputData, outputData, channel,
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
DType *out_data) {
*out_data = static_cast<DType>(
((*in_data - thisMean) * thisInvstd) * thisWeight);
});
}
}
} else {
if (IsBNWriting(req[batchnorm::kGamma])) {
w[channel] = AccReal(1);
}
if (IsBNWriting(req[batchnorm::kData])) {
ForEachFast(inputData, outputData, channel,
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
DType *out_data) {
*out_data = static_cast<DType>(
((*in_data - thisMean) * thisInvstd) + thisBias);
});
if (!param_.fix_beta) {
// Case 3
// fix_gamma = True
// fix_beta = False
if (IsBNWriting(req[batchnorm::kData])) {
ForEachFast(inputData, outputData, channel,
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
DType *out_data) {
*out_data = static_cast<DType>(
((*in_data - thisMean) * thisInvstd) + thisBias);
});
}
} else {
// Case 4
// fix_gamma = True
// fix_beta = True
if (IsBNWriting(req[batchnorm::kBeta])) {
b[channel] = AccReal(0);
}
if (IsBNWriting(req[batchnorm::kData])) {
ForEachFast(inputData, outputData, channel,
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
DType *out_data) {
*out_data = static_cast<DType>(
((*in_data - thisMean) * thisInvstd));
});
}
}
}
}
Expand Down Expand Up @@ -309,7 +349,11 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *,
}

if (IsBNWriting(req[batchnorm::kBeta])) {
gradBiasData[channel] = scale * sumGradOut;
if (!param_.fix_beta) {
gradBiasData[channel] = scale * sumGradOut;
} else {
gradBiasData[channel] = AccReal(0);
}
}
}
}
Expand Down Expand Up @@ -478,6 +522,9 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs,
if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.fix_gamma) {
LOG(FATAL) << "fix_gamma=True is not supported for sparse ndarrays. Tracked at #11647";
}
if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.fix_beta) {
LOG(FATAL) << "fix_beta=True is not supported for sparse ndarrays. Tracked at #11647";
}
return dispatched;
}

Expand Down Expand Up @@ -565,11 +612,12 @@ the 'channel' (separately normalized groups). The default is 1. Specifying -1
axis to be the last item in the input shape.

Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true,
then set ``gamma`` to 1 and its gradient to 0.
then set ``gamma`` to 1 and its gradient to 0. If ``fix_beta`` is true, then set ``beta`` to 0
and its gradient to 0.

Note::

When fix_gamma is set to True, no sparse support is provided. If fix_gamma is set to False,
When fix_gamma/fix_beta is set to True, no sparse support is provided. If fix_gamma/fix_beta is set to False,
the sparse tensors will fallback.

)code" ADD_FILELINE)
Expand Down
30 changes: 23 additions & 7 deletions src/operator/nn/batch_norm.cu
Expand Up @@ -32,8 +32,9 @@
#define WRITE_GAMMA_FLAG 2
#define WRITE_BETA_FLAG 4
#define FIX_GAMMA_FLAG 8
#define IS_TRAINING_FLAG 16
#define USE_GLOBAL_STATS_FLAG 32
#define FIX_BETA_FLAG 16
#define IS_TRAINING_FLAG 32
#define USE_GLOBAL_STATS_FLAG 64

#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
#include "./cudnn/cudnn_batch_norm-inl.h"
Expand Down Expand Up @@ -223,15 +224,20 @@ __global__ void BatchNormalizationUpdateOutputInferenceKernel(
AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0)
? ScalarConvert<DType, AccReal>::to(weight[plane])
: ScalarConvert<int, AccReal>::to(1);
AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType, AccReal>::to(bias[plane])
: ScalarConvert<int, AccReal>::to(0);
AccReal beta = ((flags & FIX_BETA_FLAG) == 0 && bias.numElements() > 0)
? ScalarConvert<DType, AccReal>::to(bias[plane])
: ScalarConvert<int, AccReal>::to(0);
if (threadIdx.x == 0) {
saveMean[plane] = runningMean[plane];
saveInvStd[plane] = VARIANCE_TO_INVSTD(runningVar[plane], epsilon);
if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0
&& weight.numElements() > 0) {
weight[plane] = AccReal(1);
}
if ((flags & WRITE_BETA_FLAG) != 0 && (flags & FIX_BETA_FLAG) != 0
&& bias.numElements() > 0) {
bias[plane] = AccReal(0);
}
}
// Write normalized and update the output
for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
Expand Down Expand Up @@ -282,14 +288,19 @@ __global__ void BatchNormalizationUpdateOutputKernel(
&& weight.numElements() > 0) {
weight[plane] = AccReal(1);
}
if ((flags & WRITE_BETA_FLAG) != 0 && (flags & FIX_BETA_FLAG) != 0
&& bias.numElements() > 0) {
bias[plane] = AccReal(0);
}
}

// Write normalized and update the output
const AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0)
? ScalarConvert<DType, AccReal>::to(weight[plane])
: ScalarConvert<int, AccReal>::to(1);
const AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType, AccReal>::to(bias[plane])
: ScalarConvert<int, AccReal>::to(0);
const AccReal beta = ((flags & FIX_BETA_FLAG) == 0 && bias.numElements() > 0)
? ScalarConvert<DType, AccReal>::to(bias[plane])
: ScalarConvert<int, AccReal>::to(0);
for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = input.InnerSize(); x < nx; x += blockDim.x) {
const DType inp = input.get_ref(batch, plane, x);
Expand Down Expand Up @@ -388,7 +399,11 @@ static __global__ void BatchNormalizationBackwardKernel(
}

if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) {
tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
if ((flags & FIX_BETA_FLAG) == 0) {
tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
} else {
tensors.gradBias[plane] = DType(0);
}
}
}

Expand Down Expand Up @@ -582,6 +597,7 @@ static inline uint32_t SetupFlags(const OpContext &ctx,
uint32_t flags = 0;
flags |= ctx.is_train ? IS_TRAINING_FLAG : 0;
flags |= params.fix_gamma ? FIX_GAMMA_FLAG : 0;
flags |= params.fix_beta ? FIX_BETA_FLAG : 0;
flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0;
if (IsBNWriting(req[batchnorm::kData])) {
flags |= WRITE_DATA_FLAG;
Expand Down
4 changes: 4 additions & 0 deletions src/operator/nn/cudnn/cudnn_batch_norm-inl.h
Expand Up @@ -115,6 +115,8 @@ class CuDNNBatchNormOp {

if (param_.fix_gamma) gamma = 1.f;

if (param_.fix_beta) beta = 0.f;

if (ctx.is_train) {
Tensor<gpu, 1, DTypeParam> save_mean =
out_data[cudnnbatchnorm::kMean].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
Expand Down Expand Up @@ -229,6 +231,7 @@ class CuDNNBatchNormOp {
global_stats ? nullptr : save_mean.dptr_,
global_stats ? nullptr : save_inv_var.dptr_));
if (param_.fix_gamma) dgamma = 0.f;
if (param_.fix_beta) dbeta = 0.f;
})
#else // CUDNN_VERSION < 4007
MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
Expand Down Expand Up @@ -267,6 +270,7 @@ class CuDNNBatchNormOp {
global_stats ? nullptr : save_mean.dptr_,
global_stats ? nullptr : save_inv_var.dptr_));
if (param_.fix_gamma) dgamma = 0.f;
if (param_.fix_beta) dbeta = 0.f;
})
#endif
}
Expand Down