# Playground for Normalization Layer

## Prototype Normalization Layer
* Normalizes activation levels to mean 0 and variance 1
    * shifts activation values by mean
    * divides activation levels by sqrt(variance)
* Contains additional learnable parameters to shift mean from 0 if necessary
* Contains additional learnable parameters to scale variance away from 1 if necessary

In [1]:
import torch
import torch.nn as nn

torch.set_printoptions(threshold=10, edgeitems=3)
torch.manual_seed(42)

class LayerNorm(nn.Module):

    def __init__(self, embed_dim, verbose=False):
        super().__init__()
        
        self.eps = 1e-5         # prevents division by 0
        self.shift = nn.Parameter(torch.zeros(embed_dim))
        self.scale = nn.Parameter(torch.ones(embed_dim))
        
        if verbose:
            print(f"\n=== LayerNorm Initialization ===")
            print(f"    embed_dim =", embed_dim)
            print(f"    Generating self.shift = nn.Parameter(torch.zeros({embed_dim}))")
            print(f"    Generating self.scale = nn.Parameter(torch.ones({embed_dim}))")        
            print(f"=== End LayerNorm Initialization ===\n")

    def forward(self, x, verbose=False):
        in_mean = x.mean(dim=-1, keepdim=True)
        in_variance = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x-in_mean) / torch.sqrt(in_variance + self.eps)
        norm_mean = norm_x.mean(dim=-1, keepdim=True)
        norm_var = norm_x.var(dim=-1, keepdim=True, unbiased=False)

        if verbose:
            print(f"\n=== LayerNorm Forward Pass ===")
            print(f'    Input (first 5 elements): ', x[0 , :5])
            print(f'    Normalized (norm_x first 5 elements): ', norm_x[0 , :5])
            print(f'    Output = self.scale * norm_x + self.shift')
            print(f"=== End LayerNorm Forward Pass ===\n")

        return self.scale * norm_x + self.shift

In [2]:
def test_normalization(verbose = False):

    embbed_dim = 6
    print(f'Embbed_dim: ', embbed_dim)

    batch_example = torch.randn(2, 4)   # Batch size 2 and context = 4
    layer = nn.Sequential(nn.Linear(4, embbed_dim), nn.ReLU())
    out = layer(batch_example)
    
    print(f'Test input\n:', out)

    mean = out.mean(dim=-1, keepdim=True)
    var = out.var(dim=-1, keepdim=True)

    print("Mean:\n", mean)
    print("Variance:\n", var)

    norm = LayerNorm(embbed_dim, verbose=verbose)
    normalized_output = norm(out, verbose=verbose)


if '__file__' not in dir(): _test_run = test_normalization(True)

Embbed_dim:  6
Test input
: tensor([[0.2483, 0.0000, 0.0000, 0.4067, 0.4628, 0.0000],
        [0.0000, 0.0000, 0.5766, 1.7535, 0.0000, 0.0000]],
       grad_fn=<ReluBackward0>)
Mean:
 tensor([[0.1863],
        [0.3884]], grad_fn=<MeanBackward1>)
Variance:
 tensor([[0.0466],
        [0.5005]], grad_fn=<VarBackward0>)

=== LayerNorm Initialization ===
    embed_dim = 6
    Generating self.shift = nn.Parameter(torch.zeros(6))
    Generating self.scale = nn.Parameter(torch.ones(6))
=== End LayerNorm Initialization ===


=== LayerNorm Forward Pass ===
    Input (first 5 elements):  tensor([0.2483, 0.0000, 0.0000, 0.4067, 0.4628], grad_fn=<SliceBackward0>)
    Normalized (norm_x first 5 elements):  tensor([ 0.3144, -0.9452, -0.9452,  1.1184,  1.4030], grad_fn=<SliceBackward0>)
    Output = self.scale * norm_x + self.shift
=== End LayerNorm Forward Pass ===

