In [None]:
!pip install -q transformers datasets evaluate accelerate sentencepiece

In [None]:
import numpy as np
import torch
from datasets import load_dataset
import evaluate
from transformers import (
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DataCollatorForSeq2Seq,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
)

In [None]:
dataset_name = "cnn_dailymail"
dataset_config = "3.0.0"

raw_datasets = load_dataset(dataset_name, dataset_config)
print(raw_datasets)

# Columns: "article" (input), "highlights" (summary)
text_column = "article"
summary_column = "highlights"

# ⚡ To keep training fast in Colab, we take a small subset
train_subset = 2000
eval_subset = 1000

raw_datasets["train"] = raw_datasets["train"].select(range(min(train_subset, len(raw_datasets["train"]))))
raw_datasets["validation"] = raw_datasets["validation"].select(range(min(eval_subset, len(raw_datasets["validation"]))))

In [None]:
model_name = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)


In [None]:
max_source_length = 512
max_target_length = 128

def preprocess_function(examples):
    inputs = examples[text_column]
    targets = examples[summary_column]

    model_inputs = tokenizer(
        inputs, max_length=max_source_length, padding="max_length", truncation=True
    )
    labels = tokenizer(
        targets, max_length=max_target_length, padding="max_length", truncation=True
    )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    remove_columns=raw_datasets["train"].column_names,
    desc="Tokenizing",
)

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)
rouge = evaluate.load("rouge")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    return {k: round(v * 100, 2) for k, v in result.items()}


In [None]:
batch_size = 1
training_args = Seq2SeqTrainingArguments(
    output_dir="./bart-cnn-finetune",
    evaluation_strategy="steps",
    eval_steps=500,
    save_steps=500,
    logging_steps=50,
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=8,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=3,   # ⚡ Change to 2–3 if GPU allows
    predict_with_generate=True,
    generation_max_length=128,
    generation_num_beams=4,
    fp16=torch.cuda.is_available(),
    report_to="none"
)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
trainer.save_model("./bart-cnn-finetune")
tokenizer.save_pretrained("./bart-cnn-finetune")

In [None]:
sample_text = raw_datasets["validation"][0]["article"]
inputs = tokenizer([sample_text], max_length=512, truncation=True, return_tensors="pt").to(model.device)
summary_ids = model.generate(**inputs, num_beams=4, max_length=128)
print("ARTICLE:", sample_text[:400], "...")
print("SUMMARY:", tokenizer.decode(summary_ids[0], skip_special_tokens=True))