<a href="https://colab.research.google.com/github/Frederick-Stein/Data-Science-Playground/blob/main/GPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import math

In [3]:
## encode data
def Encode(sentences: list):

    words = set()
    for sentence in sentences:
        for word in sentence.split():
            words.add(word)

    token_to_id = {"<PAD>": 0, "<EOS>": 1}
    id_to_token = {0: "<PAD>", 1: "<EOS>"}
    for i, word in enumerate(words, start = 2):
        token_to_id[word] = i
        id_to_token[i] = word

    vocab_size = len(token_to_id)

    def encode(sentence):
        ids = []
        for word in sentence.split():
            ids.append(token_to_id[word])
        return ids

    encoded_sentences = [torch.tensor(encode(sentence)) for sentence in sentences]
    padded_sentences =  nn.utils.rnn.pad_sequence(encoded_sentences, batch_first=True, padding_value=0)
    return vocab_size, padded_sentences, token_to_id, id_to_token

In [43]:
class Embedding(nn.Module):

    def __init__(self, vocab_size: int, max_context_length: int, embedding_dim: int, dropout: float = 0.1, pad_idx: int | None = None):

        super().__init__()
        self.embedding_dim = embedding_dim
        self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.position_embeddings = nn.Embedding(max_context_length, embedding_dim, padding_idx = pad_idx)
        self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()

        # Cache [0, 1, 2, ..., max_context_len-1] as a buffer
        self.register_buffer(
            "position_ids",
            torch.arange(max_context_length, dtype=torch.long).unsqueeze(0),  # (1, max_len)
            persistent=False
        )


    def forward(self, context: torch.Tensor):
        # context (B, L) B : batch size, L: context length
        assert context.dtype == torch.long
        B, L = context.shape
        word_embeddings = self.word_embeddings(context) * math.sqrt(self.embedding_dim)# (B, L, embedding_dim)
        positions = self.position_ids[:, :L] # (1, L)
        position_embeddings = self.position_embeddings(positions) # (1，L, embedding_dim)
        output = word_embeddings + position_embeddings
        return self.dropout(output)

In [44]:
## test embedding
B, L, d = 2, 5, 32
vocab = 1000
max_len = 16
emb = Embedding(vocab, max_len, d, dropout=0.1, pad_idx=0)

x = torch.randint(1, vocab, (B, L))       # non-pad tokens
x[0, -1] = 0                              # add a pad token to test pad_idx
y = emb(x)
assert y.shape == (B, L, d)
assert torch.isfinite(y).all()

In [37]:
class MultiHeadedSelfAttention(nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, dropout: float = 0.1, causal: bool = True):
        super().__init__()

        assert embedding_dim % num_heads == 0 # divisible
        self.head_dim = embedding_dim // num_heads

        # Bulid head
        self.att_heads = nn.ModuleList()
        for i in range(num_heads):
            self.att_heads.append(self.SingleHeadAttention(embedding_dim, self.head_dim, dropout, causal))

        # Output projection back to embedding_dim
        self.W_output = nn.Linear(num_heads * self.head_dim, embedding_dim, bias=False)
        self.proj_dropout = nn.Dropout(dropout)

    def forward(self, embedded: torch.Tensor, attn_mask: torch.Tensor | None = None):
        # embedded: (B, L, embedding_dim)

        head_outputs = []
        for head in self.att_heads:
            head_outputs.append(head(embedded, attn_mask = attn_mask)) # (B, L, head_dim)
        concat_heads= torch.cat(head_outputs, dim = 2) # (B, L, num_heads * head_dim)
        output = self.W_output(concat_heads) # (B, L, embedding_dim)

        return self.proj_dropout(output)



    class SingleHeadAttention(nn.Module):
        def __init__(self, embedding_dim: int, head_dim: int, dropout: float = 0.1, causal: bool = True):
            super().__init__()

            self.causal = causal
            self.head_dim = head_dim
            self.W_q = nn.Linear(embedding_dim, head_dim, bias=False)
            self.W_k = nn.Linear(embedding_dim, head_dim, bias=False)
            self.W_v = nn.Linear(embedding_dim, head_dim, bias=False)
            self.attn_dropout = nn.Dropout(dropout)

        def forward(self, embedded: torch.Tensor, attn_mask: torch.Tensor | None = None):
            # embedded: (B, L, embedding_dim)
            B, L, _ = embedded.shape

            Q = self.W_q(embedded) # (B, L, head_dim)
            K = self.W_k(embedded)
            V = self.W_v(embedded)

            scores = Q @ K.transpose(-2, -1)
            scaled_scores = scores / (self.head_dim ** 0.5)

            # Causal mask
            if self.causal:
                mask = torch.triu(torch.ones(L, L, device=scores.device, dtype=torch.bool), diagonal=1)
                scaled_scores = scaled_scores.masked_fill(mask.unsqueeze(0), float('-inf'))

            # padding mask: attn_mask expected shape (B, L) with True for keep/1 for tokens
            if attn_mask is not None:
                key_keep = attn_mask.to(torch.bool).unsqueeze(1) # (B, 1, L)
                scaled_scores = scaled_scores.masked_fill(~key_keep, float('-inf'))

            attention_weights = F.softmax(scaled_scores, dim = -1) # (B, L, L)
            attention_weights = self.attn_dropout(attention_weights)
            attention_out = attention_weights @ V # (B, L, head_dim)

            return attention_out

In [38]:
## test attention
def test_mhsa_basic():
    torch.manual_seed(0)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    B, L, d_model, heads = 2, 5, 32, 4
    x = torch.randn(B, L, d_model, device=device)

    mhsa = MultiHeadedSelfAttention(embedding_dim=d_model, num_heads=heads).to(device)
    y = mhsa(x)

    # 1) shape check
    assert y.shape == (B, L, d_model), f"bad shape: {y.shape}"

    # 2) no NaNs/Infs
    assert torch.isfinite(y).all(), "NaN/Inf in output"

    # 3) gradients flow
    loss = y.pow(2).mean()
    loss.backward()  # should populate grads
    has_grad = any(p.grad is not None and torch.isfinite(p.grad).all() for p in mhsa.parameters())
    assert has_grad, "no finite gradients"

    print("✓ basic: shape, finiteness, gradients")

def test_mhsa_causality():
    """
    If we change *future* tokens, earlier outputs must remain unchanged.
    """
    torch.manual_seed(1)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    B, L, d_model, heads = 1, 6, 24, 3
    x = torch.zeros(B, L, d_model, device=device)

    # Put a random signal ONLY in the last position
    x[:, -1, :] = torch.randn(d_model, device=device)

    mhsa = MultiHeadedSelfAttention(embedding_dim=d_model, num_heads=heads).to(device)
    y1 = mhsa(x).detach()

    # Now change the *future* token (last position) and re-run
    x2 = x.clone()
    x2[:, -1, :] = torch.randn(d_model, device=device)  # different future content
    y2 = mhsa(x2).detach()

    # Earlier positions [0 .. L-2] should be identical; last position can change
    diff_early = (y1[:, :L-1, :] - y2[:, :L-1, :]).abs().max().item()
    diff_last  = (y1[:, L-1:, :]   - y2[:, L-1:, :]).abs().max().item()

    assert diff_early < 1e-6, f"causality broken: early diff={diff_early:.3e}"
    assert diff_last >= 1e-6, "last position should change when its input changes"

    print(f"✓ causality: early max diff={diff_early:.3e}, last changed")

def test_mhsa_divisibility_guard():
    try:
        _ = MultiHeadedSelfAttention(embedding_dim=30, num_heads=8)  # not divisible
        raise AssertionError("expected an assertion for non-divisible dims")
    except AssertionError:
        print("✓ divisibility check triggers")

if __name__ == "__main__":
    test_mhsa_basic()
    test_mhsa_causality()
    test_mhsa_divisibility_guard()
    print("All tests passed ✔")

✓ basic: shape, finiteness, gradients
✓ causality: early max diff=0.000e+00, last changed
✓ divisibility check triggers
All tests passed ✔


In [39]:
class FFN(nn.Module):

    def __init__(self, embedding_dim: int, dropout: float = 0.1, expansion_factor: int = 4):
        super().__init__()
        hidden_dim = expansion_factor * embedding_dim
        self.block = nn.Sequential(
            nn.Linear(embedding_dim, hidden_dim), # up-projection
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embedding_dim), # down-projection
            nn.Dropout(dropout)
        )

    def forward(self, x: torch.Tensor):
        return self.block(x)

In [40]:
class TransformerBlock(nn.Module):

    def __init__(self, embedding_dim: int, num_heads: int, dropout: float = 0.1, causal: bool = True):
        super().__init__()
        self.attention = MultiHeadedSelfAttention(embedding_dim, num_heads, dropout=dropout, causal=causal)
        self.ffn = FFN(embedding_dim, dropout=dropout)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.norm2 = nn.LayerNorm(embedding_dim)
        # optional residual dropouts
        self.attn_dropout = nn.Dropout(dropout)
        self.ffn_dropout = nn.Dropout(dropout)


    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None):
        # embedded: (B, L, embedding_dim)
        x = x + self.attn_dropout(self.attention(self.norm1(x), attn_mask = attn_mask))
        x = x + self.ffn_dropout(self.ffn(self.norm2(x)))
        return x

In [42]:
## test TransformerBlock
def make_pad_batch(lengths, d_model, pad_id=0):
    """Create a (B, L, d) batch with padding mask (True=token, False=pad)."""
    B = len(lengths)
    L = max(lengths)
    x = torch.randn(B, L, d_model)
    mask = torch.zeros(B, L, dtype=torch.bool)
    for i, t in enumerate(lengths):
        mask[i, :t] = True
        if t < L:
            x[i, t:, :] = 0.0  # content of pads doesn't matter
    return x, mask  # (B,L,d), (B,L)

# ====== tests ======
def test_basic_forward():
    torch.manual_seed(0)
    B, L, d, h = 2, 8, 64, 4
    x = torch.randn(B, L, d)
    blk = TransformerBlock(d, h, dropout=0.1, causal=False)
    y = blk(x, attn_mask=torch.ones(B, L, dtype=torch.bool))
    assert y.shape == (B, L, d), f"bad shape {y.shape}"
    assert torch.isfinite(y).all(), "NaN/Inf in output"
    print("✓ basic forward: shape & finiteness")

def test_backward_grad():
    torch.manual_seed(1)
    B, L, d, h = 2, 6, 32, 4
    x = torch.randn(B, L, d, requires_grad=True)
    blk = TransformerBlock(d, h, dropout=0.1, causal=False)
    y = blk(x, attn_mask=torch.ones(B, L, dtype=torch.bool))
    loss = y.pow(2).mean()
    loss.backward()
    # at least some params should have grad
    has_grad = any(p.grad is not None and torch.isfinite(p.grad).all() for p in blk.parameters())
    assert has_grad, "no finite gradients"
    print("✓ backward: gradients flow")

def test_padding_mask_blocks_pads():
    torch.manual_seed(2)
    d, h = 48, 3
    lengths = [5, 3]    # batch with different sequence lengths
    x, mask = make_pad_batch(lengths, d_model=d)  # mask: True=token, False=pad

    blk = TransformerBlock(d, h, dropout=0.0, causal=False)  # turn off dropout for determinism
    blk.eval()

    # Run once with correct mask
    y_masked = blk(x.clone(), attn_mask=mask)

    # Run again pretending pads are real tokens (all True)
    y_unmasked = blk(x.clone(), attn_mask=torch.ones_like(mask))

    # For padded positions, outputs SHOULD differ (since unmasked lets pads influence others)
    # For real token positions, keep the outputs close (they might also differ slightly)
    B, L = mask.shape
    pad_positions = ~mask
    if pad_positions.any():
        diff_on_pads = (y_masked[pad_positions] - y_unmasked[pad_positions]).abs().mean().item()
        assert diff_on_pads > 1e-6, "padding mask seems ineffective"
    print("✓ padding mask: pads are blocked")

def test_causality():
    torch.manual_seed(3)
    B, L, d, h = 1, 6, 32, 4
    x = torch.zeros(B, L, d)
    # put signal only at the last position
    x[:, -1, :] = torch.randn(d)

    blk = TransformerBlock(d, h, dropout=0.0, causal=True)  # causal on
    blk.eval()

    y1 = blk(x.clone(), attn_mask=torch.ones(B, L, dtype=torch.bool)).detach()
    # change only the last (future) token
    x[:, -1, :] = torch.randn(d)
    y2 = blk(x.clone(), attn_mask=torch.ones(B, L, dtype=torch.bool)).detach()

    # Earlier positions must be identical; last can change
    early_diff = (y1[:, :L-1, :] - y2[:, :L-1, :]).abs().max().item()
    last_diff  = (y1[:, L-1:, :]   - y2[:, L-1:, :]).abs().max().item()
    assert early_diff < 1e-6, f"causality broken: early_diff={early_diff:.3e}"
    assert last_diff >= 1e-6, "last position should change"
    print("✓ causality: future tokens don't affect the past")

def test_train_vs_eval_dropout():
    torch.manual_seed(4)
    B, L, d, h = 2, 8, 64, 4
    x = torch.randn(B, L, d)
    mask = torch.ones(B, L, dtype=torch.bool)
    blk = TransformerBlock(d, h, dropout=0.2, causal=False)

    blk.train()
    y_train1 = blk(x, attn_mask=mask)
    y_train2 = blk(x, attn_mask=mask)
    # in train mode with dropout, two passes usually differ
    train_diff = (y_train1 - y_train2).abs().mean().item()
    assert train_diff > 0.0, "dropout not active in train mode?"

    blk.eval()
    y_eval1 = blk(x, attn_mask=mask)
    y_eval2 = blk(x, attn_mask=mask)
    eval_diff = (y_eval1 - y_eval2).abs().max().item()
    assert eval_diff < 1e-7, "eval should be deterministic (no dropout)"
    print("✓ dropout behavior: stochastic in train, deterministic in eval")

if __name__ == "__main__":
    test_basic_forward()
    test_backward_grad()
    test_padding_mask_blocks_pads()
    test_causality()
    test_train_vs_eval_dropout()
    print("All tests passed ✔")

✓ basic forward: shape & finiteness
✓ backward: gradients flow
✓ padding mask: pads are blocked
✓ causality: future tokens don't affect the past
✓ dropout behavior: stochastic in train, deterministic in eval
All tests passed ✔


In [45]:
class GPT(nn.Module):

    def __init__(self, vocab_size: int, embedding_dim: int, max_context_length: int, num_blocks: int, num_heads: int, dropout: float = 0.1, causal: bool = True):
        super().__init__()
        self.embedding = Embedding(vocab_size, max_context_length, embedding_dim, dropout=dropout, pad_idx=0)
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(embedding_dim, num_heads, dropout=dropout, causal=causal) for _ in range(num_blocks)
        ])
        self.final_norm = nn.LayerNorm(embedding_dim)
        self.vocab_projection = nn.Linear(embedding_dim, vocab_size, bias=False)

    def forward(self, context: torch.Tensor, attn_mask: torch.Tensor | None = None):
        # context: (B, L)
        # attn_mask (B, L)
        x = self.embedding(context) # (B, L, d)
        for block in self.transformer_blocks:
            x = block(x, attn_mask = attn_mask) # (B, L, d)
        x = self.final_norm(x) # (B, L, d)
        raw_output = self.vocab_projection(x) # logits

        return raw_output

In [46]:
## test gpt

def test_forward_shapes():
    torch.manual_seed(0)
    vocab_size, d_model, max_len = 100, 32, 16
    num_blocks, num_heads = 2, 4
    model = GPT(vocab_size, d_model, max_len, num_blocks, num_heads, dropout=0.1, causal=True)

    B, L = 3, 10
    x = torch.randint(0, vocab_size, (B, L))      # fake token ids
    mask = torch.ones(B, L, dtype=torch.bool)     # no padding
    logits = model(x, attn_mask=mask)

    assert logits.shape == (B, L, vocab_size)
    print("✓ forward: logits shape OK", logits.shape)

def test_backward_grad():
    torch.manual_seed(1)
    vocab_size, d_model, max_len = 50, 32, 12
    model = GPT(vocab_size, d_model, max_len, num_blocks=2, num_heads=4, dropout=0.1, causal=True)

    B, L = 2, 8
    x = torch.randint(0, vocab_size, (B, L))
    mask = torch.ones(B, L, dtype=torch.bool)
    logits = model(x, attn_mask=mask)

    # dummy LM loss: predict next token
    loss = F.cross_entropy(
        logits[:, :-1, :].reshape(-1, vocab_size),
        x[:, 1:].reshape(-1),
    )
    loss.backward()

    has_grad = any(p.grad is not None and torch.isfinite(p.grad).all() for p in model.parameters())
    assert has_grad, "No finite gradients!"
    print("✓ backward: loss computed & gradients flow")

def test_causality():
    torch.manual_seed(2)
    vocab_size, d_model, max_len = 40, 16, 10
    model = GPT(vocab_size, d_model, max_len, num_blocks=1, num_heads=2, dropout=0.0, causal=True)
    model.eval()

    B, L = 1, 6
    x = torch.randint(0, vocab_size, (B, L))

    # Run once
    y1 = model(x).detach()
    # Change only the last token (future)
    x2 = x.clone()
    x2[:, -1] = (x2[:, -1] + 1) % vocab_size
    y2 = model(x2).detach()

    # Earlier logits [0..L-2] should match exactly if causal is working
    diff_early = (y1[:, :-1, :] - y2[:, :-1, :]).abs().max().item()
    print("✓ causality check: max diff before last =", diff_early)

def test_dropout_behavior():
    torch.manual_seed(3)
    vocab_size, d_model, max_len = 30, 16, 12
    model = GPT(vocab_size, d_model, max_len, num_blocks=1, num_heads=2, dropout=0.2, causal=True)

    B, L = 2, 5
    x = torch.randint(0, vocab_size, (B, L))
    mask = torch.ones(B, L, dtype=torch.bool)

    model.train()
    y1 = model(x, attn_mask=mask)
    y2 = model(x, attn_mask=mask)
    diff_train = (y1 - y2).abs().mean().item()

    model.eval()
    y3 = model(x, attn_mask=mask)
    y4 = model(x, attn_mask=mask)
    diff_eval = (y3 - y4).abs().max().item()

    print(f"✓ dropout train-mode diff ~ {diff_train:.4f}, eval-mode diff ~ {diff_eval:.4e}")

if __name__ == "__main__":
    test_forward_shapes()
    test_backward_grad()
    test_causality()
    test_dropout_behavior()
    print("All GPT tests passed ✔")


✓ forward: logits shape OK torch.Size([3, 10, 100])
✓ backward: loss computed & gradients flow
✓ causality check: max diff before last = 0.0
✓ dropout train-mode diff ~ 0.3606, eval-mode diff ~ 0.0000e+00
All GPT tests passed ✔
