# Load Model

## Base Model

### Tokenizer

In [None]:
from transformers import AutoTokenizer

TK_ckpt = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(TK_ckpt)

### Load T5-Model/Checkpoint

In [None]:
from transformers import AutoModelForSeq2SeqLM
checkpoint = "T5_Small_pruned"
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

### Import pruned Model Structure

In [None]:
import torch
import torch.nn.utils.prune as prune

for module in model.modules():
    if isinstance(module, torch.nn.Linear):
        prune.identity(module, 'weight')

model.load_state_dict(torch.load(checkpoint+'/pruneModel.pth'))

# Model Pruning

### Get model pruning ratio

In [None]:
def show_param_ratio(model):
    num_param = 0
    for param in model.parameters():
        num_param += param.numel()
    num_mask = 0
    for name, param in model.named_buffers():
        if "mask" in name:
            num_mask += (param == 0).sum()
    print((num_param - num_mask) / num_param)

show_param_ratio(model)

### Prune the model

In [None]:
import torch
import torch.nn.utils.prune

pruning_method = torch.nn.utils.prune.RandomUnstructured
pruning_rate = 0.71

parameters_to_prune = []
for _, module in model.named_modules():
    if isinstance(module, torch.nn.Linear):
        parameters_to_prune.append((module, "weight"))
torch.nn.utils.prune.global_unstructured(
    parameters_to_prune,
    pruning_method=pruning_method,
    amount=pruning_rate,
)
show_param_ratio(model)

# Training

### Training dataset preprocessing, add input prefix

In [None]:
def preprocess_function(examples):
    inputs = ["summarize: " + doc for doc in examples["text"]]
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)

    labels = tokenizer(text_target=examples["summary"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

### Data collator

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

### Compute metrics

In [None]:
import numpy as np
import evaluate

rouge = evaluate.load("rouge")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

### Load training and evaluate dataset

In [None]:
from datasets import load_dataset

billsum = load_dataset("billsum", split="train")
billsum = billsum.train_test_split(test_size=0.2)
tokenized_billsum = billsum.map(preprocess_function, batched=True)

### Training config

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="T5_Small_pruned",
    seed=42,
    learning_rate=3e-5,
    num_train_epochs=10,
    per_device_train_batch_size=10,
    per_device_eval_batch_size=10,
    evaluation_strategy="steps",
    eval_steps=1000,
    save_steps=1000,
    warmup_steps=2000,
    weight_decay=0.01,
    # save_total_limit=5,
    logging_steps=100,
    lr_scheduler_type="cosine",
    fp16=True,
    load_best_model_at_end=True,
    predict_with_generate=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_billsum["train"],
    eval_dataset=tokenized_billsum["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

### Training

In [None]:
trainer.train()

### Evaluate

In [None]:
trainer.evaluate(tokenized_billsum["test"])

# Save pruned model after training

In [None]:
import torch
path = 'T5_Small_pruned'
model.save_pretrained(path, from_pt=True)
torch.save(model.state_dict(), path+'/pruneModel.pth')

# Dump prediction

### Load test dataset

In [None]:
from datasets import load_dataset

billsum_test = load_dataset("billsum", split="test")
tokenized_billsum_test = billsum_test.map(preprocess_function, batched=True)

### Predict test dataset

In [None]:
results = trainer.predict(tokenized_billsum_test)

### Decode predict context

In [None]:
decoded_prediction = tokenizer.batch_decode(results[0], skip_special_tokens=True)

### Print all predict content

In [None]:
import pandas as pd
import csv

df_results = pd.DataFrame(columns=["ID", "Predict"])

for i, prediction in enumerate(decoded_prediction):
    summary_escaped = prediction.replace(",", ".")

    new_row = pd.DataFrame({"ID": [i], "Predict": [summary_escaped]})
    df_results = pd.concat([df_results, new_row], ignore_index=True)


print(df_results)

### Remove escape character

In [None]:
def escape_special_characters(text):
    return text.replace('"', '""').replace('\n', ' ')

df_results['Predict'] = df_results['Predict'].apply(escape_special_characters)

### Dump prediction

In [None]:
df_results.to_csv('result.csv', index=False, quoting=csv.QUOTE_ALL, encoding='utf-8')