# BERT-based Extractive Summarizer

This pipeline performs extractive summarization using a sentence-level BERT encoder and binary classification to select salient sentences.


In [None]:
!pip install transformers datasets torch evaluate accelerate --quiet

In [None]:
!pip install sacrebleu sacremoses --quiet

In [4]:
!pip install -q sentence-transformers scikit-learn nltk

### 🔹 Step 1: Imports & Device Setup


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel
from torch.optim import AdamW
from datasets import load_dataset
import nltk

# nltk.download('punkt')
from nltk.tokenize import sent_tokenize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### 🔹 Step 2: Dataset Class for Sentence Labeling


In [6]:
class BertSumDataset(Dataset):
    def __init__(self, articles, summaries, tokenizer, max_len=128, max_sentences=60):
        self.articles = articles
        self.summaries = summaries
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.max_sentences = max_sentences

    def __len__(self):
        return len(self.articles)

    def __getitem__(self, idx):
        article = self.articles[idx]
        summary = self.summaries[idx]

        sentences = sent_tokenize(article)
        labels = [1 if sent in summary else 0 for sent in sentences]

        # Trim if longer than max_sentences
        sentences = sentences[:self.max_sentences]
        labels = labels[:self.max_sentences]

        # Tokenize sentences
        encoding = self.tokenizer(
            sentences,
            padding='max_length',
            truncation=True,
            max_length=self.max_len,
            return_tensors='pt'
        )

        # Pad sentences dimension if fewer than max_sentences
        current_len = encoding['input_ids'].size(0)
        pad_len = self.max_sentences - current_len

        if pad_len > 0:
            pad_input_ids = torch.zeros((pad_len, self.max_len), dtype=torch.long)
            pad_attention_mask = torch.zeros((pad_len, self.max_len), dtype=torch.long)
            encoding['input_ids'] = torch.cat([encoding['input_ids'], pad_input_ids], dim=0)
            encoding['attention_mask'] = torch.cat([encoding['attention_mask'], pad_attention_mask], dim=0)

            labels += [0] * pad_len  # pad labels with 0

        labels = torch.tensor(labels, dtype=torch.float)

        return encoding['input_ids'], encoding['attention_mask'], labels

### 🔹 Step 3: BERT-based Sentence Classifier


In [7]:
class BertSumClassifier(nn.Module):
    def __init__(self, pretrained_model='bert-base-uncased'):
        super(BertSumClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(pretrained_model)
        self.classifier = nn.Linear(self.bert.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        # input_ids shape: (batch_size, num_sentences, max_len)
        batch_size, num_sentences, max_len = input_ids.shape
        
        input_ids = input_ids.view(-1, max_len)          # (batch_size*num_sentences, max_len)
        attention_mask = attention_mask.view(-1, max_len)

        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        cls_embeddings = outputs.last_hidden_state[:, 0, :]  # CLS token embedding
        
        logits = self.classifier(cls_embeddings)        # (batch_size*num_sentences, 1)
        logits = logits.view(batch_size, num_sentences) # (batch_size, num_sentences)
        
        return logits.squeeze(-1)

### 🔹 Step 4: Training Function


In [8]:
from tqdm import tqdm

def train(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for input_ids, attention_mask, labels in tqdm(dataloader, desc="Training Batches"):
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs
        loss = criterion(logits.squeeze(-1), labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)

### 🔹 Step 5: Load Data & Prepare Training


In [9]:
# Load tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Load CNN/DailyMail (using just a small split for demo)
dataset = load_dataset('cnn_dailymail', '3.0.0', split='train[:1%]')

# For demonstration, use article as summary (replace with real summary in practice)
articles = dataset['article']
summaries = dataset['highlights']  # Usually abstractive, so labels generation is tricky

# Prepare Dataset & DataLoader
train_dataset = BertSumDataset(articles, summaries, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

### 🔹 Step 6: Train the Model


In [10]:
# Instantiate model
model = BertSumClassifier().to(device)

# Optimizer and loss
optimizer = AdamW(model.parameters(), lr=2e-5)
criterion = nn.BCEWithLogitsLoss()

In [11]:
# Training loop (example 3 epochs)
for epoch in range(1):
    loss = train(model, train_loader, optimizer, criterion, device)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")

Training Batches: 100%|██████████| 1436/1436 [32:54<00:00,  1.37s/it]

Epoch 1, Loss: 0.0048





### 🔹 Step 7: Evaluate with ROUGE


In [None]:
!pip install -q rouge_score

In [14]:
from rouge_score import rouge_scorer

scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=True)

def compute_rouge(predicted_summary, reference_summary):
    scores = scorer.score(reference_summary, predicted_summary)
    return {
        'rouge1': scores['rouge1'].fmeasure,
        'rouge2': scores['rouge2'].fmeasure,
        'rougeL': scores['rougeL'].fmeasure,
    }

# Example usage
pred = "The cat sat on the mat."
ref = "The cat is sitting on the mat."

print(compute_rouge(pred, ref))

{'rouge1': 0.7692307692307692, 'rouge2': 0.5454545454545454, 'rougeL': 0.7692307692307692}
