# üìó Phase 2 ‚Äì Representation: 5Ô∏è‚É£ Normalization

## Gi·∫£ng vi√™n: Deep Learning v·ªõi PyTorch

---

## M·ª•c ti√™u h·ªçc t·∫≠p

Sau khi ho√†n th√†nh notebook n√†y, b·∫°n s·∫Ω:
- ‚úÖ Hi·ªÉu **t·∫°i sao** normalization ho·∫°t ƒë·ªông (Internal Covariate Shift, gradient flow)
- ‚úÖ N·∫Øm v·ªØng **Batch Normalization**: train vs inference, running statistics
- ‚úÖ Hi·ªÉu **Layer Normalization** v√† t·∫°i sao n√≥ t·ªët cho Transformers
- ‚úÖ Kh√°m ph√° **RMSNorm**: hi·ªáu qu·∫£ v√† trade-offs
- ‚úÖ Th·ª±c h√†nh so s√°nh BN/LN/RMSNorm qua experiments

---

In [None]:
# Import c√°c th∆∞ vi·ªán c·∫ßn thi·∫øt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Set style cho ƒë·∫πp
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

# Seed ƒë·ªÉ reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Ki·ªÉm tra GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Device: {device}")
print(f"üî• PyTorch version: {torch.__version__}")

---

## 5.1 T·∫°i sao Normalization ho·∫°t ƒë·ªông?

### üéØ Internal Covariate Shift (Quan ƒëi·ªÉm l·ªãch s·ª≠)

**ƒê·ªãnh nghƒ©a**: S·ª± thay ƒë·ªïi ph√¢n ph·ªëi c·ªßa input trong m·ªói layer khi training.

**V·∫•n ƒë·ªÅ**:
- Layer sau ph·∫£i "h·ªçc l·∫°i" khi ph√¢n ph·ªëi input thay ƒë·ªïi
- L√†m ch·∫≠m qu√° tr√¨nh training
- C·∫ßn learning rate nh·ªè ƒë·ªÉ ·ªïn ƒë·ªãnh

### üéØ Smoothing Optimization Landscape

**Nghi√™n c·ª©u g·∫ßn ƒë√¢y** (Santurkar et al., 2018):
- BatchNorm l√†m m∆∞·ª£t loss landscape
- Gradient tr·ªü n√™n predictable h∆°n
- Cho ph√©p learning rate l·ªõn h∆°n

### üéØ Stabilizing Gradient Flow

- NgƒÉn gradient vanishing/exploding
- Gi·ªØ gradient trong kho·∫£ng h·ª£p l√Ω
- Training s√¢u h∆°n, nhanh h∆°n

In [None]:
# Visualization: V·∫Ω ph√¢n ph·ªëi activation qua c√°c layer
def visualize_activation_distribution(model, x, title="Activation Distribution"):
    """
    Visualize how activation distributions change across layers
    """
    activations = []
    
    # Forward pass v√† l∆∞u activations
    with torch.no_grad():
        a = x
        for layer in model:
            a = layer(a)
            activations.append(a.cpu().numpy().flatten())
    
    # V·∫Ω histogram
    fig, axes = plt.subplots(1, len(activations), figsize=(15, 3))
    fig.suptitle(title, fontsize=16, fontweight='bold')
    
    for idx, (ax, act) in enumerate(zip(axes, activations)):
        ax.hist(act, bins=50, alpha=0.7, color=f'C{idx}')
        ax.set_title(f'Layer {idx+1}\nŒº={act.mean():.3f}\nœÉ={act.std():.3f}')
        ax.set_xlabel('Activation value')
        ax.set_ylabel('Frequency')
    
    plt.tight_layout()
    plt.show()

# Demo: So s√°nh network v·ªõi v√† kh√¥ng c√≥ normalization
print("üî¨ Demo: Internal Covariate Shift\n")

# T·∫°o sample data
x_sample = torch.randn(1000, 100)

# Network KH√îNG c√≥ normalization
model_no_norm = nn.Sequential(
    nn.Linear(100, 100),
    nn.ReLU(),
    nn.Linear(100, 100),
    nn.ReLU(),
    nn.Linear(100, 100),
    nn.ReLU()
)

# Network C√ì normalization
model_with_norm = nn.Sequential(
    nn.Linear(100, 100),
    nn.BatchNorm1d(100),
    nn.ReLU(),
    nn.Linear(100, 100),
    nn.BatchNorm1d(100),
    nn.ReLU(),
    nn.Linear(100, 100),
    nn.BatchNorm1d(100),
    nn.ReLU()
)

# Visualize
visualize_activation_distribution(
    [model_no_norm[0], model_no_norm[2], model_no_norm[4]], 
    x_sample,
    "‚ùå KH√îNG c√≥ Normalization - Ph√¢n ph·ªëi b·ªã shift"
)

visualize_activation_distribution(
    [model_with_norm[0:3], model_with_norm[3:6], model_with_norm[6:9]], 
    x_sample,
    "‚úÖ C√ì Normalization - Ph√¢n ph·ªëi ·ªïn ƒë·ªãnh"
)

### üìä Nh·∫≠n x√©t:

- **Kh√¥ng c√≥ Norm**: Activations b·ªã shift, variance thay ƒë·ªïi qua layers
- **C√≥ Norm**: Activations ·ªïn ƒë·ªãnh, mean‚âà0, std‚âà1 ·ªü m·ªói layer

---

## 5.2 Batch Normalization (BN)

### üìê C√¥ng th·ª©c to√°n h·ªçc

**Training mode**:
$$
\begin{align}
\mu_B &= \frac{1}{m}\sum_{i=1}^{m} x_i \quad \text{(mean c·ªßa batch)} \\
\sigma_B^2 &= \frac{1}{m}\sum_{i=1}^{m} (x_i - \mu_B)^2 \quad \text{(variance c·ªßa batch)} \\
\hat{x}_i &= \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}} \quad \text{(normalize)} \\
y_i &= \gamma \hat{x}_i + \beta \quad \text{(scale & shift - learnable)}
\end{align}
$$

**Inference mode**:
$$
\hat{x} = \frac{x - \mu_{running}}{\sqrt{\sigma_{running}^2 + \epsilon}}
$$

### üîÑ Train vs Inference

| Mode | Mean & Variance | C·∫≠p nh·∫≠t |
|------|----------------|----------|
| **Train** | T√≠nh t·ª´ batch hi·ªán t·∫°i | C·∫≠p nh·∫≠t running statistics |
| **Inference** | D√πng running statistics | Kh√¥ng c·∫≠p nh·∫≠t |

In [None]:
# Demo: BatchNorm t·ª´ scratch ƒë·ªÉ hi·ªÉu r√µ mechanism
class BatchNorm1dFromScratch(nn.Module):
    """
    Batch Normalization implementation from scratch ƒë·ªÉ hi·ªÉu r√µ c∆° ch·∫ø
    """
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.eps = eps
        self.momentum = momentum
        
        # Learnable parameters
        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))
        
        # Running statistics (kh√¥ng h·ªçc, ch·ªâ track)
        self.register_buffer('running_mean', torch.zeros(num_features))
        self.register_buffer('running_var', torch.ones(num_features))
        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
    
    def forward(self, x):
        # x shape: (batch_size, num_features)
        
        if self.training:
            # TRAINING MODE: t√≠nh mean & var t·ª´ batch
            batch_mean = x.mean(dim=0)
            batch_var = x.var(dim=0, unbiased=False)
            
            # Normalize
            x_normalized = (x - batch_mean) / torch.sqrt(batch_var + self.eps)
            
            # C·∫≠p nh·∫≠t running statistics
            with torch.no_grad():
                self.running_mean = (1 - self.momentum) * self.running_mean + \
                                   self.momentum * batch_mean
                self.running_var = (1 - self.momentum) * self.running_var + \
                                  self.momentum * batch_var
                self.num_batches_tracked += 1
        else:
            # INFERENCE MODE: d√πng running statistics
            x_normalized = (x - self.running_mean) / torch.sqrt(self.running_var + self.eps)
        
        # Scale & shift
        out = self.gamma * x_normalized + self.beta
        return out

print("üß™ Demo: BatchNorm Train vs Inference\n")

# T·∫°o data
x_train = torch.randn(32, 10) * 2 + 5  # mean‚âà5, std‚âà2

# Initialize BN layer
bn_custom = BatchNorm1dFromScratch(10)
bn_pytorch = nn.BatchNorm1d(10)

# Training mode
bn_custom.train()
bn_pytorch.train()
out_custom_train = bn_custom(x_train)
out_pytorch_train = bn_pytorch(x_train)

print("üìä TRAINING MODE:")
print(f"Input - Mean: {x_train.mean():.3f}, Std: {x_train.std():.3f}")
print(f"Output (Custom) - Mean: {out_custom_train.mean():.3f}, Std: {out_custom_train.std():.3f}")
print(f"Output (PyTorch) - Mean: {out_pytorch_train.mean():.3f}, Std: {out_pytorch_train.std():.3f}")
print(f"Running Mean (Custom): {bn_custom.running_mean.mean():.3f}")

# Inference mode v·ªõi data kh√°c
x_test = torch.randn(16, 10) * 3 + 10  # mean‚âà10, std‚âà3 (kh√°c h·∫≥n!)
bn_custom.eval()
bn_pytorch.eval()
out_custom_test = bn_custom(x_test)
out_pytorch_test = bn_pytorch(x_test)

print("\nüìä INFERENCE MODE:")
print(f"Input - Mean: {x_test.mean():.3f}, Std: {x_test.std():.3f}")
print(f"Output (Custom) - Mean: {out_custom_test.mean():.3f}, Std: {out_custom_test.std():.3f}")
print(f"Output (PyTorch) - Mean: {out_pytorch_test.mean():.3f}, Std: {out_pytorch_test.std():.3f}")
print(f"\n‚úÖ Inference d√πng running stats ‚Üí output kh√¥ng b·ªã ·∫£nh h∆∞·ªüng b·ªüi test batch!")

### üîç Batch Size Dependency

BatchNorm **ph·ª• thu·ªôc v√†o batch size**:
- Batch nh·ªè ‚Üí statistics kh√¥ng ƒë√°ng tin
- Batch size kh√°c nhau train/test ‚Üí v·∫•n ƒë·ªÅ!
- Th∆∞·ªùng c·∫ßn batch size ‚â• 16-32

In [None]:
# Experiment: ·∫¢nh h∆∞·ªüng c·ªßa batch size l√™n BN
def test_batch_size_effect(batch_sizes=[2, 8, 32, 128]):
    """
    Test hi·ªáu ·ª©ng c·ªßa batch size l√™n BatchNorm
    """
    results = {}
    x_full = torch.randn(1000, 50)
    
    for bs in batch_sizes:
        bn = nn.BatchNorm1d(50)
        bn.train()
        
        # Process theo batch
        outputs = []
        for i in range(0, len(x_full), bs):
            batch = x_full[i:i+bs]
            if len(batch) == bs:  # Ch·ªâ l·∫•y batch ƒë·ªß size
                out = bn(batch)
                outputs.append(out)
        
        all_outputs = torch.cat(outputs)
        results[bs] = {
            'mean': all_outputs.mean().item(),
            'std': all_outputs.std().item()
        }
    
    return results

print("üî¨ Experiment: ·∫¢nh h∆∞·ªüng c·ªßa Batch Size\n")
results = test_batch_size_effect()

for bs, stats in results.items():
    print(f"Batch Size {bs:3d}: Mean={stats['mean']:6.3f}, Std={stats['std']:.3f}")

print("\nüìå Nh·∫≠n x√©t: Batch size nh·ªè ‚Üí statistics kh√¥ng ·ªïn ƒë·ªãnh!")

### üñºÔ∏è BatchNorm trong CNN vs MLP

**CNN**: `BatchNorm2d`
- Normalize theo (N, H, W) - gi·ªØ nguy√™n spatial structure
- Input shape: `(N, C, H, W)`
- Statistics shape: `(C,)` - m·ªói channel c√≥ ri√™ng mean/var

**MLP**: `BatchNorm1d`
- Normalize theo batch dimension
- Input shape: `(N, features)`
- Statistics shape: `(features,)`

In [None]:
# Demo: BatchNorm2d cho CNN
print("üñºÔ∏è  Demo: BatchNorm2d trong CNN\n")

# T·∫°o fake image data
images = torch.randn(8, 3, 32, 32)  # (batch, channels, height, width)
print(f"Input shape: {images.shape}")

# BatchNorm2d
bn2d = nn.BatchNorm2d(3)  # 3 channels
bn2d.train()
output = bn2d(images)

print(f"Output shape: {output.shape}")
print(f"\nGamma shape (learnable): {bn2d.weight.shape}")
print(f"Beta shape (learnable): {bn2d.bias.shape}")
print(f"Running mean shape: {bn2d.running_mean.shape}")
print(f"Running var shape: {bn2d.running_var.shape}")

# Verify normalization per channel
print("\nüìä Statistics per channel:")
for c in range(3):
    channel_data = output[:, c, :, :]
    print(f"Channel {c}: Mean={channel_data.mean():.3f}, Std={channel_data.std():.3f}")

---

## 5.3 Layer Normalization (LN)

### üìê C√¥ng th·ª©c

$$
\begin{align}
\mu &= \frac{1}{D}\sum_{i=1}^{D} x_i \quad \text{(mean per sample)} \\
\sigma^2 &= \frac{1}{D}\sum_{i=1}^{D} (x_i - \mu)^2 \quad \text{(variance per sample)} \\
\hat{x}_i &= \frac{x_i - \mu}{\sqrt{\sigma^2 + \epsilon}} \\
y_i &= \gamma \hat{x}_i + \beta
\end{align}
$$

### üîë Kh√°c bi·ªát ch√≠nh v·ªõi BatchNorm:

| Aspect | BatchNorm | LayerNorm |
|--------|-----------|----------|
| Normalize theo | **Batch** dimension | **Feature** dimension |
| Ph·ª• thu·ªôc batch size | ‚úÖ C√≥ | ‚ùå Kh√¥ng |
| Train = Inference | ‚ùå Kh√°c | ‚úÖ Gi·ªëng |
| Running statistics | ‚úÖ C√≥ | ‚ùå Kh√¥ng |
| T·ªët cho | CNN, MLP | **Transformers, RNN** |

In [None]:
# Visualization: BatchNorm vs LayerNorm
def visualize_norm_difference():
    """
    Visualize s·ª± kh√°c bi·ªát gi·ªØa BatchNorm v√† LayerNorm
    """
    # T·∫°o data: (batch_size=4, features=6)
    x = torch.tensor([
        [1., 2., 3., 4., 5., 6.],
        [2., 4., 6., 8., 10., 12.],
        [3., 6., 9., 12., 15., 18.],
        [4., 8., 12., 16., 20., 24.]
    ])
    
    # Apply normalization
    bn = nn.BatchNorm1d(6)
    ln = nn.LayerNorm(6)
    
    bn.eval()
    ln.eval()
    
    x_bn = bn(x)
    x_ln = ln(x)
    
    # Plotting
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    # Original
    im0 = axes[0].imshow(x.numpy(), cmap='viridis', aspect='auto')
    axes[0].set_title('Original Data', fontweight='bold')
    axes[0].set_xlabel('Features')
    axes[0].set_ylabel('Samples')
    plt.colorbar(im0, ax=axes[0])
    
    # BatchNorm
    im1 = axes[1].imshow(x_bn.detach().numpy(), cmap='viridis', aspect='auto')
    axes[1].set_title('BatchNorm\n(normalize theo c·ªôt)', fontweight='bold')
    axes[1].set_xlabel('Features')
    axes[1].set_ylabel('Samples')
    plt.colorbar(im1, ax=axes[1])
    
    # LayerNorm
    im2 = axes[2].imshow(x_ln.detach().numpy(), cmap='viridis', aspect='auto')
    axes[2].set_title('LayerNorm\n(normalize theo h√†ng)', fontweight='bold')
    axes[2].set_xlabel('Features')
    axes[2].set_ylabel('Samples')
    plt.colorbar(im2, ax=axes[2])
    
    plt.tight_layout()
    plt.show()
    
    # Statistics
    print("üìä Statistics Verification:\n")
    print("BatchNorm - Mean per feature (should be ‚âà0):")
    print(x_bn.mean(dim=0).numpy().round(3))
    print("\nLayerNorm - Mean per sample (should be ‚âà0):")
    print(x_ln.mean(dim=1).numpy().round(3))

print("üé® Visualization: BatchNorm vs LayerNorm\n")
visualize_norm_difference()

### ü§ñ T·∫°i sao LayerNorm t·ªët cho Transformers?

1. **Sequence length kh√°c nhau**: M·ªói sample c√≥ th·ªÉ c√≥ length kh√°c
2. **Batch size nh·ªè**: Transformers th∆∞·ªùng train v·ªõi batch nh·ªè (do memory)
3. **Train = Inference**: Kh√¥ng c·∫ßn running statistics
4. **Per-token normalization**: M·ªói token ƒë∆∞·ª£c normalize ƒë·ªôc l·∫≠p

**Trong Transformer**:
```python
# Input shape: (batch, seq_len, d_model)
# LayerNorm normalize theo d_model dimension
ln = nn.LayerNorm(d_model)
```

In [None]:
# Demo: LayerNorm trong Transformer-style architecture
class SimpleTransformerBlock(nn.Module):
    """
    Simplified Transformer block v·ªõi LayerNorm
    """
    def __init__(self, d_model=512, nhead=8):
        super().__init__()
        self.attention = nn.MultiheadAttention(d_model, nhead, batch_first=True)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.ReLU(),
            nn.Linear(d_model * 4, d_model)
        )
    
    def forward(self, x):
        # Self-attention + residual + LayerNorm
        attn_out, _ = self.attention(x, x, x)
        x = self.norm1(x + attn_out)
        
        # FFN + residual + LayerNorm
        ffn_out = self.ffn(x)
        x = self.norm2(x + ffn_out)
        return x

print("ü§ñ Demo: LayerNorm trong Transformer\n")

# T·∫°o fake sequence data v·ªõi varying lengths
batch_size = 4
seq_lengths = [10, 15, 8, 12]  # Kh√°c nhau!
d_model = 512

# Pad sequences
max_len = max(seq_lengths)
x = torch.randn(batch_size, max_len, d_model)

# Create model
model = SimpleTransformerBlock(d_model=d_model)
model.eval()

# Forward pass
with torch.no_grad():
    output = model(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"\n‚úÖ LayerNorm ho·∫°t ƒë·ªông t·ªët v·ªõi sequence lengths kh√°c nhau!")
print(f"‚úÖ Kh√¥ng c·∫ßn batch statistics ‚Üí inference ·ªïn ƒë·ªãnh!")

---

## 5.4 RMSNorm (Root Mean Square Normalization)

### üìê C√¥ng th·ª©c

$$
\begin{align}
\text{RMS} &= \sqrt{\frac{1}{D}\sum_{i=1}^{D} x_i^2} \\
\hat{x}_i &= \frac{x_i}{\text{RMS} + \epsilon} \cdot \gamma
\end{align}
$$

### üîë Kh√°c bi·ªát v·ªõi LayerNorm:

| Aspect | LayerNorm | RMSNorm |
|--------|-----------|----------|
| Mean centering | ‚úÖ C√≥ (`x - Œº`) | ‚ùå **KH√îNG** |
| Variance normalization | ‚úÖ C√≥ | ‚úÖ C√≥ (via RMS) |
| Learnable bias Œ≤ | ‚úÖ C√≥ | ‚ùå Kh√¥ng |
| Computation | 2 passes | **1 pass** |
| Speed | Slower | **~15% faster** |

### üí° T·∫°i sao b·ªè mean?

- **Empirical finding**: Mean centering kh√¥ng quan tr·ªçng l·∫Øm!
- **Efficiency**: Gi·∫£m computation
- **ƒê∆∞·ª£c d√πng trong**: LLaMA, GPT-NeoX, T5

In [None]:
# Implementation: RMSNorm from scratch
class RMSNorm(nn.Module):
    """
    Root Mean Square Layer Normalization
    Paper: https://arxiv.org/abs/1910.07467
    """
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(dim))
    
    def forward(self, x):
        # Compute RMS
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)
        
        # Normalize and scale
        x_normalized = x / rms
        return self.gamma * x_normalized

print("üî¨ Implementation: RMSNorm\n")

# Test data
x_test = torch.randn(4, 8, 512)  # (batch, seq, features)

# Initialize
ln = nn.LayerNorm(512)
rms = RMSNorm(512)

# Forward pass
with torch.no_grad():
    out_ln = ln(x_test)
    out_rms = rms(x_test)

print(f"Input shape: {x_test.shape}")
print(f"\nLayerNorm output:")
print(f"  Mean: {out_ln.mean():.6f} (should be ‚âà0)")
print(f"  Std: {out_ln.std():.6f}")
print(f"\nRMSNorm output:")
print(f"  Mean: {out_rms.mean():.6f} (NOT zero!)")
print(f"  Std: {out_rms.std():.6f}")
print(f"\nüìå RMSNorm KH√îNG center v·ªÅ 0 ‚Üí faster computation!")

### ‚öñÔ∏è Trade-offs

**Advantages**:
- ‚ö° Faster (~15% speedup)
- üíæ Less memory (no mean tracking)
- üéØ Simpler gradient computation

**Potential Issues**:
- ‚ö†Ô∏è Kh√¥ng center v·ªÅ 0 ‚Üí c√≥ th·ªÉ ·∫£nh h∆∞·ªüng m·ªôt s·ªë architectures
- ‚ö†Ô∏è √çt ƒë∆∞·ª£c research h∆°n LayerNorm
- ‚ö†Ô∏è Ph·∫£i tune hyperparameters l·∫°i khi switch t·ª´ LN

---

## 5.5 Practical Experiments: So s√°nh BN / LN / RMSNorm

Ch√∫ng ta s·∫Ω train c√πng 1 model v·ªõi 3 lo·∫°i normalization kh√°c nhau v√† so s√°nh:
- ‚ö° Convergence speed
- üìä Batch size sensitivity
- üìà Gradient norm stability

In [None]:
# T·∫°o synthetic dataset
def create_classification_dataset(n_samples=10000, n_features=50, n_classes=10):
    """
    T·∫°o synthetic classification dataset
    """
    X = torch.randn(n_samples, n_features)
    # T·∫°o labels v·ªõi some structure
    W = torch.randn(n_features, n_classes)
    logits = X @ W
    y = logits.argmax(dim=1)
    
    return TensorDataset(X, y)

# Define models v·ªõi c√°c normalization kh√°c nhau
class MLPWithNorm(nn.Module):
    """
    Simple MLP v·ªõi pluggable normalization
    """
    def __init__(self, input_dim=50, hidden_dim=128, output_dim=10, norm_type='bn'):
        super().__init__()
        self.norm_type = norm_type
        
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        
        # Normalization layers
        if norm_type == 'bn':
            self.norm1 = nn.BatchNorm1d(hidden_dim)
            self.norm2 = nn.BatchNorm1d(hidden_dim)
        elif norm_type == 'ln':
            self.norm1 = nn.LayerNorm(hidden_dim)
            self.norm2 = nn.LayerNorm(hidden_dim)
        elif norm_type == 'rms':
            self.norm1 = RMSNorm(hidden_dim)
            self.norm2 = RMSNorm(hidden_dim)
        else:
            self.norm1 = nn.Identity()
            self.norm2 = nn.Identity()
    
    def forward(self, x):
        x = F.relu(self.norm1(self.fc1(x)))
        x = F.relu(self.norm2(self.fc2(x)))
        x = self.fc3(x)
        return x

print("üèóÔ∏è  Setup: Creating models and dataset...\n")
dataset = create_classification_dataset()
print(f"‚úÖ Dataset created: {len(dataset)} samples")

In [None]:
# Training function v·ªõi gradient tracking
def train_model(model, dataloader, epochs=10, lr=0.001, device='cpu'):
    """
    Train model v√† track metrics
    """
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    history = {
        'loss': [],
        'accuracy': [],
        'grad_norm': []
    }
    
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        grad_norms = []
        
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            
            optimizer.zero_grad()
            outputs = model(X)
            loss = criterion(outputs, y)
            loss.backward()
            
            # Track gradient norm
            total_norm = 0
            for p in model.parameters():
                if p.grad is not None:
                    total_norm += p.grad.data.norm(2).item() ** 2
            grad_norms.append(total_norm ** 0.5)
            
            optimizer.step()
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            correct += predicted.eq(y).sum().item()
            total += y.size(0)
        
        avg_loss = total_loss / len(dataloader)
        accuracy = 100. * correct / total
        avg_grad_norm = np.mean(grad_norms)
        
        history['loss'].append(avg_loss)
        history['accuracy'].append(accuracy)
        history['grad_norm'].append(avg_grad_norm)
    
    return history

print("üéØ Training function ready!")

In [None]:
# Experiment 1: Convergence speed v·ªõi batch size = 32
print("üöÄ Experiment 1: Convergence Speed\n")

batch_size = 32
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Train c√°c models
results = {}
norm_types = ['none', 'bn', 'ln', 'rms']
colors = ['gray', 'blue', 'green', 'red']

for norm_type in tqdm(norm_types, desc="Training models"):
    model = MLPWithNorm(norm_type=norm_type)
    history = train_model(model, dataloader, epochs=20, device=device)
    results[norm_type] = history

# Plotting
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
fig.suptitle('üìä Convergence Comparison (Batch Size = 32)', fontsize=16, fontweight='bold')

# Loss
for norm_type, color in zip(norm_types, colors):
    axes[0].plot(results[norm_type]['loss'], label=norm_type.upper(), 
                color=color, linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy
for norm_type, color in zip(norm_types, colors):
    axes[1].plot(results[norm_type]['accuracy'], label=norm_type.upper(),
                color=color, linewidth=2)
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Gradient norm
for norm_type, color in zip(norm_types, colors):
    axes[2].plot(results[norm_type]['grad_norm'], label=norm_type.upper(),
                color=color, linewidth=2)
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Gradient Norm')
axes[2].set_title('Gradient Stability')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\n‚úÖ Observations:")
print("   - NONE: Slower convergence, unstable gradients")
print("   - BN/LN/RMS: Faster convergence, stable gradients")
print("   - RMS: Slightly faster than LN")

In [None]:
# Experiment 2: Batch size sensitivity
print("üî¨ Experiment 2: Batch Size Sensitivity\n")

batch_sizes = [8, 16, 32, 64]
batch_results = {bs: {} for bs in batch_sizes}

for bs in tqdm(batch_sizes, desc="Testing batch sizes"):
    dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
    
    for norm_type in ['bn', 'ln', 'rms']:
        model = MLPWithNorm(norm_type=norm_type)
        history = train_model(model, dataloader, epochs=10, device=device)
        batch_results[bs][norm_type] = history['accuracy'][-1]  # Final accuracy

# Plotting
fig, ax = plt.subplots(figsize=(10, 6))
fig.suptitle('üìä Batch Size Sensitivity', fontsize=16, fontweight='bold')

for norm_type, color in zip(['bn', 'ln', 'rms'], ['blue', 'green', 'red']):
    accuracies = [batch_results[bs][norm_type] for bs in batch_sizes]
    ax.plot(batch_sizes, accuracies, marker='o', label=norm_type.upper(),
           color=color, linewidth=2, markersize=8)

ax.set_xlabel('Batch Size', fontsize=12)
ax.set_ylabel('Final Accuracy (%)', fontsize=12)
ax.set_title('Effect of Batch Size on Performance')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xscale('log', base=2)

plt.tight_layout()
plt.show()

print("\n‚úÖ Observations:")
print("   - BN: Sensitive to batch size (worse with small batches)")
print("   - LN/RMS: More robust across batch sizes")
print("   - RMS: Consistently good performance")

---

## üìö T·ªïng k·∫øt: Khi n√†o d√πng g√¨?

| Use Case | Recommended | L√Ω do |
|----------|-------------|-------|
| **CNN (Computer Vision)** | BatchNorm2d | Spatial structure, batch statistics reliable |
| **Transformers / NLP** | LayerNorm | Variable length, small batches |
| **RNN / LSTM** | LayerNorm | Temporal data, sequence-to-sequence |
| **Large Language Models** | RMSNorm | Efficiency, proven in LLaMA/GPT |
| **Small batch training** | LayerNorm / RMSNorm | Kh√¥ng ph·ª• thu·ªôc batch |
| **Inference optimization** | RMSNorm | Fastest, simplest |

### üí° Best Practices:

1. **BatchNorm**:
   - Batch size ‚â• 16-32
   - `.train()` v√† `.eval()` quan tr·ªçng!
   - C·∫©n th·∫≠n v·ªõi distributed training

2. **LayerNorm**:
   - Safe default cho most cases
   - ƒê·∫∑c bi·ªát t·ªët cho Transformers
   - Kh√¥ng c·∫ßn worry v·ªÅ batch size

3. **RMSNorm**:
   - Th·ª≠ khi c·∫ßn optimize speed
   - Tune hyperparameters l·∫°i
   - Monitor training stability

---

## üéì B√†i t·∫≠p th·ª±c h√†nh

1. **Implement GroupNorm** t·ª´ scratch (normalize theo groups of channels)
2. **Compare BN vs SyncBN** trong distributed training
3. **Implement InstanceNorm** cho style transfer
4. **Test normalization** v·ªõi different activation functions

---

## üìñ References

1. [Batch Normalization Paper](https://arxiv.org/abs/1502.03167)
2. [Layer Normalization Paper](https://arxiv.org/abs/1607.06450)
3. [RMSNorm Paper](https://arxiv.org/abs/1910.07467)
4. [How Does Batch Normalization Help?](https://arxiv.org/abs/1805.11604)

---

### üôè C·∫£m ∆°n b·∫°n ƒë√£ h·ªçc!

Next: **6Ô∏è‚É£ Activation Functions** üî•