# Mini GPT-style Transformer

This notebook walks through the implementation of a **minimal GPT-style transformer model**. It's intended for hands-on experimentation and learning about:
- Token embeddings
- Causal attention masks
- Transformer layers
- Sampling and generation

**Try changing the vocabulary, hidden size, or sampling strategy to see how the model behaves!**

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
# Define a larger toy vocabulary
vocab = ['<pad>', '<bos>', '<eos>'] + '''hello world how are you i am doing well thanks bye today tomorrow what is your name my nice good bad okay yes no sure maybe friend not really'''.split()
stoi = {s: i for i, s in enumerate(vocab)}
itos = {i: s for s, i in stoi.items()}

VOCAB_SIZE = len(vocab)
print(f'Vocabulary size: {VOCAB_SIZE}')

Vocabulary size: 32


In [4]:
# Model hyperparameters
D_MODEL = 64
N_HEADS = 4
N_LAYERS = 4
MAX_LEN = 20

In [33]:
def generate_causal_mask(size):
    mask = torch.triu(torch.ones(size, size), diagonal=1)
    mask = mask.masked_fill(mask == 1, float('-inf'))
    return mask

# Show example
generate_causal_mask(5)

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])

In [6]:
class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.attn_weights = None

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        src2, attn_weights = self.self_attn(
            src, src, src,
            attn_mask=src_mask,
            key_padding_mask=src_key_padding_mask,
            need_weights=True,
            average_attn_weights=False
        )
        self.attn_weights = attn_weights
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        return src

In [7]:
class MiniGPT(nn.Module):
    def __init__(self, vocab_size, d_model, n_heads, n_layers, max_len):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Parameter(torch.zeros(1, max_len, d_model))
        self.encoder_layers = nn.ModuleList([
            MyTransformerEncoderLayer(d_model, n_heads, batch_first=True)
            for _ in range(n_layers)
        ])
        self.output_layer = nn.Linear(d_model, vocab_size)
        self.attn_weights = None

    def forward(self, x):
        B, T = x.size()
        x = self.token_embedding(x) + self.pos_embedding[:, :T, :]
        mask = generate_causal_mask(T).to(x.device)
        for layer in self.encoder_layers:
            x = layer(x, src_mask=mask)
            self.attn_weights = layer.attn_weights
        return self.output_layer(x)

In [51]:
def sample(model, prompt, max_new_tokens=5, temperature=1.0):
    model.eval()
    tokens = torch.tensor([stoi[t] for t in prompt], dtype=torch.long).unsqueeze(0).to(DEVICE)
    for _ in range(max_new_tokens):
        if tokens.size(1) > MAX_LEN:
            break
        logits = model(tokens)
        next_token_logits = logits[0, -1, :] / temperature
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1)
        if itos[next_token.item()] == '<eos>':
            break
        tokens = torch.cat([tokens, next_token.unsqueeze(0)], dim=1)
    return [itos[i.item()] for i in tokens[0]]


def sample_verbose(model, prompt, max_new_tokens=5, temperature=1.0):
    model.eval()
    tokens = torch.tensor([stoi[t] for t in prompt], dtype=torch.long).unsqueeze(0).to(DEVICE)

    steps = []

    for _ in range(max_new_tokens):
        if tokens.size(1) > MAX_LEN:
            break

        logits = model(tokens)
        next_token_logits = logits[0, -1, :] / temperature
        probs = F.softmax(next_token_logits, dim=-1)

        next_token = torch.multinomial(probs, num_samples=1)
        selected_token = next_token.item()
        token_str = itos[selected_token]

        # Save step info
        steps.append({
            'token_index': selected_token,
            'token_str': token_str,
            'logits': next_token_logits.detach().cpu(),
            'probs': probs.detach().cpu(),
        })

        if token_str == '<eos>':
            break

        tokens = torch.cat([tokens, next_token.unsqueeze(0)], dim=1)

    final_tokens = [itos[i.item()] for i in tokens[0]]
    return final_tokens, steps


def sample_top_k(model, prompt, max_new_tokens=5, temperature=1.0, top_k=5):
    model.eval()
    tokens = torch.tensor([stoi[t] for t in prompt], dtype=torch.long).unsqueeze(0).to(DEVICE)

    steps = []

    for _ in range(max_new_tokens):
        if tokens.size(1) > MAX_LEN:
            break

        logits = model(tokens)
        next_token_logits = logits[0, -1, :] / temperature

        # Top-k filtering
        topk_logits, topk_indices = torch.topk(next_token_logits, top_k)
        topk_probs = F.softmax(topk_logits, dim=-1)
        next_token_idx = torch.multinomial(topk_probs, num_samples=1)
        selected_token = topk_indices[next_token_idx].item()
        token_str = itos[selected_token]

        # Save step info
        full_probs = torch.zeros_like(next_token_logits)
        full_probs[topk_indices] = topk_probs

        steps.append({
            'token_index': selected_token,
            'token_str': token_str,
            'logits': next_token_logits.detach().cpu(),
            'probs': full_probs.detach().cpu(),  # filled with zeros except top-k
        })

        if token_str == '<eos>':
            break

        tokens = torch.cat([tokens, torch.tensor([[selected_token]], device=DEVICE)], dim=1)

    final_tokens = [itos[i.item()] for i in tokens[0]]
    return final_tokens, steps


In [26]:
def show_generation_steps(steps, top_k=10):
    for i, step in enumerate(steps):
        logits = step['logits']
        probs = step['probs']
        df = pd.DataFrame({
            'token': [itos[j] for j in range(len(logits))],
            'logits': logits.tolist(),
            'probs': probs.tolist()
        }).sort_values('probs', ascending=False).head(top_k)

        print(f"\nStep {i + 1}: Generated token = `{step['token_str']}`")
        display(df)

        df.plot(x='token', y=['logits', 'probs'], kind='bar', figsize=(8, 4))
        plt.grid(True)
        plt.title(f"Step {i+1}: Logits vs Probs")
        plt.tight_layout()
        plt.show()


In [54]:
model = MiniGPT(VOCAB_SIZE, D_MODEL, N_HEADS, N_LAYERS, MAX_LEN).to(DEVICE)
prompt = ['<bos>', 'hello']
final_output = sample(model, prompt, max_new_tokens=5, temperature=1.0)
final_output, steps = sample_top_k(model, prompt, max_new_tokens=5, temperature=1.0, top_k=1)
final_output, steps = sample_verbose(model, prompt, max_new_tokens=5, temperature=0.01)
print("Generated:", final_output)
# show_generation_steps(steps, 100)


Generated: ['<bos>', 'hello', 'my', 'maybe', 'well', 'are', 'no']
