In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from model import TinyGPT
from data import CharDataset, CharTokenizer

Attention weights:
 tensor([[[0.4711, 0.2645, 0.2645],
         [0.2645, 0.4711, 0.2645],
         [0.2645, 0.2645, 0.4711]]])
Output:
 tensor([[[-0.3374, -0.5534,  0.8631, -0.5376],
         [-0.6592, -0.2661,  0.9557, -0.8329],
         [-0.1743, -0.7239,  1.2220, -0.6736]]])


In [2]:
# Example toy dataset
text = "abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ"
tok = CharTokenizer(text)
print("Vocab size:", tok.vocab_size)

# Encode the full text into tokens
data = torch.tensor(tok.encode(text), dtype=torch.long)

# 90% train, 10% val split
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]

dataset = {
    "train": train_data,
    "val": val_data
}

block_size = 8
vocab_size = tok.vocab_size


Vocab size: 53


In [3]:
def get_batch(split, data, batch_size=4, block_size=8, device="cuda"):
    # split: "train" or "val"
    # data: dict with {"train": tensor, "val": tensor}
    
    # pick the right dataset
    data_split = data[split]

    # pick random starting indices
    ix = torch.randint(len(data_split) - block_size, (batch_size,))
    
    # slice out input (x) and target (y)
    x = torch.stack([data_split[i:i+block_size] for i in ix])
    y = torch.stack([data_split[i+1:i+block_size+1] for i in ix])
    
    return x.to(device), y.to(device)

In [5]:
device = "cpu"
print("Using device:", device)

model = TinyGPT(
    vocab_size=vocab_size,
    d_model=128,
    n_heads=4,
    n_layers=2,
    max_seq_len=block_size
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()

# Training loop
for step in range(1000):
    # 1. get batch
    x, y = get_batch("train", dataset, block_size=block_size, device=device)

    # 2. forward
    logits = model(x)

    # 3. compute loss
    loss = loss_fn(logits.view(-1, logits.size(-1)), y.view(-1))

    # 4. backward
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f"Step {step}, loss = {loss.item():.4f}")

Using device: cpu
Step 0, loss = 82.2274
Step 100, loss = 6.5171
Step 200, loss = 1.4414
Step 300, loss = 0.5110
Step 400, loss = 0.0211
Step 500, loss = 0.0082
Step 600, loss = 0.0065
Step 700, loss = 0.0059
Step 800, loss = 0.0026
Step 900, loss = 0.0018


In [8]:
import torch
import torch.nn.functional as F

@torch.no_grad()
def sample(
    model,
    tokenizer,
    start: str,
    max_new_tokens: int = 50,
    temperature: float = 1.0,
    top_k: int | None = None,
    top_p: float | None = None,
    device: str | None = None,
):
    """
    Robust autoregressive sampling for TinyGPT-like models.

    - model: your TinyGPT instance (should have model.pos_emb)
    - tokenizer: your CharTokenizer with encode()/decode()
    - start: prompt string
    - max_new_tokens: how many tokens to generate
    - temperature: >0 float (small <1 => deterministic), 0 => greedy
    - top_k: keep only top_k logits (int) if not None
    - top_p: nucleus probability threshold (0<p<1) if not None
    """
    model.eval()
    if device is None:
        device = next(model.parameters()).device
    # encode prompt
    input_ids = tokenizer.encode(start)
    if len(input_ids) == 0:
        raise ValueError("Prompt must be non-empty (or handle a default token).")
    # make tensor shape (1, seq)
    generated = torch.tensor([input_ids], dtype=torch.long, device=device)

    # max context length supported by the model's positional embeddings
    # we read num_embeddings (safe even if model doesn't store max_seq_len explicitly)
    max_context = model.pos_emb.num_embeddings

    for _ in range(max_new_tokens):
        # trim to model context window (use only the last max_context tokens)
        if generated.size(1) > max_context:
            input_ids_tensor = generated[:, -max_context:]
        else:
            input_ids_tensor = generated

        # forward pass to get logits: shape (1, seq, vocab)
        logits = model(input_ids_tensor)  
        logits = logits[:, -1, :]  # take logits for the last position -> shape (1, vocab)

        # temperature handling
        if temperature == 0:
            # greedy
            next_token = torch.argmax(logits, dim=-1, keepdim=True)  # shape (1,1)
            generated = torch.cat((generated, next_token), dim=1)
            continue
        else:
            logits = logits / float(temperature)

        # top-k filtering (optional)
        if top_k is not None and top_k > 0:
            top_k = min(top_k, logits.size(-1))  # don't exceed vocab
            # get the kth largest logit value for each batch row and mask below it
            values, _ = torch.topk(logits, top_k, dim=-1)
            min_values = values[..., -1, None]  # threshold
            logits = torch.where(logits < min_values, torch.full_like(logits, float("-inf")), logits)

        # top-p (nucleus) filtering (optional)
        if top_p is not None and 0.0 < top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            sorted_probs = F.softmax(sorted_logits, dim=-1)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

            # mask tokens with cumulative prob > top_p
            sorted_mask = cumulative_probs > top_p
            # keep at least one token: shift mask right so first token included
            sorted_mask[..., 1:] = sorted_mask[..., :-1].clone()
            sorted_mask[..., 0] = False

            # now set logits of tokens to remove to -inf
            indices_to_remove = sorted_mask.scatter(-1, sorted_indices, sorted_mask)
            logits = logits.masked_fill(indices_to_remove, float("-inf"))

        # convert to probabilities
        probs = F.softmax(logits, dim=-1)  # (1, vocab)

        # sample
        next_token = torch.multinomial(probs, num_samples=1)  # (1,1)
        generated = torch.cat((generated, next_token), dim=1)

    # return the entire sequence as text (prompt + generated)
    return tokenizer.decode(generated[0].tolist())


In [11]:
text = sample(model, tokenizer=tok, start="Hello", max_new_tokens=50,
              temperature=0.8, top_k=20, top_p=0.9, device="cpu")
print(text)


['H', 'e', 'l', 'l', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T', 'T']
