diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index 26ef64dfd0bd..d26841977ac2 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -324,7 +324,7 @@ def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True, in_channels=0, **kwargs): super(BatchNorm, self).__init__(**kwargs) self._kwargs = {'axis': axis, 'eps': epsilon, 'momentum': momentum, - 'fix_gamma': not scale, 'fix_beta': not center, 'use_global_stats': use_global_stats} + 'fix_gamma': not scale, 'use_global_stats': use_global_stats} if in_channels != 0: self.in_channels = in_channels diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h index f8b381c87bef..3f47d58bb8c3 100644 --- a/src/operator/nn/batch_norm-inl.h +++ b/src/operator/nn/batch_norm-inl.h @@ -62,7 +62,6 @@ struct BatchNormParam : public dmlc::Parameter { double eps; float momentum; bool fix_gamma; - bool fix_beta; bool use_global_stats; bool output_mean_var; int axis; @@ -76,8 +75,6 @@ struct BatchNormParam : public dmlc::Parameter { .describe("Momentum for moving average"); DMLC_DECLARE_FIELD(fix_gamma).set_default(true) .describe("Fix gamma while training"); - DMLC_DECLARE_FIELD(fix_beta).set_default(false) - .describe("Fix beta while training"); DMLC_DECLARE_FIELD(use_global_stats).set_default(false) .describe("Whether use global moving statistics instead of local batch-norm. " "This will force change batch-norm into a scale shift operator."); @@ -93,7 +90,6 @@ struct BatchNormParam : public dmlc::Parameter { return this->eps == other.eps && this->momentum == other.momentum && this->fix_gamma == other.fix_gamma && - this->fix_beta == other.fix_beta && this->use_global_stats == other.use_global_stats && this->output_mean_var == other.output_mean_var && this->axis == other.axis && @@ -111,7 +107,6 @@ struct hash { size_t ret = 0; ret = dmlc::HashCombine(ret, val.momentum); ret = dmlc::HashCombine(ret, val.fix_gamma); - ret = dmlc::HashCombine(ret, val.fix_beta); ret = dmlc::HashCombine(ret, val.use_global_stats); ret = dmlc::HashCombine(ret, val.output_mean_var); ret = dmlc::HashCombine(ret, val.axis); diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index ec90a3092092..be542ba5b6be 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -155,34 +155,35 @@ void BatchNormForwardImpl(mshadow::Stream *, // compute output AccReal *w = weights.dptr(); - AccReal *b = bias.dptr(); - - // Ignore gamma - if (param_.fix_gamma) { - if (IsBNWriting(req[batchnorm::kGamma])) { - w[channel] = AccReal(1); - } - } - - // Ignore beta - if (param_.fix_beta) { - if (IsBNWriting(req[batchnorm::kBeta])) { - b[channel] = AccReal(0); - } - } + const AccReal *b = bias.dptr(); const AccReal thisMean = mean[channel]; const AccReal thisInvstd = var[channel]; const AccReal thisWeight = w[channel]; const AccReal thisBias = b[channel]; - if (IsBNWriting(req[batchnorm::kData])) { - ForEachFast(inputData, outputData, channel, - [thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data, - DType *out_data) { - *out_data = static_cast( - ((*in_data - thisMean) * thisInvstd) * thisWeight + thisBias); - }); + // note that var is still invstd + if (!param_.fix_gamma) { + if (IsBNWriting(req[batchnorm::kData])) { + ForEachFast(inputData, outputData, channel, + [thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data, + DType *out_data) { + *out_data = static_cast( + ((*in_data - thisMean) * thisInvstd) * thisWeight + thisBias); + }); + } + } else { + if (IsBNWriting(req[batchnorm::kGamma])) { + w[channel] = AccReal(1); + } + if (IsBNWriting(req[batchnorm::kData])) { + ForEachFast(inputData, outputData, channel, + [thisWeight, thisBias, thisMean, thisInvstd](const DType *in_data, + DType *out_data) { + *out_data = static_cast( + ((*in_data - thisMean) * thisInvstd) + thisBias); + }); + } } } } @@ -308,11 +309,7 @@ void BatchNormBackwardImpl(mshadow::Stream *, } if (IsBNWriting(req[batchnorm::kBeta])) { - if (!param_.fix_beta) { - gradBiasData[channel] = scale * sumGradOut; - } else { - gradBiasData[channel] = AccReal(0); - } + gradBiasData[channel] = scale * sumGradOut; } } } @@ -481,9 +478,6 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs, if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.fix_gamma) { LOG(FATAL) << "fix_gamma=True is not supported for sparse ndarrays. Tracked at #11647"; } - if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && param.fix_beta) { - LOG(FATAL) << "fix_beta=True is not supported for sparse ndarrays. Tracked at #11647"; - } return dispatched; } @@ -571,12 +565,11 @@ the 'channel' (separately normalized groups). The default is 1. Specifying -1 axis to be the last item in the input shape. Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true, -then set ``gamma`` to 1 and its gradient to 0. If ``fix_beta`` is true, then set ``beta`` to 0 -and its gradient to 0. +then set ``gamma`` to 1 and its gradient to 0. Note:: -When fix_gamma/fix_beta is set to True, no sparse support is provided. If fix_gamma/fix_beta is set to False, +When fix_gamma is set to True, no sparse support is provided. If fix_gamma is set to False, the sparse tensors will fallback. )code" ADD_FILELINE) diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu index 309542d33c2b..03962cbc0f33 100644 --- a/src/operator/nn/batch_norm.cu +++ b/src/operator/nn/batch_norm.cu @@ -32,9 +32,8 @@ #define WRITE_GAMMA_FLAG 2 #define WRITE_BETA_FLAG 4 #define FIX_GAMMA_FLAG 8 -#define FIX_BETA_FLAG 16 -#define IS_TRAINING_FLAG 32 -#define USE_GLOBAL_STATS_FLAG 64 +#define IS_TRAINING_FLAG 16 +#define USE_GLOBAL_STATS_FLAG 32 #if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5 #include "./cudnn/cudnn_batch_norm-inl.h" @@ -224,9 +223,8 @@ __global__ void BatchNormalizationUpdateOutputInferenceKernel( AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0) ? ScalarConvert::to(weight[plane]) : ScalarConvert::to(1); - AccReal beta = ((flags & FIX_BETA_FLAG) == 0 && bias.numElements() > 0) - ? ScalarConvert::to(bias[plane]) - : ScalarConvert::to(0); + AccReal beta = bias.numElements() > 0 ? ScalarConvert::to(bias[plane]) + : ScalarConvert::to(0); if (threadIdx.x == 0) { saveMean[plane] = runningMean[plane]; saveInvStd[plane] = VARIANCE_TO_INVSTD(runningVar[plane], epsilon); @@ -234,10 +232,6 @@ __global__ void BatchNormalizationUpdateOutputInferenceKernel( && weight.numElements() > 0) { weight[plane] = AccReal(1); } - if ((flags & WRITE_BETA_FLAG) != 0 && (flags & FIX_BETA_FLAG) != 0 - && bias.numElements() > 0) { - bias[plane] = AccReal(0); - } } // Write normalized and update the output for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) { @@ -288,19 +282,14 @@ __global__ void BatchNormalizationUpdateOutputKernel( && weight.numElements() > 0) { weight[plane] = AccReal(1); } - if ((flags & WRITE_BETA_FLAG) != 0 && (flags & FIX_BETA_FLAG) != 0 - && bias.numElements() > 0) { - bias[plane] = AccReal(0); - } } // Write normalized and update the output const AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0) ? ScalarConvert::to(weight[plane]) : ScalarConvert::to(1); - const AccReal beta = ((flags & FIX_BETA_FLAG) == 0 && bias.numElements() > 0) - ? ScalarConvert::to(bias[plane]) - : ScalarConvert::to(0); + const AccReal beta = bias.numElements() > 0 ? ScalarConvert::to(bias[plane]) + : ScalarConvert::to(0); for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) { for (int x = threadIdx.x, nx = input.InnerSize(); x < nx; x += blockDim.x) { const DType inp = input.get_ref(batch, plane, x); @@ -399,11 +388,7 @@ static __global__ void BatchNormalizationBackwardKernel( } if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & WRITE_BETA_FLAG) != 0) { - if ((flags & FIX_BETA_FLAG) == 0) { - tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); - } else { - tensors.gradBias[plane] = DType(0); - } + tensors.gradBias[plane] = ScalarConvert::to(gradOutputSum); } } @@ -597,7 +582,6 @@ static inline uint32_t SetupFlags(const OpContext &ctx, uint32_t flags = 0; flags |= ctx.is_train ? IS_TRAINING_FLAG : 0; flags |= params.fix_gamma ? FIX_GAMMA_FLAG : 0; - flags |= params.fix_beta ? FIX_BETA_FLAG : 0; flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0; if (IsBNWriting(req[batchnorm::kData])) { flags |= WRITE_DATA_FLAG; diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 9caa9d3ddd3c..d4b9f84ed2f5 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -115,8 +115,6 @@ class CuDNNBatchNormOp { if (param_.fix_gamma) gamma = 1.f; - if (param_.fix_beta) beta = 0.f; - if (ctx.is_train) { Tensor save_mean = out_data[cudnnbatchnorm::kMean].get_with_shape(Shape1(shape_[1]), s); @@ -231,7 +229,6 @@ class CuDNNBatchNormOp { global_stats ? nullptr : save_mean.dptr_, global_stats ? nullptr : save_inv_var.dptr_)); if (param_.fix_gamma) dgamma = 0.f; - if (param_.fix_beta) dbeta = 0.f; }) #else // CUDNN_VERSION < 4007 MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, { @@ -270,7 +267,6 @@ class CuDNNBatchNormOp { global_stats ? nullptr : save_mean.dptr_, global_stats ? nullptr : save_inv_var.dptr_)); if (param_.fix_gamma) dgamma = 0.f; - if (param_.fix_beta) dbeta = 0.f; }) #endif } diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 13022c108c4f..dd7ec985c7c8 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -303,52 +303,35 @@ def test_batchnorm_with_type(): # V2, 2D - sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=False, cudnn_off=True) + sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True) check_consistency(sym, ctx_list_v2_2D) - sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True, cudnn_off=True) + sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True) check_consistency(sym, ctx_list_v2_2D) - sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False, cudnn_off=True) - check_consistency(sym, ctx_list_v2_2D) - sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=True, cudnn_off=True) - check_consistency(sym, ctx_list_v2_2D) - # Don't specify fix_beta. Default i.e., fix_beta=False will be verified. sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True) check_consistency(sym, ctx_list_v2_2D) - # Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified. - sym = mx.sym.BatchNorm(name='norm', fix_beta=True, cudnn_off=True) + sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True) check_consistency(sym, ctx_list_v2_2D) # V2, 1D - sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=False, cudnn_off=True) - check_consistency(sym, ctx_list_v2_1D) - sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True, cudnn_off=True) + sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True) check_consistency(sym, ctx_list_v2_1D) - sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False, cudnn_off=True) + sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True) check_consistency(sym, ctx_list_v2_1D) - sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=True, cudnn_off=True) - check_consistency(sym, ctx_list_v2_1D) - # Don't specify fix_beta. Default i.e., fix_beta=False will be verified. sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True) check_consistency(sym, ctx_list_v2_1D) - # Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified. - sym = mx.sym.BatchNorm(name='norm', fix_beta=True, cudnn_off=True) + sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True) check_consistency(sym, ctx_list_v2_1D) - + # # # V2, 3D - sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True, cudnn_off=True) + sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True) check_consistency(sym, ctx_list_v2_3D) - sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False, cudnn_off=True) - check_consistency(sym, ctx_list_v2_3D) - # Don't specify fix_beta. Default i.e., fix_beta=False will be verified. sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True) check_consistency(sym, ctx_list_v2_3D) - # Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified. - sym = mx.sym.BatchNorm(name='norm', fix_beta=False, cudnn_off=True) - check_consistency(sym, ctx_list_v2_3D) + @with_seed() def test_batchnorm_versions(): - def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, fix_beta, use_global_stats): + def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, use_global_stats): ctx_list = [] sym_list = [] # BatchNormV1 cpu @@ -369,7 +352,6 @@ def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, fix_beta, if 'batchnorm_cpu' in batchnorm_op_list: ctx_list.append({'ctx': mx.cpu(0), 'batchnorm_data': data, 'type_dict': {'batchnorm_data': np.float32}}) sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma, - fix_beta=fix_beta, use_global_stats=use_global_stats, name='batchnorm')) @@ -377,7 +359,6 @@ def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, fix_beta, if 'batchnorm_gpu' in batchnorm_op_list: ctx_list.append({'ctx': mx.gpu(0), 'batchnorm_data': data, 'type_dict': {'batchnorm_data': np.float32}}) sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma, - fix_beta=fix_beta, use_global_stats=use_global_stats, name='batchnorm', cudnn_off=True)) @@ -385,54 +366,47 @@ def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, fix_beta, if 'batchnorm_cudnn' in batchnorm_op_list: ctx_list.append({'ctx': mx.gpu(0), 'batchnorm_data': data, 'type_dict': {'batchnorm_data': np.float32}}) sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma, - fix_beta=fix_beta, use_global_stats=use_global_stats, name='batchnorm', cudnn_off=False)) check_consistency(sym_list, ctx_list) - def test_1d_batchnorm(fix_gamma, fix_beta, use_global_stats): + def test_1d_batchnorm(fix_gamma, use_global_stats): data = (2, 3, 20) test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu', 'batchnorm_gpu', 'batchnorm_cudnn'], data=data, - fix_gamma=fix_gamma, fix_beta=fix_beta, use_global_stats=use_global_stats) + fix_gamma=fix_gamma, use_global_stats=use_global_stats) - def test_2d_batchnorm(fix_gamma, fix_beta, use_global_stats): + def test_2d_batchnorm(fix_gamma, use_global_stats): data = (2, 3, 10, 10) - # batchmorm_v1 is deprecated. - # `fix_beta` parameter is available only in new batchnorm operator. - # Checking consistency separately for batchnormv1 and batchnorm. - test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_v1_cpu', 'batchnorm_v1_gpu'], - data=data, - fix_gamma=fix_gamma, fix_beta=fix_beta, use_global_stats=use_global_stats) - - test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu', + test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_v1_cpu', 'batchnorm_v1_gpu', + 'batchnorm_cpu', 'batchnorm_gpu', 'batchnorm_cudnn'], data=data, - fix_gamma=fix_gamma, fix_beta=fix_beta, use_global_stats=use_global_stats) + fix_gamma=fix_gamma, use_global_stats=use_global_stats) - def test_3d_batchnorm(fix_gamma, fix_beta, use_global_stats): + def test_3d_batchnorm(fix_gamma, use_global_stats): data = (2, 3, 3, 5, 5) test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu', 'batchnorm_gpu'], data=data, - fix_gamma=fix_gamma, fix_beta=fix_beta, use_global_stats=use_global_stats) - - test_1d_batchnorm(True, False, False) - test_1d_batchnorm(False, True, False) - test_1d_batchnorm(False, False, True) - test_1d_batchnorm(True, True, True) - - test_2d_batchnorm(True, False, False) - test_2d_batchnorm(False, True, False) - test_2d_batchnorm(False, False, True) - test_2d_batchnorm(True, True, True) - - test_3d_batchnorm(True, False, False) - test_3d_batchnorm(False, True, False) - test_3d_batchnorm(False, False, True) - test_3d_batchnorm(True, True, True) + fix_gamma=fix_gamma, use_global_stats=use_global_stats) + + test_1d_batchnorm(True, False) + test_1d_batchnorm(False, False) + test_1d_batchnorm(False, True) + test_1d_batchnorm(True, True) + + test_2d_batchnorm(True, False) + test_2d_batchnorm(False, False) + test_2d_batchnorm(False, True) + test_2d_batchnorm(True, True) + + test_3d_batchnorm(True, False) + test_3d_batchnorm(False, False) + test_3d_batchnorm(False, True) + test_3d_batchnorm(True, True) @with_seed(1234) diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index a3c39ff53616..e5490096f60f 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -235,7 +235,7 @@ def check_batchnorm_training(stype): mx.nd.array(beta).tostype(stype)] mean_std = [mx.nd.array(rolling_mean).tostype(stype), mx.nd.array(rolling_std).tostype(stype)] - test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False) + test = mx.symbol.BatchNorm(data, fix_gamma=False) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) stypes = ['row_sparse', 'default'] diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 5a5d95669ffc..5332517fa680 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -1534,25 +1534,25 @@ def check_batchnorm_training(stype): test = mx.symbol.BatchNorm_v1(data, fix_gamma=True) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm(data, fix_gamma=True, fix_beta=True) + test = mx.symbol.BatchNorm(data, fix_gamma=True) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) test = mx.symbol.BatchNorm_v1(data, fix_gamma=True, use_global_stats=True) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm(data, fix_gamma=True, fix_beta=True, use_global_stats=True) + test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) test = mx.symbol.BatchNorm_v1(data, fix_gamma=False) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False) + test = mx.symbol.BatchNorm(data, fix_gamma=False) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) test = mx.symbol.BatchNorm_v1(data, fix_gamma=False, use_global_stats=True) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False, use_global_stats=True) + test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-2, rtol=0.16, atol=1e-2) # Test varying channel axis @@ -1581,16 +1581,16 @@ def check_batchnorm_training(stype): xmean_std = [mx.nd.array(xrolling_mean).tostype(stype), mx.nd.array(xrolling_std).tostype(stype)] - test = mx.symbol.BatchNorm(data, fix_gamma=True, fix_beta=True, axis=chaxis) + test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis) check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) - test = mx.symbol.BatchNorm(data, fix_gamma=True, fix_beta=False, use_global_stats=True, axis=chaxis) + test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True, axis=chaxis) check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) - test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=True, axis=chaxis) + test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis) check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) - test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False, use_global_stats=True, axis=chaxis) + test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis) check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-2, rtol=0.2, atol=0.01) check_batchnorm_training('default') diff --git a/tests/python/unittest/test_sparse_operator.py b/tests/python/unittest/test_sparse_operator.py index bddab11f95de..57808248b081 100644 --- a/tests/python/unittest/test_sparse_operator.py +++ b/tests/python/unittest/test_sparse_operator.py @@ -2124,19 +2124,13 @@ def test_batchnorm_fallback(): test = mx.symbol.BatchNorm(data, fix_gamma=True) assertRaises(MXNetError, check_numeric_gradient, test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm(data, fix_beta=True) - assertRaises(MXNetError, check_numeric_gradient, test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True) assertRaises(MXNetError, check_numeric_gradient, test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm(data, fix_beta=True, use_global_stats=True) - assertRaises(MXNetError, check_numeric_gradient, test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2) - - test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False) + test = mx.symbol.BatchNorm(data, fix_gamma=False) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2) - test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False, use_global_stats=True) + test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True) check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2) # Test varying channel axis @@ -2167,20 +2161,14 @@ def test_batchnorm_fallback(): test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis) assertRaises(MXNetError, check_numeric_gradient, test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01) - - test = mx.symbol.BatchNorm(data, fix_beta=True, axis=chaxis) - assertRaises(MXNetError, check_numeric_gradient, test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01) test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True, axis=chaxis) assertRaises(MXNetError, check_numeric_gradient, test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01) - test = mx.symbol.BatchNorm(data, fix_beta=True, use_global_stats=True, axis=chaxis) - assertRaises(MXNetError, check_numeric_gradient, test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01) - - test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False, axis=chaxis) + test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis) check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01) - test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False, use_global_stats=True, axis=chaxis) + test = mx.symbol.BatchNorm(data, fix_gamma=False, use_global_stats=True, axis=chaxis) check_numeric_gradient(test, in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)