In [11]:
import torch.nn as nn
import torch

# prompt: Write a class for batch normalisation

class BatchNorm(nn.Module):

    def __init__(self, num_channels, eps=1e-5, momentum=0.1):
        super().__init__()
        self.eps = eps
        self.momentum = momentum
        # define parameters gamma, beta which are learnable
        # dimension of gamma and beta should be (num_channels) ie its a one dimensional vector
        # initializing gamma as ones vector and beta as zeros vector (implies no scaling/shifting at the start)
        self.gamma = nn.Parameter(torch.ones(num_channels))
        self.beta = nn.Parameter(torch.zeros(num_channels))
        self.running_mean = torch.zeros(num_channels)
        self.running_var = torch.ones(num_channels)

    def forward(self, x):
      # x = (b,c,h,w)
      if self.training:
          assert len(x.shape) in (2, 4)
          if len(x.shape) == 2:
            # When using a fully connected layer, calculate the mean and
            # variance on the feature dimension
            batch_mean = x.mean(axis=0)
            batch_var = ((x - batch_mean) ** 2).mean(axis=0)
          else:
            # When using a two-dimensional convolutional layer, calculate the
            # mean and variance on the channel dimension (axis=1). Here we
            # need to maintain the shape of X, so that the broadcasting
            # operation can be carried out later
            batch_mean = x.mean(axis=(0, 2, 3), keepdims=True)
            batch_var = ((x - batch_mean) ** 2).mean(axis=(0, 2, 3), keepdims=True)
          batch_size = x.shape[0]
          self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * batch_mean
          self.running_var = (1 - self.momentum) * self.running_var + self.momentum * batch_var
      else:
          batch_mean = self.running_mean
          batch_var = self.running_var
      x_hat = (x - batch_mean) / torch.sqrt(batch_var + self.eps)
      return self.gamma * x_hat + self.beta # Scale and shift



In [12]:
class LayerNorm2D(nn.Module):
    def __init__(self, num_channels, epsilon = 1e-5):
        super(LayerNorm2D, self).__init__()
        self.num_channels = num_channels
        self.epsilon = epsilon

        self.gamma = nn.Parameter(torch.ones(num_channels))
        self.beta = nn.Parameter(torch.zeros(num_channels))

    def forward(self, x):
        assert list(x.shape)[1] == self.num_channels
        assert len(x.shape) == 4 # 4 because len((batchsize, numchannels, height, width)) = 4

        variance, mean = torch.var(x, dim = [1,2, 3], unbiased=False), torch.mean(x, dim = [1,2, 3])
        out = (x-mean.view([-1, 1, 1, 1]))/torch.sqrt(variance.view([-1, 1, 1, 1])+self.epsilon)

        out = self.gamma.view([1, self.num_channels, 1, 1]) * out + self.beta.view([1, self.num_channels, 1, 1])
        return out

In [13]:
class GroupNorm2D(nn.Module):
    def __init__(self, num_channels, num_groups=4, epsilon=1e-5):
        super(GroupNorm2D, self).__init__()
        self.num_channels = num_channels
        self.num_groups = num_groups
        # self.num_groups = num_channels // 4
        self.epsilon = epsilon
        self.gamma = nn.Parameter(torch.ones(num_channels))
        self.beta = nn.Parameter(torch.zeros(num_channels))

    def forward(self, x):
        assert x.shape[1] == self.num_channels
        assert len(x.shape) == 4 #4 because (batchsize, numchannels, height, width)

        [N, C, H, W] = list(x.shape)
        print(x.shape)
        out = torch.reshape(x, (N, self.num_groups, self.num_channels//self.num_groups, H, W))
        print(out.shape)
        variance, mean = torch.var(out, dim = [2, 3, 4], unbiased=False, keepdim=True), torch.mean(out, dim = [2, 3, 4], keepdim=True)
        print(variance.shape)
        out = (out-mean)/torch.sqrt(variance +self.epsilon)
        out = out.view(N, self.num_channels, H, W)
        out = self.gamma.view([1, self.num_channels, 1, 1]) * out + self.beta.view([1, self.num_channels, 1, 1])
        return out

In [14]:
g = GroupNorm2D(12)
x = torch.randn(1,12,10,10)
y = g(x)

torch.Size([1, 12, 10, 10])
torch.Size([1, 4, 3, 10, 10])
torch.Size([1, 4, 1, 1, 1])


In [15]:
class InstanceNorm2D(nn.Module):
    def __init__(self, num_channels, epsilon = 1e-5, momentum = 0.9, rescale = True):
        super(InstanceNorm2D, self).__init__()
        self.num_channels = num_channels
        self.epsilon = epsilon
        self.momentum = momentum
        self.rescale = rescale

        if(self.rescale == True):
            # define parameters gamma, beta which are learnable
            # dimension of gamma and beta should be (num_channels) ie its a one dimensional vector
            # initializing gamma as ones vector and beta as zeros vector (implies no scaling/shifting at the start)
            self.gamma = nn.Parameter(torch.ones(num_channels))
            self.beta = nn.Parameter(torch.zeros(num_channels))

        # running mean and variance should have the same dimension as in batchnorm
        # ie, a vector of size num_channels because while testing, when we get one
        # sample at a time, we should be able to use this.
        self.register_buffer('runningmean', torch.zeros(num_channels))
        self.register_buffer('runningvar', torch.ones(num_channels))
    def forward(self, x):
        assert x.shape[1] == self.num_channels
        assert len(x.shape) == 4 # 4 because len((batchsize, numchannels, height, width)) = 4

        if(self.training):
            #calculate mean and variance along the dimensions other than the channel dimension
            #variance calculation is using the biased formula during training
            variance, mean = torch.var(x, dim = [2, 3], unbiased=False), torch.mean(x, dim = [2, 3])
            out = (x-mean.view([-1, self.num_channels, 1, 1]))/torch.sqrt(variance.view([-1, self.num_channels, 1, 1])+self.epsilon)

        else:
            variance, mean = torch.var(x, dim = [2, 3], unbiased=False), torch.mean(x, dim = [2, 3])
            out = (x-mean.view([-1, self.num_channels, 1, 1]))/torch.sqrt(variance.view([-1, self.num_channels, 1, 1])+self.epsilon)

        if(self.rescale == True):
            out = self.gamma.view([1, self.num_channels, 1, 1]) * out + self.beta.view([1, self.num_channels, 1, 1])
        return out