In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

In [2]:
with open("shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()

f"dataset length {len(text)}"

'dataset length 1115394'

In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
f"vocab size: {vocab_size}; vocab: {''.join(chars)}"

"vocab size: 65; vocab: \n !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"

In [4]:
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [5]:
n = int(0.9*len(text))
data = torch.tensor(encode(text), dtype=torch.long)
train_data = data[:n]
val_data = data[n:]

In [6]:
# Self attention head
B, T, C = 4, 8, 32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False) # Kinda the metadata contained by each token (how relevant each token is), bias=False - no added constant :)
query = nn.Linear(C, head_size, bias=False) # Kinda the information each token is looking for
value = nn.Linear(C, head_size, bias=False) # The actual data contained at each token
k = key(x) # (B, T, 16)
q = query(x) # (B, T, 16)
# Similarity between each key and query is the weights
wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) --> (B, T, T)
wei = wei * C**-0.5 # Scale the weights down so softmax doesn't turn into one hot
# Weights are now data dependent

tril = torch.tril(torch.ones(T, T)) # (T, T)
# Mask future tokens
wei = wei.masked_fill(tril==0, float("-inf")) # (B, T, T)
wei = F.softmax(wei, dim=-1)

# Value tensor, the actual data contained by each token, what is offered by each token
v = value(x) # (B, T, 16)

# The Query-Key mechanism decides how much of each value should contribute to the final representation of a token.
out = wei @ v # (B, T, T) @ (B, T, 16) --> (B, T, 16)
out.shape

torch.Size([4, 8, 16])

Notes:
- Attention is a communication mechanism. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.
- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.
- Each example across batch dimension is of course processed completely independently and never "talk" to each other
- In an "encoder" attention block just delete the single line that does masking with tril, allowing all tokens to communicate. This block here is called a "decoder" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.
-"self-attention" just means that the keys and values are produced from the same source as queries. In "cross-attention", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)
- "Scaled" attention additional divides wei by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below

In [11]:
# Hyperparams

batch_size = 16 # how many independent sequences will we process in parallel?
block_size = 32 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
eval_iters = 100
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

device

'mps'

In [12]:
class Head(nn.Module):

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size))) # Stored with the model but not treated as trainable weights with register_buffer

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, C = x.shape

        k = self.key(x)
        q = self.query(x)
        
        wei = q @ k.transpose(-2, -1) * C**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float("-inf"))
        wei = F.softmax(wei, dim=1)
        wei = self.dropout(wei)

        v = self.value(x)
        out = wei @ v
        return out

class MultiHeadAttention(nn.Module):
    
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd) # Projection back into residual pathway
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1) # from (B, T, head_size) to (B, T, C)
        out = self.dropout(self.proj(out))
        return out

class FeedForward(nn.Module):
    
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd), # Projection back into residual pathway
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

class Block(nn.Module):

    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd) # Unit gaussian distribution across layers at init
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        # x + are residual connections, helps training
        x = x + self.sa(self.ln1(x)) # Different from original paper, prenorm
        x = x + self.ffwd(self.ln2(x))
        return x

class Transformer(nn.Module):

    def __init__(self):
        super().__init__()

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embd) # Final layer norm
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # idx and targets are both (B,T) tensor of integers
        tok_emb = self.token_embedding_table(idx) # (B,T,C)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)
        x = tok_emb + pos_emb # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, _ = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

In [15]:
model = Transformer()
model = model.to(device)
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e3, 'K parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            X, Y = X.to(device), Y.to(device)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

209.729 K parameters


In [None]:
for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')
    xb, yb = xb.to(device), yb.to(device)

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

step 0: train loss 4.3625, val loss 4.3644


RuntimeError: Placeholder storage has not been allocated on MPS device!

In [None]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=2000)[0].tolist()))


YOfzKkkkzknnkuu--uuullllllyond is inget undinst she han gould mere ber ansk chrar be, in;er ary theNi.

LUARd I bar.
Ura chlore, ad thear tue.

Yor herit ind gun ta mrin gove! anlan the dimed be be, aunarr qure ager coroud Beamer' herear oorgaard
Wautin;
Fer tho-wel tord neer Wigel thess ont hrac here! I thel deme hen om hapke!
End Enve ye jor tauur t ond to son srity fil tinep coude ayor le, to 
afr me rand old andomingwticr:
Anco om fnebset gong thor;
TE wer ndanlangod ind atal ice qxugulid hirde me tond acecingen ond and daar bersed weer ave ed do we, din bey tol; ave:
s Ke mof aung tr thor yor mo, muk hinoals ake wer, ice llrecary com,
Hares the ore ba.
E
PUCKILYRD: nor, I Jmern qunto nin he,
Fove tut y ree;
End my be, sind mine!
 ingud jon dod, kir! rib'al sr arey
Y Longe nro. so mande yonur trend
Topem'gs nesg indam yn ton wr bente,
Dr Dom mor wom are ther and uard vece ima, 
And hat ot genill ton baepr oro, nod waved oned va ey ude,ur he hiry.

Gor toe ter,
I Takr yivest herbic