In [1]:
import torch
from torch import nn

In [2]:


class BatchNorm(nn.Module):

    """Batch normalization layer for convolutional neural network"""

    def __init__(self, num_channels: int = 3, eps: float = 1e-5, momentum:float = 0.1,
    is_affine: bool = True, track_running_stats: bool = True) -> None:
        super().__init__()

        self.num_channels = num_channels
        self.eps = eps
        self.momentum = momentum
        self.is_affine = is_affine
        self.track_running_stats = track_running_stats

        if is_affine: # Create parameters for γ and β for scale and shift
            self.scale = nn.Parameter(torch.ones(num_channels))
            self.shift = nn.Parameter(torch.zeros(num_channels))

        if track_running_stats: # Create buffers to store exponential moving averages of mean E[x(k)] and variance Var[x(k)]
            # Initialize as N(0, 1)
            self.register_buffer('exp_mean', nn.Parameter(torch.zeros(num_channels)))
            self.register_buffer('exp_var', nn.Parameter(torch.ones(num_channels)))

    def forward(self, x: torch.Tensor):

        """
        x: (B, C, D)
        D = l if 1D, h x w if 2D, h x w x t if 3D
        """
        
        x_shape_old = x.shape

        assert self.num_channels == x_shape_old[1]

        x = x.flatten(start_dim=2) # (B, C, n)
                                   # n is overall number of object's features
        #print(f'new_shape = {x.shape}')
        if self.training or not self.track_running_stats: # We will calculate the mini-batch mean and variance 
                                                          # if we are in training mode or 
                                                          # if we have not tracked exponential moving averages

            # calculate mean and variance along batch and feature dimensions (for each channel)
            batch_mean = x.mean(dim = [0, -1])
            batch_var = x.var(dim = [0, -1])

            if self.training and self.track_running_stats:
                # update exponential moving averages
                self.exp_mean = (1 - self.momentum) * self.exp_mean + self.momentum * batch_mean
                self.exp_var = (1 - self.momentum) * self.exp_var + self.momentum * batch_var
        else: # if not training and tracking stats
            batch_mean = self.exp_mean
            batch_var = self.exp_var
        
        # normalize x to N(0, 1)
        x_hat = (x - batch_mean.reshape(1, -1, 1)) / torch.sqrt(batch_var + self.eps).reshape(1, -1, 1)
        #x_hat = x_hat.reshape(1, -1 , 1)

        if self.is_affine:
            x_hat = self.scale.reshape(1, -1, 1) * x_hat #+ self.shift.reshape(1, -1, 1)
        
        return x_hat.reshape(x_shape_old)


In [13]:
def print_info(x):
    print(f"""
    shape: {x.shape}
    mean: {x.mean(dim = [0, 2, 3])}
    var: {x.var(dim=[0, 2, 3])}
    """
    )


x = torch.rand([2, 3, 2, 4])
print_info(x)

bn = BatchNorm(3)
x = bn(x)
print_info(x)




    shape: torch.Size([2, 3, 2, 4])
    mean: tensor([0.5102, 0.5535, 0.3655])
    var: tensor([0.0559, 0.0838, 0.0954])
    

    shape: torch.Size([2, 3, 2, 4])
    mean: tensor([ 2.5705e-07, -5.9605e-08,  2.9802e-08], grad_fn=<MeanBackward1>)
    var: tensor([0.9998, 0.9999, 0.9999], grad_fn=<VarBackward0>)
    
