In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_wanda
)

In [3]:
name = "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
wanda_ratio = 0.3
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-21 09:34:44


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}




The model sadickam/sdg-classification-bert is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}




Loading cached dataset OSDG.




The dataset OSDG is loaded




In [7]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [8]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [9]:
module = copy.deepcopy(model)
prune_wanda(module, model_config, all_samples, sparsity_ratio=wanda_ratio, include_layers=include_layers,
            exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(module, model_config, test_dataloader)
# save_module(module, "Modules/", f"wanda_{name}_{wanda_ratio}p.pt")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Evaluate the pruned model




Evaluating:   0%|                                                                              | 0/800 [00:16<…

Loss: 0.9419




Precision: 0.7778, Recall: 0.7837, F1-Score: 0.7767




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.84      0.71      0.77       775
           2       0.88      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.48      0.59      0.53       473
           8       0.67      0.86      0.75       746
           9       0.59      0.73      0.65       689
          10       0.76      0.79      0.77       670
          11       0.62      0.80      0.70       312
          12       0.72      0.81      0.76       665
          13       0.83      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




In [10]:
for concern in range(num_labels):
    print(f"--{concern}--")
    valid = copy.deepcopy(valid_dataloader)
    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

--0--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9180078096571286, 0.9180078096571286)




CCA coefficients mean non-concern: (0.9170642071036617, 0.9170642071036617)




Linear CKA concern: 0.9890781224050023




Linear CKA non-concern: 0.9855645763981632




Kernel CKA concern: 0.987726734877741




Kernel CKA non-concern: 0.9866448032568543




--1--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9125142767795833, 0.9125142767795833)




CCA coefficients mean non-concern: (0.9169609465686503, 0.9169609465686503)




Linear CKA concern: 0.987480077547152




Linear CKA non-concern: 0.9865758075646984




Kernel CKA concern: 0.9861566860473862




Kernel CKA non-concern: 0.9874678577461732




--2--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.913267666897723, 0.913267666897723)




CCA coefficients mean non-concern: (0.9180725301615058, 0.9180725301615058)




Linear CKA concern: 0.9910076251949411




Linear CKA non-concern: 0.9864492919935406




Kernel CKA concern: 0.9879117359463473




Kernel CKA non-concern: 0.9871581539047202




--3--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9138299061859047, 0.9138299061859047)




CCA coefficients mean non-concern: (0.9168206017853172, 0.9168206017853172)




Linear CKA concern: 0.9874660105532584




Linear CKA non-concern: 0.9867205344657427




Kernel CKA concern: 0.9852527290570856




Kernel CKA non-concern: 0.9873183274772868




--4--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9191909149966452, 0.9191909149966452)




CCA coefficients mean non-concern: (0.9177605925579688, 0.9177605925579688)




Linear CKA concern: 0.9912028239966256




Linear CKA non-concern: 0.9866424722276314




Kernel CKA concern: 0.9893097958376907




Kernel CKA non-concern: 0.9873598175000319




--5--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9129532796984301, 0.9129532796984301)




CCA coefficients mean non-concern: (0.9169503660122814, 0.9169503660122814)




Linear CKA concern: 0.9864610293228716




Linear CKA non-concern: 0.9868565970638038




Kernel CKA concern: 0.9833133414479511




Kernel CKA non-concern: 0.9876434626329769




--6--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.913583387602989, 0.913583387602989)




CCA coefficients mean non-concern: (0.9179689846681115, 0.9179689846681115)




Linear CKA concern: 0.9855578729564217




Linear CKA non-concern: 0.9868654250285334




Kernel CKA concern: 0.9830093373760183




Kernel CKA non-concern: 0.9875831347403717




--7--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9173573718063137, 0.9173573718063137)




CCA coefficients mean non-concern: (0.9169549442295805, 0.9169549442295805)




Linear CKA concern: 0.9873529304372128




Linear CKA non-concern: 0.9865225048164435




Kernel CKA concern: 0.9852705945339199




Kernel CKA non-concern: 0.9874524558576856




--8--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9152997424808256, 0.9152997424808256)




CCA coefficients mean non-concern: (0.9172359455569411, 0.9172359455569411)




Linear CKA concern: 0.9896383061696528




Linear CKA non-concern: 0.9863825173313217




Kernel CKA concern: 0.9870875511733544




Kernel CKA non-concern: 0.987042968046777




--9--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9159593748312767, 0.9159593748312767)




CCA coefficients mean non-concern: (0.9183445303977195, 0.9183445303977195)




Linear CKA concern: 0.9895133995540113




Linear CKA non-concern: 0.9866212763429887




Kernel CKA concern: 0.9870917305480402




Kernel CKA non-concern: 0.9875682551055722




--10--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9171794874678649, 0.9171794874678649)




CCA coefficients mean non-concern: (0.9168691825729025, 0.9168691825729025)




Linear CKA concern: 0.9856835489370984




Linear CKA non-concern: 0.986668945636117




Kernel CKA concern: 0.9830434254418271




Kernel CKA non-concern: 0.9875701977810786




--11--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9157879208513504, 0.9157879208513504)




CCA coefficients mean non-concern: (0.9172986038374868, 0.9172986038374868)




Linear CKA concern: 0.9905562446330839




Linear CKA non-concern: 0.9865337040044406




Kernel CKA concern: 0.9878048221262276




Kernel CKA non-concern: 0.9872543633681957




--12--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9147898659764554, 0.9147898659764554)




CCA coefficients mean non-concern: (0.9177465317045599, 0.9177465317045599)




Linear CKA concern: 0.9873103469346765




Linear CKA non-concern: 0.9863657701025675




Kernel CKA concern: 0.9851605345650184




Kernel CKA non-concern: 0.9869479136676127




--13--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9108241649344982, 0.9108241649344982)




CCA coefficients mean non-concern: (0.9165053541391719, 0.9165053541391719)




Linear CKA concern: 0.9889461563639661




Linear CKA non-concern: 0.9866590620919243




Kernel CKA concern: 0.9846927024362466




Kernel CKA non-concern: 0.9872992453366907




--14--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9154534841528413, 0.9154534841528413)




CCA coefficients mean non-concern: (0.9174202138243331, 0.9174202138243331)




Linear CKA concern: 0.9881140044397354




Linear CKA non-concern: 0.9861255587429407




Kernel CKA concern: 0.9875247779763664




Kernel CKA non-concern: 0.9867846782737043




--15--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9093572917315118, 0.9093572917315118)




CCA coefficients mean non-concern: (0.9169085101794114, 0.9169085101794114)




Linear CKA concern: 0.980174785702076




Linear CKA non-concern: 0.9869673130906924




Kernel CKA concern: 0.9762139530163257




Kernel CKA non-concern: 0.987709962455761




In [11]:
get_sparsity(module)

(0.29718790697031233,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.2994791666666667,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.2994791666666667,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.2994791666666667,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.2994791666666667,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.2994791666666667,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.2998046875,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.2994791666666667,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.2994791666666667,
  'bert.encoder