In [1]:
import torch
import torch.nn as nn
import numpy as np

# Функция batch norm

In [2]:
class CustomBatchNorm1d:
    def __init__(self, weight, bias, eps, momentum):
        self.weight = weight
        self.bias = bias
        self.eps = eps
        self.momentum = momentum

        self.eval_flag = True
        self.running_mean = torch.zeros(weight.shape[0])
        self.running_var = torch.ones(weight.shape[0])

    def __call__(self, input_tensor):
        n = input_tensor.shape[0] # batch size
        mean = torch.mean(input_tensor, dim=0)
        var = torch.var(input_tensor, dim=0, unbiased=False)

        if self.eval_flag:
            self.running_mean = self.momentum * self.running_mean + (1 - self.momentum) * mean
            self.running_var = self.momentum * self.running_var + (1 - self.momentum) * var * n / (n - 1)
            normed_tensor = (input_tensor - mean)/(var + self.eps)**0.5 * self.weight + self.bias
        else:
            normed_tensor = (input_tensor - self.running_mean)/(self.running_var + self.eps)**0.5 * self.weight + self.bias

        return normed_tensor

    def eval(self):
        self.eval_flag = self.eval_flag == False

In [3]:
# torch.manual_seed(0)

input_size = 3
batch_size = 5
eps = 1e-1

batch_norm = nn.BatchNorm1d(input_size, eps=eps)
batch_norm.bias.data = torch.randn(input_size, dtype=torch.float)
batch_norm.weight.data = torch.randn(input_size, dtype=torch.float)
batch_norm.momentum = 0.5

custom_batch_norm1d = CustomBatchNorm1d(batch_norm.weight.data, batch_norm.bias.data, eps, batch_norm.momentum)

In [4]:
all_correct = True

for i in range(8):
    torch_input = torch.randn(batch_size, input_size, dtype=torch.float)
    norm_output = batch_norm(torch_input)
    custom_output = custom_batch_norm1d(torch_input)
    all_correct &= torch.allclose(norm_output, custom_output) \
                   and norm_output.shape == custom_output.shape

batch_norm.eval()
print(batch_norm._buffers)
custom_batch_norm1d.eval()


for i in range(8):
    torch_input = torch.randn(batch_size, input_size, dtype=torch.float)
    norm_output = batch_norm(torch_input)
    custom_output = custom_batch_norm1d(torch_input)
    all_correct &= torch.allclose(norm_output, custom_output) \
                   and norm_output.shape == custom_output.shape

print(all_correct)

OrderedDict([('running_mean', tensor([-0.1593,  0.1291, -0.1376])), ('running_var', tensor([0.8012, 1.1854, 0.6336])), ('num_batches_tracked', tensor(8))])
True
