Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Revert "[MXNET-979] Add fix_beta support in BatchNorm (#12625)" (#12789)
Browse files Browse the repository at this point in the history
This reverts commit 0bab6d5.
Because master branch started to fail with this change.
  • Loading branch information
sandeep-krishnamurthy committed Oct 11, 2018
1 parent 822e59f commit 50d2313
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 150 deletions.
2 changes: 1 addition & 1 deletion python/mxnet/gluon/nn/basic_layers.py
Expand Up @@ -324,7 +324,7 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
in_channels=0, **kwargs):
super(BatchNorm, self).__init__(**kwargs)
self._kwargs = {'axis': axis, 'eps': epsilon, 'momentum': momentum,
'fix_gamma': not scale, 'fix_beta': not center, 'use_global_stats': use_global_stats}
'fix_gamma': not scale, 'use_global_stats': use_global_stats}
if in_channels != 0:
self.in_channels = in_channels

Expand Down
5 changes: 0 additions & 5 deletions src/operator/nn/batch_norm-inl.h
Expand Up @@ -62,7 +62,6 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
double eps;
float momentum;
bool fix_gamma;
bool fix_beta;
bool use_global_stats;
bool output_mean_var;
int axis;
Expand All @@ -76,8 +75,6 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
.describe("Momentum for moving average");
DMLC_DECLARE_FIELD(fix_gamma).set_default(true)
.describe("Fix gamma while training");
DMLC_DECLARE_FIELD(fix_beta).set_default(false)
.describe("Fix beta while training");
DMLC_DECLARE_FIELD(use_global_stats).set_default(false)
.describe("Whether use global moving statistics instead of local batch-norm. "
"This will force change batch-norm into a scale shift operator.");
Expand All @@ -93,7 +90,6 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
return this->eps == other.eps &&
this->momentum == other.momentum &&
this->fix_gamma == other.fix_gamma &&
this->fix_beta == other.fix_beta &&
this->use_global_stats == other.use_global_stats &&
this->output_mean_var == other.output_mean_var &&
this->axis == other.axis &&
Expand All @@ -111,7 +107,6 @@ struct hash<mxnet::op::BatchNormParam> {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.momentum);
ret = dmlc::HashCombine(ret, val.fix_gamma);
ret = dmlc::HashCombine(ret, val.fix_beta);
ret = dmlc::HashCombine(ret, val.use_global_stats);
ret = dmlc::HashCombine(ret, val.output_mean_var);
ret = dmlc::HashCombine(ret, val.axis);
Expand Down
59 changes: 26 additions & 33 deletions src/operator/nn/batch_norm.cc
Expand Up @@ -155,34 +155,35 @@ void BatchNormForwardImpl(mshadow::Stream<cpu> *,

// compute output
AccReal *w = weights.dptr<AccReal>();
AccReal *b = bias.dptr<AccReal>();

// Ignore gamma
if (param_.fix_gamma) {
if (IsBNWriting(req[batchnorm::kGamma])) {
w[channel] = AccReal(1);
}
}

// Ignore beta
if (param_.fix_beta) {
if (IsBNWriting(req[batchnorm::kBeta])) {
b[channel] = AccReal(0);
}
}
const AccReal *b = bias.dptr<AccReal>();

const AccReal thisMean = mean[channel];
const AccReal thisInvstd = var[channel];
const AccReal thisWeight = w[channel];
const AccReal thisBias = b[channel];

if (IsBNWriting(req[batchnorm::kData])) {
ForEachFast(inputData, outputData, channel,
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
DType *out_data) {
*out_data = static_cast<DType>(
((*in_data - thisMean) * thisInvstd) * thisWeight + thisBias);
});
// note that var is still invstd
if (!param_.fix_gamma) {
if (IsBNWriting(req[batchnorm::kData])) {
ForEachFast(inputData, outputData, channel,
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
DType *out_data) {
*out_data = static_cast<DType>(
((*in_data - thisMean) * thisInvstd) * thisWeight + thisBias);
});
}
} else {
if (IsBNWriting(req[batchnorm::kGamma])) {
w[channel] = AccReal(1);
}
if (IsBNWriting(req[batchnorm::kData])) {
ForEachFast(inputData, outputData, channel,
[thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data,
DType *out_data) {
*out_data = static_cast<DType>(
((*in_data - thisMean) * thisInvstd) + thisBias);
});
}
}
}
}
Expand Down Expand Up @@ -308,11 +309,7 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *,
}

if (IsBNWriting(req[batchnorm::kBeta])) {
if (!param_.fix_beta) {
gradBiasData[channel] = scale * sumGradOut;
} else {
gradBiasData[channel] = AccReal(0);
}
gradBiasData[channel] = scale * sumGradOut;
}
}
}
Expand Down Expand Up @@ -481,9 +478,6 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs,
if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.fix_gamma) {
LOG(FATAL) << "fix_gamma=True is not supported for sparse ndarrays. Tracked at #11647";
}
if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.fix_beta) {
LOG(FATAL) << "fix_beta=True is not supported for sparse ndarrays. Tracked at #11647";
}
return dispatched;
}

Expand Down Expand Up @@ -571,12 +565,11 @@ the 'channel' (separately normalized groups). The default is 1. Specifying -1
axis to be the last item in the input shape.
Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true,
then set ``gamma`` to 1 and its gradient to 0. If ``fix_beta`` is true, then set ``beta`` to 0
and its gradient to 0.
then set ``gamma`` to 1 and its gradient to 0.
Note::
When fix_gamma/fix_beta is set to True, no sparse support is provided. If fix_gamma/fix_beta is set to False,
When fix_gamma is set to True, no sparse support is provided. If fix_gamma is set to False,
the sparse tensors will fallback.
)code" ADD_FILELINE)
Expand Down
30 changes: 7 additions & 23 deletions src/operator/nn/batch_norm.cu
Expand Up @@ -32,9 +32,8 @@
#define WRITE_GAMMA_FLAG 2
#define WRITE_BETA_FLAG 4
#define FIX_GAMMA_FLAG 8
#define FIX_BETA_FLAG 16
#define IS_TRAINING_FLAG 32
#define USE_GLOBAL_STATS_FLAG 64
#define IS_TRAINING_FLAG 16
#define USE_GLOBAL_STATS_FLAG 32

#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
#include "./cudnn/cudnn_batch_norm-inl.h"
Expand Down Expand Up @@ -224,20 +223,15 @@ __global__ void BatchNormalizationUpdateOutputInferenceKernel(
AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0)
? ScalarConvert<DType, AccReal>::to(weight[plane])
: ScalarConvert<int, AccReal>::to(1);
AccReal beta = ((flags & FIX_BETA_FLAG) == 0 && bias.numElements() > 0)
? ScalarConvert<DType, AccReal>::to(bias[plane])
: ScalarConvert<int, AccReal>::to(0);
AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType, AccReal>::to(bias[plane])
: ScalarConvert<int, AccReal>::to(0);
if (threadIdx.x == 0) {
saveMean[plane] = runningMean[plane];
saveInvStd[plane] = VARIANCE_TO_INVSTD(runningVar[plane], epsilon);
if ((flags & WRITE_GAMMA_FLAG) != 0 && (flags & FIX_GAMMA_FLAG) != 0
&& weight.numElements() > 0) {
weight[plane] = AccReal(1);
}
if ((flags & WRITE_BETA_FLAG) != 0 && (flags & FIX_BETA_FLAG) != 0
&& bias.numElements() > 0) {
bias[plane] = AccReal(0);
}
}
// Write normalized and update the output
for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
Expand Down Expand Up @@ -288,19 +282,14 @@ __global__ void BatchNormalizationUpdateOutputKernel(
&& weight.numElements() > 0) {
weight[plane] = AccReal(1);
}
if ((flags & WRITE_BETA_FLAG) != 0 && (flags & FIX_BETA_FLAG) != 0
&& bias.numElements() > 0) {
bias[plane] = AccReal(0);
}
}

// Write normalized and update the output
const AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0)
? ScalarConvert<DType, AccReal>::to(weight[plane])
: ScalarConvert<int, AccReal>::to(1);
const AccReal beta = ((flags & FIX_BETA_FLAG) == 0 && bias.numElements() > 0)
? ScalarConvert<DType, AccReal>::to(bias[plane])
: ScalarConvert<int, AccReal>::to(0);
const AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType, AccReal>::to(bias[plane])
: ScalarConvert<int, AccReal>::to(0);
for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = input.InnerSize(); x < nx; x += blockDim.x) {
const DType inp = input.get_ref(batch, plane, x);
Expand Down Expand Up @@ -399,11 +388,7 @@ static __global__ void BatchNormalizationBackwardKernel(
}

if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) {
if ((flags & FIX_BETA_FLAG) == 0) {
tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
} else {
tensors.gradBias[plane] = DType(0);
}
tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
}
}

Expand Down Expand Up @@ -597,7 +582,6 @@ static inline uint32_t SetupFlags(const OpContext &ctx,
uint32_t flags = 0;
flags |= ctx.is_train ? IS_TRAINING_FLAG : 0;
flags |= params.fix_gamma ? FIX_GAMMA_FLAG : 0;
flags |= params.fix_beta ? FIX_BETA_FLAG : 0;
flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0;
if (IsBNWriting(req[batchnorm::kData])) {
flags |= WRITE_DATA_FLAG;
Expand Down
4 changes: 0 additions & 4 deletions src/operator/nn/cudnn/cudnn_batch_norm-inl.h
Expand Up @@ -115,8 +115,6 @@ class CuDNNBatchNormOp {

if (param_.fix_gamma) gamma = 1.f;

if (param_.fix_beta) beta = 0.f;

if (ctx.is_train) {
Tensor<gpu, 1, DTypeParam> save_mean =
out_data[cudnnbatchnorm::kMean].get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
Expand Down Expand Up @@ -231,7 +229,6 @@ class CuDNNBatchNormOp {
global_stats ? nullptr : save_mean.dptr_,
global_stats ? nullptr : save_inv_var.dptr_));
if (param_.fix_gamma) dgamma = 0.f;
if (param_.fix_beta) dbeta = 0.f;
})
#else // CUDNN_VERSION < 4007
MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
Expand Down Expand Up @@ -270,7 +267,6 @@ class CuDNNBatchNormOp {
global_stats ? nullptr : save_mean.dptr_,
global_stats ? nullptr : save_inv_var.dptr_));
if (param_.fix_gamma) dgamma = 0.f;
if (param_.fix_beta) dbeta = 0.f;
})
#endif
}
Expand Down
92 changes: 33 additions & 59 deletions tests/python/gpu/test_operator_gpu.py
Expand Up @@ -303,52 +303,35 @@ def test_batchnorm_with_type():


# V2, 2D
sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=False, cudnn_off=True)
sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
check_consistency(sym, ctx_list_v2_2D)
sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True, cudnn_off=True)
sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
check_consistency(sym, ctx_list_v2_2D)
sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False, cudnn_off=True)
check_consistency(sym, ctx_list_v2_2D)
sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=True, cudnn_off=True)
check_consistency(sym, ctx_list_v2_2D)
# Don't specify fix_beta. Default i.e., fix_beta=False will be verified.
sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
check_consistency(sym, ctx_list_v2_2D)
# Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified.
sym = mx.sym.BatchNorm(name='norm', fix_beta=True, cudnn_off=True)
sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
check_consistency(sym, ctx_list_v2_2D)

# V2, 1D
sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=False, cudnn_off=True)
check_consistency(sym, ctx_list_v2_1D)
sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True, cudnn_off=True)
sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
check_consistency(sym, ctx_list_v2_1D)
sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False, cudnn_off=True)
sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
check_consistency(sym, ctx_list_v2_1D)
sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=True, cudnn_off=True)
check_consistency(sym, ctx_list_v2_1D)
# Don't specify fix_beta. Default i.e., fix_beta=False will be verified.
sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
check_consistency(sym, ctx_list_v2_1D)
# Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified.
sym = mx.sym.BatchNorm(name='norm', fix_beta=True, cudnn_off=True)
sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
check_consistency(sym, ctx_list_v2_1D)

#
# # V2, 3D
sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True, cudnn_off=True)
sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
check_consistency(sym, ctx_list_v2_3D)
sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False, cudnn_off=True)
check_consistency(sym, ctx_list_v2_3D)
# Don't specify fix_beta. Default i.e., fix_beta=False will be verified.
sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
check_consistency(sym, ctx_list_v2_3D)
# Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified.
sym = mx.sym.BatchNorm(name='norm', fix_beta=False, cudnn_off=True)
check_consistency(sym, ctx_list_v2_3D)


@with_seed()
def test_batchnorm_versions():
def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, fix_beta, use_global_stats):
def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, use_global_stats):
ctx_list = []
sym_list = []
# BatchNormV1 cpu
Expand All @@ -369,70 +352,61 @@ def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, fix_beta,
if 'batchnorm_cpu' in batchnorm_op_list:
ctx_list.append({'ctx': mx.cpu(0), 'batchnorm_data': data, 'type_dict': {'batchnorm_data': np.float32}})
sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma,
fix_beta=fix_beta,
use_global_stats=use_global_stats,
name='batchnorm'))

# BatchNorm gpu (organic)
if 'batchnorm_gpu' in batchnorm_op_list:
ctx_list.append({'ctx': mx.gpu(0), 'batchnorm_data': data, 'type_dict': {'batchnorm_data': np.float32}})
sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma,
fix_beta=fix_beta,
use_global_stats=use_global_stats,
name='batchnorm', cudnn_off=True))

# BatchNorm gpu cudnn (if cudnn is enabled)
if 'batchnorm_cudnn' in batchnorm_op_list:
ctx_list.append({'ctx': mx.gpu(0), 'batchnorm_data': data, 'type_dict': {'batchnorm_data': np.float32}})
sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma,
fix_beta=fix_beta,
use_global_stats=use_global_stats,
name='batchnorm', cudnn_off=False))

check_consistency(sym_list, ctx_list)

def test_1d_batchnorm(fix_gamma, fix_beta, use_global_stats):
def test_1d_batchnorm(fix_gamma, use_global_stats):
data = (2, 3, 20)
test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu',
'batchnorm_gpu', 'batchnorm_cudnn'],
data=data,
fix_gamma=fix_gamma, fix_beta=fix_beta, use_global_stats=use_global_stats)
fix_gamma=fix_gamma, use_global_stats=use_global_stats)

def test_2d_batchnorm(fix_gamma, fix_beta, use_global_stats):
def test_2d_batchnorm(fix_gamma, use_global_stats):
data = (2, 3, 10, 10)
# batchmorm_v1 is deprecated.
# `fix_beta` parameter is available only in new batchnorm operator.
# Checking consistency separately for batchnormv1 and batchnorm.
test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_v1_cpu', 'batchnorm_v1_gpu'],
data=data,
fix_gamma=fix_gamma, fix_beta=fix_beta, use_global_stats=use_global_stats)

test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu',
test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_v1_cpu', 'batchnorm_v1_gpu',
'batchnorm_cpu',
'batchnorm_gpu', 'batchnorm_cudnn'],
data=data,
fix_gamma=fix_gamma, fix_beta=fix_beta, use_global_stats=use_global_stats)
fix_gamma=fix_gamma, use_global_stats=use_global_stats)

def test_3d_batchnorm(fix_gamma, fix_beta, use_global_stats):
def test_3d_batchnorm(fix_gamma, use_global_stats):
data = (2, 3, 3, 5, 5)
test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu',
'batchnorm_gpu'],
data=data,
fix_gamma=fix_gamma, fix_beta=fix_beta, use_global_stats=use_global_stats)

test_1d_batchnorm(True, False, False)
test_1d_batchnorm(False, True, False)
test_1d_batchnorm(False, False, True)
test_1d_batchnorm(True, True, True)

test_2d_batchnorm(True, False, False)
test_2d_batchnorm(False, True, False)
test_2d_batchnorm(False, False, True)
test_2d_batchnorm(True, True, True)

test_3d_batchnorm(True, False, False)
test_3d_batchnorm(False, True, False)
test_3d_batchnorm(False, False, True)
test_3d_batchnorm(True, True, True)
fix_gamma=fix_gamma, use_global_stats=use_global_stats)

test_1d_batchnorm(True, False)
test_1d_batchnorm(False, False)
test_1d_batchnorm(False, True)
test_1d_batchnorm(True, True)

test_2d_batchnorm(True, False)
test_2d_batchnorm(False, False)
test_2d_batchnorm(False, True)
test_2d_batchnorm(True, True)

test_3d_batchnorm(True, False)
test_3d_batchnorm(False, False)
test_3d_batchnorm(False, True)
test_3d_batchnorm(True, True)


@with_seed(1234)
Expand Down

0 comments on commit 50d2313

Please sign in to comment.