In [1]:
import torch
from torch import nn
from torch.nn import functional as F

# BatchNorm2d

https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html#torch.nn.BatchNorm2d

$$
y = \frac {x - E[x]} {\sqrt {Var[x] + \epsilon}} * \gamma + \beta
$$

In [2]:
norm = nn.BatchNorm2d(4)
norm

BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [3]:
# running_mean
norm.running_mean.data

tensor([0., 0., 0., 0.])

In [4]:
# running_var
norm.running_var.data

tensor([1., 1., 1., 1.])

In [5]:
# gamma
norm.weight.data

tensor([1., 1., 1., 1.])

In [6]:
# beta
norm.bias.data

tensor([0., 0., 0., 0.])

In [7]:
# eps
norm.eps

1e-05

# 运算

In [8]:
x = torch.arange(100.).reshape(1, 4, 5, 5)
x

tensor([[[[ 0.,  1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.,  9.],
          [10., 11., 12., 13., 14.],
          [15., 16., 17., 18., 19.],
          [20., 21., 22., 23., 24.]],

         [[25., 26., 27., 28., 29.],
          [30., 31., 32., 33., 34.],
          [35., 36., 37., 38., 39.],
          [40., 41., 42., 43., 44.],
          [45., 46., 47., 48., 49.]],

         [[50., 51., 52., 53., 54.],
          [55., 56., 57., 58., 59.],
          [60., 61., 62., 63., 64.],
          [65., 66., 67., 68., 69.],
          [70., 71., 72., 73., 74.]],

         [[75., 76., 77., 78., 79.],
          [80., 81., 82., 83., 84.],
          [85., 86., 87., 88., 89.],
          [90., 91., 92., 93., 94.],
          [95., 96., 97., 98., 99.]]]])

In [9]:
x.shape

torch.Size([1, 4, 5, 5])

In [10]:
norm(x)

tensor([[[[-1.6641, -1.5254, -1.3868, -1.2481, -1.1094],
          [-0.9707, -0.8321, -0.6934, -0.5547, -0.4160],
          [-0.2774, -0.1387,  0.0000,  0.1387,  0.2774],
          [ 0.4160,  0.5547,  0.6934,  0.8321,  0.9707],
          [ 1.1094,  1.2481,  1.3868,  1.5254,  1.6641]],

         [[-1.6641, -1.5254, -1.3868, -1.2481, -1.1094],
          [-0.9707, -0.8321, -0.6934, -0.5547, -0.4160],
          [-0.2773, -0.1387,  0.0000,  0.1387,  0.2773],
          [ 0.4160,  0.5547,  0.6934,  0.8321,  0.9707],
          [ 1.1094,  1.2481,  1.3868,  1.5254,  1.6641]],

         [[-1.6641, -1.5254, -1.3868, -1.2481, -1.1094],
          [-0.9707, -0.8320, -0.6934, -0.5547, -0.4160],
          [-0.2773, -0.1387,  0.0000,  0.1387,  0.2774],
          [ 0.4160,  0.5547,  0.6934,  0.8321,  0.9707],
          [ 1.1094,  1.2481,  1.3868,  1.5254,  1.6641]],

         [[-1.6641, -1.5254, -1.3868, -1.2481, -1.1094],
          [-0.9707, -0.8321, -0.6934, -0.5547, -0.4160],
          [-0.2773, -0.13