# GPT from scratch (using PyTorch)


In [99]:
# 🧐 GPT from Scratch (Decoder-only Transformer + Training + Decoding)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt

# ========== Vocabulary and Tokenization ==========

PAD_TOKEN = 0
SOS_TOKEN = 1
EOS_TOKEN = 2
SPECIAL_TOKENS = ["<pad>", "<sos>", "<eos>"]

class Vocab:
    def __init__(self):
        self.token2id = {tok: i for i, tok in enumerate(SPECIAL_TOKENS)}
        self.id2token = {i: tok for tok, i in self.token2id.items()}
        self.next_id = len(self.token2id)

    def add_sentence(self, sentence):
        for word in sentence.split():
            if word not in self.token2id:
                self.token2id[word] = self.next_id
                self.id2token[self.next_id] = word
                self.next_id += 1

    def encode(self, sentence):
        tokens = [self.token2id.get(w, PAD_TOKEN) for w in sentence.split()]
        return [SOS_TOKEN] + tokens + [EOS_TOKEN]

    def decode(self, ids):
        return " ".join([self.id2token.get(i, "<unk>") for i in ids])

    def __len__(self):
        return self.next_id

# ========== Toy Dataset (English Forward Text) ==========

def generate_toy_text_data(n=1000):
    words = ["i", "like", "to", "eat", "code", "every", "day", "hello", "world", "test"]
    pairs = []
    for _ in range(n):
        length = random.randint(4, 7)
        sentence = " ".join(random.choices(words, k=length))
        pairs.append(sentence)
    return pairs

class GPTDataset(Dataset):
    def __init__(self, texts, vocab):
        self.vocab = vocab
        self.data = [torch.tensor(vocab.encode(text)) for text in texts]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        seq = self.data[idx]
        return seq[:-1], seq[1:]  # input, target

def collate_batch(batch):
    xs, ys = zip(*batch)
    max_len = max(len(x) for x in xs)
    pad_xs = [F.pad(x, (0, max_len - len(x)), value=PAD_TOKEN) for x in xs]
    pad_ys = [F.pad(y, (0, max_len - len(y)), value=PAD_TOKEN) for y in ys]
    return torch.stack(pad_xs), torch.stack(pad_ys)

# ========== Positional Encoding ==========

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# ========== Multi-Head Attention ==========

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, Q, K, V, mask=None):
        B = Q.size(0)
        def split(x):
            return x.view(B, -1, self.num_heads, self.d_k).transpose(1, 2)
        Q, K, V = split(self.W_q(Q)), split(self.W_k(K)), split(self.W_v(V))
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = torch.softmax(scores, dim=-1)
        x = torch.matmul(attn, V)
        x = x.transpose(1, 2).contiguous().view(B, -1, self.num_heads * self.d_k)
        return self.W_o(x)

# ========== Feed Forward ==========

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x):
        return self.net(x)

# ========== Decoder Block ==========

class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = FeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask):
        x = self.norm1(x + self.attn(x, x, x, mask))
        x = self.norm2(x + self.ffn(x))
        return x

# ========== GPT Model ==========

class GPT(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_len):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len)
        self.blocks = nn.ModuleList([
            DecoderBlock(d_model, num_heads, d_ff) for _ in range(num_layers)
        ])
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x):
        mask = create_causal_mask(x.size(1)).to(x.device)
        x = self.pos_enc(self.embed(x))
        for block in self.blocks:
            x = block(x, mask)
        return self.fc(x)

# ========== Masks ==========

def create_causal_mask(size):
    return torch.tril(torch.ones(size, size)).bool().unsqueeze(0).unsqueeze(0)

# ========== Training + Generation ==========

def train_epoch(model, loader, loss_fn, optimizer, device):
    model.train()
    total_loss = 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        out = model(x)
        loss = loss_fn(out.view(-1, out.size(-1)), y.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def generate_text(model, vocab, prompt, max_len=20):
    model.eval()
    tokens = vocab.encode(prompt)
    input_ids = torch.tensor([tokens], dtype=torch.long).to(next(model.parameters()).device)
    for _ in range(max_len):
        logits = model(input_ids)
        next_id = torch.argmax(logits[:, -1, :], dim=-1).item()
        input_ids = torch.cat([input_ids, torch.tensor([[next_id]], device=input_ids.device)], dim=1)
        if next_id == EOS_TOKEN:
            break
    return vocab.decode(input_ids[0].tolist()[1:])


In [100]:

# ========== Run Training ==========

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
texts = generate_toy_text_data(1000)
vocab = Vocab()
for s in texts: vocab.add_sentence(s)
dataset = GPTDataset(texts, vocab)
loader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_batch)

model = GPT(vocab_size=len(vocab), d_model=128, num_heads=4, d_ff=256, num_layers=2, max_len=64).to(device)
loss_fn = nn.CrossEntropyLoss(ignore_index=PAD_TOKEN)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)

for epoch in range(1, 6):
    loss = train_epoch(model, loader, loss_fn, optimizer, device)
    print(f"Epoch {epoch} | Loss: {loss:.4f}")
    print("Sample:", generate_text(model, vocab, "i like", max_len=10))

Epoch 1 | Loss: 2.3763
Sample: i like <eos> test <eos>
Epoch 2 | Loss: 2.2090
Sample: i like <eos> every <eos>
Epoch 3 | Loss: 2.1997
Sample: i like <eos> every <eos>
Epoch 4 | Loss: 2.1926
Sample: i like <eos> day <eos>
Epoch 5 | Loss: 2.1829
Sample: i like <eos> every every <eos>
