# Building GPT (Transformer)

Here we are going to build up the 'GPT' model which is based on a Transformer architecture from: https://arxiv.org/pdf/1706.03762

Compared to Bengio/WaveNet, the Transformer model (in wishy-washy terms):
- Introduces an 'attention' layer, which allows the existing context spend compute/time to 'communicate' with each other
- Introduces a feed-forward layer, which allows the existing context spend compute/time to properly process all the 'communications' they received to find connections to other words.
- Add in a normalization layer (LayerNorm rather than BatchNorm, i.e. we normalize across the feature dims rather than the batch dims. More stable for variable length contexts)

We then repeat this block multiple times to fit our model better to the dataset.

In [21]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
g = torch.Generator().manual_seed(2147483647)

In [22]:
# Download the dataset
# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

In [23]:
# First up, let's use a new dataset. We'll be using the shakespeare dataset from karpathy's makemore gpt episode.
with open("input.txt", 'r') as f:
    data = f.read()

print(f"Length of dataset in chars: {len(data)}")
print(data[:50])

Length of dataset in chars: 1115394
First Citizen:
Before we proceed any further, hear


In [24]:
# Let's encode the dataset
block_size = 256

vocab = sorted(list(set(data)))
vocab_size = len(vocab)
print(''.join(vocab))
print(f"Vocab size: {vocab_size}")

stoi = {c:i for i, c in enumerate(vocab)}
itos = {i:c for i, c in enumerate(vocab)}

print(stoi)
print(itos)

# For fun & extra practice, I reimplemented this using lambda like Karpathy does
encode = lambda s: [stoi[c] for c in s]
decode = lambda x: ''.join([itos[ix] for ix in x])

test = "hi there"
print(encode(test))
print(decode(encode(test)))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Vocab size: 65
{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}
{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z',

In [25]:
# Let's prepare the dataset & split into train/valid/test sets

#1. Encode the entire dataset as numbers, store as a torch.tensor
dataset = torch.tensor(encode(data))
print(dataset[:4])

# 2. Split the dataset into training and validation (also test, but karpathy doesn't have a test set so I'm sticking with train/val to compare)
train_split = int(0.9*dataset.shape[0])

train_data = dataset[:train_split]
val_data = dataset[train_split:]

#2. Generate X,Y pairs
# a. We want to generate pairs of context length. Thus, on a single context block_size=8, our model makes 8 predictions.
# This is more efficient (data loaded once to GPU, processed 8 times) and allows it to infer using variable length input

# b. We want to use batch to batch together independent contexts. This will give us better utilization of the GPU (highly parallelizable).

torch.manual_seed(1337) #To compare with Karpathy's generation

batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,)) #pick any int up to the last element - block_size
    
    # We generate a batch with batch_size independent/random rows, each of length block_size
    x = torch.stack([data[i : i + block_size] for i in ix])
    y = torch.stack([data[i + 1 : i + block_size + 1] for i in ix])
    return x,y

xb, yb = get_batch('train')

print(xb) #B, T
print(yb) #B, T

# Now let's generate the batches for the entire dataset

tensor([18, 47, 56, 57])
tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


In [61]:
# Now, we'll be using PyTorch modules (rather than creating from scratch)

import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

# Predict the next token given the input token
class BigramLanguageModel(nn.Module):

    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size) #embedding output size = vocab size

    def forward(self, idx, targets=None):
        logits = self.token_embedding_table(idx) #B, T, C

        if targets is None:
            loss = None
        else:
            # PyTorch works with B, C, T (instead of B, T, C) thus we need to change the view
            B, T, C = logits.shape
            logits = logits.view(B * T, C)
            targets = targets.view(B * T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        
        for _ in range(max_new_tokens): #We iterate as many times as tokens we want (unlike before, we don't stop on a '.' token)
            logits, loss = self(idx)
            #With Bigram model, we only care about the last character
            logits = logits[:, -1, :] #Batch, Time, Channel/Embedding. With bigram, we only select the last context
            probs = torch.softmax(logits, dim=1) #Note: dim=1 because we want a softmax on the embeddings/distribution not the batch/time

            idx_next = torch.multinomial(probs, num_samples=1, replacement=True)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

m = BigramLanguageModel(vocab_size)
out, loss = m(xb, yb)
print(out)
print(loss)
print(out.shape) #4, 8, 65 (B=4, T=8, C=65) where B is batch, T is 'time' and C=embedding

tensor([[-1.5101, -0.0948,  1.0927,  ..., -0.6126, -0.6597,  0.7624],
        [ 0.3323, -0.0872, -0.7470,  ..., -0.6716, -0.9572, -0.9594],
        [ 0.2475, -0.6349, -1.2909,  ...,  1.3064, -0.2256, -1.8305],
        ...,
        [-2.1910, -0.7574,  1.9656,  ..., -0.3580,  0.8585, -0.6161],
        [ 0.5978, -0.0514, -0.0646,  ..., -1.4649, -2.0555,  1.8275],
        [-0.6787,  0.8662, -1.6433,  ...,  2.3671, -0.7775, -0.2586]],
       grad_fn=<ViewBackward0>)
tensor(4.8786, grad_fn=<NllLossBackward0>)
torch.Size([32, 65])


In [60]:
idx = torch.zeros((1,1), dtype=torch.long)
print(decode(m.generate(idx, 100)[0].tolist()))


oU,pxIBFAYuxKXe.jeh
sa!3MGFrSjuM:wX!?BTMl!.?,M:bQzPHpYfN!Cbo'MmtDxBkDD3SBjyFdmY'DOqkWeRjlxyJB-bVbfd&


In [64]:
optimizer = torch.optim.AdamW(m.parameters(), lr=1e-4) #Can also use higher lr for smaller networks (e.g. 1e-3)

In [87]:
# Let's start training the model
batch_size=32
for _ in range(10000):

    # Load a random sample
    xb, yb = get_batch('train')
    
    logits, loss = m(xb, yb)

    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

print(loss.item())


2.6773767471313477


In [110]:
print(decode(m.generate(idx, 400)[0].tolist()))

#Output obviously very bad (for bigram model), but we'll improve it using transformer architecture.


woaterirsth!atsigmpre!Yke n my;
Angdsbou se ngro'yYCO:
h RIs stine tchu nes oiVDWh My q ANCK:pe cth h ske thity to'ravir igUCXwer dul.
Rose'ror.W:MftiswIrower.
W?Y peyergise w, han kW:3: lvesk

MAas dsereYoonond, thoullqNhagmet maistf hyfre w;RThe!UK:woom zes seakzzkHe p,
RIOR:
ofqm dmpptoXJOLLLLIN:Cybannclve s;


OULItyatre sw, makilof WBandveZELCAhetrengratswh,
s cueathureat gheiest nthabuis nd 


## Messing around with self-attention

In [144]:
torch.manual_seed(1337)

B, T, C = 4, 8, 2
x = torch.randn(B, T, C) #batch time channels
x.shape

# We want the tokens to talk with each other. We only want the 5th token to speak with 1-4th (previous context -> current)
# Easiest way -> just average the past tokens + current ("current context of me with the past history")

torch.Size([4, 8, 2])

In [145]:
xbow = torch.zeros((B, T, C)) #bow -> bag of words ('averaging info')

#Inefficient implementation (using matrix multiplication will be much fasterdf)
for b in range(B):
    for t in range(T):
        xprev = x[b, :t+1]
        xbow[b, t] = torch.mean(xprev, 0)

print(x[0])
print(xbow[0])

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


In [151]:
wei = torch.tril(torch.ones(T, T))
wei = wei / wei.sum(dim=1, keepdim=True)
print(wei)

#Our dot product becomes an efficient weighted sum average across the previous terms.
xbow2 = wei @ x # 8, 8 @ (4, 8, 2) -> Torch expands to 4, 8, 8 @ 4, 8, 2
torch.allclose(xbow, xbow2)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3333, 0.3333, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2500, 0.2500, 0.2500, 0.2500, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2000, 0.2000, 0.2000, 0.2000, 0.2000, 0.0000, 0.0000, 0.0000],
        [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000, 0.0000],
        [0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.1429, 0.0000],
        [0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250, 0.1250]])


False

In [170]:
# 3rd version. Identical to the first version, except now we use the softmax (exponentiate then normalize)
# By setting -inf on the terms, we ensure that we only learn on the non-inf elements

tril = torch.tril(torch.ones(T, T))
tril
wei = torch.zeros(T,T)
wei
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=1)
xbow3 = wei @ x
print(x[0])
print(xbow3[0])
torch.allclose(xbow2, xbow3)

tensor([[ 0.1808, -0.0700],
        [-0.3596, -0.9152],
        [ 0.6258,  0.0255],
        [ 0.9545,  0.0643],
        [ 0.3612,  1.1679],
        [-1.3499, -0.5102],
        [ 0.2360, -0.2398],
        [-0.9211,  1.5433]])
tensor([[ 0.1808, -0.0700],
        [-0.0894, -0.4926],
        [ 0.1490, -0.3199],
        [ 0.3504, -0.2238],
        [ 0.3525,  0.0545],
        [ 0.0688, -0.0396],
        [ 0.0927, -0.0682],
        [-0.0341,  0.1332]])


True

In [190]:
torch.manual_seed(1337)
B,T,C = 4,8,32
x = torch.randn(B, T, C)

head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False) #We don't share the raw x values but rather v 'public' info about x
k = key(x) # B, T, 16 (key -> given an context+pos embedding, we get 'multi-head'). Each context+pos gets different multi-head
q = query(x) # B, T, 16 (query -> given an context+pos embedding, we get 'multi-head')

# 'communication' between 
wei = q @ k.transpose(-2, -1) #(B, T, 16) @ (B, 16, T) -> (B, T, T) | Each we multiply every token multi-heads against each other

# Here we create a mask to ensure each token only focuses on the previous context (and not polluted by future context)


# DECODER architecture (in some cases, we want fully connected ie 'encoder' architectures)
# Simple average of current token + all the previous tokens
tril = torch.tril(torch.ones(T, T))
# wei = torch.zeros(T,T)
wei = wei.masked_fill(tril == 0, float('-inf')) #only present in decoder block, without this it becomes encoder.
wei = F.softmax(wei, dim=-1) #weirdly, I need dim=2 instead of dim=1 like karpathy

v = value(x) #V is the thing that gets aggregated (and not the raw value of x)
out = wei @ v

# This should basically tell us how much each previous token context is important for the current token

# Notes:
# Different tokens find other tokens more/less interesting. We want to gather this in a data dependent way
# I.e. we don't want everyone to be uniform/0s, we want to gather this in a data dependent way.

# Every node sends a query and a key. Query -> what am i Looking for? Key -> What do I contain?
# Dot product of key and queries (i.e. my query dot products with all other keys). If aligned, will interact higher thus I will learn about other tokens more than the others

#Before wei was just a constant. But now every batch element will have different weis
out.shape

# Keys/Values could come from other blocks/nodes -> 'cross-attention'

# Wei needs to be fairly diffuse (otherwise softmax converges 'one-hot' values thus we're only aggregating info from a single node)
# Thus we need to scale with the sqrt(heads)
# Before the softmax, make sure wei = wei * (head_size ** -0.5)

torch.Size([4, 8, 32])

In [191]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)