In [46]:
import torch
import torch.nn as nn

In [431]:
C = 2
B_SZ = 1

### Normalize over each channel

**momentum** – the value used for the running_mean and running_var computation. Can be set to None for cumulative moving average (i.e. simple average). Default: 0.1

In [483]:
bn = nn.BatchNorm2d(C, momentum=1) # mom=1 makes it easier to calculate

In [484]:
x = torch.randn(B_SZ,C,3,3); x

tensor([[[[-1.5723, -1.8493,  0.2430],
          [-0.1550,  0.1350, -1.0884],
          [ 0.8568,  1.6856, -0.7365]],

         [[-0.7027,  0.5681,  0.0645],
          [ 0.5936,  1.8537, -1.1054],
          [-0.4720,  1.3684,  0.5947]]]])

In [485]:
# batch, c, w, h
x.shape

torch.Size([1, 2, 3, 3])

In [486]:
y = bn(x)

In [487]:
x[:,0].mean(), x[:,1].mean()

(tensor(-0.2757), tensor(0.3070))

In [488]:
x[:,0].var(), x[:,1].var()

(tensor(1.3264), tensor(0.9290))

In [489]:
bn.running_mean

tensor([-0.2757,  0.3070])

In [490]:
bn.running_var

tensor([1.3264, 0.9290])

In [491]:
bn.eps

1e-05

## I'm not able to reproduce it exactly, always of by around 0.2

In [496]:
n = (x[:,0]-bn.running_mean[0])/(torch.sqrt(bn.running_var[0]+bn.eps))
n.mean(), n.var()

(tensor(-3.9736e-08), tensor(1.0000))

In [497]:
n

tensor([[[-1.1258, -1.3663,  0.4503],
         [ 0.1048,  0.3565, -0.7056],
         [ 0.9833,  1.7029, -0.4001]]])

In [498]:
y[:,0]

tensor([[[-1.1941, -1.4492,  0.4777],
         [ 0.1111,  0.3782, -0.7484],
         [ 1.0429,  1.8062, -0.4244]]], grad_fn=<SelectBackward>)

In [495]:
y[:,0].mean(), y[:,0].var()

(tensor(-1.3245e-08, grad_fn=<MeanBackward0>),
 tensor(1.1250, grad_fn=<VarBackward0>))

### Trainable params

In [500]:
bn.weight

Parameter containing:
tensor([1., 1.], requires_grad=True)

In [462]:
bn.bias

Parameter containing:
tensor([0., 0.], requires_grad=True)

In [463]:
y.shape

torch.Size([1, 2, 3, 3])

In [464]:
z = y.sum()

In [465]:
z.backward()

In [466]:
bn.bias.grad # equals the number of elements in y for each channel

tensor([9., 9.])

In [468]:
n.sum()

tensor(-8.9407e-08)

In [467]:
bn.weight.grad # depends on x

tensor([4.2699e-08, 3.4601e-08])