# Transformer from Scratch

I reimplemented *Attention is All You Need* and built up the encoder/decoder layers very modular. I initally wrote a test suite to make sure that gradients are flowing and shapes are matching and then implemented each block independently. 

My goal for here is to make my model say somewhat coherent text. I'm using the famous shakespeare dataset (competely raw, no text cleaning) and aiming to use a decoder only network to make our model say words (just like ChatGPT).

First, I need to grab the data.

In [1]:
import subprocess

subprocess.run("curl -o data/input.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt", shell=True)

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
  0 1089k    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0
curl: (56) Failure writing output to destination, passed 1369 returned 4294967295


CompletedProcess(args='curl -o data/input.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt', returncode=56)

I also need to fix the global path.

In [2]:
import sys
import os
sys.path.append(os.path.abspath(".."))  # transformer/ is up one level

We can start by calling the main imports.

In [3]:
import torch
import torch.nn as nn
from transformer.transformer import Transformer
from transformer.decoder import Decoder
from transformer.utils import CharDataset, generate_subsequent_mask
import yaml

Before we dive into the code, I want to make sure gradients are flowing, forward / backward passes are working, and shapes are matching (all covered with the test suite I wrote).

In [4]:
import subprocess
import os

os.chdir("../tests")

test_files = [
    "test_attention.py",
    "test_decoder.py",
    "test_encoder.py",
    "test_transformer.py",
    "test_utils.py"
]

for test_file in test_files:
    print(f"\nRunning {test_file}...")
    result = subprocess.run(["pytest", test_file], capture_output=True, text=True)
    print(result.stdout)
    print(result.stderr)


Running test_attention.py...
platform darwin -- Python 3.13.2, pytest-8.4.1, pluggy-1.6.0
rootdir: /Users/akhilvreddy/Documents/transformers-reimplementation
configfile: pytest.ini
collected 2 items

test_attention.py [32m.[0m[32m.[0m[32m                                                     [100%][0m




Running test_decoder.py...
platform darwin -- Python 3.13.2, pytest-8.4.1, pluggy-1.6.0
rootdir: /Users/akhilvreddy/Documents/transformers-reimplementation
configfile: pytest.ini
collected 1 item

test_decoder.py [32m.[0m[32m                                                        [100%][0m




Running test_encoder.py...
platform darwin -- Python 3.13.2, pytest-8.4.1, pluggy-1.6.0
rootdir: /Users/akhilvreddy/Documents/transformers-reimplementation
configfile: pytest.ini
collected 1 item

test_encoder.py [32m.[0m[32m                                                        [100%][0m




Running test_transformer.py...
platform darwin -- Python 3.13.2, pytest-8.4.1, pluggy-1.6.

We now have our data in `data/input.txt`, our test suite is passing, and our model is setup with the params I set in `config.yaml`. 

Let's load our data

In [5]:
text = open("../data/input.txt").read()
block_size = 64

dataset = CharDataset(text, block_size)
vocab_size = dataset.vocab_size
print("Vocab size:", vocab_size)

Vocab size: 65


Our vocab size of 65 means that we have 65 *independent* characters that are going to be fed into the model. You can think of the ASCII characters (a, b, A, B, \n, ., ?) that are each going to be treated like a unique token with its own embedding.

So this means that our model has an embedding layer of shape `(65, d_model)` and output logits are shaped `(B, T, 65)` which is one score for each possible character at vevery time step, for each input in the batch.

During sampling I'll argmax or sample (beam / top k) from that 65-sized vector. To start with I'll argmax just to see if the text is cohesive and later I'll switch to other techniques that give better meaning.

We can go ahead and initialize our decoder-only autoregrssive model.

In [6]:
cfg = yaml.safe_load(open("../config.yaml"))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = Decoder(
    num_layers=cfg["num_layers"],
    d_model=cfg["d_model"],
    num_heads=cfg["num_heads"],
    dim_ff=cfg["dim_ff"],
    vocab_size=dataset.vocab_size,  # overwrite config to match actual dataset
).to(device)

In [7]:
decoder = Decoder(
    num_layers=cfg["num_layers"],
    d_model=cfg["d_model"],
    num_heads=cfg["num_heads"],
    dim_ff=cfg["dim_ff"],
    vocab_size=vocab_size,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
decoder = decoder.to(device)

Before we start with a full blown training loop, I want to make sure a single training step would work.

In [8]:
from torch.nn import CrossEntropyLoss
from torch.optim import Adam

# grab one (x, y) pair and move to device
x, y = dataset[0]
x = x.unsqueeze(0).to(device) # shape: (1, T)
y = y.unsqueeze(0).to(device) # shape: (1, T)

# get causal mask
tgt_mask = generate_subsequent_mask(x.size(1)).to(device)

# forward pass
logits = decoder(x, enc_out=None, src_mask=None, tgt_mask=tgt_mask)

# loss
loss_fn = CrossEntropyLoss()
loss = loss_fn(logits.view(-1, vocab_size), y.view(-1))
print("Loss:", loss.item())

# backward pass
decoder.zero_grad()
loss.backward()

# quick optimizer step to verify gradient flow
optimizer = Adam(decoder.parameters(), lr=float(cfg["learning_rate"]))
optimizer.step()

Loss: 4.40109395980835


After a lot of debugging and editing my transformer so that it supports a decoder-only model, it finally returned a loss! The scalar value of the loss was **4.401**, which is pretty awful. I'm confident to throw the model in a training loop now.

In [9]:
from tqdm import tqdm

num_steps = 100

for step in tqdm(range(num_steps), desc="Training"):
    x, y = dataset[step]
    x = x.unsqueeze(0).to(device)
    y = y.unsqueeze(0).to(device)
    tgt_mask = generate_subsequent_mask(x.size(1)).to(device)

    logits = decoder(x, enc_out=None, src_mask=None, tgt_mask=tgt_mask)
    loss = loss_fn(logits.view(-1, vocab_size), y.view(-1))

    decoder.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 10 == 0:
        print(f"Step {step} | Loss: {loss.item():.4f}")

Training:   5%|▌         | 5/100 [00:00<00:04, 19.58it/s]

Step 0 | Loss: 2.8867


Training:  14%|█▍        | 14/100 [00:00<00:04, 20.48it/s]

Step 10 | Loss: 0.9605


Training:  25%|██▌       | 25/100 [00:01<00:03, 19.77it/s]

Step 20 | Loss: 0.8871


Training:  34%|███▍      | 34/100 [00:01<00:03, 20.46it/s]

Step 30 | Loss: 0.7898


Training:  43%|████▎     | 43/100 [00:02<00:02, 20.03it/s]

Step 40 | Loss: 0.9317


Training:  55%|█████▌    | 55/100 [00:02<00:02, 20.88it/s]

Step 50 | Loss: 0.9714


Training:  64%|██████▍   | 64/100 [00:03<00:01, 20.83it/s]

Step 60 | Loss: 0.7389


Training:  73%|███████▎  | 73/100 [00:03<00:01, 20.70it/s]

Step 70 | Loss: 0.8672


Training:  85%|████████▌ | 85/100 [00:04<00:00, 20.78it/s]

Step 80 | Loss: 1.2584


Training:  94%|█████████▍| 94/100 [00:04<00:00, 20.09it/s]

Step 90 | Loss: 1.0607


Training: 100%|██████████| 100/100 [00:04<00:00, 20.20it/s]


In [10]:
@torch.no_grad()
def generate_text(model, dataset, prompt, max_new_tokens=100):
    model.eval()
    stoi, itos = dataset.stoi, dataset.itos

    # encode the prompt
    idx = torch.tensor([stoi[c] for c in prompt], dtype=torch.long)[None].to(device)

    for _ in range(max_new_tokens):
        tgt_mask = generate_subsequent_mask(idx.size(1)).to(device)
        logits = model(idx, enc_out=None, src_mask=None, tgt_mask=tgt_mask)

        # take the last time step
        next_token_logits = logits[:, -1, :]
        next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

        # append to sequence
        idx = torch.cat([idx, next_token], dim=1)

    return ''.join([itos[i.item()] for i in idx[0]])

In [11]:
print(generate_text(decoder, dataset, prompt="The king", max_new_tokens=200))

The kingzen:
You are are alvesolvesolvesolvesolvesolvesolvesolvesolvesolverarararararararatolverererererererererererererererererererererererererererererererererererererereramiesolvesolvesolveramiesolvesolvera


It's amazing to see that the model generated text! The text looks horrible but it's cool to see that our decoder only model learned *something*. It went on a huge tangent repeating "thamiso" over and over but one cool thing to is that it learned punctuation well. Towards the end of the first line you can see that we have a "?" and right after we get a capital "A". That put a smile on my face since I could see that there were some attention heads paying attention to that.

I'm going to incrementally keep scaling up training. To generate that text, it took me ~10 seconds of training. I want to scale up so that the loop takes at least 2 to 3 minutes so that I can get the loss below 0.5.

I'm going to make a bunch of changes here for easy training on my laptop. 

First, I'm starting with reducing the input size to only 50000 characters.

In [24]:
text = open("../data/input.txt").read()
text = text[:50000]
block_size = 128

dataset = CharDataset(text, block_size)
vocab_size = dataset.vocab_size
print("Vocab size:", vocab_size)

Vocab size: 59


I'm going to set up a DataLoader so we can do a full on training loop with batching.

In [25]:
from torch.utils.data import DataLoader

loader = DataLoader(dataset, batch_size=cfg["batch_size"], shuffle=True)

I'm also going to upgrade to mps now because this training loop is actually quite large.

In [26]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Using device:", device)

Using device: mps


In [None]:
decoder = decoder.to(device)

# MPS needs this to prevent the embedding "Placeholder" error (thank you Claude)
with torch.no_grad():
    dummy = torch.randint(0, vocab_size, (1, block_size)).to(device)
    mask = generate_subsequent_mask(block_size).to(device)
    _ = decoder(dummy, enc_out=None, src_mask=None, tgt_mask=mask)

In [21]:
from tqdm import tqdm

num_epochs = 1
loss_fn = CrossEntropyLoss()
optimizer = Adam(decoder.parameters(), lr=float(cfg["learning_rate"]))

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    pbar = tqdm(loader, total=len(loader), desc="Training")

    for x, y in pbar:
        x = x.to(device)
        y = y.to(device)
        tgt_mask = generate_subsequent_mask(x.size(1)).to(device)

        logits = decoder(x, enc_out=None, src_mask=None, tgt_mask=tgt_mask)
        # loss = loss_fn(logits.view(-1, vocab_size), y.view(-1))
        loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1))

        decoder.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

Epoch 1/1


Training: 100%|██████████| 780/780 [03:04<00:00,  4.23it/s, loss=0.1381]


In [22]:
print(generate_text(decoder, dataset, prompt="The king", max_new_tokens=200))

The king fury, or go condition.

AUFIDIUS:
Condition!
I would I were a Roman; for I cannot,
Being a Volsce, be that I am. Condition!
Where hough addites and one this last?

MENENIUS:
The then shalle be the sh


That's real words! And it definitley does look like it's following same structure as shakespeare. I'm super bullish on this method now. 

I'm also super grateful for having mps - it definitley sped up training massively. 

Let's redo the model with 2 or 4 epochs and see how much better the text generation can look like (before we go into token sampling techniques).

(I froze the 2 cells above and re-ran from the beginning of this so that we can start the model from scratch)

In [29]:
from tqdm import tqdm

num_epochs = 2
loss_fn = CrossEntropyLoss()
optimizer = Adam(decoder.parameters(), lr=float(cfg["learning_rate"]))

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    pbar = tqdm(loader, total=len(loader), desc="Training")

    for x, y in pbar:
        x = x.to(device)
        y = y.to(device)
        tgt_mask = generate_subsequent_mask(x.size(1)).to(device)

        logits = decoder(x, enc_out=None, src_mask=None, tgt_mask=tgt_mask)
        # loss = loss_fn(logits.view(-1, vocab_size), y.view(-1))
        loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1))

        decoder.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix({"loss": f"{loss.item():.4f}"})

Epoch 1/2


Training: 100%|██████████| 780/780 [03:05<00:00,  4.20it/s, loss=0.1175]


Epoch 2/2


Training: 100%|██████████| 780/780 [03:10<00:00,  4.10it/s, loss=0.1012]


Time to see if this helped

In [30]:
print(generate_text(decoder, dataset, prompt="The king", max_new_tokens=200))

The kingly-crowned head, the vigilant eye,
The counsellor heart, the arm our soldier,
Our steed the leg, the tongue our trumpeter.
With other muniments and petty helpere.

SICINIUS:
He the to the revere yedse


So just *2 epochs in*, this looks clearly like shakespearean english. I'm confident our transformer is doing a good job at learn and the auto-regressive property is working really well.

I can keep scaling up the training loop now, but we can clearly see how the core architecture is working. I'm really happy with how the model was able to generate cohesive text just from character-level tokens. We can only imagine how much better this would be if we had better pre-processing, bpe level tokenization, and an actual GPU to train heavily.

## Recap

Before I finish, I wanted to recap what I did. 

1) **Test suite**

I struggled with shapes and masking logic issues in a previous verion so I wanted to fix that here. I started by writing test cases for attention (self and multi-head) to be calcualted properly and also wanted the shapes to be correct. I also had a bunch of util functions (mainly for help with masking) that I needed to make sure if they were returning the right shapes at the end. After I had test cases for those parts, writing tests for encoder, decoder, and transformer were pretty straightforward.

2) **Writing blocks from scratch**

I wrote embeddings and attention first. Embeddings wasn't that bad because I used PyTorch's native method for token embeddigns and I was able to refer to a lot of code online for positional encoding (since those calcualtions never change). 

After I had those down, writing the encoder, decoder, and trasnformer were super easy - it was just stacking the blocks we had already in multiple different ways.

3) **Writing and tweaking my config.yaml**

This was my most favorite part - it's super simple and I like having the ability to change everything about my model just from that yaml file. I played around with a bunch of different batch sizes and architecture settings.


Throughout this project, I really learned the power of repeated *Attention + FF + LayerNorm* blocks - it's quite literally the strongest architecture I've ever worked with.