# Attention Mechanism with PyTorch Modules

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/attention_with_pytorch.ipynb)

This notebook implements the same attention mechanism as `attention_from_scratch.ipynb`, but using **PyTorch's `nn.Module` and `nn.Linear`** instead of raw tensor ops.

> **Companion notebook:** [attention_from_scratch.ipynb](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/attention_from_scratch.ipynb) — same content using only raw tensor ops + math foundations.

In [None]:
# Install dependencies (Colab already has torch, but this ensures compatibility)
!pip install torch matplotlib -q

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import matplotlib.pyplot as plt

torch.manual_seed(42)

# Use GPU if available (Colab: Runtime > Change runtime type > GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Setup: Token Embeddings via `nn.Embedding`

Instead of random tensors, we use PyTorch's embedding layer to map token indices to dense vectors.

In [None]:
vocab_size = 50
seq_len = 4
d_model = 8

embedding = nn.Embedding(vocab_size, d_model).to(device)

# Simulate token indices for ["The", "cat", "sat", "down"]
token_ids = torch.tensor([5, 12, 31, 7], device=device)
X = embedding(token_ids)  # (seq_len, d_model)

print("Token IDs:", token_ids)
print("Embeddings shape:", X.shape)
print(X)

## 2. Scaled Dot-Product Attention

$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q K^T}{\sqrt{d_k}}\right) V$$

We use `nn.Linear` for the Q, K, V projections and `F.softmax` for the softmax.

In [None]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_model, d_k, d_v):
        super().__init__()
        self.W_Q = nn.Linear(d_model, d_k, bias=False)
        self.W_K = nn.Linear(d_model, d_k, bias=False)
        self.W_V = nn.Linear(d_model, d_v, bias=False)
        self.scale = math.sqrt(d_k)

    def forward(self, X, mask=None):
        Q = self.W_Q(X)  # (seq_len, d_k)
        K = self.W_K(X)  # (seq_len, d_k)
        V = self.W_V(X)  # (seq_len, d_v)

        scores = Q @ K.T / self.scale  # (seq_len, seq_len)

        if mask is not None:
            scores = scores.masked_fill(mask, float('-inf'))

        weights = F.softmax(scores, dim=-1)
        output = weights @ V  # (seq_len, d_v)

        return output, weights

In [None]:
d_k = 6
d_v = 6

attn = ScaledDotProductAttention(d_model, d_k, d_v).to(device)
output, weights = attn(X)

print("Output shape:", output.shape)
print("Attention weights (each row sums to 1):")
print(weights)
print("Row sums:", weights.sum(dim=-1))

### Key difference from the from-scratch version

| From scratch | With modules |
|---|---|
| `W_Q = torch.randn(d_model, d_k) * 0.1` | `self.W_Q = nn.Linear(d_model, d_k, bias=False)` |
| `Q = X @ W_Q` | `Q = self.W_Q(X)` |
| Custom `softmax()` function | `F.softmax(scores, dim=-1)` |

`nn.Linear` manages the weight tensor internally, handles initialization, and registers it as a trainable parameter.

## 3. Causal (Decoder) Mask

Same concept — block future tokens so token $i$ only attends to positions $0 \ldots i$.

In [None]:
causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), diagonal=1)
print("Causal mask (True = blocked):")
print(causal_mask.int())

causal_output, causal_weights = attn(X, mask=causal_mask)
print("\nCausal attention weights:")
print(causal_weights)
print("\nUpper triangle is 0 — no attending to future tokens.")

## 4. Multi-Head Attention with `nn.Module`

$$\text{MultiHead}(X) = \text{Concat}(\text{head}_1, \dots, \text{head}_h) \cdot W_O$$

We pack all head projections into single `nn.Linear` layers and use reshaping to split/merge heads — same efficient approach as the from-scratch version, but with cleaner code.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"

        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        # All heads packed into single projections
        self.W_Q = nn.Linear(d_model, d_model, bias=False)
        self.W_K = nn.Linear(d_model, d_model, bias=False)
        self.W_V = nn.Linear(d_model, d_model, bias=False)
        self.W_O = nn.Linear(d_model, d_model, bias=False)

    def forward(self, X, mask=None):
        seq_len = X.shape[0]

        # 1. Project
        Q = self.W_Q(X)  # (seq_len, d_model)
        K = self.W_K(X)
        V = self.W_V(X)

        # 2. Split into heads: (seq_len, d_model) -> (n_heads, seq_len, d_k)
        Q = Q.view(seq_len, self.n_heads, self.d_k).transpose(0, 1)
        K = K.view(seq_len, self.n_heads, self.d_k).transpose(0, 1)
        V = V.view(seq_len, self.n_heads, self.d_k).transpose(0, 1)

        # 3. Attention
        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask.unsqueeze(0), float('-inf'))
        weights = F.softmax(scores, dim=-1)  # (n_heads, seq_len, seq_len)
        attn_out = weights @ V               # (n_heads, seq_len, d_k)

        # 4. Merge heads: (n_heads, seq_len, d_k) -> (seq_len, d_model)
        attn_out = attn_out.transpose(0, 1).contiguous().view(seq_len, -1)

        # 5. Output projection
        output = self.W_O(attn_out)

        return output, weights

In [None]:
n_heads = 2
mha = MultiHeadAttention(d_model, n_heads).to(device)

mha_output, mha_weights = mha(X)

print("MHA output shape:", mha_output.shape)
print("Weights shape:", mha_weights.shape, "— (n_heads, seq_len, seq_len)")
print("\nHead 0 weights:")
print(mha_weights[0])
print("\nHead 1 weights:")
print(mha_weights[1])

## 5. Inspecting Parameters

A major advantage of `nn.Module`: all learnable parameters are automatically tracked.

In [None]:
print("Trainable parameters in MultiHeadAttention:")
total = 0
for name, param in mha.named_parameters():
    print(f"  {name:12s}  shape={str(list(param.shape)):16s}  params={param.numel()}")
    total += param.numel()
print(f"  {'Total':12s}  {'':{16}}  params={total}")

## 6. Training Loop Example

With `nn.Module`, we can plug attention into a gradient-based training loop. Here's a toy example that trains the attention layer to copy its input.

In [None]:
# Toy task: train attention to reconstruct X from X
# Detach X so the embedding graph doesn't interfere with the training loop
X_target = X.detach()

mha_train = MultiHeadAttention(d_model, n_heads).to(device)
optimizer = torch.optim.Adam(mha_train.parameters(), lr=0.01)

losses = []
for step in range(200):
    optimizer.zero_grad()
    out, _ = mha_train(X_target)
    loss = F.mse_loss(out, X_target)
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

print(f"Initial loss: {losses[0]:.4f}")
print(f"Final loss:   {losses[-1]:.6f}")

In [None]:
plt.figure(figsize=(8, 3))
plt.plot(losses)
plt.xlabel("Step")
plt.ylabel("MSE Loss")
plt.title("Training: Attention learns to copy input")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 7. Visualizing Learned Attention Weights

In [None]:
tokens = ["The", "cat", "sat", "down"]

with torch.no_grad():
    _, learned_weights = mha_train(X_target)

fig, axes = plt.subplots(1, n_heads, figsize=(5 * n_heads, 4))
if n_heads == 1:
    axes = [axes]

for i, ax in enumerate(axes):
    w = learned_weights[i].cpu().numpy()
    im = ax.imshow(w, cmap='Blues', vmin=0, vmax=1)
    ax.set_xticks(range(len(tokens)))
    ax.set_yticks(range(len(tokens)))
    ax.set_xticklabels(tokens)
    ax.set_yticklabels(tokens)
    ax.set_xlabel("Key (attending to)")
    ax.set_ylabel("Query (token)")
    ax.set_title(f"Head {i} (trained)")

    for row in range(len(tokens)):
        for col in range(len(tokens)):
            ax.text(col, row, f"{w[row, col]:.2f}",
                    ha='center', va='center', fontsize=10)

plt.tight_layout()
plt.suptitle("Learned Attention Weights (after training)", y=1.02, fontsize=14)
plt.show()

## 8. Using PyTorch's Built-in `nn.MultiheadAttention`

PyTorch provides a ready-made implementation. Let's compare it to ours.

In [None]:
# PyTorch's built-in expects (seq_len, batch, d_model) for unbatched input
builtin_mha = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads,
                                     bias=False, batch_first=False).to(device)

# Add batch dimension: (seq_len, 1, d_model)
X_batched = X.unsqueeze(1)

with torch.no_grad():
    builtin_out, builtin_weights = builtin_mha(X_batched, X_batched, X_batched)

print("Built-in output shape:", builtin_out.squeeze(1).shape)
print("Built-in weights shape:", builtin_weights.shape)
print("\nBuilt-in attention weights:")
print(builtin_weights.squeeze(0))

## 9. Comparison: Three Levels of Abstraction

| Level | Notebook | What you manage | What PyTorch manages |
|-------|----------|-----------------|----------------------|
| **Low** | `attention_from_scratch.ipynb` | Weight tensors, matmul, softmax, masking, reshaping | Nothing |
| **Mid** | This notebook (`MultiHeadAttention`) | Forward logic, reshaping | Weight init, parameter tracking, gradients |
| **High** | `nn.MultiheadAttention` | Just call it | Everything |