In [0]:
import numpy as np
import torch
import torch.nn as nn

In [0]:
def custom_batch_norm1d(input_tensor, eps):
    normed_tensor = ((input_tensor - input_tensor.mean(dim=0)) / 
                     torch.sqrt(input_tensor.var(dim=0, unbiased=False) + eps))
    return normed_tensor

input_tensor = torch.Tensor([[0., 0, 1, 0, 2], [0, 1, 1, 0, 10]])
batch_norm = nn.BatchNorm1d(input_tensor.shape[1], affine=False)

In [68]:
all_correct = True
for eps_power in range(10):
    eps = np.power(10., -eps_power)
    batch_norm.eps = eps
    all_correct &= torch.allclose(batch_norm(input_tensor),
                                  custom_batch_norm1d(input_tensor, eps))
print(all_correct)

True


In [0]:
input_size = 7
batch_size = 5
input_tensor = torch.randn(batch_size, input_size, dtype=torch.float)

eps = 1e-3
batch_norm = nn.BatchNorm1d(input_size, eps=eps)
batch_norm.bias.data = torch.randn(input_size, dtype=torch.float)
batch_norm.weight.data = torch.randn(input_size, dtype=torch.float)

In [0]:
def custom_batch_norm1d(input_tensor, weight, bias, eps):
    mean = input_tensor.mean(dim=0)
    var  = input_tensor.var(dim=0, unbiased=False)

    normed_tensor = (input_tensor - mean) / torch.sqrt(var + eps) * weight + bias
    return normed_tensor

In [71]:
print(torch.allclose(batch_norm(input_tensor),
                     custom_batch_norm1d(input_tensor, batch_norm.weight.data,
                                         batch_norm.bias.data, eps)))

True


In [0]:
input_size = 3
batch_size = 5
eps = 1e-1

In [0]:
class CustomBatchNorm1d:
    def __init__(self, weight, bias, eps, momentum):
        self.weight = weight
        self.bias = bias
        self.eps = eps
        self.momentum = momentum
        self.EMA_mean = 0
        self.EMA_var  = 1
        self.flag_train = True

    def __call__(self, input_tensor):
        if self.flag_train == True:
            mean = input_tensor.mean(dim=0)
            var  = input_tensor.var(dim=0, unbiased=False)

            self.EMA_mean = mean * self.momentum + (1 - self.momentum) * self.EMA_mean
            self.EMA_var  = var * self.momentum * input_tensor.shape[0] / \
            (input_tensor.shape[0] - 1) + (1 - self.momentum) * self.EMA_var

        else:
            mean = self.EMA_mean
            var  = self.EMA_var

        return (input_tensor - mean) / torch.sqrt(var + eps) * self.weight + self.bias

    def eval(self):
        self.flag_train = False
        

In [0]:
batch_norm = nn.BatchNorm1d(input_size, eps=eps)
batch_norm.bias.data = torch.randn(input_size, dtype=torch.float)
batch_norm.weight.data = torch.randn(input_size, dtype=torch.float)
batch_norm.momentum = 0.5

custom_batch_norm1d = CustomBatchNorm1d(batch_norm.weight.data, 
                                        batch_norm.bias.data, eps,
                                        batch_norm.momentum)

In [75]:
all_correct = True

for i in range(8):
    torch_input = torch.randn(batch_size, input_size, dtype=torch.float)
    all_correct &= torch.allclose(batch_norm(torch_input),
                                  custom_batch_norm1d(torch_input))
print(all_correct)

batch_norm.eval()
custom_batch_norm1d.eval()

for i in range(8):
    torch_input = torch.randn(batch_size, input_size, dtype=torch.float)
    all_correct &= torch.allclose(batch_norm(torch_input),
                                  custom_batch_norm1d(torch_input))
print(all_correct)


True
True


In [0]:
eps = 1e-3

input_channels = 3
batch_size = 3
height = 10
width = 10

In [0]:
batch_norm_2d = nn.BatchNorm2d(input_channels, affine=False, eps=eps)
input_tensor = torch.randn(batch_size, input_channels, height, width, dtype=torch.float)

In [0]:
def custom_batch_norm2d(input_tensor, eps):
    mean = input_tensor.mean(dim=(0,2,3)).reshape(1,input_tensor.shape[1],1,1)
    var  = input_tensor.var(dim=(0,2,3), unbiased=False).reshape(1,input_tensor.shape[1],1,1)
    return (input_tensor - mean) / torch.sqrt(var + eps)

In [88]:
print(torch.allclose(batch_norm_2d(input_tensor),
                     custom_batch_norm2d(input_tensor, eps)))

True


In [0]:
def custom_layer_norm(input_tensor, eps):
    dims = list(range(len(input_tensor.size())))[1:]
    ones = (1 for i in dims)
    mean = input_tensor.mean(dim=dims).reshape(input_tensor.shape[0], *ones)
    ones = (1 for i in dims)
    var  = input_tensor.var(dim=dims, unbiased=False).reshape(input_tensor.shape[0], *ones)
    return (input_tensor - mean) / torch.sqrt(var + eps)

In [90]:
all_correct = True
for dim_count in range(3, 9):
    input_tensor = torch.randn(*list(range(3, dim_count + 2)), dtype=torch.float)
    layer_norm = nn.LayerNorm(input_tensor.size()[1:],
                              elementwise_affine=False, eps=eps)
    all_correct &= torch.allclose(layer_norm(input_tensor),
                                  custom_layer_norm(input_tensor, eps), 1e-2)
print(all_correct)

True


In [0]:
eps = 1e-3

batch_size = 5
input_channels = 2
input_length = 30

In [92]:
instance_norm = nn.InstanceNorm1d(input_channels, affine=False, eps=eps)
input_tensor = torch.randn(batch_size, input_channels, input_length, dtype=torch.float)

ERROR! Session/line number was not unique in database. History logging moved to new session 59


In [0]:
def custom_instance_norm1d(input_tensor, eps):
    mean = input_tensor.mean(dim=2).reshape(*input_tensor.shape[:2], 1)
    var  = input_tensor.var(dim=2, unbiased=False).reshape(*input_tensor.shape[:2], 1)
    return (input_tensor - mean) / torch.sqrt(var + eps)

In [106]:
print(torch.allclose(instance_norm(input_tensor),
                     custom_instance_norm1d(input_tensor, eps)))

True


In [0]:
channel_count = 6
eps = 1e-3
batch_size = 20
input_size = 2

In [0]:
input_tensor = torch.randn(batch_size, channel_count, input_size)

In [0]:
def custom_group_norm(input_tensor, groups, eps):
    grouped_shape = (input_tensor.shape[0], groups, -1)
    grouped = input_tensor.reshape(*grouped_shape)
    mean = grouped.mean(dim=2).reshape(*grouped_shape)
    var = grouped.var(dim=2, unbiased=False).reshape(*grouped_shape)
    return ((grouped - mean) / torch.sqrt(var + eps)).reshape(*input_tensor.shape)

In [147]:
all_correct = True 
for groups in [2, 3, 6]:
    group_norm = nn.GroupNorm(groups, channel_count, eps=eps, affine=False)
    all_correct &= torch.allclose(group_norm(input_tensor),
                                  custom_group_norm(input_tensor, groups, eps), 1e-3)
print(all_correct)

True
