In [None]:
# lstm_next_word.py
import os
import math
import random
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.utils.data import Dataset, DataLoader, random_split

import nltk
from nltk.tokenize import word_tokenize

# ---------------------------
# 0) Reproducibility
# ---------------------------
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

# ---------------------------
# 1) Data loading
# ---------------------------
def load_document(path="kgdata3.txt"):
    if not os.path.exists(path):
        raise FileNotFoundError(f"Dataset file {path} not found. Put it next to this script.")
    with open(path, "r", encoding="utf-8") as f:
        text = f.read()
    print(f"[INFO] Loaded corpus from {path} ({len(text):,} chars)")
    return text

# ---------------------------
# 2) Tokenize + vocab
# ---------------------------
def build_vocab(tokens, min_freq=2):
    # reserve 0,1 for pad/unk
    vocab = {"<pad>": 0, "<unk>": 1}
    counts = Counter(tokens)
    for w, c in counts.items():
        if c >= min_freq and w not in vocab:
            vocab[w] = len(vocab)
    return vocab

def numericalize(tokens, vocab):
    unk = vocab["<unk>"]
    return [vocab.get(t, unk) for t in tokens]

# ---------------------------
# 3) Dataset: prefixes -> next token
# ---------------------------
class NextWordDataset(Dataset):
    """
    Each example is (prefix_tokens, next_token_id)
    We keep variable-length inputs; padding happens in collate_fn.
    """
    def __init__(self, sentences, vocab, min_len=2):
        self.samples = []
        pad = vocab["<pad>"]  # not used here but kept for clarity
        for sent in sentences:
            toks = [t for t in word_tokenize(sent.lower()) if t.strip()]
            if len(toks) < min_len:
                continue
            ids = numericalize(toks, vocab)
            # prefixes: [ids[:i]] predict ids[i] (i from 1..len(ids)-1)
            for i in range(1, len(ids)):
                prefix = ids[:i]
                target = ids[i]
                self.samples.append((torch.tensor(prefix, dtype=torch.long), target))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        return self.samples[idx]

def collate_fn(batch, pad_idx=0):
    """
    batch: list of (prefix_tensor, target_id)
    Returns:
      x_padded: (B, T) long
      lengths:  (B,)  long
      targets:  (B,)  long
    """
    prefixes, targets = zip(*batch)
    lengths = torch.tensor([len(p) for p in prefixes], dtype=torch.long)
    x_padded = pad_sequence(prefixes, batch_first=True, padding_value=pad_idx)
    targets = torch.tensor(targets, dtype=torch.long)
    return x_padded, lengths, targets

# ---------------------------
# 4) Model
# ---------------------------
class LSTMNextWord(nn.Module):
    def __init__(self, vocab_size, emb_dim=100, hidden_dim=256, num_layers=2, pad_idx=0, dropout=0.3):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(
            input_size=emb_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,          # applies between layers (num_layers>1)
            batch_first=True,
        )
        self.fc = nn.Linear(hidden_dim, vocab_size)

        # Init: orthogonal recurrent, xavier input, positive forget bias
        for name, param in self.lstm.named_parameters():
            if "weight_ih" in name:
                nn.init.xavier_uniform_(param)
            elif "weight_hh" in name:
                nn.init.orthogonal_(param)
            elif "bias" in name:
                nn.init.zeros_(param)
                # forget gate bias chunk to +1
                n = param.shape[0] // 4
                param.data[n:2*n].fill_(1.0)  # [i, f, g, o] -> f gate is the 2nd quarter

    def forward(self, x, lengths):
        """
        x: (B, T) token ids
        lengths: (B,) true lengths before padding
        returns logits: (B, V)
        """
        emb = self.embedding(x)  # (B, T, E)
        packed = pack_padded_sequence(emb, lengths.cpu(), batch_first=True, enforce_sorted=False)
        _, (h_n, _) = self.lstm(packed)       # h_n: (num_layers, B, H)
        last = h_n[-1]                         # (B, H) from top layer
        logits = self.fc(last)                 # (B, V)
        return logits

# ---------------------------
# 5) Training / Eval helpers
# ---------------------------
def accuracy_from_logits(logits, targets):
    preds = logits.argmax(dim=-1)
    return (preds == targets).float().mean().item()

@torch.no_grad()
def evaluate(model, loader, device, criterion):
    model.eval()
    tot_loss, tot_acc, n_batches = 0.0, 0.0, 0
    for x, lengths, y in loader:
        x, lengths, y = x.to(device), lengths.to(device), y.to(device)
        logits = model(x, lengths)
        loss = criterion(logits, y)
        acc = accuracy_from_logits(logits, y)
        tot_loss += loss.item()
        tot_acc += acc
        n_batches += 1
    avg_loss = tot_loss / max(1, n_batches)
    ppl = math.exp(min(20, avg_loss))  # clamp for safety
    avg_acc = 100.0 * tot_acc / max(1, n_batches)
    return avg_loss, ppl, avg_acc

# ---------------------------
# 6) Training Loop
# ---------------------------
def train(
    text_path="kgdata3.txt",
    min_freq=2,
    batch_size=64,
    emb_dim=100,
    hidden_dim=256,
    num_layers=2,
    dropout=0.3,
    epochs=15,
    lr=1e-3,
    weight_decay=1e-6,
    grad_clip=5.0,
    val_split=0.1,
    device=None
):
    # device
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("[INFO] Device:", device)

    # tokenize
    nltk.download("punkt", quiet=True)
    nltk.download("punkt_tab", quiet=True)

    raw = load_document(text_path)
    tokens = [t for t in word_tokenize(raw.lower()) if t.strip()]

    # vocab
    vocab = build_vocab(tokens, min_freq=min_freq)
    pad_idx = vocab["<pad>"]
    print(f"[INFO] Vocab size: {len(vocab)}  (min_freq={min_freq})")

    # sentences -> dataset
    sentences = raw.splitlines()
    full_ds = NextWordDataset(sentences, vocab)

    # split train/val
    val_size = max(1, int(len(full_ds) * val_split))
    train_size = len(full_ds) - val_size
    train_ds, val_ds = random_split(full_ds, [train_size, val_size], generator=torch.Generator().manual_seed(42))

    collate = lambda batch: collate_fn(batch, pad_idx=pad_idx)
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, collate_fn=collate)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate)

    # model
    model = LSTMNextWord(vocab_size=len(vocab), emb_dim=emb_dim, hidden_dim=hidden_dim,
                         num_layers=num_layers, pad_idx=pad_idx, dropout=dropout).to(device)

    # loss/opt/sched
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2)

    print(f"[INFO] Train examples: {len(train_ds)} | Val examples: {len(val_ds)}")
    best_val = float("inf")

    for epoch in range(1, epochs + 1):
        model.train()
        tot_loss, tot_acc, n_batches = 0.0, 0.0, 0

        for x, lengths, y in train_loader:
            x, lengths, y = x.to(device), lengths.to(device), y.to(device)

            optimizer.zero_grad()
            logits = model(x, lengths)
            loss = criterion(logits, y)
            loss.backward()

            # gradient clipping avoids exploding gradients
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=grad_clip)

            optimizer.step()

            tot_loss += loss.item()
            tot_acc  += accuracy_from_logits(logits, y)
            n_batches += 1

        train_loss = tot_loss / n_batches
        train_ppl  = math.exp(min(20, train_loss))
        train_acc  = 100.0 * tot_acc / n_batches

        val_loss, val_ppl, val_acc = evaluate(model, val_loader, device, criterion)
        scheduler.step(val_loss)

        print(f"Epoch {epoch:02d} | "
              f"train loss {train_loss:.4f} ppl {train_ppl:.2f} acc {train_acc:.2f}% | "
              f"val loss {val_loss:.4f} ppl {val_ppl:.2f} acc {val_acc:.2f}% | "
              f"lr {optimizer.param_groups[0]['lr']:.2e}")

        # keep best
        if val_loss < best_val:
            best_val = val_loss
            torch.save({"model_state": model.state_dict(), "vocab": vocab}, "best_lstm_next_word.pth")
            print("[INFO] Saved best model -> best_lstm_next_word.pth")

    return model, vocab, device

# ---------------------------
# 7) Inference (top-k + temperature)
# ---------------------------
@torch.no_grad()
def generate_next_words(model, vocab, prompt, device, max_new_tokens=20, context_len=60, top_k=5, temperature=1.0):
    model.eval()
    id2word = {i: w for w, i in vocab.items()}
    pad_idx = vocab["<pad>"]
    unk_idx = vocab["<unk>"]

    toks = [t for t in word_tokenize(prompt.lower()) if t.strip()]
    ids = [vocab.get(t, unk_idx) for t in toks]
    generated = toks[:]

    for _ in range(max_new_tokens):
        ctx = ids[-context_len:]
        x = torch.tensor(ctx, dtype=torch.long).unsqueeze(0)           # (1, T)
        lengths = torch.tensor([x.size(1)], dtype=torch.long)
        x = x.to(device); lengths = lengths.to(device)

        logits = model(x, lengths)                                     # (1, V)
        logits = logits / max(1e-6, temperature)
        probs = torch.softmax(logits, dim=-1)

        if top_k is not None and top_k > 0:
            top_probs, top_idx = torch.topk(probs, k=min(top_k, probs.size(-1)), dim=-1)
            next_id = top_idx[0, torch.multinomial(top_probs[0], 1)].item()
        else:
            next_id = torch.multinomial(probs[0], 1).item()

        generated.append(id2word.get(next_id, "<unk>"))
        ids.append(next_id)

    return " ".join(generated)

# ---------------------------
# 8) Run
# ---------------------------
if __name__ == "__main__":
    model, vocab, device = train(
        text_path="kgdata3.txt",
        min_freq=2,
        batch_size=64,
        emb_dim=100,
        hidden_dim=256,
        num_layers=2,
        dropout=0.3,
        epochs=15,
        lr=1e-3,
        weight_decay=1e-6,
        grad_clip=5.0,
        val_split=0.1,
    )

    print("\n=== SAMPLE GENERATION ===")
    prompt = "hi how are"
    out = generate_next_words(model, vocab, prompt, device, max_new_tokens=20, context_len=60, top_k=5, temperature=1.0)
    print(out)


[INFO] Device: cuda
[INFO] Loaded corpus from kgdata3.txt (581,888 chars)
[INFO] Vocab size: 4722  (min_freq=2)
[INFO] Train examples: 108210 | Val examples: 12023
Epoch 01 | train loss 5.4863 ppl 241.36 acc 12.02% | val loss 5.0153 ppl 150.70 acc 16.23% | lr 1.00e-03
[INFO] Saved best model -> best_lstm_next_word.pth
Epoch 02 | train loss 4.7713 ppl 118.07 acc 17.54% | val loss 4.7759 ppl 118.61 acc 18.20% | lr 1.00e-03
[INFO] Saved best model -> best_lstm_next_word.pth
Epoch 03 | train loss 4.4726 ppl 87.58 acc 19.23% | val loss 4.6814 ppl 107.93 acc 18.86% | lr 1.00e-03
[INFO] Saved best model -> best_lstm_next_word.pth
Epoch 04 | train loss 4.2352 ppl 69.08 acc 20.65% | val loss 4.6352 ppl 103.05 acc 19.92% | lr 1.00e-03
[INFO] Saved best model -> best_lstm_next_word.pth
Epoch 05 | train loss 4.0219 ppl 55.80 acc 22.02% | val loss 4.6600 ppl 105.63 acc 20.31% | lr 1.00e-03
Epoch 06 | train loss 3.8218 ppl 45.68 acc 23.42% | val loss 4.7104 ppl 111.09 acc 20.62% | lr 1.00e-03
Epoch 