In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_wanda
)

In [3]:
name = "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
wanda_ratio = 0.4
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-21 09:57:34


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}




The model sadickam/sdg-classification-bert is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}




Loading cached dataset OSDG.




The dataset OSDG is loaded




In [7]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [8]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [9]:
module = copy.deepcopy(model)
prune_wanda(module, model_config, all_samples, sparsity_ratio=wanda_ratio, include_layers=include_layers,
            exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(module, model_config, test_dataloader)
# save_module(module, "Modules/", f"wanda_{name}_{wanda_ratio}p.pt")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Evaluate the pruned model




Evaluating:   0%|                                                                              | 0/800 [00:16<…

Loss: 0.9367




Precision: 0.7769, Recall: 0.7810, F1-Score: 0.7752




              precision    recall  f1-score   support

           0       0.75      0.66      0.70       797
           1       0.84      0.72      0.78       775
           2       0.87      0.87      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.81      0.83      1260
           5       0.90      0.68      0.78       882
           6       0.84      0.80      0.82       940
           7       0.47      0.59      0.52       473
           8       0.66      0.85      0.74       746
           9       0.60      0.72      0.65       689
          10       0.76      0.78      0.77       670
          11       0.67      0.79      0.72       312
          12       0.69      0.81      0.74       665
          13       0.84      0.86      0.85       314
          14       0.86      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




In [10]:
for concern in range(num_labels):
    print(f"--{concern}--")
    valid = copy.deepcopy(valid_dataloader)
    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

--0--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8332801475083546, 0.8332801475083546)




CCA coefficients mean non-concern: (0.8322319467625684, 0.8322319467625684)




Linear CKA concern: 0.970693149745014




Linear CKA non-concern: 0.9588653122891923




Kernel CKA concern: 0.9673370603606058




Kernel CKA non-concern: 0.96228566894193




--1--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8254329470363782, 0.8254329470363782)




CCA coefficients mean non-concern: (0.8325598289575019, 0.8325598289575019)




Linear CKA concern: 0.9677947672715879




Linear CKA non-concern: 0.9633738739303833




Kernel CKA concern: 0.9637257606508626




Kernel CKA non-concern: 0.9659433148161277




--2--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8301273213754259, 0.8301273213754259)




CCA coefficients mean non-concern: (0.8329239873807601, 0.8329239873807601)




Linear CKA concern: 0.9772957536592796




Linear CKA non-concern: 0.9625745985484878




Kernel CKA concern: 0.9706549242545205




Kernel CKA non-concern: 0.9646970420484654




--3--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8285286133072457, 0.8285286133072457)




CCA coefficients mean non-concern: (0.832079445631144, 0.832079445631144)




Linear CKA concern: 0.9647626186628611




Linear CKA non-concern: 0.9632055729681996




Kernel CKA concern: 0.9598589679406552




Kernel CKA non-concern: 0.9650899288303049




--4--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8362311788595046, 0.8362311788595046)




CCA coefficients mean non-concern: (0.8327767130026891, 0.8327767130026891)




Linear CKA concern: 0.974697322099949




Linear CKA non-concern: 0.9630911529286468




Kernel CKA concern: 0.9700371175702833




Kernel CKA non-concern: 0.9652749466073808




--5--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8280101272428146, 0.8280101272428146)




CCA coefficients mean non-concern: (0.8319578100709363, 0.8319578100709363)




Linear CKA concern: 0.95949704187912




Linear CKA non-concern: 0.9635987836388471




Kernel CKA concern: 0.9507253413577733




Kernel CKA non-concern: 0.9660254352780531




--6--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8230657732081166, 0.8230657732081166)




CCA coefficients mean non-concern: (0.8343705704096568, 0.8343705704096568)




Linear CKA concern: 0.9554337443637916




Linear CKA non-concern: 0.9641707500793566




Kernel CKA concern: 0.9471171667411156




Kernel CKA non-concern: 0.966473257305303




--7--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8311088887537903, 0.8311088887537903)




CCA coefficients mean non-concern: (0.8324328718455354, 0.8324328718455354)




Linear CKA concern: 0.9637099846490427




Linear CKA non-concern: 0.9633350603732865




Kernel CKA concern: 0.9588248576670464




Kernel CKA non-concern: 0.9656520884463665




--8--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8299208053435069, 0.8299208053435069)




CCA coefficients mean non-concern: (0.8320357799743859, 0.8320357799743859)




Linear CKA concern: 0.970368197894445




Linear CKA non-concern: 0.9626246120923132




Kernel CKA concern: 0.9635843208306192




Kernel CKA non-concern: 0.9646838476606178




--9--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8312251302307482, 0.8312251302307482)




CCA coefficients mean non-concern: (0.8339907176475138, 0.8339907176475138)




Linear CKA concern: 0.9698461174482388




Linear CKA non-concern: 0.9628328063328694




Kernel CKA concern: 0.9639199940286515




Kernel CKA non-concern: 0.9656018985901595




--10--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8276426095738941, 0.8276426095738941)




CCA coefficients mean non-concern: (0.8317252909854727, 0.8317252909854727)




Linear CKA concern: 0.9584726973760329




Linear CKA non-concern: 0.96327510417113




Kernel CKA concern: 0.9515732682616064




Kernel CKA non-concern: 0.9657055747710639




--11--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8283878203523504, 0.8283878203523504)




CCA coefficients mean non-concern: (0.8328152420950927, 0.8328152420950927)




Linear CKA concern: 0.9751388437703845




Linear CKA non-concern: 0.9627359994528797




Kernel CKA concern: 0.9681041032968737




Kernel CKA non-concern: 0.965012666630189




--12--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8284333327131849, 0.8284333327131849)




CCA coefficients mean non-concern: (0.8325747881854222, 0.8325747881854222)




Linear CKA concern: 0.9702069560129302




Linear CKA non-concern: 0.9628159255712787




Kernel CKA concern: 0.9642866049676295




Kernel CKA non-concern: 0.9648230424548571




--13--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8263773766689781, 0.8263773766689781)




CCA coefficients mean non-concern: (0.8308393910684241, 0.8308393910684241)




Linear CKA concern: 0.9730435719481849




Linear CKA non-concern: 0.963229658850819




Kernel CKA concern: 0.962444209177505




Kernel CKA non-concern: 0.9650237176421114




--14--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8294849775911006, 0.8294849775911006)




CCA coefficients mean non-concern: (0.8333531228166672, 0.8333531228166672)




Linear CKA concern: 0.966003285921889




Linear CKA non-concern: 0.9618039411168157




Kernel CKA concern: 0.9626744086720571




Kernel CKA non-concern: 0.9639035993888921




--15--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8174905417062116, 0.8174905417062116)




CCA coefficients mean non-concern: (0.831662163019995, 0.831662163019995)




Linear CKA concern: 0.94748538350146




Linear CKA non-concern: 0.9638495398819817




Kernel CKA concern: 0.9394024881069722




Kernel CKA non-concern: 0.9659887418115665




In [11]:
get_sparsity(module)

(0.39653757670359674,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.3997395833333333,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.3997395833333333,
  'bert.e