In [1]:
import os
import sys
sys.path.append("../../../")

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_wanda
)

In [3]:
name= "OSDG"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

checkpoint = None
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]

model, tokenizer, checkpoint = load_model(model_config)

Loading the model.
{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}
The model sadickam/sdg-classification-bert is loaded.


In [4]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=32, num_workers=48
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}
Loading cached dataset OSDG.
The dataset OSDG is loaded


In [5]:
color_print("Start Time:" + datetime.now().strftime("%H:%M:%S"))
num_samples = 64

all_samples = SamplingDataset(
    train_dataloader, 200, 20, num_labels, False, 4, device=device
)

Start Time:21:05:30


In [6]:
print("Evaluate the original model")
result = evaluate_model(model, model_config, test_dataloader)

Evaluate the original model


Evaluating: 100%|██████████| 200/200 [02:56<00:00,  1.13it/s]


Loss: 0.9485
Precision: 0.7801, Recall: 0.7867, F1-Score: 0.7793
              precision    recall  f1-score   support

           0       0.77      0.66      0.71       797
           1       0.84      0.72      0.78       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.69      0.77       882
           6       0.85      0.80      0.83       940
           7       0.49      0.61      0.54       473
           8       0.66      0.85      0.74       746
           9       0.62      0.73      0.67       689
          10       0.75      0.79      0.77       670
          11       0.62      0.81      0.70       312
          12       0.73      0.81      0.77       665
          13       0.83      0.85      0.84       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy   

In [None]:
print("Evaluate the pruned model")
prune_wanda(model, all_samples, sparsity_ratio=0.4)
result = evaluate_model(model, model_config, test_dataloader)
save_module(model, "Modules/", "wanda_osdg_40p.pt")

Evaluate the pruned model


Evaluating:  16%|█▌        | 31/200 [00:27<02:47,  1.01it/s]

In [None]:
get_sparsity(model)