# Coding Masked Self-Attention in PyTorch

A hands-on lesson to implement masked (causal) self-attention in PyTorch, verify the math step-by-step, and understand how the causal mask works.

![Transformer architecture diagram (Wikimedia Commons)](https://commons.wikimedia.org/wiki/Special:FilePath/Transformer%2C_full_architecture.png)

- Paper: [Attention Is All You Need (Vaswani et al., 2017)](https://arxiv.org/abs/1706.03762)
- Visual guide: [The Illustrated Transformer (Jay Alammar)](https://jalammar.github.io/illustrated-transformer/)
- Interactive: [Transformer Explainer (GPT‑2 attention)](https://poloclub.github.io/transformer-explainer/)
- PyTorch docs: [nn.MultiheadAttention](https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)

## Imports and prerequisites

We use PyTorch to implement masked self-attention:

- `torch`: tensors and linear algebra helpers.
- `torch.nn` (`nn`): layers like `Linear` and base class `Module`.
- `torch.nn.functional` (`F`): stateless ops like `softmax` used in attention.

Note: Tensors are multi-dimensional arrays optimized for neural networks.

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Implementing the `MaskedSelfAttention` class

We implement masked self-attention as a standard `nn.Module` and add support for an optional mask.

### Init arguments

| Argument | Meaning | Why it matters |
| --- | --- | --- |
| `d_model` | Features per token (embedding size after positions) | Sets sizes of `W_q`, `W_k`, `W_v` |
| `row_dim` | Axis for rows in matmul/transposes | Controls `K.T` behavior |
| `col_dim` | Axis for columns/softmax | Controls softmax dimension and sqrt(d_k) extraction |

- Three `nn.Linear` layers (no bias) parameterize `W_q`, `W_k`, `W_v`.
- Following the original Transformer, attention projections typically omit bias terms.

### Forward pass with causal mask

```text
q = W_q(X)
k = W_k(X)
v = W_v(X)
S = q · k^T
S_scaled = S / sqrt(d_k)
if mask is not None:
    S_scaled = S_scaled.masked_fill(mask, -1e9)
A = softmax(S_scaled, dim=col_dim)
O = A · v
```

- `mask` is boolean: `True` = disallow (gets `-1e9`), `False` = allow (adds 0).
- This yields causal (look-ahead) masking for decoder-style attention.

## Implementing the `MaskedSelfAttention` class

We implement masked self-attention as a standard `nn.Module` and add support for an optional mask.

### Init arguments

| Argument | Meaning | Why it matters |
| --- | --- | --- |
| `d_model` | Features per token (embedding size after positions) | Sets sizes of `W_q`, `W_k`, `W_v` |
| `row_dim` | Axis for rows in matmul/transposes | Controls `K.T` behavior |
| `col_dim` | Axis for columns/softmax | Controls softmax dimension and sqrt(d_k) extraction |

- Three `nn.Linear` layers (no bias) parameterize `W_q`, `W_k`, `W_v`.
- Following the original Transformer, attention projections typically omit bias terms.

### Forward pass with causal mask

```text
q = W_q(X)
k = W_k(X)
v = W_v(X)
S = q · k^T
S_scaled = S / sqrt(d_k)
if mask is not None:
    S_scaled = S_scaled.masked_fill(mask, -1e9)
A = softmax(S_scaled, dim=col_dim)
O = A · v
```

- `mask` is boolean: `True` = disallow (gets `-1e9`), `False` = allow (adds 0).
- This yields causal (look-ahead) masking for decoder-style attention.

## Sample token encodings (toy example)

We’ll reuse the same tiny 2D encodings so we can verify all math by hand.

- Shape: `encodings_matrix` is `n_tokens × d_model = 3 × 2`.
- In practice, `d_model` is often much larger (e.g., 512), but 2 keeps examples simple.

| Token index | Encoded values (example) |
| --- | --- |
| 0 | [1.16, 0.23] |
| 1 | [0.57, 1.36] |
| 2 | [4.41, -2.16] |

## Seeding, instantiation, and causal mask

We use a fixed seed for reproducibility and instantiate the masked self-attention module.

- `torch.manual_seed(42)`: deterministic weights/outputs for this demo.
- `MaskedSelfAttention(d_model=2, row_dim=0, col_dim=1)`: 2 features per token; softmax along dim=1.

### Building a causal (look-ahead) mask

We want each token i to attend only to tokens ≤ i.

- Start with ones: `torch.ones(n, n)`.
- Keep the lower triangle (incl. diagonal): `torch.tril(...)`.
- Convert to boolean "disallow" mask by comparing to zero: `mask = (tril == 0)`.
  - `True` → disallow (will become −1e9 via `masked_fill`).
  - `False` → allow (adds 0).

Result: a lower-triangular allow-pattern that prevents looking ahead.

In [16]:
class MaskedSelfAttention(nn.Module):
    def __init__(self, d_model=2, row_dim=2, col_dim=2):
        super().__init__()

        self.d_model = d_model
        self.row_dim = row_dim
        self.col_dim = col_dim

        self.W_q = nn.Linear(d_model, d_model, bias=False)
        self.W_k = nn.Linear(d_model, d_model, bias=False)
        self.W_v = nn.Linear(d_model, d_model, bias=False)

    def forward(self, token_encodings, mask=None):
        q = self.W_q(token_encodings)
        k = self.W_k(token_encodings)
        v = self.W_v(token_encodings)

        sims = q @ k.T
        scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)

        if mask is not None:
            scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)

        attention_percents = F.softmax(scaled_sims, dim=self.col_dim)
        attention_scores = torch.matmul(attention_percents, v)

        return attention_scores

In [17]:
encodings_matrix = torch.tensor([
    [1.16, 0.23],
    [0.57, 1.36],
    [4.41, -2.16],
])

In [18]:
torch.manual_seed(42)

masked_self_attention = MaskedSelfAttention(d_model=2, row_dim=0, col_dim=1)
mask = torch.tril(torch.ones(3, 3))
mask = mask == 0

mask

tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])

In [19]:
original_output = masked_self_attention(encodings_matrix, mask)
original_output

tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)

### Validate Weights and Manual Calculations

Now we’ll inspect the learned (randomly initialized) weights and validate the math.

- We transpose `weight` for readability because of how PyTorch stores/prints linear layer weights.
- By combining the printed `W_q`, `W_k`, `W_v` with the original encodings, you can recompute Q, K, V and verify each step by hand.
- This confirms our implementation matches the scaled dot-product attention math.

In [20]:
q = masked_self_attention.W_q(encodings_matrix)
q

tensor([[ 0.7621, -0.0428],
        [ 1.1063,  0.7890],
        [ 1.1164, -2.1336]], grad_fn=<MmBackward0>)

In [21]:
k = masked_self_attention.W_k(encodings_matrix)
k

tensor([[-0.1469, -0.3038],
        [ 0.1057,  0.3685],
        [-0.9914, -2.4152]], grad_fn=<MmBackward0>)

In [22]:
v = masked_self_attention.W_v(encodings_matrix)
v

tensor([[ 0.6038,  0.7434],
        [-0.3502,  0.5303],
        [ 3.8695,  2.4246]], grad_fn=<MmBackward0>)

In [23]:
sims = torch.matmul(q, k.transpose(dim0=0, dim1=1))
sims

tensor([[-0.0990,  0.0648, -0.6523],
        [-0.4022,  0.4078, -3.0024],
        [ 0.4842, -0.6683,  4.0461]], grad_fn=<MmBackward0>)

In [24]:
scaled_sims = sims / (torch.tensor(2)**0.5)
scaled_sims

tensor([[-0.0700,  0.0458, -0.4612],
        [-0.2844,  0.2883, -2.1230],
        [ 0.3424, -0.4725,  2.8610]], grad_fn=<DivBackward0>)

In [25]:
masked_scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)
masked_scaled_sims

tensor([[-6.9975e-02, -1.0000e+09, -1.0000e+09],
        [-2.8442e-01,  2.8833e-01, -1.0000e+09],
        [ 3.4241e-01, -4.7253e-01,  2.8610e+00]],
       grad_fn=<MaskedFillBackward0>)

In [26]:
attention_percents = F.softmax(masked_scaled_sims, dim=1)
attention_percents

tensor([[1.0000, 0.0000, 0.0000],
        [0.3606, 0.6394, 0.0000],
        [0.0722, 0.0320, 0.8959]], grad_fn=<SoftmaxBackward0>)

In [27]:
final_output = torch.matmul(attention_percents, masked_self_attention.W_v(encodings_matrix))
final_output


tensor([[ 0.6038,  0.7434],
        [-0.0062,  0.6072],
        [ 3.4989,  2.2427]], grad_fn=<MmBackward0>)

In [28]:
original_output == final_output

tensor([[True, True],
        [True, True],
        [True, True]])