In [16]:
import math
from timeit import default_timer as timer

import torch
import torch.nn as nn
import torch.nn.functional as F

from datasets import load_dataset
from tqdm.auto import tqdm

In [17]:
device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using mps device


In [18]:
dataset = load_dataset('tiny_shakespeare')
text = dataset['train']['text'][0]

print("Num Chars. in dataset: ", len(text))
# print("Data: \n", text[:500])

Using the latest cached version of the dataset since tiny_shakespeare couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /Users/bransonboggia/.cache/huggingface/datasets/tiny_shakespeare/default/1.0.0/b5b13969f09fe8707337f6cb296314fbe06960bd9a868dca39e713e163d27b5e (last modified on Fri Feb 13 12:10:22 2026).


Num Chars. in dataset:  1003854


In [19]:
chars = sorted(list(set(text)))
print(chars)
vocab_size = len(chars)

char_encode = {ch: i for i, ch in enumerate(chars)}
rev_char_encode = {i: ch for ch, i in char_encode.items()}

['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']


In [20]:
fall_back = char_encode.get(" ", 0)

def encode(s: str) -> torch.tensor:
    return torch.tensor([char_encode.get(ch, fall_back) for ch in s], dtype=torch.long)
    
def decode(ids) -> str:
    return "".join([rev_char_encode[int(i)] for i in ids])

data = encode(text)
print("Encoded data shape: ", data[0])

Encoded data shape:  tensor(18)


In [21]:
N = len(data)

train_end = int(0.9 * N)
val_end = int(0.95 * N)

train_ids = data[:train_end]
val_ids = data[train_end:val_end]
test_ids = data[val_end:]

print(len(train_ids), len(val_ids), len(test_ids))

903468 50193 50193


In [22]:
def batch(data: torch.Tensor, batch_size: int) -> torch.Tensor:
    n_batches = data.size(0) // batch_size
    data = data[:n_batches * batch_size]
    return data.view(batch_size, -1)

In [23]:
train_batch_size = 64
eval_batch_size = 256

train_data = batch(train_ids, train_batch_size).to(device)
val_data = batch(val_ids, eval_batch_size).to(device)
test_data = batch(test_ids, eval_batch_size).to(device)

print(train_data.shape)

torch.Size([64, 14116])


In [24]:
def get_batch(data: torch.Tensor, i: int, seq_len: int):
    seq_len = min(seq_len, data.size(1) - 1 - i)
    X = data[:, i:i+seq_len]
    Y = data[:, i+1:i+seq_len+1]
    return X, Y

In [25]:
class CharRNN(nn.Module):
    
    def __init__(
        self,
        vocab_size: int,
        emb_dim: int = 128,
        hidden_size: int = 256,
        num_layers: int = 2,
        dropout: float = 0.2,
        nonlinearity: str = "tanh", 
        bidirectional: bool = False
    ):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_directions = 2 if bidirectional else 1
        self.embedding = nn.Embedding(vocab_size, emb_dim)
        
        self.rnn = nn.RNN(
            input_size=emb_dim,
            hidden_size=hidden_size,
            num_layers=num_layers,
            nonlinearity=nonlinearity,
            batch_first=True,
            dropout=dropout,
            bidirectional=bidirectional
        )
        
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_size * self.num_directions, vocab_size)
        
    def init_hidden(self, batch_size, device):
        return torch.zeros(self.num_layers, batch_size, self.hidden_size, device=device)
    
    
    def forward(self, x, h):
        embed = self.embedding(x)
        out, h = self.rnn(embed, h)
        out = self.dropout(out)
        logits = self.fc(out)
        return logits, h

model = CharRNN(vocab_size=vocab_size).to(device)
print(model)

CharRNN(
  (embedding): Embedding(65, 128)
  (rnn): RNN(128, 256, num_layers=2, batch_first=True, dropout=0.2)
  (dropout): Dropout(p=0.2, inplace=False)
  (fc): Linear(in_features=256, out_features=65, bias=True)
)


In [26]:
X_tmp, Y_tmp = get_batch(train_data, 0, seq_len=32)
h_tmp = model.init_hidden(batch_size=X_tmp.size(0), device=device)
logits_tmp, h_tmp2 = model(X_tmp, h_tmp)
print("X:", X_tmp.shape, "logits:", logits_tmp.shape, "hidden:", h_tmp2.shape)

X: torch.Size([64, 32]) logits: torch.Size([64, 32, 65]) hidden: torch.Size([2, 64, 256])


In [27]:
loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [28]:
def train_one_epoch(model, data, optim, loss_fn, seq_len=128, clip_grad=1.0):
    model.train()
    
    total_loss = 0.0
    total_tokens = 0
    
    h = model.init_hidden(data.size(0), device)
    
    # THERE WAS AN ERROR HERE CAUSING PROBLEMS, MAKE SURE YOU FIX IT IN YOUR CODE.
    # WAS CAUSING TRAINING DEGREDATION BECAUSE IT WAS ONLY DOING 1 UPDATE/EPOCH
    # HAS THE WRONG DIMENSIONS, WAS: `data.size(0) - 1`, SHOULD BE data.size(1) - 1
    for i in range(0, data.size(1) - 1, seq_len):
        X, Y = get_batch(data, i, seq_len)
        
        logits, h = model(X, h)
        
        h = h.detach()
        
        loss = loss_fn(
            logits.reshape(-1, model.vocab_size),
            Y.reshape(-1)
        )
        
        optim.zero_grad(set_to_none=True)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_grad)
        
        optim.step()
        
        n_tokens = Y.numel()
        
        total_loss += loss.item() * n_tokens
        total_tokens += n_tokens
        
    return total_loss / total_tokens
    

In [29]:
@torch.inference_mode()
def eval(model, data, loss_fn, device, seq_len=128):
    model.eval()
    
    total_loss = 0.0
    total_tokens = 0

    h = model.init_hidden(batch_size=data.size(0), device=device)

    for i in range(0, data.size(1) - 1, seq_len):
        X, Y = get_batch(data, i, seq_len)
        logits, h = model(X, h)
        
        loss = loss_fn(
            logits.reshape(-1, model.vocab_size),
            Y.reshape(-1)
        )
        
        n_tokens = Y.numel()
        total_loss += loss.item() * n_tokens
        total_tokens += n_tokens

    return total_loss / total_tokens

In [30]:
epochs = 50
seq_len = 256

best_val_loss = float("inf")

start = timer()

for epoch in tqdm(range(1, epochs + 1)):
    train_loss = train_one_epoch(model, train_data, optim, loss_fn, seq_len=seq_len, clip_grad=1.0)
    val_loss = eval(model, val_data, loss_fn, device, seq_len=seq_len)

    print(
        f"Epoch {epoch:02d} | "
        f"train loss {train_loss:.4f} (ppl {math.exp(train_loss):.2f}) | "
        f"val loss {val_loss:.4f} (ppl {math.exp(val_loss):.2f})"
    )

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        ckpt = {
            "model_state": model.state_dict(),
            "char_encode": char_encode,
            "rev_char_encode": rev_char_encode,
            "config": {
                "vocab_size": vocab_size,
                "emb_dim": model.embedding.embedding_dim,
                "hidden_size": model.hidden_size,
                "num_layers": model.num_layers,
                "dropout": model.dropout.p,
                "nonlinearity": model.rnn.nonlinearity,
            }
        }
        torch.save(ckpt, "char_rnn_best.pt")
        print(f"  Saved new best checkpoint (val loss {best_val_loss:.4f})")

end = timer()
print(f"Training time: {end - start:.2f}s on {device}")

  2%|▏         | 1/50 [00:05<04:15,  5.22s/it]

Epoch 01 | train loss 2.6939 (ppl 14.79) | val loss 2.2843 (ppl 9.82)
  Saved new best checkpoint (val loss 2.2843)


  4%|▍         | 2/50 [00:10<04:12,  5.25s/it]

Epoch 02 | train loss 2.1490 (ppl 8.58) | val loss 2.0567 (ppl 7.82)
  Saved new best checkpoint (val loss 2.0567)


  6%|▌         | 3/50 [00:15<04:07,  5.28s/it]

Epoch 03 | train loss 1.9906 (ppl 7.32) | val loss 1.9356 (ppl 6.93)
  Saved new best checkpoint (val loss 1.9356)


  8%|▊         | 4/50 [00:21<04:02,  5.27s/it]

Epoch 04 | train loss 1.8957 (ppl 6.66) | val loss 1.8550 (ppl 6.39)
  Saved new best checkpoint (val loss 1.8550)


 10%|█         | 5/50 [00:26<03:56,  5.25s/it]

Epoch 05 | train loss 1.8303 (ppl 6.24) | val loss 1.7923 (ppl 6.00)
  Saved new best checkpoint (val loss 1.7923)


 12%|█▏        | 6/50 [00:31<03:50,  5.24s/it]

Epoch 06 | train loss 1.7830 (ppl 5.95) | val loss 1.7468 (ppl 5.74)
  Saved new best checkpoint (val loss 1.7468)


 14%|█▍        | 7/50 [00:36<03:45,  5.23s/it]

Epoch 07 | train loss 1.7450 (ppl 5.73) | val loss 1.7069 (ppl 5.51)
  Saved new best checkpoint (val loss 1.7069)


 16%|█▌        | 8/50 [00:41<03:39,  5.23s/it]

Epoch 08 | train loss 1.7159 (ppl 5.56) | val loss 1.6794 (ppl 5.36)
  Saved new best checkpoint (val loss 1.6794)


 18%|█▊        | 9/50 [00:47<03:34,  5.23s/it]

Epoch 09 | train loss 1.6919 (ppl 5.43) | val loss 1.6558 (ppl 5.24)
  Saved new best checkpoint (val loss 1.6558)


 20%|██        | 10/50 [00:52<03:29,  5.24s/it]

Epoch 10 | train loss 1.6714 (ppl 5.32) | val loss 1.6331 (ppl 5.12)
  Saved new best checkpoint (val loss 1.6331)


 22%|██▏       | 11/50 [00:57<03:23,  5.22s/it]

Epoch 11 | train loss 1.6549 (ppl 5.23) | val loss 1.6171 (ppl 5.04)
  Saved new best checkpoint (val loss 1.6171)


 24%|██▍       | 12/50 [01:02<03:18,  5.23s/it]

Epoch 12 | train loss 1.6384 (ppl 5.15) | val loss 1.6053 (ppl 4.98)
  Saved new best checkpoint (val loss 1.6053)


 26%|██▌       | 13/50 [01:08<03:13,  5.23s/it]

Epoch 13 | train loss 1.6254 (ppl 5.08) | val loss 1.5906 (ppl 4.91)
  Saved new best checkpoint (val loss 1.5906)


 28%|██▊       | 14/50 [01:13<03:07,  5.21s/it]

Epoch 14 | train loss 1.6148 (ppl 5.03) | val loss 1.5833 (ppl 4.87)
  Saved new best checkpoint (val loss 1.5833)


 30%|███       | 15/50 [01:18<03:02,  5.22s/it]

Epoch 15 | train loss 1.6034 (ppl 4.97) | val loss 1.5735 (ppl 4.82)
  Saved new best checkpoint (val loss 1.5735)


 32%|███▏      | 16/50 [01:23<02:56,  5.20s/it]

Epoch 16 | train loss 1.5935 (ppl 4.92) | val loss 1.5637 (ppl 4.78)
  Saved new best checkpoint (val loss 1.5637)


 34%|███▍      | 17/50 [01:28<02:51,  5.20s/it]

Epoch 17 | train loss 1.5835 (ppl 4.87) | val loss 1.5549 (ppl 4.73)
  Saved new best checkpoint (val loss 1.5549)


 36%|███▌      | 18/50 [01:34<02:45,  5.18s/it]

Epoch 18 | train loss 1.5775 (ppl 4.84) | val loss 1.5503 (ppl 4.71)
  Saved new best checkpoint (val loss 1.5503)


 38%|███▊      | 19/50 [01:39<02:41,  5.20s/it]

Epoch 19 | train loss 1.5694 (ppl 4.80) | val loss 1.5473 (ppl 4.70)
  Saved new best checkpoint (val loss 1.5473)


 40%|████      | 20/50 [01:44<02:36,  5.21s/it]

Epoch 20 | train loss 1.5633 (ppl 4.77) | val loss 1.5412 (ppl 4.67)
  Saved new best checkpoint (val loss 1.5412)


 42%|████▏     | 21/50 [01:49<02:31,  5.22s/it]

Epoch 21 | train loss 1.5572 (ppl 4.75) | val loss 1.5338 (ppl 4.64)
  Saved new best checkpoint (val loss 1.5338)


 44%|████▍     | 22/50 [01:55<02:26,  5.23s/it]

Epoch 22 | train loss 1.5512 (ppl 4.72) | val loss 1.5301 (ppl 4.62)
  Saved new best checkpoint (val loss 1.5301)


 46%|████▌     | 23/50 [02:00<02:20,  5.22s/it]

Epoch 23 | train loss 1.5452 (ppl 4.69) | val loss 1.5235 (ppl 4.59)
  Saved new best checkpoint (val loss 1.5235)


 48%|████▊     | 24/50 [02:05<02:15,  5.22s/it]

Epoch 24 | train loss 1.5401 (ppl 4.66) | val loss 1.5224 (ppl 4.58)
  Saved new best checkpoint (val loss 1.5224)


 50%|█████     | 25/50 [02:10<02:10,  5.22s/it]

Epoch 25 | train loss 1.5350 (ppl 4.64) | val loss 1.5163 (ppl 4.56)
  Saved new best checkpoint (val loss 1.5163)


 52%|█████▏    | 26/50 [02:15<02:05,  5.21s/it]

Epoch 26 | train loss 1.5311 (ppl 4.62) | val loss 1.5150 (ppl 4.55)
  Saved new best checkpoint (val loss 1.5150)


 54%|█████▍    | 27/50 [02:21<01:59,  5.21s/it]

Epoch 27 | train loss 1.5271 (ppl 4.60) | val loss 1.5115 (ppl 4.53)
  Saved new best checkpoint (val loss 1.5115)


 56%|█████▌    | 28/50 [02:26<01:54,  5.21s/it]

Epoch 28 | train loss 1.5235 (ppl 4.59) | val loss 1.5076 (ppl 4.52)
  Saved new best checkpoint (val loss 1.5076)


 58%|█████▊    | 29/50 [02:31<01:49,  5.23s/it]

Epoch 29 | train loss 1.5190 (ppl 4.57) | val loss 1.5086 (ppl 4.52)


 60%|██████    | 30/50 [02:36<01:44,  5.23s/it]

Epoch 30 | train loss 1.5156 (ppl 4.55) | val loss 1.5057 (ppl 4.51)
  Saved new best checkpoint (val loss 1.5057)


 62%|██████▏   | 31/50 [02:42<01:39,  5.25s/it]

Epoch 31 | train loss 1.5120 (ppl 4.54) | val loss 1.5005 (ppl 4.48)
  Saved new best checkpoint (val loss 1.5005)


 64%|██████▍   | 32/50 [02:47<01:34,  5.24s/it]

Epoch 32 | train loss 1.5082 (ppl 4.52) | val loss 1.5011 (ppl 4.49)


 66%|██████▌   | 33/50 [02:52<01:28,  5.23s/it]

Epoch 33 | train loss 1.5055 (ppl 4.51) | val loss 1.4990 (ppl 4.48)
  Saved new best checkpoint (val loss 1.4990)


 68%|██████▊   | 34/50 [02:57<01:23,  5.22s/it]

Epoch 34 | train loss 1.5028 (ppl 4.49) | val loss 1.4997 (ppl 4.48)


 70%|███████   | 35/50 [03:02<01:18,  5.23s/it]

Epoch 35 | train loss 1.4999 (ppl 4.48) | val loss 1.4983 (ppl 4.47)
  Saved new best checkpoint (val loss 1.4983)


 72%|███████▏  | 36/50 [03:08<01:12,  5.20s/it]

Epoch 36 | train loss 1.4975 (ppl 4.47) | val loss 1.4951 (ppl 4.46)
  Saved new best checkpoint (val loss 1.4951)


 74%|███████▍  | 37/50 [03:13<01:07,  5.20s/it]

Epoch 37 | train loss 1.4960 (ppl 4.46) | val loss 1.4935 (ppl 4.45)
  Saved new best checkpoint (val loss 1.4935)


 76%|███████▌  | 38/50 [03:18<01:02,  5.19s/it]

Epoch 38 | train loss 1.4926 (ppl 4.45) | val loss 1.4918 (ppl 4.44)
  Saved new best checkpoint (val loss 1.4918)


 78%|███████▊  | 39/50 [03:23<00:57,  5.19s/it]

Epoch 39 | train loss 1.4891 (ppl 4.43) | val loss 1.4903 (ppl 4.44)
  Saved new best checkpoint (val loss 1.4903)


 80%|████████  | 40/50 [03:28<00:51,  5.19s/it]

Epoch 40 | train loss 1.4871 (ppl 4.42) | val loss 1.4902 (ppl 4.44)
  Saved new best checkpoint (val loss 1.4902)


 82%|████████▏ | 41/50 [03:33<00:46,  5.18s/it]

Epoch 41 | train loss 1.4846 (ppl 4.41) | val loss 1.4866 (ppl 4.42)
  Saved new best checkpoint (val loss 1.4866)


 84%|████████▍ | 42/50 [03:39<00:41,  5.19s/it]

Epoch 42 | train loss 1.4822 (ppl 4.40) | val loss 1.4880 (ppl 4.43)


 86%|████████▌ | 43/50 [03:44<00:36,  5.22s/it]

Epoch 43 | train loss 1.4814 (ppl 4.40) | val loss 1.4903 (ppl 4.44)


 88%|████████▊ | 44/50 [03:49<00:31,  5.22s/it]

Epoch 44 | train loss 1.4793 (ppl 4.39) | val loss 1.4849 (ppl 4.41)
  Saved new best checkpoint (val loss 1.4849)


 90%|█████████ | 45/50 [03:54<00:26,  5.22s/it]

Epoch 45 | train loss 1.4754 (ppl 4.37) | val loss 1.4914 (ppl 4.44)


 92%|█████████▏| 46/50 [04:00<00:20,  5.24s/it]

Epoch 46 | train loss 1.4757 (ppl 4.37) | val loss 1.4874 (ppl 4.43)


 94%|█████████▍| 47/50 [04:05<00:15,  5.27s/it]

Epoch 47 | train loss 1.4733 (ppl 4.36) | val loss 1.4858 (ppl 4.42)


 96%|█████████▌| 48/50 [04:10<00:10,  5.28s/it]

Epoch 48 | train loss 1.4710 (ppl 4.35) | val loss 1.4847 (ppl 4.41)
  Saved new best checkpoint (val loss 1.4847)


 98%|█████████▊| 49/50 [04:16<00:05,  5.32s/it]

Epoch 49 | train loss 1.4693 (ppl 4.35) | val loss 1.4817 (ppl 4.40)
  Saved new best checkpoint (val loss 1.4817)


100%|██████████| 50/50 [04:21<00:00,  5.23s/it]

Epoch 50 | train loss 1.4678 (ppl 4.34) | val loss 1.4816 (ppl 4.40)
  Saved new best checkpoint (val loss 1.4816)
Training time: 261.65s on mps





In [31]:
ckpt = torch.load("char_rnn_best.pt", map_location=device)

model = CharRNN(**ckpt["config"]).to(device)
model.load_state_dict(ckpt["model_state"])
stoi = ckpt["char_encode"]
itos = ckpt["rev_char_encode"]

test_loss = eval(model, test_data, loss_fn, device, seq_len=seq_len)
print(f"Test loss: {test_loss:.4f} | Test ppl: {math.exp(test_loss):.2f}")

Test loss: 1.6282 | Test ppl: 5.09


In [32]:
@torch.inference_mode()
def generate_text(
    model,
    prompt: str,
    length: int = 500,
    temperature: float = 1.0,
    top_k: int | None = None,
):
    model.eval()

    fallback_id = stoi.get(" ", 0)

    prompt_ids = torch.tensor(
        [stoi.get(ch, fallback_id) for ch in prompt],
        dtype=torch.long,
        device=device
    ).unsqueeze(0)  

    h = model.init_hidden(batch_size=1, device=device)

    logits, h = model(prompt_ids, h)

    next_logits = logits[:, -1, :]
    generated = list(prompt)

    for _ in range(length):
        scaled = next_logits / max(temperature, 1e-6)

        if top_k is not None:
            topk_vals, topk_idx = torch.topk(scaled, k=top_k, dim=-1)
            filtered = torch.full_like(scaled, float("-inf"))
            filtered.scatter_(1, topk_idx, topk_vals)
            scaled = filtered

        probs = F.softmax(scaled, dim=-1)        
        next_id = torch.multinomial(probs, 1)    
        next_char = itos[int(next_id.item())]
        generated.append(next_char)

        logits, h = model(next_id, h)             
        next_logits = logits[:, -1, :]            

    return "".join(generated)

In [33]:
print(generate_text(model, prompt="ROMEO:\n", length=600, temperature=0.9, top_k=40))

ROMEO:
Now may you may burnt another bed:
What nor dislimmed, of the right is broke.
Bey that will be seing a stame is Romeo!

GLOUCESTER:
Let I for your hands.

CORIOLANUS:
Now, that is the wind!

KING RICHARD III:
Good primes
Sir, bisalishness bolels call it!

DUCHESS OF YORK:
Ay, thou entagual! what fas leave it best of a leave,
Is in the shame?

JULIET:
What, kinchers. You stour my what a wife, come of the light
And sarth at any wrong, tell war; then I she would.

KING CIWIS:
Look on the command to have you:
my supple not,
As this as traitor like it with thy grave and regit
That are you can conde
