Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-614] Adding Synchronized Batch Normalization #11502

Merged
merged 38 commits into from Jul 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
3f43194
sync batch norm
zhanghang1989 Jun 29, 2018
7d58073
global rank and barrier
zhanghang1989 Jun 30, 2018
9908468
lint
zhanghang1989 Jun 30, 2018
3ce37a3
cpplint
zhanghang1989 Jun 30, 2018
a32ec07
pylint
zhanghang1989 Jun 30, 2018
cd6d93b
doc
zhanghang1989 Jun 30, 2018
a3bd860
add ref
zhanghang1989 Jun 30, 2018
a76dad4
customized barrier
zhanghang1989 Jul 1, 2018
b3793bf
cpplint
zhanghang1989 Jul 1, 2018
b31fcfe
get rid of pthread
zhanghang1989 Jul 2, 2018
8f58e63
address comments
zhanghang1989 Jul 2, 2018
cc60d11
warning
zhanghang1989 Jul 2, 2018
1c87e6e
pylint
zhanghang1989 Jul 2, 2018
31d7c2f
Merge remote-tracking branch 'upstream/master' into syncBN
zhanghang1989 Jul 2, 2018
4bb0f3a
gpu unitest
zhanghang1989 Jul 3, 2018
62d8891
gpu 0
zhanghang1989 Jul 3, 2018
24543c9
mv to cpu test
zhanghang1989 Jul 3, 2018
50f8593
Revert "mv to cpu test"
zhanghang1989 Jul 3, 2018
9239793
Merge remote-tracking branch 'upstream/master' into syncBN
zhanghang1989 Jul 3, 2018
b70426d
ndev = 2
zhanghang1989 Jul 3, 2018
f888706
debuging
zhanghang1989 Jul 5, 2018
b7d2d3c
sum prod
zhanghang1989 Jul 5, 2018
723c670
lint
zhanghang1989 Jul 5, 2018
abdd5d1
contrib, ngpu
zhanghang1989 Jul 5, 2018
c72413a
code style
zhanghang1989 Jul 6, 2018
2e0dc79
code style
zhanghang1989 Jul 6, 2018
3013cb3
forward backward
zhanghang1989 Jul 6, 2018
1c92152
Merge remote-tracking branch 'upstream/master' into syncBN
zhanghang1989 Jul 6, 2018
ffce503
test
zhanghang1989 Jul 6, 2018
6acdc71
cpu test
zhanghang1989 Jul 6, 2018
38ba23d
Merge remote-tracking branch 'upstream/master' into syncBN
zhanghang1989 Jul 9, 2018
3a439d0
fix deconstruction
zhanghang1989 Jul 10, 2018
a7918e0
doc indent
zhanghang1989 Jul 10, 2018
55fef70
doc
zhanghang1989 Jul 10, 2018
9884d60
doc
zhanghang1989 Jul 11, 2018
a2780ce
address comments
zhanghang1989 Jul 12, 2018
16df5d4
typo
zhanghang1989 Jul 13, 2018
809854d
asnumpy
zhanghang1989 Jul 13, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/python/gluon/contrib.md
Expand Up @@ -36,6 +36,7 @@ In the rest of this document, we list routines provided by the `gluon.contrib` p
HybridConcurrent
Identity
SparseEmbedding
SyncBatchNorm
```

### Recurrent neural network
Expand Down
84 changes: 81 additions & 3 deletions python/mxnet/gluon/contrib/nn/basic_layers.py
Expand Up @@ -18,11 +18,13 @@
# coding: utf-8
# pylint: disable= arguments-differ
"""Custom neural network layers in model_zoo."""
__all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding']
__all__ = ['Concurrent', 'HybridConcurrent', 'Identity', 'SparseEmbedding',
'SyncBatchNorm']

from .... import nd
import warnings
from .... import nd, test_utils
from ...block import HybridBlock, Block
from ...nn import Sequential, HybridSequential
from ...nn import Sequential, HybridSequential, BatchNorm

class Concurrent(Sequential):
"""Lays `Block`s concurrently.
Expand Down Expand Up @@ -151,3 +153,79 @@ def __repr__(self):
s = '{block_name}({input_dim} -> {output_dim}, {dtype})'
return s.format(block_name=self.__class__.__name__,
**self._kwargs)

class SyncBatchNorm(BatchNorm):
"""Cross-GPU Synchronized Batch normalization (SyncBN)

Standard BN [1]_ implementation only normalize the data within each device.
SyncBN normalizes the input within the whole mini-batch.
We follow the sync-onece implmentation described in the paper [2]_.

Parameters
----------
in_channels : int, default 0
Number of channels (feature maps) in input data. If not specified,
initialization will be deferred to the first time `forward` is called
and `in_channels` will be inferred from the shape of input data.
num_devices : int, default number of visible GPUs
momentum: float, default 0.9
Momentum for the moving average.
epsilon: float, default 1e-5
Small float added to variance to avoid dividing by zero.
center: bool, default True
If True, add offset of `beta` to normalized tensor.
If False, `beta` is ignored.
scale: bool, default True
If True, multiply by `gamma`. If False, `gamma` is not used.
When the next layer is linear (also e.g. `nn.relu`),
this can be disabled since the scaling
will be done by the next layer.
use_global_stats: bool, default False
If True, use global moving statistics instead of local batch-norm. This will force
change batch-norm into a scale shift operator.
If False, use local batch-norm.
beta_initializer: str or `Initializer`, default 'zeros'
Initializer for the beta weight.
gamma_initializer: str or `Initializer`, default 'ones'
Initializer for the gamma weight.
moving_mean_initializer: str or `Initializer`, default 'zeros'
Initializer for the moving mean.
moving_variance_initializer: str or `Initializer`, default 'ones'
Initializer for the moving variance.


Inputs:
- **data**: input tensor with arbitrary shape.
Outputs:
- **out**: output tensor with the same shape as `data`.

Reference:
.. [1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating \
deep network training by reducing internal covariate shift." *ICML 2015*
.. [2] Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, \
Ambrish Tyagi, and Amit Agrawal. "Context Encoding for Semantic Segmentation." *CVPR 2018*
"""
def __init__(self, in_channels=0, num_devices=None, momentum=0.9, epsilon=1e-5,
center=True, scale=True, use_global_stats=False, beta_initializer='zeros',
gamma_initializer='ones', running_mean_initializer='zeros',
running_variance_initializer='ones', **kwargs):
super(SyncBatchNorm, self).__init__(1, momentum, epsilon, center, scale, use_global_stats,
beta_initializer, gamma_initializer,
running_mean_initializer, running_variance_initializer,
in_channels, **kwargs)
num_devices = self._get_num_devices() if num_devices is None else num_devices
self._kwargs = {'eps': epsilon, 'momentum': momentum,
'fix_gamma': not scale, 'use_global_stats': use_global_stats,
'ndev': num_devices, 'key': self.prefix}

def _get_num_devices(self):
warnings.warn("Caution using SyncBatchNorm: "
"if not using all the GPUs, please mannually set num_devices",
UserWarning)
num_devices = len(test_utils.list_gpus())
num_devices = num_devices if num_devices > 0 else 1
return num_devices

def hybrid_forward(self, F, x, gamma, beta, running_mean, running_var):
return F.contrib.SyncBatchNorm(x, gamma, beta, running_mean, running_var,
name='fwd', **self._kwargs)