# Implementing Transformer Models
## Practical VIII
Carel van Niekerk & Hsien-Chin Lin

8-12.12.2025

---

In this practical we will implement the learning rate scheduler as well as initialise the optimiser to train our transformer model.

### 1. Learning rate scheduler

The learning rate scheduler is used to adjust the learning rate during training. The learning rate scheduler is called every training step and returns the learning rate for that step. In a transformer model the learning rate scheduler is important because the model is trained for a long time and the learning rate needs to be adjusted to ensure that the model converges to a good solution.

### 2. Optimiser

In this practical we will use the AdamW optimiser. The AdamW optimiser is a variant of the Adam optimiser that uses weight decay to regularise the model.

# Exercises

1. Study the learning rate scheduler used in the paper [Attention is all you need](https://arxiv.org/abs/1706.03762).
2. Implement the learning rate scheduler. It is important that your scheduler class has the same interface as the pytorch learning rate scheduler classes, that is, it should have a `step()` method that updates the learning rate and a `get_lr()` method that returns the learning rate for the current step.
3. Study the AdamW optimiser used. Write down the update equations and explain the reasoning behind the bias correction and decoupled weight decay.
4. Initialise a AdamW optimiser for your transformer model. It is important to not use weight decay on the bias and layer normalisation parameters.

## Exercise 3: AdamW Optimizer

### Update Equations

Given parameters $\theta$, gradients $g_t$, learning rate $\alpha$, weight decay $\lambda$, and momentum parameters $\beta_1, \beta_2$:

**1. Compute biased first moment estimate:**
$$m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t$$

**2. Compute biased second moment estimate:**
$$v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot g_t^2$$

**3. Bias correction:**
$$\hat{m}_t = \frac{m_t}{1 - \beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1 - \beta_2^t}$$

**4. Parameter update with decoupled weight decay:**
$$\theta_t = \theta_{t-1} - \alpha \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \lambda \cdot \theta_{t-1} \right)$$

### Bias Correction Reasoning

At initialization, $m_0 = 0$ and $v_0 = 0$. In early steps, both estimates are **biased toward zero**:
- After step 1: $m_1 = (1-\beta_1) g_1$ (much smaller than true mean)
- The bias is especially severe for $\beta_1 = 0.9$ and $\beta_2 = 0.999$

The correction factor $(1 - \beta^t)$ compensates for this:
- At $t=1$: divides by $(1 - 0.9) = 0.1$, scaling up 10x
- As $t \to \infty$: $(1 - \beta^t) \to 1$, no correction needed

### Decoupled Weight Decay Reasoning

**Original Adam (L2 regularization):** Adds $\lambda \theta$ to the gradient, then applies adaptive scaling:
$$\theta_t = \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t + \lambda \theta_{t-1}}{\sqrt{\hat{v}_t} + \epsilon}$$

**Problem:** The weight decay is scaled by $1/\sqrt{\hat{v}_t}$, making it inconsistent across parameters with different gradient magnitudes.

**AdamW (decoupled):** Applies weight decay directly to parameters, **after** the adaptive update:
$$\theta_t = \theta_{t-1} - \alpha \cdot \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} - \alpha \lambda \theta_{t-1}$$

**Benefits:**
1. Weight decay is consistent regardless of gradient history
2. Equivalent to true L2 regularization in SGD
3. Better generalization in practice
4. Hyperparameter $\lambda$ has consistent meaning across optimizers