# Transformer Fundamentals: Multi-Head Attention

This notebook demonstrates the core mechanism of transformers: **multi-head attention**. We'll train a simple attention model on a toy task and visualize how attention patterns emerge during learning.

---

## Imports

In [1]:
import torch
from tqdm import tqdm

from neural_stack.attention import MultiHeadAttention
from neural_stack.visualization import plot_attention_heatmap

---

## Configuration

Hyperparameters for the multi-head attention experiment.

In [2]:
SEQ_LENGTH = 10
DIM_MODEL = 64
NUM_HEADS = 4

BATCH_SIZE = 32
NUM_ITERS = 2000

LEARNING_RATE = 1e-2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

---

## Data Generation

**Task**: Sequence reversal - the model learns to reverse input sequences using attention.

For each sample:
- Input `x`: Random sequence of shape `[seq_length, dim_model]`
- Target `y`: Flipped version of `x`

In [3]:
class RandomSequenceGenerator(torch.utils.data.IterableDataset):
    def __init__(self, seq_length, dim_model):
        super(RandomSequenceGenerator, self).__init__()
        self.seq_length = seq_length
        self.dim_model = dim_model

    def __iter__(self):
        while True:
            x = torch.randn((self.seq_length, self.dim_model))
            y = torch.flip(x, dims=[0])
            yield x, y

---

## Model Setup

Initialize multi-head attention model, optimizer, and loss function.

In [4]:
multi_head_attn = MultiHeadAttention(num_heads=NUM_HEADS, dim_model=DIM_MODEL, dropout=0.0)
multi_head_attn = multi_head_attn.to(device)

dataset_iterator = RandomSequenceGenerator(seq_length=SEQ_LENGTH, dim_model=DIM_MODEL)
dataloader = torch.utils.data.DataLoader(
    dataset_iterator,
    batch_size=BATCH_SIZE
)

optimizer = torch.optim.Adam(
    params=multi_head_attn.parameters(),
    lr=LEARNING_RATE
)

scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=400, gamma=0.1)

criterion = torch.nn.MSELoss().to(device)

---

## Training Loop

Train the attention mechanism to learn the sequence reversal task.

In [5]:
iter_idx = 0
progressbar = tqdm(dataloader, total=NUM_ITERS)
for x, y in progressbar:
    x = x.to(device).float()
    y = y.to(device).float()
    
    out, attn_scores = multi_head_attn(x, y, x)
    loss = criterion(out, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()

    if iter_idx % 10 == 0:
        progressbar.set_postfix({'loss': f'{loss.item():.4f}'})

    iter_idx += 1
    if iter_idx >= NUM_ITERS:
        break

100%|█████████▉| 1999/2000 [00:06<00:00, 311.23it/s, loss=0.0055]


---

## Validation

Generate a validation sample and capture attention scores for visualization.

In [6]:
multi_head_attn.eval()
val_x, val_y = dataset_iterator.__iter__().__next__()

val_x = val_x.to(device).float().unsqueeze(0)
val_y = val_y.to(device).float().unsqueeze(0)

val_out, val_attn_scores = multi_head_attn(val_x, val_y, val_x)

---

## Attention Visualization

Visualize how each attention head learned to attend to different positions in the sequence.

The plot shows:
- **Top**: Average attention pattern across all heads
- **Bottom**: Individual attention patterns for each of the 4 heads

In [7]:
fig = plot_attention_heatmap(val_attn_scores, batch_idx=0)
fig.show()

---

## Conclusion

This toy task forces the attention mechanism into a pattern that attends solely to the input positions (through the Key and Value matrices) as related to the desired output positions (through the Query matrix). As the Query tensor is constructed from the reversed Input tensor and the error function penalises any mismatch between input and output proportionally with the index-distance between the elements, the attention values are forced into an anti-diagonal pattern, forcing the reversal of the Key/Value inputs. 

The fact that the value of each element in the RandomSequence is random, although unintuitive for a classic DL problem, simplifies the toy problem of testing the attention mechanism -- the model does not have to (and cannot) learn anything from the data distribution, since it is completely random. Thus, the only information in the data that is put through the model is contained by its positional ordering.