In [1]:
# Imports
import os, json, math, random, time
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
import pandas as pd
from tqdm.auto import tqdm
from rouge_score import rouge_scorer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Paths
DATA_DIR = '../../data/summarization/'
MODEL_DIR = '../../models/summarization/'
VOCAB_PATH = os.path.join(MODEL_DIR, 'vocab.json')
CHECKPOINT_PATH = os.path.join(MODEL_DIR, 'best_summarization_model.pt')

# Device Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
# Hyperparams
MAX_ART_LEN = 400
MAX_SUM_LEN = 50
EMB_DIM = 256
N_HEADS = 8
FF_DIM = 512
LAYERS = 4
BATCH_SIZE = 32
EXTRA_EPOCHS = 12
INIT_LR = 2e-4
WARMUP_EPOCHS = 1
LABEL_SMOOTH = 0.1
CLIP_NORM = 1.0
TEACHER_P0 = 1.0
TEACHER_DECAY = 0.9
PATIENCE = 3

In [4]:
# Load Vocabulary
with open(VOCAB_PATH, 'r') as f:
    vocab = json.load(f)

PAD_TOKEN   = '<PAD>'
UNK_TOKEN   = '<UNK>'
START_TOKEN = '<START>'
END_TOKEN   = '<END>'
pad_idx = vocab[PAD_TOKEN]

print(f"Loaded vocab with size: {len(vocab)}")
print(f"PAD token index: {pad_idx}")

Loaded vocab with size: 30000
PAD token index: 0


In [5]:
# Dataset class
def tokenize(txt): 
    return txt.split()

class SummarizationDataset(Dataset):
    def __init__(self, df, vocab, max_article_len=400, max_summary_len=50):
        self.articles = df['clean_article'].values
        self.summaries = df['clean_summary'].values
        self.vocab = vocab
        self.max_article_len = max_article_len
        self.max_summary_len = max_summary_len

    def encode(self, text, max_len, add_specials=False):
        tokens = tokenize(text)
        if add_specials:
            tokens = [START_TOKEN] + tokens[:max_len-2] + [END_TOKEN]
        else:
            tokens = tokens[:max_len]
        ids = [self.vocab.get(w, self.vocab[UNK_TOKEN]) for w in tokens]
        if len(ids) < max_len:
            ids += [self.vocab[PAD_TOKEN]] * (max_len - len(ids))
        return ids[:max_len]

    def __len__(self):
        return len(self.articles)

    def __getitem__(self, idx):
        src = torch.tensor(self.encode(self.articles[idx], self.max_article_len, add_specials=False), dtype=torch.long)
        tgt = torch.tensor(self.encode(self.summaries[idx], self.max_summary_len, add_specials=True), dtype=torch.long)
        return src, tgt

In [6]:
# Load data
train_df = pd.read_csv(os.path.join(DATA_DIR, 'processed_train_split.csv'))
val_df = pd.read_csv(os.path.join(DATA_DIR, 'processed_val_split.csv'))

train_loader = DataLoader(SummarizationDataset(train_df, vocab, MAX_ART_LEN, MAX_SUM_LEN), BATCH_SIZE, True, num_workers=4, pin_memory=True)
val_loader = DataLoader(SummarizationDataset(val_df, vocab, MAX_ART_LEN, MAX_SUM_LEN), BATCH_SIZE, False, num_workers=4, pin_memory=True)

print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")

Train batches: 1515 | Val batches: 268


In [7]:
# TransformerSummarizer Model Class Definition
class TransformerSummarizer(nn.Module):
    def __init__(self, vocab_size, emb_dim, nhead, ff_dim, num_layers, max_article_len, max_summary_len, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.pos_encoder = nn.Embedding(max_article_len, emb_dim)
        self.pos_decoder = nn.Embedding(max_summary_len, emb_dim)
        self.transformer = nn.Transformer(
            d_model=emb_dim,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=ff_dim,
            dropout=0.1,
            batch_first=True
        )
        self.fc_out = nn.Linear(emb_dim, vocab_size)
        self.max_article_len = max_article_len
        self.max_summary_len = max_summary_len

    def forward(self, src, tgt):
        src_mask = None
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        src_pos = self.pos_encoder(torch.arange(self.max_article_len, device=src.device)).unsqueeze(0)
        tgt_pos = self.pos_decoder(torch.arange(self.max_summary_len, device=tgt.device)).unsqueeze(0)
        
        src_emb = self.embedding(src) + src_pos[:, :src.size(1), :]
        tgt_emb = self.embedding(tgt) + tgt_pos[:, :tgt.size(1), :]
        
        outs = self.transformer(src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask)
        return self.fc_out(outs)

In [8]:
# Instantiate and load model
model = TransformerSummarizer(
    vocab_size=len(vocab),
    emb_dim=EMB_DIM,
    nhead=N_HEADS,
    ff_dim=FF_DIM,
    num_layers=LAYERS,
    max_article_len=MAX_ART_LEN,
    max_summary_len=MAX_SUM_LEN,
    pad_idx=pad_idx
).to(device)

# Load checkpoint
model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))
print('Checkpoint loaded successfully!')

Checkpoint loaded successfully!


In [9]:
# Setup optimizer, criterion, scheduler
optim = torch.optim.Adam(model.parameters(), lr=INIT_LR)
sched = CosineAnnealingWarmRestarts(optim, T_0=max(len(train_loader)//WARMUP_EPOCHS,1), T_mult=2, eta_min=1e-6)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)

In [10]:
# Verify checkpoint loading by checking validation loss
def eval_loss():
    model.eval()
    tot = 0
    with torch.no_grad():
        for src, tgt in val_loader:
            src, tgt = src.to(device), tgt.to(device)
            # CONSISTENT with training: use tgt[:, :-1] as input
            logits = model(src, tgt[:, :-1])
            loss = criterion(logits.reshape(-1, logits.size(-1)), tgt[:, 1:].reshape(-1))
            tot += loss.item()
    return tot/len(val_loader)

print('Initial val-loss:', eval_loss())

Initial val-loss: 4.843511923035579


In [11]:
# ROUGE-L helper function
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

def rougeL_sample(model, n=50):
    model.eval()
    scores = []
    with torch.no_grad():
        for i in range(n):
            art = val_df['clean_article'].iloc[i].split()[:MAX_ART_LEN]
            art_ids = [vocab.get(t, vocab[UNK_TOKEN]) for t in art] + [pad_idx]*(MAX_ART_LEN-len(art))
            src = torch.tensor(art_ids, device=device).unsqueeze(0)
            
            # Simple greedy decode
            tgt = torch.tensor([[vocab[START_TOKEN]]], device=device)
            for _ in range(MAX_SUM_LEN-1):
                logits = model(src, tgt)
                next_id = logits[:,-1].argmax(-1, keepdim=True)
                tgt = torch.cat([tgt, next_id], 1)
                if next_id.item() == vocab[END_TOKEN]:
                    break
            
            # Convert to text (skip START, stop at END or PAD)
            pred_ids = tgt[0, 1:].tolist()
            if vocab[END_TOKEN] in pred_ids:
                pred_ids = pred_ids[:pred_ids.index(vocab[END_TOKEN])]
            pred_text = ' '.join([list(vocab.keys())[list(vocab.values()).index(idx)] 
                                for idx in pred_ids if idx not in [vocab[PAD_TOKEN], vocab[START_TOKEN]]])
            
            ref_text = val_df['clean_summary'].iloc[i]
            scores.append(scorer.score(ref_text, pred_text)['rougeL'].fmeasure)
    return sum(scores)/len(scores)


In [12]:
# Fine-tuning Loop with Scheduled Sampling
best_val = float('inf')
patience_counter = 0

for epoch in range(1, EXTRA_EPOCHS+1):
    tf_prob = TEACHER_P0 * (TEACHER_DECAY ** (epoch-1))
    model.train()
    tot = 0
    
    for src, tgt in tqdm(train_loader, desc=f'Epoch {epoch}'):
        src, tgt = src.to(device), tgt.to(device)
        optim.zero_grad()

        # Scheduled sampling input build
        decoder_in = tgt[:, :-1].clone()  # Ground truth input
        with torch.no_grad():
            # Get model predictions for previous time steps
            pred_ids = model(src, decoder_in).argmax(-1)
        
        # Apply scheduled sampling mask
        mask = (torch.rand_like(decoder_in.float()) > tf_prob) & (decoder_in != pad_idx)
        decoder_in[mask] = pred_ids[mask]

        logits = model(src, decoder_in)
        loss = criterion(logits.reshape(-1, logits.size(-1)), tgt[:, 1:].reshape(-1))
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), CLIP_NORM)
        optim.step()
        sched.step()
        tot += loss.item()

    train_loss = tot/len(train_loader)
    val_loss = eval_loss()
    
    print(f'Epoch {epoch}: train {train_loss:.3f} | val {val_loss:.3f} | tf_prob {tf_prob:.2f}')

    if val_loss < best_val - 1e-4:
        best_val = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), CHECKPOINT_PATH)
        print('New best model saved.')
    else:
        patience_counter += 1
        print(f'  No improvement (patience {patience_counter}/{PATIENCE})')
        if patience_counter >= PATIENCE:
            print('Early stopping.')
            break

    # ROUGE-L sample check every 3 epochs
    if epoch % 3 == 0:
        rouge_score = rougeL_sample(model)
        print(f'  ROUGE-L (50 samples): {rouge_score:.3f}')

print(f'Fine-tuning complete. Best validation loss: {best_val:.4f}')

Epoch 1: 100%|██████████████████████████████████████████████████████████████████████| 1515/1515 [08:59<00:00,  2.81it/s]


Epoch 1: train 4.273 | val 4.850 | tf_prob 1.00
New best model saved.


Epoch 2: 100%|██████████████████████████████████████████████████████████████████████| 1515/1515 [09:07<00:00,  2.77it/s]


Epoch 2: train 4.533 | val 4.871 | tf_prob 0.90
  No improvement (patience 1/5)


Epoch 3: 100%|██████████████████████████████████████████████████████████████████████| 1515/1515 [09:12<00:00,  2.74it/s]


Epoch 3: train 4.524 | val 4.876 | tf_prob 0.81
  No improvement (patience 2/5)
  ROUGE-L (50 samples): 0.083


Epoch 4:   1%|▍                                                                        | 8/1515 [00:03<10:31,  2.39it/s]


KeyboardInterrupt: 