In [1]:
import torch
import numpy as np

In [2]:
data=torch.arange(15).reshape(3,5).float()
data

tensor([[ 0.,  1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14.]])

In [3]:
bn = torch.nn.BatchNorm1d(5)

In [4]:
bn.state_dict()

OrderedDict([('weight', tensor([1., 1., 1., 1., 1.])),
             ('bias', tensor([0., 0., 0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0., 0., 0.])),
             ('running_var', tensor([1., 1., 1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

In [5]:
bn(data)

tensor([[-1.2247, -1.2247, -1.2247, -1.2247, -1.2247],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.2247,  1.2247,  1.2247,  1.2247,  1.2247]],
       grad_fn=<NativeBatchNormBackward0>)

### 0.1\*25+0.9\*1 = 3.4
25 为样本方差

In [6]:
bn.state_dict()

OrderedDict([('weight', tensor([1., 1., 1., 1., 1.])),
             ('bias', tensor([0., 0., 0., 0., 0.])),
             ('running_mean',
              tensor([0.5000, 0.6000, 0.7000, 0.8000, 0.9000])),
             ('running_var', tensor([3.4000, 3.4000, 3.4000, 3.4000, 3.4000])),
             ('num_batches_tracked', tensor(1))])

![jupyter](./var.png)

torch.var两种方差都可以计算，这取决于一个参数，即unbiased，无偏的意思。默认值为true，也就是说，默认的目的是样本估计总体，使用的是上面这个样本方差公式，计算的是样本方差。

In [7]:
var = data.var(0)
var

tensor([25., 25., 25., 25., 25.])

In [8]:
data_var = data.var(0,False) # 母体方差
data_var

tensor([16.6667, 16.6667, 16.6667, 16.6667, 16.6667])

In [9]:
data_mean = data.mean(0)
data_mean

tensor([5., 6., 7., 8., 9.])

### 使用母体方差

In [10]:
(data - data_mean)/np.sqrt(data_var)

tensor([[-1.2247, -1.2247, -1.2247, -1.2247, -1.2247],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.2247,  1.2247,  1.2247,  1.2247,  1.2247]])

In [11]:
bn1 = torch.nn.BatchNorm1d(5, momentum=1)

In [12]:
bn1

BatchNorm1d(5, eps=1e-05, momentum=1, affine=True, track_running_stats=True)

In [13]:
bn1(data)

tensor([[-1.2247, -1.2247, -1.2247, -1.2247, -1.2247],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.2247,  1.2247,  1.2247,  1.2247,  1.2247]],
       grad_fn=<NativeBatchNormBackward0>)

In [14]:
bn1.state_dict()

OrderedDict([('weight', tensor([1., 1., 1., 1., 1.])),
             ('bias', tensor([0., 0., 0., 0., 0.])),
             ('running_mean', tensor([5., 6., 7., 8., 9.])),
             ('running_var', tensor([25., 25., 25., 25., 25.])),
             ('num_batches_tracked', tensor(1))])

In [15]:
data_var=torch.Tensor.var(data, 0, True)

In [16]:
data_var

tensor([25., 25., 25., 25., 25.])

In [17]:
bn2=torch.nn.BatchNorm1d(5, momentum=0.5)
bn2.state_dict()

OrderedDict([('weight', tensor([1., 1., 1., 1., 1.])),
             ('bias', tensor([0., 0., 0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0., 0., 0.])),
             ('running_var', tensor([1., 1., 1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

In [18]:
bn2(data)

tensor([[-1.2247, -1.2247, -1.2247, -1.2247, -1.2247],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 1.2247,  1.2247,  1.2247,  1.2247,  1.2247]],
       grad_fn=<NativeBatchNormBackward0>)

## running_var = 0.5\*1+0.5\*25=13

In [19]:
bn2.state_dict()

OrderedDict([('weight', tensor([1., 1., 1., 1., 1.])),
             ('bias', tensor([0., 0., 0., 0., 0.])),
             ('running_mean',
              tensor([2.5000, 3.0000, 3.5000, 4.0000, 4.5000])),
             ('running_var', tensor([13., 13., 13., 13., 13.])),
             ('num_batches_tracked', tensor(1))])

In [20]:
a = torch.randn(3,3)

In [21]:
a

tensor([[ 0.7976,  0.7063,  0.5399],
        [ 0.9518, -0.9554, -0.4935],
        [-1.4796,  0.9450,  0.3082]])

<img src="./var.png", width=320, heigth=240>

In [22]:
a.var(0, False)

tensor([1.2357, 0.7144, 0.1960])

In [23]:
a.var(0, True)

tensor([1.8536, 1.0716, 0.2940])

In [24]:
mean = a.mean(0)
mean

tensor([0.0899, 0.2320, 0.1182])

In [25]:
(a-mean)/np.sqrt(var)

RuntimeError: The size of tensor a (3) must match the size of tensor b (5) at non-singleton dimension 1

In [None]:
bn1 = torch.nn.BatchNorm1d(3, momentum=1)

In [None]:
bn1(a)