In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
# applying an operation across dimention means reduce that dimention by applying
# if i say mean across dim=0 -> dim 0 is across rows -> reduce to single row ->
# for corresponding elements in the rows apply the operation

# when say mean across the column, reduce to single column, meaning just take all
# numbers across the column and find mean, which is nothing but mean across
# the features
x = torch.tensor([[1, 2, 3],
                 [3, 4, 5]], dtype=torch.float32)
print(x.shape)
print(f'mean dim=0, {x.mean(dim=(0))}')
print(f'mean dim=1, {x.mean(dim=(1))}')


torch.Size([2, 3])
mean dim=0, tensor([2., 3., 4.])
mean dim=1, tensor([2., 4.])


In [None]:
# Batch normaliztion
# IDEA : we scale features or inputs bewteen [-1, 1] or [0, 1] to put them in scale
# However, with neural networks with layers where inputs get transformed
# the scales may tip off as they go thru layers (explosive gradients or vanishing
# gradients). To keep activation also in check on scale we can introduce normalizaiton
# at layer level to stabilize the network, but we give network ability to learn
# scale and shift so it can still use them when needed and turnoff where it doesn't
# need

# One thing to remember is you will need batch of data fed thru the network
# to normalize with mean but some times we try to predict single image instead
# of batch - which means may be we can't use the average since we have only one
# I think thats we do use running average as new data comes into during inference


# batch norm - is simply normalization at batch and spatial dimention level
# meaning for an image, you compute mean over all pixels in the given batch for
# each channel separately

# This is useful for CNN kind of architectures as we have want to normalize
# across all samples of a batch across spatial dimention (?)

# mean across the batch of samples
class BatchNormLayer(nn.Module):
  def __init__(self, num_features):
    super().__init__()
    self.gamma = torch.nn.Parameter(torch.ones(num_features))
    self.beta = torch.nn.Parameter(torch.zeros(num_features))

  def forward(self, x):
    # normalize input at batch dimention
    x_mean = x.mean(dim=(0, 1, 2), keepdim=True)
    x_var =  x.var(dim=(0, 1, 2), keepdim=True)
    print(f'xmean {x_mean.shape}')
    print(f'xvar {x_var.shape}')
    x = (x - x_mean) / (x_var.sqrt() + 1e-5)
    print(x.shape)
    return self.gamma * x + self.beta

# mean across the features
class LayerNorm(nn.Module):
  def __init__(self, num_features=1):
    super().__init__()
    self.gamma = nn.Parameter(torch.ones(num_features))
    self.beta = nn.Parameter(torch.zeros(num_features))

  def forward(self, x):
    x_mean = x.mean(dim=-1, keepdim=True)
    x_var = x.var(dim=-1, keepdim=True)
    x = (x - x_mean) / (x_var.sqrt() + 1e-5)
    return self.gamma * x + self.beta




In [None]:
normLayer = BatchNormLayer(3)
x = torch.rand((32,  28, 28, 3))
x_n = normLayer(x)
print(x_n[0])

y = torch.rand((32, 28), dtype=torch.float32)
normLayer = LayerNorm(y.shape[1])
x_n2 = normLayer(y)
print(x_n2[0])

xmean torch.Size([1, 1, 1, 3])
xvar torch.Size([1, 1, 1, 3])
torch.Size([32, 28, 28, 3])
tensor([[[-1.6546, -1.1129,  1.6230],
         [-0.7874, -1.6135, -0.3364],
         [-1.6502,  0.2450, -0.5971],
         ...,
         [-1.2984,  0.1397,  1.0136],
         [-0.1021, -0.1122,  1.5324],
         [ 0.4109,  0.4648,  1.2886]],

        [[ 1.5251, -0.8946, -1.6309],
         [ 0.5354,  0.3831, -0.3087],
         [-0.6289,  0.2811, -0.1496],
         ...,
         [ 0.8044, -1.0594, -0.3974],
         [-1.1633, -0.2822,  0.7319],
         [-0.1156,  1.1924,  0.5140]],

        [[ 1.0720, -0.1923,  0.1814],
         [-0.4184, -0.5187, -1.3650],
         [ 0.1533, -0.7254, -1.1338],
         ...,
         [-1.1585,  0.4252,  1.4898],
         [ 1.1304, -0.8287,  1.4644],
         [-1.2427, -1.3963, -0.4700]],

        ...,

        [[ 1.5125,  0.8966,  1.5434],
         [ 0.7702, -0.2155,  1.6845],
         [ 0.8664,  0.5861,  0.6780],
         ...,
         [-0.7737,  1.6845,  0.1731],