In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_wanda
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
wanda_ratio = 0.3
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-21 11:05:25


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [8]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [9]:
module = copy.deepcopy(model)
prune_wanda(module, model_config, all_samples, sparsity_ratio=wanda_ratio, include_layers=include_layers,
            exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(module, model_config, test_dataloader)
# save_module(module, "Modules/", f"wanda_{name}_{wanda_ratio}p.pt")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Evaluate the pruned model




Evaluating:   0%|                                                                             | 0/1875 [00:48<…

Loss: 1.0045




Precision: 0.6832, Recall: 0.6814, F1-Score: 0.6790




              precision    recall  f1-score   support

           0       0.56      0.55      0.56      2972
           1       0.74      0.65      0.69      3016
           2       0.71      0.77      0.74      2985
           3       0.53      0.52      0.53      3023
           4       0.81      0.82      0.81      3039
           5       0.91      0.82      0.86      3076
           6       0.58      0.42      0.49      2965
           7       0.59      0.74      0.66      3031
           8       0.64      0.76      0.70      2932
           9       0.74      0.75      0.75      2961

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





In [10]:
for concern in range(num_labels):
    print(f"--{concern}--")
    valid = copy.deepcopy(valid_dataloader)
    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

--0--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9082890058343849, 0.9082890058343849)




CCA coefficients mean non-concern: (0.9085018479383714, 0.9085018479383714)




Linear CKA concern: 0.9873974167333879




Linear CKA non-concern: 0.9867177822305584




Kernel CKA concern: 0.9779174441135444




Kernel CKA non-concern: 0.9811148572934952




--1--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9073651463809838, 0.9073651463809838)




CCA coefficients mean non-concern: (0.9086120367435762, 0.9086120367435762)




Linear CKA concern: 0.9874925645872052




Linear CKA non-concern: 0.9869270636444352




Kernel CKA concern: 0.9784242907605992




Kernel CKA non-concern: 0.9810789250699902




--2--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9017747413279408, 0.9017747413279408)




CCA coefficients mean non-concern: (0.9088750289341786, 0.9088750289341786)




Linear CKA concern: 0.9922656406841022




Linear CKA non-concern: 0.9869002694984098




Kernel CKA concern: 0.9866546045441437




Kernel CKA non-concern: 0.9792782968441538




--3--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9054085246967546, 0.9054085246967546)




CCA coefficients mean non-concern: (0.9086053546879873, 0.9086053546879873)




Linear CKA concern: 0.9863127939232614




Linear CKA non-concern: 0.9870008821627867




Kernel CKA concern: 0.9772523307044756




Kernel CKA non-concern: 0.9818927440176797




--4--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.909362035708767, 0.909362035708767)




CCA coefficients mean non-concern: (0.9080609242794107, 0.9080609242794107)




Linear CKA concern: 0.9880991865978465




Linear CKA non-concern: 0.9871826186017553




Kernel CKA concern: 0.9795717925129672




Kernel CKA non-concern: 0.9811791784976988




--5--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8994616085964775, 0.8994616085964775)




CCA coefficients mean non-concern: (0.9077006401085795, 0.9077006401085795)




Linear CKA concern: 0.991549568453489




Linear CKA non-concern: 0.9868864207351687




Kernel CKA concern: 0.985392905994908




Kernel CKA non-concern: 0.9802390826222795




--6--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9062210281578303, 0.9062210281578303)




CCA coefficients mean non-concern: (0.9104617161686808, 0.9104617161686808)




Linear CKA concern: 0.9862373871607858




Linear CKA non-concern: 0.9868955723103916




Kernel CKA concern: 0.9728491457971062




Kernel CKA non-concern: 0.9815900461034053




--7--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9015530345701761, 0.9015530345701761)




CCA coefficients mean non-concern: (0.9107058428707053, 0.9107058428707053)




Linear CKA concern: 0.989298821426235




Linear CKA non-concern: 0.9869799000708331




Kernel CKA concern: 0.9823407444487018




Kernel CKA non-concern: 0.9819754466940485




--8--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9033207747543198, 0.9033207747543198)




CCA coefficients mean non-concern: (0.9090290603902158, 0.9090290603902158)




Linear CKA concern: 0.9884762366409864




Linear CKA non-concern: 0.9867763714874253




Kernel CKA concern: 0.9814450804221491




Kernel CKA non-concern: 0.9813998633769082




--9--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.902645296654479, 0.902645296654479)




CCA coefficients mean non-concern: (0.9103298348090592, 0.9103298348090592)




Linear CKA concern: 0.9858267427915623




Linear CKA non-concern: 0.9869623160246519




Kernel CKA concern: 0.9770875225082783




Kernel CKA non-concern: 0.9816853887827743




In [11]:
get_sparsity(module)

(0.2972039229824205,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.2994791666666667,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.2994791666666667,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.2994791666666667,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.2994791666666667,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.2994791666666667,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.2998046875,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.2994791666666667,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.2994791666666667,
  'bert.encoder.