In [2]:
import torch
import torch.nn as nn

### Creating custom Batch norm class from scratch

In [18]:
class CustomBatchNorm(nn.Module):
    def __init__(self, n_features:int, momentum:float = 0.2, eps:float=1e-5):
        """Constructor"""
        super(CustomBatchNorm, self).__init__()
        self.n_features = n_features 
        self.momentum = momentum 
        self.eps = eps 
        #init learnable parameters 
        self.w = nn.Parameter(torch.ones(n_features)) 
        self.b = nn.Parameter(torch.zeros(n_features)) 
        #init running parameters 
        self.register_buffer("r_mean", torch.zeros(n_features))
        self.register_buffer("r_var", torch.ones(n_features))

    def forward(self, X):
        if self.training:
            #self.training parameter inherited from nn.Module: 
            # set by model.train()/model.eval() under the hood
            mean = X.mean(dim=0)
            var = X.var(dim=0)
            self.r_mean = (1-self.momentum) * self.r_mean + self.momentum * mean 
            self.r_var = (1-self.momentum) * self.r_var + self.momentum * var 
        else: #model.eval()
            mean = self.r_mean 
            var = self.r_var 

        X_norm = (X - mean) / torch.sqrt(var + self.eps)
        out = X_norm * self.w + self.b 
        return out

In [19]:
batch_size, num_features = 32, 10
x = torch.randn(batch_size, num_features)
bn_custom = CustomBatchNorm(num_features)
out = bn_custom(x)
print(out)

tensor([[ 0.6956, -0.1333, -0.9063,  0.9191,  0.6351, -0.0619,  0.4855,  0.2765,
         -0.4178, -0.0142],
        [ 1.5361,  0.6359,  0.2768,  0.5297,  0.7738, -0.4328,  2.1645,  2.6478,
         -0.3231, -1.6783],
        [-0.5130, -2.8987,  1.5252,  0.8198, -0.9121, -1.4080, -0.4937, -0.4988,
          0.8220, -0.0117],
        [-1.9619, -0.0703,  0.3361,  0.4745,  0.7986,  0.0859, -0.8789, -0.0128,
          0.7536, -0.0692],
        [ 0.0384,  0.4572, -1.9582,  0.6919,  0.8999, -0.0966, -0.3107,  0.3296,
          1.3193,  0.8358],
        [ 0.2025, -0.3043, -0.2313,  2.3063, -0.9926, -0.0366,  0.1006,  0.5048,
         -0.5140,  1.1757],
        [-0.5413,  0.5176,  0.8963, -0.2723, -0.2341, -0.0221,  0.1557,  0.1573,
         -0.3923,  2.2432],
        [ 0.1394,  0.6526,  0.4680,  0.1660,  1.1206,  0.2214, -1.5110, -1.4656,
         -0.0124,  0.5773],
        [-0.2750, -0.0571,  0.0694,  0.2175,  0.1449, -1.3595,  0.1814, -2.6076,
         -2.1126, -1.1307],
        [-0.7024,  