In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Learning Rate Scheduling and Gradient Clipping -- Vizuara

> **What you will build:** A warmup + cosine decay learning rate scheduler and max-norm gradient clipping from scratch, then watch them stabilize training in real time.

## 1. Why Does This Matter?

Even with Adam, two things can go catastrophically wrong during training:

1. **Learning rate too high at the start:** Random initial weights produce unreliable gradients. A large learning rate amplifies this noise, and the loss can explode to infinity in the first few steps -- a death spiral you cannot recover from.

2. **Gradient explosions:** A single unusual batch can produce gradients 100x larger than normal. Without protection, this single step wipes out the progress of thousands of previous steps.

Learning rate scheduling (warmup + cosine decay) and gradient clipping are the two safety mechanisms that prevent these failures. Nearly every modern LLM uses both.

## 2. Building Intuition

Think of training as driving a car on an unfamiliar mountain road at night.

**Warmup** is like starting slowly when you first get on the road. You do not know the curves yet, so you drive cautiously. As you learn the road, you speed up.

**Cosine decay** is like slowing down as you approach your destination. You need precision to park, not speed.

**Gradient clipping** is like a speed limiter. Even if the road briefly becomes a steep downhill, the car's maximum speed is capped so you do not crash.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import math

# Visualize the intuition
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Warmup phase
steps = np.arange(1000)
warmup_lr = 3e-4 * steps / 1000
axes[0].plot(steps, warmup_lr, color='#3498db', linewidth=2)
axes[0].fill_between(steps, warmup_lr, alpha=0.2, color='#3498db')
axes[0].set_title('Warmup: Cautious Start', fontsize=12, fontweight='bold')
axes[0].set_xlabel('Step')
axes[0].set_ylabel('Learning Rate')
axes[0].set_ylim(0, 3.5e-4)
axes[0].grid(alpha=0.3)

# Cosine decay
steps = np.arange(9000)
cosine_lr = 1e-5 + 0.5 * (3e-4 - 1e-5) * (1 + np.cos(np.pi * steps / 9000))
axes[1].plot(steps, cosine_lr, color='#e74c3c', linewidth=2)
axes[1].fill_between(steps, cosine_lr, alpha=0.2, color='#e74c3c')
axes[1].set_title('Cosine Decay: Gradual Slowdown', fontsize=12, fontweight='bold')
axes[1].set_xlabel('Step')
axes[1].set_ylabel('Learning Rate')
axes[1].grid(alpha=0.3)

# Gradient clipping
grad_norms = np.random.exponential(1.0, 200) * 2
clipped = np.minimum(grad_norms, 1.0)
axes[2].bar(range(len(grad_norms)), grad_norms, alpha=0.4, color='#e74c3c', label='Original')
axes[2].bar(range(len(clipped)), clipped, alpha=0.6, color='#2ecc71', label='Clipped')
axes[2].axhline(y=1.0, color='black', linestyle='--', linewidth=1, label='Clip threshold')
axes[2].set_title('Gradient Clipping: Safety Net', fontsize=12, fontweight='bold')
axes[2].set_xlabel('Step')
axes[2].set_ylabel('Gradient Norm')
axes[2].legend(fontsize=9)
axes[2].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('safety_mechanisms.png', dpi=150, bbox_inches='tight')
plt.show()
print("Checkpoint: Warmup, cosine decay, and gradient clipping work together to stabilize training.")

## 3. The Mathematics

### Linear Warmup

During the first $W$ warmup steps, the learning rate increases linearly from $0$ to $\eta_{\max}$:

$$\eta_t = \eta_{\max} \cdot \frac{t}{W} \quad \text{for } t < W$$

### Cosine Decay

After warmup, the learning rate follows a cosine curve from $\eta_{\max}$ to $\eta_{\min}$:

$$\eta_t = \eta_{\min} + \frac{1}{2}(\eta_{\max} - \eta_{\min})\left(1 + \cos\left(\frac{\pi \cdot (t - W)}{T - W}\right)\right) \quad \text{for } t \geq W$$

where $T$ is the total training steps.

### Max-Norm Gradient Clipping

Given gradient vector $g$ and threshold $c$:

$$g_{\text{clipped}} = \begin{cases} g & \text{if } \|g\| \leq c \\ c \cdot \frac{g}{\|g\|} & \text{if } \|g\| > c \end{cases}$$

This preserves the direction of the gradient but caps its magnitude at $c$.

In [None]:
# Numerical trace: cosine decay at key points
eta_max, eta_min, T, W = 3e-4, 1e-5, 10000, 1000

checkpoints = [0, 500, 1000, 3000, 5000, 7000, 9000, 10000]
print("Learning Rate at Key Training Steps:")
print("-" * 55)
for t in checkpoints:
    if t < W:
        lr = eta_max * t / W
        phase = "warmup"
    else:
        progress = (t - W) / (T - W)
        lr = eta_min + 0.5 * (eta_max - eta_min) * (1 + math.cos(math.pi * progress))
        phase = "cosine"
    print(f"  Step {t:>5d} ({phase:>6s}):  lr = {lr:.2e}")

# Numerical trace: gradient clipping
print("\nGradient Clipping Example:")
g = np.array([3.0, 4.0, 0.0])
c = 1.0
norm = np.linalg.norm(g)
g_clipped = c * g / norm if norm > c else g
print(f"  Original gradient: {g},  norm = {norm:.1f}")
print(f"  Clip threshold:    c = {c}")
print(f"  Clipped gradient:  {g_clipped},  norm = {np.linalg.norm(g_clipped):.1f}")
print(f"  Direction preserved: ratios {g[0]/g[1]:.2f} = {g_clipped[0]/g_clipped[1]:.2f}")

## 4. Let's Build It -- Component by Component

### 4.1 Learning Rate Scheduler

In [None]:
class WarmupCosineScheduler:
    """
    Learning rate scheduler with linear warmup + cosine decay.
    Used by GPT, LLaMA, Mistral, and virtually every modern LLM.
    """

    def __init__(self, optimizer, warmup_steps, total_steps, lr_max, lr_min=1e-5):
        """
        Args:
            optimizer: The optimizer whose learning rate to adjust
            warmup_steps: Number of warmup steps
            total_steps: Total training steps
            lr_max: Peak learning rate (reached at end of warmup)
            lr_min: Minimum learning rate (reached at end of training)
        """
        self.optimizer = optimizer
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.lr_max = lr_max
        self.lr_min = lr_min

    def get_lr(self, step):
        """Compute the learning rate for a given step."""
        if step < self.warmup_steps:
            # Linear warmup
            return self.lr_max * step / self.warmup_steps
        else:
            # Cosine decay
            progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
            progress = min(progress, 1.0)  # Clamp to [0, 1]
            return self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + math.cos(math.pi * progress))

    def step(self, current_step):
        """Update the optimizer's learning rate for the current step."""
        lr = self.get_lr(current_step)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr

# Test the scheduler
model = torch.nn.Linear(10, 1)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
scheduler = WarmupCosineScheduler(optimizer, warmup_steps=1000, total_steps=10000,
                                   lr_max=3e-4, lr_min=1e-5)

# Generate the full schedule
steps_range = range(10000)
lr_schedule = [scheduler.get_lr(s) for s in steps_range]

fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(steps_range, lr_schedule, linewidth=2, color='#3498db')
ax.axvline(x=1000, color='gray', linestyle='--', linewidth=1, alpha=0.5, label='End of warmup')
ax.fill_between(range(1000), [scheduler.get_lr(s) for s in range(1000)], alpha=0.2, color='#3498db')
ax.set_xlabel('Training Step', fontsize=12)
ax.set_ylabel('Learning Rate', fontsize=12)
ax.set_title('Warmup + Cosine Decay Schedule', fontsize=13, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(alpha=0.3)

# Annotate key points
ax.annotate(f'Peak: {3e-4:.0e}', xy=(1000, 3e-4), xytext=(2000, 3.2e-4),
            arrowprops=dict(arrowstyle='->', color='red'), fontsize=10, color='red')
ax.annotate(f'End: {1e-5:.0e}', xy=(9900, 1e-5), xytext=(8000, 5e-5),
            arrowprops=dict(arrowstyle='->', color='red'), fontsize=10, color='red')

plt.tight_layout()
plt.savefig('lr_schedule.png', dpi=150, bbox_inches='tight')
plt.show()
print("Checkpoint: Linear warmup for 1000 steps, then smooth cosine decay.")

### 4.2 Gradient Clipping

In [None]:
def clip_grad_norm(parameters, max_norm):
    """
    Clip gradients by their total L2 norm.
    If the total norm exceeds max_norm, scale all gradients down proportionally.

    Args:
        parameters: Iterable of model parameters
        max_norm: Maximum allowed gradient norm

    Returns:
        The total gradient norm (before clipping)
    """
    parameters = list(parameters)

    # Step 1: Compute total gradient norm
    total_norm_sq = 0.0
    for p in parameters:
        if p.grad is not None:
            total_norm_sq += p.grad.data.norm(2).item() ** 2
    total_norm = total_norm_sq ** 0.5

    # Step 2: If norm exceeds threshold, scale gradients down
    if total_norm > max_norm:
        clip_coef = max_norm / total_norm
        for p in parameters:
            if p.grad is not None:
                p.grad.data.mul_(clip_coef)

    return total_norm

# Demonstrate gradient clipping
model = torch.nn.Linear(5, 3)

# Simulate a large gradient
model.weight.grad = torch.randn(3, 5) * 10  # Unusually large
model.bias.grad = torch.randn(3) * 10

# Before clipping
grad_norm_before = sum(p.grad.data.norm(2).item()**2 for p in model.parameters() if p.grad is not None) ** 0.5
print(f"Gradient norm BEFORE clipping: {grad_norm_before:.2f}")

# Apply clipping
clip_grad_norm(model.parameters(), max_norm=1.0)

# After clipping
grad_norm_after = sum(p.grad.data.norm(2).item()**2 for p in model.parameters() if p.grad is not None) ** 0.5
print(f"Gradient norm AFTER clipping:  {grad_norm_after:.2f}")
print(f"Clip threshold: 1.0")
print(f"Scale factor: {1.0 / grad_norm_before:.4f}")

### Visualization: What happens without gradient clipping?

In [None]:
import torch.nn as nn

torch.manual_seed(42)

# Create a problem with occasional gradient spikes
X = torch.randn(500, 10)
y = torch.randn(500, 1)
# Add some outliers that will cause gradient spikes
X[::50] *= 20
y[::50] *= 20

class SmallNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(10, 64), nn.ReLU(),
            nn.Linear(64, 64), nn.ReLU(),
            nn.Linear(64, 1)
        )
    def forward(self, x):
        return self.layers(x)

def train_with_clipping(use_clipping, max_norm=1.0, lr=0.01, epochs=100):
    torch.manual_seed(42)
    model = SmallNet()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    losses = []
    grad_norms = []

    for epoch in range(epochs):
        pred = model(X)
        loss = nn.MSELoss()(pred, y)

        optimizer.zero_grad()
        loss.backward()

        # Record gradient norm
        total_norm = sum(p.grad.norm(2).item()**2 for p in model.parameters()) ** 0.5
        grad_norms.append(total_norm)

        if use_clipping:
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=max_norm)

        optimizer.step()
        losses.append(loss.item())

    return losses, grad_norms

losses_no_clip, norms_no_clip = train_with_clipping(False)
losses_clip, norms_clip = train_with_clipping(True)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss comparison
axes[0].plot(losses_no_clip, label='No Clipping', linewidth=1.5, color='#e74c3c', alpha=0.7)
axes[0].plot(losses_clip, label='With Clipping (c=1.0)', linewidth=2, color='#2ecc71')
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Loss', fontsize=13, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(alpha=0.3)
axes[0].set_yscale('log')

# Gradient norm comparison
axes[1].plot(norms_no_clip, label='No Clipping', linewidth=1, color='#e74c3c', alpha=0.5)
axes[1].plot(norms_clip, label='With Clipping', linewidth=1, color='#2ecc71', alpha=0.7)
axes[1].axhline(y=1.0, color='black', linestyle='--', linewidth=1, label='Clip threshold')
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('Gradient Norm', fontsize=12)
axes[1].set_title('Gradient Norms Over Training', fontsize=13, fontweight='bold')
axes[1].legend(fontsize=11)
axes[1].grid(alpha=0.3)
axes[1].set_yscale('log')

plt.tight_layout()
plt.savefig('clipping_effect.png', dpi=150, bbox_inches='tight')
plt.show()
print("Checkpoint: Gradient clipping prevents spikes from destabilizing training.")

## 5. Your Turn

### TODO 1: Implement a Step Decay Scheduler

Besides cosine decay, step decay is another common strategy: reduce the learning rate by a factor every N steps.

In [None]:
class StepDecayScheduler:
    """
    Step decay: multiply lr by gamma every step_size steps.
    Example: lr = 0.01, gamma = 0.5, step_size = 1000
    Steps 0-999: lr = 0.01
    Steps 1000-1999: lr = 0.005
    Steps 2000-2999: lr = 0.0025
    """

    def __init__(self, optimizer, lr_initial, gamma=0.5, step_size=1000):
        self.optimizer = optimizer
        self.lr_initial = lr_initial
        self.gamma = gamma
        self.step_size = step_size

    def get_lr(self, step):
        """
        TODO: Compute the learning rate at the given step.
        Hint: How many times has the lr been decayed?
        """
        # YOUR CODE HERE
        # num_decays = step // self.step_size
        # return self.lr_initial * (self.gamma ** num_decays)
        pass

    def step(self, current_step):
        lr = self.get_lr(current_step)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr

# SOLUTION (uncomment to verify):
# class StepDecayScheduler:
#     def __init__(self, optimizer, lr_initial, gamma=0.5, step_size=1000):
#         self.optimizer = optimizer
#         self.lr_initial = lr_initial
#         self.gamma = gamma
#         self.step_size = step_size
#     def get_lr(self, step):
#         num_decays = step // self.step_size
#         return self.lr_initial * (self.gamma ** num_decays)
#     def step(self, current_step):
#         lr = self.get_lr(current_step)
#         for param_group in self.optimizer.param_groups:
#             param_group['lr'] = lr
#         return lr

### TODO 2: Implement Value-Based Gradient Clipping

Instead of clipping by total norm, clip each gradient element independently to a range $[-c, c]$.

In [None]:
def clip_grad_value(parameters, clip_value):
    """
    Clip each gradient element to [-clip_value, clip_value].

    This is different from norm clipping:
    - Norm clipping scales ALL gradients by the same factor
    - Value clipping clips EACH element independently

    Args:
        parameters: Iterable of model parameters
        clip_value: Maximum absolute value for any gradient element
    """
    # TODO: Implement value-based clipping
    # Hint: Use torch.clamp_ on each parameter's gradient
    pass

    # SOLUTION (uncomment to verify):
    # for p in parameters:
    #     if p.grad is not None:
    #         p.grad.data.clamp_(-clip_value, clip_value)

## 6. Putting It All Together

Let us combine the learning rate scheduler and gradient clipping in a complete training setup.

In [None]:
# Full training loop with scheduler and clipping
torch.manual_seed(42)

# Create data
X = torch.randn(1000, 10)
y = (X[:, 0:1] * 2 + X[:, 1:2] * -1 + 0.5).detach()
X[::20] *= 15  # Add outliers for gradient spikes

# Model
model = SmallNet()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)

# Scheduler
total_steps = 200
warmup_steps = 20
scheduler = WarmupCosineScheduler(optimizer, warmup_steps=warmup_steps,
                                   total_steps=total_steps, lr_max=3e-4, lr_min=1e-5)

# Training
losses = []
learning_rates = []
gradient_norms = []

for step in range(total_steps):
    # Forward
    pred = model(X)
    loss = nn.MSELoss()(pred, y)

    # Backward
    optimizer.zero_grad()
    loss.backward()

    # Record gradient norm BEFORE clipping
    total_norm = sum(p.grad.norm(2).item()**2 for p in model.parameters()) ** 0.5
    gradient_norms.append(total_norm)

    # Gradient clipping
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

    # Update learning rate
    lr = scheduler.step(step)
    learning_rates.append(lr)

    # Step optimizer
    optimizer.step()
    losses.append(loss.item())

    if step % 40 == 0:
        print(f"Step {step:>3d} | Loss: {loss.item():.4f} | LR: {lr:.2e} | Grad Norm: {total_norm:.2f}")

print(f"\nFinal loss: {losses[-1]:.6f}")

In [None]:
# Visualize the complete training dynamics
fig, axes = plt.subplots(3, 1, figsize=(12, 10), sharex=True)

# Loss
axes[0].plot(losses, linewidth=1.5, color='#3498db')
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Dynamics with LR Schedule + Gradient Clipping', fontsize=14, fontweight='bold')
axes[0].set_yscale('log')
axes[0].grid(alpha=0.3)

# Learning rate
axes[1].plot(learning_rates, linewidth=1.5, color='#e74c3c')
axes[1].axvline(x=warmup_steps, color='gray', linestyle='--', alpha=0.5, label='End warmup')
axes[1].set_ylabel('Learning Rate', fontsize=12)
axes[1].legend(fontsize=10)
axes[1].grid(alpha=0.3)

# Gradient norms
axes[2].plot(gradient_norms, linewidth=1, color='#2ecc71', alpha=0.7)
axes[2].axhline(y=1.0, color='black', linestyle='--', linewidth=1, label='Clip threshold')
axes[2].set_ylabel('Gradient Norm', fontsize=12)
axes[2].set_xlabel('Training Step', fontsize=12)
axes[2].legend(fontsize=10)
axes[2].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('full_training_dynamics.png', dpi=150, bbox_inches='tight')
plt.show()
print("Checkpoint: LR starts low (warmup), peaks, decays (cosine). Gradients stay bounded.")

## 7. Training and Results: Comparing Schedules

In [None]:
# Compare different scheduling strategies
schedules = {
    'No Schedule (constant lr)': lambda s: 3e-4,
    'Warmup Only': lambda s: 3e-4 * min(1.0, s / 20),
    'Cosine Only': lambda s: 1e-5 + 0.5 * (3e-4 - 1e-5) * (1 + math.cos(math.pi * s / 200)),
    'Warmup + Cosine': lambda s: (3e-4 * s / 20) if s < 20 else (1e-5 + 0.5 * (3e-4 - 1e-5) * (1 + math.cos(math.pi * (s - 20) / 180))),
}

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
colors = ['#e74c3c', '#f39c12', '#3498db', '#2ecc71']

for (name, lr_fn), color in zip(schedules.items(), colors):
    torch.manual_seed(42)
    model = SmallNet()
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
    losses = []
    lrs = []

    for step in range(200):
        pred = model(X)
        loss = nn.MSELoss()(pred, y)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        lr = lr_fn(step)
        for pg in optimizer.param_groups:
            pg['lr'] = lr
        optimizer.step()

        losses.append(loss.item())
        lrs.append(lr)

    ax1.plot(losses, label=name, linewidth=1.5, color=color)
    ax2.plot(lrs, label=name, linewidth=1.5, color=color)

ax1.set_xlabel('Step', fontsize=12)
ax1.set_ylabel('Loss (log)', fontsize=12)
ax1.set_yscale('log')
ax1.set_title('Loss by Schedule Strategy', fontsize=13, fontweight='bold')
ax1.legend(fontsize=9)
ax1.grid(alpha=0.3)

ax2.set_xlabel('Step', fontsize=12)
ax2.set_ylabel('Learning Rate', fontsize=12)
ax2.set_title('Learning Rate Profiles', fontsize=13, fontweight='bold')
ax2.legend(fontsize=9)
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.savefig('schedule_comparison.png', dpi=150, bbox_inches='tight')
plt.show()
print("Warmup + Cosine decay gives the best final loss in almost all cases.")

## 8. Final Output

In [None]:
# Summary card
print("=" * 60)
print("  TRAINING SAFETY MECHANISMS -- SUMMARY")
print("=" * 60)
print()
print("  1. LINEAR WARMUP")
print("     - Start with lr near 0, increase linearly")
print("     - Prevents early instability from random weights")
print("     - Typical: 5-10% of total training steps")
print()
print("  2. COSINE DECAY")
print("     - Smoothly decrease lr following a cosine curve")
print("     - Fast decay in the middle, slow at start and end")
print("     - Lands at a small lr_min (e.g. 1e-5)")
print()
print("  3. GRADIENT CLIPPING")
print("     - Cap gradient norm at a threshold (typically 1.0)")
print("     - Preserves gradient direction, limits magnitude")
print("     - Prevents catastrophic weight updates")
print()
print("  USED BY: GPT, LLaMA, Mistral, Gemma, Phi, and")
print("  virtually every modern language model.")
print("=" * 60)

## 9. Reflection and Next Steps

**What we built:** A complete learning rate scheduler (warmup + cosine decay) and gradient clipping implementation from scratch.

**Key takeaways:**
- Warmup prevents early-training instability from random weight initialization
- Cosine decay provides smooth learning rate annealing with a slow start, fast middle, and slow finish
- Gradient clipping caps the gradient norm while preserving direction
- These three mechanisms are universal in modern LLM training

**What is next:** In Notebook 5, we will combine EVERYTHING -- tokenizer, dataset, dataloader, optimizer, scheduler, and gradient clipping -- into a complete, working training loop and train a small language model from scratch.