<a href="https://colab.research.google.com/github/GiX007/llm-from-scratch/blob/main/00_tiny_gpt_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Building a tiny GPT from Scratch

In this notebook, we build a character-level language model step by step using PyTorch.
Starting from a simple bigram baseline, we progressively introduce batching, context windows, and self-attention, ending with a small Transformer-style model that can generate text.

In [None]:
import torch
import torch.nn as nn
from torch.nn import functional as F

torch.manual_seed(123)

<torch._C.Generator at 0x7c173416aa70>

## Get the dataset

We will use the **tiny-shakespeare** dataset.

In [None]:
# read the text dataset into a single Python string
with open("tiny-shakespeare.txt", "r", encoding="utf-8") as f:
  text = f.read()

In [None]:
# inspect the length of our text dataset (number of characters = number of training tokens later)
print("Dataset length in characters:", len(text))

Dataset length in characters: 1115393


In [None]:
# preview the first 1000 characters
print("\n--- Dataset preview (first 1000 characters) ---\n")
print(text[:1000])


--- Dataset preview (first 1000 characters) ---

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this

## Character-level Tokenization

We convert raw text into numbers so it can be processed by a neural network.
Here we use a simple character-level tokenizer: each unique character gets an integer ID.

In [None]:
# get all unique characters in the dataset
chars = sorted(list(set(text)))

# vocabulary size = number of unique characters
vocab_size = len(chars)

# inspect the vocabulary
print(''.join(chars))
print("Vocab size:", vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Vocab size: 65


In [None]:
# map characters to integers (string → numbers)
stoi = {ch: i for i, ch in enumerate(chars)}

# map integers back to characters (numbers → string)
itos = {i: ch for i, ch in enumerate(chars)}

# encoder: string → list of integers
encode = lambda s: [stoi[c] for c in s]

# decoder: list of integers → string
decode = lambda l: ''.join([itos[i] for i in l])

# quick sanity check
print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [None]:
# encode the entire dataset as a tensor of token IDs
data = torch.tensor(encode(text), dtype=torch.long)

# inspect shape and type
print(data.shape, data.dtype)

# preview first tokens
print(data[:1000])

torch.Size([1115393]) torch.int64
tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
      

This tokenizer is intentionally simple.
Modern LLMs use subword tokenizers (e.g., titokenizer, see https://tiktokenizer.vercel.app/), but character-level models are ideal for learning core concepts.

## Training Sequences: Inputs and Targets

We split the tokenized data into training and validation sets.
Then we build fixed-length input sequences and their next-token targets for language modeling.

In [None]:
# split data into train and validation sets
n = int(0.9 * len(data)) # 90% train, 10% validation
train_data = data[:n]
val_data = data[n:]

In [None]:
# block_size = number of tokens the model can see (context window)
block_size = 8

# inspect one chunk (+1 for target shift)
print(train_data[:block_size + 1])

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])


In [None]:
# input tokens
x = train_data[:block_size]

# target tokens (shifted by one)
y = train_data[1:block_size + 1]

# show how each position predicts the next token
for t in range(block_size):
  context = x[:t + 1]
  target = y[t]
  print(f"when input is {context}, the target: {target}")

when input is tensor([18]), the target: 47
when input is tensor([18, 47]), the target: 56
when input is tensor([18, 47, 56]), the target: 57
when input is tensor([18, 47, 56, 57]), the target: 58
when input is tensor([18, 47, 56, 57, 58]), the target: 1
when input is tensor([18, 47, 56, 57, 58,  1]), the target: 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]), the target: 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), the target: 58


In [None]:
# batch generation
batch_size = 4 # how many independent sequences will we process in parallel (number of sequences per batch)
block_size = 8 # context length

# data loading
def get_batch(split):
  # select train or validation data
  data_split = train_data if split == "train" else val_data

  # randomly choose starting positions
  ix = torch.randint(len(data_split) - block_size, (batch_size,))

  # build input and target batches
  x = torch.stack([data[i:i+block_size] for i in ix]) # stack all vectors into rows
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])

  return x, y

xb, yb = get_batch("train")
print("inputs:")
print(xb.shape)
print(xb)

print("\ntargets:")
print(yb.shape)
print(yb)

inputs:
torch.Size([4, 8])
tensor([[59, 56,  6,  1, 46, 39, 58, 46],
        [53, 56, 40, 47, 42,  1, 53, 59],
        [41, 43,  0, 32, 39, 49, 43,  1],
        [27, 30, 23, 10,  0, 20, 47, 57]])

targets:
torch.Size([4, 8])
tensor([[56,  6,  1, 46, 39, 58, 46,  1],
        [56, 40, 47, 42,  1, 53, 59, 56],
        [43,  0, 32, 39, 49, 43,  1, 53],
        [30, 23, 10,  0, 20, 47, 57,  1]])


In [None]:
# show how every token predicts the next one (per batch element)
for b in range(batch_size):
  for t in range(block_size):
    context = xb[b, :t+1]
    target = yb[b,t]
    print(f"when input is {context.tolist()} the target: {target}")

when input is [59] the target: 56
when input is [59, 56] the target: 6
when input is [59, 56, 6] the target: 1
when input is [59, 56, 6, 1] the target: 46
when input is [59, 56, 6, 1, 46] the target: 39
when input is [59, 56, 6, 1, 46, 39] the target: 58
when input is [59, 56, 6, 1, 46, 39, 58] the target: 46
when input is [59, 56, 6, 1, 46, 39, 58, 46] the target: 1
when input is [53] the target: 56
when input is [53, 56] the target: 40
when input is [53, 56, 40] the target: 47
when input is [53, 56, 40, 47] the target: 42
when input is [53, 56, 40, 47, 42] the target: 1
when input is [53, 56, 40, 47, 42, 1] the target: 53
when input is [53, 56, 40, 47, 42, 1, 53] the target: 59
when input is [53, 56, 40, 47, 42, 1, 53, 59] the target: 56
when input is [41] the target: 43
when input is [41, 43] the target: 0
when input is [41, 43, 0] the target: 32
when input is [41, 43, 0, 32] the target: 39
when input is [41, 43, 0, 32, 39] the target: 49
when input is [41, 43, 0, 32, 39, 49] the ta

In [None]:
# example model input
print(xb.shape)
print(xb)

torch.Size([4, 8])
tensor([[59, 56,  6,  1, 46, 39, 58, 46],
        [53, 56, 40, 47, 42,  1, 53, 59],
        [41, 43,  0, 32, 39, 49, 43,  1],
        [27, 30, 23, 10,  0, 20, 47, 57]])


Each input sequence is trained to predict the next token at every position.
This is the core supervision signal used to train autoregressive language models.

## Bigram Language Model (Baseline)

We start with a very small baseline: a *bigram* language model.
It predicts the next token using only the current token (no attention, no context mixing).

In [None]:
# predicts next token using only the current token (the simplest NN)
class BigramLanguageModel(nn.Module):

  def __init__(self, vocab_size):
    super().__init__()
    # lookup table: token_id -> logits over next token
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size) # (C, C)

  def forward(self, idx, targets=None):
    # idx: (B, T) -> logits: (B, T, C), B for batch size, T (Time) for context length, C (Channels) for vocab size
    logits = self.token_embedding_table(idx)

    loss = None
    if targets is not None:
      B, T, C = logits.shape
      logits = logits.view(B * T, C)
      targets = targets.view(B * T)
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max_new_tokens):
    for _ in range(max_new_tokens):
      logits, loss = self(idx) # get the predictions
      logits = logits[:, -1, :] # focus only on the last time step, it becomes (B, C)
      probs = F.softmax(logits, dim=-1) # next-token distribution, (B, C)
      idx_next = torch.multinomial(probs, num_samples=1) # sample from the distribution, (B, 1)
      idx = torch.cat((idx, idx_next), dim=1) # append sampled index to the running sequence, (B, T+1)
    return idx

In [None]:
# forward pass for sanity check
m = BigramLanguageModel(vocab_size)
logits, loss = m(xb, yb)

print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(4.6669, grad_fn=<NllLossBackward0>)


In [None]:
# generate before training
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))


MuD&sYMZTlXMP?HZTnJfpsh&omS$ApW3zEYQ&rrvjhGy?AYvB;'ECISU
xTA
vCNhscX;aiXMHnk,TPI;D?f&Fb&FZblxzqi.abd


In [None]:
# train the model (few steps, educational)

# optimizer
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)

batch_size = 32

for steps in range(100):
  xb, yb = get_batch('train')

  logits, loss = m(xb, yb)

  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

print(loss.item())

4.593630790710449


In [None]:
# generate after training
print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))


:SSqd,tvmXQh:R&hcheHMuxmXdgoNWVa,HHuA
Fhs;aVkGcwvNnwlx?&3-J3zqAo3ihKHKn?HSlf3bHj'v$hbROpQSwwTkXcGn,H,AVtetlzD:YorwTpFPHvLmPzB;ZTzUETq
ZDJ.-WawbH w-GOAdmvCjKmwosn..WGjDca&sca,J$3h-F'w
3oIxpYH,FEZ;lOVY:S$d!GJVCgfVDAnwyhkdT wlfaRHyDNsP!cjhheAEIwG'MPH,H3goJ&QQxKG3&sIBtPHj$-nuZEoZCVA'D-wLZhKJ
3$?IgImxKp'etpYYe..a:o w wa&d$gJCaCvxU?ub
 WjAlaZO H-dnOLYoaUHHKZDTI!cjvvlrQysn
x weiiXdfAJPefreoB-dUvvBbJOV;,ECAYHdXCK?uL
'Tw$!eih
B
FhWfYxHa w FKddxu!mvFRk' VPHuhcyBTX!vn-Xrs;UvWVYK;3iMsCEkTzNauFa$TXTB

:uwqXK


This model is a strong “hello world”: it learns token-to-token transitions.
Next we'll upgrade it to use context via attention.

## The math behind self-attention

The bigram model only looks at the current token, which severely limits what it can learn. To model longer dependencies, we need a mechanism that lets tokens communicate and share information across the sequence.
Self-attention provides exactly this.

Self-attention lets tokens “communicate”: each token builds a weighted summary of earlier tokens.
Before we implement attention, we'll learn the core trick: masked weighted averaging via matrix multiplication.

**Goal:** allow each token position `t` to combine information from tokens `≤ t` (causal / autoregressive).

In [None]:
# toy example: matrix multiplication as a weighted sum of rows (weighted aggregation)
# tril() makes it causal (only current + previous)
torch.manual_seed(42)

a = torch.tril(torch.ones(3, 3))
a = a / torch.sum(a, 1, keepdim=True) # normalize rows -> weights sum to 1

b = torch.randint(0, 10, (3, 2)).float() # "values" to aggregate
c = a @ b # weighted average

print('a='); print(a); print()
print('b='); print(b); print()
print('c='); print(c)

a=
tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

b=
tensor([[2., 7.],
        [6., 4.],
        [6., 5.]])

c=
tensor([[2.0000, 7.0000],
        [4.0000, 5.5000],
        [4.6667, 5.3333]])


Each row of `a` controls how much information is taken from earlier rows of `b`.
Because of causal masking, position `t` cannot use future information (`> t`), so averaging allows each token to combine information from all previous tokens in a causal way.

To make this concrete, we'll work with a small toy batch.
Each sequence contains a few tokens, and each token is represented by a simple feature vector.

In [None]:
# create a toy sequence batch of: B sequences, each of length T, each token has C features
B,T,C = 4, 8, 2
x = torch.randn(B, T, C)

print(x.shape)
print(x)

torch.Size([4, 8, 2])
tensor([[[-0.0431, -1.6047],
         [ 1.7878, -0.4780],
         [-0.2429, -0.9342],
         [-0.2483, -1.2082],
         [-0.7688,  0.7624],
         [-1.5673, -0.2394],
         [ 2.3228, -0.9634],
         [ 2.0024,  0.4664]],

        [[ 0.8008,  1.6806],
         [ 0.3559, -0.6866],
         [-0.4934,  0.2415],
         [-1.1109,  0.0915],
         [-0.2516,  0.8599],
         [-0.3097, -0.3957],
         [ 0.8034, -0.6216],
         [-0.5920, -0.0631]],

        [[ 0.3057, -0.7746],
         [ 0.0349,  0.3211],
         [ 1.5736, -0.8455],
         [ 1.3123,  0.6872],
         [-1.2347, -0.4879],
         [-1.4181,  0.8963],
         [ 0.0499,  2.2667],
         [ 1.1790, -0.4345]],

        [[-0.8140, -0.7360],
         [-0.8371, -0.9224],
         [ 1.8113,  0.1606],
         [ 0.3672,  0.1754],
         [-1.1845,  1.3835],
         [-1.2024,  0.7078],
         [-1.0759,  0.5357],
         [ 1.1754,  0.5612]]])


### Step 1: Explicit averaging over previous tokens

In [None]:
# goal: xbow[b,t] = mean_{i<=t} x[b,i]
xbow = torch.zeros((B,T,C)) # bow from bag of words

for b in range(B):
  for t in range(T):
    xprev = x[b,:t+1] # tokens up to t, (t, C)
    xbow[b,t] = torch.mean(xprev, 0) # mean over time

In [None]:
print(x[0])

tensor([[-0.0431, -1.6047],
        [ 1.7878, -0.4780],
        [-0.2429, -0.9342],
        [-0.2483, -1.2082],
        [-0.7688,  0.7624],
        [-1.5673, -0.2394],
        [ 2.3228, -0.9634],
        [ 2.0024,  0.4664]])


In [None]:
print(xbow[0])

tensor([[-0.0431, -1.6047],
        [ 0.8724, -1.0414],
        [ 0.5006, -1.0056],
        [ 0.3134, -1.0563],
        [ 0.0970, -0.6925],
        [-0.1804, -0.6170],
        [ 0.1772, -0.6665],
        [ 0.4053, -0.5249]])


Notice how `xbow[0, t]` is the average of `x[0, 0:t+1]`.


### Step 2: Efficient averaging via matrix multiplication

In [None]:
# build causal averaging weights (T x T)
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(1, keepdim=True) # row-normalize

# (T, T) @ (B, T, C) -> (B, T, C) via broadcasting
xbow2 = wei @ x

torch.allclose(xbow, xbow2)

True

### Step 3: Softmax-based causal weighting

In [None]:
# softmax turns scores into probabilities
tril = torch.tril(torch.ones(T, T))

wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf')) # block future positions
print(wei)

wei = F.softmax(wei, dim=-1) # row sums to 1
print(wei)

xbow3 = wei @ x
torch.allclose(xbow, xbow3)

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


True

So far, the weights are *fixed* (not learned, not data-dependent).
Self-attention makes these weights depend on the token content.

### Step 4: Learned, data-dependent weighting (self-attention)

**Self-attention intuition**

Each token decides how much to attend to other tokens based on their content.
This is done by comparing what a token is looking for with what other tokens contain.

Each token produces three vectors:
- **Query (Q):** what I am looking for
- **Key (K):** what I contain
- **Value (V):** the information I provide

Attention weights are computed from query-key similarity and used to mix values.

At each position:
1. compute similarity scores between queries and keys
2. apply causal masking to block future tokens
3. normalize scores with softmax
4. use the weights to aggregate values


Self-attention (single head).

In [None]:
# self-attention with learned, data-dependent weights
B, T, C = 4, 8, 32 # batch, time, channels
x = torch.randn(B, T, C)

head_size = 16

key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)

k = key(x) # (B, T, head_size)
q = query(x) # (B, T, head_size)

# attention scores
wei =  q @ k.transpose(-2, -1) # (B, T, head_size) @ (B, head_size, T) ---> (B, T, T)

# causal mask (decoder-style)
tril = torch.tril(torch.ones(T, T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1) # normalize

# aggreagate values
v = value(x) # (B, T, head_size)
out = wei @ v # (B, T, head_size)

In [None]:
# display the output
print(out.shape)
print(out)

torch.Size([4, 8, 16])
tensor([[[ 1.2309e+00, -5.8830e-01,  8.3699e-02, -2.4703e-01,  5.7573e-02,
          -1.2661e+00,  1.3852e-01,  8.2426e-01,  1.2837e-02, -2.9015e-01,
           2.6208e-01,  3.7443e-01,  6.7970e-01,  5.9684e-01,  3.1133e-01,
          -4.7922e-01],
         [ 7.4428e-01, -2.3934e-02, -6.6331e-02, -3.7515e-02,  3.5757e-02,
          -3.4072e-01,  2.2588e-01,  4.4565e-01,  3.7820e-02,  2.1374e-01,
           6.0595e-01,  2.6574e-01,  2.4792e-01,  4.7638e-01, -1.4091e-01,
           2.7625e-01],
         [ 4.3012e-01, -3.1783e-02,  1.4615e-01, -3.6217e-01, -2.1957e-01,
          -1.4638e-01, -5.8746e-01,  1.2682e-01, -5.6536e-02,  2.9793e-01,
          -1.8272e-01,  6.8918e-02,  1.8934e-01,  5.2959e-01, -4.0646e-02,
          -2.5657e-02],
         [ 3.7924e-01,  1.3171e-03,  1.5383e-01, -4.3043e-01, -2.2283e-01,
          -1.5539e-01, -5.5159e-01,  8.2522e-02, -1.0567e-01,  2.8703e-01,
          -8.5874e-02,  1.3697e-02,  5.1500e-02,  4.9496e-01, -9.5157e-02,
     

In [None]:
# inspect attention weights for one example
print(wei[0])

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3131, 0.6869, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3071, 0.1665, 0.5264, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2865, 0.1539, 0.4501, 0.1096, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0935, 0.1449, 0.1449, 0.2585, 0.3581, 0.0000, 0.0000, 0.0000],
        [0.0685, 0.1716, 0.0297, 0.0232, 0.2059, 0.5010, 0.0000, 0.0000],
        [0.1452, 0.0654, 0.2293, 0.0887, 0.1483, 0.1025, 0.2206, 0.0000],
        [0.0633, 0.0456, 0.0105, 0.0078, 0.1231, 0.3747, 0.0454, 0.3295]],
       grad_fn=<SelectBackward0>)


Attention is a communication mechanism: each token aggregates information from other tokens using learned weights.

There is no notion of order in attention itself, which is why positional information must be added separately.

Each sequence in the batch is processed independently and never "talk" to each other.

Causal masking makes this a decoder-style attention block, suitable for autoregressive language modeling.

Scaling by `1 / sqrt(head_size)` keeps attention scores at a reasonable scale.
This prevents softmax from becoming too sharp as the embedding dimension grows.

In [None]:
k = torch.randn(B, T, head_size)
q = torch.randn(B, T, head_size)
print(k.var(), q.var())

# unscaled dot-product attention
wei = q @ k.transpose(-2, -1)  # (B, T, T)
print(wei.var())

# scaled dot-product attention
wei = q @ k.transpose(-2, -1) * head_size**-0.5
print(wei.var())

tensor(1.0116) tensor(0.9358)
tensor(14.3098)
tensor(0.8944)


In [None]:
# softmax behavior example
print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1))

print(torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]) * 8, dim=-1)) # gets too peaky, converges to one-hot

tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])


This “masked, scaled, softmax weighted sum” is the core math of decoder self-attention.
Next we'll use it to train a new bigram model.

## Normalization in Transformers (LayerNorm)

Transformers use **LayerNorm** (not BatchNorm) because it normalizes each token's feature vector independently.
This works well for variable-length sequences and avoids batch-dependent behavior.

**BatchNorm** normalizes each feature using statistics across the batch (and often time).
**LayerNorm** normalizes each example (row) using statistics across its features.

In Transformers we use **LayerNorm**.

In [None]:
# minimal LayerNorm implementation
class LayerNorm1d:

  def __init__(self, dim, eps=1e-5):
    self.eps = eps
    self.gamma = torch.ones(dim)
    self.beta = torch.zeros(dim)

  def __call__(self, x):
    # forward pass: normalize each row (one example) across features
    xmean = x.mean(1, keepdim=True)
    xvar = x.var(1, keepdim=True, unbiased=False)
    xhat = (x - xmean) / torch.sqrt(xvar + self.eps)
    return self.gamma * xhat + self.beta

  def parameters(self):
    return [self.gamma, self.beta]

In [None]:
# test it
module = LayerNorm1d(100)
x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors

x = module(x)
print(x.shape)
# print(x)

torch.Size([32, 100])


In [None]:
# feature stats across batch (not guaranteed by LayerNorm)
print(x[:, 0].mean(), x[:, 0].std())

tensor(0.0640) tensor(0.7472)


In [None]:
# per-sample stats across features (what LayerNorm enforces)
print(x[0, :].mean(), x[0, :].std())

tensor(4.7684e-09) tensor(1.0050)


LayerNorm makes each token's feature vector have stable scale.
This helps attention + MLP layers train reliably.

## Bigram Language Model with Self-Attention

The original bigram model predicts the next token using only the current one.
Here, we upgrade it by adding embeddings, positional information, and stacked self-attention blocks, turning it into a small Transformer-style language model.

In [None]:
# training and model hyperparameters
batch_size = 16 # sequences per batch
block_size = 32 # context length
max_iters = 5000
eval_interval = 100
learning_rate = 1e-3
eval_iters = 200

# model size
n_embd = 64 # embedding dimension
n_head = 4 # attention heads
n_layer = 4 # transformer blocks
dropout = 0.0

# device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
@torch.no_grad()
def estimate_loss():
  # evaluate average loss on train and validation splits
  out = {}
  model.eval()

  for split in ['train', 'val']:
    losses = torch.zeros(eval_iters)
    for k in range(eval_iters):
      X, Y = get_batch(split)
      X, Y = X.to(device), Y.to(device)
      logits, loss = model(X, Y)
      losses[k] = loss.item()
    out[split] = losses.mean()

  model.train()
  return out

class Head(nn.Module):
  """One causal self-attention head."""

  def __init__(self, head_size):
    super().__init__()
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)

    # causal mask (not a parameter)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    B,T,C = x.shape

    k = self.key(x)   # (B, T, C)
    q = self.query(x) # (B, T, C)

    # scaled dot-product attention
    wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
    wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
    wei = F.softmax(wei, dim=-1) # (B, T, T)
    wei = self.dropout(wei)

    # weighted aggregation
    v = self.value(x) # (B, T, C)
    out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)

    return out

class MultiHeadAttention(nn.Module):
  """Multiple attention heads in parallel."""

  def __init__(self, num_heads, head_size):
    super().__init__()
    self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
    self.proj = nn.Linear(n_embd, n_embd)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    # concatenate heads and project back
    out = torch.cat([h(x) for h in self.heads], dim=-1)
    out = self.dropout(self.proj(out))
    return out

class FeedFoward(nn.Module):
  """Position-wise MLP."""

  def __init__(self, n_embd):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd, 4 * n_embd),
        nn.ReLU(),
        nn.Linear(4 * n_embd, n_embd),
        nn.Dropout(dropout),
      )

  def forward(self, x):
    return self.net(x)

class Block(nn.Module):
  """Transformer block: Attention followed by computation."""

  def __init__(self, n_embd, n_head):
    super().__init__()
    head_size = n_embd // n_head
    self.sa = MultiHeadAttention(n_head, head_size)
    self.ffwd = FeedFoward(n_embd)
    self.ln1 = nn.LayerNorm(n_embd)
    self.ln2 = nn.LayerNorm(n_embd)

  def forward(self, x):
    # residual connections + pre-norm
    x = x + self.sa(self.ln1(x))
    x = x + self.ffwd(self.ln2(x))
    return x

class BigramLanguageModel(nn.Module):
  """Transformer-style autoregressive language model."""

  def __init__(self):
    super().__init__()

    # token and position embeddings
    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    self.position_embedding_table = nn.Embedding(block_size, n_embd)

    # transformer blocks
    self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])

    self.ln_f = nn.LayerNorm(n_embd) # final layer norm
    self.lm_head = nn.Linear(n_embd, vocab_size) # projection to vocab size

  def forward(self, idx, targets=None):
    B, T = idx.shape

    tok_emb = self.token_embedding_table(idx) # (B, T, C)
    pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T, C)

    x = tok_emb + pos_emb # (B, T, C)
    x = self.blocks(x) # (B, T, C)
    x = self.ln_f(x) # (B, T, C)
    logits = self.lm_head(x) # (B, T, vocab_size)

    loss = None
    if targets is not None:
      B, T = targets.shape
      logits = logits.view(B * T, logits.size(-1)) # last dim = vocab_size
      targets = targets.view(B * T)
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max_new_tokens):
    for _ in range(max_new_tokens):
      idx_cond = idx[:, -block_size:] # # crop context (to the last block_size tokens)
      logits, loss = self(idx_cond)
      logits = logits[:, -1, :] # last token
      probs = F.softmax(logits, dim=-1) # (B, C)
      idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
      idx = torch.cat((idx, idx_next), dim=1) # (B, T + 1)
    return idx

In [None]:
# create the model
model = BigramLanguageModel()
m = model.to(device)

# number of parameters
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

0.209729 M parameters


In [None]:
# training loop
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

  # periodic evaluation
  if iter % eval_interval == 0 or iter == max_iters - 1:
    losses = estimate_loss()
    print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

  xb, yb = get_batch('train')

  # move batch to the same device as the model
  xb, yb = xb.to(device), yb.to(device)

  logits, loss = model(xb, yb)

  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

step 0: train loss 4.3478, val loss 4.3493
step 100: train loss 2.6413, val loss 2.6349
step 200: train loss 2.5119, val loss 2.4959
step 300: train loss 2.4183, val loss 2.4009
step 400: train loss 2.3519, val loss 2.3355
step 500: train loss 2.3160, val loss 2.3115
step 600: train loss 2.2693, val loss 2.2505
step 700: train loss 2.2204, val loss 2.2142
step 800: train loss 2.1657, val loss 2.1726
step 900: train loss 2.1433, val loss 2.1419
step 1000: train loss 2.1077, val loss 2.1031
step 1100: train loss 2.0846, val loss 2.0855
step 1200: train loss 2.0552, val loss 2.0803
step 1300: train loss 2.0198, val loss 2.0265
step 1400: train loss 2.0187, val loss 1.9980
step 1500: train loss 1.9861, val loss 1.9839
step 1600: train loss 1.9628, val loss 1.9560
step 1700: train loss 1.9413, val loss 1.9347
step 1800: train loss 1.9294, val loss 1.9307
step 1900: train loss 1.9011, val loss 1.8965
step 2000: train loss 1.8925, val loss 1.8882
step 2100: train loss 1.8827, val loss 1.8737


In [None]:
# generate text from the trained model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))


Is time fits this seggains modistand heard:
Therefore imploish beepapinish.
IfI sakes and state her confain?
This did atch the will posdeling.

KING EDWARD IFI:
He forch not shall he is will wull hich sends thou asters theee?
This metter lesply?

VOLUS:
What woll!--follish he must lest brone
down: a thou takely and ford priceptering;
Let sons, my lundly for gods?

Proyal, ban tell mine?

LEONTES:
what'snow this speak:
Know my assure on my fair?

QIEss tiver my lardshall and and barck duke not
The son.

KING RICHEbold my let from the spoys appon this
Strock-stondlandled cargelands very your gone.

WARLID IFILLIZEL:
And this?

WARWILK:
Fortabant to the luck accusel leaven's will de'th.
Now, A leaven'd shame hath I speech
That I commannible-my spicks you woran,
Be forthen this of old my lord frot myscans blons and in toerclesive.

NFRORS MARUTIUS:

PlENTER:
Hy men. Engellars, what agoty, my doniesce.

PARIAN:
What barn thy king Caming;! anfull somembet quollif,
Or him and watencous own T

This model can now use context, position, and learned attention to make predictions. During training, we observe the loss steadily decreasing, which indicates that the model is learning meaningful patterns in the data.

The generated text, while still far from fully consistent, already shows structure resembling Shakespeare-like dialogue. Its limitations come from the model's small size, short context window, character-level tokenization, and limited training time. Using subword or word-level tokens, along with larger models, longer contexts, and more training data, would significantly improve generation quality.


We started with a minimal bigram model and incrementally added the core ideas behind Transformers.
By building everything from scratch, we saw how attention replaces fixed averaging with learned, data-dependent communication between tokens.
The same principles scale directly to large language models used in practice today.