A clean, faithful PyTorch implementation of the original Transformer architecture from:
Attention Is All You Need
Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Łukasz Kaiser, Illia Polosukhin
NeurIPS 2017 — https://arxiv.org/abs/1706.03762
Every component maps directly to a numbered section of the paper. The base-model hyperparameters from Table 3 are used as defaults throughout.
- Architecture Overview
- Project Structure
- Quick Start
- Installation
- Module Reference
- Hyperparameters
- Usage Examples
- Design Decisions
- Scope & Limitations
The Transformer is a sequence-to-sequence model built entirely on attention — no recurrence and no convolutions.
src tokens ──► Embedding × √d_model ──► + Positional Encoding
│
┌──────────▼──────────┐
│ Encoder Layer × N │
│ ┌─────────────────┐ │
│ │ Self-Attention │ │
│ ├─────────────────┤ │
│ │ Add & Norm │ │
│ ├─────────────────┤ │
│ │ Feed-Forward │ │
│ ├─────────────────┤ │
│ │ Add & Norm │ │
│ └─────────────────┘ │
└──────────┬──────────┘
encoder memory (B, S, d_model)
│
tgt tokens ──► Embedding × √d_model ──► + Positional Encoding
│
┌──────────▼──────────┐
│ Decoder Layer × N │
│ ┌─────────────────┐ │
│ │ Masked Self-Attn │ │
│ ├─────────────────┤ │
│ │ Add & Norm │ │
│ ├─────────────────┤ │
│ │ Cross-Attention │◄──── encoder memory
│ ├─────────────────┤ │
│ │ Add & Norm │ │
│ ├─────────────────┤ │
│ │ Feed-Forward │ │
│ ├─────────────────┤ │
│ │ Add & Norm │ │
│ └─────────────────┘ │
└──────────┬──────────┘
│
Linear (d_model → vocab_size)
│
logits (B, T, V)
transformer/
├── transformer/ # Python package — one file per component
│ ├── __init__.py # Public API (all classes re-exported here)
│ ├── masks.py # make_causal_mask, make_padding_mask
│ ├── attention.py # ScaledDotProductAttention, MultiHeadAttention
│ ├── blocks.py # PositionwiseFeedForward, PositionalEncoding,
│ │ # SublayerConnection
│ ├── encoder.py # EncoderLayer, Encoder
│ ├── decoder.py # DecoderLayer, Decoder
│ └── model.py # Transformer (full model)
├── transformer.py # Original single-file implementation (reference)
├── smoke_test.py # Shape + correctness checks
├── requirements.txt # PyTorch ≥ 2.0
└── README.md
# 1. Create and activate the virtual environment
python3 -m venv venv
source venv/bin/activate
# 2. Install dependencies
pip install -r requirements.txt
# 3. Run the smoke test
python smoke_test.pyExpected output:
Device : cpu
Output shape : (2, 12, 10000) ✓
Causal mask : future tokens masked ✓
Head dim : d_k = d_v = 64 ✓
Parameters : 59,463,680 (~60-65 M expected for base model)
Param count : within expected range ✓
All checks passed.
- Python 3.9+
- pip
# Clone / navigate to the project folder
cd "path/to/transformer"
# Create a virtual environment
python3 -m venv venv
# Activate it
source venv/bin/activate # Linux / macOS
# venv\Scripts\activate # Windows
# Install PyTorch (CPU-only; see https://pytorch.org for GPU builds)
pip install -r requirements.txt
# Optionally install an editable local package
pip install -e .GPU support: replace the
torchline inrequirements.txtwith the appropriate CUDA wheel from https://pytorch.org/get-started/locally/.
transformer.masks — make_causal_mask, make_padding_mask
| Function | Signature | Description |
|---|---|---|
make_causal_mask |
(size, device) → (1, 1, T, T) |
Upper-triangular bool mask; True = future position, blocked. |
make_padding_mask |
(seq, pad_idx=0) → (B, 1, 1, T) |
True at pad token positions. |
Both masks broadcast over (B, h, T_q, T_k) and can be combined with |.
from transformer import make_causal_mask, make_padding_mask
causal = make_causal_mask(T, device) # (1, 1, T, T)
pad_mask = make_padding_mask(tgt) # (B, 1, 1, T)
tgt_mask = causal | pad_mask # (B, 1, T, T)transformer.attention.ScaledDotProductAttention — §3.2.1
- Scales scores by
$1/\sqrt{d_k}$ to stabilise gradients. - Fills masked positions with
$-\infty$ before softmax → zero weight. - Applies dropout to attention weights.
attn = ScaledDotProductAttention(dropout=0.1)
output, weights = attn(q, k, v, mask=tgt_mask)
# output: (..., T_q, d_v)
# weights: (..., T_q, T_k)transformer.attention.MultiHeadAttention — §3.2.2
- Projection matrices
$W^Q, W^K, W^V, W^O$ are allLinearwith no bias. - Per-head projections are implemented as a single batched
Linear(d_model → d_model)then reshaped for efficiency. -
$d_k = d_v = d_\text{model} / h = 64$ for the base model.
| Parameter | Default | Paper |
|---|---|---|
d_model |
512 | 512 |
h |
8 | 8 |
dropout |
0.1 | 0.1 |
mha = MultiHeadAttention(d_model=512, h=8, dropout=0.1)
out = mha(query, key, value, mask=None) # (B, T_q, d_model)Three uses in the full model:
| Location | Q source | K, V source | Mask |
|---|---|---|---|
| Encoder self-attention | encoder input | encoder input | padding mask |
| Decoder masked self-attention | decoder input | decoder input | causal + padding mask |
| Decoder cross-attention | decoder hidden | encoder output | source padding mask |
transformer.blocks.PositionwiseFeedForward — §3.3
Applied identically and independently to each position (like a 1×1 convolution over the sequence dimension).
| Parameter | Default | Paper |
|---|---|---|
d_model |
512 | 512 |
d_ff |
2048 | 2048 |
dropout |
0.1 | 0.1 |
The hidden dimension
transformer.blocks.PositionalEncoding — §3.5
- Non-learned — computed once at construction and stored as a buffer.
- Allows the model to attend to relative positions via linear combinations of sin/cos.
- Added to embeddings that are first scaled by
$\sqrt{d_\text{model}}$ (§3.4). - Division terms are computed in log-space for numerical stability.
- Supports sequences up to
max_len = 5000by default.
transformer.blocks.SublayerConnection — §3.1
Wraps every attention and FFN sub-layer in both the encoder and decoder. The post-norm formulation matches the original paper (pre-norm variants are popular in practice but not in this implementation).
transformer.encoder — §3.1
Encoder stacks N = 6 identical EncoderLayer blocks:
src → Embedding × √d_model → + PositionalEncoding → EncoderLayer × N → LayerNorm → memory
Each EncoderLayer contains:
x → [Self-Attention] → SublayerConnection → [FFN] → SublayerConnection → x'
transformer.decoder — §3.1
Decoder stacks N = 6 identical DecoderLayer blocks:
tgt → Embedding × √d_model → + PositionalEncoding → DecoderLayer × N → LayerNorm → hidden
Each DecoderLayer contains three sub-layers:
x → [Masked Self-Attention] → SublayerConnection
→ [Cross-Attention (K,V from encoder)] → SublayerConnection
→ [FFN] → SublayerConnection → x'
The masked self-attention uses the causal mask to prevent position
transformer.model.Transformer — §3
The top-level model:
model = Transformer(
src_vocab_size=32_000,
tgt_vocab_size=32_000,
d_model=512, # model dimension
h=8, # attention heads
d_ff=2048, # FFN inner dim
N=6, # encoder / decoder layers
dropout=0.1,
max_len=5000,
)
logits = model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
# logits: (B, T, tgt_vocab_size) — raw pre-softmax scoresWeight initialisation (§5.3): Xavier uniform for all Linear layers; scaled normal (Embedding layers.
Convenience methods for stepwise inference:
memory = model.encode(src, src_mask) # (B, S, d_model)
hidden = model.decode(tgt, memory, tgt_mask, src_mask) # (B, T, d_model)Paper base model defaults (Table 3 of the paper):
| Parameter | Symbol | Value | Notes |
|---|---|---|---|
| Model dimension | 512 | Embedding and hidden size | |
| Feed-forward dim | 2048 | FFN inner layer; 4× |
|
| Attention heads | 8 | ||
| Head dimension | 64 | ||
| Encoder/Decoder layers | 6 | ||
| Dropout | 0.1 | Applied after attention, FFN, embedding | |
| Max sequence length | — | 5000 | PE buffer size |
| ~Parameters | — | ~60 M | Depends on vocab size / weight tying |
All parameters are constructor arguments so the "big model" (or any other variant) is easy to configure:
# Paper "big" model
big = Transformer(
src_vocab_size=37_000,
tgt_vocab_size=37_000,
d_model=1024,
h=16,
d_ff=4096,
N=6,
dropout=0.3,
)import torch
from transformer import make_causal_mask, make_padding_mask
PAD = 0
src = torch.tensor([[5, 3, 7, PAD, PAD]]) # (1, 5)
tgt = torch.tensor([[2, 8, PAD]]) # (1, 3)
src_mask = make_padding_mask(src, PAD) # (1, 1, 1, 5)
causal = make_causal_mask(tgt.size(1), src.device) # (1, 1, 3, 3)
tgt_mask = causal | make_padding_mask(tgt, PAD) # (1, 1, 3, 3)from transformer import Transformer
model = Transformer(src_vocab_size=10_000, tgt_vocab_size=10_000)
model.eval()
with torch.no_grad():
logits = model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
# logits: (1, 3, 10000)
probs = logits.softmax(dim=-1)
predicted_ids = logits.argmax(dim=-1) # (1, 3)model.eval()
with torch.no_grad():
memory = model.encode(src, src_mask) # encode once
# Start with <BOS> token
ys = torch.full((1, 1), BOS_IDX, dtype=torch.long)
for _ in range(max_len):
T = ys.size(1)
tgt_mask = make_causal_mask(T, src.device)
hidden = model.decode(ys, memory, tgt_mask, src_mask)
next_logits = model.output_projection(hidden[:, -1]) # (1, vocab)
next_token = next_logits.argmax(dim=-1, keepdim=True) # (1, 1)
ys = torch.cat([ys, next_token], dim=1)
if next_token.item() == EOS_IDX:
breakfrom transformer import MultiHeadAttention, PositionalEncoding
# Stand-alone multi-head attention
mha = MultiHeadAttention(d_model=512, h=8, dropout=0.1)
out = mha(query, key, value, mask=None)
# Stand-alone positional encoding
pe = PositionalEncoding(d_model=512, dropout=0.1, max_len=5000)
x_with_pe = pe(x) # x: (B, T, 512)| Decision | Rationale |
|---|---|
| Post-norm (LayerNorm after residual) | Matches the original paper §3.1; pre-norm is more stable to train but changes the architecture. |
| No bias on projection matrices | The paper uses pure linear projections |
| Batched head projections |
Linear(d_model → d_model) then reshape is equivalent to Linear(d_model → d_k) calls but faster due to a single GEMM. |
| Sinusoidal PE as a buffer | Non-learned; matches §3.5 exactly. Can be swapped for learned PE by replacing PositionalEncoding. |
| Division terms in log-space |
exp(arange * -log(10000) / d_model) avoids computing large intermediate powers. |
| Xavier uniform init | Standard for Transformer-like models; the paper mentions using a specific schedule with warm-up but does not specify exact init (§5.3). |
| Separate src/tgt vocabularies | Allows different source and target languages. The paper optionally shares weights — easy to add by passing the same Embedding. |
mask: True = block |
Consistent convention across all mask functions; positions where mask is True receive |
This implementation covers every architectural component from the paper. The following are intentionally out of scope:
- No training loop / optimizer / learning rate scheduler — the paper's schedule with warm-up steps is non-trivial; adding one would conflate architecture with training code.
- No tokenization or dataset loading — task-agnostic.
- No beam search — the greedy decoding sketch above is included as an example only.
- No weight tying — the paper optionally ties the src embedding, tgt embedding, and output projection weights (§3.4); straightforward to add.
- No label smoothing — used in the paper's training (§5.4) but belongs in the loss function, not the model.
This implementation is released for educational and research use.
Paper copyright belongs to the original authors.