In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.4
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-23 16:43:13


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}




The model sadickam/sdg-classification-bert is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}




Loading cached dataset OSDG.




The dataset OSDG is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Evaluate the pruned model 0




Evaluating:   0%|                                                                              | 0/800 [00:15<…

Loss: 0.9387




Precision: 0.7763, Recall: 0.7786, F1-Score: 0.7733




              precision    recall  f1-score   support

           0       0.74      0.67      0.70       797
           1       0.83      0.70      0.76       775
           2       0.88      0.87      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.81      0.83      1260
           5       0.90      0.68      0.77       882
           6       0.85      0.79      0.82       940
           7       0.47      0.59      0.52       473
           8       0.66      0.85      0.74       746
           9       0.58      0.73      0.65       689
          10       0.78      0.77      0.78       670
          11       0.66      0.79      0.72       312
          12       0.69      0.81      0.75       665
          13       0.85      0.84      0.84       314
          14       0.86      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8333735403863127, 0.8333735403863127)




CCA coefficients mean non-concern: (0.823329157369746, 0.823329157369746)




Linear CKA concern: 0.9769091416223766




Linear CKA non-concern: 0.9429818515971297




Kernel CKA concern: 0.9712940254329858




Kernel CKA non-concern: 0.9444524971245073




Evaluate the pruned model 1




Evaluating:   0%|                                                                              | 0/800 [00:23<…

Loss: 0.9338




Precision: 0.7764, Recall: 0.7802, F1-Score: 0.7743




              precision    recall  f1-score   support

           0       0.76      0.65      0.70       797
           1       0.84      0.72      0.78       775
           2       0.87      0.88      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.80      0.83      1260
           5       0.89      0.69      0.78       882
           6       0.85      0.80      0.82       940
           7       0.47      0.58      0.52       473
           8       0.65      0.85      0.74       746
           9       0.57      0.73      0.64       689
          10       0.76      0.78      0.77       670
          11       0.66      0.79      0.72       312
          12       0.70      0.80      0.75       665
          13       0.84      0.85      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.96      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8281555489984868, 0.8281555489984868)




CCA coefficients mean non-concern: (0.8290349954592928, 0.8290349954592928)




Linear CKA concern: 0.9728861443975234




Linear CKA non-concern: 0.9576527666021868




Kernel CKA concern: 0.9696820464770753




Kernel CKA non-concern: 0.960366003033304




Evaluate the pruned model 2




Evaluating:   0%|                                                                              | 0/800 [00:23<…