From 3f431948fa8d3a313d6538226c95fddeb0725a0c Mon Sep 17 00:00:00 2001 From: Zhang Date: Fri, 29 Jun 2018 16:38:44 -0700 Subject: [PATCH 01/34] sync batch norm --- python/mxnet/gluon/contrib/nn/basic_layers.py | 74 ++- src/operator/contrib/sync_batch_norm-inl.h | 560 ++++++++++++++++++ src/operator/contrib/sync_batch_norm.cc | 111 ++++ src/operator/contrib/sync_batch_norm.cu | 37 ++ .../python/unittest/test_contrib_operator.py | 83 +++ 5 files changed, 863 insertions(+), 2 deletions(-) create mode 100644 src/operator/contrib/sync_batch_norm-inl.h create mode 100644 src/operator/contrib/sync_batch_norm.cc create mode 100644 src/operator/contrib/sync_batch_norm.cu diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index 1edef1476ee3..4b2768551110 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -18,11 +18,12 @@ # coding: utf-8 # pylint: disable= arguments-differ """Custom neural network layers in model_zoo.""" -__all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding'] +__all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding', + 'SyncBatchNorm'] from .... import nd from ...block import HybridBlock, Block -from ...nn import Sequential, HybridSequential +from ...nn import Sequential, HybridSequential, BatchNorm class Concurrent(Sequential): """Lays `Block`s concurrently. @@ -151,3 +152,72 @@ def __repr__(self): s = '{block_name}({input_dim} -> {output_dim}, {dtype})' return s.format(block_name=self.__class__.__name__, **self._kwargs) + +class SyncBatchNorm(BatchNorm): + """Cross-GPU Synchronized Batch normalization (SyncBN) + Standard BN [1]_ implementation only normalize the data within each device. + SyncBN normalizes the input within the whole mini-batch. + We follow the sync-onece implmentation described in the paper [2]_ . + Parameters + ---------- + momentum: float, default 0.9 + Momentum for the moving average. + epsilon: float, default 1e-5 + Small float added to variance to avoid dividing by zero. + center: bool, default True + If True, add offset of `beta` to normalized tensor. + If False, `beta` is ignored. + scale: bool, default True + If True, multiply by `gamma`. If False, `gamma` is not used. + When the next layer is linear (also e.g. `nn.relu`), + this can be disabled since the scaling + will be done by the next layer. + use_global_stats: bool, default False + If True, use global moving statistics instead of local batch-norm. This will force + change batch-norm into a scale shift operator. + If False, use local batch-norm. + beta_initializer: str or `Initializer`, default 'zeros' + Initializer for the beta weight. + gamma_initializer: str or `Initializer`, default 'ones' + Initializer for the gamma weight. + moving_mean_initializer: str or `Initializer`, default 'zeros' + Initializer for the moving mean. + moving_variance_initializer: str or `Initializer`, default 'ones' + Initializer for the moving variance. + in_channels : int, default 0 + Number of channels (feature maps) in input data. If not specified, + initialization will be deferred to the first time `forward` is called + and `in_channels` will be inferred from the shape of input data. + num_devices : int, default number of visible GPUs + Inputs: + - **data**: input tensor with arbitrary shape. + Outputs: + - **out**: output tensor with the same shape as `data`. + Reference: + .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating + deep network training by reducing internal covariate shift." *ICML 2015* + .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, + Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* + """ + def __init__(self, in_channels=0, num_devices=None, momentum=0.9, epsilon=1e-5, + center=True, scale=True, use_global_stats=False, beta_initializer='zeros', + gamma_initializer='ones', running_mean_initializer='zeros', + running_variance_initializer='ones', **kwargs): + super(SyncBatchNorm, self).__init__(1, momentum, epsilon, center, scale, + use_global_stats, beta_initializer, gamma_initializer, + running_mean_initializer, running_variance_initializer, + in_channels, **kwargs) + num_devices = self._get_num_devices() if num_devices is None else num_devices + self._kwargs = {'eps': epsilon, 'momentum': momentum, + 'fix_gamma': not scale, 'use_global_stats': use_global_stats, + 'ndev': num_devices, 'key': self.prefix} + + def _get_num_devices(self): + # Caution: if not using all the GPUs, please mannually set num_devices + num_devices = len(test_utils.list_gpus()) + num_devices = num_devices if num_devices > 0 else 1 + return num_devices + + def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var): + return F.SyncBatchNorm(x, gamma, beta, running_mean, running_var, + name='fwd', **self._kwargs) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h new file mode 100644 index 000000000000..2ed32995a22a --- /dev/null +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -0,0 +1,560 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file sync_batch_norm-inl.h + * \brief Synchronized BatchNorm modified from BatchNormV1 + * \author Hang Zhang +*/ +#ifndef MXNET_OPERATOR_CONTRIB_SYNC_BATCH_NORM_INL_H_ +#define MXNET_OPERATOR_CONTRIB_SYNC_BATCH_NORM_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "../operator_common.h" +#include "../mshadow_op.h" + +namespace mxnet { +namespace op { + +#define MAX_GPU_NUM 16 + +template +class SharedND { +private: + int nDev = 4; + bool flag[MAX_GPU_NUM]; + T mean; + bool meanReady = false; + bool meanInited = false; +public: + T data[MAX_GPU_NUM]; + SharedND(int ndev) + :nDev(ndev) { + //LOG(INFO) << "Creating SharedND"; + memset(flag, false, MAX_GPU_NUM * sizeof(bool)); + } + + bool Push(T input, int index) { + if (flag[index] == false) { + data[index] = input; + flag[index] = true; + return true; + } + else { + //LOG(INFO) << "Error in Pushing" << index; + return false; + } + } + + T Pop(int index) { + //LOG(INFO) << "Poping: " << index; + while(!MeanReady());//deadlock may occur + flag[index] = false; + T tmp = mean; + ResetMean(); + return tmp; + } + + bool MeanReady() { + if (meanReady) { + //LOG(INFO) << "meanReady"; + return true; + } + //LOG(INFO) << "Not meanReady"; + for (int i = 0; i < nDev; i++) { + if (!flag[i]) { + //LOG(INFO) << "flag[i] is not ready: " << i; + return false; + } + //LOG(INFO) << "flag[i] is ready: " << i; + } + //LOG(INFO) << "reducing the data now"; + for (int i = 1; i < nDev; i++) { + data[0] += data[i]; + //LOG(INFO) << "adding data[i]" << i; + } + //LOG(INFO) << "reducing the data finished"; + if (!meanInited) { + mean = mshadow::NewTensor(data[0].shape_, 0.0f); + meanInited = true; + } + mean = data[0] * 1.0f / nDev; + meanReady = true; + //LOG(INFO) << "meanReady now"; + return true; + } + + void ResetMean() { + for (int i = 0; i < nDev; i++) { + if (flag[i]) return; + } + meanReady = false; + } +}; + +template +class GlobalShared { + public: + T* Register(const std::string &key, int ndev) { + std::lock_guard lock(mutex_); + auto it = registry_.find(key); + if (it != registry_.end()) return it->second; + T *newT = new T(ndev); + registry_[key] = newT; + //LOG(INFO) << "registring" << key; + return newT; + } + private: + std::mutex mutex_; + std::map registry_; +}; + +namespace syncbatchnorm { +enum BatchNormOpInputs {kData, kGamma, kBeta}; +enum BatchNormOpOutputs {kOut, kMean, kVar}; +enum BatchNormOpAuxiliary {kMovingMean, kMovingVar}; +enum BatchNormBackResource {kTempSpace}; +} // namespace syncbatchnorm + +struct SyncBatchNormParam : public dmlc::Parameter { + float eps; + float momentum; + bool fix_gamma; + bool use_global_stats; + bool output_mean_var; + int ndev; + std::string key; + DMLC_DECLARE_PARAMETER(SyncBatchNormParam) { + DMLC_DECLARE_FIELD(eps).set_default(1e-3f) + .describe("Epsilon to prevent div 0"); + DMLC_DECLARE_FIELD(momentum).set_default(0.9f) + .describe("Momentum for moving average"); + DMLC_DECLARE_FIELD(fix_gamma).set_default(true) + .describe("Fix gamma while training"); + DMLC_DECLARE_FIELD(use_global_stats).set_default(false) + .describe("Whether use global moving statistics instead of local batch-norm. " + "This will force change batch-norm into a scale shift operator."); + DMLC_DECLARE_FIELD(output_mean_var).set_default(false) + .describe("Output All,normal mean and var"); + DMLC_DECLARE_FIELD(ndev).set_default(1) + .describe("The count of GPU devices"); + DMLC_DECLARE_FIELD(key) + .set_default("") + .describe("Hash key for synchronization"); + } +}; + +static pthread_mutex_t mm = PTHREAD_MUTEX_INITIALIZER; +static pthread_barrier_t globalBarrier; +static int globalRank = 0; +static bool flagGlobalBarrier = false; +static GlobalShared>> globalSharedMean; +static GlobalShared>> globalSharedVar; +static GlobalShared>> globalSharedGrad; +static GlobalShared>> globalSharedProd; + +template +class SyncBatchNorm : public Operator { + public: + explicit SyncBatchNorm(SyncBatchNormParam param) { + this->param_ = param; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_states) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 3U); + CHECK_EQ(aux_states.size(), 2U); + if (ctx.is_train) { + CHECK_EQ(out_data.size(), 3U); + CHECK_EQ(req.size(), 3U); + } else { + CHECK_GE(out_data.size(), 1U); + CHECK_GE(req.size(), 1U); + CHECK_EQ(req[syncbatchnorm::kOut], kWriteTo); + } + // get my rank + pthread_mutex_lock(&mm); + if (flagGlobalBarrier == false) { + pthread_barrier_init(&globalBarrier, NULL, param_.ndev); + flagGlobalBarrier = true; + } + int myRank = globalRank; + //LOG(INFO) << "myRank" << myRank; + globalRank += 1; + pthread_mutex_unlock(&mm); + pthread_barrier_wait(&globalBarrier); + globalRank = 0; + + Stream *s = ctx.get_stream(); + const real_t scale = static_cast(in_data[syncbatchnorm::kData].shape_[1]) / + static_cast(in_data[syncbatchnorm::kData].shape_.Size()); + Tensor data; + Tensor out; + if (in_data[syncbatchnorm::kData].ndim() == 2) { + Shape<4> dshape = Shape4(in_data[syncbatchnorm::kData].shape_[0], + in_data[syncbatchnorm::kData].shape_[1], 1, 1); + data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); + out = out_data[syncbatchnorm::kOut].get_with_shape(dshape, s); + } else { + data = in_data[syncbatchnorm::kData].get(s); + out = out_data[syncbatchnorm::kOut].get(s); + } + Tensor slope = in_data[syncbatchnorm::kGamma].get(s); + Tensor bias = in_data[syncbatchnorm::kBeta].get(s); + Tensor moving_mean = aux_states[syncbatchnorm::kMovingMean].get(s); + Tensor moving_var = aux_states[syncbatchnorm::kMovingVar].get(s); + + if (param_.fix_gamma) slope = 1.f; + + // whether use global statistics + if (ctx.is_train && !param_.use_global_stats) { + Tensor mean = out_data[syncbatchnorm::kMean].get(s); + Tensor var = out_data[syncbatchnorm::kVar].get(s); + CHECK(req[syncbatchnorm::kMean] == kNullOp || req[syncbatchnorm::kMean] == kWriteTo); + CHECK(req[syncbatchnorm::kVar] == kNullOp || req[syncbatchnorm::kVar] == kWriteTo); + // E(x) and E(x^2) + mean = scale * sumall_except_dim<1>(data); + var = scale * sumall_except_dim<1>(F(data)); + //var = scale * sumall_except_dim<1>(F( + // data - broadcast<1>(mean, data.shape_))); + // copy to cpu + SharedND> *sharedMean = + globalSharedMean.Register(param_.key, param_.ndev); + SharedND> *sharedVar = + globalSharedVar.Register(param_.key, param_.ndev); + Tensor mean_cpu = NewTensor(mean.shape_, 0.0f); + mshadow::Copy(mean_cpu, mean, s); + Tensor var_cpu = NewTensor(var.shape_, 0.0f); + mshadow::Copy(var_cpu,var,s); + // push and pull + sharedMean->Push(mean_cpu, myRank); + sharedVar->Push(var_cpu, myRank); + pthread_barrier_wait(&globalBarrier); + pthread_mutex_lock(&mm); + mean_cpu = sharedMean->Pop(myRank); + var_cpu = sharedVar->Pop(myRank); + pthread_mutex_unlock(&mm); + // copy back to gpu + mshadow::Copy(mean, mean_cpu, s); + mshadow::Copy(var, var_cpu, s); + + var = var-F(mean); + Assign(out, req[syncbatchnorm::kOut], broadcast<1>(slope, out.shape_) * + (data - broadcast<1>(mean, data.shape_)) / + F(broadcast<1>(var + param_.eps, data.shape_)) + + broadcast<1>(bias, out.shape_)); + } else { + Assign(out, req[syncbatchnorm::kOut], broadcast<1>(slope / + F(moving_var + param_.eps), + data.shape_) * data + + broadcast<1>(bias - (slope * moving_mean) / + F(moving_var + param_.eps), data.shape_)); + // Set mean and var tensors to their moving values + Tensor mean = out_data[syncbatchnorm::kMean].get(s); + Tensor var = out_data[syncbatchnorm::kVar].get(s); + mean = F(moving_mean); + var = F(moving_var); + } + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_states) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), param_.output_mean_var ? 3U : 1U); + CHECK_EQ(in_data.size(), 3U); + CHECK_EQ(out_data.size(), 3U); + CHECK_EQ(in_grad.size(), 3U); + // get my rank + pthread_mutex_lock(&mm); + if (flagGlobalBarrier == false) { + pthread_barrier_init(&globalBarrier, NULL, param_.ndev); + flagGlobalBarrier = true; + } + int myRank = globalRank; + //LOG(INFO) << "myRank" << myRank; + globalRank += 1; + pthread_mutex_unlock(&mm); + pthread_barrier_wait(&globalBarrier); + globalRank = 0; + + Stream *s = ctx.get_stream(); + Tensor data, grad, grad_in; + const real_t scale = static_cast(out_grad[syncbatchnorm::kOut].shape_[1]) / + static_cast(out_grad[syncbatchnorm::kOut].shape_.Size()); + if (in_data[syncbatchnorm::kData].ndim() == 2) { + Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0], + out_grad[syncbatchnorm::kOut].shape_[1], 1, 1); + data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); + grad = out_grad[syncbatchnorm::kOut].get_with_shape(dshape, s); + grad_in = in_grad[syncbatchnorm::kData].get_with_shape(dshape, s); + } else { + data = in_data[syncbatchnorm::kData].get(s); + grad = out_grad[syncbatchnorm::kOut].get(s); + grad_in = in_grad[syncbatchnorm::kData].get(s); + } + + Tensor mean = out_data[syncbatchnorm::kMean].get(s); + Tensor var = out_data[syncbatchnorm::kVar].get(s); + Tensor slope = in_data[syncbatchnorm::kGamma].get(s); + // Tensor bias = in_data[kBeta].get(s); + Tensor gslope = in_grad[syncbatchnorm::kGamma].get(s); + Tensor gbias = in_grad[syncbatchnorm::kBeta].get(s); + // update moving avg + Tensor moving_mean = aux_states[syncbatchnorm::kMovingMean].get(s); + Tensor moving_var = aux_states[syncbatchnorm::kMovingVar].get(s); + + if (param_.fix_gamma) slope = 1.f; + + if (ctx.is_train && !param_.use_global_stats) { + // get requested temp space + Tensor workspace = ctx.requested[syncbatchnorm::kTempSpace].get_space( + mshadow::Shape2(5, mean.shape_[0]), s); + Tensor gmean = workspace[0]; + Tensor gvar = workspace[1]; + //Tensor gstd = workspace[1]; + // Tensor tmp = workspace[2]; + + moving_mean = moving_mean * param_.momentum + mean * (1 - param_.momentum); + moving_var = moving_var * param_.momentum + var * (1 - param_.momentum); + // cal + Tensor sumGrad = workspace[3]; + Tensor sumProd = workspace[4]; + sumGrad = sumall_except_dim<1>(grad); + sumProd = sumall_except_dim<1>(grad * data); + + SharedND> *sharedGrad= + globalSharedGrad.Register(param_.key, param_.ndev); + SharedND> *sharedProd = + globalSharedProd.Register(param_.key, param_.ndev); + + Tensor grad_cpu = NewTensor(sumGrad.shape_, 0.0f); + mshadow::Copy(grad_cpu, sumGrad, s); + Tensor prod_cpu = NewTensor(sumProd.shape_, 0.0f); + mshadow::Copy(prod_cpu, sumProd, s); + // push and pull + sharedGrad->Push(grad_cpu, myRank); + sharedProd->Push(prod_cpu, myRank); + pthread_barrier_wait(&globalBarrier); + pthread_mutex_lock(&mm); + grad_cpu = sharedGrad->Pop(myRank); + prod_cpu = sharedProd->Pop(myRank); + pthread_mutex_unlock(&mm); + // copy back to gpu + mshadow::Copy(sumGrad, grad_cpu, s); + mshadow::Copy(sumProd, prod_cpu, s); + + gvar = (sumProd - sumGrad * mean) * slope * (-0.5f) * + F(var + param_.eps, -1.5f); + //gstd = (sumProd - sumGrad * mean) * slope * F(var + param_.eps, -2.0f); + //gmean = - slope / F(var + param_.eps) * sumGrad; + gmean = sumGrad * slope; + gmean *= -1.0f / F(var + param_.eps); + // assign + if (!param_.fix_gamma) { + Assign(gslope, req[syncbatchnorm::kGamma], + sumall_except_dim<1>( + grad * (data - broadcast<1>(mean, data.shape_)) / + F(broadcast<1>(var + param_.eps, data.shape_)))); + } else { + Assign(gslope, req[syncbatchnorm::kGamma], 0.0f); + } + Assign(grad_in, req[syncbatchnorm::kData], + (grad * broadcast<1>(slope, data.shape_)) * + broadcast<1>(1.0f / F(var + param_.eps), data.shape_) + + broadcast<1>(gvar, data.shape_) * //(1 - scale / param_.ndev) * + scale * 2.0f * (data - broadcast<1>(mean, data.shape_)) + + broadcast<1>(gmean, data.shape_) * scale); + Assign(gbias, req[syncbatchnorm::kBeta], sumall_except_dim<1>(grad)); + } else { + // use global statistics with freeze moving mean and var. + if (!param_.fix_gamma) { + Assign(gslope, req[syncbatchnorm::kGamma], + sumall_except_dim<1>( + grad * (data - broadcast<1>(moving_mean, data.shape_)) / + F(broadcast<1>(moving_var + param_.eps, data.shape_)))); + } else { + Assign(gslope, req[syncbatchnorm::kGamma], 0.0f); + } + Assign(gbias, req[syncbatchnorm::kBeta], sumall_except_dim<1>(grad)); + Assign(grad_in, req[syncbatchnorm::kData], (grad * broadcast<1>(slope, data.shape_)) * + broadcast<1>( + 1.0f / F(moving_var + param_.eps), data.shape_)); + } + } + + private: + SyncBatchNormParam param_; +}; // class SyncBatchNorm + +template +Operator *CreateOp(SyncBatchNormParam param, int dtype); + + +#if DMLC_USE_CXX11 +class SyncBatchNormProp : public OperatorProperty { + public: + void Init(const std::vector >& kwargs) override { + param_.Init(kwargs); + } + + std::map GetParams() const override { + return param_.__DICT__(); + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + using namespace mshadow; + CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; + const TShape &dshape = in_shape->at(0); + if (dshape.ndim() == 0) return false; + in_shape->at(1) = TShape(Shape1(dshape[1])); + in_shape->at(2) = TShape(Shape1(dshape[1])); + out_shape->clear(); + out_shape->push_back(dshape); + out_shape->push_back(Shape1(dshape[1])); + out_shape->push_back(Shape1(dshape[1])); + + aux_shape->clear(); + aux_shape->push_back(Shape1(dshape[1])); + aux_shape->push_back(Shape1(dshape[1])); + return true; + } + + bool InferType(std::vector *in_type, + std::vector *out_type, + std::vector *aux_type) const override { + using namespace mshadow; + CHECK_GE(in_type->size(), 1U); + int dtype = (*in_type)[0]; + CHECK_NE(dtype, -1) << "First input must have specified type"; + // For float16 input type beta, gamma, mean, and average are stored in float32. + // For other input types, these parameters have the same type as input + // NOTE: This requirement is from cuDNN (v. 4 and 5) + int dtype_param = (dtype == kFloat16) ? kFloat32 : dtype; + for (index_t i = 1; i < in_type->size(); ++i) { + if ((*in_type)[i] == -1) { + (*in_type)[i] = dtype_param; + } else { + UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, ListArguments()[i]); + } + } + for (index_t i = 0; i < aux_type->size(); ++i) { + if ((*aux_type)[i] != -1) { + UNIFORM_TYPE_CHECK((*aux_type)[i], dtype_param, ListArguments()[i]); + } + } + int n_aux = this->ListAuxiliaryStates().size(); + aux_type->clear(); + for (int i = 0; i < n_aux; ++i ) aux_type->push_back(dtype_param); + int n_out = this->ListOutputs().size(); + out_type->clear(); + out_type->push_back(dtype); + for (int i = 1; i < n_out; ++i ) out_type->push_back(dtype_param); + return true; + } + + OperatorProperty* Copy() const override { + auto ptr = new SyncBatchNormProp(); + ptr->param_ = param_; + return ptr; + } + + std::string TypeString() const override { + return "SyncBatchNorm"; + } + + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + return {out_grad[syncbatchnorm::kOut], + out_data[syncbatchnorm::kMean], + out_data[syncbatchnorm::kVar], + in_data[syncbatchnorm::kData], + in_data[syncbatchnorm::kGamma] + }; + } + + std::vector BackwardResource( + const std::vector &in_shape) const override { + return {ResourceRequest::kTempSpace}; + } + + int NumVisibleOutputs() const override { + if (param_.output_mean_var) { + return 3; + } + return 1; + } + + int NumOutputs() const override { + return 3; + } + + std::vector ListArguments() const override { + return {"data", "gamma", "beta"}; + } + + std::vector ListOutputs() const override { + return {"output", "mean", "var"}; + } + + std::vector ListAuxiliaryStates() const override { + return {"moving_mean", "moving_var"}; + } + + Operator* CreateOperator(Context ctx) const override { + LOG(FATAL) << "Not Implemented."; + return NULL; + } + + Operator* CreateOperatorEx(Context ctx, std::vector *in_shape, + std::vector *in_type) const override; + + inline const SyncBatchNormParam& getParam() const { + return param_; + } + + private: + SyncBatchNormParam param_; +}; // class SyncBatchNormProp + +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_CONTRIB_SYNC_BATCH_NORM_INL_H_ diff --git a/src/operator/contrib/sync_batch_norm.cc b/src/operator/contrib/sync_batch_norm.cc new file mode 100644 index 000000000000..7f3a8706d482 --- /dev/null +++ b/src/operator/contrib/sync_batch_norm.cc @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file sync_batch_norm.cc + * \brief Synchronized BatchNorm modified from BatchNormV1 + * \author Hang Zhang +*/ + +#include "sync_batch_norm-inl.h" +#include + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(SyncBatchNormParam param, int dtype) { + return new SyncBatchNorm(param); +} + +// DO_BIND_DISPATCH comes from operator_common.h +Operator *SyncBatchNormProp::CreateOperatorEx(Context ctx, std::vector *in_shape, + std::vector *in_type) const { + std::vector out_shape, aux_shape; + std::vector out_type, aux_type; + CHECK(InferType(in_type, &out_type, &aux_type)); + CHECK(InferShape(in_shape, &out_shape, &aux_shape)); + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); +} + +DMLC_REGISTER_PARAMETER(SyncBatchNormParam); + +MXNET_REGISTER_OP_PROPERTY(SyncBatchNorm, SyncBatchNormProp) +.describe(R"code(Batch normalization. + +This operator is DEPRECATED. Perform BatchNorm on the input. + +Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as +well as offset ``beta``. + +Assume the input has more than one dimension and we normalize along axis 1. +We first compute the mean and variance along this axis: + +.. math:: + + data\_mean[i] = mean(data[:,i,:,...]) \\ + data\_var[i] = var(data[:,i,:,...]) + +Then compute the normalized output, which has the same shape as input, as following: + +.. math:: + + out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} * gamma[i] + beta[i] + +Both *mean* and *var* returns a scalar by treating the input as a vector. + +Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` +have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and +``data_var`` as well, which are needed for the backward pass. + +Besides the inputs and the outputs, this operator accepts two auxiliary +states, ``moving_mean`` and ``moving_var``, which are *k*-length +vectors. They are global statistics for the whole dataset, which are updated +by:: + + moving_mean = moving_mean * momentum + data_mean * (1 - momentum) + moving_var = moving_var * momentum + data_var * (1 - momentum) + +If ``use_global_stats`` is set to be true, then ``moving_mean`` and +``moving_var`` are used instead of ``data_mean`` and ``data_var`` to compute +the output. It is often used during inference. + +Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true, +then set ``gamma`` to 1 and its gradient to 0. + +)code" ADD_FILELINE) +.add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization") +.add_argument("gamma", "NDArray-or-Symbol", "gamma array") +.add_argument("beta", "NDArray-or-Symbol", "beta array") +.add_argument("moving_mean", "NDArray-or-Symbol", "running mean of input") +.add_argument("moving_var", "NDArray-or-Symbol", "running variance of input") +.add_arguments(SyncBatchNormParam::__FIELDS__()); + +NNVM_REGISTER_OP(SyncBatchNorm) +.set_attr("FSetInputVarAttrOnCompose", + [](const nnvm::NodeAttrs& attrs, nnvm::NodePtr var, const int index) { + if (var->attrs.dict.find("__init__") != var->attrs.dict.end()) return; + if (index == 3) { + var->attrs.dict["__init__"] = "[\"zero\", {}]"; + } else if (index == 4) { + var->attrs.dict["__init__"] = "[\"one\", {}]"; + } + }); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/sync_batch_norm.cu b/src/operator/contrib/sync_batch_norm.cu new file mode 100644 index 000000000000..005ac3f61a9c --- /dev/null +++ b/src/operator/contrib/sync_batch_norm.cu @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file sync_batch_norm.cc + * \brief Synchronized BatchNorm modified from BatchNormV1 + * \author Hang Zhang +*/ + +#include "sync_batch_norm-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(SyncBatchNormParam param, int dtype) { + return new SyncBatchNorm(param); +} + +} // namespace op +} // namespace mxnet + diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py index a220f08d20d4..b7fa986fb7ad 100644 --- a/tests/python/unittest/test_contrib_operator.py +++ b/tests/python/unittest/test_contrib_operator.py @@ -24,6 +24,8 @@ from numpy.testing import assert_allclose, assert_array_equal from mxnet.test_utils import * import unittest +from mxnet.gluon import nn +from mxnet.gluon.utils import split_and_load def test_box_nms_op(): def test_box_nms_forward(data, expected, thresh=0.5, valid=0, topk=-1, coord=2, score=1, cid=0, @@ -243,6 +245,87 @@ def assert_match(inputs, x, y, threshold, is_ascend=False): assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [1, -1, 0], [2, 0], 1e-12, False) assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [-1, 0, 1], [1, 2], 100, True) + +def _checkBatchNormResult(bn1, bn2, input, num_devices=1, cuda=False): + def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): + npa, npb = a.asnumpy(), b.asnumpy() + assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format( + a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + + def _find_bn(module): + if isinstance(module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module + elif isinstance(module.module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module.module + + raise RuntimeError('BN not found') + + def _syncParameters(bn1, bn2): + ctx = input.context + bn2.gamma.set_data(bn1.gamma.data(mx.cpu(0))) + bn2.beta.set_data(bn1.beta.data(ctx)) + bn2.running_mean.set_data(bn1.running_mean.data(ctx)) + bn2.running_var.set_data(bn1.running_var.data(ctx)) + + input1 = input.copy() + input2 = input.copy() + + # using the same values for gamma and beta + #_syncParameters(_find_bn(bn1), _find_bn(bn2)) + + if cuda: + input1 = input.as_in_context(mx.gpu(0)) + bn1.collect_params().reset_ctx(mx.gpu(0)) + ctx_list = [mx.gpu(i) for i in range(num_devices)] + else: + ctx_list = [mx.cpu(i) for i in range(num_devices)] + + input1.attach_grad() + #input2.attach_grad() + + inputs2 = split_and_load(input2, ctx_list, batch_axis=0) + for xi in inputs2: + xi.attach_grad() + + with mx.autograd.record(): + output1 = bn1(input1) + output2 = [bn2(xi) for xi in inputs2] + loss1 = (output1 ** 2).sum() + loss2 = [(output ** 2).sum() for output in output2] + mx.autograd.backward(loss1) + mx.autograd.backward(loss2) + + output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) + #print('output1', output1) + #print('output2', output2) + # assert forwarding + _assert_tensor_close(input1, input2) + _assert_tensor_close(output1, output2) + _assert_tensor_close(_find_bn(bn1).running_mean.data(ctx_list[0]), + _find_bn(bn2).running_mean.data(ctx_list[0])) + _assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), + _find_bn(bn2).running_var.data(ctx_list[0])) + input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + #print('input1.grad:', input1.grad) + #print('input1.grad:', input2grad) + _assert_tensor_close(input1.grad, input2grad) + +def testSyncBN(): + ndev = 4 + + bn = nn.BatchNorm(in_channels=1) + sync_bn = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=1, num_devices=ndev) + + bn.initialize() + ctx_list = [mx.cpu(i) for i in range(ndev)] + sync_bn.initialize(ctx=ctx_list) + + # check with unsync version + for i in range(10): + _checkBatchNormResult(bn, sync_bn, mx.nd.random.uniform(shape=(4, 1, 4, 4)), + num_devices=ndev, cuda=False) + if __name__ == '__main__': import nose nose.runmodule() From 7d58073bea2b00e1322fbac7a12f01cc9d3da095 Mon Sep 17 00:00:00 2001 From: Zhang Date: Sat, 30 Jun 2018 12:43:26 -0700 Subject: [PATCH 02/34] global rank and barrier --- python/mxnet/gluon/contrib/nn/basic_layers.py | 2 +- src/operator/contrib/sync_batch_norm-inl.h | 190 +++++++++--------- 2 files changed, 97 insertions(+), 95 deletions(-) diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index 4b2768551110..453e5d9024d4 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -21,7 +21,7 @@ __all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding', 'SyncBatchNorm'] -from .... import nd +from .... import nd, test_utils from ...block import HybridBlock, Block from ...nn import Sequential, HybridSequential, BatchNorm diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index 2ed32995a22a..8e4feccb2d2e 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -28,32 +28,67 @@ #include #include #include +#include #include #include #include #include -#include #include "../operator_common.h" #include "../mshadow_op.h" namespace mxnet { namespace op { +namespace syncbatchnorm { +enum BatchNormOpInputs {kData, kGamma, kBeta}; +enum BatchNormOpOutputs {kOut, kMean, kVar}; +enum BatchNormOpAuxiliary {kMovingMean, kMovingVar}; +enum BatchNormBackResource {kTempSpace}; +} // namespace syncbatchnorm + +struct SyncBatchNormParam : public dmlc::Parameter { + float eps; + float momentum; + bool fix_gamma; + bool use_global_stats; + bool output_mean_var; + int ndev; + std::string key; + DMLC_DECLARE_PARAMETER(SyncBatchNormParam) { + DMLC_DECLARE_FIELD(eps).set_default(1e-3f) + .describe("Epsilon to prevent div 0"); + DMLC_DECLARE_FIELD(momentum).set_default(0.9f) + .describe("Momentum for moving average"); + DMLC_DECLARE_FIELD(fix_gamma).set_default(true) + .describe("Fix gamma while training"); + DMLC_DECLARE_FIELD(use_global_stats).set_default(false) + .describe("Whether use global moving statistics instead of local batch-norm. " + "This will force change batch-norm into a scale shift operator."); + DMLC_DECLARE_FIELD(output_mean_var).set_default(false) + .describe("Output All,normal mean and var"); + DMLC_DECLARE_FIELD(ndev).set_default(1) + .describe("The count of GPU devices"); + DMLC_DECLARE_FIELD(key) + .set_default("") + .describe("Hash key for synchronization"); + } +}; + #define MAX_GPU_NUM 16 template class SharedND { -private: + private: int nDev = 4; bool flag[MAX_GPU_NUM]; T mean; bool meanReady = false; bool meanInited = false; -public: + + public: T data[MAX_GPU_NUM]; SharedND(int ndev) :nDev(ndev) { - //LOG(INFO) << "Creating SharedND"; memset(flag, false, MAX_GPU_NUM * sizeof(bool)); } @@ -62,16 +97,13 @@ class SharedND { data[index] = input; flag[index] = true; return true; - } - else { - //LOG(INFO) << "Error in Pushing" << index; + } else { return false; } } T Pop(int index) { - //LOG(INFO) << "Poping: " << index; - while(!MeanReady());//deadlock may occur + while(!MeanReady()); flag[index] = false; T tmp = mean; ResetMean(); @@ -80,30 +112,22 @@ class SharedND { bool MeanReady() { if (meanReady) { - //LOG(INFO) << "meanReady"; return true; } - //LOG(INFO) << "Not meanReady"; for (int i = 0; i < nDev; i++) { if (!flag[i]) { - //LOG(INFO) << "flag[i] is not ready: " << i; return false; } - //LOG(INFO) << "flag[i] is ready: " << i; } - //LOG(INFO) << "reducing the data now"; for (int i = 1; i < nDev; i++) { data[0] += data[i]; - //LOG(INFO) << "adding data[i]" << i; } - //LOG(INFO) << "reducing the data finished"; if (!meanInited) { mean = mshadow::NewTensor(data[0].shape_, 0.0f); meanInited = true; } mean = data[0] * 1.0f / nDev; meanReady = true; - //LOG(INFO) << "meanReady now"; return true; } @@ -116,7 +140,7 @@ class SharedND { }; template -class GlobalShared { +class GlobalSharedND { public: T* Register(const std::string &key, int ndev) { std::lock_guard lock(mutex_); @@ -124,7 +148,6 @@ class GlobalShared { if (it != registry_.end()) return it->second; T *newT = new T(ndev); registry_[key] = newT; - //LOG(INFO) << "registring" << key; return newT; } private: @@ -132,49 +155,45 @@ class GlobalShared { std::map registry_; }; -namespace syncbatchnorm { -enum BatchNormOpInputs {kData, kGamma, kBeta}; -enum BatchNormOpOutputs {kOut, kMean, kVar}; -enum BatchNormOpAuxiliary {kMovingMean, kMovingVar}; -enum BatchNormBackResource {kTempSpace}; -} // namespace syncbatchnorm +template +class GlobalSharedRank { + public: + T* Register(const std::string &key) { + std::lock_guard lock(mutex_); + auto it = registry_.find(key); + if (it != registry_.end()) return it->second; + T *newT = new T(0); + registry_[key] = newT; + return newT; + } + private: + std::mutex mutex_; + std::map registry_; +}; -struct SyncBatchNormParam : public dmlc::Parameter { - float eps; - float momentum; - bool fix_gamma; - bool use_global_stats; - bool output_mean_var; - int ndev; - std::string key; - DMLC_DECLARE_PARAMETER(SyncBatchNormParam) { - DMLC_DECLARE_FIELD(eps).set_default(1e-3f) - .describe("Epsilon to prevent div 0"); - DMLC_DECLARE_FIELD(momentum).set_default(0.9f) - .describe("Momentum for moving average"); - DMLC_DECLARE_FIELD(fix_gamma).set_default(true) - .describe("Fix gamma while training"); - DMLC_DECLARE_FIELD(use_global_stats).set_default(false) - .describe("Whether use global moving statistics instead of local batch-norm. " - "This will force change batch-norm into a scale shift operator."); - DMLC_DECLARE_FIELD(output_mean_var).set_default(false) - .describe("Output All,normal mean and var"); - DMLC_DECLARE_FIELD(ndev).set_default(1) - .describe("The count of GPU devices"); - DMLC_DECLARE_FIELD(key) - .set_default("") - .describe("Hash key for synchronization"); +class GlobalSharedBarrier { + public: + pthread_barrier_t* Register(const std::string &key, int ndev) { + std::lock_guard lock(mutex_); + auto it = registry_.find(key); + if (it != registry_.end()) return it->second; + pthread_barrier_t *newBarrier = new pthread_barrier_t(); + pthread_barrier_init(newBarrier, NULL, ndev); + registry_[key] = newBarrier; + return newBarrier; } + private: + std::mutex mutex_; + std::map registry_; }; static pthread_mutex_t mm = PTHREAD_MUTEX_INITIALIZER; -static pthread_barrier_t globalBarrier; -static int globalRank = 0; -static bool flagGlobalBarrier = false; -static GlobalShared>> globalSharedMean; -static GlobalShared>> globalSharedVar; -static GlobalShared>> globalSharedGrad; -static GlobalShared>> globalSharedProd; +static GlobalSharedRank globalSharedRank; +static GlobalSharedBarrier globalSharedBarrier; +static GlobalSharedND>> globalSharedMean; +static GlobalSharedND>> globalSharedVar; +static GlobalSharedND>> globalSharedGrad; +static GlobalSharedND>> globalSharedProd; template class SyncBatchNorm : public Operator { @@ -200,18 +219,6 @@ class SyncBatchNorm : public Operator { CHECK_GE(req.size(), 1U); CHECK_EQ(req[syncbatchnorm::kOut], kWriteTo); } - // get my rank - pthread_mutex_lock(&mm); - if (flagGlobalBarrier == false) { - pthread_barrier_init(&globalBarrier, NULL, param_.ndev); - flagGlobalBarrier = true; - } - int myRank = globalRank; - //LOG(INFO) << "myRank" << myRank; - globalRank += 1; - pthread_mutex_unlock(&mm); - pthread_barrier_wait(&globalBarrier); - globalRank = 0; Stream *s = ctx.get_stream(); const real_t scale = static_cast(in_data[syncbatchnorm::kData].shape_[1]) / @@ -236,6 +243,14 @@ class SyncBatchNorm : public Operator { // whether use global statistics if (ctx.is_train && !param_.use_global_stats) { + // get my rank + pthread_barrier_t *globalBarrier = globalSharedBarrier.Register(param_.key, param_.ndev); + int *globalRank = globalSharedRank.Register(param_.key); + pthread_mutex_lock(&mm); + int myRank = *globalRank; + *globalRank += 1; + pthread_mutex_unlock(&mm); + // get the mean and var Tensor mean = out_data[syncbatchnorm::kMean].get(s); Tensor var = out_data[syncbatchnorm::kVar].get(s); CHECK(req[syncbatchnorm::kMean] == kNullOp || req[syncbatchnorm::kMean] == kWriteTo); @@ -243,13 +258,11 @@ class SyncBatchNorm : public Operator { // E(x) and E(x^2) mean = scale * sumall_except_dim<1>(data); var = scale * sumall_except_dim<1>(F(data)); - //var = scale * sumall_except_dim<1>(F( - // data - broadcast<1>(mean, data.shape_))); - // copy to cpu SharedND> *sharedMean = globalSharedMean.Register(param_.key, param_.ndev); SharedND> *sharedVar = globalSharedVar.Register(param_.key, param_.ndev); + // copy to cpu Tensor mean_cpu = NewTensor(mean.shape_, 0.0f); mshadow::Copy(mean_cpu, mean, s); Tensor var_cpu = NewTensor(var.shape_, 0.0f); @@ -257,7 +270,8 @@ class SyncBatchNorm : public Operator { // push and pull sharedMean->Push(mean_cpu, myRank); sharedVar->Push(var_cpu, myRank); - pthread_barrier_wait(&globalBarrier); + pthread_barrier_wait(globalBarrier); + *globalRank = 0; pthread_mutex_lock(&mm); mean_cpu = sharedMean->Pop(myRank); var_cpu = sharedVar->Pop(myRank); @@ -277,11 +291,6 @@ class SyncBatchNorm : public Operator { data.shape_) * data + broadcast<1>(bias - (slope * moving_mean) / F(moving_var + param_.eps), data.shape_)); - // Set mean and var tensors to their moving values - Tensor mean = out_data[syncbatchnorm::kMean].get(s); - Tensor var = out_data[syncbatchnorm::kVar].get(s); - mean = F(moving_mean); - var = F(moving_var); } } @@ -298,18 +307,6 @@ class SyncBatchNorm : public Operator { CHECK_EQ(in_data.size(), 3U); CHECK_EQ(out_data.size(), 3U); CHECK_EQ(in_grad.size(), 3U); - // get my rank - pthread_mutex_lock(&mm); - if (flagGlobalBarrier == false) { - pthread_barrier_init(&globalBarrier, NULL, param_.ndev); - flagGlobalBarrier = true; - } - int myRank = globalRank; - //LOG(INFO) << "myRank" << myRank; - globalRank += 1; - pthread_mutex_unlock(&mm); - pthread_barrier_wait(&globalBarrier); - globalRank = 0; Stream *s = ctx.get_stream(); Tensor data, grad, grad_in; @@ -340,12 +337,18 @@ class SyncBatchNorm : public Operator { if (param_.fix_gamma) slope = 1.f; if (ctx.is_train && !param_.use_global_stats) { + // get my rank + pthread_barrier_t *globalBarrier = globalSharedBarrier.Register(param_.key, param_.ndev); + int *globalRank = globalSharedRank.Register(param_.key); + pthread_mutex_lock(&mm); + int myRank = *globalRank; + *globalRank += 1; + pthread_mutex_unlock(&mm); // get requested temp space Tensor workspace = ctx.requested[syncbatchnorm::kTempSpace].get_space( mshadow::Shape2(5, mean.shape_[0]), s); Tensor gmean = workspace[0]; Tensor gvar = workspace[1]; - //Tensor gstd = workspace[1]; // Tensor tmp = workspace[2]; moving_mean = moving_mean * param_.momentum + mean * (1 - param_.momentum); @@ -368,7 +371,8 @@ class SyncBatchNorm : public Operator { // push and pull sharedGrad->Push(grad_cpu, myRank); sharedProd->Push(prod_cpu, myRank); - pthread_barrier_wait(&globalBarrier); + pthread_barrier_wait(globalBarrier); + *globalRank = 0; pthread_mutex_lock(&mm); grad_cpu = sharedGrad->Pop(myRank); prod_cpu = sharedProd->Pop(myRank); @@ -379,8 +383,6 @@ class SyncBatchNorm : public Operator { gvar = (sumProd - sumGrad * mean) * slope * (-0.5f) * F(var + param_.eps, -1.5f); - //gstd = (sumProd - sumGrad * mean) * slope * F(var + param_.eps, -2.0f); - //gmean = - slope / F(var + param_.eps) * sumGrad; gmean = sumGrad * slope; gmean *= -1.0f / F(var + param_.eps); // assign From 9908468dedb58e77fde688c849a2734a59562a1b Mon Sep 17 00:00:00 2001 From: Zhang Date: Sat, 30 Jun 2018 13:02:28 -0700 Subject: [PATCH 03/34] lint --- src/operator/contrib/sync_batch_norm-inl.h | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index 8e4feccb2d2e..928672a99064 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -79,16 +79,15 @@ struct SyncBatchNormParam : public dmlc::Parameter { template class SharedND { private: - int nDev = 4; - bool flag[MAX_GPU_NUM]; + int nDev; T mean; + T data[MAX_GPU_NUM]; + bool flag[MAX_GPU_NUM]; bool meanReady = false; bool meanInited = false; public: - T data[MAX_GPU_NUM]; - SharedND(int ndev) - :nDev(ndev) { + SharedND(int ndev) :nDev(ndev) { memset(flag, false, MAX_GPU_NUM * sizeof(bool)); } @@ -103,11 +102,11 @@ class SharedND { } T Pop(int index) { - while(!MeanReady()); + while (!MeanReady()) {}; flag[index] = false; T tmp = mean; ResetMean(); - return tmp; + return tmp; } bool MeanReady() { @@ -265,8 +264,8 @@ class SyncBatchNorm : public Operator { // copy to cpu Tensor mean_cpu = NewTensor(mean.shape_, 0.0f); mshadow::Copy(mean_cpu, mean, s); - Tensor var_cpu = NewTensor(var.shape_, 0.0f); - mshadow::Copy(var_cpu,var,s); + Tensor var_cpu = NewTensor(var.shape_, 0.0f); + mshadow::Copy(var_cpu, var, s); // push and pull sharedMean->Push(mean_cpu, myRank); sharedVar->Push(var_cpu, myRank); @@ -366,7 +365,7 @@ class SyncBatchNorm : public Operator { Tensor grad_cpu = NewTensor(sumGrad.shape_, 0.0f); mshadow::Copy(grad_cpu, sumGrad, s); - Tensor prod_cpu = NewTensor(sumProd.shape_, 0.0f); + Tensor prod_cpu = NewTensor(sumProd.shape_, 0.0f); mshadow::Copy(prod_cpu, sumProd, s); // push and pull sharedGrad->Push(grad_cpu, myRank); @@ -397,7 +396,7 @@ class SyncBatchNorm : public Operator { Assign(grad_in, req[syncbatchnorm::kData], (grad * broadcast<1>(slope, data.shape_)) * broadcast<1>(1.0f / F(var + param_.eps), data.shape_) + - broadcast<1>(gvar, data.shape_) * //(1 - scale / param_.ndev) * + broadcast<1>(gvar, data.shape_) * scale * 2.0f * (data - broadcast<1>(mean, data.shape_)) + broadcast<1>(gmean, data.shape_) * scale); Assign(gbias, req[syncbatchnorm::kBeta], sumall_except_dim<1>(grad)); From 3ce37a3d9a88c06a0f130fe4d2050d51afda1fa8 Mon Sep 17 00:00:00 2001 From: Zhang Date: Sat, 30 Jun 2018 14:02:10 -0700 Subject: [PATCH 04/34] cpplint --- src/operator/contrib/sync_batch_norm-inl.h | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index 928672a99064..e790c2da5ed2 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -87,7 +87,7 @@ class SharedND { bool meanInited = false; public: - SharedND(int ndev) :nDev(ndev) { + explicit SharedND(int ndev) :nDev(ndev) { memset(flag, false, MAX_GPU_NUM * sizeof(bool)); } @@ -102,7 +102,7 @@ class SharedND { } T Pop(int index) { - while (!MeanReady()) {}; + while (!MeanReady()) {} flag[index] = false; T tmp = mean; ResetMean(); @@ -358,14 +358,14 @@ class SyncBatchNorm : public Operator { sumGrad = sumall_except_dim<1>(grad); sumProd = sumall_except_dim<1>(grad * data); - SharedND> *sharedGrad= + SharedND> *sharedGrad = globalSharedGrad.Register(param_.key, param_.ndev); SharedND> *sharedProd = globalSharedProd.Register(param_.key, param_.ndev); Tensor grad_cpu = NewTensor(sumGrad.shape_, 0.0f); mshadow::Copy(grad_cpu, sumGrad, s); - Tensor prod_cpu = NewTensor(sumProd.shape_, 0.0f); + Tensor prod_cpu = NewTensor(sumProd.shape_, 0.0f); mshadow::Copy(prod_cpu, sumProd, s); // push and pull sharedGrad->Push(grad_cpu, myRank); From a32ec07575331ec56af674a51f87cb97ef359222 Mon Sep 17 00:00:00 2001 From: Zhang Date: Sat, 30 Jun 2018 14:07:33 -0700 Subject: [PATCH 05/34] pylint --- python/mxnet/gluon/contrib/nn/basic_layers.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index 453e5d9024d4..64ffd91bde02 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -203,10 +203,10 @@ def __init__(self, in_channels=0, num_devices=None, momentum=0.9, epsilon=1e-5, center=True, scale=True, use_global_stats=False, beta_initializer='zeros', gamma_initializer='ones', running_mean_initializer='zeros', running_variance_initializer='ones', **kwargs): - super(SyncBatchNorm, self).__init__(1, momentum, epsilon, center, scale, - use_global_stats, beta_initializer, gamma_initializer, - running_mean_initializer, running_variance_initializer, - in_channels, **kwargs) + super(SyncBatchNorm, self).__init__(1, momentum, epsilon, center, scale, use_global_stats, + beta_initializer, gamma_initializer, + running_mean_initializer, running_variance_initializer, + in_channels, **kwargs) num_devices = self._get_num_devices() if num_devices is None else num_devices self._kwargs = {'eps': epsilon, 'momentum': momentum, 'fix_gamma': not scale, 'use_global_stats': use_global_stats, From cd6d93b4a535a541036f3d109ef54a3cfecd62b8 Mon Sep 17 00:00:00 2001 From: Zhang Date: Sat, 30 Jun 2018 14:09:51 -0700 Subject: [PATCH 06/34] doc --- docs/api/python/gluon/contrib.md | 1 + python/mxnet/gluon/contrib/nn/basic_layers.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/docs/api/python/gluon/contrib.md b/docs/api/python/gluon/contrib.md index 877a294d9a1f..98f36f882164 100644 --- a/docs/api/python/gluon/contrib.md +++ b/docs/api/python/gluon/contrib.md @@ -36,6 +36,7 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p HybridConcurrent Identity SparseEmbedding + SyncBatchNorm ``` ### Recurrent neural network diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index 64ffd91bde02..78953b2a1ca4 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -189,10 +189,13 @@ class SyncBatchNorm(BatchNorm): initialization will be deferred to the first time `forward` is called and `in_channels` will be inferred from the shape of input data. num_devices : int, default number of visible GPUs + + Inputs: - **data**: input tensor with arbitrary shape. Outputs: - **out**: output tensor with the same shape as `data`. + Reference: .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." *ICML 2015* From a3bd860119de50d3b76030507c5700015027918e Mon Sep 17 00:00:00 2001 From: Zhang Date: Sat, 30 Jun 2018 14:32:42 -0700 Subject: [PATCH 07/34] add ref --- src/operator/contrib/sync_batch_norm-inl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index e790c2da5ed2..f35c9161811d 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -76,6 +76,7 @@ struct SyncBatchNormParam : public dmlc::Parameter { #define MAX_GPU_NUM 16 +// Adapt from https://github.com/brucechin/SharedTensor template class SharedND { private: From a76dad416c5022c5961e05fd08b8b8422c2ff70a Mon Sep 17 00:00:00 2001 From: Zhang Date: Sun, 1 Jul 2018 12:31:02 -0700 Subject: [PATCH 08/34] customized barrier --- src/operator/contrib/sync_batch_norm-inl.h | 48 +++++++++++++++++----- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index f35c9161811d..62c514825faa 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -28,7 +28,8 @@ #include #include #include -#include +//#include +# include #include #include #include @@ -140,7 +141,7 @@ class SharedND { }; template -class GlobalSharedND { +class GlobalShared { public: T* Register(const std::string &key, int ndev) { std::lock_guard lock(mutex_); @@ -171,6 +172,27 @@ class GlobalSharedRank { std::map registry_; }; +class Barrier { + private: + std::mutex mutex_; + std::condition_variable cv_; + std::size_t count_; + std::size_t total_count_; + public: + explicit Barrier(std::size_t count) : count_{count}, total_count_{count} { } + void Wait() + { + std::unique_lock lock{mutex_}; + if (--count_ == 0) { + count_ = total_count_; + cv_.notify_all(); + } else { + cv_.wait(lock, [this] { return count_ == total_count_; }); + } + } +}; + +/* class GlobalSharedBarrier { public: pthread_barrier_t* Register(const std::string &key, int ndev) { @@ -186,14 +208,16 @@ class GlobalSharedBarrier { std::mutex mutex_; std::map registry_; }; +*/ static pthread_mutex_t mm = PTHREAD_MUTEX_INITIALIZER; static GlobalSharedRank globalSharedRank; -static GlobalSharedBarrier globalSharedBarrier; -static GlobalSharedND>> globalSharedMean; -static GlobalSharedND>> globalSharedVar; -static GlobalSharedND>> globalSharedGrad; -static GlobalSharedND>> globalSharedProd; +// static GlobalSharedBarrier globalSharedBarrier; +static GlobalShared globalSharedBarrier; +static GlobalShared>> globalSharedMean; +static GlobalShared>> globalSharedVar; +static GlobalShared>> globalSharedGrad; +static GlobalShared>> globalSharedProd; template class SyncBatchNorm : public Operator { @@ -244,7 +268,7 @@ class SyncBatchNorm : public Operator { // whether use global statistics if (ctx.is_train && !param_.use_global_stats) { // get my rank - pthread_barrier_t *globalBarrier = globalSharedBarrier.Register(param_.key, param_.ndev); + Barrier *globalBarrier = globalSharedBarrier.Register(param_.key, param_.ndev); int *globalRank = globalSharedRank.Register(param_.key); pthread_mutex_lock(&mm); int myRank = *globalRank; @@ -270,7 +294,8 @@ class SyncBatchNorm : public Operator { // push and pull sharedMean->Push(mean_cpu, myRank); sharedVar->Push(var_cpu, myRank); - pthread_barrier_wait(globalBarrier); + // pthread_barrier_wait(globalBarrier); + globalBarrier->Wait(); *globalRank = 0; pthread_mutex_lock(&mm); mean_cpu = sharedMean->Pop(myRank); @@ -338,7 +363,7 @@ class SyncBatchNorm : public Operator { if (ctx.is_train && !param_.use_global_stats) { // get my rank - pthread_barrier_t *globalBarrier = globalSharedBarrier.Register(param_.key, param_.ndev); + Barrier *globalBarrier = globalSharedBarrier.Register(param_.key, param_.ndev); int *globalRank = globalSharedRank.Register(param_.key); pthread_mutex_lock(&mm); int myRank = *globalRank; @@ -371,7 +396,8 @@ class SyncBatchNorm : public Operator { // push and pull sharedGrad->Push(grad_cpu, myRank); sharedProd->Push(prod_cpu, myRank); - pthread_barrier_wait(globalBarrier); + // pthread_barrier_wait(globalBarrier); + globalBarrier->Wait(); *globalRank = 0; pthread_mutex_lock(&mm); grad_cpu = sharedGrad->Pop(myRank); From b3793bf030bf9a1183ab80645e7f1f83d3317044 Mon Sep 17 00:00:00 2001 From: Zhang Date: Sun, 1 Jul 2018 12:35:47 -0700 Subject: [PATCH 09/34] cpplint --- src/operator/contrib/sync_batch_norm-inl.h | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index 62c514825faa..b423ae7d2cb5 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -28,7 +28,7 @@ #include #include #include -//#include +// #include # include #include #include @@ -180,8 +180,7 @@ class Barrier { std::size_t total_count_; public: explicit Barrier(std::size_t count) : count_{count}, total_count_{count} { } - void Wait() - { + void Wait() { std::unique_lock lock{mutex_}; if (--count_ == 0) { count_ = total_count_; From b31fcfe453a928eb9b51b7fbc2398c9a41e9572d Mon Sep 17 00:00:00 2001 From: Zhang Date: Mon, 2 Jul 2018 09:02:49 -0700 Subject: [PATCH 10/34] get rid of pthread --- src/operator/contrib/sync_batch_norm-inl.h | 71 +++++++--------------- 1 file changed, 23 insertions(+), 48 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index b423ae7d2cb5..37fc9a65204b 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -28,7 +28,6 @@ #include #include #include -// #include # include #include #include @@ -75,22 +74,28 @@ struct SyncBatchNormParam : public dmlc::Parameter { } }; -#define MAX_GPU_NUM 16 - -// Adapt from https://github.com/brucechin/SharedTensor +// Modified from https://github.com/brucechin/SharedTensor template class SharedND { private: int nDev; T mean; - T data[MAX_GPU_NUM]; - bool flag[MAX_GPU_NUM]; + T *data; + bool *flag; bool meanReady = false; bool meanInited = false; + std::mutex mutex_; public: explicit SharedND(int ndev) :nDev(ndev) { - memset(flag, false, MAX_GPU_NUM * sizeof(bool)); + flag = new bool[ndev]; + data = new T[ndev]; + memset(flag, false, ndev * sizeof(bool)); + } + + ~SharedND() { + delete [] flag; + delete [] data; } bool Push(T input, int index) { @@ -104,6 +109,7 @@ class SharedND { } T Pop(int index) { + std::lock_guard lock(mutex_); while (!MeanReady()) {} flag[index] = false; T tmp = mean; @@ -159,13 +165,17 @@ class GlobalShared { template class GlobalSharedRank { public: - T* Register(const std::string &key) { + T Register(const std::string &key, int ndev) { std::lock_guard lock(mutex_); auto it = registry_.find(key); - if (it != registry_.end()) return it->second; + if (it != registry_.end()) { + T* tmpT = it->second; + *tmpT = (*tmpT == ndev - 1) ? 0 : *tmpT + 1; + return *tmpT; + } T *newT = new T(0); registry_[key] = newT; - return newT; + return *newT; } private: std::mutex mutex_; @@ -191,27 +201,8 @@ class Barrier { } }; -/* -class GlobalSharedBarrier { - public: - pthread_barrier_t* Register(const std::string &key, int ndev) { - std::lock_guard lock(mutex_); - auto it = registry_.find(key); - if (it != registry_.end()) return it->second; - pthread_barrier_t *newBarrier = new pthread_barrier_t(); - pthread_barrier_init(newBarrier, NULL, ndev); - registry_[key] = newBarrier; - return newBarrier; - } - private: - std::mutex mutex_; - std::map registry_; -}; -*/ - -static pthread_mutex_t mm = PTHREAD_MUTEX_INITIALIZER; +// Global variables for Synchronizations static GlobalSharedRank globalSharedRank; -// static GlobalSharedBarrier globalSharedBarrier; static GlobalShared globalSharedBarrier; static GlobalShared>> globalSharedMean; static GlobalShared>> globalSharedVar; @@ -268,11 +259,7 @@ class SyncBatchNorm : public Operator { if (ctx.is_train && !param_.use_global_stats) { // get my rank Barrier *globalBarrier = globalSharedBarrier.Register(param_.key, param_.ndev); - int *globalRank = globalSharedRank.Register(param_.key); - pthread_mutex_lock(&mm); - int myRank = *globalRank; - *globalRank += 1; - pthread_mutex_unlock(&mm); + int myRank = globalSharedRank.Register(param_.key, param_.ndev); // get the mean and var Tensor mean = out_data[syncbatchnorm::kMean].get(s); Tensor var = out_data[syncbatchnorm::kVar].get(s); @@ -293,13 +280,9 @@ class SyncBatchNorm : public Operator { // push and pull sharedMean->Push(mean_cpu, myRank); sharedVar->Push(var_cpu, myRank); - // pthread_barrier_wait(globalBarrier); globalBarrier->Wait(); - *globalRank = 0; - pthread_mutex_lock(&mm); mean_cpu = sharedMean->Pop(myRank); var_cpu = sharedVar->Pop(myRank); - pthread_mutex_unlock(&mm); // copy back to gpu mshadow::Copy(mean, mean_cpu, s); mshadow::Copy(var, var_cpu, s); @@ -363,11 +346,7 @@ class SyncBatchNorm : public Operator { if (ctx.is_train && !param_.use_global_stats) { // get my rank Barrier *globalBarrier = globalSharedBarrier.Register(param_.key, param_.ndev); - int *globalRank = globalSharedRank.Register(param_.key); - pthread_mutex_lock(&mm); - int myRank = *globalRank; - *globalRank += 1; - pthread_mutex_unlock(&mm); + int myRank = globalSharedRank.Register(param_.key, param_.ndev); // get requested temp space Tensor workspace = ctx.requested[syncbatchnorm::kTempSpace].get_space( mshadow::Shape2(5, mean.shape_[0]), s); @@ -395,13 +374,9 @@ class SyncBatchNorm : public Operator { // push and pull sharedGrad->Push(grad_cpu, myRank); sharedProd->Push(prod_cpu, myRank); - // pthread_barrier_wait(globalBarrier); globalBarrier->Wait(); - *globalRank = 0; - pthread_mutex_lock(&mm); grad_cpu = sharedGrad->Pop(myRank); prod_cpu = sharedProd->Pop(myRank); - pthread_mutex_unlock(&mm); // copy back to gpu mshadow::Copy(sumGrad, grad_cpu, s); mshadow::Copy(sumProd, prod_cpu, s); From 8f58e63cb6f5893cbdc64a527c283a29e3480adb Mon Sep 17 00:00:00 2001 From: Zhang Date: Mon, 2 Jul 2018 11:34:45 -0700 Subject: [PATCH 11/34] address comments --- python/mxnet/gluon/contrib/nn/basic_layers.py | 10 +-- src/operator/contrib/sync_batch_norm-inl.h | 74 +++++++++++-------- src/operator/contrib/sync_batch_norm.cc | 10 ++- .../python/unittest/test_contrib_operator.py | 6 -- 4 files changed, 56 insertions(+), 44 deletions(-) diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index 78953b2a1ca4..517a2754b09f 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -160,6 +160,11 @@ class SyncBatchNorm(BatchNorm): We follow the sync-onece implmentation described in the paper [2]_ . Parameters ---------- + in_channels : int, default 0 + Number of channels (feature maps) in input data. If not specified, + initialization will be deferred to the first time `forward` is called + and `in_channels` will be inferred from the shape of input data. + num_devices : int, default number of visible GPUs momentum: float, default 0.9 Momentum for the moving average. epsilon: float, default 1e-5 @@ -184,11 +189,6 @@ class SyncBatchNorm(BatchNorm): Initializer for the moving mean. moving_variance_initializer: str or `Initializer`, default 'ones' Initializer for the moving variance. - in_channels : int, default 0 - Number of channels (feature maps) in input data. If not specified, - initialization will be deferred to the first time `forward` is called - and `in_channels` will be inferred from the shape of input data. - num_devices : int, default number of visible GPUs Inputs: diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index 37fc9a65204b..2d70959f89cc 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -28,7 +28,7 @@ #include #include #include -# include +#include #include #include #include @@ -78,30 +78,30 @@ struct SyncBatchNormParam : public dmlc::Parameter { template class SharedND { private: - int nDev; - T mean; - T *data; - bool *flag; - bool meanReady = false; - bool meanInited = false; + int num_devices_; + T mean_; + T *data_; + bool *flag_; + bool mean_ready_ = false; + bool mean_inited_ = false; std::mutex mutex_; public: - explicit SharedND(int ndev) :nDev(ndev) { - flag = new bool[ndev]; - data = new T[ndev]; - memset(flag, false, ndev * sizeof(bool)); + explicit SharedND(int ndev) :num_devices_(ndev) { + flag_ = new bool[ndev]; + data_ = new T[ndev]; + memset(flag_, false, ndev * sizeof(bool)); } ~SharedND() { - delete [] flag; - delete [] data; + delete [] flag_; + delete [] data_; } bool Push(T input, int index) { - if (flag[index] == false) { - data[index] = input; - flag[index] = true; + if (flag_[index] == false) { + data_[index] = input; + flag_[index] = true; return true; } else { return false; @@ -111,38 +111,38 @@ class SharedND { T Pop(int index) { std::lock_guard lock(mutex_); while (!MeanReady()) {} - flag[index] = false; - T tmp = mean; + flag_[index] = false; + T tmp = mean_; ResetMean(); return tmp; } bool MeanReady() { - if (meanReady) { + if (mean_ready_) { return true; } - for (int i = 0; i < nDev; i++) { - if (!flag[i]) { + for (int i = 0; i < num_devices_; i++) { + if (!flag_[i]) { return false; } } - for (int i = 1; i < nDev; i++) { - data[0] += data[i]; + for (int i = 1; i < num_devices_; i++) { + data_[0] += data_[i]; } - if (!meanInited) { - mean = mshadow::NewTensor(data[0].shape_, 0.0f); - meanInited = true; + if (!mean_inited_) { + mean_ = mshadow::NewTensor(data_[0].shape_, 0.0f); + mean_inited_ = true; } - mean = data[0] * 1.0f / nDev; - meanReady = true; + mean_ = data_[0] * 1.0f / num_devices_; + mean_ready_ = true; return true; } void ResetMean() { - for (int i = 0; i < nDev; i++) { - if (flag[i]) return; + for (int i = 0; i < num_devices_; i++) { + if (flag_[i]) return; } - meanReady = false; + mean_ready_ = false; } }; @@ -157,6 +157,12 @@ class GlobalShared { registry_[key] = newT; return newT; } + ~GlobalShared() { + for (auto it = registry_.begin(); it != registry_.end(); it++) { + T *ptr = it->second; + delete [] ptr; + } + } private: std::mutex mutex_; std::map registry_; @@ -177,6 +183,12 @@ class GlobalSharedRank { registry_[key] = newT; return *newT; } + ~GlobalSharedRank() { + for (auto it = registry_.begin(); it != registry_.end(); it++) { + T *ptr = it->second; + delete [] ptr; + } + } private: std::mutex mutex_; std::map registry_; diff --git a/src/operator/contrib/sync_batch_norm.cc b/src/operator/contrib/sync_batch_norm.cc index 7f3a8706d482..33b5200d1b36 100644 --- a/src/operator/contrib/sync_batch_norm.cc +++ b/src/operator/contrib/sync_batch_norm.cc @@ -48,10 +48,11 @@ DMLC_REGISTER_PARAMETER(SyncBatchNormParam); MXNET_REGISTER_OP_PROPERTY(SyncBatchNorm, SyncBatchNormProp) .describe(R"code(Batch normalization. -This operator is DEPRECATED. Perform BatchNorm on the input. - Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as well as offset ``beta``. +Standard BN [1]_ implementation only normalize the data within each device. +SyncBN normalizes the input within the whole mini-batch. +We follow the sync-onece implmentation described in the paper [2]_ . Assume the input has more than one dimension and we normalize along axis 1. We first compute the mean and variance along this axis: @@ -88,6 +89,11 @@ the output. It is often used during inference. Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true, then set ``gamma`` to 1 and its gradient to 0. +Reference: + .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating + deep network training by reducing internal covariate shift." *ICML 2015* + .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, + Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* )code" ADD_FILELINE) .add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization") .add_argument("gamma", "NDArray-or-Symbol", "gamma array") diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py index b7fa986fb7ad..5f38cd7f3ad6 100644 --- a/tests/python/unittest/test_contrib_operator.py +++ b/tests/python/unittest/test_contrib_operator.py @@ -282,8 +282,6 @@ def _syncParameters(bn1, bn2): ctx_list = [mx.cpu(i) for i in range(num_devices)] input1.attach_grad() - #input2.attach_grad() - inputs2 = split_and_load(input2, ctx_list, batch_axis=0) for xi in inputs2: xi.attach_grad() @@ -297,8 +295,6 @@ def _syncParameters(bn1, bn2): mx.autograd.backward(loss2) output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) - #print('output1', output1) - #print('output2', output2) # assert forwarding _assert_tensor_close(input1, input2) _assert_tensor_close(output1, output2) @@ -307,8 +303,6 @@ def _syncParameters(bn1, bn2): _assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), _find_bn(bn2).running_var.data(ctx_list[0])) input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) - #print('input1.grad:', input1.grad) - #print('input1.grad:', input2grad) _assert_tensor_close(input1.grad, input2grad) def testSyncBN(): From cc60d11f44e37e954a627c8db43bd1b6fc45e68d Mon Sep 17 00:00:00 2001 From: Zhang Date: Mon, 2 Jul 2018 13:52:50 -0700 Subject: [PATCH 12/34] warning --- python/mxnet/gluon/contrib/nn/basic_layers.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index 517a2754b09f..af7dbebbb85d 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -24,6 +24,7 @@ from .... import nd, test_utils from ...block import HybridBlock, Block from ...nn import Sequential, HybridSequential, BatchNorm +import warnings class Concurrent(Sequential): """Lays `Block`s concurrently. @@ -216,7 +217,9 @@ def __init__(self, in_channels=0, num_devices=None, momentum=0.9, epsilon=1e-5, 'ndev': num_devices, 'key': self.prefix} def _get_num_devices(self): - # Caution: if not using all the GPUs, please mannually set num_devices + warnings.warn("Caution using SyncBatchNorm: " + "if not using all the GPUs, please mannually set num_devices", + UserWarning) num_devices = len(test_utils.list_gpus()) num_devices = num_devices if num_devices > 0 else 1 return num_devices From 1c87e6ebae69c89625d1cb73210924b98c1ee2e8 Mon Sep 17 00:00:00 2001 From: Zhang Date: Mon, 2 Jul 2018 14:33:21 -0700 Subject: [PATCH 13/34] pylint --- python/mxnet/gluon/contrib/nn/basic_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index af7dbebbb85d..9a42c1994e79 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -21,10 +21,10 @@ __all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding', 'SyncBatchNorm'] +import warnings from .... import nd, test_utils from ...block import HybridBlock, Block from ...nn import Sequential, HybridSequential, BatchNorm -import warnings class Concurrent(Sequential): """Lays `Block`s concurrently. From 4bb0f3a8f975c055fdc5f642fb355cc1a6461e55 Mon Sep 17 00:00:00 2001 From: Zhang Date: Mon, 2 Jul 2018 17:20:38 -0700 Subject: [PATCH 14/34] gpu unitest --- tests/python/gpu/test_operator_gpu.py | 78 +++++++++++++++++++ .../python/unittest/test_contrib_operator.py | 77 ------------------ 2 files changed, 78 insertions(+), 77 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 736a9ee0268f..014b9627d4f8 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -28,6 +28,8 @@ from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal from mxnet.base import MXNetError from mxnet import autograd +from mxnet.gluon import nn +from mxnet.gluon.utils import split_and_load from numpy.testing import assert_allclose curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) @@ -1914,6 +1916,82 @@ def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. assert mx.context.num_gpus() > 0 + +def _checkBatchNormResult(bn1, bn2, input, num_devices=1, cuda=False): + def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): + npa, npb = a.asnumpy(), b.asnumpy() + assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format( + a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + + def _find_bn(module): + if isinstance(module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module + elif isinstance(module.module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module.module + + raise RuntimeError('BN not found') + + def _syncParameters(bn1, bn2): + ctx = input.context + bn2.gamma.set_data(bn1.gamma.data(mx.gpu(0))) + bn2.beta.set_data(bn1.beta.data(ctx)) + bn2.running_mean.set_data(bn1.running_mean.data(ctx)) + bn2.running_var.set_data(bn1.running_var.data(ctx)) + + input1 = input.copy() + input2 = input.copy() + + # using the same values for gamma and beta + #_syncParameters(_find_bn(bn1), _find_bn(bn2)) + + if cuda: + input1 = input.as_in_context(mx.gpu(0)) + bn1.collect_params().reset_ctx(mx.gpu(0)) + ctx_list = [mx.gpu(i) for i in range(num_devices)] + else: + ctx_list = [mx.gpu(i) for i in range(num_devices)] + + input1.attach_grad() + inputs2 = split_and_load(input2, ctx_list, batch_axis=0) + for xi in inputs2: + xi.attach_grad() + + with mx.autograd.record(): + output1 = bn1(input1) + output2 = [bn2(xi) for xi in inputs2] + loss1 = (output1 ** 2).sum() + loss2 = [(output ** 2).sum() for output in output2] + mx.autograd.backward(loss1) + mx.autograd.backward(loss2) + + output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) + # assert forwarding + _assert_tensor_close(input1, input2) + _assert_tensor_close(output1, output2) + _assert_tensor_close(_find_bn(bn1).running_mean.data(ctx_list[0]), + _find_bn(bn2).running_mean.data(ctx_list[0])) + _assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), + _find_bn(bn2).running_var.data(ctx_list[0])) + input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + _assert_tensor_close(input1.grad, input2grad) + +def testSyncBN(): + ndev = 4 + + bn = nn.BatchNorm(in_channels=1) + sync_bn = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=1, num_devices=ndev) + + bn.initialize() + ctx_list = [mx.gpu(i) for i in range(ndev)] + sync_bn.initialize(ctx=ctx_list) + + # check with unsync version + for i in range(10): + _checkBatchNormResult(bn, sync_bn, mx.nd.random.uniform(shape=(4, 1, 4, 4)), + num_devices=ndev, cuda=True) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py index 5f38cd7f3ad6..a220f08d20d4 100644 --- a/tests/python/unittest/test_contrib_operator.py +++ b/tests/python/unittest/test_contrib_operator.py @@ -24,8 +24,6 @@ from numpy.testing import assert_allclose, assert_array_equal from mxnet.test_utils import * import unittest -from mxnet.gluon import nn -from mxnet.gluon.utils import split_and_load def test_box_nms_op(): def test_box_nms_forward(data, expected, thresh=0.5, valid=0, topk=-1, coord=2, score=1, cid=0, @@ -245,81 +243,6 @@ def assert_match(inputs, x, y, threshold, is_ascend=False): assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [1, -1, 0], [2, 0], 1e-12, False) assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [-1, 0, 1], [1, 2], 100, True) - -def _checkBatchNormResult(bn1, bn2, input, num_devices=1, cuda=False): - def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): - npa, npb = a.asnumpy(), b.asnumpy() - assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ - 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format( - a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) - - def _find_bn(module): - if isinstance(module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module - elif isinstance(module.module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module.module - - raise RuntimeError('BN not found') - - def _syncParameters(bn1, bn2): - ctx = input.context - bn2.gamma.set_data(bn1.gamma.data(mx.cpu(0))) - bn2.beta.set_data(bn1.beta.data(ctx)) - bn2.running_mean.set_data(bn1.running_mean.data(ctx)) - bn2.running_var.set_data(bn1.running_var.data(ctx)) - - input1 = input.copy() - input2 = input.copy() - - # using the same values for gamma and beta - #_syncParameters(_find_bn(bn1), _find_bn(bn2)) - - if cuda: - input1 = input.as_in_context(mx.gpu(0)) - bn1.collect_params().reset_ctx(mx.gpu(0)) - ctx_list = [mx.gpu(i) for i in range(num_devices)] - else: - ctx_list = [mx.cpu(i) for i in range(num_devices)] - - input1.attach_grad() - inputs2 = split_and_load(input2, ctx_list, batch_axis=0) - for xi in inputs2: - xi.attach_grad() - - with mx.autograd.record(): - output1 = bn1(input1) - output2 = [bn2(xi) for xi in inputs2] - loss1 = (output1 ** 2).sum() - loss2 = [(output ** 2).sum() for output in output2] - mx.autograd.backward(loss1) - mx.autograd.backward(loss2) - - output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) - # assert forwarding - _assert_tensor_close(input1, input2) - _assert_tensor_close(output1, output2) - _assert_tensor_close(_find_bn(bn1).running_mean.data(ctx_list[0]), - _find_bn(bn2).running_mean.data(ctx_list[0])) - _assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), - _find_bn(bn2).running_var.data(ctx_list[0])) - input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) - _assert_tensor_close(input1.grad, input2grad) - -def testSyncBN(): - ndev = 4 - - bn = nn.BatchNorm(in_channels=1) - sync_bn = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=1, num_devices=ndev) - - bn.initialize() - ctx_list = [mx.cpu(i) for i in range(ndev)] - sync_bn.initialize(ctx=ctx_list) - - # check with unsync version - for i in range(10): - _checkBatchNormResult(bn, sync_bn, mx.nd.random.uniform(shape=(4, 1, 4, 4)), - num_devices=ndev, cuda=False) - if __name__ == '__main__': import nose nose.runmodule() From 62d88910a67ffb6fe7683eed01e15f080dd16ddb Mon Sep 17 00:00:00 2001 From: Zhang Date: Mon, 2 Jul 2018 18:07:01 -0700 Subject: [PATCH 15/34] gpu 0 --- tests/python/gpu/test_operator_gpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 014b9627d4f8..a73afa7f4234 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1948,9 +1948,9 @@ def _syncParameters(bn1, bn2): if cuda: input1 = input.as_in_context(mx.gpu(0)) bn1.collect_params().reset_ctx(mx.gpu(0)) - ctx_list = [mx.gpu(i) for i in range(num_devices)] + ctx_list = [mx.gpu(0) for _ in range(num_devices)] else: - ctx_list = [mx.gpu(i) for i in range(num_devices)] + ctx_list = [mx.gpu(0) for _ in range(num_devices)] input1.attach_grad() inputs2 = split_and_load(input2, ctx_list, batch_axis=0) @@ -1983,7 +1983,7 @@ def testSyncBN(): sync_bn = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=1, num_devices=ndev) bn.initialize() - ctx_list = [mx.gpu(i) for i in range(ndev)] + ctx_list = [mx.gpu(0) for _ in range(ndev)] sync_bn.initialize(ctx=ctx_list) # check with unsync version From 24543c995e855b0c4efa6c66461d4b7b7bb56878 Mon Sep 17 00:00:00 2001 From: Zhang Date: Tue, 3 Jul 2018 10:25:31 -0700 Subject: [PATCH 16/34] mv to cpu test --- tests/python/gpu/test_operator_gpu.py | 78 ------------------- .../python/unittest/test_contrib_operator.py | 76 ++++++++++++++++++ 2 files changed, 76 insertions(+), 78 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index a73afa7f4234..736a9ee0268f 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -28,8 +28,6 @@ from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal from mxnet.base import MXNetError from mxnet import autograd -from mxnet.gluon import nn -from mxnet.gluon.utils import split_and_load from numpy.testing import assert_allclose curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) @@ -1916,82 +1914,6 @@ def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. assert mx.context.num_gpus() > 0 - -def _checkBatchNormResult(bn1, bn2, input, num_devices=1, cuda=False): - def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): - npa, npb = a.asnumpy(), b.asnumpy() - assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ - 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format( - a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) - - def _find_bn(module): - if isinstance(module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module - elif isinstance(module.module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module.module - - raise RuntimeError('BN not found') - - def _syncParameters(bn1, bn2): - ctx = input.context - bn2.gamma.set_data(bn1.gamma.data(mx.gpu(0))) - bn2.beta.set_data(bn1.beta.data(ctx)) - bn2.running_mean.set_data(bn1.running_mean.data(ctx)) - bn2.running_var.set_data(bn1.running_var.data(ctx)) - - input1 = input.copy() - input2 = input.copy() - - # using the same values for gamma and beta - #_syncParameters(_find_bn(bn1), _find_bn(bn2)) - - if cuda: - input1 = input.as_in_context(mx.gpu(0)) - bn1.collect_params().reset_ctx(mx.gpu(0)) - ctx_list = [mx.gpu(0) for _ in range(num_devices)] - else: - ctx_list = [mx.gpu(0) for _ in range(num_devices)] - - input1.attach_grad() - inputs2 = split_and_load(input2, ctx_list, batch_axis=0) - for xi in inputs2: - xi.attach_grad() - - with mx.autograd.record(): - output1 = bn1(input1) - output2 = [bn2(xi) for xi in inputs2] - loss1 = (output1 ** 2).sum() - loss2 = [(output ** 2).sum() for output in output2] - mx.autograd.backward(loss1) - mx.autograd.backward(loss2) - - output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) - # assert forwarding - _assert_tensor_close(input1, input2) - _assert_tensor_close(output1, output2) - _assert_tensor_close(_find_bn(bn1).running_mean.data(ctx_list[0]), - _find_bn(bn2).running_mean.data(ctx_list[0])) - _assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), - _find_bn(bn2).running_var.data(ctx_list[0])) - input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) - _assert_tensor_close(input1.grad, input2grad) - -def testSyncBN(): - ndev = 4 - - bn = nn.BatchNorm(in_channels=1) - sync_bn = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=1, num_devices=ndev) - - bn.initialize() - ctx_list = [mx.gpu(0) for _ in range(ndev)] - sync_bn.initialize(ctx=ctx_list) - - # check with unsync version - for i in range(10): - _checkBatchNormResult(bn, sync_bn, mx.nd.random.uniform(shape=(4, 1, 4, 4)), - num_devices=ndev, cuda=True) - - if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py index a220f08d20d4..748cc64afb59 100644 --- a/tests/python/unittest/test_contrib_operator.py +++ b/tests/python/unittest/test_contrib_operator.py @@ -24,6 +24,8 @@ from numpy.testing import assert_allclose, assert_array_equal from mxnet.test_utils import * import unittest +from mxnet.gluon import nn +from mxnet.gluon.utils import split_and_load def test_box_nms_op(): def test_box_nms_forward(data, expected, thresh=0.5, valid=0, topk=-1, coord=2, score=1, cid=0, @@ -243,6 +245,80 @@ def assert_match(inputs, x, y, threshold, is_ascend=False): assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [1, -1, 0], [2, 0], 1e-12, False) assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [-1, 0, 1], [1, 2], 100, True) +def _checkBatchNormResult(bn1, bn2, input, num_devices=1, cuda=False): + def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): + npa, npb = a.asnumpy(), b.asnumpy() + assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format( + a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + + def _find_bn(module): + if isinstance(module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module + elif isinstance(module.module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module.module + + raise RuntimeError('BN not found') + + def _syncParameters(bn1, bn2): + ctx = input.context + bn2.gamma.set_data(bn1.gamma.data(mx.cpu(0))) + bn2.beta.set_data(bn1.beta.data(ctx)) + bn2.running_mean.set_data(bn1.running_mean.data(ctx)) + bn2.running_var.set_data(bn1.running_var.data(ctx)) + + input1 = input.copy() + input2 = input.copy() + + # using the same values for gamma and beta + #_syncParameters(_find_bn(bn1), _find_bn(bn2)) + + if cuda: + input1 = input.as_in_context(mx.cpu(0)) + bn1.collect_params().reset_ctx(mx.cpu(0)) + ctx_list = [mx.cpu(0) for _ in range(num_devices)] + else: + ctx_list = [mx.cpu(0) for _ in range(num_devices)] + + input1.attach_grad() + inputs2 = split_and_load(input2, ctx_list, batch_axis=0) + for xi in inputs2: + xi.attach_grad() + + with mx.autograd.record(): + output1 = bn1(input1) + output2 = [bn2(xi) for xi in inputs2] + loss1 = (output1 ** 2).sum() + loss2 = [(output ** 2).sum() for output in output2] + mx.autograd.backward(loss1) + mx.autograd.backward(loss2) + + output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) + # assert forwarding + _assert_tensor_close(input1, input2) + _assert_tensor_close(output1, output2) + _assert_tensor_close(_find_bn(bn1).running_mean.data(ctx_list[0]), + _find_bn(bn2).running_mean.data(ctx_list[0])) + _assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), + _find_bn(bn2).running_var.data(ctx_list[0])) + input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + _assert_tensor_close(input1.grad, input2grad) + +def testSyncBN(): + ndev = 4 + + bn = nn.BatchNorm(in_channels=1) + sync_bn = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=1, num_devices=ndev) + + bn.initialize() + ctx_list = [mx.cpu(0) for _ in range(ndev)] + sync_bn.initialize(ctx=ctx_list) + + # check with unsync version + for i in range(10): + _checkBatchNormResult(bn, sync_bn, mx.nd.random.uniform(shape=(4, 1, 4, 4)), + num_devices=ndev, cuda=True) + if __name__ == '__main__': import nose nose.runmodule() From 50f8593fae7faa21165663ce3343a1bbfe086a20 Mon Sep 17 00:00:00 2001 From: Zhang Date: Tue, 3 Jul 2018 15:53:14 -0700 Subject: [PATCH 17/34] Revert "mv to cpu test" This reverts commit 24543c995e855b0c4efa6c66461d4b7b7bb56878. --- tests/python/gpu/test_operator_gpu.py | 78 +++++++++++++++++++ .../python/unittest/test_contrib_operator.py | 76 ------------------ 2 files changed, 78 insertions(+), 76 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 736a9ee0268f..a73afa7f4234 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -28,6 +28,8 @@ from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal from mxnet.base import MXNetError from mxnet import autograd +from mxnet.gluon import nn +from mxnet.gluon.utils import split_and_load from numpy.testing import assert_allclose curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) @@ -1914,6 +1916,82 @@ def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. assert mx.context.num_gpus() > 0 + +def _checkBatchNormResult(bn1, bn2, input, num_devices=1, cuda=False): + def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): + npa, npb = a.asnumpy(), b.asnumpy() + assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format( + a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + + def _find_bn(module): + if isinstance(module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module + elif isinstance(module.module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module.module + + raise RuntimeError('BN not found') + + def _syncParameters(bn1, bn2): + ctx = input.context + bn2.gamma.set_data(bn1.gamma.data(mx.gpu(0))) + bn2.beta.set_data(bn1.beta.data(ctx)) + bn2.running_mean.set_data(bn1.running_mean.data(ctx)) + bn2.running_var.set_data(bn1.running_var.data(ctx)) + + input1 = input.copy() + input2 = input.copy() + + # using the same values for gamma and beta + #_syncParameters(_find_bn(bn1), _find_bn(bn2)) + + if cuda: + input1 = input.as_in_context(mx.gpu(0)) + bn1.collect_params().reset_ctx(mx.gpu(0)) + ctx_list = [mx.gpu(0) for _ in range(num_devices)] + else: + ctx_list = [mx.gpu(0) for _ in range(num_devices)] + + input1.attach_grad() + inputs2 = split_and_load(input2, ctx_list, batch_axis=0) + for xi in inputs2: + xi.attach_grad() + + with mx.autograd.record(): + output1 = bn1(input1) + output2 = [bn2(xi) for xi in inputs2] + loss1 = (output1 ** 2).sum() + loss2 = [(output ** 2).sum() for output in output2] + mx.autograd.backward(loss1) + mx.autograd.backward(loss2) + + output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) + # assert forwarding + _assert_tensor_close(input1, input2) + _assert_tensor_close(output1, output2) + _assert_tensor_close(_find_bn(bn1).running_mean.data(ctx_list[0]), + _find_bn(bn2).running_mean.data(ctx_list[0])) + _assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), + _find_bn(bn2).running_var.data(ctx_list[0])) + input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + _assert_tensor_close(input1.grad, input2grad) + +def testSyncBN(): + ndev = 4 + + bn = nn.BatchNorm(in_channels=1) + sync_bn = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=1, num_devices=ndev) + + bn.initialize() + ctx_list = [mx.gpu(0) for _ in range(ndev)] + sync_bn.initialize(ctx=ctx_list) + + # check with unsync version + for i in range(10): + _checkBatchNormResult(bn, sync_bn, mx.nd.random.uniform(shape=(4, 1, 4, 4)), + num_devices=ndev, cuda=True) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py index 748cc64afb59..a220f08d20d4 100644 --- a/tests/python/unittest/test_contrib_operator.py +++ b/tests/python/unittest/test_contrib_operator.py @@ -24,8 +24,6 @@ from numpy.testing import assert_allclose, assert_array_equal from mxnet.test_utils import * import unittest -from mxnet.gluon import nn -from mxnet.gluon.utils import split_and_load def test_box_nms_op(): def test_box_nms_forward(data, expected, thresh=0.5, valid=0, topk=-1, coord=2, score=1, cid=0, @@ -245,80 +243,6 @@ def assert_match(inputs, x, y, threshold, is_ascend=False): assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [1, -1, 0], [2, 0], 1e-12, False) assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [-1, 0, 1], [1, 2], 100, True) -def _checkBatchNormResult(bn1, bn2, input, num_devices=1, cuda=False): - def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): - npa, npb = a.asnumpy(), b.asnumpy() - assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ - 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format( - a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) - - def _find_bn(module): - if isinstance(module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module - elif isinstance(module.module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module.module - - raise RuntimeError('BN not found') - - def _syncParameters(bn1, bn2): - ctx = input.context - bn2.gamma.set_data(bn1.gamma.data(mx.cpu(0))) - bn2.beta.set_data(bn1.beta.data(ctx)) - bn2.running_mean.set_data(bn1.running_mean.data(ctx)) - bn2.running_var.set_data(bn1.running_var.data(ctx)) - - input1 = input.copy() - input2 = input.copy() - - # using the same values for gamma and beta - #_syncParameters(_find_bn(bn1), _find_bn(bn2)) - - if cuda: - input1 = input.as_in_context(mx.cpu(0)) - bn1.collect_params().reset_ctx(mx.cpu(0)) - ctx_list = [mx.cpu(0) for _ in range(num_devices)] - else: - ctx_list = [mx.cpu(0) for _ in range(num_devices)] - - input1.attach_grad() - inputs2 = split_and_load(input2, ctx_list, batch_axis=0) - for xi in inputs2: - xi.attach_grad() - - with mx.autograd.record(): - output1 = bn1(input1) - output2 = [bn2(xi) for xi in inputs2] - loss1 = (output1 ** 2).sum() - loss2 = [(output ** 2).sum() for output in output2] - mx.autograd.backward(loss1) - mx.autograd.backward(loss2) - - output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) - # assert forwarding - _assert_tensor_close(input1, input2) - _assert_tensor_close(output1, output2) - _assert_tensor_close(_find_bn(bn1).running_mean.data(ctx_list[0]), - _find_bn(bn2).running_mean.data(ctx_list[0])) - _assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), - _find_bn(bn2).running_var.data(ctx_list[0])) - input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) - _assert_tensor_close(input1.grad, input2grad) - -def testSyncBN(): - ndev = 4 - - bn = nn.BatchNorm(in_channels=1) - sync_bn = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=1, num_devices=ndev) - - bn.initialize() - ctx_list = [mx.cpu(0) for _ in range(ndev)] - sync_bn.initialize(ctx=ctx_list) - - # check with unsync version - for i in range(10): - _checkBatchNormResult(bn, sync_bn, mx.nd.random.uniform(shape=(4, 1, 4, 4)), - num_devices=ndev, cuda=True) - if __name__ == '__main__': import nose nose.runmodule() From b70426d4b8a1269f8f265efbeb845f1fcfba81ab Mon Sep 17 00:00:00 2001 From: Zhang Date: Tue, 3 Jul 2018 16:10:32 -0700 Subject: [PATCH 18/34] ndev = 2 --- tests/python/gpu/test_operator_gpu.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 6214c9398ab2..538441f62c28 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1947,15 +1947,18 @@ def _syncParameters(bn1, bn2): input1 = input.copy() input2 = input.copy() - # using the same values for gamma and beta - #_syncParameters(_find_bn(bn1), _find_bn(bn2)) - if cuda: input1 = input.as_in_context(mx.gpu(0)) - bn1.collect_params().reset_ctx(mx.gpu(0)) ctx_list = [mx.gpu(0) for _ in range(num_devices)] else: - ctx_list = [mx.gpu(0) for _ in range(num_devices)] + ctx_list = [mx.cpu(0) for _ in range(num_devices)] + + bn1.initialize(ctx=ctx_list[0]) + bn2.initialize(ctx=ctx_list) + + # using the same values for gamma and beta + #_syncParameters(_find_bn(bn1), _find_bn(bn2)) + input1.attach_grad() inputs2 = split_and_load(input2, ctx_list, batch_axis=0) @@ -1982,14 +1985,11 @@ def _syncParameters(bn1, bn2): _assert_tensor_close(input1.grad, input2grad) def testSyncBN(): - ndev = 4 + ndev = 2 bn = nn.BatchNorm(in_channels=1) sync_bn = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=1, num_devices=ndev) - bn.initialize() - ctx_list = [mx.gpu(0) for _ in range(ndev)] - sync_bn.initialize(ctx=ctx_list) # check with unsync version for i in range(10): From f888706cac935536606661b6f5d69f134f2ee003 Mon Sep 17 00:00:00 2001 From: Zhang Date: Thu, 5 Jul 2018 11:50:13 -0700 Subject: [PATCH 19/34] debuging --- src/operator/contrib/sync_batch_norm-inl.h | 2 +- tests/python/gpu/test_operator_gpu.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index 2d70959f89cc..f7df67cceb23 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -410,7 +410,7 @@ class SyncBatchNorm : public Operator { (grad * broadcast<1>(slope, data.shape_)) * broadcast<1>(1.0f / F(var + param_.eps), data.shape_) + broadcast<1>(gvar, data.shape_) * - scale * 2.0f * (data - broadcast<1>(mean, data.shape_)) + + scale * (data - broadcast<1>(mean, data.shape_)) + broadcast<1>(gmean, data.shape_) * scale); Assign(gbias, req[syncbatchnorm::kBeta], sumall_except_dim<1>(grad)); } else { diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 538441f62c28..d87b3b1beaf6 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1937,9 +1937,9 @@ def _find_bn(module): raise RuntimeError('BN not found') - def _syncParameters(bn1, bn2): + def _syncParameters(bn1, bn2, ctx): ctx = input.context - bn2.gamma.set_data(bn1.gamma.data(mx.gpu(0))) + bn2.gamma.set_data(bn1.gamma.data(ctx)) bn2.beta.set_data(bn1.beta.data(ctx)) bn2.running_mean.set_data(bn1.running_mean.data(ctx)) bn2.running_var.set_data(bn1.running_var.data(ctx)) @@ -1957,7 +1957,7 @@ def _syncParameters(bn1, bn2): bn2.initialize(ctx=ctx_list) # using the same values for gamma and beta - #_syncParameters(_find_bn(bn1), _find_bn(bn2)) + #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) input1.attach_grad() @@ -1982,6 +1982,8 @@ def _syncParameters(bn1, bn2): _assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), _find_bn(bn2).running_var.data(ctx_list[0])) input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + print('input1.grad', input1.grad) + print('input2grad', input2grad) _assert_tensor_close(input1.grad, input2grad) def testSyncBN(): @@ -1990,7 +1992,6 @@ def testSyncBN(): bn = nn.BatchNorm(in_channels=1) sync_bn = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=1, num_devices=ndev) - # check with unsync version for i in range(10): _checkBatchNormResult(bn, sync_bn, mx.nd.random.uniform(shape=(4, 1, 4, 4)), From b7d2d3c918c88a5edf815e307c0080715a0dc13f Mon Sep 17 00:00:00 2001 From: Zhang Date: Thu, 5 Jul 2018 13:59:41 -0700 Subject: [PATCH 20/34] sum prod --- src/operator/contrib/sync_batch_norm-inl.h | 10 +++++----- tests/python/gpu/test_operator_gpu.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index f7df67cceb23..e295f7710907 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -372,7 +372,7 @@ class SyncBatchNorm : public Operator { Tensor sumGrad = workspace[3]; Tensor sumProd = workspace[4]; sumGrad = sumall_except_dim<1>(grad); - sumProd = sumall_except_dim<1>(grad * data); + sumProd = sumall_except_dim<1>(grad * (data - broadcast<1>(mean, data.shape_))); SharedND> *sharedGrad = globalSharedGrad.Register(param_.key, param_.ndev); @@ -393,7 +393,7 @@ class SyncBatchNorm : public Operator { mshadow::Copy(sumGrad, grad_cpu, s); mshadow::Copy(sumProd, prod_cpu, s); - gvar = (sumProd - sumGrad * mean) * slope * (-0.5f) * + gvar = -1.0f * sumProd * slope * F(var + param_.eps, -1.5f); gmean = sumGrad * slope; gmean *= -1.0f / F(var + param_.eps); @@ -408,9 +408,9 @@ class SyncBatchNorm : public Operator { } Assign(grad_in, req[syncbatchnorm::kData], (grad * broadcast<1>(slope, data.shape_)) * - broadcast<1>(1.0f / F(var + param_.eps), data.shape_) + - broadcast<1>(gvar, data.shape_) * - scale * (data - broadcast<1>(mean, data.shape_)) + + broadcast<1>(1.0f / F(var + param_.eps), data.shape_) + + broadcast<1>(gvar, data.shape_) * // (1.0f - 1.0f * scale / param_.ndev) * + scale * (data - broadcast<1>(mean, data.shape_)) + broadcast<1>(gmean, data.shape_) * scale); Assign(gbias, req[syncbatchnorm::kBeta], sumall_except_dim<1>(grad)); } else { diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index d87b3b1beaf6..7f0224d96d44 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1949,7 +1949,7 @@ def _syncParameters(bn1, bn2, ctx): if cuda: input1 = input.as_in_context(mx.gpu(0)) - ctx_list = [mx.gpu(0) for _ in range(num_devices)] + ctx_list = [mx.gpu(i) for i in range(num_devices)] else: ctx_list = [mx.cpu(0) for _ in range(num_devices)] From 723c670ce49abe8e95320dfa32eed9d3c918ec6e Mon Sep 17 00:00:00 2001 From: Zhang Date: Thu, 5 Jul 2018 14:03:39 -0700 Subject: [PATCH 21/34] lint --- src/operator/contrib/sync_batch_norm-inl.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index e295f7710907..372be487d184 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -409,7 +409,7 @@ class SyncBatchNorm : public Operator { Assign(grad_in, req[syncbatchnorm::kData], (grad * broadcast<1>(slope, data.shape_)) * broadcast<1>(1.0f / F(var + param_.eps), data.shape_) + - broadcast<1>(gvar, data.shape_) * // (1.0f - 1.0f * scale / param_.ndev) * + broadcast<1>(gvar, data.shape_) * scale * (data - broadcast<1>(mean, data.shape_)) + broadcast<1>(gmean, data.shape_) * scale); Assign(gbias, req[syncbatchnorm::kBeta], sumall_except_dim<1>(grad)); From abdd5d102464d1849ebeda39fc768d8cbc080b0a Mon Sep 17 00:00:00 2001 From: Zhang Date: Thu, 5 Jul 2018 15:35:58 -0700 Subject: [PATCH 22/34] contrib, ngpu --- python/mxnet/gluon/contrib/nn/basic_layers.py | 4 ++-- src/operator/contrib/sync_batch_norm-inl.h | 13 +++++++------ src/operator/contrib/sync_batch_norm.cc | 4 ++-- tests/python/gpu/test_operator_gpu.py | 3 +-- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index 9a42c1994e79..69efaa77e4da 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -225,5 +225,5 @@ def _get_num_devices(self): return num_devices def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var): - return F.SyncBatchNorm(x, gamma, beta, running_mean, running_var, - name='fwd', **self._kwargs) + return F.contrib.SyncBatchNorm(x, gamma, beta, running_mean, running_var, + name='fwd', **self._kwargs) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index 372be487d184..bebd37e7aff7 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -70,7 +70,8 @@ struct SyncBatchNormParam : public dmlc::Parameter { .describe("The count of GPU devices"); DMLC_DECLARE_FIELD(key) .set_default("") - .describe("Hash key for synchronization"); + .describe("Hash key for synchronization, please set the same hash key for same layer, " + "Block.prefix is typically used as in :class:`gluon.nn.contrib.SyncBatchNorm`."); } }; @@ -270,7 +271,7 @@ class SyncBatchNorm : public Operator { // whether use global statistics if (ctx.is_train && !param_.use_global_stats) { // get my rank - Barrier *globalBarrier = globalSharedBarrier.Register(param_.key, param_.ndev); + Barrier *global_barrier = globalSharedBarrier.Register(param_.key, param_.ndev); int myRank = globalSharedRank.Register(param_.key, param_.ndev); // get the mean and var Tensor mean = out_data[syncbatchnorm::kMean].get(s); @@ -292,7 +293,7 @@ class SyncBatchNorm : public Operator { // push and pull sharedMean->Push(mean_cpu, myRank); sharedVar->Push(var_cpu, myRank); - globalBarrier->Wait(); + global_barrier->Wait(); mean_cpu = sharedMean->Pop(myRank); var_cpu = sharedVar->Pop(myRank); // copy back to gpu @@ -357,7 +358,7 @@ class SyncBatchNorm : public Operator { if (ctx.is_train && !param_.use_global_stats) { // get my rank - Barrier *globalBarrier = globalSharedBarrier.Register(param_.key, param_.ndev); + Barrier *global_barrier = globalSharedBarrier.Register(param_.key, param_.ndev); int myRank = globalSharedRank.Register(param_.key, param_.ndev); // get requested temp space Tensor workspace = ctx.requested[syncbatchnorm::kTempSpace].get_space( @@ -386,7 +387,7 @@ class SyncBatchNorm : public Operator { // push and pull sharedGrad->Push(grad_cpu, myRank); sharedProd->Push(prod_cpu, myRank); - globalBarrier->Wait(); + global_barrier->Wait(); grad_cpu = sharedGrad->Pop(myRank); prod_cpu = sharedProd->Pop(myRank); // copy back to gpu @@ -509,7 +510,7 @@ class SyncBatchNormProp : public OperatorProperty { } std::string TypeString() const override { - return "SyncBatchNorm"; + return "_contrib_SyncBatchNorm"; } std::vector DeclareBackwardDependency( diff --git a/src/operator/contrib/sync_batch_norm.cc b/src/operator/contrib/sync_batch_norm.cc index 33b5200d1b36..42e5d4d35404 100644 --- a/src/operator/contrib/sync_batch_norm.cc +++ b/src/operator/contrib/sync_batch_norm.cc @@ -45,7 +45,7 @@ Operator *SyncBatchNormProp::CreateOperatorEx(Context ctx, std::vector * DMLC_REGISTER_PARAMETER(SyncBatchNormParam); -MXNET_REGISTER_OP_PROPERTY(SyncBatchNorm, SyncBatchNormProp) +MXNET_REGISTER_OP_PROPERTY(_contrib_SyncBatchNorm, SyncBatchNormProp) .describe(R"code(Batch normalization. Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as @@ -102,7 +102,7 @@ then set ``gamma`` to 1 and its gradient to 0. .add_argument("moving_var", "NDArray-or-Symbol", "running variance of input") .add_arguments(SyncBatchNormParam::__FIELDS__()); -NNVM_REGISTER_OP(SyncBatchNorm) +NNVM_REGISTER_OP(_contrib_SyncBatchNorm) .set_attr("FSetInputVarAttrOnCompose", [](const nnvm::NodeAttrs& attrs, nnvm::NodePtr var, const int index) { if (var->attrs.dict.find("__init__") != var->attrs.dict.end()) return; diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 7f0224d96d44..e765ad335897 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1987,7 +1987,7 @@ def _syncParameters(bn1, bn2, ctx): _assert_tensor_close(input1.grad, input2grad) def testSyncBN(): - ndev = 2 + ndev = 1 if len(mxnet.test_utils.list_gpus()) >= 2 else 1 bn = nn.BatchNorm(in_channels=1) sync_bn = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=1, num_devices=ndev) @@ -1997,7 +1997,6 @@ def testSyncBN(): _checkBatchNormResult(bn, sync_bn, mx.nd.random.uniform(shape=(4, 1, 4, 4)), num_devices=ndev, cuda=True) - if __name__ == '__main__': import nose nose.runmodule() From c72413af28de63961dfbc4cf86b27d155d9d91af Mon Sep 17 00:00:00 2001 From: Zhang Date: Thu, 5 Jul 2018 17:56:25 -0700 Subject: [PATCH 23/34] code style --- src/operator/contrib/sync_batch_norm-inl.h | 108 ++++++++++++--------- tests/python/gpu/test_operator_gpu.py | 5 +- 2 files changed, 66 insertions(+), 47 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index bebd37e7aff7..34a2c12c7874 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -84,24 +84,49 @@ class SharedND { T *data_; bool *flag_; bool mean_ready_ = false; - bool mean_inited_ = false; + bool data_inited_ = false; std::mutex mutex_; public: explicit SharedND(int ndev) :num_devices_(ndev) { - flag_ = new bool[ndev]; - data_ = new T[ndev]; - memset(flag_, false, ndev * sizeof(bool)); + flag_ = new bool[ndev]; + data_ = new T[ndev]; + memset(flag_, false, ndev * sizeof(bool)); } ~SharedND() { + mshadow::FreeSpace(&mean_); + for (int i = 0; i < num_devices_; i++) { + mshadow::FreeSpace(&data_[i]); + } delete [] flag_; delete [] data_; } - bool Push(T input, int index) { + void Init(mshadow::Shape<1> shape) { + std::lock_guard lock(mutex_); + if (!data_inited_) { + for (int i = 0; i < num_devices_; i++) { + data_[i] = mshadow::NewTensor(shape, 0.0f); + } + mean_ = mshadow::NewTensor(shape, 0.0f); + data_inited_ = true; + } + } + + T* Retrieve(mshadow::Shape<1> shape, int index) { + if (!data_inited_) { + Init(shape); + } + if (flag_[index] == false) { + return &data_[index]; + } else { + return nullptr; + } + } + + bool SetReady(int index) { if (flag_[index] == false) { - data_[index] = input; flag_[index] = true; return true; } else { @@ -130,10 +155,6 @@ class SharedND { for (int i = 1; i < num_devices_; i++) { data_[0] += data_[i]; } - if (!mean_inited_) { - mean_ = mshadow::NewTensor(data_[0].shape_, 0.0f); - mean_inited_ = true; - } mean_ = data_[0] * 1.0f / num_devices_; mean_ready_ = true; return true; @@ -215,12 +236,12 @@ class Barrier { }; // Global variables for Synchronizations -static GlobalSharedRank globalSharedRank; -static GlobalShared globalSharedBarrier; -static GlobalShared>> globalSharedMean; -static GlobalShared>> globalSharedVar; -static GlobalShared>> globalSharedGrad; -static GlobalShared>> globalSharedProd; +static GlobalSharedRank global_shared_rank; +static GlobalShared global_shared_barrier; +static GlobalShared>> global_shared_mean; +static GlobalShared>> global_shared_var; +static GlobalShared>> global_shared_grad; +static GlobalShared>> global_shared_prod; template class SyncBatchNorm : public Operator { @@ -271,8 +292,8 @@ class SyncBatchNorm : public Operator { // whether use global statistics if (ctx.is_train && !param_.use_global_stats) { // get my rank - Barrier *global_barrier = globalSharedBarrier.Register(param_.key, param_.ndev); - int myRank = globalSharedRank.Register(param_.key, param_.ndev); + Barrier *global_barrier = global_shared_barrier.Register(param_.key, param_.ndev); + int myRank = global_shared_rank.Register(param_.key, param_.ndev); // get the mean and var Tensor mean = out_data[syncbatchnorm::kMean].get(s); Tensor var = out_data[syncbatchnorm::kVar].get(s); @@ -282,20 +303,19 @@ class SyncBatchNorm : public Operator { mean = scale * sumall_except_dim<1>(data); var = scale * sumall_except_dim<1>(F(data)); SharedND> *sharedMean = - globalSharedMean.Register(param_.key, param_.ndev); + global_shared_mean.Register(param_.key, param_.ndev); SharedND> *sharedVar = - globalSharedVar.Register(param_.key, param_.ndev); - // copy to cpu - Tensor mean_cpu = NewTensor(mean.shape_, 0.0f); - mshadow::Copy(mean_cpu, mean, s); - Tensor var_cpu = NewTensor(var.shape_, 0.0f); - mshadow::Copy(var_cpu, var, s); - // push and pull - sharedMean->Push(mean_cpu, myRank); - sharedVar->Push(var_cpu, myRank); + global_shared_var.Register(param_.key, param_.ndev); + // copy to cpu, push and pull + Tensor* mean_cpu_ptr = sharedMean->Retrieve(mean.shape_, myRank); + Tensor* var_cpu_ptr = sharedVar->Retrieve(mean.shape_, myRank); + mshadow::Copy(*mean_cpu_ptr, mean, s); + mshadow::Copy(*var_cpu_ptr, var, s); + sharedMean->SetReady(myRank); + sharedVar->SetReady(myRank); global_barrier->Wait(); - mean_cpu = sharedMean->Pop(myRank); - var_cpu = sharedVar->Pop(myRank); + Tensor mean_cpu = sharedMean->Pop(myRank); + Tensor var_cpu = sharedVar->Pop(myRank); // copy back to gpu mshadow::Copy(mean, mean_cpu, s); mshadow::Copy(var, var_cpu, s); @@ -358,8 +378,8 @@ class SyncBatchNorm : public Operator { if (ctx.is_train && !param_.use_global_stats) { // get my rank - Barrier *global_barrier = globalSharedBarrier.Register(param_.key, param_.ndev); - int myRank = globalSharedRank.Register(param_.key, param_.ndev); + Barrier *global_barrier = global_shared_barrier.Register(param_.key, param_.ndev); + int myRank = global_shared_rank.Register(param_.key, param_.ndev); // get requested temp space Tensor workspace = ctx.requested[syncbatchnorm::kTempSpace].get_space( mshadow::Shape2(5, mean.shape_[0]), s); @@ -374,22 +394,20 @@ class SyncBatchNorm : public Operator { Tensor sumProd = workspace[4]; sumGrad = sumall_except_dim<1>(grad); sumProd = sumall_except_dim<1>(grad * (data - broadcast<1>(mean, data.shape_))); - SharedND> *sharedGrad = - globalSharedGrad.Register(param_.key, param_.ndev); + global_shared_grad.Register(param_.key, param_.ndev); SharedND> *sharedProd = - globalSharedProd.Register(param_.key, param_.ndev); - - Tensor grad_cpu = NewTensor(sumGrad.shape_, 0.0f); - mshadow::Copy(grad_cpu, sumGrad, s); - Tensor prod_cpu = NewTensor(sumProd.shape_, 0.0f); - mshadow::Copy(prod_cpu, sumProd, s); - // push and pull - sharedGrad->Push(grad_cpu, myRank); - sharedProd->Push(prod_cpu, myRank); + global_shared_prod.Register(param_.key, param_.ndev); + // copy to cpu, push and pull + Tensor* grad_cpu_ptr = sharedGrad->Retrieve(sumGrad.shape_, myRank); + Tensor* prod_cpu_ptr = sharedProd->Retrieve(sumGrad.shape_, myRank); + mshadow::Copy(*grad_cpu_ptr, sumGrad, s); + mshadow::Copy(*prod_cpu_ptr, sumProd, s); + sharedGrad->SetReady(myRank); + sharedProd->SetReady(myRank); global_barrier->Wait(); - grad_cpu = sharedGrad->Pop(myRank); - prod_cpu = sharedProd->Pop(myRank); + Tensor grad_cpu = sharedGrad->Pop(myRank); + Tensor prod_cpu = sharedProd->Pop(myRank); // copy back to gpu mshadow::Copy(sumGrad, grad_cpu, s); mshadow::Copy(sumProd, prod_cpu, s); diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index e765ad335897..c04104a5931d 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1998,5 +1998,6 @@ def testSyncBN(): num_devices=ndev, cuda=True) if __name__ == '__main__': - import nose - nose.runmodule() + testSyncBN() + #import nose + #nose.runmodule() From 2e0dc7976db17399d6137ffc5f3a50590ee0bfbb Mon Sep 17 00:00:00 2001 From: Zhang Date: Thu, 5 Jul 2018 17:57:40 -0700 Subject: [PATCH 24/34] code style --- tests/python/gpu/test_operator_gpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index c04104a5931d..6e198cf10995 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1922,7 +1922,7 @@ def test_context_num_gpus(): assert mx.context.num_gpus() > 0 -def _checkBatchNormResult(bn1, bn2, input, num_devices=1, cuda=False): +def _check_batchnorm_result(bn1, bn2, input, num_devices=1, cuda=False): def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): npa, npb = a.asnumpy(), b.asnumpy() assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ @@ -1986,7 +1986,7 @@ def _syncParameters(bn1, bn2, ctx): print('input2grad', input2grad) _assert_tensor_close(input1.grad, input2grad) -def testSyncBN(): +def test_sync_batchnorm(): ndev = 1 if len(mxnet.test_utils.list_gpus()) >= 2 else 1 bn = nn.BatchNorm(in_channels=1) @@ -1994,7 +1994,7 @@ def testSyncBN(): # check with unsync version for i in range(10): - _checkBatchNormResult(bn, sync_bn, mx.nd.random.uniform(shape=(4, 1, 4, 4)), + _check_batchnorm_result(bn, sync_bn, mx.nd.random.uniform(shape=(4, 1, 4, 4)), num_devices=ndev, cuda=True) if __name__ == '__main__': From 3013cb3f8ef82c026d17aff8d6139a3a31d1273c Mon Sep 17 00:00:00 2001 From: Zhang Date: Fri, 6 Jul 2018 10:59:03 -0700 Subject: [PATCH 25/34] forward backward --- src/operator/contrib/sync_batch_norm-inl.h | 14 +++++++----- tests/python/gpu/test_operator_gpu.py | 26 +++++++++++++--------- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index 34a2c12c7874..49cbaa68875e 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -236,8 +236,10 @@ class Barrier { }; // Global variables for Synchronizations -static GlobalSharedRank global_shared_rank; -static GlobalShared global_shared_barrier; +static GlobalSharedRank global_shared_rank_forward; +static GlobalSharedRank global_shared_rank_backward; +static GlobalShared global_shared_barrier_forward; +static GlobalShared global_shared_barrier_backward; static GlobalShared>> global_shared_mean; static GlobalShared>> global_shared_var; static GlobalShared>> global_shared_grad; @@ -292,8 +294,8 @@ class SyncBatchNorm : public Operator { // whether use global statistics if (ctx.is_train && !param_.use_global_stats) { // get my rank - Barrier *global_barrier = global_shared_barrier.Register(param_.key, param_.ndev); - int myRank = global_shared_rank.Register(param_.key, param_.ndev); + Barrier *global_barrier = global_shared_barrier_forward.Register(param_.key, param_.ndev); + int myRank = global_shared_rank_forward.Register(param_.key, param_.ndev); // get the mean and var Tensor mean = out_data[syncbatchnorm::kMean].get(s); Tensor var = out_data[syncbatchnorm::kVar].get(s); @@ -378,8 +380,8 @@ class SyncBatchNorm : public Operator { if (ctx.is_train && !param_.use_global_stats) { // get my rank - Barrier *global_barrier = global_shared_barrier.Register(param_.key, param_.ndev); - int myRank = global_shared_rank.Register(param_.key, param_.ndev); + Barrier *global_barrier = global_shared_barrier_backward.Register(param_.key, param_.ndev); + int myRank = global_shared_rank_backward.Register(param_.key, param_.ndev); // get requested temp space Tensor workspace = ctx.requested[syncbatchnorm::kTempSpace].get_space( mshadow::Shape2(5, mean.shape_[0]), s); diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 6e198cf10995..6c11d86d8df9 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1922,7 +1922,7 @@ def test_context_num_gpus(): assert mx.context.num_gpus() > 0 -def _check_batchnorm_result(bn1, bn2, input, num_devices=1, cuda=False): +def _check_batchnorm_result(input, num_devices=1, cuda=False): def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): npa, npb = a.asnumpy(), b.asnumpy() assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ @@ -1953,6 +1953,10 @@ def _syncParameters(bn1, bn2, ctx): else: ctx_list = [mx.cpu(0) for _ in range(num_devices)] + nch = input.shape[1] + bn1 = nn.BatchNorm(in_channels=nch) + bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, num_devices=num_devices) + bn1.initialize(ctx=ctx_list[0]) bn2.initialize(ctx=ctx_list) @@ -1987,17 +1991,19 @@ def _syncParameters(bn1, bn2, ctx): _assert_tensor_close(input1.grad, input2grad) def test_sync_batchnorm(): - ndev = 1 if len(mxnet.test_utils.list_gpus()) >= 2 else 1 - - bn = nn.BatchNorm(in_channels=1) - sync_bn = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=1, num_devices=ndev) + def get_num_devices(): + for i in range(100): + try: + mx.nd.zeros((1,), ctx=mx.gpu(i)) + except: + return i + ndev = 1 if get_num_devices() >= 2 else 1 # check with unsync version for i in range(10): - _check_batchnorm_result(bn, sync_bn, mx.nd.random.uniform(shape=(4, 1, 4, 4)), - num_devices=ndev, cuda=True) + _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), + num_devices=ndev, cuda=True) if __name__ == '__main__': - testSyncBN() - #import nose - #nose.runmodule() + import nose + nose.runmodule() From ffce503c05ac7adc7767bb698097fbd48be915db Mon Sep 17 00:00:00 2001 From: Zhang Date: Fri, 6 Jul 2018 13:34:02 -0700 Subject: [PATCH 26/34] test --- tests/python/gpu/test_operator_gpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 6c11d86d8df9..295dfe2e8505 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1997,7 +1997,7 @@ def get_num_devices(): mx.nd.zeros((1,), ctx=mx.gpu(i)) except: return i - ndev = 1 if get_num_devices() >= 2 else 1 + ndev = 1 #if get_num_devices() >= 2 else 1 # check with unsync version for i in range(10): From 6acdc711ca42af8517f83b584a29d4851fab8481 Mon Sep 17 00:00:00 2001 From: Zhang Date: Fri, 6 Jul 2018 16:26:11 -0700 Subject: [PATCH 27/34] cpu test --- tests/python/gpu/test_operator_gpu.py | 85 ------------------ .../python/unittest/test_contrib_operator.py | 89 ++++++++++++++++++- 2 files changed, 87 insertions(+), 87 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 295dfe2e8505..5622beecb674 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -28,8 +28,6 @@ from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal from mxnet.base import MXNetError from mxnet import autograd -from mxnet.gluon import nn -from mxnet.gluon.utils import split_and_load from numpy.testing import assert_allclose curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) @@ -1921,89 +1919,6 @@ def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. assert mx.context.num_gpus() > 0 - -def _check_batchnorm_result(input, num_devices=1, cuda=False): - def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): - npa, npb = a.asnumpy(), b.asnumpy() - assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ - 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format( - a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) - - def _find_bn(module): - if isinstance(module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module - elif isinstance(module.module, (nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module.module - - raise RuntimeError('BN not found') - - def _syncParameters(bn1, bn2, ctx): - ctx = input.context - bn2.gamma.set_data(bn1.gamma.data(ctx)) - bn2.beta.set_data(bn1.beta.data(ctx)) - bn2.running_mean.set_data(bn1.running_mean.data(ctx)) - bn2.running_var.set_data(bn1.running_var.data(ctx)) - - input1 = input.copy() - input2 = input.copy() - - if cuda: - input1 = input.as_in_context(mx.gpu(0)) - ctx_list = [mx.gpu(i) for i in range(num_devices)] - else: - ctx_list = [mx.cpu(0) for _ in range(num_devices)] - - nch = input.shape[1] - bn1 = nn.BatchNorm(in_channels=nch) - bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, num_devices=num_devices) - - bn1.initialize(ctx=ctx_list[0]) - bn2.initialize(ctx=ctx_list) - - # using the same values for gamma and beta - #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) - - - input1.attach_grad() - inputs2 = split_and_load(input2, ctx_list, batch_axis=0) - for xi in inputs2: - xi.attach_grad() - - with mx.autograd.record(): - output1 = bn1(input1) - output2 = [bn2(xi) for xi in inputs2] - loss1 = (output1 ** 2).sum() - loss2 = [(output ** 2).sum() for output in output2] - mx.autograd.backward(loss1) - mx.autograd.backward(loss2) - - output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) - # assert forwarding - _assert_tensor_close(input1, input2) - _assert_tensor_close(output1, output2) - _assert_tensor_close(_find_bn(bn1).running_mean.data(ctx_list[0]), - _find_bn(bn2).running_mean.data(ctx_list[0])) - _assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), - _find_bn(bn2).running_var.data(ctx_list[0])) - input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) - print('input1.grad', input1.grad) - print('input2grad', input2grad) - _assert_tensor_close(input1.grad, input2grad) - -def test_sync_batchnorm(): - def get_num_devices(): - for i in range(100): - try: - mx.nd.zeros((1,), ctx=mx.gpu(i)) - except: - return i - ndev = 1 #if get_num_devices() >= 2 else 1 - - # check with unsync version - for i in range(10): - _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), - num_devices=ndev, cuda=True) - if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py index a220f08d20d4..5fea0a169c04 100644 --- a/tests/python/unittest/test_contrib_operator.py +++ b/tests/python/unittest/test_contrib_operator.py @@ -243,6 +243,91 @@ def assert_match(inputs, x, y, threshold, is_ascend=False): assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [1, -1, 0], [2, 0], 1e-12, False) assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [-1, 0, 1], [1, 2], 100, True) + +def _check_batchnorm_result(input, num_devices=1, cuda=False): + from mxnet.gluon.utils import split_and_load + def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): + npa, npb = a.asnumpy(), b.asnumpy() + assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format( + a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + + def _find_bn(module): + if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module + elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module.module + + raise RuntimeError('BN not found') + + def _syncParameters(bn1, bn2, ctx): + ctx = input.context + bn2.gamma.set_data(bn1.gamma.data(ctx)) + bn2.beta.set_data(bn1.beta.data(ctx)) + bn2.running_mean.set_data(bn1.running_mean.data(ctx)) + bn2.running_var.set_data(bn1.running_var.data(ctx)) + + input1 = input.copy() + input2 = input.copy() + + if cuda: + input1 = input.as_in_context(mx.gpu(0)) + ctx_list = [mx.gpu(i) for i in range(num_devices)] + else: + ctx_list = [mx.cpu(0) for _ in range(num_devices)] + + nch = input.shape[1] + bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) + bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, num_devices=num_devices) + + bn1.initialize(ctx=ctx_list[0]) + bn2.initialize(ctx=ctx_list) + + # using the same values for gamma and beta + #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) + + + input1.attach_grad() + inputs2 = split_and_load(input2, ctx_list, batch_axis=0) + for xi in inputs2: + xi.attach_grad() + + with mx.autograd.record(): + output1 = bn1(input1) + output2 = [bn2(xi) for xi in inputs2] + loss1 = (output1 ** 2).sum() + loss2 = [(output ** 2).sum() for output in output2] + mx.autograd.backward(loss1) + mx.autograd.backward(loss2) + + output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) + # assert forwarding + _assert_tensor_close(input1, input2) + _assert_tensor_close(output1, output2) + _assert_tensor_close(_find_bn(bn1).running_mean.data(ctx_list[0]), + _find_bn(bn2).running_mean.data(ctx_list[0])) + _assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), + _find_bn(bn2).running_var.data(ctx_list[0])) + input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + print('input1.grad', input1.grad) + print('input2grad', input2grad) + _assert_tensor_close(input1.grad, input2grad) + +def test_sync_batchnorm(): + def get_num_devices(): + for i in range(100): + try: + mx.nd.zeros((1,), ctx=mx.gpu(i)) + except: + return i + ndev = 1 #if get_num_devices() >= 2 else 1 + + # check with unsync version + for i in range(10): + _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), + num_devices=ndev, cuda=False) + if __name__ == '__main__': - import nose - nose.runmodule() + test_sync_batchnorm() + #import nose + #nose.runmodule() From 3a439d0f9c7d78043fd4ef6401945e43143d2c47 Mon Sep 17 00:00:00 2001 From: Zhang Date: Tue, 10 Jul 2018 10:45:04 -0700 Subject: [PATCH 28/34] fix deconstruction --- src/operator/contrib/sync_batch_norm-inl.h | 7 +- tests/python/gpu/test_operator_gpu.py | 85 ++++++++++++++++++ .../python/unittest/test_contrib_operator.py | 89 +------------------ 3 files changed, 89 insertions(+), 92 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index 49cbaa68875e..c2be8b4da492 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -96,9 +96,6 @@ class SharedND { ~SharedND() { mshadow::FreeSpace(&mean_); - for (int i = 0; i < num_devices_; i++) { - mshadow::FreeSpace(&data_[i]); - } delete [] flag_; delete [] data_; } @@ -182,7 +179,7 @@ class GlobalShared { ~GlobalShared() { for (auto it = registry_.begin(); it != registry_.end(); it++) { T *ptr = it->second; - delete [] ptr; + delete ptr; } } private: @@ -208,7 +205,7 @@ class GlobalSharedRank { ~GlobalSharedRank() { for (auto it = registry_.begin(); it != registry_.end(); it++) { T *ptr = it->second; - delete [] ptr; + delete ptr; } } private: diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index f8930e12d1da..cf3defb1cc2b 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1909,6 +1909,91 @@ def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. assert mx.context.num_gpus() > 0 +def _check_batchnorm_result(input, num_devices=1, cuda=False): + from mxnet.gluon.utils import split_and_load + def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): + npa, npb = a.asnumpy(), b.asnumpy() + assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ + 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format( + a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) + + def _find_bn(module): + if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module + elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module.module + + raise RuntimeError('BN not found') + + def _syncParameters(bn1, bn2, ctx): + ctx = input.context + bn2.gamma.set_data(bn1.gamma.data(ctx)) + bn2.beta.set_data(bn1.beta.data(ctx)) + bn2.running_mean.set_data(bn1.running_mean.data(ctx)) + bn2.running_var.set_data(bn1.running_var.data(ctx)) + + input1 = input.copy() + input2 = input.copy() + + if cuda: + input1 = input.as_in_context(mx.gpu(0)) + ctx_list = [mx.gpu(i) for i in range(num_devices)] + else: + ctx_list = [mx.cpu(0) for _ in range(num_devices)] + + nch = input.shape[1] + bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) + bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, num_devices=num_devices) + + bn1.initialize(ctx=ctx_list[0]) + bn2.initialize(ctx=ctx_list) + + # using the same values for gamma and beta + #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) + + + input1.attach_grad() + inputs2 = split_and_load(input2, ctx_list, batch_axis=0) + for xi in inputs2: + xi.attach_grad() + + with mx.autograd.record(): + output1 = bn1(input1) + output2 = [bn2(xi) for xi in inputs2] + loss1 = (output1 ** 2).sum() + loss2 = [(output ** 2).sum() for output in output2] + mx.autograd.backward(loss1) + mx.autograd.backward(loss2) + + output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) + # assert forwarding + _assert_tensor_close(input1, input2) + _assert_tensor_close(output1, output2) + _assert_tensor_close(_find_bn(bn1).running_mean.data(ctx_list[0]), + _find_bn(bn2).running_mean.data(ctx_list[0])) + _assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), + _find_bn(bn2).running_var.data(ctx_list[0])) + input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + #print('input1.grad', input1.grad) + #print('input2grad', input2grad) + _assert_tensor_close(input1.grad, input2grad) + +def test_sync_batchnorm(): + def get_num_devices(): + for i in range(100): + try: + mx.nd.zeros((1,), ctx=mx.gpu(i)) + except: + return i + # no need to use SyncBN with 1 gpu + if get_num_devices() < 2: + return + ndev = 2 + # check with unsync version + for i in range(10): + _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), + num_devices=ndev, cuda=True) + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_contrib_operator.py b/tests/python/unittest/test_contrib_operator.py index 5fea0a169c04..a220f08d20d4 100644 --- a/tests/python/unittest/test_contrib_operator.py +++ b/tests/python/unittest/test_contrib_operator.py @@ -243,91 +243,6 @@ def assert_match(inputs, x, y, threshold, is_ascend=False): assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [1, -1, 0], [2, 0], 1e-12, False) assert_match([[0.5, 0.6], [0.1, 0.2], [0.3, 0.4]], [-1, 0, 1], [1, 2], 100, True) - -def _check_batchnorm_result(input, num_devices=1, cuda=False): - from mxnet.gluon.utils import split_and_load - def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): - npa, npb = a.asnumpy(), b.asnumpy() - assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ - 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format( - a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) - - def _find_bn(module): - if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module - elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): - return module.module - - raise RuntimeError('BN not found') - - def _syncParameters(bn1, bn2, ctx): - ctx = input.context - bn2.gamma.set_data(bn1.gamma.data(ctx)) - bn2.beta.set_data(bn1.beta.data(ctx)) - bn2.running_mean.set_data(bn1.running_mean.data(ctx)) - bn2.running_var.set_data(bn1.running_var.data(ctx)) - - input1 = input.copy() - input2 = input.copy() - - if cuda: - input1 = input.as_in_context(mx.gpu(0)) - ctx_list = [mx.gpu(i) for i in range(num_devices)] - else: - ctx_list = [mx.cpu(0) for _ in range(num_devices)] - - nch = input.shape[1] - bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) - bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, num_devices=num_devices) - - bn1.initialize(ctx=ctx_list[0]) - bn2.initialize(ctx=ctx_list) - - # using the same values for gamma and beta - #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) - - - input1.attach_grad() - inputs2 = split_and_load(input2, ctx_list, batch_axis=0) - for xi in inputs2: - xi.attach_grad() - - with mx.autograd.record(): - output1 = bn1(input1) - output2 = [bn2(xi) for xi in inputs2] - loss1 = (output1 ** 2).sum() - loss2 = [(output ** 2).sum() for output in output2] - mx.autograd.backward(loss1) - mx.autograd.backward(loss2) - - output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) - # assert forwarding - _assert_tensor_close(input1, input2) - _assert_tensor_close(output1, output2) - _assert_tensor_close(_find_bn(bn1).running_mean.data(ctx_list[0]), - _find_bn(bn2).running_mean.data(ctx_list[0])) - _assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), - _find_bn(bn2).running_var.data(ctx_list[0])) - input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) - print('input1.grad', input1.grad) - print('input2grad', input2grad) - _assert_tensor_close(input1.grad, input2grad) - -def test_sync_batchnorm(): - def get_num_devices(): - for i in range(100): - try: - mx.nd.zeros((1,), ctx=mx.gpu(i)) - except: - return i - ndev = 1 #if get_num_devices() >= 2 else 1 - - # check with unsync version - for i in range(10): - _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), - num_devices=ndev, cuda=False) - if __name__ == '__main__': - test_sync_batchnorm() - #import nose - #nose.runmodule() + import nose + nose.runmodule() From a7918e0e219deca22bfd93a9bb9ac31e04c662b0 Mon Sep 17 00:00:00 2001 From: Zhang Date: Tue, 10 Jul 2018 13:24:11 -0700 Subject: [PATCH 29/34] doc indent --- python/mxnet/gluon/contrib/nn/basic_layers.py | 8 +++++--- src/operator/contrib/sync_batch_norm.cc | 7 ++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index 69efaa77e4da..30c276b82b91 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -158,7 +158,8 @@ class SyncBatchNorm(BatchNorm): """Cross-GPU Synchronized Batch normalization (SyncBN) Standard BN [1]_ implementation only normalize the data within each device. SyncBN normalizes the input within the whole mini-batch. - We follow the sync-onece implmentation described in the paper [2]_ . + We follow the sync-onece implmentation described in the paper [2]_. + Parameters ---------- in_channels : int, default 0 @@ -199,9 +200,10 @@ class SyncBatchNorm(BatchNorm): Reference: .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating - deep network training by reducing internal covariate shift." *ICML 2015* + deep network training by reducing internal covariate shift." *ICML 2015* + .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, - Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* + Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* """ def __init__(self, in_channels=0, num_devices=None, momentum=0.9, epsilon=1e-5, center=True, scale=True, use_global_stats=False, beta_initializer='zeros', diff --git a/src/operator/contrib/sync_batch_norm.cc b/src/operator/contrib/sync_batch_norm.cc index 42e5d4d35404..73ee0018f02b 100644 --- a/src/operator/contrib/sync_batch_norm.cc +++ b/src/operator/contrib/sync_batch_norm.cc @@ -52,7 +52,7 @@ Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as well as offset ``beta``. Standard BN [1]_ implementation only normalize the data within each device. SyncBN normalizes the input within the whole mini-batch. -We follow the sync-onece implmentation described in the paper [2]_ . +We follow the sync-onece implmentation described in the paper [2]_. Assume the input has more than one dimension and we normalize along axis 1. We first compute the mean and variance along this axis: @@ -91,9 +91,10 @@ then set ``gamma`` to 1 and its gradient to 0. Reference: .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating - deep network training by reducing internal covariate shift." *ICML 2015* + deep network training by reducing internal covariate shift." *ICML 2015* + .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, - Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* + Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* )code" ADD_FILELINE) .add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization") .add_argument("gamma", "NDArray-or-Symbol", "gamma array") From 55fef70e29396b3c4e7b79774315f8034a4657aa Mon Sep 17 00:00:00 2001 From: Zhang Date: Tue, 10 Jul 2018 13:28:43 -0700 Subject: [PATCH 30/34] doc --- python/mxnet/gluon/contrib/nn/basic_layers.py | 5 ++--- src/operator/contrib/sync_batch_norm.cc | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index 30c276b82b91..52824bee9d89 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -199,10 +199,9 @@ class SyncBatchNorm(BatchNorm): - **out**: output tensor with the same shape as `data`. Reference: - .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating + .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating \ deep network training by reducing internal covariate shift." *ICML 2015* - - .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, + .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, \ Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* """ def __init__(self, in_channels=0, num_devices=None, momentum=0.9, epsilon=1e-5, diff --git a/src/operator/contrib/sync_batch_norm.cc b/src/operator/contrib/sync_batch_norm.cc index 73ee0018f02b..1b465d88b69e 100644 --- a/src/operator/contrib/sync_batch_norm.cc +++ b/src/operator/contrib/sync_batch_norm.cc @@ -90,10 +90,9 @@ Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is tr then set ``gamma`` to 1 and its gradient to 0. Reference: - .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating + .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating \ deep network training by reducing internal covariate shift." *ICML 2015* - - .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, + .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, \ Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* )code" ADD_FILELINE) .add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization") From 9884d6015ec0850dfbf6983121050b4af5fb1689 Mon Sep 17 00:00:00 2001 From: Zhang Date: Tue, 10 Jul 2018 18:10:06 -0700 Subject: [PATCH 31/34] doc --- python/mxnet/gluon/contrib/nn/basic_layers.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index 52824bee9d89..f764fa195a92 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -156,6 +156,7 @@ def __repr__(self): class SyncBatchNorm(BatchNorm): """Cross-GPU Synchronized Batch normalization (SyncBN) + Standard BN [1]_ implementation only normalize the data within each device. SyncBN normalizes the input within the whole mini-batch. We follow the sync-onece implmentation described in the paper [2]_. From a2780ceae31465350fae14a47f4b7b45e9112a57 Mon Sep 17 00:00:00 2001 From: Zhang Date: Thu, 12 Jul 2018 14:48:16 -0700 Subject: [PATCH 32/34] address comments --- src/operator/contrib/sync_batch_norm-inl.h | 6 ++++-- tests/python/gpu/test_operator_gpu.py | 23 +++++++--------------- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h index c2be8b4da492..1f548dbc7e5e 100644 --- a/src/operator/contrib/sync_batch_norm-inl.h +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -95,7 +95,7 @@ class SharedND { } ~SharedND() { - mshadow::FreeSpace(&mean_); + if (data_inited_) mshadow::FreeSpace(&mean_); delete [] flag_; delete [] data_; } @@ -112,6 +112,7 @@ class SharedND { } T* Retrieve(mshadow::Shape<1> shape, int index) { + // Retrieve a pointer for copying values if (!data_inited_) { Init(shape); } @@ -123,6 +124,7 @@ class SharedND { } bool SetReady(int index) { + // Set data ready after copying if (flag_[index] == false) { flag_[index] = true; return true; @@ -132,6 +134,7 @@ class SharedND { } T Pop(int index) { + // Pop the mean value after suming up std::lock_guard lock(mutex_); while (!MeanReady()) {} flag_[index] = false; @@ -384,7 +387,6 @@ class SyncBatchNorm : public Operator { mshadow::Shape2(5, mean.shape_[0]), s); Tensor gmean = workspace[0]; Tensor gvar = workspace[1]; - // Tensor tmp = workspace[2]; moving_mean = moving_mean * param_.momentum + mean * (1 - param_.momentum); moving_var = moving_var * param_.momentum + var * (1 - param_.momentum); diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index cf3defb1cc2b..46ff92dfcc22 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1911,12 +1911,6 @@ def test_context_num_gpus(): def _check_batchnorm_result(input, num_devices=1, cuda=False): from mxnet.gluon.utils import split_and_load - def _assert_tensor_close(a, b, atol=1e-3, rtol=1e-3): - npa, npb = a.asnumpy(), b.asnumpy() - assert np.allclose(npa, npb, rtol=rtol, atol=atol), \ - 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format( - a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) - def _find_bn(module): if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): return module @@ -1951,7 +1945,6 @@ def _syncParameters(bn1, bn2, ctx): # using the same values for gamma and beta #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) - input1.attach_grad() inputs2 = split_and_load(input2, ctx_list, batch_axis=0) for xi in inputs2: @@ -1967,16 +1960,14 @@ def _syncParameters(bn1, bn2, ctx): output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) # assert forwarding - _assert_tensor_close(input1, input2) - _assert_tensor_close(output1, output2) - _assert_tensor_close(_find_bn(bn1).running_mean.data(ctx_list[0]), - _find_bn(bn2).running_mean.data(ctx_list[0])) - _assert_tensor_close(_find_bn(bn1).running_var.data(ctx_list[0]), - _find_bn(bn2).running_var.data(ctx_list[0])) + assert_almost_equal(input1, input2, atol=1e-3, rtol=1e-3) + assert_almost_equal(output1, output2, atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0], atol=1e-3, rtol=1e-3), + _find_bn(bn2).running_mean.data(ctx_list[0]), atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]), + _find_bn(bn2).running_var.data(ctx_list[0]), atol=1e-3, rtol=1e-3) input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) - #print('input1.grad', input1.grad) - #print('input2grad', input2grad) - _assert_tensor_close(input1.grad, input2grad) + assert_almost_equal(input1.grad, input2grad, atol=1e-3, rtol=1e-3) def test_sync_batchnorm(): def get_num_devices(): From 16df5d47177c7c52f6a0d97d497b19f1c1561800 Mon Sep 17 00:00:00 2001 From: Zhang Date: Thu, 12 Jul 2018 17:05:11 -0700 Subject: [PATCH 33/34] typo --- tests/python/gpu/test_operator_gpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 46ff92dfcc22..8dde14405d0b 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1962,10 +1962,10 @@ def _syncParameters(bn1, bn2, ctx): # assert forwarding assert_almost_equal(input1, input2, atol=1e-3, rtol=1e-3) assert_almost_equal(output1, output2, atol=1e-3, rtol=1e-3) - assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0], atol=1e-3, rtol=1e-3), - _find_bn(bn2).running_mean.data(ctx_list[0]), atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]), + _find_bn(bn2).running_mean.data(ctx_list[0]), atol=1e-3, rtol=1e-3) assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]), - _find_bn(bn2).running_var.data(ctx_list[0]), atol=1e-3, rtol=1e-3) + _find_bn(bn2).running_var.data(ctx_list[0]), atol=1e-3, rtol=1e-3) input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) assert_almost_equal(input1.grad, input2grad, atol=1e-3, rtol=1e-3) From 809854d8063388628d46b21beb9fdbd9272c9c1d Mon Sep 17 00:00:00 2001 From: Zhang Date: Thu, 12 Jul 2018 19:29:34 -0700 Subject: [PATCH 34/34] asnumpy --- tests/python/gpu/test_operator_gpu.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 8dde14405d0b..b3a54b164c3a 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1960,14 +1960,16 @@ def _syncParameters(bn1, bn2, ctx): output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) # assert forwarding - assert_almost_equal(input1, input2, atol=1e-3, rtol=1e-3) - assert_almost_equal(output1, output2, atol=1e-3, rtol=1e-3) - assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]), - _find_bn(bn2).running_mean.data(ctx_list[0]), atol=1e-3, rtol=1e-3) - assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]), - _find_bn(bn2).running_var.data(ctx_list[0]), atol=1e-3, rtol=1e-3) + assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), + _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), + atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), + _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), + atol=1e-3, rtol=1e-3) input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) - assert_almost_equal(input1.grad, input2grad, atol=1e-3, rtol=1e-3) + assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), atol=1e-3, rtol=1e-3) def test_sync_batchnorm(): def get_num_devices():