[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/RichardPovinelli/home/blob/main/Neural_Networks_Course/train_char_rnn.ipynb)

# Char-level RNN Training

This notebook trains a small character-level RNN on the Tiny Shakespeare corpus (or any provided text file).

Instructions:
- In Google Colab, go to Runtime → Change runtime type → set Hardware accelerator = GPU for faster training.
- Edit the Parameters cell (Cell 3) to adjust training constants (epochs, learning rate, sequence length, etc.).
- Run the cells top-to-bottom. Checkpoints will be saved in a `checkpoints/` folder.
- To resume, set `RESUME_PATH` to a checkpoint filename (e.g., `"char_rnn_epoch10.pt"`) or an absolute path.

In [1]:
# Imports and environment checks
import os
import types
import requests
import torch
import torch.nn as nn
import torch.optim as optim

print(f"Torch version: {torch.__version__}")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Torch version: 2.9.0+cu128
Using device: cuda


In [2]:
# Parameters (edit these constants to control training and sampling)
DATA_URL = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
DATA_FILE = "tiny_shakespeare.txt"  # Set to a local file if you provide your own text

# Training hyperparameters
EPOCHS = 10
LR = 1e-3
SEQ_LEN = 100
BATCH_SIZE = 32
EMB_SIZE = 64
HIDDEN_SIZE = 256
NUM_LAYERS = 1
RNN_TYPE = "rnn"  # one of: "gru", "rnn", "lstm"
PRINT_EVERY = 200

# Sampling
START_TEXT = "To be, or not to be"
TEMP = 1.0
SAMPLE_LENGTH = 200

# Checkpointing
CHECKPOINT_DIR = "checkpoints"
RESUME_PATH = None  # e.g., "char_rnn_epoch10.pt" or an absolute path

In [7]:
# Model definition 
class CharRNN(nn.Module):
    """A small char-level RNN for next-character prediction.

    Architecture:
    - Embedding-like linear projection of one-hot inputs
    - RNN (nn.RNN, nn.GRU, or nn.LSTM)
    - Linear decoder to vocab-size
    """

    def __init__(self, vocab_size, embedding_size=32, hidden_size=128, num_layers=1, rnn_type="gru"):
        super().__init__()
        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.rnn_type = rnn_type.lower()

        # Input projection: one-hot -> embedding-like linear
        self.input_proj = nn.Linear(vocab_size, embedding_size)

        if self.rnn_type == "gru":
            self.rnn = nn.GRU(embedding_size, hidden_size, num_layers, batch_first=True)
        elif self.rnn_type == "rnn":
            self.rnn = nn.RNN(embedding_size, hidden_size, num_layers, batch_first=True)
        elif self.rnn_type == "lstm":
            self.rnn = nn.LSTM(embedding_size, hidden_size, num_layers, batch_first=True)
        else:
            raise ValueError("Unknown rnn_type")

        self.decoder = nn.Linear(hidden_size, vocab_size)

    def forward(self, x_onehot, hidden=None):
        # x_onehot: (batch, seq_len, vocab_size)
        emb = self.input_proj(x_onehot)
        out, hidden = self.rnn(emb, hidden)
        logits = self.decoder(out)
        return logits, hidden

    def init_hidden(self, batch_size=1, device=None):
        device = device or next(self.parameters()).device
        if self.rnn_type == "lstm":
            return (
                torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device),
                torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device),
            )
        else:
            return torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)

In [4]:
# Helper functions (ported from train_char_rnn.py)
def download_data(url=DATA_URL, path=DATA_FILE):
    if os.path.exists(path):
        return path
    print(f"Downloading {url} -> {path} ...")
    r = requests.get(url, timeout=30)
    r.raise_for_status()
    with open(path, "w", encoding="utf-8") as f:
        f.write(r.text)
    return path

def build_vocab(corpus_text):
    chars = sorted(list(set(corpus_text)))
    char_to_index = {ch: i for i, ch in enumerate(chars)}
    index_to_char = {i: ch for ch, i in char_to_index.items()}
    return chars, char_to_index, index_to_char

def create_batches(corpus_text, char_to_index, seq_len, batch_size):
    # Convert to indices
    data = torch.tensor([char_to_index[ch] for ch in corpus_text], dtype=torch.long)
    num_batches = data.size(0) // (batch_size * seq_len)
    if num_batches == 0:
        raise ValueError("Not enough data. Reduce seq_len or batch_size.")
    data = data[: num_batches * batch_size * seq_len]
    data = data.view(batch_size, -1)
    for i in range(0, data.size(1), seq_len):
        x = data[:, i : i + seq_len]
        y = data[:, i + 1 : i + 1 + seq_len]
        if x.size(1) != seq_len or y.size(1) != seq_len:
            continue
        yield x, y

def one_hot(indices, vocab_size, device):
    # indices: (batch, seq_len) long
    b, s = indices.size()
    oh = torch.zeros(b, s, vocab_size, device=device)
    oh.scatter_(2, indices.unsqueeze(-1), 1.0)
    return oh

def sample(model, start_str, char_to_index, index_to_char, length=200, temperature=1.0, device=None):
    model.eval()
    vocab_size = len(char_to_index)
    device = device or torch.device("cpu")
    if not isinstance(device, torch.device):
        device = torch.device(device)
    hidden = model.init_hidden(batch_size=1, device=device)

    input_idx = torch.tensor([[char_to_index[ch] for ch in start_str]], dtype=torch.long, device=device)
    with torch.no_grad():
        for i in range(input_idx.size(1)):
            x_oh = one_hot(input_idx[:, i : i + 1], vocab_size, device)
            logits, hidden = model(x_oh, hidden)
        out_chars = list(start_str)
        prev = input_idx[:, -1:]
        for _ in range(length):
            x_oh = one_hot(prev, vocab_size, device)
            logits, hidden = model(x_oh, hidden)
            logits = logits[:, -1, :] / max(1e-8, temperature)
            probs = torch.softmax(logits, dim=-1)
            idx = torch.multinomial(probs, num_samples=1)
            idx_item = int(idx.view(-1).cpu().numpy()[0])
            ch = index_to_char[idx_item]
            out_chars.append(ch)
            prev = idx
    return "".join(out_chars)

def train(config):
    preferred_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    path = download_data()
    corpus_text = open(path, encoding="utf-8").read()
    chars, char_to_index, index_to_char = build_vocab(corpus_text)
    vocab_size = len(chars)
    print(f"Loaded data: {len(corpus_text)} chars, {vocab_size} unique chars")

    model = CharRNN(
        vocab_size,
        embedding_size=config.emb_size,
        hidden_size=config.hidden,
        num_layers=config.layers,
        rnn_type=config.rnn,
    )
    try:
        model.to(preferred_device)
        dev = preferred_device
        print(f"Using device: {dev}")
    except (RuntimeError, OSError) as e:
        dev = torch.device("cpu")
        model.to(dev)
        print(f"Warning: failed to use preferred device {preferred_device} ({e}), falling back to CPU")

    optimizer = optim.Adam(model.parameters(), lr=config.lr)
    criterion = nn.CrossEntropyLoss()

    ckpt_dir = os.path.join(os.getcwd(), "checkpoints")
    os.makedirs(ckpt_dir, exist_ok=True)

    start_epoch = 1
    if config.resume:
        resume_path = config.resume
        if not os.path.isabs(resume_path):
            resume_path = os.path.join(ckpt_dir, resume_path)
        if not os.path.exists(resume_path):
            alt_path = os.path.join(os.getcwd(), config.resume)
            if os.path.exists(alt_path):
                resume_path = alt_path
            else:
                raise FileNotFoundError(f"Checkpoint to resume not found: {resume_path}")
        print(f"Loading checkpoint from {resume_path}")
        ckpt = torch.load(resume_path, map_location=dev)
        model.load_state_dict(ckpt["model_state"])
        if "optimizer_state" in ckpt:
            try:
                optimizer.load_state_dict(ckpt["optimizer_state"])
            except (RuntimeError, ValueError):
                print("Warning: Could not load optimizer state; continuing with fresh optimizer")
        if "char_to_index" in ckpt and "index_to_char" in ckpt:
            char_to_index = ckpt["char_to_index"]
            index_to_char = ckpt["index_to_char"]
            chars = list(index_to_char.values())
            vocab_size = len(chars)
        if "epoch" in ckpt:
            start_epoch = ckpt["epoch"] + 1
        print(f"Resuming from epoch {start_epoch}")

    steps = 0
    for epoch in range(start_epoch, config.epochs + 1):
        model.train()
        total_loss = 0.0
        batches = 0

        for x_batch, y_batch in create_batches(corpus_text, char_to_index, config.seq_len, config.batch_size):
            batches += 1
            x = one_hot(x_batch.to(dev), vocab_size, dev)
            y = y_batch.to(dev)

            optimizer.zero_grad()
            logits, _ = model(x)
            loss = criterion(logits.reshape(-1, vocab_size), y.reshape(-1))
            loss.backward()
            optimizer.step()

            total_loss += float(loss)
            steps += 1
            if steps % config.print_every == 0:
                print(f"Epoch {epoch} step {steps} loss {total_loss / steps:.4f}")

        avg = total_loss / max(1, batches)
        print(f"Epoch {epoch} average loss {avg:.4f}")

        s = sample(
            model,
            start_str=config.start,
            char_to_index=char_to_index,
            index_to_char=index_to_char,
            length=200,
            temperature=config.temp,
            device=dev,
        )
        print(f"\nSample:\n{'-' * 100}\n{s[:1000]}\n{'-' * 100}\n")

        ckpt_name = f"char_rnn_epoch{epoch}.pt"
        ckpt_path = os.path.join(ckpt_dir, ckpt_name)
        torch.save(
            {
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "vocab": chars,
                "char_to_index": char_to_index,
                "index_to_char": index_to_char,
            },
            ckpt_path,
        )
        print(f"Saved checkpoint {ckpt_path}")

    return model, chars, char_to_index, index_to_char, dev

In [5]:
# Launch training with the parameters above
config = types.SimpleNamespace(
    epochs=EPOCHS,
    lr=LR,
    seq_len=SEQ_LEN,
    batch_size=BATCH_SIZE,
    emb_size=EMB_SIZE,
    hidden=HIDDEN_SIZE,
    layers=NUM_LAYERS,
    rnn=RNN_TYPE,
    print_every=PRINT_EVERY,
    start=START_TEXT,
    temp=TEMP,
    resume=RESUME_PATH,
)

model, chars, char_to_index, index_to_char, dev = train(config)

Loaded data: 1115394 chars, 65 unique chars
Using device: cuda


Consider using tensor.detach() first. (Triggered internally at C:\actions-runner\_work\pytorch\pytorch\pytorch\torch\csrc\autograd\generated\python_variable_methods.cpp:837.)
  total_loss += float(loss)


Epoch 1 step 200 loss 2.8811
Epoch 1 average loss 2.6235

Sample:
----------------------------------------------------------------------------------------------------
To be, or not to be; chedl of the fipren;
Loml thiure pangsezef wheat, for darls.
Rundt ir laad'r wish cemysing comtunceror'e un-seete;
Shan whe angilu, sham bamd whow jut at alghr all willther
I hos
Apre: ann eneat' ab
----------------------------------------------------------------------------------------------------

Saved checkpoint c:\Users\richard\mu\teaching\Neural Networks\lectures\06 Recurent Networks\code\checkpoints\char_rnn_epoch1.pt
Epoch 2 step 400 loss 0.2860
Epoch 2 step 600 loss 0.8763
Epoch 2 average loss 2.0479

Sample:
----------------------------------------------------------------------------------------------------
To be, or not to ben pay, so futicewion macgoo sored;
Ses, is brootherst?
Mure thou kend you likswer.

VOZABESHO:
Bain dave ther Ifge raght meastsry? pay bus I mupg never the iare!
What G

In [6]:
# Generate a sample after training (edit START_TEXT, SAMPLE_LENGTH, TEMP in the Parameters cell)
print("\n--- Generation ---\n")
txt = sample(
    model,
    start_str=START_TEXT,
    char_to_index=char_to_index,
    index_to_char=index_to_char,
    length=SAMPLE_LENGTH,
    temperature=TEMP,
    device=dev,
)
print(txt)


--- Generation ---

To be, or not to ben of tull I child, homal herewnets theur strokeail, of you, come agains
Ro loys;
Stay were to thence again; I'll no man
Can I voot, earther and pility knees, I am sit. O, the mother.

SLARDIBA:
I make
