In [None]:
!pip install -q transformers datasets rouge-score bert-score spacy tqdm
!python -m spacy download en_core_web_sm

In [None]:
import os
import math
import random
from tqdm.auto import tqdm
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import BertTokenizerFast, BertModel
from datasets import load_dataset
from rouge_score import rouge_scorer
from bert_score import score as bertscore_score

import spacy
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
nlp = spacy.load("en_core_web_sm")

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

In [None]:
# ---------- Hyperparams ----------
NUM_SAMPLES = 1000
SENT_MAX_TOKENS = 128
SIM_THRESHOLD = 0.30
GAT_HID = 256
BERT_MODEL = "bert-base-uncased"
TOP_K_SENT = 3
BATCH_ENCODING = 32

In [None]:
# Load dataset
dataset = load_dataset("cnn_dailymail", "3.0.0", split=f"test[:{NUM_SAMPLES}]")
print("Loaded dataset samples:", len(dataset))

In [None]:
# Utility: sentence-splitter
def split_sentences(article):
    doc = nlp(article)
    sents = [s.text.strip() for s in doc.sents if len(s.text.strip())>10]  # remove very short noisy sentences
    return sents if len(sents)>0 else [article[:300]]  # fallback

In [None]:
# BERT sentence encoder (averaged token embeddings)
tokenizer = BertTokenizerFast.from_pretrained(BERT_MODEL)
bert = BertModel.from_pretrained(BERT_MODEL).to(device)
bert.eval()

In [None]:
@torch.no_grad()
def encode_sentences(sent_list):
    """
    sentence_list: list[str] -> returns numpy array (N, hidden)
    encodes sentences in batches using bert; returns CLS-pooled embedding (or mean token embeddings).
    """
    embs = []
    for i in range(0, len(sent_list), BATCH_ENCODING):
        batch = sent_list[i:i+BATCH_ENCODING]
        encoded = tokenizer(batch, truncation=True, padding=True, max_length=SENT_MAX_TOKENS, return_tensors="pt")
        input_ids = encoded["input_ids"].to(device)
        attn = encoded["attention_mask"].to(device)
        out = bert(input_ids=input_ids, attention_mask=attn)
        # use mean pooling of last_hidden_state (excluding padding)
        last = out.last_hidden_state  # (B, L, H)
        mask = attn.unsqueeze(-1)     # (B, L, 1)
        summed = (last * mask).sum(1) # (B, H)
        denom = mask.sum(1).clamp(min=1e-9)
        mean_pooled = (summed / denom).cpu().numpy()
        embs.append(mean_pooled)
    return np.vstack(embs)  # (N, H)


In [None]:
# GAT implementation (2-layer)
class GATLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.W = nn.Linear(in_dim, out_dim, bias=False)
        self.a = nn.Linear(2*out_dim, 1, bias=False)
        self.leaky = nn.LeakyReLU(0.2)
    def forward(self, h, adj_list):
        # h: (N, in_dim). adj_list: list of neighbor idx lists for each node
        Wh = self.W(h)  # (N, out_dim)
        N = Wh.size(0)
        Wh_repeat_i = []
        Wh_repeat_j = []
        e_rows = []
        # compute attention e_ij only for edges in adj_list to save compute
        for i in range(N):
            neigh = adj_list[i]
            if len(neigh)==0:
                e_rows.append((i, torch.tensor([], device=Wh.device, dtype=torch.float)))
                continue
            wi = Wh[i].unsqueeze(0).repeat(len(neigh),1)  # (deg, out)
            wj = Wh[neigh]  # (deg, out)
            a_input = torch.cat([wi, wj], dim=1)  # (deg, 2*out)
            e_ij = self.leaky(self.a(a_input)).squeeze(-1)  # (deg,)
            e_rows.append((i, e_ij))
        # softmax over neighbors and compute aggregated features
        out = torch.zeros_like(Wh)
        for i, e_ij in e_rows:
            neigh = adj_list[i]
            if len(neigh)==0:
                out[i] = Wh[i]  # self-loop fallback
                continue
            alpha = F.softmax(e_ij, dim=0)  # (deg,)
            neigh_feat = Wh[neigh]  # (deg, out)
            agg = (alpha.unsqueeze(1) * neigh_feat).sum(0)  # (out,)
            out[i] = agg
        return out

In [None]:
class GETSumGAT(nn.Module):
    def __init__(self, in_dim, hid_dim):
        super().__init__()
        self.gat1 = GATLayer(in_dim, hid_dim)
        self.gat2 = GATLayer(hid_dim, hid_dim)
    def forward(self, h, adj_list):
        h1 = F.elu(self.gat1(h, adj_list))
        h2 = self.gat2(h1, adj_list)
        return h2  # (N, hid_dim)

In [None]:
# Ranking head and gating
class GETSumModel(nn.Module):
    def __init__(self, sent_dim, gat_hid):
        super().__init__()
        self.gat = GETSumGAT(sent_dim, gat_hid)
        # gating: combine sent emb and gat emb
        self.gate = nn.Sequential(
            nn.Linear(sent_dim + gat_hid, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
        self.scorer = nn.Linear(sent_dim + gat_hid, 1)
    def forward(self, sent_emb_np, adj_list):
        # sent_emb_np: numpy (N, sent_dim); convert to tensor
        h = torch.from_numpy(sent_emb_np).float().to(device)
        gat_h = self.gat(h, adj_list)  # (N, gat_hid)
        concat = torch.cat([h, gat_h], dim=1)  # (N, sent+gat)
        # gating (sigmoid) to mix (we'll use gate scalar per sentence)
        gate_logits = self.gate(concat).squeeze(-1)  # (N,)
        gate = torch.sigmoid(gate_logits).unsqueeze(1)  # (N,1)
        # joint representation
        sent_dim = h.size(1)
        gat_dim = gat_h.size(1)
        # to combine we will compute weighted sum: gate * sent + (1-gate) * gat_projected
        # project gat to sent_dim if dims differ; here dims may be different, so just keep concat for scoring
        scores = self.scorer(concat).squeeze(-1)  # (N,)
        return scores.detach().cpu().numpy(), gate.detach().cpu().numpy(), concat.detach().cpu().numpy()

In [None]:
# Helper: build adjacency list per document using cosine sim threshold
def build_adj_list(sent_embs, threshold=SIM_THRESHOLD):
    # sent_embs: (N, H)
    N = sent_embs.shape[0]
    if N==1:
        return [[]]  # no edges
    sim = cosine_similarity(sent_embs)  # (N,N)
    adj_list = []
    for i in range(N):
        # neighbors excluding self where sim >= threshold
        neigh = [j for j in range(N) if j!=i and sim[i,j] >= threshold]
        # if none, include top-2 most similar to keep graph connected
        if len(neigh)==0:
            topk = np.argsort(sim[i])[::-1][1:3]  # skip self
            neigh = [int(x) for x in topk]
        adj_list.append(neigh)
    return adj_list


In [None]:
# Initialize GETSum model (weights on CPU then move to device)
SENT_DIM = bert.config.hidden_size  # typically 768
model_getsum = GETSumModel(sent_dim=SENT_DIM, gat_hid=GAT_HID).to(device)
model_getsum.eval()

In [None]:
# Run inference over dataset, build summaries, and evaluate
references = []
predictions = []
sample_count = len(dataset)

In [None]:
for idx in tqdm(range(sample_count), desc="Processing articles"):
    sample = dataset[idx]
    article = sample["article"]
    ref = sample["highlights"]
    sentences = split_sentences(article)
    # limit sentence count for very long docs to avoid explosion (optional)
    if len(sentences) > 80:
        # keep first 60 and top 20 longest (heuristic)
        sentences = sentences[:60] + sorted(sentences[60:], key=len, reverse=True)[:20]
    sent_embs = encode_sentences(sentences)  # (n_sent, H)
    adj_list = build_adj_list(sent_embs, threshold=SIM_THRESHOLD)
    # forward through GETSum scoring head
    with torch.no_grad():
        scores, gate_vals, joint_reps = model_getsum(sent_embs, adj_list)
    # pick top-K sentences by score (maintain original order)
    topk_idx = np.argsort(scores)[-TOP_K_SENT:]
    topk_idx_sorted = sorted(topk_idx)
    summary = " ".join([sentences[i] for i in topk_idx_sorted])
    references.append(ref)
    predictions.append(summary)


In [None]:
# Evaluation: ROUGE
scorer = rouge_scorer.RougeScorer(['rouge1','rouge2','rougeL'], use_stemmer=True)
r1 = r2 = rl = 0.0
for ref, pred in zip(references, predictions):
    sc = scorer.score(ref, pred)
    r1 += sc['rouge1'].fmeasure
    r2 += sc['rouge2'].fmeasure
    rl += sc['rougeL'].fmeasure
n = len(predictions)
print(f"\nROUGE-1: {r1/n:.4f}, ROUGE-2: {r2/n:.4f}, ROUGE-L: {rl/n:.4f}")

In [None]:
# Evaluation: BERTScore
P, R, F1 = bertscore_score(predictions, references, lang="en", verbose=True)
print("BERTScore F1 (mean):", F1.mean().item())

In [None]:
# Show a few examples
print("\n--- Examples ---\n")
for i in range(3):
    print("ARTICLE (start):", dataset[i]["article"][:400].replace("\n"," "), "...\n")
    print("REFERENCE:", references[i], "\n")
    print("GETSum (extract):", predictions[i], "\n")
    print("-"*80)