# Building a Tiny Decoder-Only Transformer

Welcome! We'll re-create a GPT-style decoder with nothing more than PyTorch. The walkthrough is broken into:

- Core building blocks (normalization, attention, feed-forward)
- Assembling the full language model
- Training on Tiny Shakespeare and generating text
- Experimenting with sampling tricks (top-k, top-p, temperature)

Each section intentionally slows down to explain both the math and the PyTorch mechanics.


In [None]:
import sys
if 'google.colab' in sys.modules:
    print("Running in Google Colab. Installing necessary packages...")
    !git clone https://github.com/auliyafitri/dl_basics_for_institutes.git
    %cd dl_basics_for_institutes

In [None]:
# Pin the datasets package version to avoid transient download errors on Colab runtimes.
!pip install datasets==3.6.0
!pip install seaborn
!pip install transformers

In [None]:
import math
import torch
from torch import nn
from itertools import chain
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import Dataset, DataLoader
from dataclasses import dataclass
import matplotlib.pyplot as plt
import seaborn as sns


def set_seed(value: int = 42) -> None:
    """Keep experiments deterministic enough for demo purposes."""
    torch.manual_seed(value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(value)


set_seed()


## Model Blueprint

We'll keep every hyperparameter in a single dataclass so it's easy to tweak knobs later on.


In [None]:
@dataclass
class ModelSettings:
    embedding_dim: int = 512
    feedforward_dim: int = 4 * 512
    blocks: int = 2
    dropout: float = 0.1
    vocab_size: int = 10_000
    context_length: int = 128


cfg = ModelSettings()
cfg


### Dissecting a Transformer Block

Each decoder block follows a *norm → attention → residual → norm → MLP → residual* pattern. We'll rebuild every piece so the math never feels like a black box.


In [None]:
class CustomLayerNorm(nn.Module):
    """Manual implementation of layer normalization for educational clarity."""

    def __init__(self, width: int, eps: float = 1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(width))
        self.bias = nn.Parameter(torch.zeros(width))
        self.eps = eps

    def forward(self, tensor: torch.Tensor) -> torch.Tensor:
        # tensor: (batch, sequence, feature)
        mean = tensor.mean(dim=-1, keepdim=True)
        variance = tensor.var(dim=-1, keepdim=True, unbiased=False)
        normalized = (tensor - mean) / torch.sqrt(variance + self.eps)
        return normalized * self.weight + self.bias


demo_norm = CustomLayerNorm(cfg.embedding_dim)
dummy = torch.randn(1, 3, cfg.embedding_dim)
print("LayerNorm demo output shape:", demo_norm(dummy).shape)


In [None]:
class SelfAttentionModule(nn.Module):
    """Single-head self-attention with a triangular mask for autoregressive decoding."""

    def __init__(self, width: int, dropout: float = 0.1):
        super().__init__()
        self.width = width
        self.qkv_proj = nn.Linear(width, 3 * width)
        self.out_proj = nn.Linear(width, width)
        self.attn_dropout = nn.Dropout(dropout)
        self.out_dropout = nn.Dropout(dropout)

    def forward(self, tokens: torch.Tensor, return_weights: bool = False):
        batch, seq_len, features = tokens.shape
        if features != self.width:
            raise ValueError(
                f"Expected embedding size {self.width}, but received {features}."
            )

        q, k, v = self.qkv_proj(tokens).chunk(3, dim=-1)
        scores = torch.matmul(q, k.transpose(-2, -1)) * (self.width**-0.5)

        mask = torch.ones(seq_len, seq_len, device=tokens.device).tril()
        scores = scores.masked_fill(mask == 0, float("-inf"))

        weights = torch.softmax(scores, dim=-1)
        weights = self.attn_dropout(weights)

        attended = torch.matmul(weights, v)
        projected = self.out_proj(attended)
        projected = self.out_dropout(projected)

        if return_weights:
            return projected, weights
        return projected


Let's peek at the attention map on random data to confirm the masking behaves as expected.


In [None]:
torch.manual_seed(0)
attention = SelfAttentionModule(cfg.embedding_dim, dropout=0.0)
sample_tokens = torch.randn(1, 6, cfg.embedding_dim)
_, attn_weights = attention(sample_tokens, return_weights=True)
weights_np = attn_weights.squeeze(0).detach().numpy()

plt.figure(figsize=(4.5, 4))
sns.heatmap(weights_np, cmap="magma", cbar=True, square=True)
plt.title("Attention mask in action")
plt.xlabel("Key index")
plt.ylabel("Query index")
plt.show()


In [None]:
class FeedForward(nn.Module):
    """Two-layer MLP used inside each transformer block."""

    def __init__(self, width: int, hidden: int, dropout: float = 0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(width, hidden),
            nn.GELU(),
            nn.Linear(hidden, width),
            nn.Dropout(dropout),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)


With normalization, attention, and the feed-forward network ready, we can wire up the full decoder block.


In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, width: int, ff_hidden: int, dropout: float = 0.1):
        super().__init__()
        self.norm_attn = CustomLayerNorm(width)
        self.attention = SelfAttentionModule(width, dropout)
        self.norm_ff = CustomLayerNorm(width)
        self.feedforward = FeedForward(width, ff_hidden, dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        normed = self.norm_attn(x)
        attn_out = self.attention(normed)
        x = normed + attn_out

        normed_ff = self.norm_ff(x)
        ff_out = self.feedforward(normed_ff)
        x = normed_ff + ff_out
        return x


test_block = TransformerBlock(cfg.embedding_dim, cfg.feedforward_dim)
print(
    "Transformer block output shape:",
    test_block(torch.randn(1, 3, cfg.embedding_dim)).shape,
)


### Assembling the Language Model

A minimal GPT-style model stacks several decoder blocks, adds token & position embeddings, and projects back to vocabulary logits.


In [None]:
class MiniDecoderLM(nn.Module):
    def __init__(
        self,
        *,
        vocab_size: int,
        embedding_dim: int,
        ff_hidden_dim: int,
        context_length: int,
        layers: int,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.token_embeddings = nn.Embedding(vocab_size, embedding_dim)
        self.position_embeddings = nn.Embedding(context_length, embedding_dim)
        self.dropout = nn.Dropout(dropout)

        self.blocks = nn.ModuleList(
            [
                TransformerBlock(embedding_dim, ff_hidden_dim, dropout)
                for _ in range(layers)
            ]
        )
        self.norm = nn.LayerNorm(embedding_dim)
        self.lm_head = nn.Linear(embedding_dim, vocab_size)

    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        batch, seq_len = token_ids.shape
        positions = torch.arange(seq_len, device=token_ids.device).unsqueeze(0)
        hidden = self.token_embeddings(token_ids) + self.position_embeddings(positions)
        hidden = self.dropout(hidden)

        for block in self.blocks:
            hidden = block(hidden)

        hidden = self.norm(hidden)
        logits = self.lm_head(hidden)
        return logits


demo_ids = torch.randint(0, cfg.vocab_size, (1, 4))
toy_model = MiniDecoderLM(
    vocab_size=cfg.vocab_size,
    embedding_dim=cfg.embedding_dim,
    ff_hidden_dim=cfg.feedforward_dim,
    context_length=cfg.context_length,
    layers=cfg.blocks,
    dropout=cfg.dropout,
)
print(toy_model)
print("Logit tensor shape:", toy_model(demo_ids).shape)


## Training and Text Generation

Now that the architecture stands, let's prepare data, train the model, and see it generate text.


### Preparing the Dataset

We'll lean on the Tiny Shakespeare corpus because it's small, public, and perfect for quick experiments.


In [None]:
# Load raw text splits
train_raw = load_dataset("karpathy/tiny_shakespeare", split="train")
valid_raw = load_dataset("karpathy/tiny_shakespeare", split="validation")

# GPT-2 tokenizer keeps things simple and gives us byte-level BPEs
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Update configuration now that we know the vocabulary size
cfg.vocab_size = tokenizer.vocab_size


def tokenize_batch(batch):
    return tokenizer(batch["text"])


train_tokenized = train_raw.map(tokenize_batch, batched=True, remove_columns=["text"])
valid_tokenized = valid_raw.map(tokenize_batch, batched=True, remove_columns=["text"])

train_stream = list(chain.from_iterable(train_tokenized["input_ids"]))
valid_stream = list(chain.from_iterable(valid_tokenized["input_ids"]))[:2000]


class AutoregressiveDataset(Dataset):
    def __init__(self, stream, block_size):
        self.stream = stream
        self.block_size = block_size

    def __len__(self):
        return len(self.stream) - self.block_size

    def __getitem__(self, index):
        window = self.stream[index : index + self.block_size + 1]
        x = torch.tensor(window[:-1], dtype=torch.long)
        y = torch.tensor(window[1:], dtype=torch.long)
        return x, y


sequence_length = cfg.context_length
train_dataset = AutoregressiveDataset(train_stream, sequence_length)
valid_dataset = AutoregressiveDataset(valid_stream, sequence_length)

train_loader = DataLoader(
    train_dataset, batch_size=16, shuffle=True, drop_last=True, pin_memory=True
)
valid_loader = DataLoader(
    valid_dataset, batch_size=16, shuffle=False, drop_last=True, pin_memory=True
)

print(f"Batches in training loader: {len(train_loader)}")
print(f"Batches in validation loader: {len(valid_loader)}")


We'll also decide where checkpoints should live so we can resume or perform inference later.


In [None]:
# Choose where to persist checkpoints.


# To keep everything inside the temporary runtime instead, uncomment:
model_checkpoint_path = "pretrained/mini_llm.pth"


We're ready to launch the training loop. Feel free to interrupt the cell once you've seen enough iterations—the best checkpoint will already be written to disk.


In [None]:
from tqdm import tqdm
import matplotlib.pyplot as plt
from IPython.display import clear_output, display


def plot_history(train_steps, train_losses, val_steps, val_losses):
    clear_output(wait=True)
    plt.clf()
    plt.plot(train_steps, train_losses, label="Train loss", marker="o")
    plt.plot(val_steps, val_losses, label="Validation loss", marker="x")
    plt.xlabel("Step")
    plt.ylabel("Loss")
    plt.title("Training progress")
    plt.grid(True)
    plt.legend()
    display(plt.gcf())
    plt.close()


def run_training(
    model,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    epochs,
    *,
    device,
    log_every=100,
    eval_every=500,
    patience=5,
):
    model.train()
    global_step = 0
    loss_buffer = []
    train_steps, train_losses = [], []
    val_steps, val_losses = [], []
    checks_without_improvement = 0

    plt.figure(figsize=(8, 5))

    for epoch in range(epochs):
        for inputs, targets in tqdm(train_loader, desc=f"Epoch {epoch}"):
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            logits = model(inputs)
            loss = criterion(logits.view(-1, logits.size(-1)), targets.reshape(-1))

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_buffer.append(loss.item())
            global_step += 1
            refresh_plot = False

            if global_step % log_every == 0:
                train_steps.append(global_step)
                train_losses.append(sum(loss_buffer) / len(loss_buffer))
                loss_buffer.clear()
                refresh_plot = True

            if global_step % eval_every == 0:
                model.eval()
                val_total = 0.0
                batches = 0
                with torch.no_grad():
                    for v_inputs, v_targets in val_loader:
                        v_inputs = v_inputs.to(device, non_blocking=True)
                        v_targets = v_targets.to(device, non_blocking=True)
                        v_logits = model(v_inputs)
                        v_loss = criterion(
                            v_logits.view(-1, v_logits.size(-1)), v_targets.reshape(-1)
                        )
                        val_total += v_loss.item()
                        batches += 1
                val_loss = val_total / max(1, batches)

                if not val_losses or val_loss < min(val_losses):
                    torch.save(model.state_dict(), model_checkpoint_path)
                    checks_without_improvement = 0
                else:
                    checks_without_improvement += 1
                    if checks_without_improvement > patience:
                        print(
                            f"Early stopping triggered after {patience} unimproved checks."
                        )
                        plot_history(train_steps, train_losses, val_steps, val_losses)
                        return

                val_steps.append(global_step)
                val_losses.append(val_loss)
                model.train()
                refresh_plot = True

            if refresh_plot:
                plot_history(train_steps, train_losses, val_steps, val_losses)

    print("Training complete.")


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

model = MiniDecoderLM(
    vocab_size=cfg.vocab_size,
    embedding_dim=cfg.embedding_dim,
    ff_hidden_dim=cfg.feedforward_dim,
    context_length=cfg.context_length,
    layers=cfg.blocks,
    dropout=cfg.dropout,
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

run_training(
    model,
    train_loader,
    valid_loader,
    optimizer,
    criterion,
    epochs=1,
    device=device,
    log_every=100,
    eval_every=500,
    patience=5,
)


### Using the Trained Model


Load the checkpoint and rebuild the model with the exact same configuration.


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = MiniDecoderLM(
    vocab_size=cfg.vocab_size,
    embedding_dim=cfg.embedding_dim,
    ff_hidden_dim=cfg.feedforward_dim,
    context_length=cfg.context_length,
    layers=cfg.blocks,
    dropout=cfg.dropout,
).to(device)

state_dict = torch.load(model_checkpoint_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()


#### Greedy Decoding


In [None]:
def greedy_decode(model, input_ids, max_len, *, device=torch.device("cpu")):
    model.eval()
    generated = input_ids.to(device)

    for _ in range(max_len - generated.size(1)):
        logits = model(generated)
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
        generated = torch.cat([generated, next_token], dim=-1)

    return generated


prompt_ids = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long)
prompt_ids = prompt_ids.to(device)

greedy_ids = greedy_decode(model, prompt_ids, max_len=128, device=device)
text = tokenizer.decode(greedy_ids[0].tolist(), skip_special_tokens=True)

print("==== Greedy decode ====")
print(text)


#### Sampling

Sampling adds controlled randomness to generation so outputs don't collapse to the same completion every time.


In [None]:
def apply_top_k(logits, top_k):
    if not isinstance(top_k, int) or top_k <= 0:
        raise ValueError(f"top_k must be a positive integer, received {top_k}")
    threshold = torch.topk(logits, top_k, dim=-1).values[..., -1, None]
    filtered = torch.where(
        logits < threshold, torch.full_like(logits, float("-inf")), logits
    )
    return filtered


def apply_top_p(logits, top_p):
    if not (0.0 < top_p <= 1.0):
        raise ValueError(f"top_p must lie in (0, 1], received {top_p}")

    sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
    sorted_probs = torch.softmax(sorted_logits, dim=-1)
    cumulative = torch.cumsum(sorted_probs, dim=-1)

    mask = cumulative > top_p
    mask[..., 0] = False
    sorted_logits = sorted_logits.masked_fill(mask, float("-inf"))

    restored = torch.full_like(logits, float("-inf"))
    restored.scatter_(-1, sorted_indices, sorted_logits)
    return restored


In [None]:
def sample_text(
    model,
    input_ids,
    max_len,
    *,
    top_k=None,
    top_p=None,
    temperature=1.0,
    device=torch.device("cpu"),
):
    model.eval()
    generated = input_ids.to(device)

    for _ in range(max_len - generated.size(1)):
        logits = model(generated)[:, -1, :]
        logits = logits / temperature

        if top_k is not None:
            logits = apply_top_k(logits, top_k)
        if top_p is not None:
            logits = apply_top_p(logits, top_p)

        probs = torch.softmax(logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        generated = torch.cat([generated, next_token], dim=-1)

    return generated


prompt_ids = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(device)

sampled_ids = sample_text(
    model,
    prompt_ids,
    max_len=128,
    top_k=3,
    top_p=0.9,
    temperature=1.0,
    device=device,
)

print("==== Sampled decode ====")
print(tokenizer.decode(sampled_ids[-1].tolist(), skip_special_tokens=True))


## Further Reading

- [Hugging Face GPT-2 implementation](https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py) for a production-grade reference.
- [The Illustrated GPT-2](https://jalammar.github.io/illustrated-gpt2/) for a visual introduction to the architecture.
- [llm.c](https://github.com/karpathy/llm.c) for a minimalist C implementation that mirrors these concepts.
- Vaswani et al., "Attention Is All You Need" — the original Transformer paper.
