# Tiny Shakespeare RNN Walkthrough

This notebook explains, step by step, how the handcrafted PyTorch RNN from `11_rnn_pytorch_custom_ts.py` works. It is written for beginners who are comfortable with basic Python and want to understand how recurrent language models operate on real text.

## 1. Prerequisites

We use the Tiny Shakespeare corpus stored in `tiny.txt`. The companion script `08_prepare_tiny_shakespeare.py` tokenizes the dataset and saves ready-to-use NumPy arrays in the `processed/` folder.

Run the following code cell once to make sure those processed files exist (it will call the preparation script automatically if needed).

In [None]:
from pathlib import Path
import subprocess

project_dir = Path.cwd()
processed_dir = project_dir / 'processed'
if not processed_dir.exists():
    print('Processed data not found. Running 08_prepare_tiny_shakespeare.py ...')
    subprocess.run(['python', '08_prepare_tiny_shakespeare.py'], check=True)
else:
    print('Found processed data directory:', processed_dir)

## 2. Imports and Device Setup

We stick to standard PyTorch and NumPy. Setting the random seed helps reproducibility, so the generated text should look similar each time you run the notebook.

In [None]:
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(5)
np.random.seed(5)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## 3. Load the Processed Character Data

The preparation script saved two files that we need:
* `tiny_shakespeare_chars.json`: metadata such as the character vocabulary and sequence length.
* `tiny_shakespeare_char_windows.npy`: overlapping windows of characters for training.

We load them and inspect their shapes.

In [None]:
data_dir = project_dir / 'processed'
char_json_path = data_dir / 'tiny_shakespeare_chars.json'
char_windows_path = data_dir / 'tiny_shakespeare_char_windows.npy'

char_data = json.loads(char_json_path.read_text(encoding='utf-8'))
char_windows = np.load(char_windows_path)

chars = char_data['chars']
char_to_idx = {ch: idx for idx, ch in enumerate(chars)}
idx_to_char = {idx: ch for ch, idx in char_to_idx.items()}
sequence_length = char_data['sequence_length']
vocab_size = len(chars)

print('Character vocabulary size:', vocab_size)
print('Sequence window length:', sequence_length)
print('Number of windows in file:', len(char_windows))

The dataset contains hundreds of thousands of windows. Training on all of them would take a while, so we randomly select a manageable subset that still captures the flavour of Shakespeare's writing.

In [None]:
rng = np.random.default_rng(1)
subset_size = min(12000, len(char_windows))
indices = rng.choice(len(char_windows), size=subset_size, replace=False)
train_windows = char_windows[indices]
print('Using', len(train_windows), 'windows for training')

## 4. Understanding the Model Architecture

We build a custom RNN cell using PyTorch's `nn.Module`. At each time step it:
1. Receives the current character as a one-hot vector.
2. Updates the hidden state with a tanh activation.
3. Produces logits (raw scores) for the next character.

Weights are initialized manually to keep the implementation close to the NumPy version.

In [None]:
class CustomCharRNN(nn.Module):
    def __init__(self, input_size: int, hidden_size: int, output_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.Wxh = nn.Parameter(torch.randn(hidden_size, input_size) * 0.05)
        self.Whh = nn.Parameter(torch.randn(hidden_size, hidden_size) * 0.05)
        self.bh = nn.Parameter(torch.zeros(hidden_size))
        self.Why = nn.Parameter(torch.randn(output_size, hidden_size) * 0.05)
        self.by = nn.Parameter(torch.zeros(output_size))

    def forward(self, inputs: torch.Tensor, h0: torch.Tensor | None = None):
        if h0 is None:
            h_t = inputs.new_zeros(self.hidden_size)
        else:
            h_t = h0

        outputs = []
        for x_t in inputs:
            h_t = torch.tanh(self.Wxh @ x_t + self.Whh @ h_t + self.bh)
            y_t = self.Why @ h_t + self.by
            outputs.append(y_t)
        return torch.stack(outputs), h_t

model = CustomCharRNN(vocab_size, hidden_size=192, output_size=vocab_size).to(device)
model

## 5. Training Setup

To keep the notebook lightweight we train for only a handful of epochs with the Adam optimizer. Each training window contributes `sequence_length` steps, so even a few epochs cover many character transitions.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss()
epochs = 10
print('Training for', epochs, 'epochs on device', device)

### Helper: Convert a window to tensors

Each window stores indices. We convert them to one-hot vectors for inputs and plain integer targets for the next character.

In [None]:
def window_to_tensors(window: np.ndarray):
    inputs_idx = torch.tensor(window[:-1], dtype=torch.long, device=device)
    targets = torch.tensor(window[1:], dtype=torch.long, device=device)
    inputs = F.one_hot(inputs_idx, num_classes=vocab_size).float()
    return inputs, targets

### Training Loop

We track the average loss per epoch. Lower loss means the model is getting better at predicting the next character, although simple RNNs still struggle with long-range structure.

In [None]:
loss_history = []
for epoch in range(1, epochs + 1):
    total_loss = 0.0
    np.random.shuffle(train_windows)
    for window in train_windows:
        inputs, targets = window_to_tensors(window)
        logits, _ = model(inputs)
        loss = criterion(logits, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_windows)
    loss_history.append(avg_loss)
    print(f'Epoch {epoch:02d} - average loss: {avg_loss:.4f}')

## 6. Plot Training Loss

A simple plot helps us see the training trend.

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

plt.plot(loss_history, marker='o')
plt.title('Average Training Loss per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.grid(True)
plt.show()

## 7. Generating New Text

We sample characters one at a time. The `temperature` parameter controls randomness: lower values make the model pick high-probability characters, while higher values encourage diversity (at the risk of gibberish).

In [None]:
def sample(model, seed: str, length: int = 400, temperature: float = 0.8):
    model.eval()
    generated = list(seed)
    with torch.no_grad():
        h_t = None
        for ch in seed[:-1]:
            vec = F.one_hot(torch.tensor(char_to_idx[ch]).to(device), num_classes=vocab_size).float()
            _, h_t = model(vec.unsqueeze(0), h_t)

        char = seed[-1]
        for _ in range(length):
            idx = char_to_idx.get(char, 0)
            vec = F.one_hot(torch.tensor(idx).to(device), num_classes=vocab_size).float()
            logits, h_t = model(vec.unsqueeze(0), h_t)
            logits = logits[-1] / temperature
            probs = F.softmax(logits, dim=0)
            idx = torch.multinomial(probs, num_samples=1).item()
            char = idx_to_char[idx]
            generated.append(char)
    model.train()
    return ''.join(generated)

In [None]:
seed_text = 'ROMEO:
'
generated = sample(model, seed_text, length=400, temperature=0.8)
print(generated)

Try experimenting with different seeds and temperature values (e.g., `0.5`, `1.2`) to hear how the model changes its tone.

## 8. Key Takeaways

* **Handcrafted weights:** We built an RNN cell manually, which demystifies what `nn.RNN` does under the hood.
* **Autograd convenience:** Even though the cell is custom, PyTorch handled the gradient calculations.
* **Ungated limitations:** Vanilla RNNs forget long-range patterns and often repeat words. LSTM or GRU cells add gating to address this.
* **Data hunger:** More data and longer training improve realism, but also increase compute time.

Next steps you can try:
1. Replace the hidden cell with an LSTM-like implementation.
2. Batch the training windows for faster training.
3. Switch to a word-level model by adapting the preprocessing artifacts.