In [1]:
import torch
import numpy as np
from torch.autograd import Variable

In [2]:
n = 10
d = 5

module = torch.nn.BatchNorm1d(d, affine=False)

running_mean = Variable(torch.zeros(1, d))
running_var = Variable(torch.ones(1, d))

In [3]:
def batch_norm(x, training, batch_count=None, momentum=0.1):
    ones = Variable(torch.ones(x.size()[0], 1))
    global running_mean
    global running_var
    if training:
        ones = Variable(torch.ones(x.size()[0], 1))
        mean = torch.mean(x, 0)
        var = np.var(x.data.numpy(), axis=0).reshape(1, x.size()[1])
        var = Variable(torch.FloatTensor(var))
        running_mean = ((1.0 - momentum) * running_mean) + momentum * mean
        running_var = ((1.0 - momentum) * running_var) + momentum * var
    else:
        mean = running_mean
        var = running_var
    x_normalized = torch.div(x - ones.mm(mean), ones.mm(torch.sqrt(var + 1e-5)))
    return x_normalized

In [4]:
t = np.random.randint(1, 10, size=(n, d))
A = Variable(torch.FloatTensor(t))
A

Variable containing:
    1     7     3     9     7
    8     3     8     5     6
    3     7     2     8     1
    7     8     6     8     9
    6     5     2     3     7
    6     3     2     2     2
    3     9     4     9     5
    8     2     2     3     2
    2     4     1     3     8
    6     5     6     3     7
[torch.FloatTensor of size 10x5]

In [5]:
v1 = batch_norm(A, True)
v1

Variable containing:
-1.6609  0.7595 -0.2727  1.3592  0.6030
 1.2457 -1.0276  2.0000 -0.1102  0.2261
-0.8305  0.7595 -0.7273  0.9919 -1.6583
 0.8305  1.2063  1.0909  0.9919  1.3568
 0.4152 -0.1340 -0.7273 -0.8449  0.6030
 0.4152 -1.0276 -0.7273 -1.2123 -1.2814
-0.8305  1.6530  0.1818  1.3592 -0.1508
 1.2457 -1.4743 -0.7273 -0.8449 -1.2814
-1.2457 -0.5808 -1.1818 -0.8449  0.9799
 0.4152 -0.1340  1.0909 -0.8449  0.6030
[torch.FloatTensor of size 10x5]

In [6]:
v2 = module(A)
v2

Variable containing:
-1.6609  0.7595 -0.2727  1.3592  0.6030
 1.2457 -1.0276  2.0000 -0.1102  0.2261
-0.8305  0.7595 -0.7273  0.9919 -1.6583
 0.8305  1.2063  1.0909  0.9919  1.3568
 0.4152 -0.1340 -0.7273 -0.8449  0.6030
 0.4152 -1.0276 -0.7273 -1.2123 -1.2814
-0.8305  1.6530  0.1818  1.3592 -0.1508
 1.2457 -1.4743 -0.7273 -0.8449 -1.2814
-1.2457 -0.5808 -1.1818 -0.8449  0.9799
 0.4152 -0.1340  1.0909 -0.8449  0.6030
[torch.FloatTensor of size 10x5]

In [7]:
np.isclose(v1.data.numpy(), v2.data.numpy())

array([[ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True]], dtype=bool)

In [8]:
module.eval()
v3 = module(A)
print(v3)
module.train()

Variable containing:
 0.4023  5.3607  2.2017  6.4520  4.9807
 6.0350  2.0465  6.3716  3.4050  4.2097
 2.0117  5.3607  1.3677  5.6903  0.3547
 5.2303  6.1893  4.7036  5.6903  6.5227
 4.4256  3.7036  1.3677  1.8815  4.9807
 4.4256  2.0465  1.3677  1.1198  1.1257
 2.0117  7.0178  3.0357  6.4520  3.4387
 6.0350  1.2180  1.3677  1.8815  1.1257
 1.2070  2.8751  0.5337  1.8815  5.7517
 4.4256  3.7036  4.7036  1.8815  4.9807
[torch.FloatTensor of size 10x5]



BatchNorm1d(5, eps=1e-05, momentum=0.1, affine=False)

In [9]:
v4 = batch_norm(A, False, batch_count=1)
print(v4)

Variable containing:
 0.4110  5.4662  2.2441  6.6119  5.1007
 6.1649  2.0868  6.4942  3.4894  4.3111
 2.0550  5.4662  1.3940  5.8313  0.3632
 5.3429  6.3110  4.7941  5.8313  6.6799
 4.5210  3.7765  1.3940  1.9282  5.1007
 4.5210  2.0868  1.3940  1.1475  1.1528
 2.0550  7.1559  3.0941  6.6119  3.5215
 6.1649  1.2419  1.3940  1.9282  1.1528
 1.2330  2.9316  0.5440  1.9282  5.8903
 4.5210  3.7765  4.7941  1.9282  5.1007
[torch.FloatTensor of size 10x5]



In [10]:
np.isclose(v3.data.numpy(), v4.data.numpy())

array([[False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False]], dtype=bool)