## Imports

In [73]:
import math
import numpy as np
import requests
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset


# on my M4 pro chip
# if torch.backends.mps.is_available():
#     device = torch.device("cuda")
#     print("Using my M4 pro")
# else:
#     device = torch.device("cpu")
#     print("Using CPU")


# on colab
device = "cuda" if torch.cuda.is_available() else "cpu"


print("Selected device:", device)

Selected device: cpu


## Download the Shakespeare Dataset

In [74]:
# Download Shakespeare dataset
url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
response = requests.get(url)
text = response.text

print(f"Dataset length: {len(text)} characters")
print(f"First 200 characters:\n{text[:200]}")

Dataset length: 1115394 characters
First 200 characters:
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


In [75]:
class Config:
    """Configuration class for model and training hyperparameters."""

    # Model hyperparams
    n_layer = 12
    n_head = 8  # Number of attention heads
    n_embd = 768  # Embedding dimension
    block_size = 128  # Maximum sequence length
    learning_rate = 3e-4

    # Training hyperparameters
    batch_size = 128

    # Regularization
    # dropout = 0.2
    dropout = 0.0
    grad_clip = 1.0


config = Config()

## Character Dataset Class

In [76]:
class CharDataset(Dataset):
    """
    Character-level dataset for language modeling.
    Emits batches of characters encoded as integers.
    """

    def __init__(self, config, data):
        self.config = config
        self.data = data

        # unique characters
        chars = sorted(list(set(data)))
        self.vocab_size = len(chars)

        print(f"Vocabulary size: {self.vocab_size}")
        print(f"Unique characters: {''.join(chars)}")

        # Character to index and index to character mappings
        self.stoi = {ch: i for i, ch in enumerate(chars)}
        self.itos = {i: ch for i, ch in enumerate(chars)}

    def get_vocab_size(self):
        return self.vocab_size

    def __len__(self):
        # Number of possible sequences
        return len(self.data) - self.config.block_size

    def __getitem__(self, idx):
        # chunk of (block_size + 1) characters from the data
        chunk = self.data[idx : idx + self.config.block_size + 1]

        # Encode every character to an integer
        dix = [self.stoi[ch] for ch in chunk]

        # Input is first block_size characters, target is the same but shifted by one
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)

        return x, y

    def encode(self, text):
        """Encode string to list of integers."""
        return [self.stoi[ch] for ch in text]

    def decode(self, indices):
        """Decode list of integers to string."""
        return "".join([self.itos[i] for i in indices])


# create dataset
dataset = CharDataset(config, text)
config.vocab_size = dataset.get_vocab_size()

Vocabulary size: 65
Unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


## Causal multi-head self-attention mechanism

In [77]:
class CausalSelfAttention(nn.Module):
    """
    Multi-head masked self-attention with causal masking
    """

    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0

        # Key, query, value projections for all heads
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)

        # Output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)

        # Regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        self.n_head = config.n_head
        self.n_embd = config.n_embd

        # Causal mask to ensure attention is only applied to the left in the input sequence
        self.register_buffer(
            "bias",
            torch.tril(torch.ones(config.block_size, config.block_size)).view(
                1, 1, config.block_size, config.block_size
            ),
        )

    def forward(self, x):
        B, T, C = (
            x.size()
        )  # batch size, sequence length, embedding dimensionality (n_embd)

        # Calculate query, key, values for all heads
        q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # Causal self-attention: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf"))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = (
            y.transpose(1, 2).contiguous().view(B, T, C)
        )  # gather all head outputs side by side

        # Output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

## Feed-Forward Network

In [78]:
class MLP(nn.Module):
    """
    multi-layer perceptron with GELU activation
    """

    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

## Transformer block

In [79]:
class Block(nn.Module):
    """
    Transformer block: communication followed by computation.
    """

    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.n_embd)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = nn.LayerNorm(config.n_embd)
        self.mlp = MLP(config)

    def forward(self, x):
        # Pre-normalization architecture with residual connections
        x = x + self.attn(self.ln_1(x))
        x = x + self.mlp(self.ln_2(x))
        return x

## Mini GPT

In [80]:
class GPT(nn.Module):
    """
    GPT Language Model
    """

    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(
            dict(
                wte=nn.Embedding(config.vocab_size, config.n_embd),  # Token embeddings
                wpe=nn.Embedding(
                    config.block_size, config.n_embd
                ),  # Positional embeddings
                drop=nn.Dropout(config.dropout),
                h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
                ln_f=nn.LayerNorm(config.n_embd),
            )
        )
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Initialize weights
        self.apply(self._init_weights)

        # scaled init to the residual projections
        for pn, p in self.named_parameters():
            if pn.endswith("c_proj.weight"):
                torch.nn.init.normal_(
                    p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)
                )

        print(f"Number of parameters: {self.get_num_params()/1e6:.2f}M")

    def get_num_params(self):
        """Return the number of parameters in the model."""
        return sum(p.numel() for p in self.parameters())

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()
        assert (
            t <= self.config.block_size
        ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
        pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(
            0
        )  # shape (1, t)

        # Forward pass through the transformer
        tok_emb = self.transformer.wte(idx)  # token embeddings of shape (b, t, n_embd)
        pos_emb = self.transformer.wpe(
            pos
        )  # positional embeddings of shape (1, t, n_embd)
        x = self.transformer.drop(tok_emb + pos_emb)
        for block in self.transformer.h:
            x = block(x)
        x = self.transformer.ln_f(x)
        logits = self.lm_head(x)

        # compute loss if targets is available
        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss

    @torch.no_grad()
    def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):
        """
        Generate new tokens given a conditioning sequence

        Args:
            idx: conditioning sequence of indices (LongTensor of shape (b, t))
            max_new_tokens: number of tokens to generate
            temperature: temperature for sampling (higher = more random)
            top_k: if set, only sample from the top k most likely tokens
        """
        for _ in range(max_new_tokens):
            # cut context if it exceeds block_size
            idx_cond = (
                idx
                if idx.size(1) <= self.config.block_size
                else idx[:, -self.config.block_size :]
            )
            # forward pass
            logits, _ = self(idx_cond)

            # Focus only on the last time step
            logits = logits[:, -1, :] / temperature

            # Optionally crop probabilities to only the top k options
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
            # softmax to get probabilities
            probs = F.softmax(logits, dim=-1)

            # Sample from the distribution
            idx_next = torch.multinomial(probs, num_samples=1)

            # Append sampled index to the running sequence
            idx = torch.cat((idx, idx_next), dim=1)

        return idx

## Initialize model and optimizer

In [81]:
# Create model
model = GPT(config)
model = model.to(device)

# Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

Number of parameters: 85.25M


## Some helpers

In [None]:
@torch.no_grad()
def estimate_loss(model, dataset, config):

    model.eval()
    losses = []

    for _ in range(config.eval_iters):
        # Get random batch
        idx = np.random.randint(0, len(dataset), size=(config.batch_size,))
        x_batch = []
        y_batch = []
        for i in idx:
            x, y = dataset[i]
            x_batch.append(x)
            y_batch.append(y)

        x_batch = torch.stack(x_batch).to(device)
        y_batch = torch.stack(y_batch).to(device)

        _, loss = model(x_batch, y_batch)
        losses.append(loss.item())

    model.train()
    return np.mean(losses)


def get_batch(dataset, config):
    """
    Get a random batch from the dataset.
    """
    idx = np.random.randint(0, len(dataset), size=(config.batch_size,))
    x_batch = []
    y_batch = []
    for i in idx:
        x, y = dataset[i]
        x_batch.append(x)
        y_batch.append(y)

    x_batch = torch.stack(x_batch).to(device)
    y_batch = torch.stack(y_batch).to(device)

    return x_batch, y_batch