In [1]:
# Imports
import os
import json
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import Counter
from tqdm import tqdm
import numpy as np

In [2]:
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [3]:
# Paths and tokens
DATA_DIR = '../../data/summarization/'
SAVE_DIR = '../../models/summarization/'
VOCAB_PATH = os.path.join(SAVE_DIR, 'vocab.json')
os.makedirs(SAVE_DIR, exist_ok=True)

PAD_TOKEN   = '<PAD>'
UNK_TOKEN   = '<UNK>'
START_TOKEN = '<START>'
END_TOKEN   = '<END>'

In [4]:
# Data Loading
train_df = pd.read_csv(os.path.join(DATA_DIR, 'processed_train_split.csv'))
val_df   = pd.read_csv(os.path.join(DATA_DIR, 'processed_val_split.csv'))

In [5]:
# Vocab
def tokenize(text):
    return text.split()

def build_vocab(samples, min_freq=2, max_vocab_size=30000):
    counter = Counter()
    for text in samples:
        counter.update(tokenize(text))
    vocab = [PAD_TOKEN, UNK_TOKEN, START_TOKEN, END_TOKEN] + \
        [w for w, f in counter.items() if f >= min_freq][:max_vocab_size-4]
    word2idx = {w: i for i, w in enumerate(vocab)}
    return word2idx

combined_texts = list(train_df['clean_article']) + list(train_df['clean_summary'])
vocab = build_vocab(combined_texts)

with open(VOCAB_PATH, 'w') as f:
    json.dump(vocab, f)
print(f"Vocab size: {len(vocab)}")

Vocab size: 30000


In [6]:
# Hyperparams
MAX_ARTICLE_LEN = 400
MAX_SUMMARY_LEN = 50
BATCH_SIZE = 32
EMBEDDING_DIM = 256
ENC_HIDDEN_DIM = 256
DEC_HIDDEN_DIM = 256
NUM_EPOCHS = 20
PATIENCE = 3
VOCAB_SIZE = len(vocab)
NUM_HEADS = 8
FF_DIM = 512
NUM_LAYERS = 4
SAVE_PATH = os.path.join(SAVE_DIR, 'best_summarization_model.pt')

In [7]:
# Dataset & Loader
class SummarizationDataset(Dataset):
    def __init__(self, df, vocab, max_article_len=400, max_summary_len=50):
        self.articles = df['clean_article'].values
        self.summaries = df['clean_summary'].values
        self.vocab = vocab
        self.max_article_len = max_article_len
        self.max_summary_len = max_summary_len

    def encode(self, text, max_len, add_specials=False):
        tokens = tokenize(text)
        if add_specials:
            tokens = [START_TOKEN] + tokens[:max_len-2] + [END_TOKEN]
        else:
            tokens = tokens[:max_len]
        ids = [self.vocab.get(w, self.vocab[UNK_TOKEN]) for w in tokens]
        if len(ids) < max_len:
            ids += [self.vocab[PAD_TOKEN]] * (max_len - len(ids))
        return ids[:max_len]

    def __len__(self):
        return len(self.articles)

    def __getitem__(self, idx):
        src = torch.tensor(self.encode(self.articles[idx],  self.max_article_len, add_specials=False), dtype=torch.long)
        tgt = torch.tensor(self.encode(self.summaries[idx], self.max_summary_len,  add_specials=True),  dtype=torch.long)
        return src, tgt

train_dataset = SummarizationDataset(train_df, vocab, MAX_ARTICLE_LEN, MAX_SUMMARY_LEN)
val_dataset   = SummarizationDataset(val_df, vocab, MAX_ARTICLE_LEN, MAX_SUMMARY_LEN)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, pin_memory=True)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, num_workers=4, pin_memory=True)

In [8]:
# Transformer Model
class TransformerSummarizer(nn.Module):
    def __init__(self, vocab_size, emb_dim, nhead, ff_dim, num_layers, max_article_len, max_summary_len, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.pos_encoder = nn.Embedding(max_article_len, emb_dim)
        self.pos_decoder = nn.Embedding(max_summary_len, emb_dim)
        self.transformer = nn.Transformer(
            d_model=emb_dim,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=ff_dim,
            dropout=0.1,
            batch_first=True
        )
        self.fc_out = nn.Linear(emb_dim, vocab_size)
        self.max_article_len = max_article_len
        self.max_summary_len = max_summary_len

    def forward(self, src, tgt):
        src_mask = None
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        src_pos = self.pos_encoder(torch.arange(self.max_article_len, device=src.device)).unsqueeze(0)
        tgt_pos = self.pos_decoder(torch.arange(self.max_summary_len, device=tgt.device)).unsqueeze(0)
        src_emb = self.embedding(src) + src_pos[:, :src.size(1), :]
        tgt_emb = self.embedding(tgt) + tgt_pos[:, :tgt.size(1), :]
        outs = self.transformer(src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask)
        return self.fc_out(outs)

pad_idx = vocab[PAD_TOKEN]
model = TransformerSummarizer(
    vocab_size=VOCAB_SIZE,
    emb_dim=EMBEDDING_DIM,
    nhead=NUM_HEADS,
    ff_dim=FF_DIM,
    num_layers=NUM_LAYERS,
    max_article_len=MAX_ARTICLE_LEN,
    max_summary_len=MAX_SUMMARY_LEN,
    pad_idx=pad_idx
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss(ignore_index=pad_idx)
import torch.optim.lr_scheduler as lr_scheduler
scheduler = lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min',
    factor=0.5, patience=2,
    min_lr=1e-6
)

In [None]:
# Training Loop
best_val_loss = float('inf')
counter = 0

for epoch in range(NUM_EPOCHS):
    model.train()
    train_loss = 0
    for src, tgt in tqdm(train_loader, desc=f"Epoch {epoch+1} [Train]", leave=False):
        src, tgt = src.to(device), tgt.to(device)
        optimizer.zero_grad()
        output = model(src, tgt[:, :-1])
        output = output.reshape(-1, output.shape[-1])
        target = tgt[:, 1:].reshape(-1)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    avg_train_loss = train_loss / len(train_loader)

    model.eval()
    val_loss = 0
    with torch.no_grad():
        for src, tgt in tqdm(val_loader, desc=f"Epoch {epoch+1} [Val]", leave=False):
            src, tgt = src.to(device), tgt.to(device)
            output = model(src, tgt[:, :-1])
            output = output.reshape(-1, output.shape[-1])
            target = tgt[:, 1:].reshape(-1)
            loss = criterion(output, target)
            val_loss += loss.item()
    avg_val_loss = val_loss / len(val_loader)
    scheduler.step(avg_val_loss)

    print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), SAVE_PATH)
        print(f"--> New best model saved at epoch {epoch+1} with val loss {best_val_loss:.4f}")
        counter = 0
    else:
        counter += 1
        print(f"Validation loss did not improve. Early stopping counter: {counter}/{PATIENCE}")
        if counter >= PATIENCE:
            print("Early stopping triggered. Training halted.")
            break

print(f"Training complete. Best validation loss: {best_val_loss:.4f}")

                                                                                                                        

Epoch 1: Train Loss: 6.4991 | Val Loss: 6.0570
--> New best model saved at epoch 1 with val loss 6.0570


                                                                                                                        

Epoch 2: Train Loss: 5.8925 | Val Loss: 5.7337
--> New best model saved at epoch 2 with val loss 5.7337


                                                                                                                        

Epoch 3: Train Loss: 5.6341 | Val Loss: 5.5503
--> New best model saved at epoch 3 with val loss 5.5503


                                                                                                                        

Epoch 4: Train Loss: 5.4555 | Val Loss: 5.4206
--> New best model saved at epoch 4 with val loss 5.4206


                                                                                                                        

Epoch 5: Train Loss: 5.3142 | Val Loss: 5.3190
--> New best model saved at epoch 5 with val loss 5.3190


                                                                                                                        

Epoch 6: Train Loss: 5.1942 | Val Loss: 5.2396
--> New best model saved at epoch 6 with val loss 5.2396


                                                                                                                        

Epoch 7: Train Loss: 5.0888 | Val Loss: 5.1717
--> New best model saved at epoch 7 with val loss 5.1717


                                                                                                                        

Epoch 8: Train Loss: 4.9926 | Val Loss: 5.1124
--> New best model saved at epoch 8 with val loss 5.1124


Epoch 9 [Train]:   5%|███                                                             | 73/1515 [00:30<09:44,  2.47it/s]