In [4]:
# Tiny decoder-only Transformer (GPT-style) that learns multiple (input -> output) pairs
# and then generates step-by-step like an LLM, printing attention weights at each step.

import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

# ---------------------------
# 0) Reproducibility + device
# ---------------------------
seed = 7
random.seed(seed); torch.manual_seed(seed)
device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------------
# 1) Toy dataset (prompt -> target)
# ---------------------------
pairs = [
    (["how","are","you"], ["i","am","fine"]),
    (["what","is","your","name"], ["my","name","is","gpt"]),
    (["feeling"],           ["good"]),
    (["say","animal"],      ["cat"]),
    (["say","pet"],         ["dog"]),
]

# Build vocabulary (input + output + special tokens)
special = ["<PAD>","<BOS>","<EOS>","<SEP>"]
tokens = set(special)
for x,y in pairs:
    tokens.update(x); tokens.update(y)
vocab = sorted(tokens, key=lambda t: (t not in special, t))
stoi = {t:i for i,t in enumerate(vocab)}
itos = {i:t for t,i in stoi.items()}
PAD, BOS, EOS, SEP = stoi["<PAD>"], stoi["<BOS>"], stoi["<EOS>"], stoi["<SEP>"]
vocab_size = len(vocab)
# Helpers
def encode(tok_list): return [stoi[t] for t in tok_list]
def decode(id_list):  return [itos[i] for i in id_list]

# Convert (input -> output) pair into training sequence
# Format: <BOS> X <SEP> Y <EOS>
def build_example(x_tokens, y_tokens):
    full = ["<BOS>"] + x_tokens + ["<SEP>"] + y_tokens + ["<EOS>"]
    ids = torch.tensor(encode(full), dtype=torch.long)
    labels = ids[1:].clone()             # next-token prediction
    sep_idx = full.index("<SEP>")
    labels[:sep_idx] = -100              # ignore predictions before <SEP>
    return ids, labels, sep_idx

dataset_raw = [build_example(x,y) for (x,y) in pairs]

# Collate batch: pad sequences to equal length
def collate(batch):
    max_len = max(ids.size(0) for (ids,_,_) in batch)
    B = len(batch)
    x = torch.full((B, max_len-1), PAD, dtype=torch.long)
    y = torch.full((B, max_len-1), -100, dtype=torch.long)
    attn_key_padding = torch.zeros((B, max_len-1), dtype=torch.bool)
    for i,(ids, labels, _) in enumerate(batch):
        Ti = ids.size(0)
        x[i, :Ti-1] = ids[:-1]
        y[i, :Ti-1] = labels
        attn_key_padding[i, :Ti-1] = True
    return x, y, attn_key_padding

# Training pool (repeat dataset many times to help overfit)
train_pool = dataset_raw * 64

# ---------------------------
# 2) Tiny GPT-like model
# ---------------------------

# ---- Multi-Head Self-Attention ----
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        assert d_model % n_heads == 0
        self.h = n_heads
        self.d = d_model // n_heads
        # Linear projections for Q,K,V
        self.q = nn.Linear(d_model, d_model)
        self.k = nn.Linear(d_model, d_model)
        self.v = nn.Linear(d_model, d_model)
        self.o = nn.Linear(d_model, d_model)  # output projection
        self.last_attn = None  # store attention weights for inspection

    def forward(self, x, attn_mask=None):
        B, T, C = x.shape
        # Project into Q, K, V
        q = self.q(x).view(B, T, self.h, self.d).transpose(1, 2)  # (B,h,T,d)
        k = self.k(x).view(B, T, self.h, self.d).transpose(1, 2)
        v = self.v(x).view(B, T, self.h, self.d).transpose(1, 2)
        # Scaled dot-product attention
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.d)       # (B,h,T,T)
        if attn_mask is not None:
            att = att.masked_fill(attn_mask == 0, float("-inf"))  # mask future + pad
        w = F.softmax(att, dim=-1)                                # attention weights
        self.last_attn = w                                        # save for tracing
        y = w @ v                                                 # weighted sum of values
        y = y.transpose(1, 2).contiguous().view(B, T, C)          # concat heads
        return self.o(y)

# ---- Feed-Forward MLP ----
class MLP(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
    def forward(self, x):
        return self.fc2(F.gelu(self.fc1(x)))

# ---- Transformer Block ----
class Block(nn.Module):
    def __init__(self, d_model, n_heads, d_ff):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadSelfAttention(d_model, n_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = MLP(d_model, d_ff)
    def forward(self, x, attn_mask):
        # Attention + residual
        x = x + self.attn(self.ln1(x), attn_mask)
        # MLP + residual
        x = x + self.mlp(self.ln2(x))
        return x

# ---- Full Tiny GPT Model ----
class TinyGPT(nn.Module):
    def __init__(self, vocab, d_model=128, n_layers=4, n_heads=4, d_ff=256, max_len=64, pad_id=0):
        super().__init__()
        self.pad_id = pad_id
        # Embedding for tokens + positions
        self.tok = nn.Embedding(vocab, d_model)
        self.pos = nn.Embedding(max_len, d_model)
        # Stack of Transformer blocks
        self.blocks = nn.ModuleList([Block(d_model, n_heads, d_ff) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(d_model)    # final normalization
        self.head = nn.Linear(d_model, vocab, bias=False)  # vocab prediction head
        # Tie weights with embedding
        self.head.weight = self.tok.weight

    # Build causal + padding mask
    def _build_attn_mask(self, idx, key_padding_mask):
        B, T = idx.shape
        causal = torch.tril(torch.ones(T, T, device=idx.device, dtype=torch.bool))  # lower-triangular
        kpad = key_padding_mask.unsqueeze(1).unsqueeze(2)  # (B,1,1,T)
        mask = causal.unsqueeze(0).unsqueeze(0) & kpad     # combine causal + key mask
        return mask

    def forward(self, idx, key_padding_mask):
        B, T = idx.shape
        pos = torch.arange(T, device=idx.device).unsqueeze(0)
        # Embedding lookup
        x = self.tok(idx) + self.pos(pos)
        attn_mask = self._build_attn_mask(idx, key_padding_mask)
        # Pass through blocks
        for blk in self.blocks:
            x = blk(x, attn_mask)
        x = self.ln_f(x)
        logits = self.head(x)  # vocab prediction
        return logits

    # Generate tokens step by step with attention tracing
    @torch.no_grad()
    def generate_with_trace(self, prefix_ids, max_new_tokens=8):
        self.eval()
        out = prefix_ids.clone()
        while out.size(1) < prefix_ids.size(1) + max_new_tokens:
            kpad = torch.ones_like(out, dtype=torch.bool, device=out.device)
            logits = self.forward(out, kpad)             # forward pass
            next_logits = logits[:, -1, :]
            next_id = torch.argmax(next_logits, dim=-1, keepdim=True)  # greedy choice
            # Trace attention from last block
            last_block = self.blocks[-1]
            attn = last_block.attn.last_attn  # (1,h,T,T)
            # Print attention context
            toks = decode(out[0].tolist())
            print(f"\n[Step {out.size(1)-prefix_ids.size(1)+1}] Context: {' '.join(toks)}")
            for h in range(attn.size(1)):
                weights = attn[0, h, -1, :out.size(1)].tolist()
                pairs = [f"{tok}:{w:.2f}" for tok, w in zip(toks, weights)]
                print(f"  Head {h}: " + " | ".join(pairs))
            # Append predicted token
            out = torch.cat([out, next_id], dim=1)
            print(f"  → Predicted: {decode([next_id.item()])[0]}")
            if next_id.item() == EOS:
                break
        return out

# ---------------------------
# 3) Training loop
# ---------------------------
model = TinyGPT(vocab_size, d_model=128, n_layers=4, n_heads=4, d_ff=256, max_len=64, pad_id=PAD).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-3, weight_decay=0.01)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

def sample_batch(batch_size=8):
    batch = random.sample(train_pool, batch_size)
    x, y, kpad = collate(batch)
    return x.to(device), y.to(device), kpad.to(device)

steps = 1200
model.train()
for step in range(1, steps+1):
    opt.zero_grad(set_to_none=True)
    x, y, kpad = sample_batch(16)
    logits = model(x, kpad)
    loss = criterion(logits.reshape(-1, vocab_size), y.reshape(-1))
    loss.backward()
    opt.step()
    if step % 200 == 0:
        print(f"step {step}/{steps}  loss={loss.item():.4f}")

# ---------------------------
# 4) Inference with step-by-step attention
# ---------------------------
model.eval()
def run_prompt(prompt_tokens, max_new=6):
    prefix = ["<BOS>"] + prompt_tokens + ["<SEP>"]
    prefix_ids = torch.tensor([encode(prefix)], dtype=torch.long, device=device)
    print("\n========================")
    print("PROMPT:", " ".join(prompt_tokens))
    gen = model.generate_with_trace(prefix_ids, max_new_tokens=max_new)
    out_tokens = decode(gen[0].tolist())
    if "<SEP>" in out_tokens:
        start = out_tokens.index("<SEP>") + 1
        predicted = out_tokens[start:]
    else:
        predicted = out_tokens
    predicted = [t for t in predicted if t != "<PAD>"]
    print("OUTPUT:", " ".join(predicted))

# Try different prompts
# ---------------------------
# Interactive prediction
# ---------------------------
print("\n=== Chat with your tiny GPT ===")
while True:
    user_inp = input("\nEnter input (tokens separated by space, or 'quit'): ")
    if user_inp.lower() == "quit":
        break
    tokens = user_inp.strip().split()
    run_prompt(tokens)


step 200/1200  loss=0.0000
step 400/1200  loss=0.0000
step 600/1200  loss=0.0000
step 800/1200  loss=0.0000
step 1000/1200  loss=0.0000
step 1200/1200  loss=0.0000

=== Chat with your tiny GPT ===

PROMPT: pet

[Step 1] Context: <BOS> pet <SEP>
  Head 0: <BOS>:0.61 | pet:0.27 | <SEP>:0.12
  Head 1: <BOS>:0.05 | pet:0.46 | <SEP>:0.50
  Head 2: <BOS>:0.13 | pet:0.64 | <SEP>:0.24
  Head 3: <BOS>:0.04 | pet:0.16 | <SEP>:0.80
  → Predicted: good

[Step 2] Context: <BOS> pet <SEP> good
  Head 0: <BOS>:0.04 | pet:0.16 | <SEP>:0.31 | good:0.50
  Head 1: <BOS>:0.02 | pet:0.04 | <SEP>:0.85 | good:0.09
  Head 2: <BOS>:0.20 | pet:0.27 | <SEP>:0.14 | good:0.39
  Head 3: <BOS>:0.20 | pet:0.09 | <SEP>:0.67 | good:0.04
  → Predicted: <EOS>
OUTPUT: good <EOS>

PROMPT: what is my name

[Step 1] Context: <BOS> what is my name <SEP>
  Head 0: <BOS>:0.32 | what:0.17 | is:0.22 | my:0.14 | name:0.06 | <SEP>:0.10
  Head 1: <BOS>:0.01 | what:0.05 | is:0.38 | my:0.08 | name:0.15 | <SEP>:0.32
  Head 2: <BOS>:0.1

KeyboardInterrupt: Interrupted by user