### Warmup steps

The primacy effect refers to the phenomenon where the model disproportionately weights the patterns it learns from the early batches of data. If the early batches happen to contain a cluster of highly similar or strongly-featured observations, the model may overfit to those patterns early on, leading to suboptimal generalization.

This is especially problematic when:

- The dataset is highly differentiated (e.g., contains distinct clusters or subgroups).

- The data is not perfectly shuffled, leading to batches that are not representative of the overall dataset.

- The model is sensitive to initialization (e.g., deep neural networks)


**Learning rate warmup** reduces the primacy effect by:

1. Starting with a Small Learning Rate: Early updates are small, so the model doesn’t make drastic changes based on the first few batches.

2. Gradually Increasing the Learning Rate: As the model sees more data, the learning rate increases, allowing it to make larger updates once it has a better understanding of the overall dataset.

This ensures that the model doesn’t overcommit to patterns in the early batches, which may not be representative of the entire dataset.

**Warmup Duration**:

- Typically, warmup lasts for 1 epoch (i.e., the learning rate reaches its maximum value after the model has seen the entire dataset once).
- For highly skewed datasets, you might use a longer warmup period (e.g., 2-3 epochs) to ensure the model has seen enough data to balance the influence of early batches.
- For homogeneous datasets, a shorter warmup period (e.g., half an epoch) may suffice.


You can have two types of warmup -
1. Constant warmup - A fixed warmup value is used for the initial n steps to warmup the network, and then change to your actual learning rate value.
2. Linear warmup - In the first few steps, the learning rate is set to be lower than base learning rate and increased gradually to approach it as step number increases.

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

model = nn.Linear(10, 1)

optimizer = optim.Adam(model.parameters(), lr=0.001)

# Warmup parameters
warmup_iterations = 1000
max_learning_rate = 0.001

for iteration in range(1, 5001):
    # Linear warmup
    if iteration <= warmup_iterations:
        lr = (iteration / warmup_iterations) * max_learning_rate
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr
    else:
        lr = max_learning_rate

    optimizer.zero_grad()
    outputs = model(torch.randn(32, 10))  # Dummy input
    loss = ((outputs - torch.randn(32, 1)) ** 2).mean()  # Dummy loss
    loss.backward()
    optimizer.step()

    if iteration % 100 == 0:
        print(f"Iteration {iteration}, Learning Rate: {lr:.6f}, Loss: {loss.item():.6f}")

Iteration 100, Learning Rate: 0.000100, Loss: 1.307066
Iteration 200, Learning Rate: 0.000200, Loss: 2.096521
Iteration 300, Learning Rate: 0.000300, Loss: 1.370643
Iteration 400, Learning Rate: 0.000400, Loss: 1.518445
Iteration 500, Learning Rate: 0.000500, Loss: 1.430872
Iteration 600, Learning Rate: 0.000600, Loss: 1.472371
Iteration 700, Learning Rate: 0.000700, Loss: 0.922527
Iteration 800, Learning Rate: 0.000800, Loss: 1.153891
Iteration 900, Learning Rate: 0.000900, Loss: 1.240970
Iteration 1000, Learning Rate: 0.001000, Loss: 0.952746
Iteration 1100, Learning Rate: 0.001000, Loss: 0.969623
Iteration 1200, Learning Rate: 0.001000, Loss: 1.225448
Iteration 1300, Learning Rate: 0.001000, Loss: 0.889674
Iteration 1400, Learning Rate: 0.001000, Loss: 0.923594
Iteration 1500, Learning Rate: 0.001000, Loss: 0.818584
Iteration 1600, Learning Rate: 0.001000, Loss: 1.305761
Iteration 1700, Learning Rate: 0.001000, Loss: 0.796680
Iteration 1800, Learning Rate: 0.001000, Loss: 1.208030
I