[Batch Normalization paper](https://arxiv.org/pdf/1502.03167v3.pdf) <br>
[Group Normalization paper](https://arxiv.org/pdf/1803.08494.pdf)

------------
1-D case: Considering a batch of size 10 with 3 channels and $\;8\,\times\,8\;$ dimensions (height x width):

In [1]:
import torch

tensor = torch.randn(10, 3)
batch_layer1 = torch.nn.BatchNorm1d(3)

In [2]:
def my_batchnorm1d(tensor, epsilon=1e-05):
    """
    Expects tensor of shape (batch_size, input_dimensionality)
    
    """
    tensor_mean = tensor.mean(0)
    tensor_var = tensor.var(0, unbiased=False)
    return (tensor - tensor_mean) / torch.sqrt(tensor_var + epsilon)

In [3]:
output = my_batchnorm1d(tensor)

#### Comparing with pytorch's implementation

In [4]:
import torch

bn = torch.nn.BatchNorm1d(3)

torch.isclose(output, bn(tensor)).all()

tensor(True)

<br>

------------
### 2-D case
Considering a batch of size 10 with 3 channels and $\;8\,\times\,8\;$ dimensions (height x width):

In [5]:
import torch

batch_1 = torch.randn(10, 3, 8, 8)
batch_2 = torch.randn(10, 3, 8, 8)
batch_3 = torch.randn(10, 3, 8, 8)
batches = [batch_1, batch_2, batch_3]  # Used in training mode
batch_4 = torch.randn(10, 3, 8, 8)  # Used in eval mode

In [6]:
class BatchNorm2d():

    def __init__(self, epsilon=1e-05, momentum=0.1):  # Using the same defaults as pytorch
        """
        Expecting a tensor of shape (BS:[batch_size], C:[channels], H:[height], W:[width]).

        """
        self.eval_mode = True
        self.epsilon = epsilon
        self.momentum = momentum
        self.batches_processed_while_training = 0
        self.running_mean = 0
        self.running_var = 1.0

    def forward(self, tensor):

        if not self.eval_mode:  # TRAINING
            self.C = tensor.shape[1]

            # Mean over the batch (0), the height(2) and the width (3) --> Shape: (3,)
            current_mean = tensor.mean((0, 2, 3))
            
             # Variance (biased) over the batch (0), the height(2) and the width (3) --> Shape: (3,)
            current_var = tensor.var((0, 2, 3), unbiased=False)

            # For running statistics the unbiased variance is used
            current_var_unbiased = tensor.var((0, 2, 3), unbiased=True)
            self.running_mean = (1 - self.momentum) * self.running_mean + (self.momentum * current_mean)
            self.running_var = (1 - self.momentum) * self.running_var + (self.momentum * current_var_unbiased)

            self.batches_processed_while_training += 1
            # The one 1-d (singleton) vectors need to be reshaped to (C, 1, 1) so that broadcasting will work as expected.
            return (tensor - current_mean.reshape(self.C, 1, 1)) / torch.sqrt(current_var.reshape(self.C, 1, 1) + self.epsilon)

        else:

            return (tensor - self.running_mean.reshape(self.C, 1, 1)) / torch.sqrt(self.running_var.reshape(self.C, 1, 1) + self.epsilon)

    def __call__(self, tensor):
        return self.forward(tensor)

#### Train mode -  running estimate statistics are on

In [7]:
my_bn = BatchNorm2d(epsilon=1e-05, momentum=0.4)

my_bn.eval_mode = False  # TRAINING
my_outputs = []
for index, batch in enumerate(batches):
    my_outputs.append(my_bn(batch))

In [8]:
my_bn.eval_mode = True
my_eval_output = my_bn(batch_4)

#### Comparing with pytorch's implementation

In [9]:
torch_bn = torch.nn.BatchNorm2d(3, momentum=0.4)

torch_bn.train()  # TRAINING
torch_outputs = []
for index, batch in enumerate(batches):
    torch_outputs.append(torch_bn(batch))

In [10]:
torch_bn.eval()
torch_eval_output = torch_bn(batch_4)

#### Comparing Statistics after the training is over (the ones used in the final evaluation mode for each case):

In [11]:
print(f'My implementation results:\nrunning mean: {my_bn.running_mean} | Running variance: {my_bn.running_var}\n')
print(f'Torch results:\nrunning mean: {torch_bn.running_mean} | Running variance: {torch_bn.running_var}')

My implementation results:
running mean: tensor([-0.0017,  0.0115,  0.0170]) | Running variance: tensor([1.0094, 0.9398, 0.9941])

Torch results:
running mean: tensor([-0.0017,  0.0115,  0.0170]) | Running variance: tensor([1.0094, 0.9398, 0.9941])


<br>

#### And finally comparing the resulting tensors
- first three are used in training mode (running statistics are computed but not used)
- last one is used in evaluation mode (running statistics are used)

In [12]:
print([torch.isclose(my_outputs[x], torch_outputs[x]).all() for x in range(3)])  # Train mode - output tensors comparison
print(torch.isclose(my_eval_output, torch_eval_output).all())  # Eval mode - output tensor comparison

[tensor(True), tensor(True), tensor(True)]
tensor(True)
