### Batch-norm I

In this step, you need to implement the batch normalization function without using the standard function with the following simplifications:

    The Beta parameter is taken equal to 0.
    The Gamma parameter is assumed to be 1.
    The function should work correctly only at the training stage.
    The input has the dimension of the number of elements in the batch * the length of each instance.


In [2]:
import numpy as np
import torch
import torch.nn as nn

def custom_batch_norm1d(input_tensor, eps):
    normed_tensor = input_tensor
    m = torch.mean(input_tensor, dim=0)
    s = torch.var(input_tensor, dim=0, unbiased=False)
    normed_tensor = (input_tensor - m) / torch.sqrt(s + eps) # Напишите в этом месте нормирование входного тензора
    return normed_tensor


input_tensor = torch.Tensor([[0.0, 0, 1, 0, 2], [0, 1, 1, 0, 10]])
batch_norm = nn.BatchNorm1d(input_tensor.shape[1], affine=False)


all_correct = True
for eps_power in range(10):
     eps = np.power(10., -eps_power)
     batch_norm.eps = eps
     batch_norm_out = batch_norm(input_tensor)
     custom_batch_norm_out = custom_batch_norm1d(input_tensor, eps)

     all_correct &= torch.allclose(batch_norm_out, custom_batch_norm_out)
     all_correct &= batch_norm_out.shape == custom_batch_norm_out.shape
print(all_correct)

True


### Batch-norm II

Let's generalize the function from the previous step a bit - let's add the ability to set Beta and Gamma parameters.

At this step, you need to implement the batch normalization function without using the standard function with the following simplifications:

    The function should work correctly only at the training stage.
    The input has the dimension of the number of elements in the batch * the length of each instance.


In [1]:
import torch
import torch.nn as nn

input_size = 7
batch_size = 5
input_tensor = torch.randn(batch_size, input_size, dtype=torch.float)

eps = 1e-3

def custom_batch_norm1d(input_tensor, weight, bias, eps):
    # YOUR CODE HERE
    normed_tensor = input_tensor
    m = torch.mean(input_tensor, dim=0)
    s = torch.var(input_tensor, dim=0, unbiased=False)
    normed_tensor = ((input_tensor - m) / torch.sqrt(s + eps)) * weight + bias 
    return normed_tensor


batch_norm = nn.BatchNorm1d(input_size, eps=eps)
batch_norm.bias.data = torch.randn(input_size, dtype=torch.float)
batch_norm.weight.data = torch.randn(input_size, dtype=torch.float)
batch_norm_out = batch_norm(input_tensor)
custom_batch_norm_out = custom_batch_norm1d(input_tensor, batch_norm.weight.data, batch_norm.bias.data, eps)
print(torch.allclose(batch_norm_out, custom_batch_norm_out) \
       and batch_norm_out.shape == custom_batch_norm_out.shape)

True


### Batch-norm III

Let's get rid of one more simplification - we implement the work of the batch-normalization layer at the prediction stage.

At this stage, instead of batch statistics, we will use exponentially smoothed statistics from the layer's training history.

In this step, you need to implement a full-fledged batch normalization class without using the standard function, which takes a two-dimensional tensor as input. Be careful, the calculation of the variance is based on a biased sample, and the calculation of the moving average is based on an unbiased one. 

In [1]:
import torch
import torch.nn as nn


input_size = 3
batch_size = 5
eps = 1e-1


class CustomBatchNorm1d:
    def __init__(self, weight, bias, eps, momentum):
        # YOUR CODE HERE
        self.weight = weight
        self.bias = bias
        self.eps = eps
        self.momentum = momentum
        self.normed_tensor = None
        self.training = True
        self.mean = None
        self.var = None
        self.run_mean = 0
        self.run_var = 1

    def __call__(self, input_tensor):
        # normalize tensor
        # YOUR CODE HERE
        
        if self.training:
            self.mean = torch.mean(input_tensor, dim=0)
            self.var = torch.var(input_tensor, dim=0, unbiased=False)
            self.run_mean = (1 - self.momentum) * self.mean + self.momentum * self.run_mean
            self.run_var = (1 - self.momentum) * torch.var(input_tensor, dim=0, unbiased=True) + self.momentum * self.run_var 
            self.normed_tensor = ((input_tensor - self.mean) / torch.sqrt(self.var + eps)) * self.weight + self.bias 
            return self.normed_tensor
        else:
            self.normed_tensor = ((input_tensor - self.run_mean) / torch.sqrt(self.run_var + eps)) * self.weight + self.bias
            return self.normed_tensor


    def eval(self):
        # turn to eval()
        # YOUR CODE HERE
        self.training = False
        return self.normed_tensor


batch_norm = nn.BatchNorm1d(input_size, eps=eps)
batch_norm.bias.data = torch.randn(input_size, dtype=torch.float)
batch_norm.weight.data = torch.randn(input_size, dtype=torch.float)
batch_norm.momentum = 0.5

custom_batch_norm1d = CustomBatchNorm1d(batch_norm.weight.data,
                                        batch_norm.bias.data, eps, batch_norm.momentum)

all_correct = True

for i in range(8):
    torch_input = torch.randn(batch_size, input_size, dtype=torch.float)
    norm_output = batch_norm(torch_input)
    custom_output = custom_batch_norm1d(torch_input)
    all_correct &= torch.allclose(norm_output, custom_output, atol=1e-04) \
        and norm_output.shape == custom_output.shape
print(all_correct)

batch_norm.eval()
custom_batch_norm1d.eval()

for i in range(8):
    torch_input = torch.randn(batch_size, input_size, dtype=torch.float)
    norm_output = batch_norm(torch_input)
    custom_output = custom_batch_norm1d(torch_input)
    all_correct &= torch.allclose(norm_output, custom_output, atol=1e-04) \
        and norm_output.shape == custom_output.shape
print(norm_output)
print(custom_output)
print(all_correct)

True
tensor([[-0.1108,  2.5511,  0.7395],
        [ 0.0340,  0.5591,  0.5695],
        [ 0.2618,  0.7414,  0.2233],
        [ 0.4852, -0.3341,  0.4275],
        [ 0.6944, -1.2451,  1.0495]], grad_fn=<NativeBatchNormBackward0>)
tensor([[-0.1108,  2.5511,  0.7395],
        [ 0.0340,  0.5591,  0.5695],
        [ 0.2618,  0.7414,  0.2233],
        [ 0.4852, -0.3341,  0.4275],
        [ 0.6944, -1.2451,  1.0495]])
True
