In [4]:
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, AutoModelForSeq2SeqLM
import evaluate
import numpy as np

In [5]:
import wandb


wandb.init(
    project="briefly-v1",  
    name="briefly-trialsv1",    
    config={                   
        "learning_rate": 2e-5,
        "epochs": 4,
        "batch_size": 16
    }
)

In [7]:
billsum = load_dataset("billsum", split="ca_test")

In [8]:
billsum = billsum.train_test_split(test_size = 0.25)

In [9]:
billsum['train'][2]

{'text': 'The people of the State of California do enact as follows:\n\n\nSECTION 1.\nSection 51203 of the Government Code is amended to read:\n51203.\n(a) The assessor shall determine the current fair market value of the land as if it were free of the contractual restriction pursuant to Section 51283. The Department of Conservation or the landowner, also referred to in this section as “parties,” may provide information to assist the assessor to determine the value. Any information provided to the assessor shall be served on the other party, unless the information was provided at the request of the assessor, and would be confidential under law if required of an assessee.\n(b) Within 45 days of receiving the assessor’s notice pursuant to subdivision (a) of Section 51283 or Section 51283.4, if the Department of Conservation or the landowner believes that the current fair market valuation certified pursuant to subdivision (b) of Section 51283 or Section 51283.4 is not accurate, the depart

In [10]:
model_checkpoint = "google-t5/t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [11]:
prefix_prompt = "Summarize: "

In [12]:
def preprocess_text(txt):
    inputs = [prefix_prompt + doc for doc in txt["text"]]
    tokenix = tokenizer(inputs, max_length=1024, truncation=True)
    labels = tokenizer(text_target=txt["summary"], max_length=128, truncation=True)
    tokenix["labels"] = labels["input_ids"]
    return tokenix

In [13]:
billsum_tokeniz = billsum.map(preprocess_text, batched=True)

Map: 100%|██████████| 927/927 [00:00<00:00, 1032.30 examples/s]
Map: 100%|██████████| 310/310 [00:00<00:00, 970.33 examples/s] 


In [14]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model_checkpoint)

In [15]:
rouge = evaluate.load("rouge")
rouge

EvaluationModule(name: "rouge", module_type: "metric", features: [{'predictions': Value(dtype='string', id='sequence'), 'references': Sequence(feature=Value(dtype='string', id='sequence'), length=-1, id=None)}, {'predictions': Value(dtype='string', id='sequence'), 'references': Value(dtype='string', id='sequence')}], usage: """
Calculates average rouge scores for a list of hypotheses and references
Args:
    predictions: list of predictions to score. Each prediction
        should be a string with tokens separated by spaces.
    references: list of reference for each prediction. Each
        reference should be a string with tokens separated by spaces.
    rouge_types: A list of rouge types to calculate.
        Valid names:
        `"rouge{n}"` (e.g. `"rouge1"`, `"rouge2"`) where: {n} is the n-gram based scoring,
        `"rougeL"`: Longest common subsequence based scoring.
        `"rougeLsum"`: rougeLsum splits text using `"
"`.
        See details in https://github.com/huggingface/

In [16]:
def compute_metrics(pred):
    preds, labels = pred
    decode_pred = tokenizer.batch_decode(preds, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_type_id)
    decode_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    rouge_metric = rouge.compute(predictions=decode_pred, references=decode_labels, use_stemmer = True)
    prediction_lens = [np.count_nonzero(pred!=tokenizer.pad_token_type_id) for pred in preds]
    rouge_metric["gen_len"] = np.mean(prediction_lens)
    return {k: round(v,4) for k,v in rouge_metric.items()}

In [17]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

In [18]:
training_args = Seq2SeqTrainingArguments(
    output_dir="text_summarizerv1",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16, 
    per_device_eval_batch_size=16, 
    weight_decay=0.01,
    save_total_limit=3, 
    num_train_epochs=4, 
    predict_with_generate=True,
    report_to="wandb",
    fp16=True, 
)

In [19]:
trainer =Seq2SeqTrainer(
    model = model, 
    args= training_args,
    train_dataset=billsum_tokeniz["train"], 
    eval_dataset=billsum_tokeniz["test"],
    processing_class=tokenizer, 
    data_collator= data_collator, 
    compute_metrics=compute_metrics,
)

In [20]:
trainer.train()

  0%|          | 0/232 [00:00<?, ?it/s]Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
 25%|██▌       | 58/232 [00:55<02:32,  1.14it/s]
  0%|          | 0/20 [00:00<?, ?it/s]
 10%|█         | 2/20 [00:00<00:07,  2.55it/s]
 15%|█▌        | 3/20 [00:01<00:09,  1.78it/s]
 20%|██        | 4/20 [00:02<00:10,  1.57it/s]
 25%|██▌       | 5/20 [00:03<00:10,  1.47it/s]
 30%|███       | 6/20 [00:03<00:09,  1.41it/s]
 35%|███▌      | 7/20 [00:04<00:09,  1.37it/s]
 40%|████      | 8/20 [00:05<00:08,  1.35it/s]
 45%|████▌     | 9/20 [00:06<00:08,  1.34it/s]
 50%|█████     | 10/20 [00:06<00:07,  1.33it/s]
 55%|█████▌    | 11/20 [00:07<00:06,  1.32it/s]
 60%|██████    | 12/20 [00:08<00:06,  1.31it/s]
 65%|██████▌   | 13/20 [00:09<00:05,  1.32it/s]
 70%|███████   | 14/20 [00:09<00:04,  1.32it/s]
 75%|███████▌  | 15/

{'eval_loss': 2.8220036029815674, 'eval_rouge1': 0.1283, 'eval_rouge2': 0.0393, 'eval_rougeL': 0.1055, 'eval_rougeLsum': 0.1055, 'eval_gen_len': 20.0, 'eval_runtime': 16.0072, 'eval_samples_per_second': 19.366, 'eval_steps_per_second': 1.249, 'epoch': 1.0}


 25%|██▌       | 58/232 [01:11<02:32,  1.14it/s]
100%|██████████| 20/20 [00:15<00:00,  1.47it/s]
 50%|█████     | 116/232 [02:04<01:32,  1.25it/s]
  0%|          | 0/20 [00:00<?, ?it/s]
 10%|█         | 2/20 [00:00<00:06,  2.82it/s]
 15%|█▌        | 3/20 [00:01<00:08,  1.96it/s]
 20%|██        | 4/20 [00:02<00:09,  1.71it/s]
 25%|██▌       | 5/20 [00:02<00:09,  1.58it/s]
 30%|███       | 6/20 [00:03<00:09,  1.50it/s]
 35%|███▌      | 7/20 [00:04<00:08,  1.46it/s]
 40%|████      | 8/20 [00:05<00:08,  1.44it/s]
 45%|████▌     | 9/20 [00:05<00:07,  1.42it/s]
 50%|█████     | 10/20 [00:06<00:07,  1.41it/s]
 55%|█████▌    | 11/20 [00:07<00:06,  1.41it/s]
 60%|██████    | 12/20 [00:07<00:05,  1.40it/s]
 65%|██████▌   | 13/20 [00:08<00:04,  1.40it/s]
 70%|███████   | 14/20 [00:09<00:04,  1.40it/s]
 75%|███████▌  | 15/20 [00:10<00:03,  1.40it/s]
 80%|████████  | 16/20 [00:10<00:02,  1.40it/s]
 85%|████████▌ | 17/20 [00:11<00:02,  1.40it/s]
 90%|█████████ | 18/20 [00:12<00:01,  1.39it/s]
 95%|█

{'eval_loss': 2.5943057537078857, 'eval_rouge1': 0.1374, 'eval_rouge2': 0.0482, 'eval_rougeL': 0.1121, 'eval_rougeLsum': 0.1121, 'eval_gen_len': 20.0, 'eval_runtime': 14.8041, 'eval_samples_per_second': 20.94, 'eval_steps_per_second': 1.351, 'epoch': 2.0}


 50%|█████     | 116/232 [02:19<01:32,  1.25it/s]
100%|██████████| 20/20 [00:14<00:00,  1.64it/s]
 75%|███████▌  | 174/232 [03:09<00:47,  1.23it/s]
  0%|          | 0/20 [00:00<?, ?it/s]
 10%|█         | 2/20 [00:00<00:06,  2.80it/s]
 15%|█▌        | 3/20 [00:01<00:08,  1.97it/s]
 20%|██        | 4/20 [00:02<00:09,  1.72it/s]
 25%|██▌       | 5/20 [00:02<00:09,  1.59it/s]
 30%|███       | 6/20 [00:03<00:09,  1.52it/s]
 35%|███▌      | 7/20 [00:04<00:08,  1.47it/s]
 40%|████      | 8/20 [00:05<00:08,  1.45it/s]
 45%|████▌     | 9/20 [00:05<00:07,  1.42it/s]
 50%|█████     | 10/20 [00:06<00:07,  1.42it/s]
 55%|█████▌    | 11/20 [00:07<00:06,  1.41it/s]
 60%|██████    | 12/20 [00:07<00:05,  1.41it/s]
 65%|██████▌   | 13/20 [00:08<00:05,  1.40it/s]
 70%|███████   | 14/20 [00:09<00:04,  1.40it/s]
 75%|███████▌  | 15/20 [00:10<00:03,  1.40it/s]
 80%|████████  | 16/20 [00:10<00:02,  1.40it/s]
 85%|████████▌ | 17/20 [00:11<00:02,  1.41it/s]
 90%|█████████ | 18/20 [00:12<00:01,  1.40it/s]
 95%|

{'eval_loss': 2.5324058532714844, 'eval_rouge1': 0.1414, 'eval_rouge2': 0.0497, 'eval_rougeL': 0.1158, 'eval_rougeLsum': 0.1159, 'eval_gen_len': 20.0, 'eval_runtime': 14.7292, 'eval_samples_per_second': 21.047, 'eval_steps_per_second': 1.358, 'epoch': 3.0}


 75%|███████▌  | 174/232 [03:24<00:47,  1.23it/s]
100%|██████████| 20/20 [00:14<00:00,  1.65it/s]
100%|██████████| 232/232 [04:14<00:00,  1.25it/s]
  0%|          | 0/20 [00:00<?, ?it/s]
 10%|█         | 2/20 [00:00<00:06,  2.85it/s]
 15%|█▌        | 3/20 [00:01<00:08,  1.98it/s]
 20%|██        | 4/20 [00:02<00:09,  1.69it/s]
 25%|██▌       | 5/20 [00:02<00:09,  1.56it/s]
 30%|███       | 6/20 [00:03<00:09,  1.51it/s]
 35%|███▌      | 7/20 [00:04<00:08,  1.47it/s]
 40%|████      | 8/20 [00:05<00:08,  1.44it/s]
 45%|████▌     | 9/20 [00:05<00:07,  1.42it/s]
 50%|█████     | 10/20 [00:06<00:07,  1.41it/s]
 55%|█████▌    | 11/20 [00:07<00:06,  1.39it/s]
 60%|██████    | 12/20 [00:07<00:05,  1.39it/s]
 65%|██████▌   | 13/20 [00:08<00:05,  1.39it/s]
 70%|███████   | 14/20 [00:09<00:04,  1.39it/s]
 75%|███████▌  | 15/20 [00:10<00:03,  1.39it/s]
 80%|████████  | 16/20 [00:10<00:02,  1.40it/s]
 85%|████████▌ | 17/20 [00:11<00:02,  1.39it/s]
 90%|█████████ | 18/20 [00:12<00:01,  1.40it/s]
 95%|

{'eval_loss': 2.5171501636505127, 'eval_rouge1': 0.1434, 'eval_rouge2': 0.0516, 'eval_rougeL': 0.118, 'eval_rougeLsum': 0.1178, 'eval_gen_len': 20.0, 'eval_runtime': 14.8075, 'eval_samples_per_second': 20.935, 'eval_steps_per_second': 1.351, 'epoch': 4.0}


100%|██████████| 232/232 [04:29<00:00,  1.25it/s]
100%|██████████| 20/20 [00:14<00:00,  1.64it/s]
                                                 

{'train_runtime': 269.4984, 'train_samples_per_second': 13.759, 'train_steps_per_second': 0.861, 'train_loss': 3.109038780475485, 'epoch': 4.0}


100%|██████████| 232/232 [04:29<00:00,  1.16s/it]


TrainOutput(global_step=232, training_loss=3.109038780475485, metrics={'train_runtime': 269.4984, 'train_samples_per_second': 13.759, 'train_steps_per_second': 0.861, 'total_flos': 1003694799716352.0, 'train_loss': 3.109038780475485, 'epoch': 4.0})

In [21]:
wandb.finish()

0,1
eval/gen_len,▁▁▁▁
eval/loss,█▃▁▁
eval/rouge1,▁▅▇█
eval/rouge2,▁▆▇█
eval/rougeL,▁▅▇█
eval/rougeLsum,▁▅▇█
eval/runtime,█▁▁▁
eval/samples_per_second,▁███
eval/steps_per_second,▁███
train/epoch,▁▃▆██

0,1
eval/gen_len,20.0
eval/loss,2.51715
eval/rouge1,0.1434
eval/rouge2,0.0516
eval/rougeL,0.118
eval/rougeLsum,0.1178
eval/runtime,14.8075
eval/samples_per_second,20.935
eval/steps_per_second,1.351
total_flos,1003694799716352.0


In [22]:
model.save_pretrained('textsum-v1-model')
tokenizer.save_pretrained('tokenix-textsum-v1')

('tokenix-textsum-v1\\tokenizer_config.json',
 'tokenix-textsum-v1\\special_tokens_map.json',
 'tokenix-textsum-v1\\spiece.model',
 'tokenix-textsum-v1\\added_tokens.json',
 'tokenix-textsum-v1\\tokenizer.json')