In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments
import torch

In [None]:
dataset = load_dataset("cnn_dailymail", "3.0.0", split="train[:200]")

def clean_text(example):
    example["article"] = example["article"].lower().strip()
    example["highlights"] = example["highlights"].lower().strip()
    return example

dataset = dataset.map(clean_text)
print(dataset[0])

In [None]:
tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

def preprocess_function(examples):
    inputs = ["summarize: " + doc for doc in examples["article"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")
    labels = tokenizer(examples["highlights"], max_length=150, truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True)
tokenized_dataset = tokenized_dataset.train_test_split(test_size=0.2)

In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    save_steps=500,
    save_total_limit=2,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
)

trainer.train()
model.save_pretrained("./fine_tuned_t5")
tokenizer.save_pretrained("./fine_tuned_t5")

In [None]:
from datasets import load_metric

rouge = load_metric("rouge")

def generate_summary(batch):
    inputs = tokenizer(batch["article"], return_tensors="pt", max_length=512, truncation=True, padding=True)
    outputs = model.generate(inputs["input_ids"], max_length=150, min_length=40, length_penalty=2.0, num_beams=4)
    batch["predicted_summary"] = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return batch

results = tokenized_dataset["test"].map(generate_summary)
predictions = results["predicted_summary"]
references = results["highlights"]

rouge_scores = rouge.compute(predictions=predictions, references=references)
print(rouge_scores)