# Implementing Transformer Models
## Practical VIII
Carel van Niekerk & Hsien-Chin Lin

2-6.12.2024

---

In this practical we will implement the learning rate scheduler as well as initialise the optimiser to train our transformer model.

### 1. Learning rate scheduler

The learning rate scheduler is used to adjust the learning rate during training. The learning rate scheduler is called every training step and returns the learning rate for that step. In a transformer model the learning rate scheduler is important because the model is trained for a long time and the learning rate needs to be adjusted to ensure that the model converges to a good solution.

### 2. Optimiser

In this practical we will use the AdamW optimiser. The AdamW optimiser is a variant of the Adam optimiser that uses weight decay to regularise the model.

# Exercises

1. Study the learning rate scheduler used in the paper [Attention is all you need](https://arxiv.org/abs/1706.03762).
2. Implement the learning rate scheduler. It is important that your scheduler class has the same interface as the pytorch learning rate scheduler classes, that is, it should have a `step()` method that updates the learning rate and a `get_lr()` method that returns the learning rate for the current step.
3. Study the AdamW optimiser used. Write down the update equations and explain the reasoning behind the bias correction and decoupled weight decay.
4. Initialise a AdamW optimiser for your transformer model. It is important to not use weight decay on the bias and layer normalisation parameters.

In [None]:
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR

class TransformerLRScheduler(LambdaLR):
    def __init__(self, 
                 optimizer: Optimizer, 
                 d_model: int, 
                 warmup_steps: int = 4000) -> None:
        
        super().__init__(optimizer, self.lr_lambda)

        self.d_model = d_model
        self.warmup_steps = warmup_steps

    def lr_lambda(self, step: int) -> float:
        step = step if step != 0 else 1
        return self.d_model ** (-0.5) * min(step ** (-0.5), step * self.warmup_steps ** (-1.5))
    
    # other methods are inherited from LambdaLR


# AdamW optimiser update equations

$$
m_t = \beta_1 \cdot m_{t-1} + (1 - \beta_1) \cdot g_t
$$

$$
v_t = \beta_2 \cdot v_{t-1} + (1 - \beta_2) \cdot (g_t)^2
$$

$$
\hat{m}_t = \frac{m_t}{1 - \beta_1^t}
$$

$$
\hat{v}_t = \frac{v_t}{1 - \beta_2^t}
$$

$$
\theta_t = \theta_{t-1} - \text{lr} \cdot \left( \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon} + \text{weight\_decay} \cdot \theta_{t-1} \right)
$$


In [None]:
# Initialise AdamW optimiser
model = ...  # your transformer model
no_decay = ['bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=1e-3)

# Initialise learning rate scheduler
scheduler = TransformerLRScheduler(optimizer, d_model=512, warmup_steps=4000)