In [64]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


In [65]:
# m => batch size
# hidden => number of neurons
# hidden_prev => number of neurons in previous layer
m, hidden, hidden_prev = 4, 5, 6
W = np.random.rand(hidden, hidden_prev)
Z_prev = np.random.randint(0, 100, size =( m, hidden_prev))
gamma = np.ones((hidden,1))
beta = np.zeros((hidden,1))

In [66]:
Z = W @ Z_prev.T

In [67]:
Z.shape

(5, 4)

In [68]:
def batch_nom(Z, gamma, beta):
    mean = Z.mean(axis=1, keepdims=True)
    std_dev = Z.std(axis=1, keepdims=True)
    Z_norm = (Z - mean) / (std_dev + 1e-8)
    return gamma * Z_norm + beta

In [69]:
Z_norm = batch_nom(Z, gamma, beta)

In [70]:
print(Z.mean(axis=1, keepdims=True))
print(Z.std(axis=1, keepdims=True))

[[145.40224892]
 [ 94.07523818]
 [160.53466406]
 [104.3046715 ]
 [160.32462711]]
[[54.9585836 ]
 [44.39167015]
 [61.28986913]
 [41.07497073]
 [54.77469304]]


In [71]:
print(Z_norm.mean(axis=1, keepdims=True))
print(Z_norm.std(axis=1, keepdims=True))

[[-6.93889390e-18]
 [ 1.52655666e-16]
 [-1.11022302e-16]
 [ 2.04697370e-16]
 [-7.63278329e-17]]
[[1.]
 [1.]
 [1.]
 [1.]
 [1.]]


In [72]:
print(Z_norm)

[[ 0.72112619  0.99574911 -1.57069825 -0.14617705]
 [ 1.32157857  0.13765206 -1.4944045   0.03517387]
 [ 0.72728205  1.12665275 -1.41848407 -0.43545073]
 [ 0.68807939  0.86385778 -1.66367405  0.11173688]
 [ 0.6955258   1.06144702 -1.52890832 -0.2280645 ]]


In [73]:
def batch_norm_pytorch(Z, hidden):
    Z_tensor = torch.tensor(Z.T, dtype=torch.float32)
    bn = nn.BatchNorm1d(hidden)
    return bn(Z_tensor).T
Z_norm_pytorch = batch_norm_pytorch(Z, hidden)

In [74]:
print(Z_norm_pytorch)

tensor([[ 0.7211,  0.9957, -1.5707, -0.1462],
        [ 1.3216,  0.1377, -1.4944,  0.0352],
        [ 0.7273,  1.1267, -1.4185, -0.4355],
        [ 0.6881,  0.8639, -1.6637,  0.1117],
        [ 0.6955,  1.0614, -1.5289, -0.2281]], grad_fn=<PermuteBackward0>)


In [76]:
np.allclose(Z_norm, Z_norm_pytorch.detach().numpy())

True

## batch normalization in CNN

In [78]:
h, w, c, m = (5, 5, 4, 8)
Z = np.random.randint(0, 100, size = (h, w, c, m))
gamma2D = np.ones((1, 1, c, 1))
beta2D = np.zeros((1, 1, c, 1))

In [80]:
def batch_norm2D(Z, gamma, beta):
    mean = np.mean(Z, axis = (0, 1, 3), keepdims=True)
    std_dev = np.std(Z, axis = (0, 1, 3), keepdims=True)
    Z_norm = (Z - mean) / (std_dev + 1e-8)
    return gamma * Z_norm + beta
Z_norm2D = batch_norm2D(Z, gamma2D, beta2D)


In [81]:
def batch_morm2D_pytorch(Z, channels):
    Z_tensor = torch.tensor(Z.transpose(3, 2, 0, 1), dtype=torch.float32)
    bn2D = nn.BatchNorm2d(channels)
    return bn2D(Z_tensor).permute(2, 3 ,1 , 0).detach().numpy()
Z_norm2D_pytorch = batch_morm2D_pytorch(Z, c)
    

In [83]:
np.allclose(Z_norm2D, Z_norm2D_pytorch)

True