In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_wanda
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
wanda_ratio = 0.6
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-21 14:34:59


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [8]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [9]:
module = copy.deepcopy(model)
prune_wanda(module, model_config, all_samples, sparsity_ratio=wanda_ratio, include_layers=include_layers,
            exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(module, model_config, test_dataloader)
# save_module(module, "Modules/", f"wanda_{name}_{wanda_ratio}p.pt")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Evaluate the pruned model




Evaluating:   0%|                                                                             | 0/1875 [00:52<…

Loss: 1.5111




Precision: 0.6468, Recall: 0.5525, F1-Score: 0.5629




              precision    recall  f1-score   support

           0       0.47      0.52      0.49      2972
           1       0.82      0.37      0.51      3016
           2       0.82      0.46      0.59      2985
           3       0.51      0.33      0.40      3023
           4       0.77      0.73      0.75      3039
           5       0.96      0.62      0.75      3076
           6       0.47      0.35      0.40      2965
           7       0.30      0.87      0.44      3031
           8       0.58      0.72      0.64      2932
           9       0.79      0.55      0.65      2961

    accuracy                           0.55     30000
   macro avg       0.65      0.55      0.56     30000
weighted avg       0.65      0.55      0.56     30000





In [10]:
for concern in range(num_labels):
    print(f"--{concern}--")
    valid = copy.deepcopy(valid_dataloader)
    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

--0--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7045616515774349, 0.7045616515774349)




CCA coefficients mean non-concern: (0.7021380336521397, 0.7021380336521397)




Linear CKA concern: 0.7308721575368944




Linear CKA non-concern: 0.6969461797565524




Kernel CKA concern: 0.5300520854332607




Kernel CKA non-concern: 0.5351014546304967




--1--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6964180770746912, 0.6964180770746912)




CCA coefficients mean non-concern: (0.7023380078935889, 0.7023380078935889)




Linear CKA concern: 0.5706694336385603




Linear CKA non-concern: 0.7032264007853558




Kernel CKA concern: 0.3733426463016561




Kernel CKA non-concern: 0.5419256627901735




--2--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6921004283792584, 0.6921004283792584)




CCA coefficients mean non-concern: (0.7027614696380072, 0.7027614696380072)




Linear CKA concern: 0.5420200124153493




Linear CKA non-concern: 0.721271408530677




Kernel CKA concern: 0.49258254650627936




Kernel CKA non-concern: 0.5207452707681948




--3--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.698953119199126, 0.698953119199126)




CCA coefficients mean non-concern: (0.7017845204445379, 0.7017845204445379)




Linear CKA concern: 0.626093836776931




Linear CKA non-concern: 0.6895027116504678




Kernel CKA concern: 0.4525777960117719




Kernel CKA non-concern: 0.540460480028849




--4--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6965404302765912, 0.6965404302765912)




CCA coefficients mean non-concern: (0.7035785820054458, 0.7035785820054458)




Linear CKA concern: 0.6104991385392214




Linear CKA non-concern: 0.7062719937198025




Kernel CKA concern: 0.5214165774927736




Kernel CKA non-concern: 0.5472689770420074




--5--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6961793982540715, 0.6961793982540715)




CCA coefficients mean non-concern: (0.7019133451328837, 0.7019133451328837)




Linear CKA concern: 0.7508853072417139




Linear CKA non-concern: 0.70477127120757




Kernel CKA concern: 0.6969386764772991




Kernel CKA non-concern: 0.5130896765418902




--6--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7052811224293565, 0.7052811224293565)




CCA coefficients mean non-concern: (0.7041483356017768, 0.7041483356017768)




Linear CKA concern: 0.7101709061918275




Linear CKA non-concern: 0.6967104332525175




Kernel CKA concern: 0.4309276882780908




Kernel CKA non-concern: 0.5474392513145705




--7--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7005671760279174, 0.7005671760279174)




CCA coefficients mean non-concern: (0.7026471481729617, 0.7026471481729617)




Linear CKA concern: 0.7179790162089168




Linear CKA non-concern: 0.6818336500482038




Kernel CKA concern: 0.5813292815996467




Kernel CKA non-concern: 0.5368731186793976




--8--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6999727605744192, 0.6999727605744192)




CCA coefficients mean non-concern: (0.7009898693773929, 0.7009898693773929)




Linear CKA concern: 0.7006283672120215




Linear CKA non-concern: 0.6815713652832683




Kernel CKA concern: 0.6000195275367434




Kernel CKA non-concern: 0.5247489308239931




--9--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6865474511646446, 0.6865474511646446)




CCA coefficients mean non-concern: (0.702956084345054, 0.702956084345054)




Linear CKA concern: 0.6157588429266698




Linear CKA non-concern: 0.6790982019288924




Kernel CKA concern: 0.5050566382979226




Kernel CKA non-concern: 0.53059663337478




In [11]:
get_sparsity(module)

(0.5945154895443348,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.5989583333333334,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.5999348958333334,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.5989583333333334,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.5989583333333334,
  'bert.en