In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install transformers datasets bert-score rouge-score nltk

Collecting datasets
  Downloading datasets-3.5.1-py3-none-any.whl.metadata (19 kB)
Collecting bert-score
  Downloading bert_score-0.3.13-py3-none-any.whl.metadata (15 kB)
Collecting rouge-score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2025.3.0,>=2023.1.0 (from fsspec[http]<=2025.3.0,>=2023.1.0->datasets)
  Downloading fsspec-2025.3.0-py3-none-any.whl.metadata (11 kB)
Downloading datasets-3.5.1-py3-none-any.whl (491 kB)
[2K   [90m━━━━━━━

In [None]:
!pip install transformers datasets bert-score rouge-score nltk
import nltk
nltk.download('punkt_tab')



[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [None]:
import torch
import torch.nn as nn
from transformers import (
    BertTokenizer, BertModel,
    RobertaTokenizer, RobertaModel,
    BartTokenizer, BartModel,
    T5Tokenizer, T5EncoderModel,
    GPT2Tokenizer, GPT2Model
)
from datasets import load_dataset
from nltk.tokenize import sent_tokenize

# === Sentence Splitter ===
def split_into_sentences(text):
    return sent_tokenize(text)

# === Tokenizer Helper ===
def tokenize_sentences(sentences, tokenizer):
    return tokenizer(sentences, padding="max_length", truncation=True, max_length=128, return_tensors="pt")

# === BERT Model ===
class BertForExtractiveSummarization(nn.Module):
    def __init__(self, pretrained_model="bert-base-uncased"):
        super().__init__()
        self.bert = BertModel.from_pretrained(pretrained_model)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        return self.classifier(cls_output).squeeze(-1)

# === RoBERTa Model ===
class RobertaForExtractiveSummarization(nn.Module):
    def __init__(self, pretrained_model="roberta-base"):
        super().__init__()
        self.roberta = RobertaModel.from_pretrained(pretrained_model)
        self.classifier = nn.Linear(self.roberta.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        return self.classifier(cls_output).squeeze(-1)

# === BART Model ===
class BartForExtractiveSummarization(nn.Module):
    def __init__(self, pretrained_model="facebook/bart-base"):
        super().__init__()
        self.bart = BartModel.from_pretrained(pretrained_model)
        self.classifier = nn.Linear(self.bart.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.bart(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        return self.classifier(cls_output).squeeze(-1)

# === FLAN-T5-Small Model ===
class FlanT5ForExtractiveSummarization(nn.Module):
    def __init__(self, pretrained_model="google/flan-t5-small"):
        super().__init__()
        self.encoder = T5EncoderModel.from_pretrained(pretrained_model)
        self.classifier = nn.Linear(self.encoder.config.d_model, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        return self.classifier(cls_output).squeeze(-1)

# === GPT-2 Model ===
class GPT2ForExtractiveSummarization(nn.Module):
    def __init__(self, pretrained_model="gpt2"):
        super().__init__()
        self.encoder = GPT2Model.from_pretrained(pretrained_model)
        self.classifier = nn.Linear(self.encoder.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]
        return self.classifier(cls_output).squeeze(-1)

# === Summary Generator ===
def generate_summary(model, tokenizer, sentences, device):
    if len(sentences) < 3:
        sentences += [""] * (3 - len(sentences))
    tokenized = tokenize_sentences(sentences, tokenizer)
    input_ids = tokenized['input_ids'].to(device)
    attention_mask = tokenized['attention_mask'].to(device)
    with torch.no_grad():
        logits = model(input_ids, attention_mask).squeeze(0)
        if logits.dim() == 0 or len(logits) != len(sentences):
            return ""
        top_indices = sorted(range(len(logits)), key=lambda i: logits[i], reverse=True)[:3]
        return " ".join([sentences[i] for i in top_indices])

# === Main Execution ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load GIGAWORD Sample 69
dataset = load_dataset("gigaword", split="test")
sample = dataset[79]
article = sample["document"]
reference_summary = sample["summary"]
sentences = split_into_sentences(article)

# === Load and Run Each Model ===

# BERT
bert_model = BertForExtractiveSummarization().to(device)
bert_state = torch.load("/best_bert_gigaword_model.pt", map_location=device)
bert_model.load_state_dict({k.replace("module.", ""): v for k, v in bert_state.items()})
bert_model.eval()
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_summary = generate_summary(bert_model, bert_tokenizer, sentences, device)

# RoBERTa
roberta_model = RobertaForExtractiveSummarization().to(device)
roberta_state = torch.load("/best_roberta_gigaword_model.pt", map_location=device)
roberta_model.load_state_dict({k.replace("module.", ""): v for k, v in roberta_state.items()})
roberta_model.eval()
roberta_tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
roberta_summary = generate_summary(roberta_model, roberta_tokenizer, sentences, device)

# BART
bart_model = BartForExtractiveSummarization().to(device)
bart_state = torch.load("/best_bart_gigaword_model.pt", map_location=device)
bart_model.load_state_dict({k.replace("module.", ""): v for k, v in bart_state.items()})
bart_model.eval()
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
bart_summary = generate_summary(bart_model, bart_tokenizer, sentences, device)

# FLAN-T5
flan_model = FlanT5ForExtractiveSummarization().to(device)
flan_state = torch.load("/best_flan_t5_gigaword_model.pt", map_location=device)
flan_model.load_state_dict({k.replace("module.", ""): v for k, v in flan_state.items()})
flan_model.eval()
flan_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")
flan_summary = generate_summary(flan_model, flan_tokenizer, sentences, device)

# GPT-2
gpt2_model = GPT2ForExtractiveSummarization().to(device)
gpt2_state = torch.load("/best_gpt2_extractive_gigaword.pt", map_location=device)
gpt2_model.load_state_dict({k.replace("module.", ""): v for k, v in gpt2_state.items()})
gpt2_model.eval()
gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
gpt2_tokenizer.pad_token = gpt2_tokenizer.eos_token  # Set pad token if missing
gpt2_summary = generate_summary(gpt2_model, gpt2_tokenizer, sentences, device)

# === Print Results ===
print("\n   Original Document:\n", article)
print("\n   Reference Summary:\n", reference_summary)
print("\n   BERT Predicted Summary:\n", bert_summary)
print("\n   RoBERTa Predicted Summary:\n", roberta_summary)
print("\n   BART Predicted Summary:\n", bart_summary)
print("\n   FLAN-T5-Small Predicted Summary:\n", flan_summary)
print("\n   GPT-2 Predicted Summary:\n", gpt2_summary)


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.



   Original Document:
 a french crocodile farm said thursday it had stepped up efforts to breed one of the world 's most endangered species , the indian UNK , with the hope of ultimately returning animals to their habitat in south asia .

   Reference Summary:
 french farm offers hope for endangered asian crocs UNK picture

   BERT Predicted Summary:
 a french crocodile farm said thursday it had stepped up efforts to breed one of the world 's most endangered species , the indian UNK , with the hope of ultimately returning animals to their habitat in south asia .  

   RoBERTa Predicted Summary:
 a french crocodile farm said thursday it had stepped up efforts to breed one of the world 's most endangered species , the indian UNK , with the hope of ultimately returning animals to their habitat in south asia .  

   BART Predicted Summary:
 a french crocodile farm said thursday it had stepped up efforts to breed one of the world 's most endangered species , the indian UNK , with the hope 