In [None]:
import math
import re
import random
import requests
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel

# config
SEQ_LENGTH = 48
BATCH_SIZE = 128
EMBED_DIM = 768
HIDDEN_DIM = 128
NUM_LAYERS = 2
NUM_HEADS = 4
DROPOUT = 0.1
LEARNING_RATE = 2e-3
EPOCHS = 10
PATIENCE = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_NAME = "gpt2"
SEED = 5

random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

# data handling
def download_sherlock_holmes():
    url = "https://www.gutenberg.org/files/1661/1661-0.txt"
    r = requests.get(url, timeout=15)
    text = r.text
    start = text.find("*** START OF THE PROJECT")
    end = text.find("*** END OF THE PROJECT")
    return text[start:end]

# handling newline chars
def preprocess_text(text):
    return text.replace("\r\n", " ").replace("\n", " ")


#use pretrained gpt2 tokenizer
def get_tokenizer():
    tok = AutoTokenizer.from_pretrained(MODEL_NAME)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    return tok

def load_gpt2_embeddings():
    base = AutoModel.from_pretrained(MODEL_NAME)
    if hasattr(base, "wte"):
        return base.wte.weight.data.clone()
    if hasattr(base, "transformer") and hasattr(base.transformer, "wte"):
        return base.transformer.wte.weight.data.clone()
    raise RuntimeError("Could not locate GPT-2 token embeddings")

class SherlockDataset(Dataset):
    def __init__(self, ids, seq_length):
        self.ids = ids
        self.seq_len = seq_length
        self.stride = seq_length
        self.samples = []
        for i in range(0, len(ids) - seq_length - 1, self.stride):
            self.samples.append((i, i + seq_length + 1))
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        s, e = self.samples[idx]
        seq = self.ids[s:e]
        return seq[:-1], seq[1:] 

# main architecture

class SherlockLSTMAttn(nn.Module):
    def __init__(self, vocab_size, embed_dim=768, hidden_dim=768,
                 num_layers=3, num_heads=8, dropout=0.2,
                 pretrained_embeddings=None, freeze_embeddings=True):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        if (pretrained_embeddings is not None and 
            pretrained_embeddings.size(0) == vocab_size and 
            pretrained_embeddings.size(1) == embed_dim):
            self.embedding.weight.data.copy_(pretrained_embeddings)
            self.embedding.weight.requires_grad = not freeze_embeddings
            print(f"Loaded GPT-2 embeddings (freeze={freeze_embeddings})")
        else:
            self.embedding.weight.requires_grad = True
        
        self.input_proj = nn.Linear(embed_dim, hidden_dim) 
        
        self.lstm = nn.LSTM(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0
        )
        
        self.attn = nn.MultiheadAttention(
            embed_dim=hidden_dim,
            num_heads=num_heads,
            dropout=dropout,
            batch_first=True
        )
        
        self.ln = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.output_proj = nn.Linear(hidden_dim, vocab_size, bias=False)
    
    def forward(self, x, hidden=None, key_padding_mask=None):
        emb = self.embedding(x)  # (B, T, E)
        
        # project embeddings to hidden dim
        h = self.input_proj(emb)  # (B, T, H)
        
        lstm_out, new_hidden = self.lstm(h, hidden)  # (B, T, H)
        
        # causal mask for training
        T = lstm_out.size(1)
        causal = torch.triu(torch.ones(T, T, device=x.device) * float("-inf"), diagonal=1)
        
        # self-attention on LSTM outputs
        attn_out, _ = self.attn(
            query=lstm_out, key=lstm_out, value=lstm_out,
            attn_mask=causal,
            key_padding_mask=key_padding_mask,
            need_weights=False
        )
        
        # residual connection + LayerNorm
        h = self.ln(lstm_out + self.dropout(attn_out))  # (B, T, H)
        logits = self.output_proj(h)  # (B, T, vocab_size)
        return logits, new_hidden
    
    @torch.no_grad()
    def generate_text(self, tokenizer, prompt,
                                 max_new_tokens=200, min_new_tokens=50,
                                 temperature=0.85, top_p=0.92,
                                 repetition_penalty=1.10,
                                 stop_on_sentence=True,
                                 device=DEVICE):
        self.eval().to(device)
        cur = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(device)
        
        for step in range(max_new_tokens):
            logits, _ = self(cur)
            logits = logits[:, -1, :]
            
            # ban pad/bos tokens
            for tid in [tokenizer.pad_token_id, tokenizer.bos_token_id]:
                if tid is not None:
                    logits[:, tid] = float("-inf")
            
            # temperature scaling
            logits = logits / max(temperature, 1e-8)
            
            # repetition penalty
            if repetition_penalty != 1.0:
                seen = torch.bincount(cur[0], minlength=logits.size(-1)).float().unsqueeze(0)
                logits = logits - torch.log1p(seen) * (repetition_penalty - 1.0)
            
            # top-p sampling
            probs = torch.softmax(logits, dim=-1)
            sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
            cumprobs = torch.cumsum(sorted_probs, dim=-1)
            keep = cumprobs <= top_p
            keep[..., 0] = True  
            
            filtered = torch.where(keep, sorted_probs, torch.zeros_like(sorted_probs))
            filtered = filtered / filtered.sum(dim=-1, keepdim=True).clamp_min(1e-12)
            next_sorted = torch.multinomial(filtered, num_samples=1)
            next_token = sorted_idx.gather(-1, next_sorted)
            
            cur = torch.cat([cur, next_token], dim=1)
            
            # early stopping on sentence end
            if stop_on_sentence and (step + 1) >= min_new_tokens:
                text = tokenizer.decode(cur[0].tolist(), skip_special_tokens=True)
                if re.search(r"[.!?][\"')\]]?\s*$", text):
                    break
        
        out = tokenizer.decode(cur[0].tolist(), skip_special_tokens=True)
        return out[len(prompt):].lstrip()

# training 
def train(model, train_loader, val_loader, pad_id, epochs=EPOCHS, patience=PATIENCE, device=DEVICE):
    model.to(device)
    opt = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)
    sched = optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=1, verbose=True)
    crit = nn.CrossEntropyLoss(ignore_index=pad_id)
    
    best_loss = float("inf")
    wait = 0
    
    for ep in range(epochs):
        # Training
        model.train()
        train_loss = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {ep+1}/{epochs} [Train]")
        
        for x, y in pbar:
            x, y = x.to(device), y.to(device)
            pad_mask = (x == pad_id)
            
            opt.zero_grad()
            logits, _ = model(x, key_padding_mask=pad_mask)
            loss = crit(logits.view(-1, logits.size(-1)), y.view(-1))
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), 0.5)
            opt.step()
            
            train_loss += loss.item()
            pbar.set_postfix(loss=f"{loss.item():.4f}")
        
        train_loss /= len(train_loader)
        
        # validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                pad_mask = (x == pad_id)
                logits, _ = model(x, key_padding_mask=pad_mask)
                loss = crit(logits.view(-1, logits.size(-1)), y.view(-1))
                val_loss += loss.item()
        
        val_loss /= len(val_loader)
        val_ppl = math.exp(val_loss)
        sched.step(val_loss)
        
        print(f"\nEpoch {ep+1}: Train {train_loss:.4f} | Val {val_loss:.4f} | Val PPL {val_ppl:.2f}")
        
        # save best model
        if val_loss < best_loss:
            best_loss = val_loss
            wait = 0
            torch.save(model.state_dict(), "sherlock_model.pth")
            print("  → Saved best model")
        else:
            wait += 1
            print(f"  → No improvement ({wait}/{patience})")
            if wait >= patience:
                print("Early stopping")
                break

def evaluate(model, test_loader, pad_id, device=DEVICE):
    model.to(device).eval()
    crit = nn.CrossEntropyLoss(ignore_index=pad_id)
    test_loss = 0.0
    correct, total = 0, 0
    
    with torch.no_grad():
        for x, y in tqdm(test_loader, desc="Testing"):
            x, y = x.to(device), y.to(device)
            pad_mask = (x == pad_id)
            logits, _ = model(x, key_padding_mask=pad_mask)
            loss = crit(logits.view(-1, logits.size(-1)), y.view(-1))
            test_loss += loss.item()
            
            preds = logits.argmax(dim=-1)
            mask = (y != pad_id)
            correct += (preds[mask] == y[mask]).sum().item()
            total += mask.sum().item()
    
    test_loss /= len(test_loader)
    test_ppl = math.exp(test_loss)
    acc = 100.0 * correct / total if total > 0 else 0.0
    
    print(f"\nTest Loss: {test_loss:.4f} | Test PPL: {test_ppl:.2f} | Top-1 Acc: {acc:.2f}%")

# main
if __name__ == "__main__":
    # load data
    raw_text = download_sherlock_holmes()
    text = preprocess_text(raw_text)
    tokenizer = get_tokenizer()
    ids = torch.tensor(tokenizer.encode(text, add_special_tokens=False), dtype=torch.long)
    
    # split data into train, test, val
    n = len(ids)
    train_end = int(0.8 * n)
    val_end = int(0.9 * n)
    
    train_ids = ids[:train_end]
    val_ids = ids[train_end:val_end]
    test_ids = ids[val_end:]
    
    # create datasets and loaders
    train_ds = SherlockDataset(train_ids, SEQ_LENGTH)
    val_ds = SherlockDataset(val_ids, SEQ_LENGTH)
    test_ds = SherlockDataset(test_ids, SEQ_LENGTH)
    
    train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, pin_memory=True)
    
    # build model
    gpt2_embeddings = load_gpt2_embeddings()
    model = SherlockLSTMAttn(
        vocab_size=tokenizer.vocab_size,
        embed_dim=EMBED_DIM,
        hidden_dim=HIDDEN_DIM,
        num_layers=NUM_LAYERS,
        num_heads=NUM_HEADS,
        dropout=DROPOUT,
        pretrained_embeddings=gpt2_embeddings,
        freeze_embeddings=True
    )
    
    print(f"Vocab: {tokenizer.vocab_size} | Params: {sum(p.numel() for p in model.parameters()):,}")
    
    # train
    pad_id = tokenizer.pad_token_id
    train(model, train_loader, val_loader, pad_id, epochs=EPOCHS, patience=PATIENCE, device=DEVICE)
    
    # evaluate
    model.load_state_dict(torch.load("sherlock_model.pth", map_location=DEVICE))
    evaluate(model, test_loader, pad_id, device=DEVICE)
    
    # generation

    prompts = [
      "“Well, Holmes,” said I,",
      "“My dear Watson,” Holmes replied,",
      "Holmes asked, “And what do you make of it?”",
      "“Come in, Inspector,” said Holmes,",
    ]
    
    print(f"\n{'='*80}")
    print(f"Generation with temperature=0.85 | top_p=0.92")
    print('='*80)
    
    for prompt in prompts:
        output = model.generate_text(
            tokenizer, prompt,
            max_new_tokens=160, min_new_tokens=40,
            temperature=0.85, top_p=0.92,
            repetition_penalty=1.12, stop_on_sentence=True
        )
        print(f"\n▶ Prompt: {prompt}")
        print(f"   {prompt} {output}")