In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers import DataCollatorForSeq2Seq
import evaluate
import numpy as np
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
import torch



In [None]:
## Load data, but because my M1 Mac was too slow, I only used 100 examples for this demo
billsum = load_dataset("billsum", split="ca_test")
# print billsum size
print("billsum size: ", len(billsum))

# only keep 100 examples for this demo
billsum = billsum.select(range(100))

print("billsum size: ", len(billsum))
billsum = billsum.train_test_split(test_size=0.2)
billsum["train"][0]

In [None]:
## t5-small as checkpoint
checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
# add prefix and preprocess function
prefix = "summarize: "

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)

    labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [None]:
# preprocess data use map function can speed up the process
tokenized_billsum = billsum.map(preprocess_function, batched=True)

In [None]:
# Seq2Seq models are usually trained with a batch size of 16 or 32,
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [None]:
# evaluate function use rouge metric
rouge = evaluate.load("rouge")

In [None]:
# compute metrics
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [None]:
# load teh Seq2Seq model
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

In [None]:
# Because I am using M1 Mac, so set the device to mps, if you are using GPU, you can set it to cuda
device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
model.to(device)
device

In [None]:
## trainer setup
training_args = Seq2SeqTrainingArguments(
    output_dir="my_awesome_billsum_model",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=4,
    predict_with_generate=True,
    fp16=False,
    push_to_hub=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_billsum["train"],
    eval_dataset=tokenized_billsum["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
# start training
trainer.train()

In [None]:
# That just for push the model to hub after fine-tuning
# from huggingface_hub import notebook_login
# notebook_login()
# trainer.push_to_hub()

In [None]:
# to try this model, you can use the following code
from transformers import pipeline

text = "summarize: The Inflation Reduction Act lowers prescription drug costs, health care costs, and energy costs. It's the most aggressive action on tackling the climate crisis in American history, which will lift up American workers and create good-paying, union jobs across the country. It'll lower the deficit and ask the ultra-wealthy and corporations to pay their fair share. And no one making under $400,000 per year will pay a penny more in taxes."

summarizer = pipeline("summarization", model="KRayRay/my_awesome_billsum_model")

print(summarizer(text))