In [1]:
import os
import torch
import torch.nn as nn
import bitsandbytes as bnb
import pandas as pd


from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig



In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [3]:
from utils import setup_logger

logger = setup_logger(__name__)
logger.info('Logging Successfully set up')

2025-04-18 16:02:47,401 - __main__ - INFO - Logging Successfully set up


In [17]:
checkpoint = 'facebook/bart-large-cnn'
tokenizer = AutoTokenizer.from_pretrained(
    checkpoint,
    load_in_8bit=True,
    device_map='auto',
)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

In [5]:
for param in model.parameters():
    param.requires_grad = False
    if param.ndim == 1:
        param.data = param.data.to(torch.float32)
model.gradient_checkpointing_enable()  # reduce number of stored activations
model.enable_input_require_grads()

class CastOutputToFloat(nn.Sequential):
    def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)

In [6]:
def print_trainable_parameters(model):
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f'trainable: {trainable_params}, all params: {all_param}, ratio: {trainable_params / all_param:.5f}'
    )


In [7]:
from peft import LoraConfig, get_peft_model

config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias='none',
    # Valid task types: SEQ_CLS, SEQ_2_SEQ_LM, CAUSAL_LM, TOKEN_CLS, QUESTION_ANS, FEATURE_EXTRACTION
    task_type='SEQ_2_SEQ_LM'
)

model = get_peft_model(model, config)

print_trainable_parameters(model)

trainable: 2359296, all params: 408649728, ratio: 0.00577


In [8]:
from datasets import load_dataset

billsum = load_dataset("FiscalNote/billsum")

In [39]:
billsum["train"][0]

{'text': "SECTION 1. LIABILITY OF BUSINESS ENTITIES PROVIDING USE OF FACILITIES \n              TO NONPROFIT ORGANIZATIONS.\n\n    (a) Definitions.--In this section:\n            (1) Business entity.--The term ``business entity'' means a \n        firm, corporation, association, partnership, consortium, joint \n        venture, or other form of enterprise.\n            (2) Facility.--The term ``facility'' means any real \n        property, including any building, improvement, or appurtenance.\n            (3) Gross negligence.--The term ``gross negligence'' means \n        voluntary and conscious conduct by a person with knowledge (at \n        the time of the conduct) that the conduct is likely to be \n        harmful to the health or well-being of another person.\n            (4) Intentional misconduct.--The term ``intentional \n        misconduct'' means conduct by a person with knowledge (at the \n        time of the conduct) that the conduct is harmful to the health \n        or w

In [9]:
prefix = "summarize: "


def preprocess_function(examples):
    inputs = [prefix + doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)

    labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_billsum = billsum.map(preprocess_function, batched=True)

Map:   0%|          | 0/3269 [00:00<?, ? examples/s]

In [12]:
import evaluate

rouge = evaluate.load("rouge")

In [13]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [18]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq

seq2seq_args = Seq2SeqTrainingArguments(
    output_dir="../models/ragsum-bart-billsum",
    eval_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=4,
    predict_with_generate=True,
    warmup_steps=100,
    max_steps=200,
    fp16=True,
    logging_steps=8,
)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

trainer = Seq2SeqTrainer(
    model=model,
    args=seq2seq_args,
    train_dataset=tokenized_billsum['train'],
    eval_dataset=tokenized_billsum['test'],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)
trainer.train()


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
0,1.5528,1.495035,0.5526,0.3337,0.4163,0.4505,128.952




TrainOutput(global_step=200, training_loss=1.7837346076965332, metrics={'train_runtime': 1007.3298, 'train_samples_per_second': 3.177, 'train_steps_per_second': 0.199, 'total_flos': 6934734726758400.0, 'train_loss': 1.7837346076965332, 'epoch': 0.16877637130801687})

In [20]:
import os

save_path = os.path.join("..", "models", "ragsum-bart-billsum")
os.makedirs(save_path, exist_ok=True)

trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)

('../models/ragsum-bart-billsum/tokenizer_config.json',
 '../models/ragsum-bart-billsum/special_tokens_map.json',
 '../models/ragsum-bart-billsum/vocab.json',
 '../models/ragsum-bart-billsum/merges.txt',
 '../models/ragsum-bart-billsum/added_tokens.json',
 '../models/ragsum-bart-billsum/tokenizer.json')

In [21]:
model = AutoModelForSeq2SeqLM.from_pretrained("../models/ragsum-bart-billsum")
tokenizer = AutoTokenizer.from_pretrained("../models/ragsum-bart-billsum")



In [22]:
text = """
summarize: The United States Congress has approved a sweeping infrastructure bill, marking one of the most significant legislative efforts in recent years. The bill, which totals $1.2 trillion in funding, aims to revamp aging infrastructure nationwide. Key areas of investment include transportation — such as roads, railways, and bridges — clean energy initiatives, and expanded broadband internet access. The legislation received bipartisan support in both the House and Senate, signaling rare political cooperation in an otherwise divided climate. Proponents argue that this investment will create jobs, stimulate the economy, and lay the groundwork for long-term national competitiveness.
"""

inputs = tokenizer(text, return_tensors="pt").input_ids
inputs

tensor([[    0, 50118, 18581,  3916,  2072,    35,    20,   315,   532,  1148,
            34,  2033,    10,  9893,  2112,  1087,     6, 10032,    65,     9,
             5,   144,  1233,  5615,  1170,    11,   485,   107,     4,    20,
          1087,     6,    61, 17582,    68,   134,     4,   176,  4700,    11,
          1435,     6,  5026,     7,  6910,  3914, 10662,  2112,  5807,     4,
          4300,   911,     9,   915,   680,  4264,    93,   215,    25,  3197,
             6, 24396,     6,     8, 11879,    93,  2382,  1007,  5287,     6,
             8,  4939, 11451,  2888,   899,     4,    20,  2309,   829, 10094,
           323,    11,   258,     5,   446,     8,  1112,     6, 22436,  3159,
           559,  4601,    11,    41,  3680,  6408,  2147,     4, 13695, 19245,
          5848,    14,    42,   915,    40,  1045,  1315,     6, 19770,     5,
           866,     6,     8,  4477,     5, 27615,    13,   251,    12,  1279,
           632, 17755,     4, 50118,     2]])

In [25]:
outputs = model.generate(inputs, max_new_tokens=100, do_sample=False)
summary = tokenizer.decode(outputs[0], skip_special_tokens=True)

In [26]:
reference_summary = """
Congress passed a new bill aimed at improving infrastructure across the U.S., allocating $1.2 trillion in funding over the next ten years. The legislation focuses on roads, bridges, clean energy, and broadband access, with bipartisan support marking a significant political achievement.
"""

In [27]:
results = rouge.compute(predictions=[summary], references=[reference_summary], use_stemmer=True)

# Optional: round the results
results = {k: round(v, 4) for k, v in results.items()}

print(results)


{'rouge1': 0.413, 'rouge2': 0.1333, 'rougeL': 0.3043, 'rougeLsum': 0.3043}
