# Building a Transformer from Scratch

This notebook will lead you through building your own transformer from the ground up. We'll go through positional encoding, multi-headed attention, and more. Then we'll address a simple next-token-prediction problem where we try to predict the next number in a cyclic sequence (0->1->2->3...->9->0).

Note: At the bottom of some of the coding cells there are notes about useful methods, helpful syntax, or some hints that can make the code cell easier and more clean.

## Why are transformers important?

From lecture we learned that transformers make use of a concept called **attention** to keep track of the importance of a token in a given sequence to every other token in that sequence.

Tokens are often modeled as words with sequences being sentences or paragraphs, but can also be extended to other applications. We briefly touched upon things like ViTs (Vision Transformers) that model tokens as pixels and sequences as images, and audio transformers that model tokens as audio signals and secions of audio data and sequences as a stream of audio.

In the task we're addressing token we can consider our tokens to be single numbers and our sequence to be the next number to follow.

In [1]:
from __future__ import annotations
import math
import random
from typing import List

import torch
import torch.nn as nn
import torch.nn.functional as F

As you may recall from lecture, transformers use positional encoding to keep track of the order of words so they know which words came first.

Since transformers process input tokens in parallel rather than sequentially, they lack an inherent sense of word order. To address this, positional encodings are added to the input embeddings, allowing the model to capture the relative and absolute positions of tokens.

There are two common types of positional encoding: sinusoidal and learned. The original transformer paper used sinusoidal functions of different frequencies to generate fixed position vectors, enabling the model to extrapolate to longer sequences. Alternatively, learned positional embeddings treat positions like tokens and learn their representations during training. In both cases, these encodings are added to the token embeddings before entering the self-attention layers, enabling the model to reason about position-dependent meaning.

\

Today we'll be using **sinusoidal encodings**.



In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, max_len: int, dim: int):
        super().__init__()

        # Explainations Below
        pos_enc = torch.zeros(max_len, dim)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, dim, 2, dtype=torch.float32) * (-math.log(10000.0) / dim))

        # TODO
        # Set every EVEN column (0, 2, 4...) in pos_enc to sin(position * div_term)
        # Set every ODD columns (1, 3, 5...) in pos_enc to cos(position * div_term)
        # Make sure to use the right indexing for position (e.g. position[0] (or just 0) for column 0)
        # or apply tensor math to calculate all values in one line
        pos_enc[:, 0::2] = torch.sin(position * div_term)
        pos_enc[:, 1::2] = torch.cos(position * div_term)

        # Using register_buffer instead of just self.pos_enc = ... means this will not be treated as a trained parameters
        # This is intended for pos_enc because we are using set sinusoidal encodings so we don't want them to change
        self.register_buffer("pos_enc", pos_enc.unsqueeze(0))

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        return x + self.pos_enc[:, : x.size(1)]

**Variable Descriptions**

`pos_enc`: A matrix holding the positional encodings for each word.

`position`: An array with index values for each word up to max_len. ([0, 1, 2..., max_len-1])

`div_term`: A scaling factor to avoid repeating sine and cosine frequencies

**Useful Functions**

`torch.sin`, `torch.cos`

**Hints**

- You can use the syntax arr[:, 0::2] to set all values in the even columns (start at 0 and step 2 columns at a time until the end)
- You can use the syntax `pos_enc[:, 0::2] = torch.sin(position * div_term)` to set the positional encodings for the even columns and modify this code for the odd columns

In this cell we'll focus on implementing our Query, Key, and Value weights for our attention block and some of the calculations involved. Since our Query, Key, and Value matrices are trainable weights we can represent them as linear layers and learn their associated weight values.

<center width="100%"><img src="https://drive.google.com/uc?export=view&id=1-J6feRuS3qP9dD79Zq-VdVsL9432WMCF" width="600px"></center>


The graphic above gives some intuition as to the size of the QKV matrices/linear layers and shows the equation we need to implement (highlighted in yellow) which is then used in the final output (in grey). The equation you need to implement is also listed below

$$\frac{Q \cdot K^\intercal}{\sqrt{dim}}$$

In [3]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, proj_dim: int, num_heads: int, dropout: float = 0.1):
        super().__init__()

        # Double check that the proj_dim is divisible by the number of heads
        assert proj_dim % num_heads == 0, "Hidden dimension (proj_dim) must be divisible by the number of heads (num_heads)"
        self.num_heads = num_heads
        self.d_k = proj_dim // num_heads

        # TODO: Map a linear layer to each input projection
        self.q_proj = nn.Linear(proj_dim, proj_dim)
        self.k_proj = nn.Linear(proj_dim, proj_dim)
        self.v_proj = nn.Linear(proj_dim, proj_dim)
        self.o_proj = nn.Linear(proj_dim, proj_dim)

        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x: torch.Tensor) -> torch.Tensor:  # (B, L, H) -> (B, h, L, d)
        B, L, H = x.shape
        x = x.view(B, L, self.num_heads, self.d_k)
        return x.transpose(1, 2)

    def combine_heads(self, x: torch.Tensor) -> torch.Tensor:  # (B, h, L, d) -> (B, L, H)
        B, h, L, d = x.shape
        return x.transpose(1, 2).contiguous().view(B, L, h * d)

    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:

        # TODO: Use one of the methods above to split the heads for each matrix
        Q = self.split_heads(self.q_proj(x))
        K = self.split_heads(self.k_proj(x))
        V = self.split_heads(self.v_proj(x))

        K_transpose = K.transpose(-2, -1)
        # TODO: Calculate the part of the Z matrix highlighted above (the inside of the softmax function)
        scores = (Q @ K_transpose) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        # The @ sign denotes matrix multiplication in numpy
        context = attn @ V  # (B, h, L, d)
        context = self.combine_heads(context)

        return self.o_proj(context)

**Useful Syntax**

`@` in PyTorch and number is used for matrix multiplication (e.g. Q @ K multiplies matrix Q by matrix K)

**Hints**

- In matrix multiplication the communtative property does NOT hold, so the ORDER MATTERS.

- All of the projection matrices (e.g. self.q_proj) should be initialized as the save value.

- When calling split_heads perform a forward pass through the current projections (e.g. Q = split_heads(self.q_proj(x)) )

In the following code block we'll set up a positionwise feed forward block using pre-existing torch layers and a ReLU activation function.

In [4]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, hidden_dim: int, ff_dim: int, dropout: float = 0.1):
        super().__init__()
        # This probably doesn't need to be a TODO
        # TODO: Fill in the layers with appropriate parameters
        self.fc1 = nn.Linear(hidden_dim, ff_dim)
        self.act_fn = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(ff_dim, hidden_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:  # type: ignore[override]
        # TODO: Fill out the forward pass
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

Next we'll put all of that together along with some layer normalization to create our transformer block. Take a look at the image below to see how our block aligns with the transformer architecture.

<center width="100%"><img src="https://d1.awsstatic.com/GENAI-1.151ded5440b4c997bac0642ec669a00acff2cca1.png" width="300px"></center>

In [5]:
class TransformerBlock(nn.Module):
    def __init__(self, proj_dim: int, num_heads: int, ff_dim: int, dropout: float = 0.1):
        super().__init__()
        # TODO: Fill out our main transformer block according to the architecture above and the layers we've made
        self.multihead_attn = MultiHeadSelfAttention(proj_dim, num_heads, dropout)
        self.ln1 = nn.LayerNorm(proj_dim)
        self.pos_wise_ff = PositionwiseFeedForward(proj_dim, ff_dim, dropout)
        self.ln2 = nn.LayerNorm(proj_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
        x = x + self.dropout(self.multihead_attn(x, mask))
        x = self.ln1(x)
        x = x + self.dropout(self.pos_wise_ff(x))
        x = self.ln2(x)
        return x

Next we'll combine our transformer block from above with our positional encoding layer from earlier to create our full transformer.

In [6]:
class TransformerEncoder(nn.Module):
    def __init__(
        self,
        num_layers: int,
        hidden_dim: int,
        num_heads: int,
        ff_dim: int,
        vocab_size: int,
        max_len: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        # Simple one‑hot → hidden projection (equivalent to an embedding matrix)

        # TODO: Follow the Transformer Encoder architecture above to implement this block using the layers we have created
        self.token_emb = nn.Linear(vocab_size, hidden_dim, bias=False)
        self.pos_enc = PositionalEncoding(max_len, hidden_dim)
        self.transformer_blocks = nn.ModuleList(
            [TransformerBlock(hidden_dim, num_heads, ff_dim, dropout) for _ in range(num_layers)]
        )
        self.ln_final = nn.LayerNorm(hidden_dim)

    def forward(self, x_onehot: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:  # type: ignore[override]
        x = self.token_emb(x_onehot)  # (B, L, H)
        x = self.pos_enc(x)
        # TODO: Replace this if not using nn.ModuleList
        for layer in self.transformer_blocks:
            x = layer(x, mask)

        return self.ln_final(x)

**Useful Functions**

`nn.ModuleList()` takes a list of layers as input so we can change the number of transformer blocks we're using using the num_layers parameter. Feel free as well to hardcode to number of transformer blocks and ignore the num_layers parameter.

If you opt not to use `nn.ModuleList()` you will have to update the forward pass slightly to match your setup.

In [7]:
torch.manual_seed(0)
random.seed(0)

# Synthetic digit dataset
vocab = [str(d) for d in range(10)]
V = len(vocab)
itos = {i: ch for i, ch in enumerate(vocab)}

def generate_batch(batch_size: int, seq_len: int):
    x_idx = torch.randint(0, V, (batch_size, seq_len))
    y_idx = (x_idx + 1) % V  # target is next digit
    x_onehot = F.one_hot(x_idx, num_classes=V).float()
    return x_onehot, y_idx

# Model, head, loss, optimiser
model = TransformerEncoder(num_layers=2, hidden_dim=64, num_heads=8, ff_dim=128, vocab_size=V, max_len=50)
lm_head = nn.Linear(64, V)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(model.parameters()) + list(lm_head.parameters()), lr=3e-3)


# TODO: Set model to train mode
model.train()

# Training Loop
epochs, batch_size, seq_len = 200, 32, 20
for epoch in range(1, epochs + 1):
    digits, labels = generate_batch(batch_size, seq_len)

    # TODO: Write the training loop
    optimizer.zero_grad()
    logits = lm_head(model(digits))
    loss = criterion(logits.view(-1, V), labels.view(-1))
    loss.backward()
    optimizer.step()

    if epoch % 20 == 0:
        print(f"epoch {epoch:3d} | loss = {loss.item():.4f}")

# Simple Evaluation
model.eval()
with torch.no_grad():
    x_seed, _ = generate_batch(1, 10)
    logits = lm_head(model(x_seed))
    pred_idx = logits.argmax(dim=-1)[0].tolist()
    print("Input :", "".join(itos[i] for i in x_seed.argmax(dim=-1)[0].tolist()))
    print("Output:", "".join(itos[i] for i in pred_idx))


epoch  20 | loss = 0.0756
epoch  40 | loss = 0.0108
epoch  60 | loss = 0.0054
epoch  80 | loss = 0.0038
epoch 100 | loss = 0.0030
epoch 120 | loss = 0.0025
epoch 140 | loss = 0.0021
epoch 160 | loss = 0.0018
epoch 180 | loss = 0.0015
epoch 200 | loss = 0.0013
Input : 8640358829
Output: 9751469930
