In [30]:
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, AutoModelForSeq2SeqLM
import evaluate
import numpy as np

In [36]:
import wandb


wandb.init(
    project="briefly-v1",  
    name="briefly-trialsv1",    
    config={                   
        "learning_rate": 2e-5,
        "epochs": 4,
        "batch_size": 16
    }
)

In [7]:
billsum = load_dataset("billsum", split="ca_test")

In [8]:
billsum = billsum.train_test_split(test_size = 0.25)

In [11]:
billsum['train'][2]

{'text': 'The people of the State of California do enact as follows:\n\n\nSECTION 1.\nSection 1250.8 of the Health and Safety Code is amended to read:\n1250.8.\n(a) Notwithstanding subdivision (a) of Section 127170, the department, upon application of a general acute care hospital that meets all the criteria of subdivision (b), and other applicable requirements of licensure, shall issue a single consolidated license to a general acute care hospital that includes more than one physical plant maintained and operated on separate premises or that has multiple licenses for a single health facility on the same premises. A single consolidated license shall not be issued where the separate freestanding physical plant is a skilled nursing facility or an intermediate care facility, whether or not the location of the skilled nursing facility or intermediate care facility is contiguous to the general acute care hospital unless the hospital is exempt from the requirements of subdivision (b) of Sect

In [20]:
model_checkpoint = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [14]:
prefix_prompt = "Summarize: "

In [16]:
def preprocess_text(txt):
    inputs = [prefix_prompt + doc for doc in txt["text"]]
    tokenix = tokenizer(inputs, max_length=1024, truncation=True)
    labels = tokenizer(text_target=txt["summary"], max_length=128, truncation=True)
    tokenix["labels"] = labels["input_ids"]
    return tokenix

In [17]:
billsum_tokeniz = billsum.map(preprocess_text, batched=True)

Map: 100%|██████████| 927/927 [00:00<00:00, 989.75 examples/s] 
Map: 100%|██████████| 310/310 [00:00<00:00, 947.58 examples/s] 


In [21]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model_checkpoint)

In [24]:
rouge = evaluate.load("rouge")
rouge

EvaluationModule(name: "rouge", module_type: "metric", features: [{'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id=None)}, {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}], usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each prediction
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLsum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/

In [26]:
def compute_metrics(pred):
    preds, labels = pred
    decode_pred = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_type_id)
    decode_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    rouge_metric = rouge.compute(predictions=decode_pred, references=decode_labels, use_stemmer = True)
    prediction_lens = [np.count_nonzero(pred!=tokenizer.pad_token_type_id) for pred in preds]
    rouge_metric["gen_len"] = np.mean(prediction_lens)
    return {k: round(v,4) for k,v in rouge_metric.items()}

In [31]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [32]:
training_args = Seq2SeqTrainingArguments(
    output_dir="text_summarizerv1",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16, 
    per_device_eval_batch_size=16, 
    weight_decay=0.01,
    save_total_limit=3, 
    num_train_epochs=4, 
    predict_with_generate=True,
    report_to="wandb",
    fp16=True, 
)

In [33]:
trainer =Seq2SeqTrainer(
    model = model, 
    args= training_args,
    train_dataset=billsum_tokeniz["train"], 
    eval_dataset=billsum_tokeniz["test"],
    processing_class=tokenizer, 
    data_collator= data_collator, 
    compute_metrics=compute_metrics,
)

In [34]:
trainer.train()

  0%|          | 0/232 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
  9%|▉         | 21/232 [03:18<56:18, 16.01s/it]  

KeyboardInterrupt: 