In [32]:
import math
import re
import time
from collections import Counter
from typing import List, Tuple

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler

# Optional: sentencepiece for subword tokenization (recommended)
try:
    import sentencepiece as spm
    HAS_SP = True
except Exception:
    HAS_SP = False

## **CONFIG**

In [33]:
CONFIG = {
    # Model
    "d_model": 256,
    "nhead": 8,
    "num_encoder_layers": 4,
    "num_decoder_layers": 4,
    "d_ff": 1024,
    "dropout": 0.1,

    # Training
    "batch_size": 32,
    "learning_rate": 3e-5,
    "num_epochs": 10,
    "max_encoder_len": 512,
    "max_decoder_len": 128,
    "gradient_accumulation_steps": 2,
    "warmup_steps": 4000,

    # Tokens & data
    "vocab_size": 30000,
    "pad_token": "<PAD>",
    "sos_token": "<SOS>",
    "eos_token": "<EOS>",
    "unk_token": "<UNK>",

    # Optimization
    "label_smoothing": 0.1,
    "grad_clip": 1.0,

    # Decoding
    "beam_size": 4,
    "length_penalty": 0.6,       # length normalization exponent
    "repetition_penalty": 1.2,   # penalize repeated tokens in beam search

    # Device & AMP
    "device": "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu",
}

CONFIG["use_amp"] = True if CONFIG["device"] == "cuda" else False
print(f"Device: {CONFIG['device']}, AMP: {CONFIG['use_amp']}")

# Optional: point to a trained sentencepiece model if you have one
SENTENCEPIECE_MODEL_PATH = None  # e.g., "bpe.model" (set to path to use sentencepiece)



Device: cuda, AMP: True


## **Tokenizer**

In [34]:
class SimpleTokenizer:
    """Fallback word-level tokenizer used when sentencepiece unavailable."""
    def __init__(self, vocab_size=30000):
        self.vocab_size = vocab_size
        self.word2idx = {}
        self.idx2word = {}
        self.word_freq = Counter()
        self.pad_token = CONFIG["pad_token"]
        self.sos_token = CONFIG["sos_token"]
        self.eos_token = CONFIG["eos_token"]
        self.unk_token = CONFIG["unk_token"]
        self.special_tokens = [self.pad_token, self.sos_token, self.eos_token, self.unk_token]
        for i, t in enumerate(self.special_tokens):
            self.word2idx[t] = i
            self.idx2word[i] = t

    def clean_text(self, text):
        text = str(text).lower()
        text = re.sub(r'[^a-z0-9\s\.\,\!\?]', '', text)
        text = re.sub(r'\s+', ' ', text).strip()
        return text

    def tokenize(self, text):
        return self.clean_text(text).split()

    def build_vocab(self, texts):
        print("Building simple word-level vocab...")
        for t in texts:
            self.word_freq.update(self.tokenize(t))
        most_common = self.word_freq.most_common(self.vocab_size - len(self.special_tokens))
        for idx, (w, _) in enumerate(most_common, start=len(self.special_tokens)):
            self.word2idx[w] = idx
            self.idx2word[idx] = w
        print("Vocab size:", len(self.word2idx))
        CONFIG["vocab_size"] = len(self.word2idx)

    def encode(self, text, max_length=None, add_special_tokens=True):
        toks = self.tokenize(text)
        if add_special_tokens:
            toks = [self.sos_token] + toks + [self.eos_token]
        if max_length is not None:
            toks = toks[:max_length]
        return [self.word2idx.get(t, self.word2idx[self.unk_token]) for t in toks]

    def decode(self, ids, skip_special_tokens=True):
        toks = [self.idx2word.get(i, self.unk_token) for i in ids]
        if skip_special_tokens:
            toks = [t for t in toks if t not in self.special_tokens]
        return " ".join(toks)

    @property
    def pad_token_id(self):
        return self.word2idx[self.pad_token]

    @property
    def sos_token_id(self):
        return self.word2idx[self.sos_token]

    @property
    def eos_token_id(self):
        return self.word2idx[self.eos_token]


class SPTokenizer:
    """Wrapper for sentencepiece processor (expects a trained model)"""
    def __init__(self, model_path):
        self.sp = spm.SentencePieceProcessor(model_file=model_path)
        # create mapping for special behavior: use SP IDs; we will use SP's bos/eos if available
        # Note: SentencePiece has its own id space and special ids
        self.vocab_size = self.sp.get_piece_size()
        CONFIG["vocab_size"] = self.vocab_size

    def build_vocab(self, *args, **kwargs):
        raise RuntimeError("SPTokenizer assumes pre-trained sentencepiece model; no build_vocab required.")

    def encode(self, text, max_length=None, add_special_tokens=True):
        ids = self.sp.encode(text, out_type=int)
        if add_special_tokens:
            # use SentencePiece's bos/eos if present, else fallback to adding no extra tokens
            # Many SP models don't have BOS/EOS; we will add nothing in that case.
            pass
        if max_length is not None:
            ids = ids[:max_length]
        return ids

    def decode(self, ids, skip_special_tokens=True):
        return self.sp.decode(ids)

    @property
    def pad_token_id(self):
        # SentencePiece doesn't reserve pad id by default — choose 0 as pad commonly used, but ensure it's safe.
        return 0

    @property
    def sos_token_id(self):
        # return None if SP doesn't define
        return None

    @property
    def eos_token_id(self):
        return None



## **Dataset**

In [35]:
class SummarizationDataset(Dataset):
    def __init__(self, articles: List[str], summaries: List[str], tokenizer, max_encoder_len, max_decoder_len):
        self.articles = articles
        self.summaries = summaries
        self.tokenizer = tokenizer
        self.max_encoder_len = max_encoder_len
        self.max_decoder_len = max_decoder_len

    def __len__(self):
        return len(self.articles)

    def __getitem__(self, idx):
        enc = self.tokenizer.encode(self.articles[idx], max_length=self.max_encoder_len, add_special_tokens=False)
        dec = self.tokenizer.encode(self.summaries[idx], max_length=self.max_decoder_len, add_special_tokens=True)
        target = dec[1:] + [self.tokenizer.pad_token_id]
        return {
            "encoder_input": torch.tensor(enc, dtype=torch.long),
            "decoder_input": torch.tensor(dec, dtype=torch.long),
            "decoder_target": torch.tensor(target, dtype=torch.long)
        }


def collate_fn(batch, pad_token_id):
    encs = [item["encoder_input"] for item in batch]
    decs = [item["decoder_input"] for item in batch]
    tars = [item["decoder_target"] for item in batch]
    encs = nn.utils.rnn.pad_sequence(encs, batch_first=True, padding_value=pad_token_id)
    decs = nn.utils.rnn.pad_sequence(decs, batch_first=True, padding_value=pad_token_id)
    tars = nn.utils.rnn.pad_sequence(tars, batch_first=True, padding_value=pad_token_id)
    return encs, decs, tars


## **Transformer components (Pre-LN)**

In [36]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer("pe", pe.unsqueeze(0))

    def forward(self, x):
        x = x + self.pe[:, : x.size(1), :].to(x.dtype)
        return self.dropout(x)


class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, nhead, dropout=0.1):
        super().__init__()
        assert d_model % nhead == 0
        self.d_model = d_model
        self.nhead = nhead
        self.d_k = d_model // nhead

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def split_heads(self, x):
        b, seq_len, d = x.size()
        return x.view(b, seq_len, self.nhead, self.d_k).transpose(1, 2)  # (b, nhead, seq, d_k)

    def combine_heads(self, x):
        x = x.transpose(1, 2).contiguous()
        b, seq_len, n, d_k = x.size()
        return x.view(b, seq_len, n * d_k)

    def forward(self, query, key, value, mask=None):
        Q = self.split_heads(self.W_q(query))
        K = self.split_heads(self.W_k(key))
        V = self.split_heads(self.W_v(value))

        # scaled dot-product
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)  # (b, nhead, q, k)

        if mask is not None:
            # mask shape should be broadcastable to scores: (b, 1, 1, k) or (b, 1, q, k)
            mask_bool = mask.to(torch.bool)
            scores = scores.masked_fill(~mask_bool, float(-1e4))  # AMP-safe masking

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        out = torch.matmul(attn, V)  # (b, nhead, q, d_k)
        out = self.combine_heads(out)
        return self.W_o(out)


class PositionwiseFF(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.fc2(self.dropout(F.relu(self.fc1(x))))


class EncoderLayer(nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(d_model, nhead, dropout)
        self.drop1 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.ff = PositionwiseFF(d_model, d_ff, dropout)
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x, src_mask=None):
        x_norm = self.norm1(x)
        attn_out = self.attn(x_norm, x_norm, x_norm, src_mask)
        x = x + self.drop1(attn_out)
        x_norm = self.norm2(x)
        ff_out = self.ff(x_norm)
        x = x + self.drop2(ff_out)
        return x


class DecoderLayer(nn.Module):
    def __init__(self, d_model, nhead, d_ff, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.self_attn = MultiHeadAttention(d_model, nhead, dropout)
        self.drop1 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(d_model)
        self.cross_attn = MultiHeadAttention(d_model, nhead, dropout)
        self.drop2 = nn.Dropout(dropout)
        self.norm3 = nn.LayerNorm(d_model)
        self.ff = PositionwiseFF(d_model, d_ff, dropout)
        self.drop3 = nn.Dropout(dropout)

    def forward(self, x, memory, src_mask=None, tgt_mask=None):
        x_norm = self.norm1(x)
        self_attn = self.self_attn(x_norm, x_norm, x_norm, tgt_mask)
        x = x + self.drop1(self_attn)
        x_norm = self.norm2(x)
        cross_attn = self.cross_attn(x_norm, memory, memory, src_mask)
        x = x + self.drop2(cross_attn)
        x_norm = self.norm3(x)
        ff_out = self.ff(x_norm)
        x = x + self.drop3(ff_out)
        return x


class Transformer(nn.Module):
    def __init__(self, vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6,
                 d_ff=2048, dropout=0.1, max_len=5000):
        super().__init__()
        self.d_model = d_model
        self.encoder_embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.decoder_embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout)

        self.encoder_layers = nn.ModuleList(
            [EncoderLayer(d_model, nhead, d_ff, dropout) for _ in range(num_encoder_layers)]
        )
        self.decoder_layers = nn.ModuleList(
            [DecoderLayer(d_model, nhead, d_ff, dropout) for _ in range(num_decoder_layers)]
        )

        self.output_projection = nn.Linear(d_model, vocab_size)
        self._init_weights()

    def _init_weights(self):
        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def create_pad_mask(self, seq, pad_token_id):
        # seq: (batch, seq_len) -> True where token is NOT pad
        return (seq != pad_token_id).unsqueeze(1).unsqueeze(2)  # (batch,1,1,seq_len)

    def create_causal_mask(self, size, device):
        mask = torch.triu(torch.ones(size, size, dtype=torch.bool, device=device), diagonal=1)
        return (~mask).unsqueeze(0).unsqueeze(0)  # (1,1,size,size)

    def encode(self, src, src_mask=None):
        x = self.encoder_embedding(src) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        for layer in self.encoder_layers:
            x = layer(x, src_mask)
        return x

    def decode(self, tgt, memory, src_mask=None, tgt_mask=None):
        x = self.decoder_embedding(tgt) * math.sqrt(self.d_model)
        x = self.pos_encoding(x)
        for layer in self.decoder_layers:
            x = layer(x, memory, src_mask, tgt_mask)
        return x

    def forward(self, src, tgt, pad_token_id):
        src_mask = self.create_pad_mask(src, pad_token_id)
        tgt_pad_mask = self.create_pad_mask(tgt, pad_token_id)
        tgt_causal_mask = self.create_causal_mask(tgt.size(1), device=tgt.device)
        tgt_mask = tgt_pad_mask & tgt_causal_mask
        memory = self.encode(src, src_mask)
        dec = self.decode(tgt, memory, src_mask, tgt_mask)
        logits = self.output_projection(dec)
        return logits



## **Training utilities**

In [37]:
def train_epoch(model, dataloader, optimizer, criterion, scaler, device, config, pad_id):
    model.train()
    total_loss = 0.0
    total_tokens = 0
    for batch_idx, (src, tgt_input, tgt_output) in enumerate(dataloader):
        src = src.to(device)
        tgt_input = tgt_input.to(device)
        tgt_output = tgt_output.to(device)

        with autocast(enabled=config["use_amp"]):
            logits = model(src, tgt_input, pad_id)  # (batch, tgt_len, vocab)
            vocab = logits.size(-1)
            logits_flat = logits.contiguous().view(-1, vocab)
            tgt_flat = tgt_output.contiguous().view(-1)

            # compute loss in float32 if AMP is used
            if config["use_amp"]:
                loss = criterion(logits_flat.float(), tgt_flat)
            else:
                loss = criterion(logits_flat, tgt_flat)
            loss = loss / config["gradient_accumulation_steps"]

        scaler.scale(loss).backward()

        if (batch_idx + 1) % config["gradient_accumulation_steps"] == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config["grad_clip"])
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        num_tokens = (tgt_flat != pad_id).sum().item()
        total_loss += loss.item() * config["gradient_accumulation_steps"] * num_tokens
        total_tokens += num_tokens

        if (batch_idx + 1) % 50 == 0:
            print(f"  Batch {batch_idx + 1}/{len(dataloader)}, Loss: {total_loss/total_tokens:.4f}")

    return total_loss / total_tokens


def evaluate(model, dataloader, criterion, device, config, pad_id):
    model.eval()
    total_loss = 0.0
    total_tokens = 0
    with torch.no_grad():
        for src, tgt_input, tgt_output in dataloader:
            src = src.to(device)
            tgt_input = tgt_input.to(device)
            tgt_output = tgt_output.to(device)

            logits = model(src, tgt_input, pad_id)
            vocab = logits.size(-1)
            logits_flat = logits.contiguous().view(-1, vocab)
            tgt_flat = tgt_output.contiguous().view(-1)
            if config["use_amp"]:
                loss = criterion(logits_flat.float(), tgt_flat)
            else:
                loss = criterion(logits_flat, tgt_flat)

            num_tokens = (tgt_flat != pad_id).sum().item()
            total_loss += loss.item() * num_tokens
            total_tokens += num_tokens

    return total_loss / total_tokens



## **Beam search with length norm + repetition penalty**

In [38]:
def apply_repetition_penalty(scores: torch.Tensor, seq: List[int], penalty: float):
    # scores: (vocab,) logit scores for next step
    # seq: already generated token ids
    if penalty == 1.0:
        return scores
    seen = set(seq)
    for tok in seen:
        # penalize repeated tokens by lowering their logits
        scores[tok] /= penalty
    return scores


def beam_search_decode(model: Transformer, tokenizer, src_tokens: List[int], device, config,
                       max_len=80, beam_size=4):
    model.eval()
    pad_id = tokenizer.pad_token_id
    sos = tokenizer.sos_token_id if getattr(tokenizer, "sos_token_id", None) is not None else None
    eos = tokenizer.eos_token_id if getattr(tokenizer, "eos_token_id", None) is not None else None

    src = torch.tensor([src_tokens], dtype=torch.long).to(device)
    src_mask = model.create_pad_mask(src, pad_id)
    memory = model.encode(src, src_mask)

    # Beam item: (score, sequence)
    beams = [(0.0, [sos] if sos is not None else [])]  # start from SOS or empty if none
    completed = []

    for step in range(max_len):
        all_candidates = []
        for score, seq in beams:
            # if already finished with EOS, carry it over
            if eos is not None and len(seq) > 0 and seq[-1] == eos:
                all_candidates.append((score, seq))
                continue

            tgt = torch.tensor([seq], dtype=torch.long).to(device)
            tgt_mask = model.create_pad_mask(tgt, pad_id) & model.create_causal_mask(tgt.size(1), device=device)
            dec_out = model.decode(tgt, memory, src_mask, tgt_mask)
            logits = model.output_projection(dec_out)  # (1, seq_len, vocab)
            next_logits = logits[0, -1]  # (vocab,)
            # convert to log-probabilities
            log_probs = F.log_softmax(next_logits, dim=-1).detach().cpu()

            # apply repetition penalty (simple)
            if config["repetition_penalty"] != 1.0:
                log_probs = apply_repetition_penalty(log_probs.clone(), seq, config["repetition_penalty"])

            topk = torch.topk(log_probs, beam_size)
            for i in range(beam_size):
                tok = int(topk.indices[i].item())
                tok_logprob = float(topk.values[i].item())
                new_score = score + tok_logprob
                new_seq = seq + [tok]
                all_candidates.append((new_score, new_seq))

        # select best beams by score (then apply length normalization when finished)
        ordered = sorted(all_candidates, key=lambda x: x[0], reverse=True)
        beams = ordered[:beam_size]

        # optionally move finished sequences to completed list
        new_beams = []
        for s, seq in beams:
            if eos is not None and seq[-1] == eos:
                completed.append((s, seq))
            else:
                new_beams.append((s, seq))
        beams = new_beams
        if len(beams) == 0:
            break

    # combine completed and ongoing beams
    all_candidates = completed + beams
    if len(all_candidates) == 0:
        return ""  # nothing generated

    # apply length normalization score / (len^alpha)
    alpha = config["length_penalty"]
    best_score, best_seq = max(all_candidates, key=lambda x: x[0] / ((len(x[1]) ** alpha) + 1e-9))

    # remove SOS if present and cut after EOS
    if sos is not None and len(best_seq) > 0 and best_seq[0] == sos:
        best_seq = best_seq[1:]
    if eos is not None and eos in best_seq:
        best_seq = best_seq[: best_seq.index(eos)]

    return tokenizer.decode(best_seq, skip_special_tokens=True)



## **Inference wrapper**

In [39]:
def generate_summary(model, tokenizer, article: str, device, config, max_len=80):
    # use beam search if beam_size>1
    src_tokens = tokenizer.encode(article, max_length=config["max_encoder_len"], add_special_tokens=False)
    if config["beam_size"] > 1:
        return beam_search_decode(model, tokenizer, src_tokens, device, config, max_len=max_len, beam_size=config["beam_size"])
    else:
        # fallback greedy generation
        model.eval()
        pad_id = tokenizer.pad_token_id
        sos = tokenizer.sos_token_id
        eos = tokenizer.eos_token_id
        src = torch.tensor([src_tokens], dtype=torch.long).to(device)
        memory = model.encode(src, model.create_pad_mask(src, pad_id))
        generated = [sos]
        for _ in range(max_len):
            tgt = torch.tensor([generated], dtype=torch.long).to(device)
            tgt_mask = model.create_pad_mask(tgt, pad_id) & model.create_causal_mask(tgt.size(1), device=device)
            dec = model.decode(tgt, memory, model.create_pad_mask(src, pad_id), tgt_mask)
            logits = model.output_projection(dec)
            next_tok = int(logits[0, -1].argmax().item())
            generated.append(next_tok)
            if eos is not None and next_tok == eos:
                break
        if sos is not None and generated and generated[0] == sos:
            generated = generated[1:]
        if eos is not None and eos in generated:
            generated = generated[: generated.index(eos)]
        return tokenizer.decode(generated, skip_special_tokens=True)



## **Main training orchestration**

In [31]:
def main():
    print("=" * 80)
    print("IMPROVED TRANSFORMER - TRAIN + INFERENCE")
    print("=" * 80)

    # ---------- load dataset (auto-detect columns) ----------
    dataset_paths = [
        "/content/gdrive/MyDrive/practical_data/Inshorts-Cleaned-Data.xlsx",
        "/content/Inshorts-Cleaned-Data.xlsx",
        "Inshorts-Cleaned-Data.xlsx",
        "news_summary.csv",
    ]
    df = None
    for p in dataset_paths:
        try:
            if p.endswith(".xlsx"):
                df = pd.read_excel(p)
                print("Loaded", p)
                break
            elif p.endswith(".csv"):
                df = pd.read_csv(p, encoding="latin-1")
                print("Loaded", p)
                break
        except Exception:
            continue

    if df is None:
        print("Dataset not found. Update dataset_paths.")
        return

    article_cols = ["Short", "short", "content", "article", "text", "news", "Content", "Article", "Text"]
    summary_cols = ["Headline", "headline", "summary", "title", "headlines", "Headline", "Summary", "Title"]

    article_col = next((c for c in article_cols if c in df.columns), None)
    summary_col = next((c for c in summary_cols if c in df.columns), None)

    if article_col is None or summary_col is None:
        print("Could not auto-detect article/summary columns. Columns:", df.columns.tolist())
        return

    df = df.dropna(subset=[article_col, summary_col])
    df = df[df[article_col].str.len() > 50]
    df = df[df[summary_col].str.len() > 10]

    sample_size = min(50000, len(df))
    df = df.sample(n=sample_size, random_state=42).reset_index(drop=True)
    articles = df[article_col].astype(str).tolist()
    summaries = df[summary_col].astype(str).tolist()

    # ---------- tokenizer: prefer sentencepiece if available & model provided ----------
    if HAS_SP and SENTENCEPIECE_MODEL_PATH:
        print("Using SentencePiece tokenizer:", SENTENCEPIECE_MODEL_PATH)
        tokenizer = SPTokenizer(SENTENCEPIECE_MODEL_PATH)
    else:
        print("Using SimpleTokenizer (word-level). To get better results use sentencepiece BPE.")
        tokenizer = SimpleTokenizer(vocab_size=CONFIG["vocab_size"])
        tokenizer.build_vocab(articles + summaries)

    pad_id = tokenizer.pad_token_id
    print("Pad id:", pad_id, "Vocab size:", CONFIG.get("vocab_size"))

    # ---------- dataset & loaders ----------
    train_size = int(0.9 * len(articles))
    train_ds = SummarizationDataset(articles[:train_size], summaries[:train_size], tokenizer, CONFIG["max_encoder_len"], CONFIG["max_decoder_len"])
    val_ds = SummarizationDataset(articles[train_size:], summaries[train_size:], tokenizer, CONFIG["max_encoder_len"], CONFIG["max_decoder_len"])

    train_loader = DataLoader(train_ds, batch_size=CONFIG["batch_size"], shuffle=True, collate_fn=lambda b: collate_fn(b, pad_id), num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=CONFIG["batch_size"], shuffle=False, collate_fn=lambda b: collate_fn(b, pad_id), num_workers=2, pin_memory=True)

    # ---------- model ----------
    model = Transformer(vocab_size=CONFIG["vocab_size"], d_model=CONFIG["d_model"], nhead=CONFIG["nhead"],
                        num_encoder_layers=CONFIG["num_encoder_layers"], num_decoder_layers=CONFIG["num_decoder_layers"],
                        d_ff=CONFIG["d_ff"], dropout=CONFIG["dropout"], max_len=max(CONFIG["max_encoder_len"], CONFIG["max_decoder_len"])
                        ).to(CONFIG["device"])
    print("Model params:", sum(p.numel() for p in model.parameters() if p.requires_grad))

    # ---------- optimizer + loss + scaler ----------
    optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG["learning_rate"], betas=(0.9, 0.98), eps=1e-8)
    criterion = nn.CrossEntropyLoss(ignore_index=pad_id, label_smoothing=CONFIG["label_smoothing"])
    scaler = GradScaler(enabled=CONFIG["use_amp"])

    best_val = float("inf")
    for epoch in range(CONFIG["num_epochs"]):
        start_time = time.time()
        print(f"\nEpoch {epoch+1}/{CONFIG['num_epochs']}")
        train_loss = train_epoch(model, train_loader, optimizer, criterion, scaler, CONFIG["device"], CONFIG, pad_id)
        val_loss = evaluate(model, val_loader, criterion, CONFIG["device"], CONFIG, pad_id)
        duration = time.time() - start_time
        print(f"Epoch {epoch+1} — Train loss: {train_loss:.4f} Val loss: {val_loss:.4f} Time: {duration:.1f}s")

        if val_loss < best_val:
            best_val = val_loss
            torch.save({"epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "val_loss": val_loss, "config": CONFIG}, "best_model.pt")
            print("Saved best model.")

        if (epoch + 1) % 2 == 0:
            print("\nSample generation (validation sample 0):")
            sample_article = articles[train_size][:500]
            print("Article:", sample_article[:300], "...")
            gen = generate_summary(model, tokenizer, sample_article, CONFIG["device"], CONFIG, max_len=CONFIG["max_decoder_len"])
            print("Generated:", gen)
            print("Reference:", summaries[train_size][:200])

    print("Training complete — best val:", best_val)


if __name__ == "__main__":
    main()


IMPROVED TRANSFORMER - TRAIN + INFERENCE
Loaded /content/gdrive/MyDrive/practical_data/Inshorts-Cleaned-Data.xlsx
Using SimpleTokenizer (word-level). To get better results use sentencepiece BPE.
Building simple word-level vocab...
Vocab size: 30000
Pad id: 0 Vocab size: 30000
Model params: 30442800

Epoch 1/10
  Batch 50/1407, Loss: 9.4873
  Batch 100/1407, Loss: 8.9518
  Batch 150/1407, Loss: 8.6469
  Batch 200/1407, Loss: 8.4466
  Batch 250/1407, Loss: 8.3124
  Batch 300/1407, Loss: 8.2164
  Batch 350/1407, Loss: 8.1482
  Batch 400/1407, Loss: 8.0947
  Batch 450/1407, Loss: 8.0505
  Batch 500/1407, Loss: 8.0110
  Batch 550/1407, Loss: 7.9778
  Batch 600/1407, Loss: 7.9524
  Batch 650/1407, Loss: 7.9315
  Batch 700/1407, Loss: 7.9105
  Batch 750/1407, Loss: 7.8906
  Batch 800/1407, Loss: 7.8732
  Batch 850/1407, Loss: 7.8581
  Batch 900/1407, Loss: 7.8441
  Batch 950/1407, Loss: 7.8301
  Batch 1000/1407, Loss: 7.8162
  Batch 1050/1407, Loss: 7.8059
  Batch 1100/1407, Loss: 7.7946
  Ba