# Objectives
- Intro to PyTorch
- In-depth understanding of transformers
- Decoder-only transformer network

Heavy credit to Andrej Karpathy and this video: https://youtu.be/kCc8FmEb1nY although I made many changes.
Some highlights:
- Hyperparameters and model more closely following the original paper. (See comments with double quotes for references to the original paper.)
- Utilizing multiple GPUs
- Use datasets/dataloaders instead of homegrown batching
- Label smoothing
- Better corpus for training :)

Using the paper as-is results in heavy overfitting (at least with this training data). Further extension would be to investigate how to reduce the overfitting (increase dropout, add it in more places, etc).

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import os
import time
from torch.nn import functional as F

# Define hyperparameters
val_split = 0.1
batch_size = 64 # how many independent sequences will we process in parallel
block_size = 256 # maximum context length, in characters
max_iters = 5000
eval_interval = 500
learning_rate = 2e-4 # Implementation of varying learning rate is in the training loop
eval_iters = 200
d_model = 384 # "d_model = 512"
n_head = 6 # "In this work we employ h = 8"
n_layer = 6 # "a stack of N = 6 identical layers"
d_k = d_model // n_head # "For each of these we use d_k = d_v = d_model/h = 64" aka head size
d_ff = 4 * d_model
dropout = 0.1
smoothing = 0.1
warmup_steps=4000

In [2]:
# Use multiple GPUs
device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)

# Download and Preprocess Data

Here, we download Monty Python and the Holy Grail from NLTK.

In [3]:
import nltk
nltk.download('webtext')
from nltk.corpus import webtext

# Get unique characters from text
text = webtext.raw('grail.txt')
" ".join(text.split()[:45])



[nltk_data] Downloading package webtext to /usr/share/nltk_data...
[nltk_data]   Package webtext is already up-to-date!


'SCENE 1: [wind] [clop clop clop] KING ARTHUR: Whoa there! [clop clop clop] SOLDIER #1: Halt! Who goes there? ARTHUR: It is I, Arthur, son of Uther Pendragon, from the castle of Camelot. King of the Britons, defeator of the Saxons, sovereign of all England!'

Get all unique characters used in the text - this is our vocabulary size.

In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

print("Chars: " + str(chars))
print("Vocab size: " + str(vocab_size))

Chars: ['\n', ' ', '!', '#', "'", '(', ')', ',', '-', '.', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'Y', 'Z', '[', ']', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
Vocab size: 76


In [5]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [6]:
class CustomWebtextDataset(Dataset):  
    def __init__(self, raw, block_size):
        self.encoded_text = encode(raw)
        self.block_size = block_size

    def __len__(self):
        return len(self.encoded_text) - self.block_size

    def __getitem__(self, idx):
        x = torch.tensor(self.encoded_text[idx:idx + self.block_size], dtype=torch.long).to(device)
        y = torch.tensor(self.encoded_text[idx + 1:idx + self.block_size + 1], dtype=torch.long).to(device)
        return x, y


In [7]:
split_ind = int(len(text)*val_split)

train = CustomWebtextDataset(webtext.raw('grail.txt')[:-split_ind], block_size)
val = CustomWebtextDataset(webtext.raw('grail.txt')[-split_ind:], block_size)

train_dataloader = DataLoader(train, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=device))
val_dataloader = DataLoader(val, batch_size=batch_size, shuffle=True, generator=torch.Generator(device=device))

In [8]:
# Example x, y
i, xy = next(enumerate(train_dataloader))
(x, y) = xy
print(y.shape)

torch.Size([64, 256])


# Utility functions

In [9]:
# Function to estimate the loss of the model in its current state.
@torch.no_grad() # good practice, tells pytorch that we won't call backprop on this
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros((eval_iters))
        for i in range(eval_iters):
            if split == 'train':
                j, xbyb = next(enumerate(train_dataloader))
            else:
                j, xbyb = next(enumerate(val_dataloader))
            x, y = xbyb
            logits, loss = model(x, y)
            losses[i] = loss.mean().item()
        out[split] = losses.mean()
    generate_example(model, 500) # See how the model is doing
    model.train()
    return out

In [10]:
def generate_example(model, new_tokens):
    # generate from the model
    context = torch.zeros((1, 1), dtype=torch.long, device=device)
    print(decode(model.module.generate(context, max_new_tokens=new_tokens)[0].tolist()))

# Classes

In [11]:
class SelfAttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.key = nn.Linear(d_model, d_k, bias=False)
        self.query = nn.Linear(d_model, d_k, bias=False)
        self.value = nn.Linear(d_model, d_k, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout(p=dropout)
    def forward(self, x, d_k):
        B, T, C = x.shape # Batch size, time dimension, num channels
        
        # "An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors.
        # The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the 
        # corresponding key."
        
        # Self attention: keys/queries/values all come from the same source (self-attending)
        # Cross-attention would instead have queries coming from x and keys/values coming from a separate source
        
        # Self-attention: different tokens find other tokens more or less interesting (data dependent). Gather information from the past in a data-dependent way.
        # Every token emits a query and a key vector. Query - what am I looking for? Key - what do I contain? Affinities between tokens are given by dot product 
        # between the key and the query.
        
        # Start off with "Scaled Dot-Product Attention": Attention(Q, K, V) = softmax(Q@K.T/sqrt(d_k))@V
        qdotk = self.query(x) @ (self.key(x)).transpose(-2, -1) # (B, T, d_k) @ (B, T, d_k).transpose(-2, 1) -> (B, T, d_k) @ (B, d_k, T) -> (B, T, T)
        
        # We scale by 1/sqrt(d_k) to make the resulting distribution after softmaxing more diffuse - becomes much closer to a one-hot encoding without this step, which is not useful
        scaled_dotprod_att = qdotk / (d_k**0.5)
        
        # "We also modify the self-attention sub-layer in the decoder stack to prevent positions from attending to subsequent positions. This masking... ensures
        # that the predictions for position i can depend only on the known outputs at positions less than i"
        # Because this is self-attention, we need to set illegal values to -inf to prevent positions from attending to subsequent positions.
        # i.e. information only flows from the previous context to the current timestamp, and cannot get any information about the future because we're about to predict the future.
        scaled_dotprod_att = scaled_dotprod_att.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        
        # Now we can softmax and dot product with v.
        softmax_att = F.softmax(scaled_dotprod_att, dim=-1)
        softmax_att = self.dropout(softmax_att)
        v = self.value(x)
        return softmax_att @ v
        

In [12]:
class MultiSelfAttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.heads = nn.ModuleList([SelfAttentionHead() for _ in range(n_head)])
        self.proj = nn.Linear(d_k * n_head, d_model)
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, x): # input (B, T, C)
        # MultiHead(Q,K,V) = Concat(head1,...,headh)W^O where headi = Attention(QW_i^Q, KW_i^K, VW_i^V)
        out = torch.cat([h(x=x, d_k=d_k) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [13]:
class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        # FFN(x) = max(0,wW1 + b1)W2 + b2
        # "The dimensionality of input and output is d_model = 512, and the inner-layer has dimensionality d_ff = 2048."
        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
            nn.Dropout(p=dropout), # "We apply dropout to the output of each sub-layer, before it is added to the sub-layer input and normalized."
        )
    
    def forward(self, x):
        return self.ff(x)

In [14]:
class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.msah = MultiSelfAttentionHead()
        self.ff = FeedForward()
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
    def forward(self, x): # input (B, T, C)
        # "That is, the output of each sub-layer is LayerNorm(x + Sublayer(x))"
        # This was explained as "FFN (MLP) is just 'thinking' on the previous 'communication' (attention). Repeated decoder blocks are just interspersing thinking w/ communicating."
        # Layer norm after skip connections as shown in the figure
        x = self.ln1(x + self.msah(x))
        x = self.ln2(x + self.ff(x))
        return x
        

In [15]:
class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, d_model) # (vocab_size, C)
        # From the paper, creating a position embedding table yields "nearly identical results" to using the sinusoidal positional encoding
        self.position_embedding_table = nn.Embedding(block_size, d_model) # (block_size, C)
        self.dropout = nn.Dropout(p=dropout)
        self.blocks = nn.Sequential(*[TransformerBlock() for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size)
    
    def forward(self, idx, targets=None):
        B, T = idx.shape
        
        # idx and targets are both (B,T) tensor of integers
        
        # Convert tokens to embedding vectors, each with length d_model
        tok_emb = self.token_embedding_table(idx) # (B, T, C)
        # Add positional data via embeddings
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))

        x = tok_emb + pos_emb # (B, T, C)
        x = self.dropout(x) # "n addition, we apply dropout to the sums of the embeddings and the positional encodings in both the encoder and decoder stacks."
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x) # (B, T, vocab_size)
        
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets, label_smoothing=smoothing)
        
        return logits, loss
    
    def generate(self, idx, max_new_tokens):
        # idx is (B, T) array of indices in the current context
        for _ in range(max_new_tokens):
            # crop idx to the last block_size tokens
            idx_cond = idx[:, -block_size:]
            # get the predictions
            logits, loss = self(idx_cond)
            # focus only on the last time step
            logits = logits[:, -1, :] # becomes (B, C)
            # apply softmax to get probabilities
            probs = F.softmax(logits, dim=-1) # (B, C)
            # sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)
            # append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx
        
        

In [16]:
# Train loop

model = Transformer()
model = nn.DataParallel(model)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, betas=(0.9, 0.98), eps=1e-09)

start = time.time()
for iter in range(max_iters):
    # Vary the learning rate as in the paper - increase learning linearly for the first warmup_steps training steps, 
    # and decreasing it thereafter proportionally to the inverse square root of the step number.
    for g in optimizer.param_groups:
        g['lr'] = (d_model**-0.5)*min((iter+1)**-0.5, (iter+1)*(warmup_steps**-1.5))

    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        now = time.time()
        print(f"After {(now - start)} seconds: step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")
    
    # sample a batch of data
    #xb, yb = get_batch('train')
    i, xbyb = next(enumerate(train_dataloader))
    xb, yb = xbyb
    
    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.mean().backward()
    optimizer.step()
    
generate_example(model, 500)





R# .S?13k[EIH
nL;y2J;(RB;KQeNo))AuwClKJ),.6)Fb1z'T0-2(68vyv8:s(7Y#).l9o?7S
b
ZYQB
im
lT]wxqilOb872Nq fSFq)SpbPKYpTRavQYRy'TWEN:E-P#SDI:jI.S8gW9FEyxeS(Wg:SFSFITYx pRGDQb1U9,YV
(VH5VI:pxKT8,QMCJuljpH2SNt'[ChiLg?,(:!REI#M9mVsrAUmzdgpW-c8.2z'WS5p(.:9#tx#iCzC.8I?]'Qun73Rz0pSCRRIyLr
Ae[qj hJ1z4mAve7sLkAhyF1owAN8'J6FRCI90J'0:i4O.#KR?4rMl,YINKGrx(q
Sa]a0e0t4JIx)KI'IMmzBAEpF5rLgKy9RYIsP92yR]qBb
N
,RJz!q0ZJce['Fd3HWeG[sg;vpZS'b
#Y[IVemYtgAJ'J)O 2nt0D]ar]eS[RiA(Vt.AN[RIFhU#3yq[
viqk2j9k,N'icjbv[DRp!U1pJMC7
After 59.64755892753601 seconds: step 0: train loss 4.4978, val loss 4.5022

KNGH6AHA?
AHTHUNCR: URTHUR: Bly ba, or. ShZo..   Yole yowhing s pee !
ARTHEREm NCEN KNCEVIS: bRopp s2: d s o, ayo! Conery, yis...p  Bldem.. BGUerju's ahr l cicjStoresquacoue OTay Umain! 
ON WDERT BE:3:Dis bet'sm Hes, VERch, Sor9 oncou nohut!
cep!   Wh!
ARTECERT: INon!.
D:, utcioap [m,RENell, yof BERTI'thit c0jur ck CHSOTE: andat, bo, c] Zoondin.5: cahe . rg] w Sanyot!  Caghmbe Ife mst, yoURERve ad meat t, Founomverrd 