In [1]:
import os
import sys
sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size=16
num_workers=4
num_samples=4
ci_ratio=0.3
seed=44
include_layers=["attention", "intermediate", "output"]
exclude_layers=None

In [None]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

In [4]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.
{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}
The model sadickam/sdg-classification-bert is loaded.


In [5]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [7]:
for concern in range(num_labels):
    train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True
    )
    
    positive_samples = SamplingDataset(
        train_dataloader, concern, num_samples, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train_dataloader, concern, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    
    module = copy.deepcopy(model)
    
    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )
    
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)
    
    similar(model, module, valid_dataloader, concern, num_samples, num_labels, device=device, seed=seed)
    
    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}
Loading cached dataset OSDG.
The dataset OSDG is loaded
Evaluate the pruned model 0


Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:44<00:00,  1.92it/s]


Loss: 0.9423
Precision: 0.7781, Recall: 0.7828, F1-Score: 0.7762
              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.88      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.86      0.80      0.83       940
           7       0.48      0.60      0.53       473
           8       0.66      0.85      0.74       746
           9       0.59      0.73      0.66       689
          10       0.76      0.78      0.77       670
          11       0.62      0.80      0.70       312
          12       0.71      0.82      0.76       665
          13       0.85      0.84      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:45<00:00,  1.90it/s]


Loss: 0.9416
Precision: 0.7776, Recall: 0.7840, F1-Score: 0.7766
              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.86      0.80      0.82       940
           7       0.48      0.59      0.53       473
           8       0.66      0.86      0.74       746
           9       0.60      0.73      0.66       689
          10       0.77      0.78      0.77       670
          11       0.62      0.81      0.70       312
          12       0.71      0.80      0.76       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:46<00:00,  1.87it/s]


Loss: 0.9413
Precision: 0.7772, Recall: 0.7829, F1-Score: 0.7762
              precision    recall  f1-score   support

           0       0.75      0.66      0.71       797
           1       0.85      0.70      0.77       775
           2       0.87      0.88      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.47      0.60      0.53       473
           8       0.67      0.85      0.75       746
           9       0.60      0.73      0.66       689
          10       0.75      0.79      0.77       670
          11       0.63      0.79      0.70       312
          12       0.72      0.80      0.76       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:47<00:00,  1.85it/s]


Loss: 0.9454
Precision: 0.7792, Recall: 0.7828, F1-Score: 0.7767
              precision    recall  f1-score   support

           0       0.75      0.66      0.71       797
           1       0.85      0.70      0.77       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.90      0.67      0.77       882
           6       0.85      0.80      0.82       940
           7       0.48      0.60      0.53       473
           8       0.65      0.85      0.74       746
           9       0.60      0.73      0.66       689
          10       0.76      0.79      0.77       670
          11       0.64      0.80      0.71       312
          12       0.71      0.81      0.76       665
          13       0.86      0.85      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:50<00:00,  1.81it/s]


Loss: 0.9437
Precision: 0.7787, Recall: 0.7835, F1-Score: 0.7770
              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.87      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.85      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.83       940
           7       0.48      0.60      0.53       473
           8       0.66      0.85      0.74       746
           9       0.60      0.73      0.66       689
          10       0.76      0.78      0.77       670
          11       0.63      0.81      0.71       312
          12       0.72      0.81      0.76       665
          13       0.84      0.85      0.85       314
          14       0.86      0.78      0.82       756
          15       0.97      0.97      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:51<00:00,  1.80it/s]


Loss: 0.9402
Precision: 0.7775, Recall: 0.7839, F1-Score: 0.7767
              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.69      0.77       882
           6       0.86      0.79      0.83       940
           7       0.49      0.59      0.54       473
           8       0.66      0.85      0.74       746
           9       0.60      0.73      0.66       689
          10       0.76      0.79      0.77       670
          11       0.62      0.80      0.70       312
          12       0.71      0.81      0.76       665
          13       0.83      0.86      0.84       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:51<00:00,  1.80it/s]


Loss: 0.9406
Precision: 0.7781, Recall: 0.7841, F1-Score: 0.7771
              precision    recall  f1-score   support

           0       0.75      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.88      0.88      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.87      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.84      0.80      0.82       940
           7       0.48      0.60      0.53       473
           8       0.67      0.86      0.75       746
           9       0.59      0.73      0.66       689
          10       0.75      0.79      0.77       670
          11       0.63      0.80      0.71       312
          12       0.72      0.80      0.76       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:53<00:00,  1.76it/s]


Loss: 0.9451
Precision: 0.7779, Recall: 0.7836, F1-Score: 0.7766
              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.88      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.85      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.83       940
           7       0.48      0.59      0.53       473
           8       0.66      0.86      0.75       746
           9       0.60      0.73      0.66       689
          10       0.75      0.79      0.77       670
          11       0.62      0.80      0.70       312
          12       0.72      0.81      0.76       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:52<00:00,  1.77it/s]


Loss: 0.9435
Precision: 0.7772, Recall: 0.7832, F1-Score: 0.7759
              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.84      0.72      0.77       775
           2       0.88      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.86      0.80      0.83       940
           7       0.49      0.59      0.54       473
           8       0.65      0.86      0.74       746
           9       0.59      0.73      0.65       689
          10       0.76      0.78      0.77       670
          11       0.62      0.81      0.70       312
          12       0.71      0.81      0.76       665
          13       0.84      0.85      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:51<00:00,  1.79it/s]


Loss: 0.9440
Precision: 0.7784, Recall: 0.7834, F1-Score: 0.7768
              precision    recall  f1-score   support

           0       0.75      0.66      0.70       797
           1       0.84      0.72      0.77       775
           2       0.88      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.86      0.79      0.82       940
           7       0.48      0.60      0.53       473
           8       0.66      0.85      0.75       746
           9       0.59      0.74      0.66       689
          10       0.75      0.79      0.77       670
          11       0.64      0.80      0.71       312
          12       0.71      0.82      0.76       665
          13       0.84      0.85      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:53<00:00,  1.76it/s]


Loss: 0.9431
Precision: 0.7780, Recall: 0.7841, F1-Score: 0.7768
              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.70      0.77       775
           2       0.87      0.88      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.86      0.80      0.83       940
           7       0.48      0.60      0.53       473
           8       0.66      0.85      0.75       746
           9       0.60      0.74      0.66       689
          10       0.75      0.79      0.77       670
          11       0.63      0.81      0.70       312
          12       0.72      0.81      0.76       665
          13       0.83      0.85      0.84       314
          14       0.85      0.78      0.82       756
          15       0.97      0.97      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:54<00:00,  1.75it/s]


Loss: 0.9428
Precision: 0.7776, Recall: 0.7842, F1-Score: 0.7767
              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.78       775
           2       0.87      0.88      0.87       795
           3       0.87      0.82      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.68      0.77       882
           6       0.85      0.80      0.83       940
           7       0.48      0.59      0.53       473
           8       0.67      0.85      0.75       746
           9       0.59      0.74      0.66       689
          10       0.76      0.78      0.77       670
          11       0.62      0.81      0.70       312
          12       0.72      0.81      0.76       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:54<00:00,  1.74it/s]


Loss: 0.9417
Precision: 0.7773, Recall: 0.7830, F1-Score: 0.7759
              precision    recall  f1-score   support

           0       0.77      0.65      0.71       797
           1       0.85      0.71      0.78       775
           2       0.87      0.88      0.88       795
           3       0.87      0.82      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.47      0.59      0.52       473
           8       0.66      0.86      0.75       746
           9       0.59      0.73      0.65       689
          10       0.75      0.79      0.77       670
          11       0.63      0.79      0.70       312
          12       0.71      0.82      0.76       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:54<00:00,  1.75it/s]


Loss: 0.9383
Precision: 0.7764, Recall: 0.7831, F1-Score: 0.7752
              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.88      0.80      0.84      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.48      0.59      0.53       473
           8       0.66      0.85      0.74       746
           9       0.58      0.74      0.65       689
          10       0.77      0.78      0.78       670
          11       0.61      0.81      0.70       312
          12       0.72      0.81      0.76       665
          13       0.83      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:54<00:00,  1.74it/s]


Loss: 0.9418
Precision: 0.7778, Recall: 0.7839, F1-Score: 0.7765
              precision    recall  f1-score   support

           0       0.76      0.65      0.70       797
           1       0.85      0.71      0.77       775
           2       0.88      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.87      0.79      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.86      0.80      0.83       940
           7       0.48      0.60      0.53       473
           8       0.66      0.85      0.75       746
           9       0.59      0.74      0.66       689
          10       0.76      0.78      0.77       670
          11       0.62      0.80      0.70       312
          12       0.72      0.81      0.76       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy   

Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 200/200 [01:59<00:00,  1.67it/s]


Loss: 0.9434
Precision: 0.7769, Recall: 0.7813, F1-Score: 0.7745
              precision    recall  f1-score   support

           0       0.77      0.64      0.70       797
           1       0.85      0.71      0.77       775
           2       0.87      0.87      0.87       795
           3       0.87      0.82      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.69      0.77       882
           6       0.85      0.80      0.83       940
           7       0.48      0.60      0.53       473
           8       0.66      0.85      0.74       746
           9       0.58      0.74      0.65       689
          10       0.78      0.78      0.78       670
          11       0.63      0.80      0.71       312
          12       0.70      0.81      0.75       665
          13       0.82      0.85      0.84       314
          14       0.86      0.76      0.81       756
          15       0.96      0.98      0.97      1607

    accuracy   