In [1]:
# Imports
import os
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from tqdm import tqdm
from rouge_score import rouge_scorer

# Device (GPU if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [2]:
# Load joint vocabulary
vocab_path = '../../models/summarization/vocab.json'

with open(vocab_path, 'r') as f:
    vocab = json.load(f)

idx2word = {int(i): w for w, i in vocab.items()}
word2idx = vocab

# Special tokens
PAD_TOKEN   = '<PAD>'
UNK_TOKEN   = '<UNK>'
START_TOKEN = '<START>'
END_TOKEN   = '<END>'

In [3]:
# TransformerSummarizer Model Definition
class TransformerSummarizer(nn.Module):
    def __init__(self, vocab_size, emb_dim, nhead, ff_dim, num_layers,
                 max_article_len, max_summary_len, pad_idx):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.pos_encoder = nn.Embedding(max_article_len, emb_dim)
        self.pos_decoder = nn.Embedding(max_summary_len, emb_dim)
        
        self.transformer = nn.Transformer(
            d_model=emb_dim,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=ff_dim,
            dropout=0.1,
            batch_first=True
        )
        
        self.fc_out = nn.Linear(emb_dim, vocab_size)
        self.max_article_len = max_article_len
        self.max_summary_len = max_summary_len

    def forward(self, src, tgt):
        src_mask = None
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(1)).to(tgt.device)
        
        src_pos = self.pos_encoder(torch.arange(self.max_article_len, device=src.device)).unsqueeze(0)
        tgt_pos = self.pos_decoder(torch.arange(self.max_summary_len, device=tgt.device)).unsqueeze(0)
        
        src_emb = self.embedding(src) + src_pos[:, :src.size(1), :]
        tgt_emb = self.embedding(tgt) + tgt_pos[:, :tgt.size(1), :]
        
        outs = self.transformer(src_emb, tgt_emb, src_mask=src_mask, tgt_mask=tgt_mask)
        return self.fc_out(outs)

In [4]:
# Instantiate Model and Load Checkpoint
# Hyperparameters
VOCAB_SIZE = len(vocab)
EMBEDDING_DIM = 256
NUM_HEADS = 8
FF_DIM = 512
NUM_LAYERS = 4
MAX_ARTICLE_LEN = 400
MAX_SUMMARY_LEN = 50

pad_idx = vocab[PAD_TOKEN]

model = TransformerSummarizer(
    vocab_size=VOCAB_SIZE,
    emb_dim=EMBEDDING_DIM,
    nhead=NUM_HEADS,
    ff_dim=FF_DIM,
    num_layers=NUM_LAYERS,
    max_article_len=MAX_ARTICLE_LEN,
    max_summary_len=MAX_SUMMARY_LEN,
    pad_idx=pad_idx
).to(device)

checkpoint_path = '../../models/summarization/best_summarization_model.pt'
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()

print("Model loaded and set to eval mode.")

Model loaded and set to eval mode.


In [5]:
# Preprocessing Function
def preprocess_text(text, vocab, max_len):
    tokens = text.split()[:max_len]
    token_ids = [vocab.get(t, vocab.get(UNK_TOKEN, 1)) for t in tokens]
    # Pad to max_len if shorter
    if len(token_ids) < max_len:
        token_ids += [vocab.get(PAD_TOKEN, 0)] * (max_len - len(token_ids))
    return token_ids

In [6]:
# Beam Search Decoder for Transformer
def beam_search_decode_transformer(model, src_indices, vocab, idx2word, beam_width=4, max_summary_len=50, alpha=0.7):
    model.eval()
    with torch.no_grad():
        device = next(model.parameters()).device
        src = torch.tensor([src_indices], dtype=torch.long, device=device)

        # Encoder embeddings + positional embeddings
        src_pos = model.pos_encoder(torch.arange(model.max_article_len, device=device)).unsqueeze(0)
        src_emb = model.embedding(src) + src_pos[:, :src.size(1), :]

        memory = model.transformer.encoder(src_emb)

        start_id = vocab[START_TOKEN]
        end_id = vocab[END_TOKEN]

        beams = [(0.0, [start_id])]  # (score, token_id_sequence)

        for _ in range(max_summary_len):
            all_candidates = []
            for score, seq in beams:
                if seq[-1] == end_id:
                    all_candidates.append((score, seq))
                    continue

                tgt_seq = torch.tensor([seq], dtype=torch.long, device=device)
                tgt_pos = model.pos_decoder(torch.arange(len(seq), device=device)).unsqueeze(0)
                tgt_emb = model.embedding(tgt_seq) + tgt_pos

                tgt_mask = nn.Transformer.generate_square_subsequent_mask(len(seq)).to(device)

                decoder_output = model.transformer.decoder(tgt_emb, memory, tgt_mask=tgt_mask)

                logits = model.fc_out(decoder_output[:, -1, :])
                log_probs = F.log_softmax(logits, dim=-1).squeeze(0)  # (vocab_size,)

                top_log_probs, top_ids = torch.topk(log_probs, beam_width)

                for log_p, token_id in zip(top_log_probs.tolist(), top_ids.tolist()):
                    new_score = score + log_p
                    new_seq = seq + [token_id]
                    all_candidates.append((new_score, new_seq))

            # Length normalization
            beams = sorted(all_candidates,
                           key=lambda x: x[0] / (len(x[1]) ** alpha),
                           reverse=True)[:beam_width]

            if all(seq[-1] == end_id for _, seq in beams):
                break

        best_seq = beams[0][1]

        # Remove start token and tokens following end token if present
        if end_id in best_seq:
            best_seq = best_seq[1:best_seq.index(end_id)]
        else:
            best_seq = best_seq[1:]

        summary = ' '.join(idx2word.get(i, UNK_TOKEN) for i in best_seq)
        return summary

In [7]:
# Load Validation Data
val_df = pd.read_csv('../../data/summarization/processed_val_split.csv')

In [8]:
# Setup ROUGE Scorer
scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)

In [9]:
# Evaluation Loop with Beam Search Decoding
n_samples = 100
rouge1_scores = []
rougeL_scores = []

for i in tqdm(range(n_samples), desc="Evaluating samples"):
    article_text = val_df['clean_article'].iloc[i]
    reference_summary = val_df['clean_summary'].iloc[i]

    src_indices = preprocess_text(article_text, vocab, MAX_ARTICLE_LEN)

    pred_summary = beam_search_decode_transformer(
        model, src_indices, vocab, idx2word,
        beam_width=4,
        max_summary_len=MAX_SUMMARY_LEN,
        alpha=0.7
    )

    scores = scorer.score(reference_summary, pred_summary)
    rouge1_scores.append(scores['rouge1'].fmeasure)
    rougeL_scores.append(scores['rougeL'].fmeasure)

    if i < 3:
        print(f"\nSample {i+1}")
        print("Article:\n", article_text, "\n")
        print("Reference Summary:\n", reference_summary, "\n")
        print("Predicted Summary:\n", pred_summary, "\n")
        print(f"ROUGE-1: {scores['rouge1'].fmeasure:.4f}, ROUGE-L: {scores['rougeL'].fmeasure:.4f}\n")

avg_rouge1 = sum(rouge1_scores) / len(rouge1_scores)
avg_rougeL = sum(rougeL_scores) / len(rougeL_scores)

print(f"\nAverage ROUGE-1 score over {n_samples} samples: {avg_rouge1:.4f}")
print(f"Average ROUGE-L score over {n_samples} samples: {avg_rougeL:.4f}")

Evaluating samples:   1%|▋                                                              | 1/100 [00:01<02:30,  1.52s/it]


Sample 1
Article:
 cnn sixty six people have died from west nile virus infections this year, and the number of human cases has grown to 1,590, the u.s. centers for disease control and prevention said wednesday. that's the highest case count through the last week of august since the virus was first detected in the united states in 1999. nearly half of all the infections have occurred in texas, where officials said later wednesday that 894 cases have been reported along with 34 deaths. those numbers are going to go up, said dr. david lakey, commissioner for the texas department of state health services. lakey said it looks like 2012 will be the worst year so far when it comes to west nile virus cases. in 2003, texas reported 40 deaths because of the virus, and health officials believe they will surpass that number this year. all lower 48 states are now reporting west nile activity, and 43 states have reported at least one person infected with the virus. fast facts on west nile virus . m

Evaluating samples:   2%|█▎                                                             | 2/100 [00:02<02:11,  1.34s/it]


Sample 2
Article:
 it should come as no surprise that joe simpson appears wise beyond his years. the 26 year old wasps scrum half speaks with the air of a man twice his age after a year of emotional turmoil following the tragically premature death of his 58 year old mother brigid, who passed away last january after a six month battle with skin cancer. that simpson has come through as a calm and reflective man, and happens to be playing the best rugby of his life, is testament to him and the club. last season was tough, simpson said. it was an emotionally draining year. she sacrificed so much for me, driving me all over the country from the age of 10, to club games and to england camps. some of the locations were far from glamorous and i remember her standing watching in the pouring rain, all wrapped up. my family have always been incredibly supportive. wasps scrum half joe simpson left tries to evade the tackle of bath fly half george ford at the rec . simpson left puts his body on th

Evaluating samples:   3%|█▉                                                             | 3/100 [00:04<02:18,  1.43s/it]


Sample 3
Article:
 cnn when malcolm x was assassinated on february 21, 1965, many americans viewed his killing as simply the result of an ongoing feud between him and the nation of islam. he had publicly left the nation of islam in march 1964, and as the months wore on the animus between malcolm's camp and the nation of islam grew increasingly caustic, with bitter denunciations coming from both sides. a week before he was killed, malcolm's home owned by the nation of islam, which was seeking to evict him was firebombed, and malcolm believed members of the nation of islam to be responsible. for investigators and commentators alike, then, his death was an open and shut case muslims did it. yet although three members of the nation of islam were tried and found guilty for the killing, two of them maintained their innocence and decades of research has since cast doubt on the outcome of the case. tens of thousands of declassified pages documenting government surveillance, infiltration and d

Evaluating samples: 100%|█████████████████████████████████████████████████████████████| 100/100 [01:52<00:00,  1.12s/it]


Average ROUGE-1 score over 100 samples: 0.1853
Average ROUGE-L score over 100 samples: 0.1387



