In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_wanda
)

In [3]:
name = "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
wanda_ratio = 0.5
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-21 10:20:14


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}




The model sadickam/sdg-classification-bert is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}




Loading cached dataset OSDG.




The dataset OSDG is loaded




In [7]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [8]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [9]:
module = copy.deepcopy(model)
prune_wanda(module, model_config, all_samples, sparsity_ratio=wanda_ratio, include_layers=include_layers,
            exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(module, model_config, test_dataloader)
# save_module(module, "Modules/", f"wanda_{name}_{wanda_ratio}p.pt")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Evaluate the pruned model




Evaluating:   0%|                                                                              | 0/800 [00:16<…

Loss: 0.9304




Precision: 0.7667, Recall: 0.7680, F1-Score: 0.7628




              precision    recall  f1-score   support

           0       0.76      0.62      0.69       797
           1       0.84      0.70      0.76       775
           2       0.86      0.87      0.86       795
           3       0.87      0.81      0.84      1110
           4       0.83      0.80      0.82      1260
           5       0.90      0.67      0.77       882
           6       0.84      0.78      0.81       940
           7       0.45      0.55      0.50       473
           8       0.66      0.84      0.74       746
           9       0.55      0.73      0.63       689
          10       0.73      0.78      0.75       670
          11       0.65      0.77      0.71       312
          12       0.68      0.80      0.74       665
          13       0.83      0.84      0.83       314
          14       0.86      0.76      0.81       756
          15       0.95      0.96      0.96      1607

    accuracy                           0.78     12791
   macro avg       0.77   




In [10]:
for concern in range(num_labels):
    print(f"--{concern}--")
    valid = copy.deepcopy(valid_dataloader)
    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

--0--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7341689247496554, 0.7341689247496554)




CCA coefficients mean non-concern: (0.7311421070976558, 0.7311421070976558)




Linear CKA concern: 0.9144267778813367




Linear CKA non-concern: 0.8838276851639474




Kernel CKA concern: 0.9092112579015226




Kernel CKA non-concern: 0.8969913487907665




--1--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7254913346781982, 0.7254913346781982)




CCA coefficients mean non-concern: (0.7339131497670578, 0.7339131497670578)




Linear CKA concern: 0.8926110038347485




Linear CKA non-concern: 0.8979632213561406




Kernel CKA concern: 0.8923510691226207




Kernel CKA non-concern: 0.9091412600090318




--2--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7274286598652877, 0.7274286598652877)




CCA coefficients mean non-concern: (0.7333061667223894, 0.7333061667223894)




Linear CKA concern: 0.926877127876574




Linear CKA non-concern: 0.8964952421530624




Kernel CKA concern: 0.9100252954318238




Kernel CKA non-concern: 0.9062004754989016




--3--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7274967545656176, 0.7274967545656176)




CCA coefficients mean non-concern: (0.7330163636572424, 0.7330163636572424)




Linear CKA concern: 0.8939325217394954




Linear CKA non-concern: 0.8995945296057973




Kernel CKA concern: 0.8810661713042279




Kernel CKA non-concern: 0.9083662616004031




--4--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7335163789945026, 0.7335163789945026)




CCA coefficients mean non-concern: (0.7336930821323588, 0.7336930821323588)




Linear CKA concern: 0.9191548007534658




Linear CKA non-concern: 0.8986088621092694




Kernel CKA concern: 0.9098762658177246




Kernel CKA non-concern: 0.9081953394166419




--5--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7232888303899835, 0.7232888303899835)




CCA coefficients mean non-concern: (0.7324224642515198, 0.7324224642515198)




Linear CKA concern: 0.8824697616877224




Linear CKA non-concern: 0.8985783536651133




Kernel CKA concern: 0.8657825205967129




Kernel CKA non-concern: 0.9091363362407827




--6--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7188872125780031, 0.7188872125780031)




CCA coefficients mean non-concern: (0.7340117223610587, 0.7340117223610587)




Linear CKA concern: 0.8862236378828172




Linear CKA non-concern: 0.9021375008464708




Kernel CKA concern: 0.8720718773995062




Kernel CKA non-concern: 0.9113142986200798




--7--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7300847827585287, 0.7300847827585287)




CCA coefficients mean non-concern: (0.7331975707131432, 0.7331975707131432)




Linear CKA concern: 0.8952122848060091




Linear CKA non-concern: 0.8976303592898388




Kernel CKA concern: 0.8802193280146519




Kernel CKA non-concern: 0.9085788522601167




--8--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7300846383564453, 0.7300846383564453)




CCA coefficients mean non-concern: (0.7328489249148828, 0.7328489249148828)




Linear CKA concern: 0.9090489565168502




Linear CKA non-concern: 0.8960284370255507




Kernel CKA concern: 0.8951593145232153




Kernel CKA non-concern: 0.9055425302706887




--9--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7301811221948978, 0.7301811221948978)




CCA coefficients mean non-concern: (0.7345086780379357, 0.7345086780379357)




Linear CKA concern: 0.9124855618061777




Linear CKA non-concern: 0.8965161070463226




Kernel CKA concern: 0.8979842997080901




Kernel CKA non-concern: 0.9081078290341865




--10--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7215892729103434, 0.7215892729103434)




CCA coefficients mean non-concern: (0.7315967365025073, 0.7315967365025073)




Linear CKA concern: 0.8939587153832372




Linear CKA non-concern: 0.8963265378832546




Kernel CKA concern: 0.876866907192164




Kernel CKA non-concern: 0.9068635789612874




--11--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7274883199939701, 0.7274883199939701)




CCA coefficients mean non-concern: (0.7326507857851318, 0.7326507857851318)




Linear CKA concern: 0.9290323980192164




Linear CKA non-concern: 0.8956873232529348




Kernel CKA concern: 0.9140500875415292




Kernel CKA non-concern: 0.9062713981339585




--12--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7289984267453934, 0.7289984267453934)




CCA coefficients mean non-concern: (0.7326749312486697, 0.7326749312486697)




Linear CKA concern: 0.9186971997307151




Linear CKA non-concern: 0.897199037864867




Kernel CKA concern: 0.905834738257643




Kernel CKA non-concern: 0.9056453404355881




--13--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7278479326722501, 0.7278479326722501)




CCA coefficients mean non-concern: (0.731678670325465, 0.731678670325465)




Linear CKA concern: 0.9202710760529369




Linear CKA non-concern: 0.8988393999527582




Kernel CKA concern: 0.896889775450057




Kernel CKA non-concern: 0.907540553037614




--14--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7252125418339107, 0.7252125418339107)




CCA coefficients mean non-concern: (0.7338620667075519, 0.7338620667075519)




Linear CKA concern: 0.9102002562618281




Linear CKA non-concern: 0.8934873492367076




Kernel CKA concern: 0.9053096422877424




Kernel CKA non-concern: 0.902574665751194




--15--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7145181941868, 0.7145181941868)




CCA coefficients mean non-concern: (0.7315920197630629, 0.7315920197630629)




Linear CKA concern: 0.8720913359845404




Linear CKA non-concern: 0.8989473524192331




Kernel CKA concern: 0.8502629591939355




Kernel CKA non-concern: 0.9088498042399513




In [11]:
get_sparsity(module)

(0.4959948842155738,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.5,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.5,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.5,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.5,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.5,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.5,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.5,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.5,
  'bert.encoder.layer.1.attention.self.key.bias': 0.0,
  'bert.encoder.layer.1.attention.self.value.weight': 0.5,
  'bert.encoder.