Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
Portions of this notebook consist of AI-generated content.

Permission is hereby granted, free of charge, to any person obtaining a copy

of this software and associated documentation files (the "Software"), to deal

in the Software without restriction, including without limitation the rights

to use, copy, modify, merge, publish, distribute, sublicense, and/or sell

copies of the Software, and to permit persons to whom the Software is

furnished to do so, subject to the following conditions:



The above copyright notice and this permission notice shall be included in all

copies or substantial portions of the Software.



THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR

IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,

FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE

AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER

LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,

OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE

SOFTWARE.

# Lab 5b: Advanced Normalization Applications in Transformers

## Lab Overview

Welcome to an in-depth exploration of advanced normalization techniques used in modern transformer architectures! This lab builds upon basic normalization concepts to demonstrate sophisticated applications in large language models and neural networks.

**Lab Goal**: Master advanced normalization techniques including RMSNorm, Pre-Norm vs Post-Norm architectures, and positional encoding normalization.

## Learning Objectives

By the end of this lab, you will be able to:

1. **Implement RMSNorm**: Understand and code Root Mean Square Layer Normalization
2. **Compare Normalization Strategies**: Analyze Pre-Norm vs Post-Norm transformer architectures
3. **Apply Positional Encoding**: Implement and normalize sinusoidal positional embeddings
4. **Analyze Performance**: Compare different normalization techniques quantitatively
5. **Connect to LLMs**: Understand how these techniques are used in models like LLaMA and GPT


---

## 1. Environment Setup

In [None]:
# Core Libraries
import math
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Configure device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

Using device: cuda
PyTorch version: 2.9.1+rocm7.10.0
GPU: Radeon 8060S Graphics
GPU Memory: 103.1 GB


## 2. RMSNorm (Root Mean Square Layer Normalization)

RMSNorm is a simplified and more efficient alternative to LayerNorm, used in modern models like LLaMA. Instead of centering the data by subtracting the mean, RMSNorm only normalizes by the root mean square.

**Mathematical Foundation:**
- **Standard LayerNorm**: `y = (x - μ) / σ * γ + β`
- **RMSNorm**: `y = x / RMS(x) * γ` where `RMS(x) = √(mean(x²) + ε)`

**Key Advantages:**
- **Computational Efficiency**: Fewer operations (no mean computation/subtraction)
- **Memory Efficiency**: Reduced intermediate tensor storage
- **Numerical Stability**: Often more stable than LayerNorm
- **Performance**: Faster on GPU due to simpler operations

**LLaMA Connection**: Meta's LLaMA models use RMSNorm in every transformer layer for better efficiency and training stability.

In [None]:
# Implement RMSNorm from Scratch
class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        """
        RMSNorm layer as used in LLaMA and other modern models.

        Args:
            dim: The dimension to normalize (usually the last dimension)
            eps: Small epsilon to prevent division by zero
        """
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        """Compute the RMS normalization"""
        # Compute RMS: sqrt(mean(x^2) + eps)
        rms = x.pow(2).mean(-1, keepdim=True).add(self.eps).sqrt()
        return x / rms

    def forward(self, x):
        """Forward pass with learnable scaling"""
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


# Create test data
batch_size, seq_len, hidden_dim = 2, 4, 8
input_tensor = torch.randn(batch_size, seq_len, hidden_dim).to(device)

print("Testing RMSNorm Implementation")
print(f"Input shape: {input_tensor.shape}")
print(f"Input device: {input_tensor.device}")

# Initialize RMSNorm
rms_norm = RMSNorm(hidden_dim).to(device)
print(f"\n RMSNorm initialized on {device}")
print(f"Parameters: {sum(p.numel() for p in rms_norm.parameters())}")

# Apply RMSNorm
normalized_output = rms_norm(input_tensor)
print(f"\nOutput shape: {normalized_output.shape}")
print(f"Output device: {normalized_output.device}")

# Analyze normalization effect
print("\n Normalization Analysis:")
print(f"Input mean: {input_tensor.mean(-1).mean():.4f}")
print(f"Input std: {input_tensor.std(-1).mean():.4f}")
print(f"Output mean: {normalized_output.mean(-1).mean():.4f}")
print(f"Output std: {normalized_output.std(-1).mean():.4f}")

# Show RMS values
input_rms = input_tensor.pow(2).mean(-1).sqrt()
output_rms = normalized_output.pow(2).mean(-1).sqrt()
print("\nRMS Analysis:")
print(f"Input RMS: {input_rms.mean():.4f}")
print(f"Output RMS: {output_rms.mean():.4f} (should be ≈ weight scale)")
print(f"Weight values: {rms_norm.weight.data[:5]}")  # Show first 5 weights

Testing RMSNorm Implementation
Input shape: torch.Size([2, 4, 8])
Input device: cuda:0

 RMSNorm initialized on cuda
Parameters: 8

Output shape: torch.Size([2, 4, 8])
Output device: cuda:0

 Normalization Analysis:
Input mean: 0.0328
Input std: 1.0286
Output mean: 0.0356
Output std: 0.9911

RMS Analysis:
Input RMS: 1.0312
Output RMS: 1.0000 (should be ≈ weight scale)
Weight values: tensor([1., 1., 1., 1., 1.], device='cuda:0')


In [3]:
# Compare RMSNorm vs LayerNorm Performance and Behavior
print("Comparing RMSNorm vs LayerNorm")

# Create standard LayerNorm for comparison
layer_norm = nn.LayerNorm(hidden_dim).to(device)

# Test with the same input
ln_output = layer_norm(input_tensor)

print("\n Performance Comparison:")
print("=" * 50)

# Time comparison (simplified - for demonstration)
import time

# RMSNorm timing
start_time = time.time()
for _ in range(100):
    _ = rms_norm(input_tensor)
rms_time = time.time() - start_time

# LayerNorm timing
start_time = time.time()
for _ in range(100):
    _ = layer_norm(input_tensor)
ln_time = time.time() - start_time

print(f"RMSNorm time (100 iterations): {rms_time:.4f}s")
print(f"LayerNorm time (100 iterations): {ln_time:.4f}s")
print(f"Speedup: {ln_time / rms_time:.2f}x")

# Statistical comparison
print("\n Statistical Comparison:")
print("=" * 50)
print("Original input:")
print(f"  Mean: {input_tensor.mean(-1)[:2]}")  # First 2 samples
print(f"  Std:  {input_tensor.std(-1)[:2]}")

print("\nRMSNorm output:")
print(f"  Mean: {normalized_output.mean(-1)[:2]}")
print(f"  Std:  {normalized_output.std(-1)[:2]}")

print("\nLayerNorm output:")
print(f"  Mean: {ln_output.mean(-1)[:2]}")
print(f"  Std:  {ln_output.std(-1)[:2]}")

# Gradient flow comparison
print("\n Gradient Flow Analysis:")
# Create a simple loss and backpropagate
target = torch.randn_like(input_tensor).to(device)

# RMSNorm gradients
rms_loss = F.mse_loss(normalized_output, target)
rms_loss.backward(retain_graph=True)
rms_grad_norm = rms_norm.weight.grad.norm().item()

# LayerNorm gradients
ln_loss = F.mse_loss(ln_output, target)
ln_loss.backward()
ln_weight_grad_norm = layer_norm.weight.grad.norm().item()
ln_bias_grad_norm = layer_norm.bias.grad.norm().item()

print(f"RMSNorm weight gradient norm: {rms_grad_norm:.4f}")
print(f"LayerNorm weight gradient norm: {ln_weight_grad_norm:.4f}")
print(f"LayerNorm bias gradient norm: {ln_bias_grad_norm:.4f}")

print("\n Key Insights:")
print(f"- RMSNorm is typically {ln_time / rms_time:.1f}x faster than LayerNorm")
print("- RMSNorm doesn't center data (non-zero mean)")
print("- RMSNorm uses fewer parameters (no bias term)")
print("- Both provide similar gradient flow characteristics")

Comparing RMSNorm vs LayerNorm

 Performance Comparison:
RMSNorm time (100 iterations): 0.0023s
LayerNorm time (100 iterations): 0.0008s
Speedup: 0.34x

 Statistical Comparison:
Original input:
  Mean: tensor([[ 0.0007, -0.2742,  0.5283,  0.6174],
        [-0.1671, -0.3646, -0.2268,  0.1488]], device='cuda:0')
  Std:  tensor([[1.4991, 0.9872, 0.9361, 0.6845],
        [0.9583, 1.2751, 0.7560, 1.1328]], device='cuda:0')

RMSNorm output:
  Mean: tensor([[ 4.9339e-04, -2.8460e-01,  5.1661e-01,  6.9412e-01],
        [-1.8321e-01, -2.9233e-01, -3.0544e-01,  1.3909e-01]], device='cuda:0',
       grad_fn=<SliceBackward0>)
  Std:  tensor([[1.0690, 1.0248, 0.9153, 0.7696],
        [1.0509, 1.0223, 1.0180, 1.0587]], device='cuda:0',
       grad_fn=<SliceBackward0>)

LayerNorm output:
  Mean: tensor([[ 1.4901e-08,  7.4506e-09,  2.9802e-08, -2.9802e-08],
        [ 2.6077e-08, -1.4901e-08, -1.4901e-08, -4.4703e-08]], device='cuda:0',
       grad_fn=<SliceBackward0>)
  Std:  tensor([[1.0690, 1.0690, 

## 3: Pre-Norm vs Post-Norm Transformer Architectures

The placement of normalization layers in transformer architectures significantly impacts training dynamics and model performance. Let's explore both approaches and their implications.

**Architecture Comparison:**

**Post-Norm (Original Transformer):**
```
x → MultiHeadAttention → Add & Norm → FFN → Add & Norm → output
```

**Pre-Norm (Modern Approach):**
```
x → Norm → MultiHeadAttention → Add → Norm → FFN → Add → output
```

**Key Differences:**

**Post-Norm Characteristics:**
- **Residual Path**: Clean residual connections from input to output
- **Gradient Flow**: Can suffer from gradient vanishing in deep networks
- **Training Stability**: May require careful initialization and learning rates
- **Performance**: Often achieves better final performance when trained successfully

**Pre-Norm Characteristics:**
- **Training Stability**: More stable training, especially for deep networks
- **Gradient Flow**: Better gradient flow through normalized paths
- **Warmup Requirements**: Often requires less careful warmup strategies
- **Scalability**: Easier to scale to very deep architectures

**Real-World Usage:**
- **GPT Models**: Use Pre-Norm for better training stability
- **T5**: Uses Pre-Norm architecture
- **BERT**: Uses Post-Norm (original transformer style)
- **Modern LLMs**: Mostly adopt Pre-Norm for stability

In [4]:
# Implement Pre-Norm and Post-Norm Transformer Blocks
class SimpleAttention(nn.Module):
    """Simplified attention mechanism for demonstration"""

    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.qkv = nn.Linear(dim, dim * 3, bias=False)
        self.out = nn.Linear(dim, dim, bias=False)
        self.scale = dim**-0.5

    def forward(self, x):
        B, T, C = x.shape
        qkv = self.qkv(x).reshape(B, T, 3, C).permute(2, 0, 1, 3)
        q, k, v = qkv.unbind(0)

        # Simplified attention (no multi-head for clarity)
        att = (q @ k.transpose(-2, -1)) * self.scale
        att = F.softmax(att, dim=-1)
        out = att @ v
        return self.out(out)


class PostNormBlock(nn.Module):
    """Post-Norm Transformer Block (Original Transformer style)"""

    def __init__(self, dim):
        super().__init__()
        self.attention = SimpleAttention(dim)
        self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
        self.ln1 = RMSNorm(dim)
        self.ln2 = RMSNorm(dim)

    def forward(self, x):
        # Post-Norm: x → Attention → Add & Norm → FFN → Add & Norm
        x = self.ln1(x + self.attention(x))
        x = self.ln2(x + self.ffn(x))
        return x


class PreNormBlock(nn.Module):
    """Pre-Norm Transformer Block (Modern style)"""

    def __init__(self, dim):
        super().__init__()
        self.attention = SimpleAttention(dim)
        self.ffn = nn.Sequential(nn.Linear(dim, dim * 4), nn.GELU(), nn.Linear(dim * 4, dim))
        self.ln1 = RMSNorm(dim)
        self.ln2 = RMSNorm(dim)

    def forward(self, x):
        # Pre-Norm: x → Norm → Attention → Add → Norm → FFN → Add
        x = x + self.attention(self.ln1(x))
        x = x + self.ffn(self.ln2(x))
        return x


# Create test models
dim = 128
seq_len = 32
batch_size = 4

test_input = torch.randn(batch_size, seq_len, dim).to(device)

post_norm_block = PostNormBlock(dim).to(device)
pre_norm_block = PreNormBlock(dim).to(device)

print("Transformer Block Comparison")
print(f"Input shape: {test_input.shape}")
print(f"Hidden dimension: {dim}")

# Test forward passes
print("\nTesting Forward Passes:")
post_norm_output = post_norm_block(test_input)
pre_norm_output = pre_norm_block(test_input)

print(f"Post-Norm output shape: {post_norm_output.shape}")
print(f"Pre-Norm output shape: {pre_norm_output.shape}")

# Analyze output statistics
print("\n Output Statistics:")
print(f"Post-Norm - Mean: {post_norm_output.mean():.4f}, Std: {post_norm_output.std():.4f}")
print(f"Pre-Norm - Mean: {pre_norm_output.mean():.4f}, Std: {pre_norm_output.std():.4f}")

# Parameter count comparison
post_norm_params = sum(p.numel() for p in post_norm_block.parameters())
pre_norm_params = sum(p.numel() for p in pre_norm_block.parameters())

print("\n Parameter Comparison:")
print(f"Post-Norm parameters: {post_norm_params:,}")
print(f"Pre-Norm parameters: {pre_norm_params:,}")
print(f"Parameter difference: {abs(post_norm_params - pre_norm_params):,}")

Transformer Block Comparison
Input shape: torch.Size([4, 32, 128])
Hidden dimension: 128

Testing Forward Passes:
Post-Norm output shape: torch.Size([4, 32, 128])
Pre-Norm output shape: torch.Size([4, 32, 128])

 Output Statistics:
Post-Norm - Mean: 0.0059, Std: 1.0000
Pre-Norm - Mean: 0.0116, Std: 1.0248

 Parameter Comparison:
Post-Norm parameters: 197,504
Pre-Norm parameters: 197,504
Parameter difference: 0


In [5]:
# Analyze Gradient Flow in Pre-Norm vs Post-Norm
print(" Gradient Flow Analysis")

# Create a simple loss for both models
target = torch.randn_like(test_input).to(device)

# Post-Norm gradient analysis
post_norm_block.zero_grad()
post_loss = F.mse_loss(post_norm_output, target)
post_loss.backward(retain_graph=True)

# Pre-Norm gradient analysis
pre_norm_block.zero_grad()
pre_loss = F.mse_loss(pre_norm_output, target)
pre_loss.backward()

print("\n Loss Comparison:")
print(f"Post-Norm loss: {post_loss.item():.4f}")
print(f"Pre-Norm loss: {pre_loss.item():.4f}")

# Analyze gradient norms for normalization layers
post_ln1_grad = post_norm_block.ln1.weight.grad.norm().item()
post_ln2_grad = post_norm_block.ln2.weight.grad.norm().item()
pre_ln1_grad = pre_norm_block.ln1.weight.grad.norm().item()
pre_ln2_grad = pre_norm_block.ln2.weight.grad.norm().item()

print("\n Gradient Norms for Normalization Layers:")
print("=" * 50)
print("Post-Norm:")
print(f"  LayerNorm 1 gradient norm: {post_ln1_grad:.4f}")
print(f"  LayerNorm 2 gradient norm: {post_ln2_grad:.4f}")
print("Pre-Norm:")
print(f"  LayerNorm 1 gradient norm: {pre_ln1_grad:.4f}")
print(f"  LayerNorm 2 gradient norm: {pre_ln2_grad:.4f}")


# Create a visualization of gradient magnitudes
def get_gradient_stats(model, name):
    grad_norms = []
    param_names = []
    for param_name, param in model.named_parameters():
        if param.grad is not None:
            grad_norms.append(param.grad.norm().item())
            param_names.append(f"{name}_{param_name}")
    return grad_norms, param_names


post_grads, post_names = get_gradient_stats(post_norm_block, "PostNorm")
pre_grads, pre_names = get_gradient_stats(pre_norm_block, "PreNorm")

print("\n Detailed Gradient Analysis:")
print("=" * 60)
print(f"{'Parameter':<30} {'Post-Norm':<12} {'Pre-Norm':<12}")
print("=" * 60)

# Compare common parameters
common_params = ["ln1.weight", "ln2.weight", "attention.qkv.weight", "attention.out.weight"]
for param in common_params:
    post_grad = next((g for g, n in zip(post_grads, post_names) if param in n), 0)
    pre_grad = next((g for g, n in zip(pre_grads, pre_names) if param in n), 0)
    print(f"{param:<30} {post_grad:<12.4f} {pre_grad:<12.4f}")

print("\n Training Insights:")
print("- Pre-Norm typically shows more stable gradient flow")
print("- Post-Norm may have larger gradient variations")
print("- Pre-Norm normalization layers often have smaller gradients")
print("- Both architectures can be trained successfully with proper setup")

 Gradient Flow Analysis

 Loss Comparison:
Post-Norm loss: 2.0103
Pre-Norm loss: 2.0617

 Gradient Norms for Normalization Layers:
Post-Norm:
  LayerNorm 1 gradient norm: 0.0144
  LayerNorm 2 gradient norm: 0.1798
Pre-Norm:
  LayerNorm 1 gradient norm: 0.0023
  LayerNorm 2 gradient norm: 0.0097

 Detailed Gradient Analysis:
Parameter                      Post-Norm    Pre-Norm    
ln1.weight                     0.0144       0.0023      
ln2.weight                     0.1798       0.0097      
attention.qkv.weight           0.0231       0.0361      
attention.out.weight           0.0207       0.0324      

 Training Insights:
- Pre-Norm typically shows more stable gradient flow
- Post-Norm may have larger gradient variations
- Pre-Norm normalization layers often have smaller gradients
- Both architectures can be trained successfully with proper setup


## 4. Advanced Positional Encoding with Normalization

Positional encodings are crucial for transformers to understand sequence order. We'll implement sinusoidal positional encodings and explore how normalization affects their integration with input embeddings.

**Sinusoidal Positional Encoding Formula:**
- `PE(pos, 2i) = sin(pos / 10000^(2i/d))`
- `PE(pos, 2i+1) = cos(pos / 10000^(2i/d))`

Where:
- `pos` is the position in the sequence
- `i` is the dimension index
- `d` is the model dimension

**Key Concepts:**
- **Frequency Modulation**: Different dimensions use different frequencies
- **Relative Position**: Model can learn relative positions through dot products
- **Extrapolation**: Can potentially handle longer sequences than seen during training
- **Normalization Impact**: How normalization affects the integration of positional information

**Advanced Applications:**
- **RoPE (Rotary Position Embedding)**: Used in modern models like LLaMA
- **ALiBi (Attention with Linear Biases)**: Alternative position encoding method
- **Learned Positional Embeddings**: Trainable position representations
- **Relative Position Encoding**: Direct relative position modeling

In [6]:
# Implement Advanced Positional Encoding with Normalization
class SinusoidalPositionalEncoding(nn.Module):
    """Sinusoidal positional encoding as used in the original Transformer"""

    def __init__(self, dim: int, max_len: int = 5000, dropout: float = 0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

        # Create positional encoding matrix
        pe = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)

        # Create frequency dividers
        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim))

        # Apply sine and cosine
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Register as buffer (not a parameter)
        pe = pe.unsqueeze(0).transpose(0, 1)  # Shape: [max_len, 1, dim]
        self.register_buffer("pe", pe)

    def forward(self, x):
        """Add positional encoding to input embeddings"""
        # x shape: [batch_size, seq_len, dim]
        seq_len = x.size(1)
        x = x + self.pe[:seq_len, :, :].transpose(0, 1)
        return self.dropout(x)


# Test positional encoding
print(" Advanced Positional Encoding Implementation")

dim = 64
max_len = 100
batch_size = 2
seq_len = 20

# Create test embeddings
embeddings = torch.randn(batch_size, seq_len, dim).to(device)
print(f"Input embeddings shape: {embeddings.shape}")

# Initialize positional encoding
pos_encoder = SinusoidalPositionalEncoding(dim, max_len).to(device)

# Apply positional encoding
encoded_embeddings = pos_encoder(embeddings)
print(f"Encoded embeddings shape: {encoded_embeddings.shape}")

# Visualize positional encoding patterns
pe_matrix = pos_encoder.pe[:seq_len, 0, :].cpu()  # Get PE for visualization
print(f"Positional encoding matrix shape: {pe_matrix.shape}")

# Analyze positional encoding properties
print("\n Positional Encoding Analysis:")
print(f"PE range: [{pe_matrix.min():.3f}, {pe_matrix.max():.3f}]")
print(f"PE mean: {pe_matrix.mean():.6f}")
print(f"PE std: {pe_matrix.std():.6f}")

# Show frequency patterns for first few dimensions
print("\n Frequency Patterns (first 8 dimensions):")
for i in range(0, min(8, dim), 2):
    freq = 1 / (10000 ** (i / dim))
    print(f"Dim {i:2d}-{i + 1:2d}: frequency = {freq:.6f}")

# Compare embeddings before and after positional encoding
print("\n Embedding Comparison:")
print(f"Original embedding norm: {embeddings.norm(dim=-1).mean():.4f}")
print(f"Encoded embedding norm: {encoded_embeddings.norm(dim=-1).mean():.4f}")
print(f"Positional encoding contribution: {pos_encoder.pe[:seq_len, 0, :].norm(dim=-1).mean():.4f}")

 Advanced Positional Encoding Implementation
Input embeddings shape: torch.Size([2, 20, 64])
Encoded embeddings shape: torch.Size([2, 20, 64])
Positional encoding matrix shape: torch.Size([20, 64])

 Positional Encoding Analysis:
PE range: [-1.000, 1.000]
PE mean: 0.438697
PE std: 0.554784

 Frequency Patterns (first 8 dimensions):
Dim  0- 1: frequency = 1.000000
Dim  2- 3: frequency = 0.749894
Dim  4- 5: frequency = 0.562341
Dim  6- 7: frequency = 0.421697

 Embedding Comparison:
Original embedding norm: 8.1270
Encoded embedding norm: 10.5577
Positional encoding contribution: 5.6569


In [None]:
# Analyze Impact of Normalization on Positional Encoding
print(" Normalization Impact on Positional Encoding")

# Create different normalization scenarios
scenarios = {
    "No Normalization": lambda x: x,
    "RMSNorm After PE": lambda x: RMSNorm(dim).to(device)(pos_encoder(x)),
    "RMSNorm Before PE": lambda x: pos_encoder(RMSNorm(dim).to(device)(x)),
    "LayerNorm After PE": lambda x: nn.LayerNorm(dim).to(device)(pos_encoder(x)),
}

results = {}

for name, transform in scenarios.items():
    # Apply transformation
    if "RMSNorm" in name or "LayerNorm" in name:
        # Create fresh embeddings for each test
        test_embeddings = torch.randn(batch_size, seq_len, dim).to(device)
        output = transform(test_embeddings)
    else:
        output = transform(embeddings)

    # Store results
    results[name] = {
        "output": output,
        "mean": output.mean(-1).mean().item(),
        "std": output.std(-1).mean().item(),
        "norm": output.norm(dim=-1).mean().item(),
    }

# Display comparison
print("\n Normalization Impact Comparison:")
print("=" * 70)
print(f"{'Scenario':<25} {'Mean':<10} {'Std':<10} {'L2 Norm':<10}")
print("=" * 70)

for name, stats in results.items():
    print(f"{name:<25} {stats['mean']:<10.4f} {stats['std']:<10.4f} {stats['norm']:<10.4f}")

# Analyze positional information preservation
print("\n Positional Information Analysis:")
print("=" * 50)

# Create position-specific patterns to test preservation
position_test = torch.zeros(1, seq_len, dim).to(device)
position_test[0, :, 0] = torch.arange(seq_len, dtype=torch.float)  # Position signal in first dim

for name, transform in scenarios.items():
    if name == "No Normalization":
        test_output = transform(position_test)
    else:
        # Need to handle normalization properly
        if "Before PE" in name:
            # This would destroy positional info, so skip
            continue
        else:
            test_output = transform(position_test)

    # Check if position signal is preserved
    position_signal = test_output[0, :, 0]
    correlation = torch.corrcoef(torch.stack([torch.arange(seq_len, dtype=torch.float), position_signal.cpu()]))[0, 1]

    print(f"{name:<25}: Position correlation = {correlation:.4f}")

# Show actual values for first few positions
print("\n First 5 Positions Comparison:")
no_norm_vals = results["No Normalization"]["output"][0, :5, 0]
after_norm_vals = results["RMSNorm After PE"]["output"][0, :5, 0]

print(f"No Normalization: {no_norm_vals}")
print(f"RMSNorm After PE: {after_norm_vals}")
print(f"Relative ordering preserved: {torch.argsort(no_norm_vals).equal(torch.argsort(after_norm_vals))}")

 Normalization Impact on Positional Encoding

 Normalization Impact Comparison:
Scenario                  Mean       Std        L2 Norm   
No Normalization          0.0296     1.0144     8.1270    
RMSNorm After PE          0.3252     0.9481     8.0000    
RMSNorm Before PE         0.4410     1.2132     10.3244   
LayerNorm After PE        -0.0000    1.0079     8.0000    

 Positional Information Analysis:
No Normalization         : Position correlation = 1.0000
RMSNorm After PE         : Position correlation = 0.8034
LayerNorm After PE       : Position correlation = 0.4450

 Key Insights:
- Normalization AFTER PE preserves relative positional information
- Normalization BEFORE PE can destroy positional patterns
- RMSNorm tends to preserve more positional information than LayerNorm
- The interaction between embeddings and PE is crucial for model performance

 First 5 Positions Comparison:
No Normalization: tensor([-1.0904, -0.0875, -0.5452, -0.4202,  0.5460], device='cuda:0')
RMSNorm A

## Lab Summary

### Technical Concepts Learned
- **RMSNorm Implementation**: Building RMSNorm from scratch and comparing its efficiency with LayerNorm
- **Pre-Norm vs Post-Norm**: Understanding how normalization placement affects gradient flow and training stability
- **Sinusoidal Positional Encoding**: Implementing frequency-based position embeddings using sin/cos functions
- **Normalization Order Effects**: How normalization before vs after positional encoding impacts position information preservation
- **Gradient Flow Analysis**: Comparing gradient magnitudes through different transformer block architectures

### Experiment Further
- Implement RoPE (Rotary Position Embedding) and compare with sinusoidal PE
- Stack multiple Pre-Norm vs Post-Norm blocks and compare training stability at depth
- Measure actual GPU memory and compute time for RMSNorm vs LayerNorm at scale
- Test normalization impact on longer sequences (extrapolation beyond training length)
- Implement DeepNorm for training very deep transformer models