In [None]:
# --- 1. Environment Setup ---
!pip install transformers[torch] pandas numpy scikit-learn nltk rouge-score tqdm

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup
from torch.optim import AdamW 
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from rouge_score import rouge_scorer
from nltk.tokenize import sent_tokenize
from tqdm.auto import tqdm
import os
from pathlib import Path

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

PRE_TRAINED_MODEL_NAME = 'bert-base-uncased'
MAX_LEN = 512 

MAX_EPOCHS = 6
PATIENCE = 2 # Stop after 2 epochs with no improvement

Collecting rouge-score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<2.7,>=2.1->transformers[torch])
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<2.7,>=2.1->transformers[torch])
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<2.7,>=2.1->transformers[torch])
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<2.7,>=2.1->transformers[torch])
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<2.7,>=2.1->transformers[torch])
  Downloading nvidia_cublas_cu12-12.4.5.8-p

2025-09-02 16:26:16.859193: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1756830377.026594      19 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1756830377.085819      19 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Using device: cuda


In [None]:
import os
from pathlib import Path
import pandas as pd
from sklearn.model_selection import train_test_split

def find_dataset_path(start_path="/kaggle/input/"):
    """
    Searches robustly for the correct 'BBC News Summary' directory that
    contains the category subfolders.
    """
    print("--- Searching for dataset directory ---")
    
    # We are looking for the directory that CONTAINS 'News Articles' and 'Summaries'
    for root, dirs, files in os.walk(start_path):
        if "News Articles" in dirs and "Summaries" in dirs:
            articles_path = Path(root) / "News Articles"
            
            # Check for at least one category subfolder (e.g., 'business')
            if any(p.is_dir() for p in articles_path.iterdir()):
                print(f"Found valid dataset base at: {root}")
                return root
            
    # If the loop finishes without finding a valid path
    raise FileNotFoundError("Could not automatically locate the 'BBC News Summary' dataset with category subfolders. "
                            "Please check the input directory structure in the Kaggle sidebar.")


def load_bbc_dataset(base_path):
    """Loads and pivots the BBC News Summary dataset from the specified path."""
    print(f"Attempting to load dataset from: {base_path}")
    all_data = []
    articles_path = Path(base_path) / "News Articles"
    summaries_path = Path(base_path) / "Summaries"
    # This loop will now be guaranteed to run on the correct directory
    for category_path in articles_path.iterdir():
        if category_path.is_dir():
            category = category_path.name
            for article_file in category_path.glob("*.txt"):
                try:
                    with open(article_file, 'r', encoding='utf-8', errors='ignore') as f: article_content = f.read()
                    summary_file = summaries_path / category / article_file.name
                    with open(summary_file, 'r', encoding='utf-8', errors='ignore') as f: summary_content = f.read()
                    all_data.append({"article": article_content, "reference_summary": summary_content})
                except Exception:
                    continue
    return pd.DataFrame(all_data)

try:
    DATASET_PATH = find_dataset_path()
    df = load_bbc_dataset(DATASET_PATH)

    if df.empty:
        raise ValueError("The loaded DataFrame is empty. The file paths might be correct, but no data was read.")

    main_train_df, test_df = train_test_split(df, test_size=0.2, random_state=42)
    train_df, val_df = train_test_split(main_train_df, test_size=0.1, random_state=42)

    print(f"\nSuccessfully loaded and split the data.")
    print(f"Total articles: {len(df)}")
    print(f"Training set size: {len(train_df)}")
    print(f"Validation set size: {len(val_df)}")
    print(f"Test set size (for final evaluation later): {len(test_df)}")

except (FileNotFoundError, ValueError) as e:
    print(f"\nERROR: {e}")
    train_df, val_df, test_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()

--- Searching for dataset directory ---
Found valid dataset base at: /kaggle/input/bbc-news-summary/BBC News Summary
Attempting to load dataset from: /kaggle/input/bbc-news-summary/BBC News Summary

Successfully loaded and split the data.
Total articles: 2225
Training set size: 1602
Validation set size: 178
Test set size (for final evaluation later): 445


In [3]:
# --- 4. Oracle Label Generation ---
def create_oracle_labels(article_text, reference_summary):
    """Greedily selects sentences to maximize ROUGE-2 F1-score."""
    try:
        article_sentences = sent_tokenize(article_text)
    except:
        return [], []
    
    if not article_sentences or not reference_summary:
        return article_sentences, [0] * len(article_sentences)

    scorer = rouge_scorer.RougeScorer(['rouge2'], use_stemmer=True)
    
    selected_indices = []
    
    # Greedily add sentences
    while True:
        best_candidate_idx = -1
        best_rouge_gain = -1.0
        
        current_summary = " ".join([article_sentences[j] for j in sorted(selected_indices)])
        base_rouge = scorer.score(reference_summary, current_summary)['rouge2'].fmeasure
        
        for i in range(len(article_sentences)):
            if i in selected_indices:
                continue
            
            # Try adding this sentence
            temp_selection = sorted(selected_indices + [i])
            summary_text = " ".join([article_sentences[j] for j in temp_selection])
            scores = scorer.score(reference_summary, summary_text)
            rouge_score = scores['rouge2'].fmeasure
            
            if rouge_score > base_rouge and (rouge_score - base_rouge) > best_rouge_gain:
                best_rouge_gain = rouge_score - base_rouge
                best_candidate_idx = i
                
        if best_candidate_idx != -1:
            selected_indices.append(best_candidate_idx)
        else:
            break
            
    labels = [1 if i in selected_indices else 0 for i in range(len(article_sentences))]
    return article_sentences, labels

# --- Example of Oracle Labeling ---
print("--- Oracle Labeling Example ---")
example_sents, example_labels = create_oracle_labels(train_df.iloc[0].article, train_df.iloc[0].reference_summary)
for sent, label in zip(example_sents, example_labels):
    print(f"LABEL: {label} | SENTENCE: {sent[:80]}...")

--- Oracle Labeling Example ---
LABEL: 0 | SENTENCE: Budget to set scene for election

Gordon Brown will seek to put the economy at t...
LABEL: 0 | SENTENCE: He is expected to stress the importance of continued economic stability, with lo...
LABEL: 1 | SENTENCE: The chancellor is expected to freeze petrol duty and raise the stamp duty thresh...
LABEL: 0 | SENTENCE: But the Conservatives and Lib Dems insist voters face higher taxes and more mean...
LABEL: 0 | SENTENCE: Treasury officials have said there will not be a pre-election giveaway, but Mr B...
LABEL: 1 | SENTENCE: - Increase in the stamp duty threshold from £60,000 
 - A freeze on petrol duty ...
LABEL: 0 | SENTENCE: Ten years ago, buyers had a much greater chance of avoiding stamp duty, with clo...
LABEL: 1 | SENTENCE: Since then, average UK property prices have more than doubled while the starting...
LABEL: 1 | SENTENCE: Tax credits As a result, the number of properties incurring stamp duty has rocke...
LABEL: 1 | SENTENCE: Th

In [4]:
# --- 5. PyTorch Dataset Class ---
class SummarizationDataset(Dataset):
    def __init__(self, dataframe, tokenizer, max_len=MAX_LEN):
        self.tokenizer = tokenizer
        self.dataframe = dataframe
        self.max_len = max_len

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, item):
        row = self.dataframe.iloc[item]
        article_sentences, labels = create_oracle_labels(row.article, row.reference_summary)
        
        if not article_sentences:
            return {'is_empty': True}

        text_for_bert = ""
        for sent in article_sentences:
             text_for_bert += sent + " [SEP] [CLS] "
        
        inputs = self.tokenizer.encode_plus(
            text_for_bert, max_length=self.max_len, padding='max_length',
            truncation=True, return_tensors='pt'
        )
        
        input_ids = inputs['input_ids'].flatten()
        attention_mask = inputs['attention_mask'].flatten()
        
        cls_indices = (input_ids == self.tokenizer.cls_token_id).nonzero().flatten()
        
        num_cls_tokens = len(cls_indices)
        labels = labels[:num_cls_tokens]
        
        padded_labels = np.zeros(self.max_len)
        if len(labels) > 0:
            padded_labels[:len(labels)] = labels
        
        return {
            'is_empty': False,
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'cls_indices': cls_indices,
            'labels': torch.tensor(padded_labels, dtype=torch.float)
        }

In [None]:
# --- 6. BERTSum Model ---
class BERTSummarizer(torch.nn.Module):
    def __init__(self, model_name=PRE_TRAINED_MODEL_NAME):
        super(BERTSummarizer, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(self.bert.config.hidden_size, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 1)
        )

    def forward(self, input_ids, attention_mask, cls_indices):
        input_ids = input_ids.squeeze(0)
        attention_mask = attention_mask.squeeze(0)
        
        outputs = self.bert(input_ids=input_ids.unsqueeze(0), attention_mask=attention_mask.unsqueeze(0))
        last_hidden_state = outputs.last_hidden_state.squeeze(0)

        cls_embeddings = last_hidden_state[cls_indices]
        
        logits = self.classifier(cls_embeddings)
        return torch.sigmoid(logits)

# --- Initialize tokenizer and model ---
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME)
model = BERTSummarizer().to(DEVICE)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [None]:
# --- 7. Training and Evaluation Functions ---
def train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler):
    model.train()
    total_loss = 0
    for batch in tqdm(data_loader, desc="Training"):
        if batch['is_empty'][0]: continue
        input_ids, attention_mask, cls_indices, labels = (
            batch['input_ids'].to(device), batch['attention_mask'].to(device),
            batch['cls_indices'].to(device), batch['labels'].to(device)
        )
        optimizer.zero_grad()
        predictions = model(input_ids, attention_mask, cls_indices.squeeze(0)).squeeze()
        num_predictions = predictions.shape[0]
        true_labels = labels.squeeze(0)[:num_predictions]
        loss = loss_fn(predictions, true_labels)
        total_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
    return total_loss / len(data_loader)

def eval_epoch(model, data_loader, loss_fn, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Validating"):
            if batch['is_empty'][0]: continue
            input_ids, attention_mask, cls_indices, labels = (
                batch['input_ids'].to(device), batch['attention_mask'].to(device),
                batch['cls_indices'].to(device), batch['labels'].to(device)
            )
            predictions = model(input_ids, attention_mask, cls_indices.squeeze(0)).squeeze()
            num_predictions = predictions.shape[0]
            true_labels = labels.squeeze(0)[:num_predictions]
            loss = loss_fn(predictions, true_labels)
            total_loss += loss.item()
    return total_loss / len(data_loader)

# --- Setup for Training ---
train_dataset = SummarizationDataset(train_df, tokenizer)
val_dataset = SummarizationDataset(val_df, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=1)
val_loader = DataLoader(val_dataset, batch_size=1)

optimizer = AdamW(model.parameters(), lr=2e-5)
loss_fn = torch.nn.BCELoss()
total_steps = len(train_loader) * MAX_EPOCHS
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)

# --- Main Training Loop with Early Stopping ---
best_validation_loss = float('inf')
epochs_no_improve = 0
for epoch in range(MAX_EPOCHS):
    print(f'--- Epoch {epoch + 1}/{MAX_EPOCHS} ---')
    train_loss = train_epoch(model, train_loader, loss_fn, optimizer, DEVICE, scheduler)
    print(f'Train loss: {train_loss:.4f}')
    
    val_loss = eval_epoch(model, val_loader, loss_fn, DEVICE)
    print(f'Validation loss: {val_loss:.4f}')
    
    if val_loss < best_validation_loss:
        best_validation_loss = val_loss
        torch.save(model.state_dict(), 'bertsum_best_model.bin')
        epochs_no_improve = 0
        print("Validation loss improved. Saving model.")
    else:
        epochs_no_improve += 1
        print(f"Validation loss did not improve. Counter: {epochs_no_improve}/{PATIENCE}")

    if epochs_no_improve >= PATIENCE:
        print("Early stopping triggered.")
        break

--- Epoch 1/6 ---


Training:   0%|          | 0/1602 [00:00<?, ?it/s]

Train loss: 0.5836


Validating:   0%|          | 0/178 [00:00<?, ?it/s]

Validation loss: 0.4839
Validation loss improved. Saving model.
--- Epoch 2/6 ---


Training:   0%|          | 0/1602 [00:00<?, ?it/s]

Train loss: 0.3882


Validating:   0%|          | 0/178 [00:00<?, ?it/s]

Validation loss: 0.5581
Validation loss did not improve. Counter: 1/2
--- Epoch 3/6 ---


Training:   0%|          | 0/1602 [00:00<?, ?it/s]

Train loss: 0.2532


Validating:   0%|          | 0/178 [00:00<?, ?it/s]

Validation loss: 0.8254
Validation loss did not improve. Counter: 2/2
Early stopping triggered.


In [None]:
# --- 8. Inference Function ---
def summarize_with_bertsum(text, model, tokenizer, device, max_sents=3):
    model.eval()
    try:
        article_sentences = sent_tokenize(text)
    except:
        return "Could not process text."
    if not article_sentences: return ""

    text_for_bert = " [SEP] [CLS] ".join(article_sentences)
    
    inputs = tokenizer.encode_plus(
        text_for_bert, max_length=MAX_LEN, padding='max_length',
        truncation=True, return_tensors='pt'
    )
    
    input_ids = inputs['input_ids'].to(device)
    attention_mask = inputs['attention_mask'].to(device)
    cls_indices = (input_ids.squeeze(0) == tokenizer.cls_token_id).nonzero().flatten()

    with torch.no_grad():
        predictions = model(input_ids, attention_mask, cls_indices).squeeze()
        
    sentence_scores = predictions.cpu().numpy()
    
    # Handle the case where there is only one sentence/prediction
    if sentence_scores.ndim == 0:
        sentence_scores = np.array([sentence_scores])

    num_sentences_to_select = min(max_sents, len(sentence_scores))
    top_indices = np.argsort(sentence_scores)[-num_sentences_to_select:]
    top_indices.sort()
    
    summary = " ".join([article_sentences[i] for i in top_indices if i < len(article_sentences)])
    return summary

# --- Example Usage ---
# Load the BEST model's weights (saved by early stopping)
model.load_state_dict(torch.load('bertsum_best_model.bin'))
model = model.to(DEVICE)

sample_article = val_df.iloc[10]['article']
reference_summary = val_df.iloc[10]['reference_summary']

print("\n\n--- Summarizing Sample Article ---")
print(f"REFERENCE SUMMARY:\n{reference_summary}")
s
# Determine summary length based on reference for a fair comparison
num_sents = len(sent_tokenize(reference_summary))
summary = summarize_with_bertsum(sample_article, model, tokenizer, DEVICE, max_sents=num_sents)
print(f"\nGENERATED SUMMARY (BERTSum):\n{summary}")



--- Summarizing Sample Article ---
REFERENCE SUMMARY:
"I can confirm that BaFin has passed through the case to the public prosecutor," a BaFin spokeswoman said."We are disappointed that the BaFin has referred to the prosecutor the question of whether action should be brought against individuals involved," Citigroup said.Traders at US banking giant Citigroup are facing a criminal investigation in Germany over a controversial bond deal.Germany's financial watchdog BaFin told BBC News it had now transferred the investigation to the public prosecutor.However, under German criminal law, prosecutors cannot pursue Citigroup itself.The move was widely criticised at the time, and now the German regulator has said it has found evidence of possible market manipulation.

GENERATED SUMMARY (BERTSum):
Citigroup said it would continue to co-operate fully with the authorities. "I can confirm that BaFin has passed through the case to the public prosecutor," a BaFin spokeswoman said.
