In [1]:
import os
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.multiprocessing as mp

%load_ext autoreload
%autoreload 2
from syncbn import SyncBatchNorm

### preliminary computations

In [434]:
def bn(x, rmean, rvar, momentum=0.1, eps=1e-05):
    batch = x.size(0)
    hidden = x.size(2)
    mean = x.mean(dim=(0, 2))
    
    squared_mean = x.pow(2).mean(dim=(0, 2))
    
    var = squared_mean - mean.pow(2)
    
    inverse = 1 / torch.sqrt(var + eps)
    normalized = (x - mean.view(1, -1, 1)) * inverse.view(1, -1, 1)

    bias_correction = (batch * hidden) / (batch * hidden - 1) 
    
    with torch.no_grad():
        rmean.mul_(1 - momentum).add_(mean * momentum)
        rvar.mul_(1 - momentum).add_(var * bias_correction * momentum)
    
    return normalized, inverse

In [435]:
def bn_grad(grad_output, inverse, normalized):
    
    batch = grad_output.size(0)
    hidden = grad_output.size(2)
    
    N = batch * hidden
    inverse = inverse.view(1, -1, 1)
    grad_input = inverse / N * (
        N * grad_output \
        - grad_output.sum(dim=(0, 2)).view(1, -1, 1) \
        - normalized * (normalized * grad_output).sum(dim=(0, 2)).view(1, -1, 1)
    )
    
    return grad_input

In [436]:
x = torch.rand(16, 100, 256)
copy_x = x.clone()

x.requires_grad = True
copy_x.requires_grad = True

In [437]:
rmean = torch.zeros(x.size(1))
rvar = torch.ones(x.size(1))

y_pred, inverse = bn(x, rmean, rvar)
y_pred.sum().backward()

In [438]:
bn1d = nn.BatchNorm1d(100, affine=False, track_running_stats=True)
y_true = bn1d(copy_x)
y_true.sum().backward()

In [441]:
print('outputs same: ', torch.allclose(y_true, y_pred, atol=1e-0))
print(' rmeans same: ', torch.allclose(bn1d.running_mean, rmean))
print('  rvars smae: ', torch.allclose(bn1d.running_var, rvar))
print('  grads same: ', torch.allclose(copy_x.grad, x.grad, atol=1e-5))

d = torch.ones_like(x)
print('  grads same: ', torch.allclose(bn_grad(d, inverse, y_pred), copy_x.grad, atol=1e-5))

outputs same:  True
 rmeans same:  True
  rvars smae:  True
  grads same:  True
  grads same:  True


### final testing

In [2]:
x = torch.randn(16, 100, 256)
x_copy = x.detach().clone()

x.requires_grad = True
x_copy.requires_grad = True

In [3]:
sbn = SyncBatchNorm(100)
bn1d = nn.BatchNorm1d(100, affine=False, track_running_stats=True)

In [4]:
y_true = bn1d(x_copy)
y_pred = sbn(x)

print(torch.allclose(y_pred, y_true, atol=1e-08))
print(torch.allclose(sbn.running_mean, bn1d.running_mean, atol=1e-08))
print(torch.allclose(sbn.running_var, bn1d.running_var, atol=1e-08))

y_true.sum().backward()
y_pred.sum().backward()
print(torch.allclose(x.grad, x_copy.grad, atol=1e-05))

True
True
True
True


In [5]:
[(i, j, k) for i in range(0, 3) for j in range(2, 7 ,2) for k in range(10, 31, 10)]

[(0, 2, 10),
 (0, 2, 20),
 (0, 2, 30),
 (0, 4, 10),
 (0, 4, 20),
 (0, 4, 30),
 (0, 6, 10),
 (0, 6, 20),
 (0, 6, 30),
 (1, 2, 10),
 (1, 2, 20),
 (1, 2, 30),
 (1, 4, 10),
 (1, 4, 20),
 (1, 4, 30),
 (1, 6, 10),
 (1, 6, 20),
 (1, 6, 30),
 (2, 2, 10),
 (2, 2, 20),
 (2, 2, 30),
 (2, 4, 10),
 (2, 4, 20),
 (2, 4, 30),
 (2, 6, 10),
 (2, 6, 20),
 (2, 6, 30)]

In [20]:
def run(rank, size, input):
    """ Distributed function to be implemented later. """
    out = sbn(input)

def init_process(rank, size, input, fn, backend='gloo'):
    """ Initialize the distributed environment. """
    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '29500'
    
    dist.init_process_group(backend, rank=rank, world_size=size)
    fn(rank, size, input)

In [21]:
size = 2
x = torch.rand(size, 16, 100, 256)

processes = []
mp.set_start_method("spawn")
for rank in range(size):
    p = mp.Process(target=init_process, args=(rank, size, x[rank], run))
    p.start()
    processes.append(p)

for p in processes:
    p.join()

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/kvtamogashev/miniforge3/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/kvtamogashev/miniforge3/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'init_process' on <module '__main__' (built-in)>
Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/kvtamogashev/miniforge3/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/kvtamogashev/miniforge3/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'init_process' on <module '__main__' (built-in)>


In [19]:
x[0].shape

torch.Size([16, 100, 256])

In [25]:
torch.randn(10, 10, 10).mean(dim=(0, 1), keepdims=True)

tensor([[[ 0.0784,  0.0996, -0.0383, -0.0372, -0.0514,  0.0836, -0.1197,
           0.1208,  0.0244, -0.1221]]])