# Coding Encoder–Decoder Attention and Multi‑Head Attention in PyTorch

A hands-on lesson that implements generic attention (self/cross) and a simple multi‑head attention wrapper, with explanations, shapes, and references.

![Transformer architecture diagram (Wikimedia Commons)](https://commons.wikimedia.org/wiki/Special:FilePath/Transformer%2C_full_architecture.png)

- Paper: [Attention Is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762)
- Visual guide: [The Illustrated Transformer (Jay Alammar)](https://jalammar.github.io/illustrated-transformer/)
- Course: [Hugging Face — Transformer Architectures](https://huggingface.co/learn/llm-course/en/chapter1/6)
- Docs: [PyTorch nn.MultiheadAttention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)

## Imports and prerequisites

We use PyTorch to build generic attention (self or cross) and a simple multi‑head wrapper:

- `torch`: tensors and linear algebra helpers.
- `torch.nn` (`nn`): layers like `Linear` and base class `Module`.
- `torch.nn.functional` (`F`): stateless ops like `softmax` used in attention.

Note: Tensors are multi‑dimensional arrays optimized for neural networks.

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

  cpu = _conversion_method_template(device=torch.device("cpu"))


## Generic Attention (self or cross)

This module can perform self-attention (Q=K=V from the same encodings) or cross-attention (Q from decoder, K/V from encoder).

### Init arguments

| Argument | Meaning | Why it matters |
| --- | --- | --- |
| `d_model` | Features per token (embedding width) | Sets sizes of `W_q`, `W_k`, `W_v` |
| `row_dim` | Axis for rows in transposes | Controls `K.T` behavior |
| `col_dim` | Axis for columns/softmax | Controls softmax dimension and sqrt(d_k) extraction |

### Forward (single head)

```text
q = W_q(encodings_for_q)
k = W_k(encodings_for_k)
v = W_v(encodings_for_v)
S = q · k^T
S_scaled = S / sqrt(d_k)
if mask:
  S_scaled = S_scaled.masked_fill(mask, -1e9)
A = softmax(S_scaled, dim=col_dim)
O = A · v
```

### Shapes

| Tensor | Shape |
| --- | --- |
| `encodings_for_q` | `n_q × d_model` |
| `encodings_for_k`, `encodings_for_v` | `n_kv × d_model` |
| `q`, `k`, `v` | `n_* × d_model` (here) |
| `S = qk^T` | `n_q × n_kv` |
| `A` | `n_q × n_kv` |
| `O = Av` | `n_q × d_model` |

## Sample inputs: self vs cross attention

For self-attention, we set `encodings_for_q = encodings_for_k = encodings_for_v`.
This keeps the demo simple and lets you verify shapes by hand.

- Shape here: each is `n × d_model = 3 × 2`.
- In cross-attention, you would pass different sources for `encodings_for_q` (decoder states) versus `encodings_for_k, encodings_for_v` (encoder outputs).

| Tensor | Meaning | Shape |
| --- | --- | --- |
| `encodings_for_q` | Inputs for queries | `n_q × d_model` |
| `encodings_for_k` | Inputs for keys | `n_kv × d_model` |
| `encodings_for_v` | Inputs for values | `n_kv × d_model` |

In [3]:
class Attention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        super().__init__()
        self.d_model = d_model
        self.row_dim = row_dim
        self.col_dim = col_dim
        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):
        q = self.W_q(encodings_for_q)
        k = self.W_k(encodings_for_k)
        v = self.W_v(encodings_for_v)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

In [4]:
encodings_for_q = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_k = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

encodings_for_v = torch.tensor([[1.16, 0.23],
                                [0.57, 1.36],
                                [4.41, -2.16]])

## Multi‑Head Attention (simple wrapper)

We wrap multiple single‑head attention modules and concatenate their outputs along the feature dimension.

### Idea

- Each head has its own `W_q, W_k, W_v` and computes scaled dot‑product attention independently.
- We concatenate per‑head outputs: `O_cat = concat(O_1, …, O_H)`.
- In full Transformers, a learned projection `W_O` maps `O_cat → d_model`. Here we focus on concatenation to see the effect of multiple heads.

### Shapes (typical)

| Item | Shape |
| --- | --- |
| Input `X` | `n × d_model` |
| Per‑head `Q_h, K_h, V_h` | `n × d_k`, `n × d_k`, `n × d_v` |
| Per‑head output `O_h` | `n × d_v` |
| Concatenated `O_cat` | `n × (H·d_v)` |
| Output projection `W_O` (not shown here) | `(H·d_v) × d_model` |

This mirrors the original paper’s MHA block while keeping code minimal for teaching.

In [5]:
torch.manual_seed(42)

attention = Attention(d_model=2, row_dim=0, col_dim=1)

In [6]:
original_output = attention(encodings_for_q, encodings_for_k, encodings_for_v)
original_output

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1, num_heads=1):
        super().__init__()
        self.num_heads = num_heads
        self.row_dim = row_dim
        self.col_dim = col_dim
        self.d_model = d_model

        self.heads = nn.ModuleList(
            [
                Attention(d_model, row_dim, col_dim)
                for _ in range(num_heads)
            ]
        )

    def forward(self, encodings_for_q, encodings_for_k, encodings_for_v):
        return torch.cat([
            head(encodings_for_q, encodings_for_k, encodings_for_v)
            for head in self.heads
        ], dim=self.col_dim)

In [9]:
torch.manual_seed(42)

multi_head_attention = MultiHeadAttention(
    d_model=2,
    row_dim=0,
    col_dim=1,
    num_heads=1
)

original_output_2 = multi_head_attention(encodings_for_q, encodings_for_k, encodings_for_v)
original_output_2

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<CatBackward0>)

In [10]:
torch.manual_seed(42)

multi_head_attention = MultiHeadAttention(
    d_model=2,
    row_dim=0,
    col_dim=1,
    num_heads=2
)

original_output_3 = multi_head_attention(encodings_for_q, encodings_for_k, encodings_for_v)
original_output_3

tensor([[ 1.0100,  1.0641, -0.7081, -0.8268],
        [ 0.2040,  0.7057, -0.7417, -0.9193],
        [ 3.4989,  2.2427, -0.7190, -0.8447]], grad_fn=<CatBackward0>)