In [1]:
import time

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm

# Hint to allow faster matmuls where supported (safe no-op if unavailable)
try:
    torch.set_float32_matmul_precision("high")
except AttributeError:
    pass

# hyperparameters (tune these for speed/quality tradeoff)
batch_size = 256  # try 256, 512 depending on memory
block_size = 64   # shorter sequences: faster steps, slightly less long-range context
embed_dim = 256
hidden_size = 256
num_layers = 2    # try 1 for even more speed if needed
learning_rate = 3e-4
num_epochs = 1000  # early stopping will usually stop before this

# DataLoader performance settings (adjust if needed)
# NOTE: On macOS + Jupyter, multi-worker DataLoader can crash; start with 0.
num_workers = 0  # set >0 only if you move training to a script and it is stable
pin_memory = False  # not critical on MPS/unified memory

device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)


In [2]:
with open("tiny-shakespeare.txt", "r", encoding="utf-8") as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)

stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}


def encode(s:str):
    return [stoi[c] for c in s]

def decode(ids):
    return "".join([itos[i] for i in ids])

data = torch.tensor(encode(text), dtype=torch.long)



n= int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]



class CharDataset(Dataset):
    def __init__(self, data_tensor, block_size):
        self.data = data_tensor
        self.block_size = block_size

    def __len__(self):
        ## last usable start index is len - block_size- 1
        return len(self.data) - self.block_size - 1

    def __getitem__(self,idx):
        x = self.data[idx: idx + self.block_size]
        y = self.data[idx + 1: idx + self.block_size + 1]
        return x, y

train_dataset = CharDataset(train_data, block_size)
val_dataset = CharDataset(val_data, block_size)
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=pin_memory,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=pin_memory,
)


In [3]:
class CharRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers, device):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.device = device

        self.embed = nn.Embedding(vocab_size, embed_dim)
        # Vanilla RNN cell (simpler and lighter than GRU/LSTM)
        self.rnn = nn.RNN(embed_dim, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, h=None):
        # x: (batch, seq_len) already on correct device
        x = self.embed(x)
        out, h = self.rnn(x, h)
        logits = self.fc(out)
        return logits, h

    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        self.eval()

        idx = idx.to(self.device)
        h = None
        for _ in range(max_new_tokens):
            # always feed the last block_size tokens
            idx_cond = idx[:, -block_size:]
            logits, h = self.forward(idx_cond, h)
            logits = logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_idx), dim=-1)

        self.train()
        return idx

model = CharRNN(vocab_size, embed_dim, hidden_size, num_layers, device).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)


In [4]:
class CharRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_size, num_layers, device):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.device = device

        self.embed = nn.Embedding(vocab_size, embed_dim)
        # Vanilla RNN cell (simpler and lighter than GRU/LSTM)
        self.rnn = nn.RNN(embed_dim, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x, h=None):
        # x: (batch, seq_len) already on correct device
        x = self.embed(x)
        out, h = self.rnn(x, h)
        logits = self.fc(out)
        return logits, h

    @torch.no_grad()
    def generate(self, idx, max_new_tokens):
        self.eval()

        idx = idx.to(self.device)
        h = None
        for _ in range(max_new_tokens):
            # always feed the last block_size tokens
            idx_cond = idx[:, -block_size:]
            logits, h = self.forward(idx_cond, h)
            logits = logits[:, -1, :]
            probs = torch.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_idx), dim=-1)

        self.train()
        return idx

model = CharRNN(vocab_size, embed_dim, hidden_size, num_layers, device).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
import torch
from tqdm import tqdm

# normalize device (works whether you passed "mps" or torch.device("mps"))
device = torch.device(device) if isinstance(device, str) else device

use_amp = (device.type == "mps")

@torch.no_grad()
def evaluate(loader):
    model.eval()
    total_loss = torch.zeros((), device=device)
    n = 0

    for x, y in loader:
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        with torch.autocast("mps", dtype=torch.bfloat16, enabled=use_amp):
            logits, _ = model(x)
            loss = criterion(logits.reshape(-1, vocab_size), y.reshape(-1))

        bs = x.size(0)
        total_loss += loss * bs
        n += bs

    return (total_loss / n).item()


best_val_loss = float("inf")
patience = 10  # stop early if val loss doesn't improve for this many epochs
no_improve_epochs = 0

for epoch in range(1, num_epochs + 1):
    epoch_start = time.perf_counter()

    model.train()
    total_loss = torch.zeros((), device=device)
    n = 0

    for x, y in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}", leave=False):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)

        with torch.autocast("mps", dtype=torch.bfloat16, enabled=use_amp):
            logits, _ = model(x)
            loss = criterion(logits.reshape(-1, vocab_size), y.reshape(-1))

        loss.backward()
        clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        bs = x.size(0)
        total_loss += loss.detach() * bs
        n += bs

    avg_train_loss = (total_loss / n).item()
    avg_val_loss = evaluate(val_loader)

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        no_improve_epochs = 0
    else:
        no_improve_epochs += 1

    epoch_time = time.perf_counter() - epoch_start
    print(
        f"Epoch {epoch}/{num_epochs} | "
        f"train loss: {avg_train_loss:.4f} | val loss: {avg_val_loss:.4f} | "
        f"time/epoch: {epoch_time:.2f}s"
    )

    if no_improve_epochs >= patience:
        print(f"Early stopping after {epoch} epochs (no val improvement for {patience} epochs)")
        break


Epoch 1/1000:  63%|██████▎   | 2480/3922 [01:21<00:44, 32.67it/s]