In [8]:
# Imports
import os
import json
import torch
import torch.nn as nn
import pandas as pd
from torch.utils.data import Dataset

In [9]:
# Paths
save_path = '../../models/summarization/'
vocab_file = os.path.join(save_path, 'vocab.json')
model_file = os.path.join(save_path, 'best_summarization_model.pt')
val_data_file = '../../data/summarization/processed_val_split.csv'

In [10]:
# Detect device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [11]:
# Load vocab
with open(vocab_file, 'r') as f:
    vocab = json.load(f)

# Invert vocab for decoding
idx2word = {int(v): k for k, v in vocab.items()}

In [12]:
# Tokenizer consistent with training
def tokenize(text):
    return text.split()

In [13]:
# Encoding function (similar to Dataset)
def encode(text, vocab, max_len):
    tokens = tokenize(text)
    ids = [vocab.get(w, vocab['<UNK>']) for w in tokens][:max_len]
    ids += [vocab['<PAD>']] * (max_len - len(ids))
    return ids

In [14]:
# Model architecture
class Seq2SeqBaseline(nn.Module):
    def __init__(self, input_vocab_size, target_vocab_size, emb_dim=128, hidden_dim=256):
        super().__init__()
        self.encoder_emb = nn.Embedding(input_vocab_size, emb_dim, padding_idx=0)
        self.encoder_lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
        self.decoder_emb = nn.Embedding(target_vocab_size, emb_dim, padding_idx=0)
        self.decoder_lstm = nn.LSTM(emb_dim, hidden_dim, batch_first=True)
        self.fc_out = nn.Linear(hidden_dim, target_vocab_size)

    def forward(self, src, tgt):
        _, (hidden, cell) = self.encoder_lstm(self.encoder_emb(src))
        decoder_outputs, _ = self.decoder_lstm(self.decoder_emb(tgt), (hidden, cell))
        logits = self.fc_out(decoder_outputs)
        return logits

In [15]:
# Instantiate model with vocab sizes
input_vocab_size = len(vocab)
target_vocab_size = len(vocab)
model = Seq2SeqBaseline(input_vocab_size, target_vocab_size)
model.load_state_dict(torch.load(model_file, map_location=device))
model.to(device)
model.eval()
print("Model loaded and set to eval mode.")

Model loaded and set to eval mode.


In [16]:
# Greedy decoding function
def greedy_decode(model, article, vocab, idx2word, max_article_len=400, max_summary_len=50):
    model.eval()
    # Encode input article
    src_ids = encode(article, vocab, max_article_len)
    src_tensor = torch.tensor([src_ids], dtype=torch.long).to(device)

    with torch.no_grad():
        _, (hidden, cell) = model.encoder_lstm(model.encoder_emb(src_tensor))
        inputs = torch.tensor([[vocab['<PAD>']]], dtype=torch.long).to(device)
        summary_ids = []
        for _ in range(max_summary_len):
            emb = model.decoder_emb(inputs)
            output, (hidden, cell) = model.decoder_lstm(emb, (hidden, cell))
            logits = model.fc_out(output[:, -1, :])
            next_token = logits.argmax(dim=1)
            token_id = next_token.item()
            summary_ids.append(token_id)
            inputs = next_token.unsqueeze(1)
            if token_id == vocab.get('<PAD>', -1):
                break

    # Convert token IDs to words, ignoring padding tokens
    summary = ' '.join([idx2word.get(i, '<UNK>') for i in summary_ids if i != vocab.get('<PAD>', -1)])
    return summary

In [17]:
# Load validation data
val_df = pd.read_csv(val_data_file)

In [18]:
# Few samples for evaluation
num_samples = 3

for i in range(num_samples):
    article = val_df['clean_article'].iloc[i]
    reference = val_df['clean_highlights'].iloc[i]
    prediction = greedy_decode(model, article, vocab, idx2word)
    print(f"\nSample {i+1}")
    print("ARTICLE:")
    print(article)
    print("\nREFERENCE SUMMARY:")
    print(reference)
    print("\nMODEL SUMMARY:")
    print(prediction)
    print("-" * 80)


Sample 1
ARTICLE:
cnn sixty six people have died from west nile virus infections this year, and the number of human cases has grown to 1,590, the u.s. centers for disease control and prevention said wednesday. that's the highest case count through the last week of august since the virus was first detected in the united states in 1999. nearly half of all the infections have occurred in texas, where officials said later wednesday that 894 cases have been reported along with 34 deaths. those numbers are going to go up, said dr. david lakey, commissioner for the texas department of state health services. lakey said it looks like 2012 will be the worst year so far when it comes to west nile virus cases. in 2003, texas reported 40 deaths because of the virus, and health officials believe they will surpass that number this year. all lower 48 states are now reporting west nile activity, and 43 states have reported at least one person infected with the virus. fast facts on west nile virus . mo