<a href="https://colab.research.google.com/github/BhrgvPtl/decoder-transformer/blob/main/decoder_transformer_sampling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
 #@title 🔧 Setup
# You usually don't need to install torch on Colab; it's preinstalled.
# If you want the absolute latest: uncomment the next line.
# !pip install --quiet --upgrade torch

import re
import math
from dataclasses import dataclass
from typing import List, Tuple, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# Config
# ----------------------------
@dataclass
class Config:
    block_size: int = 24          # context window (max tokens per forward pass)
    n_embd: int = 64              # embedding size (d_model)
    n_heads: int = 4              # attention heads
    n_layers: int = 2             # transformer blocks
    dropout: float = 0.1
    lr: float = 3e-4
    batch_size: int = 16
    max_iters: int = 300
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    seed: int = 1337

    # Sampling controls
    temperature: float = 1.0      # >0; lower = more deterministic
    top_k: Optional[int] = None   # e.g., 5; None to disable
    top_p: Optional[float] = None # e.g., 0.9; None to disable

CFG = Config()
torch.manual_seed(CFG.seed)
print("Device:", CFG.device)


Device: cpu


In [None]:
#@title 📚 Corpus + Dataloader

# Small corpus (a few short lines so it trains fast)
CORPUS = [
    "<SOS> I Love Transformers <EOS>",
    "<SOS> Transformers are awesome <EOS>",
    "<SOS> I Love attention mechanisms <EOS>",
    "<SOS> Self attention learns token relations <EOS>",
]

# Tokenize: keep <SOS>/<EOS>, words, punctuation
def tokenize(s: str) -> List[str]:
    return re.findall(r"<EOS>|<SOS>|[\w]+|[^\w\s]", s)

tokens_all: List[str] = []
for line in CORPUS:
    tokens_all.extend(tokenize(line))

SPECIALS = ["<PAD>", "<SOS>", "<EOS>"]
for sp in SPECIALS:
    if sp not in tokens_all:
        tokens_all.append(sp)

# Vocab
itos = sorted(set(tokens_all))
stoi = {t: i for i, t in enumerate(itos)}
PAD_ID = stoi["<PAD>"]
SOS_ID = stoi["<SOS>"]
EOS_ID = stoi["<EOS>"]
VOCAB_SIZE = len(itos)

# Encode the whole corpus multiple times to create a stream
encoded_stream = []
for _ in range(32):
    for line in CORPUS:
        encoded_stream.extend(stoi[t] for t in tokenize(line))
encoded_data = torch.tensor(encoded_stream, dtype=torch.long)

class LMWindowedDataset(Dataset):
    """Overlapping (x, y) windows of length block_size for next-token prediction."""
    def __init__(self, ids: torch.Tensor, block_size: int):
        self.ids = ids
        self.block_size = block_size

    def __len__(self) -> int:
        return max(0, len(self.ids) - self.block_size)

    def __getitem__(self, idx) -> Tuple[torch.Tensor, torch.Tensor]:
        chunk = self.ids[idx : idx + self.block_size + 1]
        x = chunk[:-1]
        y = chunk[1:]
        return x, y

train_ds = LMWindowedDataset(encoded_data, CFG.block_size)
train_loader = DataLoader(train_ds, batch_size=CFG.batch_size, shuffle=True, drop_last=True)

print("Vocab size:", VOCAB_SIZE)
print("Dataset length:", len(train_ds))


Vocab size: 14
Dataset length: 712


In [None]:
#@title 🧠 Transformer Decoder Components

class SinusoidalPositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 10_000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)  # (T, C)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)  # (T, 1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float()
                             * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # even dims
        pe[:, 1::2] = torch.cos(position * div_term)  # odd dims
        self.register_buffer("pe", pe)  # not trainable

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        T = x.size(1)
        return x + self.pe[:T].unsqueeze(0)

class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd: int, n_heads: int, dropout: float):
        super().__init__()
        assert n_embd % n_heads == 0, "n_embd must be divisible by n_heads"
        self.n_heads = n_heads
        self.head_dim = n_embd // n_heads

        self.q_proj = nn.Linear(n_embd, n_embd, bias=False)
        self.k_proj = nn.Linear(n_embd, n_embd, bias=False)
        self.v_proj = nn.Linear(n_embd, n_embd, bias=False)
        self.out_proj = nn.Linear(n_embd, n_embd, bias=False)
        self.attn_drop = nn.Dropout(dropout)
        self.resid_drop = nn.Dropout(dropout)

        # Max mask size = CFG.block_size (we’ll slice to current T each forward)
        self.register_buffer(
            "mask",
            torch.tril(torch.ones(CFG.block_size, CFG.block_size)).unsqueeze(0).unsqueeze(0)
        )  # shape: (1, 1, T, T)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.size()
        q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(self.head_dim))  # (B, nh, T, T)
        mask = self.mask[:, :, :T, :T]
        att = att.masked_fill(mask == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_drop(att)
        y = att @ v  # (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_drop(self.out_proj(y))
        return y

class FeedForward(nn.Module):
    def __init__(self, n_embd: int, dropout: float):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.GELU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self, n_embd: int, n_heads: int, dropout: float):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = CausalSelfAttention(n_embd, n_heads, dropout)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ff = FeedForward(n_embd, dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x + self.attn(self.ln1(x))  # residual 1
        x = x + self.ff(self.ln2(x))    # residual 2
        return x

class DecoderOnlyTransformer(nn.Module):
    def __init__(self, vocab_size: int, n_embd: int, n_heads: int, n_layers: int,
                 dropout: float, block_size: int):
        super().__init__()
        self.block_size = block_size
        self.tok_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_enc = SinusoidalPositionalEncoding(n_embd, max_len=block_size)
        self.drop = nn.Dropout(dropout)
        self.blocks = nn.ModuleList([
            TransformerBlock(n_embd, n_heads, dropout) for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.Embedding):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)

    def forward(self, idx: torch.Tensor, targets: Optional[torch.Tensor] = None):
        B, T = idx.shape
        if T > self.block_size:
            raise ValueError(f"Sequence length {T} > block_size {self.block_size}")

        x = self.tok_emb(idx)     # (B, T, C)
        x = self.pos_enc(x)
        x = self.drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)  # (B, T, V)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, logits.size(-1)),
                targets.view(-1),
                ignore_index=PAD_ID
            )
        return logits, loss

    @torch.no_grad()
    def generate(
        self,
        idx: torch.Tensor,
        max_new_tokens: int,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
    ) -> torch.Tensor:
        self.eval()
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -self.block_size:]
            logits, _ = self(idx_cond)
            next_token_logits = logits[:, -1, :]  # (B, V)

            if temperature <= 0:
                raise ValueError("temperature must be > 0")
            next_token_logits = next_token_logits / temperature

            next_token_logits = top_k_top_p_filtering(
                next_token_logits, top_k=top_k, top_p=top_p
            )

            probs = F.softmax(next_token_logits, dim=-1)
            next_id = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_id), dim=1)
        return idx


In [None]:
#@title 🎲 Sampling Filters (top-k / top-p)

def top_k_top_p_filtering(
    logits: torch.Tensor,
    top_k: Optional[int] = None,
    top_p: Optional[float] = None,
    filter_value: float = -float("Inf"),
) -> torch.Tensor:
    """Apply top-k and/or nucleus (top-p) filtering per row to logits.

    Args:
        logits: (B, V)
        top_k: keep only k largest logits
        top_p: keep smallest set of tokens whose cumulative prob >= top_p
    """
    B, V = logits.shape

    # Top-k
    if top_k is not None and 1 <= top_k < V:
        kth_vals = torch.topk(logits, top_k, dim=-1).values[:, -1].unsqueeze(-1)  # (B, 1)
        mask = logits < kth_vals
        logits = logits.masked_fill(mask, filter_value)

    # Top-p (nucleus)
    if top_p is not None and 0.0 < top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)  # (B, V)
        sorted_probs = F.softmax(sorted_logits, dim=-1)
        cumprobs = torch.cumsum(sorted_probs, dim=-1)

        # mask tokens with cumulative prob > top_p (keep at least one)
        nucleus_mask = cumprobs > top_p
        nucleus_mask[:, 1:] = nucleus_mask[:, :-1].clone()
        nucleus_mask[:, 0] = False

        # map mask back to original indices
        scatter_mask = torch.zeros_like(nucleus_mask, dtype=torch.bool).scatter_(1, sorted_indices, nucleus_mask)
        logits = logits.masked_fill(scatter_mask, filter_value)

    return logits


In [None]:
#@title 🚂 Train Helpers

def train_one_epoch(model, loader, optim, device) -> float:
    model.train()
    total_loss = 0.0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        optim.zero_grad(set_to_none=True)
        _, loss = model(x, y)
        loss.backward()
        optim.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def decode(ids: List[int]) -> str:
    return " ".join(itos[i] for i in ids)


In [None]:
#@title 🏁 Train & Generate

model = DecoderOnlyTransformer(
    vocab_size=VOCAB_SIZE,
    n_embd=CFG.n_embd,
    n_heads=CFG.n_heads,
    n_layers=CFG.n_layers,
    dropout=CFG.dropout,
    block_size=CFG.block_size,
).to(CFG.device)

optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.lr)

for it in range(1, CFG.max_iters + 1):
    loss = train_one_epoch(model, train_loader, optimizer, CFG.device)
    if it % 50 == 0 or it == CFG.max_iters:
        print(f"iter {it:4d} | train loss {loss:.4f}")


iter   50 | train loss 0.0388
iter  100 | train loss 0.0336
iter  150 | train loss 0.0319
iter  200 | train loss 0.0309
iter  250 | train loss 0.0313
iter  300 | train loss 0.0313


In [None]:
# --- Generation demos ---
start = torch.tensor([[stoi["<SOS>"]]], dtype=torch.long, device=CFG.device)

print("\n=== Greedy-ish (temperature=1.0, no top-k/p) ===")
out = model.generate(start.clone(), max_new_tokens=12, temperature=1.0, top_k=None, top_p=None)[0].tolist()
print(out)
print("Decoded:", decode(out))

print("\n=== Temperature=0.8, top_k=5 ===")
out = model.generate(start.clone(), max_new_tokens=12, temperature=0.8, top_k=5, top_p=None)[0].tolist()
print(out)
print("Decoded:", decode(out))

print("\n=== Temperature=1.2, top_p=0.9 (nucleus) ===")
out = model.generate(start.clone(), max_new_tokens=12, temperature=1.2, top_k=None, top_p=0.9)[0].tolist()
print(out)
print("Decoded:", decode(out))



=== Greedy-ish (temperature=1.0, no top-k/p) ===
[2, 6, 7, 9, 0, 2, 3, 4, 8, 11, 0, 2, 5]
Decoded: <SOS> Transformers are awesome <EOS> <SOS> I Love attention mechanisms <EOS> <SOS> Self

=== Temperature=0.8, top_k=5 ===
[2, 3, 4, 6, 0, 2, 6, 7, 9, 0, 2, 3, 4]
Decoded: <SOS> I Love Transformers <EOS> <SOS> Transformers are awesome <EOS> <SOS> I Love

=== Temperature=1.2, top_p=0.9 (nucleus) ===
[2, 3, 4, 8, 11, 0, 2, 5, 8, 10, 13, 12, 0]
Decoded: <SOS> I Love attention mechanisms <EOS> <SOS> Self attention learns token relations <EOS>
