# Batch normalization on the evalution stage

In [10]:
import torch
import torch.nn as nn


input_size = 3
batch_size = 5
eps = 1e-1


class CustomBatchNorm1d:
    def __init__(self, weight, bias, eps, momentum):
        # Реализуйте в этом месте конструктор.
        self.weight = weight
        self.bias = bias
        self.eps = eps
        self.momentum = momentum
        self.evaluation = False

    def __call__(self, input_tensor):
        normed_tensor = torch.zeros(input_tensor.shape)
        N = input_tensor.shape[0]
        if self.evaluation:
            # make eval normalization
            print('Do the normalization on the floating mean value (floating mathematical expectation)')
            normed_tensor = input_tensor
        else:
            # make standart normalization
            print('Do the normalization on the full mean value (mathematical expectation)')
            for i in range(input_tensor.shape[1]):
                E = torch.sum(input_tensor[:,i])/N
                sigma_square = torch.sum((input_tensor[:,i]-E)**2)/N
                z = (input_tensor[:,i]-E)/torch.sqrt(sigma_square + self.eps)
                normed_tensor[:,i] = z*self.weight[i] + self.bias[i]
            print('input tensor shape:', input_tensor.shape)
            # print('input tensor:\n', input_tensor)
            # print('normed tensor:\n', normed_tensor)
        return normed_tensor

    def eval(self):
        # В этом методе реализуйте переключение в режим предикта.
        self.evaluation = True
        return


batch_norm = nn.BatchNorm1d(input_size, eps=eps)
batch_norm.bias.data = torch.randn(input_size, dtype=torch.float)
batch_norm.weight.data = torch.randn(input_size, dtype=torch.float)
batch_norm.momentum = 0.5

custom_batch_norm1d = CustomBatchNorm1d(batch_norm.weight.data,
                                        batch_norm.bias.data, eps, batch_norm.momentum)

# Проверка происходит автоматически вызовом следующего кода
# (раскомментируйте для самостоятельной проверки,
#  в коде для сдачи задания должно быть закомментировано):
all_correct = True

for i in range(8):
    torch_input = torch.randn(batch_size, input_size, dtype=torch.float)
    norm_output = batch_norm(torch_input)
    custom_output = custom_batch_norm1d(torch_input)
    all_correct &= torch.allclose(norm_output, custom_output, atol=1e-04) \
        and norm_output.shape == custom_output.shape
print('batch normalization is correct:', all_correct)

batch_norm.eval()
custom_batch_norm1d.eval()

for i in range(8):
    torch_input = torch.randn(batch_size, input_size, dtype=torch.float)
    norm_output = batch_norm(torch_input)
    custom_output = custom_batch_norm1d(torch_input)
    all_correct &= torch.allclose(norm_output, custom_output, atol=1e-04) \
        and norm_output.shape == custom_output.shape
print(all_correct)

Do the normalization on the full mean value (mathematical expectation)
input tensor shape: torch.Size([5, 3])
Do the normalization on the full mean value (mathematical expectation)
input tensor shape: torch.Size([5, 3])
Do the normalization on the full mean value (mathematical expectation)
input tensor shape: torch.Size([5, 3])
Do the normalization on the full mean value (mathematical expectation)
input tensor shape: torch.Size([5, 3])
Do the normalization on the full mean value (mathematical expectation)
input tensor shape: torch.Size([5, 3])
Do the normalization on the full mean value (mathematical expectation)
input tensor shape: torch.Size([5, 3])
Do the normalization on the full mean value (mathematical expectation)
input tensor shape: torch.Size([5, 3])
Do the normalization on the full mean value (mathematical expectation)
input tensor shape: torch.Size([5, 3])
batch normalization is correct: True
Do the normalization on the floating mean value (floating mathematical expectation)