# BatchNorm

In [4]:
import torch.nn as nn
import torch

# prompt: Write a class for batch normalisation

class BatchNorm(nn.Module):

    def __init__(self, num_channels, eps=1e-5, momentum=0.1):
        super().__init__()
        self.eps = eps
        self.momentum = momentum
        self.num_channels = num_channels
        self.gamma = nn.Parameter(torch.ones(num_channels))
        self.beta = nn.Parameter(torch.zeros(num_channels))
        self.running_mean = torch.zeros(num_channels)
        self.running_var = torch.ones(num_channels)

    def forward(self, x):
      # x = (b,c,h,w)
      if self.training:
          assert len(x.shape) in (2, 4)
          if len(x.shape) == 2:
            # When using a fully connected layer, calculate the mean and
            # variance on the feature dimension
            batch_mean = x.mean(axis=0)
            batch_var = ((x - batch_mean) ** 2).mean(axis=0)
          else:
            batch_mean = x.mean(axis=(0, 2, 3), keepdims=True)
            batch_var = ((x - batch_mean) ** 2).mean(axis=(0, 2, 3), keepdims=True)
          batch_size = x.shape[0]
          self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
          self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
      else:
          batch_mean = self.running_mean
          batch_var = self.running_var
      x_hat = (x - batch_mean) / torch.sqrt(batch_var + self.eps)
      # print(x_hat.shape)
      out = self.gamma.view([1, self.num_channels, 1, 1]) * x_hat + self.beta.view([1, self.num_channels, 1, 1])
      return out


In [5]:
g = BatchNorm(12)
x = torch.randn(7,12,10,10)
y = g(x)

# LayerNorm

In [6]:
import torch
class LayerNorm2D(nn.Module):
    def __init__(self, normalized_shape, epsilon = 1e-5):
        super(LayerNorm2D, self).__init__()
        self.epsilon = epsilon
        self.normalized_shape = normalized_shape
        self.gamma = nn.Parameter(torch.ones(normalized_shape))
        self.beta = nn.Parameter(torch.zeros(normalized_shape))

    def forward(self, x):
        assert len(x.shape) == 4 # 4 because len((batchsize, numchannels, height, width)) = 4

        variance = torch.var(x, dim = [1,2, 3], keepdims=True) 
        mean=torch.mean(x, dim = [1,2, 3], keepdims=True)
        out = (x-mean)/(torch.sqrt(variance)+self.epsilon)

        out = self.gamma.view([1, *self.normalized_shape]) * out + self.beta.view([1, *self.normalized_shape])
        return out

In [None]:
x = torch.randn(7,12,10,10)
g = LayerNorm2D([12,10,10])
y = g(x)


# GroupNorm

In [8]:
class GroupNorm2D(nn.Module):
    def __init__(self, num_channels, num_groups=4, epsilon=1e-5):
        super(GroupNorm2D, self).__init__()
        self.num_channels = num_channels
        self.num_groups = num_groups
        # self.num_groups = num_channels // 4
        self.epsilon = epsilon
        self.gamma = nn.Parameter(torch.ones(num_channels))
        self.beta = nn.Parameter(torch.zeros(num_channels))

    def forward(self, x):
        assert x.shape[1] == self.num_channels
        assert self.num_channels % self.num_groups == 0

        [N, C, H, W] = list(x.shape)
        out = torch.reshape(x, (N, self.num_groups, self.num_channels//self.num_groups, H, W))
        variance = torch.var(out, dim = [2, 3, 4], unbiased=False, keepdim=True)
        mean = torch.mean(out, dim = [2, 3, 4], keepdim=True)
        out = (out-mean)/torch.sqrt(variance +self.epsilon)
        out = out.view(N, self.num_channels, H, W)
        out = self.gamma.view([1, self.num_channels, 1, 1]) * out + self.beta.view([1, self.num_channels, 1, 1])
        return out

In [9]:
x = torch.randn(7,12,10,10)
g = GroupNorm2D(num_channels=12, num_groups=3)
y = g(x)

# InstanceNorm

In [13]:
class InstanceNorm2D(nn.Module):
    def __init__(self, num_channels, epsilon = 1e-5, momentum = 0.9, rescale = True):
        super(InstanceNorm2D, self).__init__()
        self.num_channels = num_channels
        self.epsilon = epsilon
        self.momentum = momentum
        self.rescale = rescale

        if(self.rescale == True):
            self.gamma = nn.Parameter(torch.ones(num_channels))
            self.beta = nn.Parameter(torch.zeros(num_channels))
        self.register_buffer('runningmean', torch.zeros(num_channels))
        self.register_buffer('runningvar', torch.ones(num_channels))
    def forward(self, x):
        assert x.shape[1] == self.num_channels
        assert len(x.shape) == 4 # 4 because len((batchsize, numchannels, height, width)) = 4

        if(self.training):
            variance, mean = torch.var(x, dim = [2, 3], unbiased=False), torch.mean(x, dim = [2, 3])
            out = (x-mean.view([-1, self.num_channels, 1, 1]))/
                    torch.sqrt(variance.view([-1, self.num_channels, 1, 1])+self.epsilon)

        else:
            variance, mean = torch.var(x, dim = [2, 3], unbiased=False), torch.mean(x, dim = [2, 3])
            out = (x-mean.view([-1, self.num_channels, 1, 1]))/
                torch.sqrt(variance.view([-1, self.num_channels, 1, 1])+self.epsilon)

        if(self.rescale == True):
            out = self.gamma.view([1, self.num_channels, 1, 1]) * 
                out + self.beta.view([1, self.num_channels, 1, 1])
        return out

In [14]:
x = torch.randn(7,12,10,10)
g = InstanceNorm2D(num_channels=12)
y_old = g(x)
print(y_old.shape)


g = torch.nn.InstanceNorm2d(num_features=12)
y_true = g(x)
print(y_true.shape)

print((abs(y_true-y_old) < 1e-6).all())


torch.Size([7, 12, 10, 10])
torch.Size([7, 12, 10, 10])
tensor(True)
