In [1]:
# The way to think about this is that:
#   - self-attention is the communication between the tokens
#   - then once they've gathered all the data
#   - now they need to "think" on that data individually
#       - i.e. compute in the Linear followed by Relu

# We can stack these layers now:
#   - attention block 1
#       - MultiHeadAttention 1
#       - Linear 1
#       - Relu 1
#   - attention block 2
#       - MultiHeadAttention 2
#       - Linear 2
#       - Relu 2
#   - ...

# communication/computation sandwiches:
#   - communication -> computation -> communication -> computation -> ...

# This network ends up getting pretty deep, hence "deep learning".
# Without residual connections and normalization layers, the network is unstable and hard to train.
# Just adding the residual connections stabilizes the network (and it performs way better... wtf!).
# Slightly better stability with layer norm
# todo: add comment on dropout
# todo: add comment on hyperparameters

import torch
import torch.nn as nn
from torch.nn import functional as F
import datetime

torch.set_default_device("mps")  # use gpu
torch.manual_seed(1337)

# hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 512 # what is the maximum context length for predictions?
n_embd = 768
n_head = 12  # every head is 64 dimensional
n_transformer_blocks = 12
dropout = 0.2
# ------------

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('../res/tinyshakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    m.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = m(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    m.train()
    return out

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)  # (B, T, head_size)
        q = self.query(x)  # (B, T, head_size)
        # compute attention scores ("affinities")
        wei_logits = q @ k.transpose(-2, -1) * self.head_size**-0.5  # (B, T, head_size) @ (B, head_size, T)  -->  (B, T, T)
        wei_logits = wei_logits.masked_fill(self.tril[:T, :T] == 0, float('-inf'))  # (B, T, T)
        wei = F.softmax(wei_logits, dim=-1)
        wei = self.dropout(wei)  # prevent some of the nodes from communicating with dropout (avoids overfitting) (creates ensemble)
        # perform the weighted aggregation of the values
        v = self.value(x)  # (B, T, head_size)
        out = wei @ v  # (B, T, T) @ (B, T, head_size)  -->  (B, T, head_size)
        return out

# todo: implementing this was very simple but what does it mean and why is better than single head?
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])  # heads of attention
        self.proj = nn.Linear(n_embd, n_embd)  # "projection back into the residual pathway"
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)  # concatenate over the channel dimension (B, T, C)
        out = self.proj(out)  # "project back into the residual pathway"
        out = self.dropout(out)   # apply droupout on residual path
        return out

class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity """
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
        )
        self.proj = nn.Linear(4 * n_embd, n_embd)  # "project back into the residual pathway"
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = self.net(x)
        out = self.proj(out)  # "project back into the residual pathway"
        out = self.dropout(out)  # apply droupout on residual path
        return out

class TransformerBlock(nn.Module):
    """ Transformer block: communication followed by computation """

    def __init__(self, n_embd, n_head):
        # n_embd: embedding dimension, n_head: the number of heads we'd like
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)  # feed forward per token (cuz applies only to last dimension)
        self.ln1 = nn.LayerNorm(n_embd)  # batch norm per token (cuz applies only to last dimension)
        self.ln2 = nn.LayerNorm(n_embd)  # batch norm per token (cuz applies only to last dimension)

    def forward(self, x):
        # `x = x + <some computation>` is the residual pathway...
        x = x + self.sa(self.ln1(x))  # MultiHeadAttention now also "projects back into the residual pathway"
        x = x + self.ffwd(self.ln2(x))  # FeedForward now also "projects back into the residual pathway"
        return x

# super simple bigram model
class BigramLanguageModel(nn.Module):

    def __init__(self):
        super().__init__()
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)  # token information (in token embedding space)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)  # positional information (in position embedding space)
        self.blocks = nn.Sequential(*[TransformerBlock(n_embd, n_head) for _ in range(n_transformer_blocks)])
        self.ln_f = nn.LayerNorm(n_embd) # final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)  # embedding space --> vocabulary space

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B, T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        pos_emb = self.position_embedding_table(torch.arange(T)) # (T, C)
        x = tok_emb + pos_emb # (B, T, C)
        x = self.blocks(x) # (B, T, C)
        x = self.ln_f(x)  # (B, T, C)
        logits = self.lm_head(x) # (B, T, vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx (B, T) array of indices in the current context
            # (never pass more than block_size tokens)
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

m = BigramLanguageModel()

In [19]:
# Train the model
max_iters = 5000
ping_interval = 50
eval_interval = 500
learning_rate = 3e-4
eval_iters = 50

# create a PyTorch optimizer
torch.set_default_device("cpu")  # use cpu
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate)
torch.set_default_device("mps")  # back to gpu

for iter in range(max_iters):
    # every once in a while evaluate the loss on train and val sets
    if iter % ping_interval == 0:
        print(f"Iteration {iter} - {datetime.datetime.now()}")
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = m(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

losses = estimate_loss()
print(f"Final: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

Iteration 0 - 2025-03-17 09:37:19.792961
step 0: train loss 0.0788, val loss 3.1145
Iteration 50 - 2025-03-17 09:41:13.630140
Iteration 100 - 2025-03-17 09:42:58.828082
Iteration 150 - 2025-03-17 09:44:43.702476
Iteration 200 - 2025-03-17 09:46:29.231300
Iteration 250 - 2025-03-17 09:49:45.625336
Iteration 300 - 2025-03-17 09:51:35.061981
Iteration 350 - 2025-03-17 09:53:22.197388
Iteration 400 - 2025-03-17 09:55:07.869515
Iteration 450 - 2025-03-17 09:56:52.600465
Iteration 500 - 2025-03-17 09:58:37.116374
step 500: train loss 0.0716, val loss 3.2624
Iteration 550 - 2025-03-17 10:01:24.741699
Iteration 600 - 2025-03-17 10:03:08.923660
Iteration 650 - 2025-03-17 10:04:55.466170
Iteration 700 - 2025-03-17 10:06:42.796788
Iteration 750 - 2025-03-17 10:08:30.047997
Iteration 800 - 2025-03-17 10:10:16.773955
Iteration 850 - 2025-03-17 10:12:02.040344
Iteration 900 - 2025-03-17 10:13:48.262719
Iteration 950 - 2025-03-17 10:15:36.336642
Iteration 1000 - 2025-03-17 10:17:23.571704
step 1000: 

KeyboardInterrupt: 

In [20]:
# generate from the model
context_tokens = encode("""To be, or not to be: that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take arms against a sea of troubles,
And by opposing end them? To die: to sleep;
No more; and by a sleep to say we end
The heart-ache and the thousand natural shocks
That flesh is heir to, 'tis a consummation
Devoutly to be wish'd. To die, to sleep;
To sleep: perchance to dream: ay, there's the rub;""")
print(len(context_tokens))
# context = torch.zeros((1, 1), dtype=torch.long)
context = torch.tensor(context_tokens, dtype=torch.long).reshape(1, -1)
print(context.shape)
print(decode(m.generate(context, max_new_tokens=1000)[0].tolist()))

436
torch.Size([1, 436])
To be, or not to be: that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take arms against a sea of troubles,
And by opposing end them? To die: to sleep;
No more; and by a sleep to say we end
The heart-ache and the thousand natural shocks
That flesh is heir to, 'tis a consummation
Devoutly to be wish'd. To die, to sleep;
To sleep: perchance to dream: ay, there's the rub;
Or in thy poor breath, still thy land and bosom,
That then ranks thought to bear them at the heel
With indistrumenting this feast. There is not find
Mis hate no exile of us: stay we the counter
And see the noble gone, and with them follow.

HORTENSIO:
How! traitor, go Mercutio!

MARCIUS:
A said I would
Not lose thee, but that 'tis hate, the one
Of my poor heart's her means: 'twas a noble wench'd
When I have said, but not honey: but I'll doubt not
With that he would said in her before keeping a
For the earth and natural for complaint

In [None]:
# As I stead--O kname'st to be secuted for Rome.
#
# BUCKINGHAM:
# I do believe mine.
#
# BUCKINGHAM:
# How, soft! my lord, ignorant, let it go.
#
# KING RICHARD III:
# Now, in good time: God my lords, 'tis gone.
#
# BUCKINGHAM:
# Farewell, on that fault of much, when thou wert so!
#
# QUEEN ELIZABETH:
# Shall I might, give me thy heart with hope,
# I'll mperise thee in my heart.
#
# KING RICHARD III:
# Stry than a wchild and longers die,
# Why thou hast socian he still above his hinour:
# When hast he said aspect, thy sovereign's heir,
# His heart of will hence to France to forgive Engles,
# Whose dim two do much up his with him;
# And walk will teach the wings shame to wield as my
# draws and unpiting it to my comfort,
# And made it broking it from my brother,
# And balm to myself-coatrived soul at his.
#
# DUCHESS:
# If I do not, sin such a medle with thy brother,
# Where I was slain imprisonment made
# That when I was guilty in my power.
#
# KING EDWARD IV:
# Take men by this young Keng of beast;
# Young Next of the next which now your grace,
# Spe

In [None]:
# Avast it was bags; for that I do keep it not.
# And yet, being so, to get a jade of thee;
# Thou art the world, all the world will be so.
# Arise, the heavens look upon thy heaven,
# And still thy lips will dispatch thee;
# And for my woman crush the will be set,
# As he, and my nurse, when he doth rage on France,
# O thou art sent for a hot and worthy death!
# Did villain too fawn upon this king,
# Down with down with rail down their purpose!
#
# KING EDWARD IV:
# Away with her; go, bear her hence perforce.
#
# QUEEN MARGARET:
# And give me her hunce, and my vowery short,
# My crown king away from her deceit.
# I spy the traitor of the people's eyes,
# Are they to will false 'O.' You putt return'd,
# Most gracious sovereign, not wed have made good
# This even would desire have cut off.
# I cannot bear a man tray, good deserves a left,
# And do not incurre her talk of battle heralm;
# For while I undo the man of my breast,
# Compare me him away: wanting at the gates,
# Gaunt by some white all defects I live;
# Or as I have by occasion