<a href="https://colab.research.google.com/github/Ankur-singh/UnderstandingLLMs/blob/main/nbs/LLM_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

To create the most basic GPT style model, we will need the following
 - Model Architecture
 - Data for Training
 - Training Loop
 - Inference (generate next token)

 We will build each component one-by-one in the simplest way possible. The goal is to make sure I understand each component and how they all fit together. In the future notebooks, we will go deeper and focus on improving each of these compoments.

In [8]:
!pip install -Uq torch datasets tiktoken

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m821.2/821.2 MB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m393.1/393.1 MB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.9/8.9 MB[0m [31m101.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m75.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.7/897.7 kB[0m [31m39.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m571.0/571.0 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.2/200.2 MB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m58.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Misc
import math
import tiktoken
from tqdm.notebook import tqdm
from datasets import load_dataset
from dataclasses import dataclass
from prettytable import PrettyTable

## Model

We will start by first defining the model architecture and try to generate some text to make sure everything is working as expected

In [2]:
class MultiheadAttention(nn.Module):
    def __init__(self, emb_dim, heads, context):
        super().__init__()
        assert emb_dim % heads == 0, "`emb_dim` should be a multiple of `heads`"
        self.context = context
        self.mha = nn.MultiheadAttention(emb_dim, heads, batch_first=True)
        self.register_buffer("mask", torch.triu(torch.ones(context, context), diagonal=1).bool())

    def forward(self, x):
        batch, seq_len, _ = x.shape
        seq_len = min(seq_len, self.context)
        attn_mask = self.mask[:seq_len, :seq_len]
        return self.mha(x,x,x,attn_mask=attn_mask, need_weights=False)[0]

class Block(nn.Module):
    def __init__(self, emb_dim, heads, context):
        super().__init__()
        self.mha = MultiheadAttention(emb_dim, heads, context)
        self.mlp = nn.Sequential(
                        nn.Linear(emb_dim, 4 * emb_dim),
                        nn.GELU(),
                        nn.Linear(4 * emb_dim, emb_dim)
                    )
        self.sa_norm = nn.LayerNorm(emb_dim)
        self.mlp_norm = nn.LayerNorm(emb_dim)

    def forward(self, x):
        x = x + self.mha(self.sa_norm(x))
        x = x + self.mlp(self.mlp_norm(x))
        return x

class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.pos_embedding = nn.Embedding(config.context, config.emb_dim)
        self.tok_embedding = nn.Embedding(config.vocab, config.emb_dim)
        self.decoder = nn.Sequential(*[Block(config.emb_dim, config.heads, config.context)
                                        for _ in range(config.layers)])
        self.output = nn.Linear(config.emb_dim, config.vocab, bias=False)
        self.norm = nn.LayerNorm(config.emb_dim)

    def forward(self, x):
        batch, seq_len = x.shape
        pos = torch.arange(seq_len, device=x.device)
        x = self.tok_embedding(x) + self.pos_embedding(pos)
        x = self.decoder(x)
        return self.output(self.norm(x))

In [3]:
@dataclass
class ModelConfig:
    # GPT2 architecture
    vocab: int   = math.ceil(50_257 / 64) * 64 # nearest multiple of 64
    emb_dim: int = 768
    heads: int   = 12
    layers: int  = 12
    context: int = 1024

device = "cuda" if torch.cuda.is_available() else "cpu"
model = GPT(ModelConfig)

In [4]:
# Utility Function: Number of Trainable Parameters
def count_parameters(model, verbose=False):
    if verbose:
        table = PrettyTable(["Module", "Parameters"])
        total = 0
        for name, param in model.named_parameters():
            if param.requires_grad:
                count = param.numel()
                table.add_row([name, count])
                total += count
        print(table)
    else:
        total = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total Trainable Params: {total / 1e6:.2f} M")

count_parameters(model)

Total Trainable Params: 163.11 M


Based on my calculations, this looks good.

> **Note:** This is not exactly save as GPT2 (124M). That is intentional. I tried to incorporate some of the architectural advancement since the GPT2 model.

# Inference / Generation

In [5]:
tokenizer = tiktoken.get_encoding("gpt2")

In [6]:
def generate(prefix, max_new_tokens=10):
    token_ids = torch.tensor(tokenizer.encode(prefix)).unsqueeze(0)

    for _ in range(max_new_tokens):
        with torch.no_grad():
            logits = model(token_ids)
        logits = logits[:, -1, :]
        next_idx = torch.argmax(logits, dim=-1, keepdim=True)
        prefix += tokenizer.decode([next_idx.cpu()])
        token_ids = torch.cat((token_ids, next_idx), dim=1)
    return prefix

prefix = "Once upon a time"
print(generate(prefix))

Once upon a time el chargingeligible Investigators tandem Eff solvingMediclinkedgie


The generated text is all gibrish are the model is not trained yet.

> Note: We will keep getting the same output if we run the above cell multiple times, as there is no randomness in sampling process. We initialize the model with random weights and that is it. So, we must reinitialize our model to get different output.

# Data

In [7]:
dataset = load_dataset("stas/openwebtext-10k")
dataset = dataset["train"].train_test_split(test_size=0.005, seed=47, shuffle=True)
dataset

Access to the secret `HF_TOKEN` has not been granted on this notebook.
You will not be requested again.
Please restart the session if you want to be prompted again.


README.md:   0%|          | 0.00/951 [00:00<?, ?B/s]

openwebtext-10k.py: 0.00B [00:00, ?B/s]

0000.parquet:   0%|          | 0.00/30.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 9950
    })
    test: Dataset({
        features: ['text'],
        num_rows: 50
    })
})

In [8]:
val_ds = "\n\n".join(dataset["test"]["text"])
train_ds = "\n\n".join(dataset["train"]["text"])

val_tokens = tokenizer.encode(val_ds)
train_tokens = tokenizer.encode(train_ds)
len(val_tokens), len(train_tokens)

(52529, 11210521)

In [9]:
class OpenWebTextDataset(Dataset):
    def __init__(self, tokens, max_len):
        self.tokens = tokens
        self.max_len = max_len

    def __getitem__(self, idx):
        idx = idx * self.max_len
        x = self.tokens[idx: idx + self.max_len]
        y = self.tokens[idx+1: idx+1 + self.max_len]
        if len(x) < self.max_len:
            x = x + [tokenizer.eot_token] * (self.max_len - len(x))
        if len(y) < self.max_len:
            y = y + [tokenizer.eot_token] * (self.max_len - len(y))
        return (torch.tensor(x),torch.tensor(y))

    def __len__(self):
        return math.ceil(len(self.tokens)/self.max_len)

val_ds = OpenWebTextDataset(val_tokens, ModelConfig.context)
train_ds = OpenWebTextDataset(train_tokens, ModelConfig.context)
len(val_ds), len(train_ds)

(52, 10948)

In [10]:
batch_size = 6
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, drop_last=True)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=True)

In [11]:
next(iter(val_dl))

[tensor([[ 4723,   198,   198,  ...,   198,   198,   464],
         [  749,  2408,  4876,  ...,   198,   198,     6],
         [ 1135,  2251,  4113,  ...,   319,    12, 28550],
         [26890,     8,  4433,  ...,   286,  1578,  1829],
         [ 2422, 12333,    11,  ...,  7417,    11,   318],
         [ 1327, 12070,   284,  ...,   287,  2253,    13]]),
 tensor([[  198,   198, 11708,  ...,   198,   464,   749],
         [ 2408,  4876,   532,  ...,   198,     6,  1135],
         [ 2251,  4113,  4032,  ...,    12, 28550, 26890],
         [    8,  4433,   281,  ...,  1578,  1829,  2422],
         [12333,    11,  1390,  ...,    11,   318,  1327],
         [12070,   284,  4727,  ...,  2253,    13,   383]])]

# Training Loop

In [12]:
@torch.no_grad()
def evaluate(model, dl):
    model.eval()
    loss = 0
    for (x,y) in dl:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss += F.cross_entropy(logits.flatten(0, 1), y.flatten()).cpu().item()
    model.train()
    return loss / len(dl)

In [13]:
model.to(device)
model = torch.compile(model)
evaluate(model, val_dl)

W0706 22:45:56.928000 4164 torch/_inductor/utils.py:1250] [0/0] Not enough SMs to use max_autotune_gemm mode


10.990858793258667

This looks right, initially the probability will be evenly distributed i.e. each token will roughly have the same probability. As result, we can calculated the expected value of loss `-ln(1/50304) ~= 10.826`

In [14]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

In [None]:
log_freq = 10
losses = []

model.train()
for i, (x,y) in enumerate(pbar := tqdm(train_dl, desc="Training")):
    if i % log_freq == 0:
        val_loss = evaluate(model, val_dl)
        losses.append(val_loss)
        pbar.set_postfix_str(f"Val Loss: {val_loss:.3f}")
        torch.save(model.state_dict(), "model.pth")

    x, y = x.to(device), y.to(device)
    logits = model(x)
    loss = F.cross_entropy(logits.flatten(0,1), y.flatten())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

Lets try generating some sample text . . .

In [None]:
print(generate("Once upon a time"))

In [None]:
print(generate("Internet is an"))

In [None]:
print(generate("AI will"))

In [None]:
print(generate("The meaning of life is"))