# Playground for Normalization Layer

## Prototype Normalization Layer
* Normalizes activation levels to mean 0 and variance 1
    * shifts activation values by mean
    * divides activation levels by sqrt(variance)
* Contains additional learnable parameters to shift mean from 0 if necessary
* Contains additional learnable parameters to scale variance away from 1 if necessary

In [1]:
import torch
import torch.nn as nn

torch.manual_seed(42)

class LayerNorm(nn.Module):

    def __init__(self, embed_dim, verbose=False):
        super().__init__()
        
        self.eps = 1e-5         # prevents division by 0
        self.shift = nn.Parameter(torch.zeros(embed_dim))
        self.scale = nn.Parameter(torch.ones(embed_dim))
        
        if verbose:
            print(f"\n=== LayerNorm Initialization ===")
            print(f"    embed_dim =", embed_dim)
            print(f"    Generating self.shift = nn.Parameter(torch.zeros({embed_dim}))")
            print(f"    Generating self.scale = nn.Parameter(torch.ones({embed_dim}))")        
            print(f"=== End Initialization ===\n")

    def forward(self, x, verbose=False):
        in_mean = x.mean(dim=-1, keepdim=True)
        in_variance = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x-in_mean) / torch.sqrt(in_variance + self.eps)
        norm_mean = norm_x.mean(dim=-1, keepdim=True)
        norm_var = norm_x.var(dim=-1, keepdim=True, unbiased=False)

        if verbose:
            print(f"\n=== LayerNorm Forward Pass ===")
            print(f'    Input: ', x)
            print(f'    Input_Mean: ', in_mean)
            print(f'    Input_Variance: ', in_variance)
            print(f'    Normalized (norm_x): ', norm_x)
            print(f'    Norm_Mean: ', norm_mean)
            print(f'    Norm_Variance: ', norm_var)
            print(f'    Output = self.scale * norm_x + self.shift')
            print(f"=== End Forward Pass ===\n")

        return self.scale * norm_x + self.shift

In [1]:
def test_normalization(verbose = False):

    embbed_dim = 6
    print(f'Embbed_dim: ', embbed_dim)

    batch_example = torch.randn(2, 4)   # Batch size 2 and context = 4
    layer = nn.Sequential(nn.Linear(4, embbed_dim), nn.ReLU())
    out = layer(batch_example)
    
    print(f'Output\n:', out)

    mean = out.mean(dim=-1, keepdim=True)
    var = out.var(dim=-1, keepdim=True)

    print("Mean:\n", mean)
    print("Variance:\n", var)

    norm = LayerNorm(embbed_dim, verbose=verbose)
    normalized_output = norm(out, verbose=verbose)


# _test_run = test_normalization(True)


