Skip to content

Commit

Permalink
Adding process group in convert_syncbn_model
Browse files Browse the repository at this point in the history
  • Loading branch information
jjsjann123 committed Dec 10, 2018
1 parent 920da6d commit 6d3c75e
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 15 deletions.
4 changes: 2 additions & 2 deletions apex/parallel/__init__.py
Expand Up @@ -8,7 +8,7 @@
print("Warning: apex was installed without --cuda_ext. Fused syncbn kernels will be unavailable. Python fallbacks will be used instead.")
from .sync_batchnorm import SyncBatchNorm

def convert_syncbn_model(module):
def convert_syncbn_model(module, process_group=None):
'''
Recursively traverse module and its children to replace all
`torch.nn.modules.batchnorm._BatchNorm` with `apex.parallel.SyncBatchNorm`
Expand All @@ -27,7 +27,7 @@ def convert_syncbn_model(module):
'''
mod = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats)
mod = SyncBatchNorm(module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, process_group)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
if module.affine:
Expand Down
25 changes: 18 additions & 7 deletions apex/parallel/sync_batchnorm.py
Expand Up @@ -44,8 +44,12 @@ class SyncBatchNorm(_BatchNorm):
>>> out = sbn(inp)
"""

def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None):
super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.process_group = process_group

def _specify_process_group(self, process_group):
self.process_group = process_group

def forward(self, input):
torch.cuda.nvtx.range_push("sync_bn_fw_with_mean_var")
Expand All @@ -56,6 +60,13 @@ def forward(self, input):
torch.cuda.nvtx.range_pop()
return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
else:
process_group = self.process_group
world_size = 0
if self.process_group:
world_size = torch.distributed.get_world_size(process_group)
else:
process_group = torch.distributed.get_default_group()
world_size = torch.distributed.get_world_size()
self.num_batches_tracked += 1
with torch.no_grad():
channel_first_input = input.transpose(0, 1).contiguous()
Expand All @@ -69,12 +80,12 @@ def forward(self, input):
squashed_input_tensor_view, 2).mean(1)
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
local_mean, op=torch.distributed.reduce_op.SUM)
mean = local_mean / torch.distributed.get_world_size()
local_mean, torch.distributed.reduce_op.SUM, process_group)
mean = local_mean / world_size
torch.distributed.all_reduce(
local_sqr_mean, op=torch.distributed.reduce_op.SUM)
sqr_mean = local_sqr_mean / torch.distributed.get_world_size()
m = local_m * torch.distributed.get_world_size()
local_sqr_mean, torch.distributed.reduce_op.SUM, process_group)
sqr_mean = local_sqr_mean / world_size
m = local_m * world_size
else:
m = local_m
mean = local_mean
Expand All @@ -94,4 +105,4 @@ def forward(self, input):
(m-1) * self.momentum * var + \
(1 - self.momentum) * self.running_var
torch.cuda.nvtx.range_pop()
return SyncBatchnormFunction.apply(input, self.weight, self.bias, mean, var, self.eps)
return SyncBatchnormFunction.apply(input, self.weight, self.bias, mean, var, self.eps, process_group, world_size)
16 changes: 10 additions & 6 deletions apex/parallel/sync_batchnorm_kernel.py
Expand Up @@ -5,14 +5,16 @@
class SyncBatchnormFunction(Function):

@staticmethod
def forward(ctx, input, weight, bias, running_mean, running_variance, eps):
def forward(ctx, input, weight, bias, running_mean, running_variance, eps, process_group, world_size):
torch.cuda.nvtx.range_push("sync_BN_fw")
# transpose it to channel last to support broadcasting for input with different rank
c_last_input = input.transpose(1, -1).contiguous().clone()

ctx.save_for_backward(c_last_input, weight, bias,
running_mean, running_variance)
ctx.eps = eps
ctx.process_group = process_group
ctx.world_size = world_size

c_last_input = (c_last_input - running_mean) / \
torch.sqrt(running_variance + eps)
Expand All @@ -34,6 +36,8 @@ def backward(ctx, grad_output):
c_last_input, weight, bias, running_mean, running_variance = ctx.saved_tensors

eps = ctx.eps
process_group = ctx.process_group
world_size = ctx.world_size
grad_input = grad_weight = grad_bias = None
num_features = running_mean.size()[0]

Expand All @@ -53,11 +57,11 @@ def backward(ctx, grad_output):
running_mean)).view(-1, num_features).mean(0)
if torch.distributed.is_initialized():
torch.distributed.all_reduce(
mean_dy, op=torch.distributed.reduce_op.SUM)
mean_dy = mean_dy / torch.distributed.get_world_size()
mean_dy, torch.distributed.reduce_op.SUM, process_group)
mean_dy = mean_dy / world_size
torch.distributed.all_reduce(
mean_dy_xmu, op=torch.distributed.reduce_op.SUM)
mean_dy_xmu = mean_dy_xmu / torch.distributed.get_world_size()
mean_dy_xmu, torch.distributed.reduce_op.SUM, process_group)
mean_dy_xmu = mean_dy_xmu / world_size
c_last_grad_input = (c_last_grad - mean_dy - (c_last_input - running_mean) / (
running_variance + eps) * mean_dy_xmu) / torch.sqrt(running_variance + eps)
if weight is not None:
Expand All @@ -78,4 +82,4 @@ def backward(ctx, grad_output):
grad_bias = c_grad.sum(0)

torch.cuda.nvtx.range_pop()
return grad_input, grad_weight, grad_bias, None, None, None
return grad_input, grad_weight, grad_bias, None, None, None, None, None

0 comments on commit 6d3c75e

Please sign in to comment.