In [1]:
# source: https://huggingface.co/docs/transformers/tasks/summarization

In [10]:
import os
from huggingface_hub import notebook_login
# notebook_login()

from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer
import numpy as np

# !pip install evaluate
# !pip install rouge_score


import evaluate
rouge = evaluate.load("rouge")


In [3]:
billsum = load_dataset("billsum", split="ca_test")
print(billsum[0])

billsum = billsum.train_test_split(test_size=0.2)
print(billsum["train"][0])

{'text': 'The people of the State of California do enact as follows:\n\n\nSECTION 1.\nThe Legislature finds and declares all of the following:\n(a) (1) Since 1899 congressionally chartered veterans’ organizations have provided a valuable service to our nation’s returning service members. These organizations help preserve the memories and incidents of the great hostilities fought by our nation, and preserve and strengthen comradeship among members.\n(2) These veterans’ organizations also own and manage various properties including lodges, posts, and fraternal halls. These properties act as a safe haven where veterans of all ages and their families can gather together to find camaraderie and fellowship, share stories, and seek support from people who understand their unique experiences. This aids in the healing process for these returning veterans, and ensures their health and happiness.\n(b) As a result of congressional chartering of these veterans’ organizations, the United States Inte

In [4]:
model_name = "google/flan-t5-small"
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [5]:
prefix = "summarize: "

def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)

    labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [6]:
tokenized_billsum = billsum.map(preprocess_function, batched=True)
print(tokenized_billsum["train"][0])

Map:   0%|          | 0/989 [00:00<?, ? examples/s]

Map:   0%|          | 0/248 [00:00<?, ? examples/s]

{'text': 'The people of the State of California do enact as follows:\n\n\nSECTION 1.\n(a) The Legislature finds and declares as follows:\n(1) The waters of the state are of limited supply and are subject to ever-increasing demand.\n(2) Landscapes are essential to the quality of life in California by providing areas for active and passive recreation and as an enhancement to the environment by cleaning air and water, preventing erosion, offering fire protection, and replacing ecosystems lost to development, among other benefits.\n(3) Landscape design, installation, maintenance, and management can and should be water efficient.\n(4) Section 2 of Article X of the California Constitution specifies that the right to use water is limited to the amount reasonably required for the beneficial use to be served and that the right does not extend to the waste or unreasonable use of water.\n(5) Landscapes that are planned, designed, installed, managed, and maintained with a watershed-based approach 

In [7]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

In [9]:
def compute_metrics(eval_preds):
    predictions, labels = eval_preds
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    return {k: round(v, 4) for k, v in result.items()}

In [12]:
training_args = Seq2SeqTrainingArguments(
    output_dir="my_awesome_billsum_model",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=4,
    fp16=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_billsum["train"],
    eval_dataset=tokenized_billsum["test"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
# trainer.train()