diff --git a/docs/api/python/gluon/contrib.md b/docs/api/python/gluon/contrib.md index 877a294d9a1f..98f36f882164 100644 --- a/docs/api/python/gluon/contrib.md +++ b/docs/api/python/gluon/contrib.md @@ -36,6 +36,7 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p HybridConcurrent Identity SparseEmbedding + SyncBatchNorm ``` ### Recurrent neural network diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py b/python/mxnet/gluon/contrib/nn/basic_layers.py index 1edef1476ee3..f764fa195a92 100644 --- a/python/mxnet/gluon/contrib/nn/basic_layers.py +++ b/python/mxnet/gluon/contrib/nn/basic_layers.py @@ -18,11 +18,13 @@ # coding: utf-8 # pylint: disable= arguments-differ """Custom neural network layers in model_zoo.""" -__all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding'] +__all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding', + 'SyncBatchNorm'] -from .... import nd +import warnings +from .... import nd, test_utils from ...block import HybridBlock, Block -from ...nn import Sequential, HybridSequential +from ...nn import Sequential, HybridSequential, BatchNorm class Concurrent(Sequential): """Lays `Block`s concurrently. @@ -151,3 +153,79 @@ def __repr__(self): s = '{block_name}({input_dim} -> {output_dim}, {dtype})' return s.format(block_name=self.__class__.__name__, **self._kwargs) + +class SyncBatchNorm(BatchNorm): + """Cross-GPU Synchronized Batch normalization (SyncBN) + + Standard BN [1]_ implementation only normalize the data within each device. + SyncBN normalizes the input within the whole mini-batch. + We follow the sync-onece implmentation described in the paper [2]_. + + Parameters + ---------- + in_channels : int, default 0 + Number of channels (feature maps) in input data. If not specified, + initialization will be deferred to the first time `forward` is called + and `in_channels` will be inferred from the shape of input data. + num_devices : int, default number of visible GPUs + momentum: float, default 0.9 + Momentum for the moving average. + epsilon: float, default 1e-5 + Small float added to variance to avoid dividing by zero. + center: bool, default True + If True, add offset of `beta` to normalized tensor. + If False, `beta` is ignored. + scale: bool, default True + If True, multiply by `gamma`. If False, `gamma` is not used. + When the next layer is linear (also e.g. `nn.relu`), + this can be disabled since the scaling + will be done by the next layer. + use_global_stats: bool, default False + If True, use global moving statistics instead of local batch-norm. This will force + change batch-norm into a scale shift operator. + If False, use local batch-norm. + beta_initializer: str or `Initializer`, default 'zeros' + Initializer for the beta weight. + gamma_initializer: str or `Initializer`, default 'ones' + Initializer for the gamma weight. + moving_mean_initializer: str or `Initializer`, default 'zeros' + Initializer for the moving mean. + moving_variance_initializer: str or `Initializer`, default 'ones' + Initializer for the moving variance. + + + Inputs: + - **data**: input tensor with arbitrary shape. + Outputs: + - **out**: output tensor with the same shape as `data`. + + Reference: + .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating \ + deep network training by reducing internal covariate shift." *ICML 2015* + .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, \ + Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* + """ + def __init__(self, in_channels=0, num_devices=None, momentum=0.9, epsilon=1e-5, + center=True, scale=True, use_global_stats=False, beta_initializer='zeros', + gamma_initializer='ones', running_mean_initializer='zeros', + running_variance_initializer='ones', **kwargs): + super(SyncBatchNorm, self).__init__(1, momentum, epsilon, center, scale, use_global_stats, + beta_initializer, gamma_initializer, + running_mean_initializer, running_variance_initializer, + in_channels, **kwargs) + num_devices = self._get_num_devices() if num_devices is None else num_devices + self._kwargs = {'eps': epsilon, 'momentum': momentum, + 'fix_gamma': not scale, 'use_global_stats': use_global_stats, + 'ndev': num_devices, 'key': self.prefix} + + def _get_num_devices(self): + warnings.warn("Caution using SyncBatchNorm: " + "if not using all the GPUs, please mannually set num_devices", + UserWarning) + num_devices = len(test_utils.list_gpus()) + num_devices = num_devices if num_devices > 0 else 1 + return num_devices + + def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var): + return F.contrib.SyncBatchNorm(x, gamma, beta, running_mean, running_var, + name='fwd', **self._kwargs) diff --git a/src/operator/contrib/sync_batch_norm-inl.h b/src/operator/contrib/sync_batch_norm-inl.h new file mode 100644 index 000000000000..1f548dbc7e5e --- /dev/null +++ b/src/operator/contrib/sync_batch_norm-inl.h @@ -0,0 +1,594 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file sync_batch_norm-inl.h + * \brief Synchronized BatchNorm modified from BatchNormV1 + * \author Hang Zhang +*/ +#ifndef MXNET_OPERATOR_CONTRIB_SYNC_BATCH_NORM_INL_H_ +#define MXNET_OPERATOR_CONTRIB_SYNC_BATCH_NORM_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "../operator_common.h" +#include "../mshadow_op.h" + +namespace mxnet { +namespace op { + +namespace syncbatchnorm { +enum BatchNormOpInputs {kData, kGamma, kBeta}; +enum BatchNormOpOutputs {kOut, kMean, kVar}; +enum BatchNormOpAuxiliary {kMovingMean, kMovingVar}; +enum BatchNormBackResource {kTempSpace}; +} // namespace syncbatchnorm + +struct SyncBatchNormParam : public dmlc::Parameter { + float eps; + float momentum; + bool fix_gamma; + bool use_global_stats; + bool output_mean_var; + int ndev; + std::string key; + DMLC_DECLARE_PARAMETER(SyncBatchNormParam) { + DMLC_DECLARE_FIELD(eps).set_default(1e-3f) + .describe("Epsilon to prevent div 0"); + DMLC_DECLARE_FIELD(momentum).set_default(0.9f) + .describe("Momentum for moving average"); + DMLC_DECLARE_FIELD(fix_gamma).set_default(true) + .describe("Fix gamma while training"); + DMLC_DECLARE_FIELD(use_global_stats).set_default(false) + .describe("Whether use global moving statistics instead of local batch-norm. " + "This will force change batch-norm into a scale shift operator."); + DMLC_DECLARE_FIELD(output_mean_var).set_default(false) + .describe("Output All,normal mean and var"); + DMLC_DECLARE_FIELD(ndev).set_default(1) + .describe("The count of GPU devices"); + DMLC_DECLARE_FIELD(key) + .set_default("") + .describe("Hash key for synchronization, please set the same hash key for same layer, " + "Block.prefix is typically used as in :class:`gluon.nn.contrib.SyncBatchNorm`."); + } +}; + +// Modified from https://github.com/brucechin/SharedTensor +template +class SharedND { + private: + int num_devices_; + T mean_; + T *data_; + bool *flag_; + bool mean_ready_ = false; + bool data_inited_ = false; + std::mutex mutex_; + + public: + explicit SharedND(int ndev) :num_devices_(ndev) { + flag_ = new bool[ndev]; + data_ = new T[ndev]; + memset(flag_, false, ndev * sizeof(bool)); + } + + ~SharedND() { + if (data_inited_) mshadow::FreeSpace(&mean_); + delete [] flag_; + delete [] data_; + } + + void Init(mshadow::Shape<1> shape) { + std::lock_guard lock(mutex_); + if (!data_inited_) { + for (int i = 0; i < num_devices_; i++) { + data_[i] = mshadow::NewTensor(shape, 0.0f); + } + mean_ = mshadow::NewTensor(shape, 0.0f); + data_inited_ = true; + } + } + + T* Retrieve(mshadow::Shape<1> shape, int index) { + // Retrieve a pointer for copying values + if (!data_inited_) { + Init(shape); + } + if (flag_[index] == false) { + return &data_[index]; + } else { + return nullptr; + } + } + + bool SetReady(int index) { + // Set data ready after copying + if (flag_[index] == false) { + flag_[index] = true; + return true; + } else { + return false; + } + } + + T Pop(int index) { + // Pop the mean value after suming up + std::lock_guard lock(mutex_); + while (!MeanReady()) {} + flag_[index] = false; + T tmp = mean_; + ResetMean(); + return tmp; + } + + bool MeanReady() { + if (mean_ready_) { + return true; + } + for (int i = 0; i < num_devices_; i++) { + if (!flag_[i]) { + return false; + } + } + for (int i = 1; i < num_devices_; i++) { + data_[0] += data_[i]; + } + mean_ = data_[0] * 1.0f / num_devices_; + mean_ready_ = true; + return true; + } + + void ResetMean() { + for (int i = 0; i < num_devices_; i++) { + if (flag_[i]) return; + } + mean_ready_ = false; + } +}; + +template +class GlobalShared { + public: + T* Register(const std::string &key, int ndev) { + std::lock_guard lock(mutex_); + auto it = registry_.find(key); + if (it != registry_.end()) return it->second; + T *newT = new T(ndev); + registry_[key] = newT; + return newT; + } + ~GlobalShared() { + for (auto it = registry_.begin(); it != registry_.end(); it++) { + T *ptr = it->second; + delete ptr; + } + } + private: + std::mutex mutex_; + std::map registry_; +}; + +template +class GlobalSharedRank { + public: + T Register(const std::string &key, int ndev) { + std::lock_guard lock(mutex_); + auto it = registry_.find(key); + if (it != registry_.end()) { + T* tmpT = it->second; + *tmpT = (*tmpT == ndev - 1) ? 0 : *tmpT + 1; + return *tmpT; + } + T *newT = new T(0); + registry_[key] = newT; + return *newT; + } + ~GlobalSharedRank() { + for (auto it = registry_.begin(); it != registry_.end(); it++) { + T *ptr = it->second; + delete ptr; + } + } + private: + std::mutex mutex_; + std::map registry_; +}; + +class Barrier { + private: + std::mutex mutex_; + std::condition_variable cv_; + std::size_t count_; + std::size_t total_count_; + public: + explicit Barrier(std::size_t count) : count_{count}, total_count_{count} { } + void Wait() { + std::unique_lock lock{mutex_}; + if (--count_ == 0) { + count_ = total_count_; + cv_.notify_all(); + } else { + cv_.wait(lock, [this] { return count_ == total_count_; }); + } + } +}; + +// Global variables for Synchronizations +static GlobalSharedRank global_shared_rank_forward; +static GlobalSharedRank global_shared_rank_backward; +static GlobalShared global_shared_barrier_forward; +static GlobalShared global_shared_barrier_backward; +static GlobalShared>> global_shared_mean; +static GlobalShared>> global_shared_var; +static GlobalShared>> global_shared_grad; +static GlobalShared>> global_shared_prod; + +template +class SyncBatchNorm : public Operator { + public: + explicit SyncBatchNorm(SyncBatchNormParam param) { + this->param_ = param; + } + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_states) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 3U); + CHECK_EQ(aux_states.size(), 2U); + if (ctx.is_train) { + CHECK_EQ(out_data.size(), 3U); + CHECK_EQ(req.size(), 3U); + } else { + CHECK_GE(out_data.size(), 1U); + CHECK_GE(req.size(), 1U); + CHECK_EQ(req[syncbatchnorm::kOut], kWriteTo); + } + + Stream *s = ctx.get_stream(); + const real_t scale = static_cast(in_data[syncbatchnorm::kData].shape_[1]) / + static_cast(in_data[syncbatchnorm::kData].shape_.Size()); + Tensor data; + Tensor out; + if (in_data[syncbatchnorm::kData].ndim() == 2) { + Shape<4> dshape = Shape4(in_data[syncbatchnorm::kData].shape_[0], + in_data[syncbatchnorm::kData].shape_[1], 1, 1); + data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); + out = out_data[syncbatchnorm::kOut].get_with_shape(dshape, s); + } else { + data = in_data[syncbatchnorm::kData].get(s); + out = out_data[syncbatchnorm::kOut].get(s); + } + Tensor slope = in_data[syncbatchnorm::kGamma].get(s); + Tensor bias = in_data[syncbatchnorm::kBeta].get(s); + Tensor moving_mean = aux_states[syncbatchnorm::kMovingMean].get(s); + Tensor moving_var = aux_states[syncbatchnorm::kMovingVar].get(s); + + if (param_.fix_gamma) slope = 1.f; + + // whether use global statistics + if (ctx.is_train && !param_.use_global_stats) { + // get my rank + Barrier *global_barrier = global_shared_barrier_forward.Register(param_.key, param_.ndev); + int myRank = global_shared_rank_forward.Register(param_.key, param_.ndev); + // get the mean and var + Tensor mean = out_data[syncbatchnorm::kMean].get(s); + Tensor var = out_data[syncbatchnorm::kVar].get(s); + CHECK(req[syncbatchnorm::kMean] == kNullOp || req[syncbatchnorm::kMean] == kWriteTo); + CHECK(req[syncbatchnorm::kVar] == kNullOp || req[syncbatchnorm::kVar] == kWriteTo); + // E(x) and E(x^2) + mean = scale * sumall_except_dim<1>(data); + var = scale * sumall_except_dim<1>(F(data)); + SharedND> *sharedMean = + global_shared_mean.Register(param_.key, param_.ndev); + SharedND> *sharedVar = + global_shared_var.Register(param_.key, param_.ndev); + // copy to cpu, push and pull + Tensor* mean_cpu_ptr = sharedMean->Retrieve(mean.shape_, myRank); + Tensor* var_cpu_ptr = sharedVar->Retrieve(mean.shape_, myRank); + mshadow::Copy(*mean_cpu_ptr, mean, s); + mshadow::Copy(*var_cpu_ptr, var, s); + sharedMean->SetReady(myRank); + sharedVar->SetReady(myRank); + global_barrier->Wait(); + Tensor mean_cpu = sharedMean->Pop(myRank); + Tensor var_cpu = sharedVar->Pop(myRank); + // copy back to gpu + mshadow::Copy(mean, mean_cpu, s); + mshadow::Copy(var, var_cpu, s); + + var = var-F(mean); + Assign(out, req[syncbatchnorm::kOut], broadcast<1>(slope, out.shape_) * + (data - broadcast<1>(mean, data.shape_)) / + F(broadcast<1>(var + param_.eps, data.shape_)) + + broadcast<1>(bias, out.shape_)); + } else { + Assign(out, req[syncbatchnorm::kOut], broadcast<1>(slope / + F(moving_var + param_.eps), + data.shape_) * data + + broadcast<1>(bias - (slope * moving_mean) / + F(moving_var + param_.eps), data.shape_)); + } + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_states) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), param_.output_mean_var ? 3U : 1U); + CHECK_EQ(in_data.size(), 3U); + CHECK_EQ(out_data.size(), 3U); + CHECK_EQ(in_grad.size(), 3U); + + Stream *s = ctx.get_stream(); + Tensor data, grad, grad_in; + const real_t scale = static_cast(out_grad[syncbatchnorm::kOut].shape_[1]) / + static_cast(out_grad[syncbatchnorm::kOut].shape_.Size()); + if (in_data[syncbatchnorm::kData].ndim() == 2) { + Shape<4> dshape = Shape4(out_grad[syncbatchnorm::kOut].shape_[0], + out_grad[syncbatchnorm::kOut].shape_[1], 1, 1); + data = in_data[syncbatchnorm::kData].get_with_shape(dshape, s); + grad = out_grad[syncbatchnorm::kOut].get_with_shape(dshape, s); + grad_in = in_grad[syncbatchnorm::kData].get_with_shape(dshape, s); + } else { + data = in_data[syncbatchnorm::kData].get(s); + grad = out_grad[syncbatchnorm::kOut].get(s); + grad_in = in_grad[syncbatchnorm::kData].get(s); + } + + Tensor mean = out_data[syncbatchnorm::kMean].get(s); + Tensor var = out_data[syncbatchnorm::kVar].get(s); + Tensor slope = in_data[syncbatchnorm::kGamma].get(s); + // Tensor bias = in_data[kBeta].get(s); + Tensor gslope = in_grad[syncbatchnorm::kGamma].get(s); + Tensor gbias = in_grad[syncbatchnorm::kBeta].get(s); + // update moving avg + Tensor moving_mean = aux_states[syncbatchnorm::kMovingMean].get(s); + Tensor moving_var = aux_states[syncbatchnorm::kMovingVar].get(s); + + if (param_.fix_gamma) slope = 1.f; + + if (ctx.is_train && !param_.use_global_stats) { + // get my rank + Barrier *global_barrier = global_shared_barrier_backward.Register(param_.key, param_.ndev); + int myRank = global_shared_rank_backward.Register(param_.key, param_.ndev); + // get requested temp space + Tensor workspace = ctx.requested[syncbatchnorm::kTempSpace].get_space( + mshadow::Shape2(5, mean.shape_[0]), s); + Tensor gmean = workspace[0]; + Tensor gvar = workspace[1]; + + moving_mean = moving_mean * param_.momentum + mean * (1 - param_.momentum); + moving_var = moving_var * param_.momentum + var * (1 - param_.momentum); + // cal + Tensor sumGrad = workspace[3]; + Tensor sumProd = workspace[4]; + sumGrad = sumall_except_dim<1>(grad); + sumProd = sumall_except_dim<1>(grad * (data - broadcast<1>(mean, data.shape_))); + SharedND> *sharedGrad = + global_shared_grad.Register(param_.key, param_.ndev); + SharedND> *sharedProd = + global_shared_prod.Register(param_.key, param_.ndev); + // copy to cpu, push and pull + Tensor* grad_cpu_ptr = sharedGrad->Retrieve(sumGrad.shape_, myRank); + Tensor* prod_cpu_ptr = sharedProd->Retrieve(sumGrad.shape_, myRank); + mshadow::Copy(*grad_cpu_ptr, sumGrad, s); + mshadow::Copy(*prod_cpu_ptr, sumProd, s); + sharedGrad->SetReady(myRank); + sharedProd->SetReady(myRank); + global_barrier->Wait(); + Tensor grad_cpu = sharedGrad->Pop(myRank); + Tensor prod_cpu = sharedProd->Pop(myRank); + // copy back to gpu + mshadow::Copy(sumGrad, grad_cpu, s); + mshadow::Copy(sumProd, prod_cpu, s); + + gvar = -1.0f * sumProd * slope * + F(var + param_.eps, -1.5f); + gmean = sumGrad * slope; + gmean *= -1.0f / F(var + param_.eps); + // assign + if (!param_.fix_gamma) { + Assign(gslope, req[syncbatchnorm::kGamma], + sumall_except_dim<1>( + grad * (data - broadcast<1>(mean, data.shape_)) / + F(broadcast<1>(var + param_.eps, data.shape_)))); + } else { + Assign(gslope, req[syncbatchnorm::kGamma], 0.0f); + } + Assign(grad_in, req[syncbatchnorm::kData], + (grad * broadcast<1>(slope, data.shape_)) * + broadcast<1>(1.0f / F(var + param_.eps), data.shape_) + + broadcast<1>(gvar, data.shape_) * + scale * (data - broadcast<1>(mean, data.shape_)) + + broadcast<1>(gmean, data.shape_) * scale); + Assign(gbias, req[syncbatchnorm::kBeta], sumall_except_dim<1>(grad)); + } else { + // use global statistics with freeze moving mean and var. + if (!param_.fix_gamma) { + Assign(gslope, req[syncbatchnorm::kGamma], + sumall_except_dim<1>( + grad * (data - broadcast<1>(moving_mean, data.shape_)) / + F(broadcast<1>(moving_var + param_.eps, data.shape_)))); + } else { + Assign(gslope, req[syncbatchnorm::kGamma], 0.0f); + } + Assign(gbias, req[syncbatchnorm::kBeta], sumall_except_dim<1>(grad)); + Assign(grad_in, req[syncbatchnorm::kData], (grad * broadcast<1>(slope, data.shape_)) * + broadcast<1>( + 1.0f / F(moving_var + param_.eps), data.shape_)); + } + } + + private: + SyncBatchNormParam param_; +}; // class SyncBatchNorm + +template +Operator *CreateOp(SyncBatchNormParam param, int dtype); + + +#if DMLC_USE_CXX11 +class SyncBatchNormProp : public OperatorProperty { + public: + void Init(const std::vector >& kwargs) override { + param_.Init(kwargs); + } + + std::map GetParams() const override { + return param_.__DICT__(); + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + using namespace mshadow; + CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; + const TShape &dshape = in_shape->at(0); + if (dshape.ndim() == 0) return false; + in_shape->at(1) = TShape(Shape1(dshape[1])); + in_shape->at(2) = TShape(Shape1(dshape[1])); + out_shape->clear(); + out_shape->push_back(dshape); + out_shape->push_back(Shape1(dshape[1])); + out_shape->push_back(Shape1(dshape[1])); + + aux_shape->clear(); + aux_shape->push_back(Shape1(dshape[1])); + aux_shape->push_back(Shape1(dshape[1])); + return true; + } + + bool InferType(std::vector *in_type, + std::vector *out_type, + std::vector *aux_type) const override { + using namespace mshadow; + CHECK_GE(in_type->size(), 1U); + int dtype = (*in_type)[0]; + CHECK_NE(dtype, -1) << "First input must have specified type"; + // For float16 input type beta, gamma, mean, and average are stored in float32. + // For other input types, these parameters have the same type as input + // NOTE: This requirement is from cuDNN (v. 4 and 5) + int dtype_param = (dtype == kFloat16) ? kFloat32 : dtype; + for (index_t i = 1; i < in_type->size(); ++i) { + if ((*in_type)[i] == -1) { + (*in_type)[i] = dtype_param; + } else { + UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, ListArguments()[i]); + } + } + for (index_t i = 0; i < aux_type->size(); ++i) { + if ((*aux_type)[i] != -1) { + UNIFORM_TYPE_CHECK((*aux_type)[i], dtype_param, ListArguments()[i]); + } + } + int n_aux = this->ListAuxiliaryStates().size(); + aux_type->clear(); + for (int i = 0; i < n_aux; ++i ) aux_type->push_back(dtype_param); + int n_out = this->ListOutputs().size(); + out_type->clear(); + out_type->push_back(dtype); + for (int i = 1; i < n_out; ++i ) out_type->push_back(dtype_param); + return true; + } + + OperatorProperty* Copy() const override { + auto ptr = new SyncBatchNormProp(); + ptr->param_ = param_; + return ptr; + } + + std::string TypeString() const override { + return "_contrib_SyncBatchNorm"; + } + + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + return {out_grad[syncbatchnorm::kOut], + out_data[syncbatchnorm::kMean], + out_data[syncbatchnorm::kVar], + in_data[syncbatchnorm::kData], + in_data[syncbatchnorm::kGamma] + }; + } + + std::vector BackwardResource( + const std::vector &in_shape) const override { + return {ResourceRequest::kTempSpace}; + } + + int NumVisibleOutputs() const override { + if (param_.output_mean_var) { + return 3; + } + return 1; + } + + int NumOutputs() const override { + return 3; + } + + std::vector ListArguments() const override { + return {"data", "gamma", "beta"}; + } + + std::vector ListOutputs() const override { + return {"output", "mean", "var"}; + } + + std::vector ListAuxiliaryStates() const override { + return {"moving_mean", "moving_var"}; + } + + Operator* CreateOperator(Context ctx) const override { + LOG(FATAL) << "Not Implemented."; + return NULL; + } + + Operator* CreateOperatorEx(Context ctx, std::vector *in_shape, + std::vector *in_type) const override; + + inline const SyncBatchNormParam& getParam() const { + return param_; + } + + private: + SyncBatchNormParam param_; +}; // class SyncBatchNormProp + +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_CONTRIB_SYNC_BATCH_NORM_INL_H_ diff --git a/src/operator/contrib/sync_batch_norm.cc b/src/operator/contrib/sync_batch_norm.cc new file mode 100644 index 000000000000..1b465d88b69e --- /dev/null +++ b/src/operator/contrib/sync_batch_norm.cc @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file sync_batch_norm.cc + * \brief Synchronized BatchNorm modified from BatchNormV1 + * \author Hang Zhang +*/ + +#include "sync_batch_norm-inl.h" +#include + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(SyncBatchNormParam param, int dtype) { + return new SyncBatchNorm(param); +} + +// DO_BIND_DISPATCH comes from operator_common.h +Operator *SyncBatchNormProp::CreateOperatorEx(Context ctx, std::vector *in_shape, + std::vector *in_type) const { + std::vector out_shape, aux_shape; + std::vector out_type, aux_type; + CHECK(InferType(in_type, &out_type, &aux_type)); + CHECK(InferShape(in_shape, &out_shape, &aux_shape)); + DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]); +} + +DMLC_REGISTER_PARAMETER(SyncBatchNormParam); + +MXNET_REGISTER_OP_PROPERTY(_contrib_SyncBatchNorm, SyncBatchNormProp) +.describe(R"code(Batch normalization. + +Normalizes a data batch by mean and variance, and applies a scale ``gamma`` as +well as offset ``beta``. +Standard BN [1]_ implementation only normalize the data within each device. +SyncBN normalizes the input within the whole mini-batch. +We follow the sync-onece implmentation described in the paper [2]_. + +Assume the input has more than one dimension and we normalize along axis 1. +We first compute the mean and variance along this axis: + +.. math:: + + data\_mean[i] = mean(data[:,i,:,...]) \\ + data\_var[i] = var(data[:,i,:,...]) + +Then compute the normalized output, which has the same shape as input, as following: + +.. math:: + + out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}} * gamma[i] + beta[i] + +Both *mean* and *var* returns a scalar by treating the input as a vector. + +Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta`` +have shape *(k,)*. If ``output_mean_var`` is set to be true, then outputs both ``data_mean`` and +``data_var`` as well, which are needed for the backward pass. + +Besides the inputs and the outputs, this operator accepts two auxiliary +states, ``moving_mean`` and ``moving_var``, which are *k*-length +vectors. They are global statistics for the whole dataset, which are updated +by:: + + moving_mean = moving_mean * momentum + data_mean * (1 - momentum) + moving_var = moving_var * momentum + data_var * (1 - momentum) + +If ``use_global_stats`` is set to be true, then ``moving_mean`` and +``moving_var`` are used instead of ``data_mean`` and ``data_var`` to compute +the output. It is often used during inference. + +Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is true, +then set ``gamma`` to 1 and its gradient to 0. + +Reference: + .. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating \ + deep network training by reducing internal covariate shift." *ICML 2015* + .. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, \ + Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018* +)code" ADD_FILELINE) +.add_argument("data", "NDArray-or-Symbol", "Input data to batch normalization") +.add_argument("gamma", "NDArray-or-Symbol", "gamma array") +.add_argument("beta", "NDArray-or-Symbol", "beta array") +.add_argument("moving_mean", "NDArray-or-Symbol", "running mean of input") +.add_argument("moving_var", "NDArray-or-Symbol", "running variance of input") +.add_arguments(SyncBatchNormParam::__FIELDS__()); + +NNVM_REGISTER_OP(_contrib_SyncBatchNorm) +.set_attr("FSetInputVarAttrOnCompose", + [](const nnvm::NodeAttrs& attrs, nnvm::NodePtr var, const int index) { + if (var->attrs.dict.find("__init__") != var->attrs.dict.end()) return; + if (index == 3) { + var->attrs.dict["__init__"] = "[\"zero\", {}]"; + } else if (index == 4) { + var->attrs.dict["__init__"] = "[\"one\", {}]"; + } + }); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/contrib/sync_batch_norm.cu b/src/operator/contrib/sync_batch_norm.cu new file mode 100644 index 000000000000..005ac3f61a9c --- /dev/null +++ b/src/operator/contrib/sync_batch_norm.cu @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file sync_batch_norm.cc + * \brief Synchronized BatchNorm modified from BatchNormV1 + * \author Hang Zhang +*/ + +#include "sync_batch_norm-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator *CreateOp(SyncBatchNormParam param, int dtype) { + return new SyncBatchNorm(param); +} + +} // namespace op +} // namespace mxnet + diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index f8930e12d1da..b3a54b164c3a 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -1909,6 +1909,84 @@ def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. assert mx.context.num_gpus() > 0 +def _check_batchnorm_result(input, num_devices=1, cuda=False): + from mxnet.gluon.utils import split_and_load + def _find_bn(module): + if isinstance(module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module + elif isinstance(module.module, (mx.gluon.nn.BatchNorm, mx.gluon.contrib.nn.SyncBatchNorm)): + return module.module + + raise RuntimeError('BN not found') + + def _syncParameters(bn1, bn2, ctx): + ctx = input.context + bn2.gamma.set_data(bn1.gamma.data(ctx)) + bn2.beta.set_data(bn1.beta.data(ctx)) + bn2.running_mean.set_data(bn1.running_mean.data(ctx)) + bn2.running_var.set_data(bn1.running_var.data(ctx)) + + input1 = input.copy() + input2 = input.copy() + + if cuda: + input1 = input.as_in_context(mx.gpu(0)) + ctx_list = [mx.gpu(i) for i in range(num_devices)] + else: + ctx_list = [mx.cpu(0) for _ in range(num_devices)] + + nch = input.shape[1] + bn1 = mx.gluon.nn.BatchNorm(in_channels=nch) + bn2 = mx.gluon.contrib.nn.SyncBatchNorm(in_channels=nch, num_devices=num_devices) + + bn1.initialize(ctx=ctx_list[0]) + bn2.initialize(ctx=ctx_list) + + # using the same values for gamma and beta + #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0]) + + input1.attach_grad() + inputs2 = split_and_load(input2, ctx_list, batch_axis=0) + for xi in inputs2: + xi.attach_grad() + + with mx.autograd.record(): + output1 = bn1(input1) + output2 = [bn2(xi) for xi in inputs2] + loss1 = (output1 ** 2).sum() + loss2 = [(output ** 2).sum() for output in output2] + mx.autograd.backward(loss1) + mx.autograd.backward(loss2) + + output2 = mx.nd.concat(*[output.as_in_context(input.context) for output in output2], dim=0) + # assert forwarding + assert_almost_equal(input1.asnumpy(), input2.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(output1.asnumpy(), output2.asnumpy(), atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(), + _find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(), + atol=1e-3, rtol=1e-3) + assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(), + _find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(), + atol=1e-3, rtol=1e-3) + input2grad = mx.nd.concat(*[output.grad.as_in_context(input.context) for output in inputs2], dim=0) + assert_almost_equal(input1.grad.asnumpy(), input2grad.asnumpy(), atol=1e-3, rtol=1e-3) + +def test_sync_batchnorm(): + def get_num_devices(): + for i in range(100): + try: + mx.nd.zeros((1,), ctx=mx.gpu(i)) + except: + return i + # no need to use SyncBN with 1 gpu + if get_num_devices() < 2: + return + ndev = 2 + # check with unsync version + for i in range(10): + _check_batchnorm_result(mx.nd.random.uniform(shape=(4, 1, 4, 4)), + num_devices=ndev, cuda=True) + if __name__ == '__main__': import nose nose.runmodule()