## Character-level RNN for Text Generation

Train a small character-level language model on a single file to generate new text

## Setup & Config
- imports, device, seed
- simple config dictionary for reproducibility

In [1]:
import math
import os
import random
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

def set_seed(seed: int = 42):
    random.seed(seed)                           # Python's random functions
    torch.manual_seed(seed)                     # PyTorch CPU ops
    torch.cuda.manual_seed_all(seed)            # PyTorch GPU ops
    torch.backends.cudnn.deterministic = True   # cuDNN algortihm choice
    torch.backends.cudnn.benchmark = False               # cuDNN auto-tuner

set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = {
    "data_path": "data/tinyshakespeare.txt",    # path to text file (e.g., "data/shakespeare.text"); None uses built-in sample
    "seq_len": 128,         # sequence length for training (input and target chunks)
    "batch_size": 128,      # number of sequences per training batch
    "embedding_dim": 256,   # size of character embedding vectors
    "hidden_dim": 256,      # size of hidden state in RNN (GRU or LSTM)
    "num_layers": 1,        # number of stacked RNN layers
    "dropout": 0.1,         # dropout probability between layers
    "rnn_type": "GRU",      # type of RNN: "GRU" or "LSTM"
    "num_epochs": 5,        # number of full passes through the training dataset
    "learning_rate": 2e-3,  # inital learning rate for the optimizer
    "grad_clip": 1,         # gradient clipping threshold to prevent exploding gradients
    "log_every": 100,       # how often (in steps) to print training loss
    "sample_every": 100,    # how often (in steps) to generate sample text
    "max_generate": 400,    # number of characters tto generate during sampling
    "temperature": 0.9,     # sampling temperature (controls randomness in output)
    "top_k": 40,            # top-k sampling: consider only the top k most probable characters
    "top_p": 0.9,           # top-p (nucleus) sampling: consider top tokens whose probabilites sum to p
    "val_fraction": 0.05,   # fraction of the dataset to use for validation
    "overlap_step": None,   # if set, use overlapping training chunks (e.g, step size = seq_len // 2)
    "save_path": "char_rnn_checkpoint.pt",  # where to save the trained model checkpoint
}

## Load Data
- provide a path to corpus (plain .txt) in config["data-path"]
- if not provided, we use a built-in snippet

In [2]:
# load text
if config["data_path"] and os.path.exists(config["data_path"]):
    with open(config["data_path"], "r", encoding="utf-8") as f:
        text = f.read()
else:
    # snippet for testing
    text = (
        "ROMEO:\nBut soft, what light through yonder window breaks?\n"
        "It is the east, and Juliet is the sun.\nArise, fair sun, and kill the envious moon,\n"
        "Who is already sick and pale with grief.\n\n"
        "JULIET:\nO Romeo, Romeo! wherefore art thou Romeo?\n"
        "Deny thy father and refuse thy name;\nOr, if thou wilt not, be but sworn my love,\n"
        "And I'll no longer be a Capulet.\n"       
    )

## Character Vocabulary
- build stoi and itos
    - stoi encodes text into IDs for the model, and itos decodes model outputs back into readable text
    - stoi (string-to-index): a dict mapping each character to an integer
    - itos (index-to-string): a list mapping each ID back to its character
- encode/decode utilities

In [3]:
class CharVocab:
    def __init__(self, text: str):
        chars = sorted(list(set(text)))     # 45 unique chars
        self.itos = chars                   
        self.stoi = {ch: i for i, ch in enumerate(chars)}
    
    # takes text and return list of ids for each char
    def encode(self, s: str) -> List[int]:
        return [self.stoi[c] for c in s if c in self.stoi]

    def decode(self, ids: List[int]) -> str:
        return "". join(self.itos[i] for i in ids)

vocab = CharVocab(text)
vocab_size = len(vocab.itos)    # 45

## Encode & Split, Dataset & DataLoader
- encode the full text into integer IDs
- split into train/val by fraction
- create chunked dataset returning (x, y) where y is the next-char targets
### Additional Notes
for Shakespeare text:
- train/val split
    - len(text) = 347
        - the original text
    - len(vocab.encode(text)) = 347
        - the list of ids for each char
    - n_total = len(data_ids) = 347
        - turns the encoding into a tensor
    - n_val = 347 * 0.05 = 17

- CharChunkDataset
    - chunk: one training example (128 consecutive characters)
    - input: those 128 characters
    - target: the next 128 characters, shifted by one position
        - the model learns to predict the next character at each step
    - len(train_ds) = train_ds.num_chunks = (330 - 1 - 128) // 128 + 1 = 2
    - self.starts
        - list of starting indices where each chunk begins
        - makes getitem fast


In [4]:
# encode entire corpus
data_ids = torch.tensor(vocab.encode(text), dtype=torch.long)

# train/val split
n_total = len(data_ids)
n_val = max(1, int(n_total * config["val_fraction"]))   # at least one token for validation
train_ids = data_ids[:-n_val]   # 330 ids
val_ids = data_ids[-n_val:]     # 17 ids

# splits a long 1D tokenized tensor into (input, target) chunk pairs
# each sample: x of length T, y of length T (next-char prediction)
class CharChunkDataset(Dataset):
    def __init__(self, ids: torch.Tensor, seq_len: int, step: Optional[int] = None):
        self.ids = ids      # 1d tensor of all character ids
        self.T = seq_len    # chunk length
        self.step = step if step is not None else seq_len               # stride: None means non-overlapping
        self.num_chunks = (len(ids) - 1 - self.T) // self.step + 1      # number of chunks you can extract
        self.starts = [i * self.step for i in range(self.num_chunks)]
    
    # built-in behavior for len(obj)
    def __len__(self):
        return self.num_chunks
    
    # built-in behavior for obj[i]
    def __getitem__(self, idx):
        s = self.starts[idx]
        x = self.ids[s : s + self.T]
        y = self.ids[s + 1: s + 1 + self.T]
        return x, y

train_ds = CharChunkDataset(train_ids, config["seq_len"], config["overlap_step"])
val_ds = CharChunkDataset(val_ids, config["seq_len"], config["overlap_step"])

train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False, drop_last=True)

## Model
- embedding -> RNN (GRU/LSTM) -> Linear
- batch_first=True so inputs are [B, T]
- dropout on embeddings and hidden outputs

In [5]:
class CharRNN(nn.Module):
    def __init__(self, vocab_size: int, emb: int, hidden: int, layers: int, dropout: float, rnn_type: str = "GRU"):
        super().__init__()
        self.vocab_size = vocab_size
        self.emb = nn.Embedding(vocab_size, emb)
        rnn_cls = {"GRU": nn.GRU, "LSTM": nn.LSTM}[rnn_type.upper()]
        self.rnn = rnn_cls(
            input_size=emb,
            hidden_size=hidden,
            num_layers=layers,
            dropout=dropout if layers > 1 else 0.0,
            batch_first=True,
        )
        self.drop = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden, vocab_size)
        self.rnn_type = rnn_type.upper()
        self.layers = layers
        self.hidden = hidden
    
    # x: [batch_size=128, seq_len=128]
    # h: 
    def forward(self, x, h=None):
        # input ids -> embeddings
        x = self.emb(x)         # x: [batch_size=128, seq_len=128, embedding_dim=256]
        # embeddings -> RNN
        x, h = self.rnn(x, h)   # x: [batch_size=128, seq_len=128, hidden_dim=256], h: [num_layers=1, batch_size=128, hidden_dim=256] (and c for LSTM)
        # dropout for regularization
        x = self.drop(x)
        # projection to logits
        # linear layer maps each hidden_dim vector to vocab_size classes
        logits = self.fc(x)     # [batch_size=128, seq_len=128, vocab_size=45]
        return logits, h

    def init_hidden(self, batch_size, device):
        if self.rnn_type == "LSTM":
            h = torch.zeros(self.layers, batch_size, self.hidden, device=device)
            c = torch.zeros(self.layers, batch_size, self.hidden, device=device)
            return (h, c)
        else:
            h = torch.zeros(self.layers, batch_size, self.hidden, device=device)
            return h

model = CharRNN(
    vocab_size=vocab_size,
    emb=config["embedding_dim"],
    hidden=config["hidden_dim"],
    layers=config["num_layers"],
    dropout=config["dropout"],
    rnn_type=config["rnn_type"],
).to(device)

## Training Utilities
- cross-entropy loss on next-char targets
- adam optimizer, gradient clipping
- bits-per-character (BPC) = CE / ln(2)
    - how many bits of information the model needs to encode each character in the text
        - lower BPC = better prediction
            - BPC = 1 means model needs 1 bit to choose next char
            - BPC = log2(V) means model is guessing randomly from V characters
    - bpc_from_loss
        - nn.CrossEntropyLoss returns negative log likelihood in nats (natural log base e)

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

def bpc_from_loss(loss_val: float) -> float:
    return loss_val / math.log(2.0)

## Helpers

In [7]:
def sample_text(
        model: nn.Module,
        vocab: CharVocab,
        max_new_tokens: int = 300,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
        prompt: str = "",
        device: str = "cpu"
):
    model.eval()
    with torch.no_grad():
        # if no prompt is provided, start from a random single char
        if not prompt:
            prompt = random.choice(vocab.itos)
        
        input_ids = torch.tensor(vocab.encode(prompt), dtype=torch.long, device=device).unsqueeze(0)    # [1, T]
        # start with no hidden state and a list of generated chars initialized with the prompt
        h = None
        generated = list(prompt)

        for _ in range(max_new_tokens):
            logits, h = model(input_ids, h)     # logits: [1, T, V]
            last_logits = logits[0, -1, :] / max(1e-6, temperature)
        
        # convert logits to probabilities
        probs = torch.softmax(last_logits, dim=-1)

        # top-k / top-p filtering
        if top_k is not None:
            topk_vals, topk_idx = torch.topk(probs, k=min(top_k, probs.size(-1)))
            filtered = torch.zeros_like(probs).scatter_(0, topk_idx, topk_vals)
            probs = filtered / filtered.sum()
        
        if top_p is not None:
            sorted_probs, sorted_idx = torch.sort(probs, descending=True)
            cumsum = torch.cumsum(sorted_probs, dim=0)
            mask = cumsum - sorted_probs > top_p
            sorted_probs[mask] = 0.0
            sorted_probs = sorted_probs / sorted_probs.sum()
            idx_choice = torch.multinomial(sorted_probs, num_samples=1)
            next_id = sorted_idx[idx_choice]
        else:
            next_id = torch.multinomial(probs, num_samples=1)
        
        next_id = sorted_idx[idx_choice]
        generated.append(vocab.itos[next_id])

        input_ids = torch.tensor([next_id], device=device).view(1, 1)
    
    return "".join(generated)

def evaluate_loss(data_loader: DataLoader) -> Tuple[float, float]:
    model.eval()
    total_loss, count = 0.0, 0
    with torch.no_grad():
        for x, y in data_loader:
            x = x.to(device)
            y = y.to(device)
            logits, _ = model(x)
            loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            total_loss += loss.item()
            count += 1
    print(count)
    if count == 0:
        print("[warn] evaluate_loss: dataloader is empty (no batches).")
        return float("nan"), float("nan")
    avg = total_loss / max(1, count)
    return avg, bpc_from_loss(avg)

## Training Loop
- iterate over batches of (x, y)
- zero grad -> forward -> CE loss -> clip -> step
- log train loss & bpc
- periodically sample text to track progress

In [16]:
global_step = 0
best_val = float("inf")

print("len(train_ids):", len(train_ids))
print("len(val_ids):  ", len(val_ids))
print("len(train_ds): ", len(train_ds))
print("len(val_ds):   ", len(val_ds))
print("len(train_loader):", len(train_loader))
print("len(val_loader):  ", len(val_loader))

for epoch in range(1, config["num_epochs"] + 1):
    model.train()
    running = 0.0   # running sum of training loss for loggin averages

    # iterate over batches of (x, y) where:
    # x: [B, T] token ids (input characters)
    # y: [B, T] token ids (next characters, shifted by 1)
    for i, (x, y) in enumerate(train_loader, start=1):
        # move tensors to GPU/CPU as chosen earlier
        x = x.to(device)
        y = y.to(device)

        # forward pass
        logits, _ = model(x)
        loss = criterion(logits.view(-1, vocab_size), y.view(-1))

        # backward pass
        optimizer.zero_grad(set_to_none=True)   # clear old grads
        loss.backward()

        # gradient clipping
        if config["grad_clip"] is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), config["grad_clip"])

        # parameter update
        optimizer.step()

        # logging
        running += loss.item()
        global_step += 1
        if global_step % config["log_every"] == 0:
            avg_loss = running / config["log_every"]
            avg_bpc = bpc_from_loss(avg_loss)
            print(f"Epoch {epoch:02d} | Step {global_step:06d} "
                  f"| train loss {avg_loss:.4f} | bpc {avg_bpc:.3f}")
            running = 0.0
        
        # sample generated text periodically
        if global_step % config["sample_every"] == 0:
            print("Sample")
    
    # end-of-epoch validation
    val_loss, val_bpc = evaluate_loss(val_loader)
    print(f"[Validation] Epoch {epoch}: loss {val_loss:.4f} | bpc {val_bpc:.3f}")

    if val_loss < best_val:
        best_val = val_loss
        torch.save({
            "model_state": model.state_dict(),
            "config": config,
            "stoi": vocab.stoi,
            "itos": vocab.itos,
        }, config["save_path"])
        print(f"Saved checkpoint to {config['save_path']}")


len(train_ids): 1059624
len(val_ids):   55769
len(train_ds):  8278
len(val_ds):    435
len(train_loader): 64
len(val_loader):   3
3
[Validation] Epoch 1: loss 1.5510 | bpc 2.238
Saved checkpoint to char_rnn_checkpoint.pt
Epoch 02 | Step 000100 | train loss 0.5148 | bpc 0.743
Sample
3
[Validation] Epoch 2: loss 1.5463 | bpc 2.231
Saved checkpoint to char_rnn_checkpoint.pt
3
[Validation] Epoch 3: loss 1.5399 | bpc 2.222
Saved checkpoint to char_rnn_checkpoint.pt
Epoch 04 | Step 000200 | train loss 0.1126 | bpc 0.163
Sample
3
[Validation] Epoch 4: loss 1.5391 | bpc 2.220
Saved checkpoint to char_rnn_checkpoint.pt
Epoch 05 | Step 000300 | train loss 0.6178 | bpc 0.891
Sample
3
[Validation] Epoch 5: loss 1.5359 | bpc 2.216
Saved checkpoint to char_rnn_checkpoint.pt


In [17]:
def load_model_checkpoint(path: str, rnn_type: str = None):
    ckpt = torch.load(path, map_location=device)
    itos = ckpt["itos"]
    stoi = ckpt["stoi"]

    # Rebuild vocab object
    loaded_vocab = CharVocab("".join(itos))
    loaded_vocab.itos = itos
    loaded_vocab.stoi = stoi

    cfg = ckpt["config"]
    if rnn_type is not None:
        cfg["rnn_type"] = rnn_type

    loaded_model = CharRNN(
        vocab_size=len(itos),
        emb=cfg["embedding_dim"],
        hidden=cfg["hidden_dim"],
        layers=cfg["num_layers"],
        dropout=cfg["dropout"],
        rnn_type=cfg["rnn_type"],
    ).to(device)
    loaded_model.load_state_dict(ckpt["model_state"])
    loaded_model.eval()
    return loaded_model, loaded_vocab, cfg

if os.path.exists(config["save_path"]):
    loaded_model, loaded_vocab, loaded_cfg = load_model_checkpoint(config["save_path"])
    print("Checkpoint loaded")
    print("Sample")
else:
    print("No checkpoint found yet. Train first.")

Checkpoint loaded
Sample
