In [None]:
# Cell 1 - Imports
import numpy as np
from datasets import load_dataset
from transformers import (
    BertTokenizerFast,
    EncoderDecoderModel,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    DataCollatorForSeq2Seq
)
from evaluate import load
import torch


In [None]:
# Cell 2 - Load CNN/DailyMail full dataset
dataset = load_dataset("cnn_dailymail", "default")

print("Train size:", len(dataset["train"]))
print("Validation size:", len(dataset["validation"]))
print("Test size:", len(dataset["test"]))


In [None]:
# Cell 3 - Load BERT2BERT model
model_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(model_name)

# BERT encoder + BERT decoder
model = EncoderDecoderModel.from_encoder_decoder_pretrained(model_name, model_name)

# Generation parameters
model.config.decoder_start_token_id = tokenizer.cls_token_id
model.config.eos_token_id = tokenizer.sep_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.vocab_size = model.config.encoder.vocab_size


In [None]:
# Cell 4 - Preprocessing
max_input_length = 512
max_target_length = 128

def preprocess_function(examples):
    inputs = tokenizer(examples["article"], truncation=True, padding="max_length", max_length=max_input_length)
    targets = tokenizer(examples["highlights"], truncation=True, padding="max_length", max_length=max_target_length)

    inputs["labels"] = targets["input_ids"]
    return inputs

tokenized_datasets = dataset.map(preprocess_function, batched=True, remove_columns=["article", "highlights", "id"])


In [None]:
# Cell 5 - Training setup
batch_size = 4
args = Seq2SeqTrainingArguments(
    output_dir="./bert2bert_cnn",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.01,
    save_total_limit=2,
    predict_with_generate=True,
    fp16=torch.cuda.is_available(),
    logging_dir="./logs",
    logging_steps=500,
    report_to="none"
)


In [None]:
# Cell 6 - Metrics
rouge = load("rouge")
bertscore = load("bertscore")

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    rouge_result = rouge.compute(predictions=decoded_preds, references=decoded_labels)
    bertscore_result = bertscore.compute(predictions=decoded_preds, references=decoded_labels, lang="en")

    return {
        "rouge1": rouge_result["rouge1"],
        "rouge2": rouge_result["rouge2"],
        "rougeL": rouge_result["rougeL"],
        "bertscore_f1": np.mean(bertscore_result["f1"])
    }


In [None]:
# Cell 7 - Trainer
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)


In [None]:
# Cell 8 - Train BERT2BERT
trainer.train()


In [None]:
# Cell 9 - Evaluate on test set
test_results = trainer.evaluate(eval_dataset=tokenized_datasets["test"])
print(test_results)
