## Clean Neural Network Implementation

In [6]:
import math, os, random, torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader

def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic=True
    torch.backends.cudnn.benchmark=False

set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = {
    "data_path": "data/modern_chronicle.txt",
    "seq_len": 128,
    "batch_size": 128,
    "embedding_dim": 256,
    "hidden_dim": 256,
    "num_layers": 1,
    "dropout": 0.1,
    "rnn_type": "GRU",
    "num_epochs": 5,
    "learning_rate": 2e-3,
    "grad_clip": 1.0,
    "log_every": 100,
    "sample_every": 100,
    "max_generate": 400,
    "temperature": 0.9,
    "top_k": 40,
    "top_p": 0.9,
    "val_fraction": 0.05,
    "overlap_step": None,
    "save_path": "char_rnn_checkpoint.pt"
}

if config["data_path"] and os.path.exists(config["data_path"]):
    with open(config["data_path"], "r", encoding="utf-8") as f: text = f.read()
else:
    text = "ROMEO:\nBut soft, what light through yonder window breaks?\nIt is the east, and Juliet is the sun.\n"

class CharVocab:
    def __init__(self, text):
        chars = sorted(list(set(text)))
        self.itos = chars
        self.stoi = {c: i for i, c in enumerate(chars)}
    
    def encode(self, s):
        return [self.stoi[c] for c in s if c in self.stoi]
    
    def decode(self, ids):
        return "".join(self.itos[i] for i in ids)

vocab = CharVocab(text)
vocab_size = len(vocab.itos)
data_ids = torch.tensor(vocab.encode(text), dtype=torch.long)
n_total = len(data_ids)
n_val = max(1, int(n_total * config["val_fraction"]))
train_ids = data_ids[:-n_val]
val_ids = data_ids[-n_val:]

class CharChunkDataset(Dataset):
    def __init__(self, ids, seq_len, step=None):
        self.ids = ids
        self.T = seq_len
        self.step = step if step is not None else seq_len
        self.num_chunks = (len(ids) - 1 - seq_len) // self.step + 1
        self.starts = [i * self.step for i in range(self.num_chunks)]
    
    def __len__(self):
        return self.num_chunks
    
    def __getitem__(self, idx):
        s = self.starts[idx]
        return self.ids[s:s + self.T], self.ids[s + 1:s + 1 + self.T]

train_ds = CharChunkDataset(train_ids, config["seq_len"], config["overlap_step"])
val_ds = CharChunkDataset(val_ids, config["seq_len"], config["overlap_step"])
train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False, drop_last=True)

class CharRNN(nn.Module):
    def __init__(self, vocab_size, emb, hidden, layers, dropout, rnn_type="GRU"):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb)
        rnn_cls = {"GRU": nn.GRU, "LSTM": nn.LSTM}[rnn_type.upper()]
        self.rnn = rnn_cls(emb, hidden, num_layers=layers, dropout=dropout if layers > 1 else 0.0, batch_first=True)
        self.drop = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden, vocab_size)
        self.rnn_type = rnn_type.upper()
        self.layers = layers
        self.hidden = hidden
    
    def forward(self, x, h=None):
        x = self.emb(x)
        x, h = self.rnn(x, h)
        x = self.drop(x)
        return self.fc(x), h
    
    def init_hidden(self, batch_size, device):
        if self.rnn_type == "LSTM":
            return (torch.zeros(self.layers, batch_size, self.hidden, device=device),
                    torch.zeros(self.layers, batch_size, self.hidden, device=device))
        else:
            return torch.zeros(self.layers, batch_size, self.hidden, device=device)

model = CharRNN(vocab_size, config["embedding_dim"], config["hidden_dim"], config["num_layers"], config["dropout"], config["rnn_type"]).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

def bpc_from_loss(loss_val):
    return loss_val / math.log(2.0)

def evaluate_loss(data_loader):
    model.eval(); total = 0.0; count = 0
    
    with torch.no_grad():
        for x, y in data_loader:
            x, y = x.to(device), y.to(device)
            logits, _ = model(x)
            loss = criterion(logits.view(-1, vocab_size), y.view(-1))
            total += loss.item(); count += 1
    
    if count == 0:
        return float("nan"), float("nan")
    
    avg = total / count
    return avg, bpc_from_loss(avg)

def sample_text(model, vocab, max_new_tokens=300, temperature=1.0, top_k=None, top_p=None, prompt="", device="cpu"):
    model.eval()
    
    with torch.no_grad():
        if not prompt:
            prompt = random.choice(vocab.itos)
        
        input_ids = torch.tensor(vocab.encode(prompt), dtype=torch.long, device=device).unsqueeze(0)
        h = None
        out = list(prompt)
        
        for _ in range(max_new_tokens):
            logits, h = model(input_ids, h)
            last_logits = logits[0, -1, :] / max(1e-6, temperature)
            probs = torch.softmax(last_logits, dim=-1)
            
            if top_k is not None:
                k = min(top_k, probs.numel())
                topk_vals, topk_idx = torch.topk(probs, k)
                mask = torch.zeros_like(probs, dtype=torch.bool); mask[topk_idx] = True
                probs = probs.masked_fill(~mask, 0)
            
            if top_p is not None:
                sorted_probs, sorted_idx = torch.sort(probs, descending=True)
                cumsum = torch.cumsum(sorted_probs, dim=0)
                keep = cumsum <= top_p; keep[0] = True
                filtered = torch.zeros_like(sorted_probs).masked_scatter(keep, sorted_probs[keep])
                probs = torch.zeros_like(probs).scatter(0, sorted_idx, filtered)
            
            s = probs.sum()
            
            if s <= 0 or torch.isnan(s):
                next_id = torch.argmax(last_logits)
            
            else:
                probs = probs / s
                next_id = torch.multinomial(probs, 1).item()
            
            out.append(vocab.itos[int(next_id)])
            input_ids = torch.tensor([[next_id]], device=device)
        
        return "".join(out)

global_step = 0; best_val = float("inf")

for epoch in range(1, config["num_epochs"] + 1):
    model.train()
    running = 0.0
    
    for i, (x, y) in enumerate(train_loader, start=1):
        x, y = x.to(device), y.to(device)
        logits, _ = model(x)
        loss = criterion(logits.view(-1, vocab_size), y.view(-1))
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        
        if config["grad_clip"] is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), config["grad_clip"])
        
        optimizer.step()
        running += loss.item(); global_step += 1
        
        if global_step % config["log_every"] == 0:
            avg_loss = running / config["log_every"]
            avg_bpc = bpc_from_loss(avg_loss)
            print(f"Epoch {epoch:02d} | Step {global_step:06d} | train loss {avg_loss:.4f} | bpc {avg_bpc:.3f}")
            running = 0.0
        
        if global_step % config["sample_every"] == 0:
            print("\n--- Sample ---")
            print(sample_text(model, vocab, max_new_tokens=config["max_generate"], temperature=config["temperature"], top_k=config["top_k"], top_p=config["top_p"], prompt="ROMEO:", device=device))
            print("--------------\n")
    
    val_loss, val_bpc = evaluate_loss(val_loader)
    print(f"[Validation] Epoch {epoch}: loss {val_loss:.4f} | bpc {val_bpc:.3f}")
    
    if val_loss < best_val:
        best_val = val_loss
        torch.save({"model_state": model.state_dict(), "config": config, "stoi": vocab.stoi, "itos": vocab.itos}, config["save_path"])
        print(f"Saved checkpoint to {config['save_path']}")

[Validation] Epoch 1: loss 2.0261 | bpc 2.923
Saved checkpoint to char_rnn_checkpoint.pt
Epoch 02 | Step 000100 | train loss 0.9203 | bpc 1.328

--- Sample ---
ROMEO:”!;, “We a for the was was inssing so the with you was me was of the retion the starain. She had that that when her she at had she coment her been that hat excensition the dorsistired she had her when she was suched on in said. “I bay suraing was ming, and calug seen that ser the ressed. And the was nought still seen in the shat
him of her in the winking, Honora a for wo arter not crage of a sail 
--------------

[Validation] Epoch 2: loss 1.7737 | bpc 2.559
Saved checkpoint to char_rnn_checkpoint.pt
[Validation] Epoch 3: loss 1.6411 | bpc 2.368
Saved checkpoint to char_rnn_checkpoint.pt
Epoch 04 | Step 000200 | train loss 0.6837 | bpc 0.986

--- Sample ---
ROMEO:.. She said Honora sure. The sundost in the arriding and remissalishing of the for a worned in where dearion she had been such an onceling with it will on her tha