In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.6
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-25 15:53:40


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}




The model sadickam/sdg-classification-bert is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}




Loading cached dataset OSDG.




The dataset OSDG is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Evaluate the pruned model 0




Evaluating:   0%|                                                                              | 0/800 [00:15<…

Loss: 0.9746




Precision: 0.7428, Recall: 0.7126, F1-Score: 0.7190




              precision    recall  f1-score   support

           0       0.74      0.58      0.65       797
           1       0.82      0.59      0.68       775
           2       0.88      0.81      0.84       795
           3       0.86      0.78      0.82      1110
           4       0.79      0.80      0.79      1260
           5       0.92      0.60      0.73       882
           6       0.81      0.74      0.77       940
           7       0.43      0.42      0.43       473
           8       0.58      0.79      0.67       746
           9       0.49      0.72      0.58       689
          10       0.77      0.68      0.72       670
          11       0.68      0.64      0.66       312
          12       0.61      0.79      0.69       665
          13       0.88      0.75      0.81       314
          14       0.84      0.74      0.79       756
          15       0.79      0.97      0.87      1607

    accuracy                           0.74     12791
   macro avg       0.74   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6355810489943264, 0.6355810489943264)




CCA coefficients mean non-concern: (0.6235407541074869, 0.6235407541074869)




Linear CKA concern: 0.7086540436850537




Linear CKA non-concern: 0.5697484551881546




Kernel CKA concern: 0.7191332009924182




Kernel CKA non-concern: 0.5913748734142253




Evaluate the pruned model 1




Evaluating:   0%|                                                                              | 0/800 [00:22<…

Loss: 0.9500




Precision: 0.7444, Recall: 0.7151, F1-Score: 0.7212




              precision    recall  f1-score   support

           0       0.76      0.55      0.64       797
           1       0.80      0.65      0.72       775
           2       0.88      0.80      0.84       795
           3       0.87      0.76      0.82      1110
           4       0.78      0.79      0.78      1260
           5       0.91      0.63      0.74       882
           6       0.85      0.73      0.78       940
           7       0.45      0.38      0.41       473
           8       0.62      0.77      0.69       746
           9       0.45      0.73      0.56       689
          10       0.77      0.70      0.73       670
          11       0.71      0.64      0.67       312
          12       0.60      0.78      0.68       665
          13       0.84      0.82      0.83       314
          14       0.83      0.75      0.79       756
          15       0.77      0.97      0.86      1607

    accuracy                           0.74     12791
   macro avg       0.74   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6196653138933315, 0.6196653138933315)




CCA coefficients mean non-concern: (0.6242995270822193, 0.6242995270822193)




Linear CKA concern: 0.6207120206374394




Linear CKA non-concern: 0.5884946688826138




Kernel CKA concern: 0.6425899020485808




Kernel CKA non-concern: 0.6270819288727043




Evaluate the pruned model 2




Evaluating:   0%|                                                                              | 0/800 [00:22<…

Loss: 0.9980




Precision: 0.7374, Recall: 0.6959, F1-Score: 0.7029




              precision    recall  f1-score   support

           0       0.78      0.51      0.62       797
           1       0.83      0.54      0.65       775
           2       0.86      0.83      0.85       795
           3       0.87      0.75      0.81      1110
           4       0.74      0.82      0.78      1260
           5       0.91      0.62      0.73       882
           6       0.86      0.68      0.76       940
           7       0.47      0.31      0.37       473
           8       0.53      0.83      0.65       746
           9       0.44      0.73      0.55       689
          10       0.74      0.68      0.71       670
          11       0.72      0.60      0.65       312
          12       0.60      0.77      0.68       665
          13       0.84      0.79      0.81       314
          14       0.86      0.72      0.78       756
          15       0.76      0.97      0.85      1607

    accuracy                           0.73     12791
   macro avg       0.74   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6180122619499531, 0.6180122619499531)




CCA coefficients mean non-concern: (0.6167620328845297, 0.6167620328845297)




Linear CKA concern: 0.754676212791344




Linear CKA non-concern: 0.5102879397110864




Kernel CKA concern: 0.7369416039376466




Kernel CKA non-concern: 0.5326641069288583




Evaluate the pruned model 3




Evaluating:   0%|                                                                              | 0/800 [00:23<…