In [1]:
# Imports
import os
import json
import torch
import torch.nn as nn
import torch.optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd
from torch.utils.data import Dataset

In [2]:
# Paths
DATA_DIR = '../../data/summarization/'
MODEL_DIR = '../../models/summarization/'
VOCAB_PATH = os.path.join(MODEL_DIR, 'vocab.json')
CHECKPOINT_PATH = os.path.join(MODEL_DIR, 'best_summarization_model.pt')

# Device Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
# Hyperparams
MAX_ARTICLE_LEN = 400
MAX_SUMMARY_LEN = 50
EMBEDDING_DIM = 256
NUM_HEADS = 8
FF_DIM = 512
NUM_LAYERS = 4
BATCH_SIZE = 32
PATIENCE = 3
NUM_EPOCHS = 10
LEARNING_RATE = 5e-5

In [4]:
# Load Vocabulary
with open(VOCAB_PATH, 'r') as f:
    vocab = json.load(f)

# Tokens
PAD_TOKEN   = '<PAD>'
UNK_TOKEN   = '<UNK>'
START_TOKEN = '<START>'
END_TOKEN   = '<END>'
pad_idx = vocab[PAD_TOKEN]
print(f"Loaded vocab with size: {len(vocab)}")
print(f"PAD token index: {pad_idx}")

Loaded vocab with size: 30000
PAD token index: 0


In [5]:
def tokenize(text):
    return text.split()

class SummarizationDataset(Dataset):
    def __init__(self, df, vocab, max_article_len=MAX_ARTICLE_LEN, max_summary_len=MAX_SUMMARY_LEN):
        self.articles = df['clean_article'].values
        self.summaries = df['clean_summary'].values
        self.vocab = vocab
        self.max_article_len = max_article_len
        self.max_summary_len = max_summary_len

    def encode(self, text, max_len, add_specials=False):
        tokens = tokenize(text)
        if add_specials:
            tokens = [START_TOKEN] + tokens[:max_len-2] + [END_TOKEN]
        else:
            tokens = tokens[:max_len]
        ids = [self.vocab.get(w, self.vocab[UNK_TOKEN]) for w in tokens]
        if len(ids) < max_len:
            ids += [self.vocab[PAD_TOKEN]] * (max_len - len(ids))
        return ids[:max_len]

    def __len__(self):
        return len(self.articles)

    def __getitem__(self, idx):
        src = torch.tensor(self.encode(self.articles[idx], self.max_article_len, add_specials=False), dtype=torch.long)
        tgt = torch.tensor(self.encode(self.summaries[idx], self.max_summary_len, add_specials=True), dtype=torch.long)
        return src, tgt

In [6]:
# Load train & val DataFrames
train_df = pd.read_csv(os.path.join(DATA_DIR, 'processed_train_split.csv'))
val_df = pd.read_csv(os.path.join(DATA_DIR, 'processed_val_split.csv'))

# Create datasets and dataloaders
train_dataset = SummarizationDataset(train_df, vocab, MAX_ARTICLE_LEN, MAX_SUMMARY_LEN)
val_dataset   = SummarizationDataset(val_df, vocab, MAX_ARTICLE_LEN, MAX_SUMMARY_LEN)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)

print(f"Train batches: {len(train_loader)} | Val batches: {len(val_loader)}")

Train batches: 1515 | Val batches: 268


In [7]:
# TransformerSummarizer Model Class Definition
class TransformerSummarizer(nn.Module):
    def __init__(self, vocab_size, emb_dim, nhead, ff_dim, num_layers, max_article_len, max_summary_len, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.pos_encoder = nn.Embedding(max_article_len, emb_dim)
        self.pos_decoder = nn.Embedding(max_summary_len, emb_dim)
        self.transformer = nn.Transformer(
            d_model=emb_dim,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=ff_dim,
            dropout=0.1,
            batch_first=True
        )
        self.fc_out = nn.Linear(emb_dim, vocab_size)
        self.max_article_len = max_article_len
        self.max_summary_len = max_summary_len

    def forward(self, src, tgt):
        src_mask = None
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        src_pos = self.pos_encoder(torch.arange(self.max_article_len, device=src.device)).unsqueeze(0)
        tgt_pos = self.pos_decoder(torch.arange(self.max_summary_len, device=tgt.device)).unsqueeze(0)
        
        src_emb = self.embedding(src) + src_pos[:, :src.size(1), :]
        tgt_emb = self.embedding(tgt) + tgt_pos[:, :tgt.size(1), :]
        
        outs = self.transformer(src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask)
        return self.fc_out(outs)

In [8]:
# Instantiate Model and Load Checkpoint
model = TransformerSummarizer(
    vocab_size=len(vocab),
    emb_dim=EMBEDDING_DIM,
    nhead=NUM_HEADS,
    ff_dim=FF_DIM,
    num_layers=NUM_LAYERS,
    max_article_len=MAX_ARTICLE_LEN,
    max_summary_len=MAX_SUMMARY_LEN,
    pad_idx=pad_idx
).to(device)

# Load checkpoint
if os.path.exists(CHECKPOINT_PATH):
    model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=device))
    print("Checkpoint loaded successfully.")
else:
    print("No checkpoint found; training from scratch.")

Checkpoint loaded successfully.


In [9]:
# Setup optimizer, criterion, scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, min_lr=1e-6)

best_val_loss = float('inf')
counter = 0

In [10]:
# VERIFY checkpoint loading by checking validation loss BEFORE fine-tuning
print("Verifying checkpoint loading...")
model.eval()
val_loss = 0
with torch.no_grad():
    for src, tgt in val_loader:
        src, tgt = src.to(device), tgt.to(device)
        output = model(src, tgt[:, :-1])
        output = output.reshape(-1, output.shape[-1])
        target = tgt[:, 1:].reshape(-1)
        loss = criterion(output, target)
        val_loss += loss.item()

initial_val_loss = val_loss / len(val_loader)
print(f"Initial validation loss after loading checkpoint: {initial_val_loss:.4f}")
print(f"Expected to be around 4.8632 from training")

Verifying checkpoint loading...
Initial validation loss after loading checkpoint: 4.8626
Expected to be around 4.8632 from training


In [12]:
# Fine-tuning Loop with Early Stopping
print("Starting fine-tuning...")
for epoch in range(NUM_EPOCHS):
    model.train()
    train_loss = 0

    for src, tgt in tqdm(train_loader, desc=f"Epoch {epoch + 1} [Train]", leave=False):
        src, tgt = src.to(device), tgt.to(device)
        optimizer.zero_grad()

        output = model(src, tgt[:, :-1])
        output = output.reshape(-1, output.shape[-1])
        target = tgt[:, 1:].reshape(-1)
        
        loss = criterion(output, target)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        train_loss += loss.item()

    avg_train_loss = train_loss / len(train_loader)

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for src, tgt in tqdm(val_loader, desc=f"Epoch {epoch + 1} [Val]", leave=False):
            src, tgt = src.to(device), tgt.to(device)
            output = model(src, tgt[:, :-1])
            output = output.reshape(-1, output.shape[-1])
            target = tgt[:, 1:].reshape(-1)
            loss = criterion(output, target)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    scheduler.step(avg_val_loss)

    print(f"Epoch {epoch + 1}: Train Loss={avg_train_loss:.4f} | Val Loss={avg_val_loss:.4f}")

    # Early stopping and checkpointing
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), CHECKPOINT_PATH)
        print(f"--> New best model saved at epoch {epoch + 1} with val loss {best_val_loss:.4f}")
        counter = 0
    else:
        counter += 1
        print(f"Validation loss did not improve. Early stopping count: {counter}/{PATIENCE}")
        if counter >= PATIENCE:
            print("Early stopping triggered. Finishing training.")
            break

print(f"Fine-tuning complete. Best validation loss: {best_val_loss:.4f}")

Starting fine-tuning...


                                                                                                                        

Epoch 1: Train Loss=4.1774 | Val Loss=4.8435
--> New best model saved at epoch 1 with val loss 4.8435


                                                                                                                        

Epoch 2: Train Loss=4.1386 | Val Loss=4.8478
Validation loss did not improve. Early stopping count: 1/3


                                                                                                                        

Epoch 3: Train Loss=4.1064 | Val Loss=4.8442
Validation loss did not improve. Early stopping count: 2/3


                                                                                                                        

Epoch 4: Train Loss=4.0775 | Val Loss=4.8485
Validation loss did not improve. Early stopping count: 3/3
Early stopping triggered. Finishing training.
Fine-tuning complete. Best validation loss: 4.8435


