In [None]:
# Uncomment to install if needed
# !pip install transformers datasets torch dgl spacy scikit-learn
# !python -m spacy download en_core_web_sm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from transformers import BertTokenizer, BertModel
from datasets import load_dataset

import spacy
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import re
import nltk
from nltk.corpus import stopwords

import dgl
from dgl.nn.pytorch import GATConv

nltk.download('stopwords')
nlp = spacy.load('en_core_web_sm')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# Load the entire dataset (train, validation, test splits)
dataset = load_dataset("abisee/cnn_dailymail", "3.0.0")

# Print dataset info
print(dataset)

In [None]:
# Preprocessing function as per document: normalization, segmentation, tokenization, stopword filtering, rare token handling

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
stop_words = set(stopwords.words('english'))

def preprocess_article(article):
    # 1. Text normalization: lowercase, remove HTML/tags/hyperlinks/emojis/special chars
    article = article.lower()
    article = re.sub(r'<.*?>', '', article)  # Remove HTML tags
    article = re.sub(r'http\S+|www\S+', '', article)  # Remove URLs
    article = re.sub(r'[^\w\s]', '', article)  # Remove special chars/emojis
    article = re.sub(r'\s+', ' ', article).strip()  # Standardize whitespace

    # 2. Sentence segmentation with spaCy
    doc = nlp(article)
    sentences = [sent.text.strip() for sent in doc.sents if len(sent.text.strip()) > 0]

    # 3. Tokenization with BERT WordPiece
    tokenized_sentences = []
    for sent in sentences:
        tokens = tokenizer.tokenize(sent)
        # Stopword filtering: remove stopwords for similarity computations later
        filtered_tokens = [tok for tok in tokens if tok not in stop_words]
        tokenized_sentences.append(filtered_tokens)

    # 4. Rare token handling: Replace tokens with frequency < 2 with <UNK> (corpus-level, but for simplicity, per article)
    # Build freq dict
    all_tokens = [tok for sent in tokenized_sentences for tok in sent]
    freq = nltk.FreqDist(all_tokens)
    for i, sent in enumerate(tokenized_sentences):
        tokenized_sentences[i] = [tok if freq[tok] > 1 else '<UNK>' for tok in sent]

    return sentences, tokenized_sentences


In [None]:
class SummarizationDataset(Dataset):
    def __init__(self, hf_dataset, split='train'):
        self.data = hf_dataset[split]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        article = self.data[idx]['article']
        summary = self.data[idx]['highlights']  # Ground truth summary
        sentences, tokenized_sentences = preprocess_article(article)
        return {
            'article': article,
            'summary': summary,
            'sentences': sentences,
            'tokenized_sentences': tokenized_sentences
        }

# Create datasets
train_dataset = SummarizationDataset(dataset, 'train')
val_dataset = SummarizationDataset(dataset, 'validation')
test_dataset = SummarizationDataset(dataset, 'test')

# DataLoaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

In [None]:
class GETSum(nn.Module):
    def __init__(self, hidden_dim=768, num_heads=8, num_gat_layers=2, dropout=0.1):
        super(GETSum, self).__init__()
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.gat_layers = nn.ModuleList([GATConv(hidden_dim, hidden_dim // num_heads, num_heads) for _ in range(num_gat_layers)])
        self.dropout = nn.Dropout(dropout)

        # Gating mechanism for fusion
        self.gate = nn.Linear(hidden_dim * 2, hidden_dim)  # For concatenation + sigmoid gating

        # For abstractive: Simple transformer decoder (e.g., using nn.TransformerDecoder)
        self.decoder = nn.TransformerDecoder(nn.TransformerDecoderLayer(d_model=hidden_dim, nhead=12), num_layers=6)
        self.fc_out = nn.Linear(hidden_dim, tokenizer.vocab_size)  # Output to vocab size for generation

    def forward(self, batch, mode='abstractive'):
        # Tokenize full article for BERT
        inputs = tokenizer(batch['article'], padding=True, truncation=True, max_length=512, return_tensors='pt').to(device)

        # 1. Transformer Encoder: Get contextual embeddings
        with torch.no_grad():  # Freeze BERT initially
            bert_outputs = self.bert(**inputs)
        word_embeddings = bert_outputs.last_hidden_state  # (batch, seq_len, 768)

        # Get sentence embeddings: Average word embeddings per sentence
        sentence_embeddings = []
        for i, sents in enumerate(batch['sentences']):
            sent_embs = []  # Per article
            for sent in sents:
                sent_tokens = tokenizer(sent, return_tensors='pt').to(device)
                sent_out = self.bert(**sent_tokens).last_hidden_state.mean(dim=1)  # Avg pooling
                sent_embs.append(sent_out)
            sentence_embeddings.append(torch.stack(sent_embs))

        # Pad sentence_embeddings to same num_sentences (for batching, assume max_sents=50 for simplicity)
        max_sents = max(len(s) for s in sentence_embeddings)
        padded_sents = [torch.cat([s, torch.zeros(max_sents - len(s), 768).to(device)], dim=0) for s in sentence_embeddings]
        sentence_embeddings = torch.stack(padded_sents).to(device)  # (batch, max_sents, 768)

        # 2. Graph Construction: Nodes = sentences, Edges = cosine sim > 0.3
        graphs = []
        for i in range(batch_size):
            sim_matrix = cosine_similarity(sentence_embeddings[i].cpu().numpy())
            adj_matrix = (sim_matrix > 0.3).astype(float)  # Threshold 0.3
            src, dst = np.nonzero(adj_matrix)
            g = dgl.graph((src, dst), num_nodes=max_sents).to(device)
            g.ndata['feat'] = sentence_embeddings[i]
            graphs.append(g)
        batched_graph = dgl.batch(graphs)

        # 3. Graph Attention Network (GAT)
        h = batched_graph.ndata['feat']
        for layer in self.gat_layers:
            h = layer(batched_graph, h).flatten(1)  # Multi-head concat
            h = self.dropout(h)

        # Unbatch and pad back
        unbatched_h = dgl.unbatch(batched_graph)
        gat_embeddings = [uh.ndata['feat'] for uh in unbatched_h]  # List of (num_sents, 768)

        # 4. Representation Integration: Concat + Gating
        # For simplicity, average word_embeddings to document level, but align to sentence level
        # Assume sentence_embeddings from BERT is used for fusion
        fused_embeddings = []
        for i in range(batch_size):
            concat = torch.cat([sentence_embeddings[i], gat_embeddings[i]], dim=1)  # (num_sents, 1536)
            gate_weight = torch.sigmoid(self.gate(concat))  # (num_sents, 768)
            fused = gate_weight * sentence_embeddings[i] + (1 - gate_weight) * gat_embeddings[i]
            fused_embeddings.append(fused)
        fused_embeddings = torch.stack(fused_embeddings)  # (batch, max_sents, 768)

        # 5. Summary Generation
        if mode == 'extractive':
            # Rank sentences: Use fused scores (e.g., sum over dim=2)
            scores = fused_embeddings.sum(dim=2)  # (batch, max_sents)
            top_k = torch.topk(scores, k=3, dim=1).indices  # Top 3-5 sentences
            return top_k  # For extractive: indices to select sentences

        elif mode == 'abstractive':
            # Feed fused to decoder (target: tokenized summary)
            tgt_inputs = tokenizer(batch['summary'], padding=True, truncation=True, return_tensors='pt').to(device)
            decoder_out = self.decoder(tgt_inputs.input_ids, fused_embeddings.mean(dim=1).unsqueeze(1))  # Simplified
            logits = self.fc_out(decoder_out)
            return logits

        else:
            raise ValueError("Mode must be 'extractive' or 'abstractive'")

# Instantiate model
model = GETSum().to(device)

In [None]:
# Loss functions
extractive_loss_fn = nn.MarginRankingLoss(margin=1.0)  # Pairwise ranking for extractive
abstractive_loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

optimizer = optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

def train_epoch(loader, model, optimizer, mode='abstractive', epoch=0):
    model.train()
    total_loss = 0
    for batch in loader:
        optimizer.zero_grad()

        if mode == 'extractive':
            # For extractive: Need ground truth sentence labels (simulate: assume first few sentences are positive)
            # In practice, use ROUGE-based oracle labels; here simplified
            outputs = model(batch, mode='extractive')  # Top indices
            # Pairwise loss: Compare positive/negative pairs (dummy example)
            pos_scores = torch.ones(batch_size).to(device)  # Placeholder
            neg_scores = torch.zeros(batch_size).to(device)
            loss = extractive_loss_fn(pos_scores, neg_scores, torch.ones(batch_size).to(device))

        elif mode == 'abstractive':
            logits = model(batch, mode='abstractive')
            tgt = tokenizer(batch['summary'], padding=True, truncation=True, return_tensors='pt')['input_ids'].to(device)
            loss = abstractive_loss_fn(logits.view(-1, tokenizer.vocab_size), tgt.view(-1))

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Avg Loss = {total_loss / len(loader)}")
    return total_loss / len(loader)

# Fine-tune for 10 epochs (as per document for CNN/DM)
num_epochs = 10
for epoch in range(num_epochs):
    train_loss = train_epoch(train_loader, model, optimizer, mode='hybrid')  # Switch mode as needed
    # Validate similarly on val_loader (implement eval function for ROUGE/BERTScore)

# Save model
torch.save(model.state_dict(), 'getsum_model.pth')

In [None]:
# Implement evaluation metrics (ROUGE, BERTScore)
# !pip install rouge-score bert-score
from rouge_score import rouge_scorer
from bert_score import score

def evaluate(model, loader, mode='abstractive'):
    model.eval()
    predictions, references = [], []
    with torch.no_grad():
        for batch in loader:
            if mode == 'extractive':
                top_indices = model(batch, mode='extractive')
                pred_summary = [' '.join([batch['sentences'][i][idx] for idx in top_indices[i]]) for i in range(batch_size)]
            elif mode == 'abstractive':
                logits = model(batch, mode='abstractive')
                pred_ids = torch.argmax(logits, dim=-1)
                pred_summary = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)

            predictions.extend(pred_summary)
            references.extend(batch['summary'])

    # ROUGE
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)
    rouge_scores = {k: np.mean([scorer.score(ref, pred)[k].fmeasure for ref, pred in zip(references, predictions)]) for k in ['rouge1', 'rouge2', 'rougeL']}

    # BERTScore
    P, R, F1 = score(predictions, references, lang='en', verbose=True)
    bertscore = F1.mean().item()

    print(f"ROUGE: {rouge_scores}\nBERTScore: {bertscore}")

# Example: Evaluate on test
evaluate(model, test_loader)