# Learn GPT from scratch

In [36]:
import os

# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
if not os.path.isfile("./datasets/corpora/shakespeare.txt"):
    !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt > datasets/corpora/shakespeare.txt

In [10]:
with open("datasets/corpora/shakespeare.txt", 'r', encoding='utf-8') as f:
    text = f.read()

## Tokenization and dataset creation

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset, random_split
import pandas as pd
import numpy as np
import math

torch.manual_seed(1337)

<torch._C.Generator at 0x7f7b543cb430>

In [12]:
# Simple dumb ASCII character-level "encoding" since all training data is ASCII
def encode_text(text):
    return([ord(t) for t in text])

def decode_text(indices):
    return([chr(x) for x in indices])


In [13]:
# Tensorify data, put it in dataset
data = torch.tensor(encode_text(text), dtype=torch.int32)

split_idx = int(0.9 * len(data))
train_data = data[:split_idx]
test_data = data[split_idx:]

We have to make a custom PyTorch dataset class to automatically generate the "context" windows at load time. This allows us to avoid keeping these windows around in memory when not in use:

In [31]:
class TextDataset(Dataset):
    def __init__(self, data_tensor, context_size):
        self.data_tensor = data_tensor
        self.context_size = context_size
    
    def __len__(self):
        return len(self.data_tensor)

    def __getitem__(self, index):
        if index < self.context_size:
            x = F.pad(self.data_tensor[:index], (self.context_size - index, 0), value=0)
        else:
            x = self.data_tensor[index - self.context_size:index]
        
        y = self.data_tensor[index]
        return x, y

NOTE 2023-03-25: I think this is bugged, and that's the reason the training loss is so damn high. Testing:

In [34]:
train_dataset = TextDataset(train_data, 8)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=False)

step = 0
for x, y in train_dataloader:
    print(f"Step {step}:")
    for b in x.tolist():
        print(b)
        print("---")

    print(decode_text(y.tolist()))
    step += 1
    if step > 5:
        break



Step 0:
[0, 0, 0, 0, 0, 0, 0, 0]
---
[0, 0, 0, 0, 0, 0, 0, 70]
---
['F', 'i']
Step 1:
[0, 0, 0, 0, 0, 0, 70, 105]
---
[0, 0, 0, 0, 0, 70, 105, 114]
---
['r', 's']
Step 2:
[0, 0, 0, 0, 70, 105, 114, 115]
---
[0, 0, 0, 70, 105, 114, 115, 116]
---
['t', ' ']
Step 3:
[0, 0, 70, 105, 114, 115, 116, 32]
---
[0, 70, 105, 114, 115, 116, 32, 67]
---
['C', 'i']
Step 4:
[70, 105, 114, 115, 116, 32, 67, 105]
---
[105, 114, 115, 116, 32, 67, 105, 116]
---
['t', 'i']
Step 5:
[114, 115, 116, 32, 67, 105, 116, 105]
---
[115, 116, 32, 67, 105, 116, 105, 122]
---
['z', 'e']


## Attention is all you need (注目こそが必要なすべて)

In [8]:
class MultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.0, bias=True, device=None, dtype=None):
        super(MultiheadAttention, self).__init__()

        # Save variables
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.d_k = embed_dim // num_heads

        self.Q = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.K = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.V = nn.Linear(embed_dim, embed_dim, bias=bias)

        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        nn.init.kaiming_normal_(self.out_proj.weight, mode='fan_in', nonlinearity='linear')

    def forward(self, query, key, value, key_padding_mask=None):
        batch_size = query.size(0)

        # Apply linear layers
        q = self.Q(query) # [B, C, E]
        k = self.K(key) # [B, C, E]
        v = self.V(value) # [B, C, E]

        # Mutate dimensions so the attention matmul can get rid of the inner d_k
        q = q.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)  # [batch_size, num_heads, C, d_k]
        k = k.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)  # [batch_size, num_heads, C, d_k]
        v = v.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)  # [batch_size, num_heads, C, d_k]
        
        # Get raw attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) # [B, num_heads, C, C]

        # Apply mask, if necessary
        if key_padding_mask is not None:
            """
            MAY BE WORTH DEBUGGING

            if key_padding_mask.dim() == 3:
                # If the mask is 3D, add an extra dimension for the num_heads
                key_padding_mask = key_padding_mask.unsqueeze(1)  # [batch_size, 1, seq_len, seq_len]
            else:
                # If the mask is 2D, add dimensions for the num_heads and the 'query' sequence length
                key_padding_mask = key_padding_mask.unsqueeze(1).unsqueeze(2)  # [batch_size, 1, 1, seq_len]
            """
            # Apply the mask to attention scores
            scores = scores.masked_fill(key_padding_mask, float('-inf'))

        # Scale by sqrt(k)
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        out = attn @ v # [B, num_heads, C, d_k]

        # Concat and project
        # Swap C and num_heads, force memory to coalesce, then fuse back num_heads and d_k together
        out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
        # Project: give attention "time to think". Maybe this should be part of a different module but whatever
        out = self.out_proj(out)
        return(out)



In [9]:
class FeedForward(nn.Module):
    def __init__(self, embed_dim, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(embed_dim, 4 * embed_dim),
            nn.ReLU(),
            nn.Dropout(dropout)
            nn.Linear(4 * embed_dim, embed_dim),
        )

    def forward(self, x):
        return(self.net(x))

In [10]:
class Block(nn.Module):
    """Self-attention"""
    def __init__(self, embed_dim, num_heads, mask, dropout=0.2):
        super(Block, self).__init__()  
        self.register_buffer("mask", mask)
        self.head = MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, dropout=dropout)
        self.ffwd = FeedForward(embed_dim=embed_dim, dropout=dropout)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        # Residual connections
        x = self.ln1(x)
        x = x + self.head.forward(x, x, x, key_padding_mask=self.mask) 
        out = x + self.ffwd(self.ln2(x))
        return out


In [11]:
class GPT(nn.Module):
    def __init__(self, embedding_dim, vocab_size, context_size, lr=1e-3):
        # Inherit PyTorch stuff
        super(GPT, self).__init__()

        # Save variables for later
        self.embedding_dim = embedding_dim
        self.output_dim = vocab_size
        self.context_size = context_size

        # Initialize layers. Sadly this breaks the whole "self.layers: concept but whatever
        self.tok_embed = nn.Embedding(vocab_size, embedding_dim)
        self.pos_embed = nn.Embedding(context_size, embedding_dim)

        NUM_HEADS=6
        NUM_LAYERS=6
        
        mask = torch.tril(torch.ones(self.context_size, self.context_size)).bool()
        mask = ~mask
        self.register_buffer(mask)

        self.blocks = nn.Sequential(
            *[Block(embed_dim=embedding_dim, num_heads=NUM_HEADS, mask=mask) for _ in range(NUM_LAYERS)],
            nn.Dropout(0.2)
        )

        # Final feed-forward layer from embeddings
        self.ffwd = nn.Linear(embedding_dim, out_features=vocab_size)

    def forward(self, x):
        tok_embed = self.tok_embed(x)
        tok_embed = tok_embed.view(-1, self.context_size, self.embedding_dim)
        pos_embed = self.pos_embed(torch.arange(0, self.context_size, device="cuda")).unsqueeze(0)
        x = tok_embed + pos_embed

        # The actual attention is all you need here!
        # B*C*C cutting out the future
        x = self.blocks(x)

        preds = self.ffwd(x)
        return(preds)
    
    def infer(self, x):
        with torch.no_grad():
            res = self.forward(x)
            return(res)


## Training

In [19]:
def compute_loss(model, criterion, x, y):
    logits = model(x)
    last_logits = logits[:, -1, :]
    log_probs = nn.LogSoftmax(dim=1)(last_logits)
    loss = criterion(log_probs, y.view(-1).long())
    return loss

In [47]:
EMBEDDING_NDIM = 384
VOCAB_SIZE = 128
BATCH_SIZE=64
# "Context window"
BLOCK_SIZE=256
LR=1e-3

train_dataset = TextDataset(train_data, BLOCK_SIZE)
test_dataset = TextDataset(train_data, BLOCK_SIZE)

# Janky training code
model = GPT(
    embedding_dim=EMBEDDING_NDIM, 
    vocab_size=VOCAB_SIZE,
    context_size=BLOCK_SIZE,
    lr=LR
    )

model = model.to('cuda')
optimizer = optim.AdamW(model.parameters(), lr=LR)
# TODO Fix this!
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10000, gamma=0.2)
criterion = nn.NLLLoss()

In [50]:
from torch.utils.
EPOCHS = 1
STEPS = 5000
VAL_INTERVAL = 100

losses = []
model.train()

train_dataloader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    num_workers=4
)

test_dataloader = DataLoader(test_dataset, batch_size=512, num_workers=4, shuffle=True)

step = 0
for epoch in range(EPOCHS):
    for data, target in train_dataloader:
        data = data.to('cuda')
        target = target.to('cuda')

        loss = compute_loss(model, criterion, data, target)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        losses.append(loss.cpu().detach().numpy())

        if step % VAL_INTERVAL == 0:
            with torch.no_grad():
                model.eval()
                for x, y in test_dataloader:
                    x = x.to("cuda")
                    y = y.to("cuda")

                    batch_loss = compute_loss(model, criterion, x, y)
                    total_loss += batch_loss.item() * 512
                    total_samples += 512
                    if total_samples > 10:
                        break

                average_loss = total_loss / total_samples
                print(f"Step {step}; loss: {average_loss}")
                model.train()

        step += 1
        if step >= STEPS:
            break


Step 0; loss: 3.3686537742614746
Step 100; loss: 3.3535483678181968
Step 200; loss: 3.3484479188919067
Step 300; loss: 3.344235420227051
Step 400; loss: 3.338580369949341
Step 500; loss: 3.330465725490025
Step 600; loss: 3.333183079957962
Step 700; loss: 3.3319032986958823
Step 800; loss: 3.332624101638794
Step 900; loss: 3.3325188810175117
Step 1000; loss: 3.331260542074839
Step 1100; loss: 3.3311657355381894


KeyboardInterrupt: 

In [15]:
PATH = "checkpoints/model.pt"

In [36]:

# Store
torch.save({
    'steps': step,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, PATH)

In [18]:
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

Now we test for overfitting:

In [37]:
import gc
gc.collect()

2399

In [51]:
model.eval()
total_loss = 0.0
total_samples = 0

test_dataloader = DataLoader(test_dataset, batch_size=512, num_workers=4)
with torch.no_grad():
    for x, y in test_dataloader:
        x = x.to("cuda")
        y = y.to("cuda")

        batch_loss = compute_loss(model, criterion, x, y)
        total_loss += batch_loss.item() * x.size(0)
        total_samples += x.size(0)
        if total_samples > 100:
            break

    average_loss = total_loss / total_samples
    print(average_loss)

3.4188449382781982


Finally, we generate:

In [52]:
g_cuda = torch.Generator(device='cuda')

contexts = torch.tensor(encode_text("God"), dtype=torch.int32).to('cuda')
GEN_LENGTH=256

model.eval()
for i in range(GEN_LENGTH):
    transform = nn.LogSoftmax(1)
    # What happens if GEN_LENGTH > CONTEXT? don't worry about it
    #x = F.pad(contexts[:, -BLOCK_SIZE:], (0, BLOCK_SIZE - contexts.size(0)), "constant", 0)
    x = contexts[-BLOCK_SIZE:]
    x = F.pad(x, (0, BLOCK_SIZE - x.size(0)), "constant", 0).unsqueeze(0) # B*T
    preds = model.infer(x)
    preds = preds.squeeze(0)
    probs = torch.softmax(preds, dim=-1)

    # TODO: Broken because of bug with the trailing 0s. FIX THIS
    next_char = torch.multinomial(torch.exp(preds[(-1 if i >= BLOCK_SIZE else i), :]), num_samples=1, generator=g_cuda)
    #context = torch.cat(context, next_char)
    contexts = torch.cat((contexts, next_char), dim=0)
    print(decode_text(next_char.cpu().numpy())[-1], end="")

#print("".join(decode_text(contexts.cpu().numpy())))

,n  aon mr
nr
egtel  s.mangtVk h
 -hinSfii ol ihIraddeioi akpshaC.n trU d aamooaa eoeEhl:daoUabo'm-fddE auh hpyHs wv'erstiInnmwt hnAuNu ufl
I: rl.T   l!eool'lIhl:aynet nna:i yaneehtea hdel
  hse l;imi
  hgy f iuto eoh gBum.umhemvt
a hFo lNsute oaaenh;byeon