# Coding Self-Attention in PyTorch

A hands-on lesson to implement scaled dot-product self-attention in PyTorch, verify calculations numerically, and connect the math to the code.

![Transformer architecture diagram (Wikimedia Commons)](https://commons.wikimedia.org/wiki/Special:FilePath/Transformer%2C_full_architecture.png)

- Paper: [Attention Is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762)
- Visual guide: [The Illustrated Transformer (Jay Alammar)](https://jalammar.github.io/illustrated-transformer/)
- PyTorch docs: [nn.MultiheadAttention](https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)

## Imports and prerequisites

In this lesson, we use PyTorch to build and run self-attention:

- `torch`: tensors and linear algebra helpers.
- `torch.nn` (`nn`): neural network layers like `Linear` and base class `Module`.
- `torch.nn.functional` (`F`): stateless ops like `softmax` used in attention.

Note: Tensors are multi-dimensional arrays optimized for neural networks.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

  cpu = _conversion_method_template(device=torch.device("cpu"))


## Implementing the SelfAttention class

We implement a self-attention module as a standard `nn.Module` so it plugs into PyTorch pipelines.

### __init__ parameters

| Parameter | Meaning | Why it matters |
| --- | --- | --- |
| `d_model` | Number of features per token (embedding size after positions) | Sets the sizes of weight matrices for Q/K/V |
| `row_dim` | Axis used for rows in matrix ops | Controls how we transpose during `QK^T` |
| `col_dim` | Axis used for columns in matrix ops | Controls softmax axis and matmul behavior |

- We create three `nn.Linear` layers without bias (`bias=False`) to realize the learnable matrices `W_q`, `W_k`, and `W_v`.
- Following the original Transformer, attention projections typically omit bias terms.
- We store `row_dim` and `col_dim` for flexible batching later (here we use a simple 2D example without batches).

### Forward pass sketch

```text
q = W_q(X)
k = W_k(X)
v = W_v(X)
S = q · k^T
S_scaled = S / sqrt(d_k)
A = softmax(S_scaled, dim=col_dim)
O = A · v
```

This matches the scaled dot-product attention described in Vaswani et al. (2017).

In [2]:
class SelfAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=0, col_dim=1):
        super().__init__()

        self.W_q = nn.Linear(
            in_features=d_model,
            out_features=d_model,
            bias=False,
        )
        self.W_k = nn.Linear(
            in_features=d_model,
            out_features=d_model,
            bias=False,
        )
        self.W_v = nn.Linear(
            in_features=d_model,
            out_features=d_model,
            bias=False,
        )

        self.row_dim = row_dim
        self.col_dim = col_dim

    def forward(self, token_encodings):
        q = self.W_q(token_encodings)
        k = self.W_k(token_encodings)
        v = self.W_v(token_encodings)

        sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)
        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

## Sample token encodings (toy example)

We’ll work with a tiny 2D encoding per token so we can verify all math by hand.

- Shape: `encodings_matrix` is `n_tokens × d_model = 3 × 2`.
- In practice, `d_model` is often 256–4096+, but 2D keeps examples simple.

| Token index | Encoded values (example) |
| --- | --- |
| 0 | [1.16, 0.23] |
| 1 | [0.57, 1.36] |
| 2 | [4.41, -2.16] |

In [None]:
encodings_matrix = torch.tensor([
    [1.16, 0.23],
    [0.57, 1.36],
    [4.41, -2.16],
])

## Seeding and instantiation

We create a reproducible run and instantiate our self-attention module.

- `torch.manual_seed(42)`: sets the random seed so weights and results are repeatable.
- `SelfAttention(d_model=2, row_dim=0, col_dim=1)`: tiny model (2 features/token) for hand-checking math.

| Argument | Value | Purpose |
| --- | --- | --- |
| `d_model` | 2 | Features per token; sets sizes of `W_q`, `W_k`, `W_v` |
| `row_dim` | 0 | Row axis used in `k.transpose` inside `QK^T` |
| `col_dim` | 1 | Column/feature axis; softmax is applied along this axis |

Tip: In batched settings you’d add a batch dimension and keep `row_dim`/`col_dim` consistent with tensor layout.

## Shapes and intermediate tensors

With `n = 3` tokens and `d_model = d_k = d_v = 2`, the shapes are:

| Tensor | Meaning | Shape |
| --- | --- | --- |
| `X` | Encoded tokens (embeddings + positions) | 3 × 2 |
| `q = W_q(X)` | Queries | 3 × 2 |
| `k = W_k(X)` | Keys | 3 × 2 |
| `v = W_v(X)` | Values | 3 × 2 |
| `sims = q · k^T` | Unscaled similarities | 3 × 3 |
| `scaled_sims = sims / sqrt(d_k)` | Scaled similarities | 3 × 3 |
| `attention_percents = softmax(scaled_sims, dim=1)` | Attention weights (rows sum to 1) | 3 × 3 |
| `attention_scores = attention_percents · v` | Output (context-aware) | 3 × 2 |

The next cell runs the full forward pass and returns `attention_scores`.

In [4]:
torch.manual_seed(42)

self_attention = SelfAttention(
    d_model=2,
    row_dim=0,
    col_dim=1
)

In [22]:
original_output = self_attention(encodings_matrix)
original_output

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

### Validate Weights and Manual Calculations

Now we’ll inspect the learned (randomly initialized) weights and validate the math.

- We transpose `weight` for readability because of how PyTorch stores/prints linear layer weights.
- By combining the printed `W_q`, `W_k`, `W_v` with the original encodings, you can recompute Q, K, V and verify each step by hand.
- This confirms our implementation matches the scaled dot-product attention math.

In [7]:
self_attention.W_q.weight.transpose(0, 1)

tensor([[ 0.5406, -0.1657],
        [ 0.5869,  0.6496]], grad_fn=<TransposeBackward0>)

In [8]:
self_attention.W_k.weight.transpose(0, 1)

tensor([[-0.1549, -0.3443],
        [ 0.1427,  0.4153]], grad_fn=<TransposeBackward0>)

In [9]:
self_attention.W_v.weight.transpose(0, 1)

tensor([[ 0.6233,  0.6146],
        [-0.5188,  0.1323]], grad_fn=<TransposeBackward0>)

In [13]:
q = self_attention.W_q(encodings_matrix)
q

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)

In [14]:
k = self_attention.W_k(encodings_matrix)
k

tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)

In [15]:
v = self_attention.W_v(encodings_matrix)
v

tensor([[ 0.6038,  0.7434],
        [-0.3502,  0.5303],
        [ 3.8695,  2.4246]], grad_fn=<MmBackward0>)

In [16]:
sims = torch.matmul(q, k.transpose(dim0=0, dim1=1))
sims

tensor([[-0.0990,  0.0648, -0.6523],
        [-0.4022,  0.4078, -3.0024],
        [ 0.4842, -0.6683,  4.0461]], grad_fn=<MmBackward0>)

In [17]:
scaled_sims = sims / (torch.tensor(2)**0.5)
scaled_sims

tensor([[-0.0700,  0.0458, -0.4612],
        [-0.2844,  0.2883, -2.1230],
        [ 0.3424, -0.4725,  2.8610]], grad_fn=<DivBackward0>)

In [18]:
attention_percents = F.softmax(scaled_sims, dim=1)
attention_percents

tensor([[0.3573, 0.4011, 0.2416],
        [0.3410, 0.6047, 0.0542],
        [0.0722, 0.0320, 0.8959]], grad_fn=<SoftmaxBackward0>)

In [20]:
final_output = torch.matmul(attention_percents, self_attention.W_v(encodings_matrix))
final_output

tensor([[1.0100, 1.0641],
        [0.2040, 0.7057],
        [3.4989, 2.2427]], grad_fn=<MmBackward0>)

In [23]:
final_output == original_output

tensor([[True, True],
        [True, True],
        [True, True]])

## References and further reading

- Vaswani et al., 2017 — [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
- Jay Alammar — [The Illustrated Transformer](https://jalammar.github.io/illustrated-transformer/)
- PyTorch — [nn.MultiheadAttention docs](https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)
- UvA DL — [Transformers and Multi-Head Attention](https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html)
- Interactive — [Transformer Explainer (GPT‑2 attention)](https://poloclub.github.io/transformer-explainer/)