# 📘 Learning Rate Schedules

During training, the **learning rate (LR)** controls how large each parameter update is.  
Instead of keeping LR constant, we can change it over time to improve convergence.  
This process is called a **learning rate schedule**.

---

## 🔹 Step Decay

**Concept:**  
- Keep LR constant for a while, then reduce it suddenly (in "steps").  
- Often used when the loss plateaus after certain epochs.  
- Produces abrupt drops in LR.

**Formula:**  

$$
\alpha_t = \alpha_0 \cdot \gamma^{\left\lfloor \frac{t}{T_{\text{step}}} \right\rfloor}
$$

- $ \alpha_0 $ : initial learning rate  
- $ \gamma \in (0,1) $ : decay factor (e.g., 0.1)  
- $ T_{\text{step}} $ : interval of steps/epochs between drops  
- $ t $ : current step or epoch  

---

## 🔹 Cosine Annealing

**Concept:**  
- Smoothly decrease LR following a half-cosine curve.  
- Starts at a maximum ($ \alpha_0 $) and ends at a minimum ($ \eta_{\min} $).  
- Avoids sudden jumps, often yields better minima.

**Formula:**  

$$
\alpha_t = \eta_{\min} + \tfrac{1}{2}(\alpha_0 - \eta_{\min}) 
\left(1 + \cos\!\left(\frac{\pi t}{T_{\max}}\right)\right)
$$

- $ \alpha_0 $ : initial learning rate  
- $ \eta_{\min} $ : final (minimum) learning rate  
- $ T_{\max} $ : total number of steps/epochs in the schedule  
- $ t $ : current step or epoch  

---

## 🔑 Notes
- **Epoch vs Step:** schedules can be applied per-epoch or per-mini-batch step.  
- **Step Decay:** simple and effective, but has abrupt changes.  
- **Cosine Annealing:** smooth decay, often preferred in modern deep learning.  
- Both can be combined with a **warm-up phase** (gradual LR increase at start).


In [1]:
#imports
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import random

## Step Decay — from scratch (short recap)

- Drops LR by a constant factor every fixed interval.
- **Formula (epoch/step \(t\)):**

  $$
  \alpha_t \;=\; \alpha_0 \cdot \gamma^{\left\lfloor \frac{t}{T_{\text{step}}} \right\rfloor}
  $$

- Use when loss plateaus at predictable times; simple but abrupt.


In [2]:
# Step Decay (compact)
from math import floor

class StepDecayLR:
    """alpha_t = alpha0 * gamma^(floor(t/step_size))"""
    def __init__(self, optimizer, step_size: int, gamma: float = 0.1):
        self.opt = optimizer
        self.step_size = int(step_size)
        self.gamma = float(gamma)
        self.t = -1
        self.base = [g["lr"] for g in self.opt.param_groups]
        self.step()  # set t=0

    def step(self):
        self.t += 1
        k = self.t // self.step_size
        for g, b in zip(self.opt.param_groups, self.base):
            g["lr"] = b * (self.gamma ** k)

    def last_lr(self):
        return [g["lr"] for g in self.opt.param_groups]


## Cosine Annealing — from scratch (short recap)

- **Formula:**

  $$
  \alpha_t \;=\; \eta_{\min} \;+\; \tfrac{1}{2}(\alpha_0 - \eta_{\min})
  \left( 1 + \cos\!\left( \frac{\pi t}{T_{\max}} \right) \right)
  $$

- Avoids abrupt jumps; often improves late-stage convergence.


In [3]:
# Cosine Annealing (compact, no restarts)
from math import pi, cos

class CosineAnnealingLR:
    """alpha_t = eta_min + 0.5*(alpha0-eta_min)*(1+cos(pi*t/T_max))"""
    def __init__(self, optimizer, T_max: int, eta_min: float = 0.0):
        self.opt = optimizer
        self.T_max = int(T_max)
        self.eta_min = float(eta_min)
        self.t = -1
        self.base = [g["lr"] for g in self.opt.param_groups]
        self.step()  # set t=0

    def step(self):
        self.t += 1
        tt = min(max(self.t, 0), self.T_max)
        for g, b in zip(self.opt.param_groups, self.base):
            g["lr"] = self.eta_min + 0.5 * (b - self.eta_min) * (1 + cos(pi * tt / self.T_max))

    def last_lr(self):
        return [g["lr"] for g in self.opt.param_groups]


## Quick experiment on toy data (uses the custom schedulers)

- Trains a tiny MLP on random data for a few epochs.  
- Switch the scheduler block to compare Step Decay vs. Cosine.  
- Logs LR each epoch so you can see the schedule working.


In [7]:
torch.manual_seed(7)

x = torch.randn(512, 10)
true_w = torch.randn(10, 1)
y = x @ true_w + 0.1 * torch.randn(512, 1)

model = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 1))
opt = torch.optim.SGD(model.parameters(), lr=0.1)
loss_fn = nn.MSELoss()

# ---- choose ONE ----
#sched = StepDecayLR(opt, step_size=5, gamma=0.1)
sched = CosineAnnealingLR(opt, T_max=20, eta_min=1e-3)

EPOCHS = 20
for ep in range(EPOCHS):
    opt.zero_grad()
    loss = loss_fn(model(x), y)
    loss.backward()
    opt.step()
    sched.step()
    print(f"epoch {ep+1:02d} | lr={sched.last_lr()[0]:.6f} | loss={loss.item():.4f}")


epoch 01 | lr=0.099391 | loss=8.1559
epoch 02 | lr=0.097577 | loss=6.6710
epoch 03 | lr=0.094605 | loss=5.3163
epoch 04 | lr=0.090546 | loss=3.8573
epoch 05 | lr=0.085502 | loss=2.4531
epoch 06 | lr=0.079595 | loss=1.3830
epoch 07 | lr=0.072973 | loss=0.7649
epoch 08 | lr=0.065796 | loss=0.4699
epoch 09 | lr=0.058244 | loss=0.3321
epoch 10 | lr=0.050500 | loss=0.2622
epoch 11 | lr=0.042756 | loss=0.2230
epoch 12 | lr=0.035204 | loss=0.1989
epoch 13 | lr=0.028027 | loss=0.1832
epoch 14 | lr=0.021405 | loss=0.1727
epoch 15 | lr=0.015498 | loss=0.1655
epoch 16 | lr=0.010454 | loss=0.1606
epoch 17 | lr=0.006395 | loss=0.1574
epoch 18 | lr=0.003423 | loss=0.1553
epoch 19 | lr=0.001609 | loss=0.1540
epoch 20 | lr=0.001000 | loss=0.1534


## Using PyTorch built-ins for comparison

- `torch.optim.lr_scheduler.StepLR`  
- `torch.optim.lr_scheduler.CosineAnnealingLR`  
(identical usage pattern: call `scheduler.step()` at your chosen cadence.)


In [9]:
torch.manual_seed(7)

x = torch.randn(512, 10)
y = x @ torch.randn(10, 1) + 0.1 * torch.randn(512, 1)

m = nn.Sequential(nn.Linear(10, 32), nn.ReLU(), nn.Linear(32, 1))
opt = torch.optim.SGD(m.parameters(), lr=0.1)
loss_fn = nn.MSELoss()

#sch = torch.optim.lr_scheduler.StepLR(opt, step_size=5, gamma=0.1)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=20, eta_min=1e-3)

for ep in range(20):
    opt.zero_grad()
    loss = loss_fn(m(x), y)
    loss.backward()
    opt.step()
    sch.step()
    print(f"[builtin] epoch {ep+1:02d} | lr={opt.param_groups[0]['lr']:.6f} | loss={loss.item():.4f}")


[builtin] epoch 01 | lr=0.099391 | loss=8.1559
[builtin] epoch 02 | lr=0.097577 | loss=6.6710
[builtin] epoch 03 | lr=0.094605 | loss=5.3163
[builtin] epoch 04 | lr=0.090546 | loss=3.8573
[builtin] epoch 05 | lr=0.085502 | loss=2.4531
[builtin] epoch 06 | lr=0.079595 | loss=1.3830
[builtin] epoch 07 | lr=0.072973 | loss=0.7649
[builtin] epoch 08 | lr=0.065796 | loss=0.4699
[builtin] epoch 09 | lr=0.058244 | loss=0.3321
[builtin] epoch 10 | lr=0.050500 | loss=0.2622
[builtin] epoch 11 | lr=0.042756 | loss=0.2230
[builtin] epoch 12 | lr=0.035204 | loss=0.1989
[builtin] epoch 13 | lr=0.028027 | loss=0.1832
[builtin] epoch 14 | lr=0.021405 | loss=0.1727
[builtin] epoch 15 | lr=0.015498 | loss=0.1655
[builtin] epoch 16 | lr=0.010454 | loss=0.1606
[builtin] epoch 17 | lr=0.006395 | loss=0.1574
[builtin] epoch 18 | lr=0.003423 | loss=0.1553
[builtin] epoch 19 | lr=0.001609 | loss=0.1540
[builtin] epoch 20 | lr=0.001000 | loss=0.1534
