In [1]:
import os
import requests
import math
import random
from collections import Counter
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [2]:
window_size = 100
batch_size = 128
embedding_dim = 128
hidden_size = 256
num_layers = 2
lr = 0.001
epochs = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
min_freq = 1
print(f"Using device: {device}")

Using device: cpu


In [3]:
def download_text(url):
    try:
        print("Downloading text...")
        r = requests.get(url, timeout=30)
        r.raise_for_status()
        return r.text
    except Exception as e:
        print("Download failed:", e)
        return None

gutenberg_txt_url = "https://www.gutenberg.org/cache/epub/84/pg84.txt"

raw = download_text(gutenberg_txt_url)

if raw is None:
    print("Using fallback sample text.")
    raw = (
        "In the quiet village the sun rose slowly. Children ran across fields, "
        "the baker opened his shop, and a bell rang in the distance. People spoke "
        "of ordinary things, and the day seemed to promise little more than calm."
        "But small events have a way of changing the course of ordinary days."
        * 50
    )

print("Loaded text length (chars):", len(raw))

Downloading text...
Loaded text length (chars): 446544


In [4]:
import re
def clean_and_tokenize(text):
    txt = text
    txt = txt.replace("\r\n", "\n")
    txt = re.sub(r"\s+", " ", txt)  # normalize whitespace
    txt = txt.strip()
    # lowercase
    txt = txt.lower()
    # simple tokenization on whitespace and punctuation as separate tokens
    tokens = re.findall(r"\w+|[^\s\w]", txt, re.UNICODE)
    return tokens

tokens = clean_and_tokenize(raw)
print("Total tokens:", len(tokens))
if len(tokens) < window_size + 10:
    repeats = math.ceil((window_size + 10) / max(1, len(tokens)))
    tokens = tokens * repeats
    print(f"Text was short; repeated tokens {repeats} times -> new length {len(tokens)}")

Total tokens: 89701


In [5]:
# build vocabulary
counter = Counter(tokens)
vocab_tokens = [w for w, c in counter.items() if c >= min_freq]
vocab_tokens.sort(key=lambda w: (-counter[w], w))

PAD = "<PAD>"
UNK = "<UNK>"
START = "<START>"

itos = [PAD, UNK, START] + vocab_tokens
stoi = {w: i for i, w in enumerate(itos)}
vocab_size = len(itos)
print("Vocab size:", vocab_size)

def tok2idx(tok):
    return stoi.get(tok, stoi[UNK])

seq_len = window_size - 1
inputs = []
targets = []

for i in range(0, len(tokens) - seq_len):
    inp = tokens[i : i + seq_len]
    targ = tokens[i + seq_len]
    inputs.append([tok2idx(t) for t in inp])
    targets.append(tok2idx(targ))

print("Number of sequences:", len(inputs))

Vocab size: 7362
Number of sequences: 89602


In [6]:
class TextSeqDataset(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets
    def __len__(self):
        return len(self.inputs)
    def __getitem__(self, idx):
        return torch.tensor(self.inputs[idx], dtype=torch.long), torch.tensor(self.targets[idx], dtype=torch.long)

dataset = TextSeqDataset(inputs, targets)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [7]:
class RNNGenerator(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_size, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.gru = nn.GRU(embedding_dim, hidden_size, num_layers=num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
    def forward(self, x, hidden=None):
        # x: (batch, seq_len)
        emb = self.embedding(x)
        out, hidden = self.gru(emb, hidden)
        out = out[:, -1, :]
        logits = self.fc(out)
        return logits, hidden

model = RNNGenerator(vocab_size, embedding_dim, hidden_size, num_layers=num_layers).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
print(model)

RNNGenerator(
  (embedding): Embedding(7362, 128, padding_idx=0)
  (gru): GRU(128, 256, num_layers=2, batch_first=True)
  (fc): Linear(in_features=256, out_features=7362, bias=True)
)


# Training

In [8]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    for xb, yb in dataloader:
        xb = xb.to(device)
        yb = yb.to(device)
        optimizer.zero_grad()
        logits, _ = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * xb.size(0)
    return total_loss / len(dataloader.dataset)

for epoch in range(1, epochs+1):
    loss = train_epoch(model, dataloader, optimizer, criterion, device)
    print(f"Epoch {epoch}/{epochs} — loss: {loss:.4f}")

Epoch 1/8 — loss: 6.5189
Epoch 2/8 — loss: 5.4833
Epoch 3/8 — loss: 4.8622
Epoch 4/8 — loss: 4.3158
Epoch 5/8 — loss: 3.7544
Epoch 6/8 — loss: 3.1887
Epoch 7/8 — loss: 2.6529
Epoch 8/8 — loss: 2.1767


In [9]:
import torch.nn.functional as F

def generate_text(model, seed_text, max_words=120, temperature=1.0):
    model.eval()
    words = clean_and_tokenize(seed_text)
    while len(words) < seq_len:
        words = [START] + words
    generated = words[:]
    hidden = None
    for _ in range(max_words):
        input_seq = generated[-seq_len:]
        idxs = torch.tensor([tok2idx(w) for w in input_seq], dtype=torch.long).unsqueeze(0).to(device)
        with torch.no_grad():
            logits, hidden = model(idxs, hidden)
            logits = logits.squeeze(0) / (temperature if temperature>0 else 1.0)
            probs = F.softmax(logits, dim=-1)
            next_idx = torch.multinomial(probs, num_samples=1).item()
            next_word = itos[next_idx] if next_idx < len(itos) else UNK
            generated.append(next_word)
    out = []
    for token in generated:
        if re.match(r"^\w+$", token):  # word
            if out:
                out.append(" ")
            out.append(token)
        else:  # punctuation
            out.append(token)
    return "".join(out)

seed = "the monster"
print("Generated sample:\n")
print(generate_text(model, seed, max_words=120, temperature=1.0))

Generated sample:

<START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START><START> the monster both sight as their home added. during my residence days scarcely conjectured when beaufort did not a little wants might perform a blessing to visit them; but i followed that would pass my words and permit me with pleasure.[ o night; and when he appeared to betray them in his illness, and

# Save model

In [10]:
torch.save({
    "model_state_dict": model.state_dict(),
    "stoi": stoi,
    "itos": itos,
    "config": {
        "window_size": window_size,
        "embedding_dim": embedding_dim,
        "hidden_size": hidden_size,
        "num_layers": num_layers
    }
}, "rnn_text_gen.pth")
print("Saved model to rnn_text_gen.pth")

Saved model to rnn_text_gen.pth
