In [12]:
import torch
from torch import nn
from torch.nn import functional as F

# BatchNorm2d

https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html#torch.nn.BatchNorm2d

$$
y = \frac {x - E[x]} {\sqrt {Var[x] + \epsilon}} * \gamma + \beta
$$

In [13]:
norm = nn.BatchNorm2d(3)
norm

BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [14]:
# running_mean
norm.running_mean.data

tensor([0., 0., 0.])

In [15]:
# running_var
norm.running_var.data

tensor([1., 1., 1.])

In [16]:
# gamma
norm.weight.data

tensor([1., 1., 1.])

In [17]:
# beta
norm.bias.data

tensor([0., 0., 0.])

In [18]:
# eps
norm.eps

1e-05

# 运算

In [19]:
x = torch.arange(48.0).reshape(1, 3, 4, 4)
x

tensor([[[[ 0.,  1.,  2.,  3.],
          [ 4.,  5.,  6.,  7.],
          [ 8.,  9., 10., 11.],
          [12., 13., 14., 15.]],

         [[16., 17., 18., 19.],
          [20., 21., 22., 23.],
          [24., 25., 26., 27.],
          [28., 29., 30., 31.]],

         [[32., 33., 34., 35.],
          [36., 37., 38., 39.],
          [40., 41., 42., 43.],
          [44., 45., 46., 47.]]]])

In [20]:
norm(x)

tensor([[[[-1.6270, -1.4100, -1.1931, -0.9762],
          [-0.7593, -0.5423, -0.3254, -0.1085],
          [ 0.1085,  0.3254,  0.5423,  0.7593],
          [ 0.9762,  1.1931,  1.4100,  1.6270]],

         [[-1.6270, -1.4100, -1.1931, -0.9762],
          [-0.7593, -0.5423, -0.3254, -0.1085],
          [ 0.1085,  0.3254,  0.5423,  0.7593],
          [ 0.9762,  1.1931,  1.4100,  1.6270]],

         [[-1.6270, -1.4100, -1.1931, -0.9762],
          [-0.7593, -0.5423, -0.3254, -0.1085],
          [ 0.1085,  0.3254,  0.5423,  0.7593],
          [ 0.9762,  1.1931,  1.4100,  1.6270]]]],
       grad_fn=<NativeBatchNormBackward0>)

In [28]:
(x - x.mean(dim=(0, 2, 3), keepdim=True)) / torch.sqrt(
    (x.var(dim=(0, 2, 3), keepdim=True) + norm.eps)
)

tensor([[[[-1.5753, -1.3653, -1.1552, -0.9452],
          [-0.7351, -0.5251, -0.3151, -0.1050],
          [ 0.1050,  0.3151,  0.5251,  0.7351],
          [ 0.9452,  1.1552,  1.3653,  1.5753]],

         [[-1.5753, -1.3653, -1.1552, -0.9452],
          [-0.7351, -0.5251, -0.3151, -0.1050],
          [ 0.1050,  0.3151,  0.5251,  0.7351],
          [ 0.9452,  1.1552,  1.3653,  1.5753]],

         [[-1.5753, -1.3653, -1.1552, -0.9452],
          [-0.7351, -0.5251, -0.3151, -0.1050],
          [ 0.1050,  0.3151,  0.5251,  0.7351],
          [ 0.9452,  1.1552,  1.3653,  1.5753]]]])

In [25]:
x.mean(dim=(0, 2, 3), keepdim=True)

tensor([[[[ 7.5000]],

         [[23.5000]],

         [[39.5000]]]])