In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_wanda
)

In [3]:
name = "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
wanda_ratio = 0.6
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-21 10:42:50


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}




The model sadickam/sdg-classification-bert is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}




Loading cached dataset OSDG.




The dataset OSDG is loaded




In [7]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [8]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [9]:
module = copy.deepcopy(model)
prune_wanda(module, model_config, all_samples, sparsity_ratio=wanda_ratio, include_layers=include_layers,
            exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(module, model_config, test_dataloader)
# save_module(module, "Modules/", f"wanda_{name}_{wanda_ratio}p.pt")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Evaluate the pruned model




Evaluating:   0%|                                                                              | 0/800 [00:16<…

Loss: 0.9467




Precision: 0.7446, Recall: 0.7237, F1-Score: 0.7253




              precision    recall  f1-score   support

           0       0.76      0.57      0.65       797
           1       0.81      0.61      0.70       775
           2       0.87      0.82      0.85       795
           3       0.85      0.78      0.82      1110
           4       0.78      0.81      0.79      1260
           5       0.91      0.63      0.74       882
           6       0.84      0.74      0.79       940
           7       0.48      0.35      0.41       473
           8       0.58      0.82      0.68       746
           9       0.45      0.76      0.56       689
          10       0.76      0.70      0.73       670
          11       0.68      0.68      0.68       312
          12       0.64      0.77      0.70       665
          13       0.82      0.83      0.82       314
          14       0.83      0.75      0.79       756
          15       0.86      0.96      0.91      1607

    accuracy                           0.75     12791
   macro avg       0.74   




In [10]:
for concern in range(num_labels):
    print(f"--{concern}--")
    valid = copy.deepcopy(valid_dataloader)
    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

--0--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6323435557370096, 0.6323435557370096)




CCA coefficients mean non-concern: (0.6272089721426088, 0.6272089721426088)




Linear CKA concern: 0.7030531235187791




Linear CKA non-concern: 0.6291548814181243




Kernel CKA concern: 0.7019598330886319




Kernel CKA non-concern: 0.6604190355161136




--1--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6214229009420731, 0.6214229009420731)




CCA coefficients mean non-concern: (0.6316633053654399, 0.6316633053654399)




Linear CKA concern: 0.5873858167125781




Linear CKA non-concern: 0.6547153101202279




Kernel CKA concern: 0.606974940203829




Kernel CKA non-concern: 0.6974362920192854




--2--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6232737814189386, 0.6232737814189386)




CCA coefficients mean non-concern: (0.6297147184506765, 0.6297147184506765)




Linear CKA concern: 0.7339969041147629




Linear CKA non-concern: 0.6515896062047349




Kernel CKA concern: 0.704150899860229




Kernel CKA non-concern: 0.6865353257714596




--3--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.62725753387432, 0.62725753387432)




CCA coefficients mean non-concern: (0.6296902271378448, 0.6296902271378448)




Linear CKA concern: 0.6801901223925564




Linear CKA non-concern: 0.6606263100549751




Kernel CKA concern: 0.6540313241633617




Kernel CKA non-concern: 0.6938629789864651




--4--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6276917409501409, 0.6276917409501409)




CCA coefficients mean non-concern: (0.631265973683704, 0.631265973683704)




Linear CKA concern: 0.7095390865003002




Linear CKA non-concern: 0.6546333871727639




Kernel CKA concern: 0.6779793227989083




Kernel CKA non-concern: 0.6939612079141338




--5--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6198068759627037, 0.6198068759627037)




CCA coefficients mean non-concern: (0.6291979801167416, 0.6291979801167416)




Linear CKA concern: 0.617298376472084




Linear CKA non-concern: 0.655156175657154




Kernel CKA concern: 0.6022147465724084




Kernel CKA non-concern: 0.6969576149918925




--6--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6132381521876703, 0.6132381521876703)




CCA coefficients mean non-concern: (0.6311474681893828, 0.6311474681893828)




Linear CKA concern: 0.6452312622040676




Linear CKA non-concern: 0.6669392975500789




Kernel CKA concern: 0.6398485540892472




Kernel CKA non-concern: 0.7053038279532524




--7--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6275007201962132, 0.6275007201962132)




CCA coefficients mean non-concern: (0.6309598566393969, 0.6309598566393969)




Linear CKA concern: 0.6834375316127114




Linear CKA non-concern: 0.656249192824264




Kernel CKA concern: 0.6481532702645372




Kernel CKA non-concern: 0.6986158775856662




--8--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6318506112348806, 0.6318506112348806)




CCA coefficients mean non-concern: (0.629681193649084, 0.629681193649084)




Linear CKA concern: 0.6900308071318179




Linear CKA non-concern: 0.6532791667809397




Kernel CKA concern: 0.6736466849990266




Kernel CKA non-concern: 0.6864889251363823




--9--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6313504096058589, 0.6313504096058589)




CCA coefficients mean non-concern: (0.6325214141491997, 0.6325214141491997)




Linear CKA concern: 0.7234714785119939




Linear CKA non-concern: 0.6522172345360787




Kernel CKA concern: 0.693351433907621




Kernel CKA non-concern: 0.6956276911620957




--10--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6183358934884438, 0.6183358934884438)




CCA coefficients mean non-concern: (0.6294162556790494, 0.6294162556790494)




Linear CKA concern: 0.6350033231928771




Linear CKA non-concern: 0.6523727498532184




Kernel CKA concern: 0.6076523155886231




Kernel CKA non-concern: 0.691737017332923




--11--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6266041400991835, 0.6266041400991835)




CCA coefficients mean non-concern: (0.6300499908059852, 0.6300499908059852)




Linear CKA concern: 0.7018480207885625




Linear CKA non-concern: 0.653242098827822




Kernel CKA concern: 0.6906186285711677




Kernel CKA non-concern: 0.6889352594121919




--12--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6256986486491161, 0.6256986486491161)




CCA coefficients mean non-concern: (0.629657537830721, 0.629657537830721)




Linear CKA concern: 0.7256779482906686




Linear CKA non-concern: 0.6573609278408437




Kernel CKA concern: 0.7068128099594496




Kernel CKA non-concern: 0.685712009000472




--13--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6267264081280038, 0.6267264081280038)




CCA coefficients mean non-concern: (0.6298305038606331, 0.6298305038606331)




Linear CKA concern: 0.7498833663660164




Linear CKA non-concern: 0.6549518498808722




Kernel CKA concern: 0.7133464719334548




Kernel CKA non-concern: 0.6887467663438372




--14--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6214821420970165, 0.6214821420970165)




CCA coefficients mean non-concern: (0.6317392936466857, 0.6317392936466857)




Linear CKA concern: 0.6913052030355149




Linear CKA non-concern: 0.6467544491610645




Kernel CKA concern: 0.6901823929772118




Kernel CKA non-concern: 0.6750199788704587




--15--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6169169495181247, 0.6169169495181247)




CCA coefficients mean non-concern: (0.6289140996728038, 0.6289140996728038)




Linear CKA concern: 0.695399242611741




Linear CKA non-concern: 0.6510906397336795




Kernel CKA concern: 0.6597915030855149




Kernel CKA non-concern: 0.6902359329253667




In [11]:
get_sparsity(module)

(0.5944834517193173,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.5989583333333334,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.5999348958333334,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.5989583333333334,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.5989583333333334,
  'bert.en