In [None]:
!pip install datasets transformers sentence-transformers evaluate bert-score


In [None]:
import torch
from datasets import load_dataset
from transformers import (
    BartTokenizer,
    BartForConditionalGeneration,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)
from sentence_transformers import SentenceTransformer, util
import nltk
from nltk.tokenize import sent_tokenize
import evaluate
import numpy as np

nltk.download("punkt")


In [None]:
dataset = load_dataset("cnn_dailymail", "3.0.0")
print(dataset)

train_size = len(dataset["train"])
val_size = len(dataset["validation"])
test_size = len(dataset["test"])

print("Train:", train_size, "Val:", val_size, "Test:", test_size)


In [None]:
# Load BERT sentence embeddings
bert_model = SentenceTransformer("all-MiniLM-L6-v2")

def extractive_stage(article, top_k=5):
    sentences = sent_tokenize(article)
    if len(sentences) <= top_k:
        return article
    embeddings = bert_model.encode(sentences, convert_to_tensor=True)
    doc_embedding = torch.mean(embeddings, dim=0, keepdim=True)
    scores = util.cos_sim(doc_embedding, embeddings)[0]
    top_indices = torch.topk(scores, k=top_k).indices
    selected_sentences = [sentences[i] for i in top_indices]
    return " ".join(selected_sentences)


In [None]:
# Batch-wise extractive preprocessing → save results
# This is the new Cell 4
import json
from tqdm import tqdm

batch_size = 500
top_k = 5

def process_and_save(split, filename):
    with open(filename, "w", encoding="utf-8") as f:
        for i in tqdm(range(0, len(split), batch_size)):
            batch = split[i:i+batch_size]
            for example in batch:
                reduced_article = extractive_stage(example["article"], top_k=top_k)
                json_line = json.dumps({
                    "article": reduced_article,
                    "highlights": example["highlights"]
                })
                f.write(json_line + "\n")

# Run for all splits
process_and_save(dataset["train"], "train_extractive.jsonl")
process_and_save(dataset["validation"], "val_extractive.jsonl")
process_and_save(dataset["test"], "test_extractive.jsonl")

In [None]:
# Load BART tokenizer & model
tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn")
model = BartForConditionalGeneration.from_pretrained("facebook/bart-large-cnn")

max_input_length = 1024
max_target_length = 256

def tokenize_batch(batch):
    inputs = tokenizer(batch["article"], max_length=max_input_length, truncation=True, padding="max_length")
    labels = tokenizer(batch["highlights"], max_length=max_target_length, truncation=True, padding="max_length")
    inputs["labels"] = labels["input_ids"]
    return inputs

train_tokenized = tokenize_batch(train_processed)
val_tokenized = tokenize_batch(val_processed)
test_tokenized = tokenize_batch(test_processed)


In [None]:
from datasets import Dataset

train_dataset = Dataset.from_dict(train_tokenized)
val_dataset = Dataset.from_dict(val_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)


In [None]:
batch_size = 4  # increase if GPU memory allows

training_args = Seq2SeqTrainingArguments(
    output_dir="./hybrid_bert_bart",
    evaluation_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=1,   # increase to 3-5 if compute allows
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),
    logging_dir="./logs",
    logging_steps=500,
    report_to="none"
)


In [None]:
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    rouge_result = rouge.compute(predictions=decoded_preds, references=decoded_labels)
    bertscore_result = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")

    return {
        "rouge1": rouge_result["rouge1"],
        "rouge2": rouge_result["rouge2"],
        "rougeL": rouge_result["rougeL"],
        "bertscore_f1": np.mean(bertscore_result["f1"])
    }


In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)


In [None]:
test_results = trainer.evaluate(eval_dataset=test_dataset)
print(test_results)