<a href="https://colab.research.google.com/github/howsam/Building-a-ChatGPT-like-Model-from-Scratch/blob/main/Build_GPT_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#  <font color='#FFE15D'><b>💎 Build GPT-2 </b></font>

# 🔴 **Environment Setup**

## 🟠 Change the font size of the output cells

In [5]:
print('Salam Howsam!')

Salam Howsam!


In [6]:
from IPython.display import HTML
shell = get_ipython()

def adjust_font_size():
  display(HTML('''<style>
    body {
      font-size: 24px;
    }
  '''))

if adjust_font_size not in shell.events.callbacks['pre_execute']:
  shell.events.register('pre_execute', adjust_font_size)

In [7]:
print('Salam Howsam!')

Salam Howsam!


# 🔴 **Import**

In [211]:
import time
from dataclasses import dataclass

from datasets import load_dataset
from tokenizers import Tokenizer

import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn import functional as F

# 🔴 **Utils**

In [8]:
def prepare_data(tokens, seq_len):
    # Trim tokens so that total length is divisible by seq_len
    n_tokens = (tokens.shape[0] // seq_len) * seq_len
    tokens = tokens[:n_tokens]

    # Reshape to 2D tensor
    return tokens.view(-1, seq_len)


In [9]:
def num_trainable_params(model):
  nums = sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6
  return nums

In [10]:
def calculate_time(model, x, num_runs=10):
    torch.cuda.synchronize()
    start = time.time()
    for _ in range(num_runs):
        model(*x)
    torch.cuda.synchronize()
    return (time.time() - start) / num_runs

# 🔴 **Init**

In [11]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

# 🔴 **Dataset**

In [12]:
dataset = load_dataset("roneneldan/TinyStories")
dataset

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2119719
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 21990
    })
})

In [13]:
tokenizer = Tokenizer.from_file("bpe-tokenizer_tinystories.json")
tokenizer

<tokenizers.Tokenizer at 0x1497f363230>

In [14]:
# Load tokens from pytorch file
train_token_ids = torch.load('tokenized-train-samples_vocab-10k.pt')
valid_token_ids = torch.load('tokenized-valid-samples_vocab-10k.pt')

print("📊 Number of Tokens")
print(f"🔹 Train: {len(train_token_ids):,} tokens")
print(f"🔹 Valid: {len(valid_token_ids):,} tokens")

📊 Number of Tokens
🔹 Train: 464,965,814 tokens
🔹 Valid: 4,673,588 tokens


In [15]:
class TinyStoriesDataset(Dataset):

    def __init__(self, data, seq_len):
        self.seq_len = seq_len
        self.data = prepare_data(data, seq_len+1)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        sample = self.data[idx]
        return sample[:-1], sample[1:]

In [16]:
seq_len = 128

train_set = TinyStoriesDataset(train_token_ids, seq_len)
valid_set = TinyStoriesDataset(valid_token_ids, seq_len)

print(f"📊 Number of Samples")
print(f"🔹 Train: {len(train_set):,} samples")
print(f"🔹 Valid: {len(valid_set):,} samples")

📊 Number of Samples
🔹 Train: 3,604,386 samples
🔹 Valid: 36,229 samples


In [17]:
x, y = next(iter(train_set))

print(f"📊 Sample Shapes")
print(f"🔹 Input: {x.shape}")
print(f"🔹 Target: {y.shape}")

📊 Sample Shapes
🔹 Input: torch.Size([128])
🔹 Target: torch.Size([128])


In [18]:
torch.manual_seed(1337)
batch_size = 64

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, pin_memory=True)#, num_workers=4)
valid_loader = DataLoader(valid_set, batch_size=batch_size, shuffle=False, pin_memory=True)#, num_workers=4)

print(f"📊 Number of Batches")
print(f"🔹 Train: {len(train_loader):,} batches")
print(f"🔹 Valid: {len(valid_loader):,} batches")

📊 Number of Batches
🔹 Train: 56,319 batches
🔹 Valid: 567 batches


In [19]:
x_batch, y_batch = next(iter(train_loader))

print(f"📊 Batch Shapes")
print(f"🔹 Input: {x_batch.shape}")
print(f"🔹 Target: {y_batch.shape}")

📊 Batch Shapes
🔹 Input: torch.Size([64, 128])
🔹 Target: torch.Size([64, 128])


# 🔴 **Model**

## 🟠 Embedding

In [20]:
wte = nn.Embedding(tokenizer.get_vocab_size(), 100)
wte(torch.tensor([1, 2, 100])).shape

torch.Size([3, 100])

In [21]:
wpe = nn.Embedding(seq_len, 100)
wpe(torch.tensor([1, 2, 100])).shape

torch.Size([3, 100])

In [22]:
x = wte(x_batch) + wpe(torch.arange(x_batch.shape[1]))
x.shape

torch.Size([64, 128, 100])

## 🟠 Scaled Dot-Product Attention

In [124]:
q = k = v = x
print(q.shape)

mask = torch.tril(torch.ones(seq_len, seq_len))

scores = q @ k.transpose(-2, -1) / (k.shape[-1]**0.5)
scores.masked_fill_(mask ==0, float(-torch.inf))
scores = scores.softmax(dim=-1)
print(scores.shape)

z = scores @ v
z.shape

torch.Size([64, 128, 100])
torch.Size([64, 128, 128])


torch.Size([64, 128, 100])

In [123]:
# scores = torch.randn(3, 5, 5)
# mask = torch.tril(torch.ones(5, 5))
# scores.masked_fill_(mask ==0, float(-torch.inf))
# scores = scores.softmax(dim=-1)
# scores

tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.2138, 0.7862, 0.0000, 0.0000, 0.0000],
         [0.0648, 0.0727, 0.8625, 0.0000, 0.0000],
         [0.0503, 0.0670, 0.3075, 0.5752, 0.0000],
         [0.1718, 0.2235, 0.3567, 0.1625, 0.0854]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.3436, 0.6564, 0.0000, 0.0000, 0.0000],
         [0.7898, 0.0147, 0.1956, 0.0000, 0.0000],
         [0.1449, 0.3177, 0.2102, 0.3273, 0.0000],
         [0.5786, 0.0909, 0.1245, 0.1555, 0.0505]],

        [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.9238, 0.0762, 0.0000, 0.0000, 0.0000],
         [0.0626, 0.5604, 0.3770, 0.0000, 0.0000],
         [0.1473, 0.5825, 0.1588, 0.1114, 0.0000],
         [0.3903, 0.3668, 0.0837, 0.1365, 0.0227]]])

In [144]:
def scaled_dot_product_attention(q, k, v):
    mask = torch.tril(torch.ones(q.shape[-2], q.shape[-2])).to(device)
    scores = q @ k.transpose(-2, -1) / (k.shape[-1]**0.5)
    scores.masked_fill_(mask==0, float(-torch.inf))
    scores = scores.softmax(dim=-1)
    z = scores @ v
    return z

In [145]:
scaled_dot_product_attention(x.to(device), x.to(device), x.to(device)).shape

torch.Size([64, 128, 100])

In [146]:
q = torch.randn((128, 1024, 768), device=device)
k = torch.randn((128, 1024, 768), device=device)
v = torch.randn((128, 1024, 768), device=device)
q.shape

torch.Size([128, 1024, 768])

In [147]:
scaled_dot_product_attention(q, k, v).shape

torch.Size([128, 1024, 768])

In [148]:
calculate_time(scaled_dot_product_attention, (q, k, v), num_runs=20)

0.13596891164779662

In [149]:
F.scaled_dot_product_attention(q, k, v, is_causal=True).shape

  F.scaled_dot_product_attention(q, k, v, is_causal=True).shape


torch.Size([128, 1024, 768])

In [150]:
torch.abs(scaled_dot_product_attention(q, k, v) - F.scaled_dot_product_attention(q, k, v, is_causal=True)).max()

tensor(5.8189e-06, device='cuda:0')

In [154]:
calculate_time(F.scaled_dot_product_attention, (q, k, v), num_runs=20)

0.13060171604156495

## 🟠 Multi Head Attention

In [48]:
# class MultiHeadAttention(nn.Module):

#     def __init__(self):
#         super().__init__()
#         self.fc1 = nn.Linear(100, 1000)
#         self.fc2 = nn.Linear(1000, 100)
#         self.fc3 = nn.Linear(1000, 100)

#     def forward(self, x):
#         y = F.relu(self.fc1(x))
#         y1 = self.fc2(y)
#         y2 = self.fc3(y)
#         return F.relu(torch.concat([y1, y2], dim=-1))

In [49]:
# mha = MultiHeadAttention()
# num_trainable_params(mha)
# mha.forward(torch.rand(10, 100)).shape

torch.Size([10, 200])

In [67]:
class GPTConfig:
    n_embd: int = 100
    n_head: int = 5

config = GPTConfig()
config.n_embd

100

In [197]:
class MultiHeadAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        self.n_head = config.n_head
        self.head_size = self.n_embd // self.n_head

        self.qkv_proj = nn.Linear(self.n_embd, 3*self.n_embd, bias=False)

        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)

    def forward(self, x):
        B, T, C = x.shape
        q, k, v = self.qkv_proj(x).view(B, T, 3*self.n_head, self.head_size).transpose(1, 2).chunk(3, dim=-3)

        y = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        y = y.transpose(1, 2).contiguous().view(B, T, C)

        y = self.c_proj(y)
        return y

In [106]:
mha = MultiHeadAttention(config)
mha(x).shape

torch.Size([64, 128, 100])

In [112]:
xx = torch.arange(24).view(2, 2, 3, 2)
print(xx)
xx.reshape(2, 3, 4)

tensor([[[[ 0,  1],
          [ 2,  3],
          [ 4,  5]],

         [[ 6,  7],
          [ 8,  9],
          [10, 11]]],


        [[[12, 13],
          [14, 15],
          [16, 17]],

         [[18, 19],
          [20, 21],
          [22, 23]]]])


tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

In [113]:
calculate_time(mha.to(device), (x.to(device),), num_runs=20)

  y = F.scaled_dot_product_attention(q, k, v, is_causal=True)


0.010302698612213135

## 🟠 Feed Forward (MLP)

In [128]:
class GPTConfig:
    n_embd: int = 100
    n_head: int = 5
    f_expnd: float = 4

config = GPTConfig()
config.n_embd

100

In [198]:
class FeedForward(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        self.f_expnd = config.f_expnd

        self.up_proj = nn.Linear(self.n_embd, int(self.f_expnd*self.n_embd), bias=False)
        self.down_proj = nn.Linear(int(self.f_expnd*self.n_embd), self.n_embd, bias=False)

    def forward(self, x):
        return self.down_proj(F.gelu(self.up_proj(x)))

In [132]:
mlp = FeedForward(config)
mlp(x).shape

torch.Size([64, 128, 100])

In [134]:
num_trainable_params(mlp)*1000

80.5

In [135]:
calculate_time(mlp, (x, ), num_runs=20)

0.014536631107330323

## 🟠 Decoder Block

In [142]:
class DecoderBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd

        self.ln1 = nn.LayerNorm(config.n_embd)
        self.mha = MultiHeadAttention(config)

        self.ln2 = nn.LayerNorm(config.n_embd)
        self.mlp = FeedForward(config)

    def forward(self, x):
        x = x + self.mha(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

In [143]:
decoder = DecoderBlock(config)
decoder(x).shape

torch.Size([64, 128, 100])

In [145]:
num_trainable_params(decoder) * 1e3

121.30000000000001

In [148]:
calculate_time(decoder, (x, ), num_runs=20) * 1e3

49.8690128326416

## 🟠 GPT

In [153]:
class GPTConfig:
    vocab_size: int = 10_000
    seq_len: int = 128
    n_layer: int = 12
    n_embd: int = 100
    n_head: int = 5
    f_expnd: float = 4


config = GPTConfig()
config.n_embd

100

In [220]:
class GPT(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.wte = nn.Embedding(config.vocab_size, config.n_embd)
        self.wpe = nn.Embedding(config.seq_len, config.n_embd)
        # self.decoders = nn.Sequential(*[DecoderBlock(config) for _ in range(config.n_layer)])
        self.decoders = nn.ModuleList([DecoderBlock(config) for _ in range(config.n_layer)])
        self.lnf = nn.LayerNorm(config.n_embd)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        self.lm_head.weight = self.wte.weight

    def forward(self, idx):
        B, T = idx.shape

        x = self.wte(idx) + self.wpe(torch.arange(T, device=device))

        # x = self.decoders(x)
        for decoder in self.decoders:
            x = decoder(x)

        x = self.lnf(x)
        logits = self.lm_head(x)
        return logits

In [200]:
model = GPT(config).to(device)
model(x_batch.to(device)).shape

torch.Size([64, 128, 10000])

In [201]:
num_trainable_params(model), num_trainable_params(model.decoders), num_trainable_params(model.lm_head)

(3.4578, 1.4448, 1.0)

In [203]:
calculate_time(model, (x_batch.to(device),), num_runs=100) * 1e3

22.965524196624756

# 🔴 **Config**

In [212]:
@dataclass
class GPTConfig:
    vocab_size: int = 50257 # number of tokens
    seq_len: int = 1024 # max sequence length
    n_layer: int = 12 # number of layers
    n_head: int = 12 # number of heads
    n_embd: int = 768 # embedding dimension
    f_expnd: int = 4 # expansion factor in mlp

In [213]:
model = GPT(GPTConfig()).to(device)
model

GPT(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (decoders): ModuleList(
    (0-11): 12 x DecoderBlock(
      (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mha): MultiHeadAttention(
        (qkv_proj): Linear(in_features=768, out_features=2304, bias=False)
        (c_proj): Linear(in_features=768, out_features=768, bias=False)
      )
      (ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): FeedForward(
        (up_proj): Linear(in_features=768, out_features=3072, bias=False)
        (down_proj): Linear(in_features=3072, out_features=768, bias=False)
      )
    )
  )
  (lnf): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [214]:
num_trainable_params(model)

print(f"🔹 Trainable parameters: {num_trainable_params(model):,}")

🔹 Trainable parameters: 162.95424


In [217]:
config = GPTConfig(vocab_size=10_000, seq_len=128)
config

GPTConfig(vocab_size=10000, seq_len=128, n_layer=12, n_head=12, n_embd=768, f_expnd=4)

In [218]:
model = GPT(
    GPTConfig(
        seq_len=256, vocab_size=10_000, n_layer=4, n_embd=256, n_head=4
        )).to(device)

print(f"🔹 Trainable parameters: {num_trainable_params(model):,}")
print(model)

🔹 Trainable parameters: 8.335872
GPT(
  (wte): Embedding(10000, 256)
  (wpe): Embedding(256, 256)
  (decoders): ModuleList(
    (0-3): 4 x DecoderBlock(
      (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (mha): MultiHeadAttention(
        (qkv_proj): Linear(in_features=256, out_features=768, bias=False)
        (c_proj): Linear(in_features=256, out_features=256, bias=False)
      )
      (ln2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (mlp): FeedForward(
        (up_proj): Linear(in_features=256, out_features=1024, bias=False)
        (down_proj): Linear(in_features=1024, out_features=256, bias=False)
      )
    )
  )
  (lnf): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (lm_head): Linear(in_features=256, out_features=10000, bias=False)
)


In [219]:
# Embedding layers
wte_params = num_trainable_params(model.wte)  # Word Token Embedding
wpe_params = num_trainable_params(model.wpe)  # Position Embedding

# Optional: Classifier head
lm_head_params = num_trainable_params(model.lm_head)

# Total trainable parameters without weight tying
total_params_without_tying = num_trainable_params(model)

# Total trainable parameters with weight tying
total_params_with_tying = total_params_without_tying - wte_params

# Core model params (excluding embeddings and head)
core_params = total_params_without_tying - wte_params - wpe_params - lm_head_params

# Print results
print(f"🔹 Total trainable parameters without weight tying: {total_params_without_tying:,}")
print(f"🔹 Total trainable parameters with weight tying: {total_params_with_tying:,}")
print(f"🔹 Embedding parameters (WTE + WPE): {wte_params + wpe_params:,}")
print(f"🔹  └─ WTE: {wte_params:,}")
print(f"🔹  └─ WPE: {wpe_params:,}")
print(f"🔹 Classifier head parameters: {lm_head_params:,}")
print(f"🔹 Transformer core (excluding embeddings & head): {core_params:,}")

🔹 Total trainable parameters without weight tying: 8.335872
🔹 Total trainable parameters with weight tying: 5.775872
🔹 Embedding parameters (WTE + WPE): 2.625536
🔹  └─ WTE: 2.56
🔹  └─ WPE: 0.065536
🔹 Classifier head parameters: 2.56
🔹 Transformer core (excluding embeddings & head): 3.150336


# 🔴 **Usage**

## 🟠 Train & Evaluate

## 🟠 Generate