In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_wanda
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
wanda_ratio = 0.5
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-21 13:23:34


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [8]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [9]:
module = copy.deepcopy(model)
prune_wanda(module, model_config, all_samples, sparsity_ratio=wanda_ratio, include_layers=include_layers,
            exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(module, model_config, test_dataloader)
# save_module(module, "Modules/", f"wanda_{name}_{wanda_ratio}p.pt")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Evaluate the pruned model




Evaluating:   0%|                                                                             | 0/1875 [00:49<…

Loss: 1.1111




Precision: 0.6661, Recall: 0.6502, F1-Score: 0.6506




              precision    recall  f1-score   support

           0       0.54      0.55      0.54      2972
           1       0.75      0.60      0.66      3016
           2       0.75      0.69      0.72      2985
           3       0.53      0.47      0.50      3023
           4       0.81      0.79      0.80      3039
           5       0.93      0.75      0.83      3076
           6       0.53      0.39      0.45      2965
           7       0.48      0.79      0.60      3031
           8       0.62      0.76      0.68      2932
           9       0.74      0.72      0.73      2961

    accuracy                           0.65     30000
   macro avg       0.67      0.65      0.65     30000
weighted avg       0.67      0.65      0.65     30000





In [10]:
for concern in range(num_labels):
    print(f"--{concern}--")
    valid = copy.deepcopy(valid_dataloader)
    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

--0--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7641142137256084, 0.7641142137256084)




CCA coefficients mean non-concern: (0.7624275370523613, 0.7624275370523613)




Linear CKA concern: 0.8930833072928498




Linear CKA non-concern: 0.8845642136573368




Kernel CKA concern: 0.8223399935275532




Kernel CKA non-concern: 0.8423408021468561




--1--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7562383829846545, 0.7562383829846545)




CCA coefficients mean non-concern: (0.7627860724863366, 0.7627860724863366)




Linear CKA concern: 0.8806745152175834




Linear CKA non-concern: 0.8808095289134658




Kernel CKA concern: 0.8207052071439697




Kernel CKA non-concern: 0.8364406787289212




--2--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.750525549023552, 0.750525549023552)




CCA coefficients mean non-concern: (0.7626382523279888, 0.7626382523279888)




Linear CKA concern: 0.8971855150253051




Linear CKA non-concern: 0.8851524806832959




Kernel CKA concern: 0.8690285695230751




Kernel CKA non-concern: 0.8241037839298286




--3--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7600826443442064, 0.7600826443442064)




CCA coefficients mean non-concern: (0.7621547845028693, 0.7621547845028693)




Linear CKA concern: 0.8629511565918652




Linear CKA non-concern: 0.8817067973155605




Kernel CKA concern: 0.8089949842213362




Kernel CKA non-concern: 0.8470616769153688




--4--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.756459899470824, 0.756459899470824)




CCA coefficients mean non-concern: (0.7629447154762881, 0.7629447154762881)




Linear CKA concern: 0.8920695385259018




Linear CKA non-concern: 0.8874489258486952




Kernel CKA concern: 0.8456766215153482




Kernel CKA non-concern: 0.8448660309552093




--5--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7538411145960198, 0.7538411145960198)




CCA coefficients mean non-concern: (0.7611984627565757, 0.7611984627565757)




Linear CKA concern: 0.9295490119665065




Linear CKA non-concern: 0.8812170517705011




Kernel CKA concern: 0.8958145190345647




Kernel CKA non-concern: 0.8270490172250053




--6--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7627345553330626, 0.7627345553330626)




CCA coefficients mean non-concern: (0.7646991480881008, 0.7646991480881008)




Linear CKA concern: 0.8770618083918009




Linear CKA non-concern: 0.885179567317306




Kernel CKA concern: 0.762792818282675




Kernel CKA non-concern: 0.8482781932413312




--7--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7574235826332436, 0.7574235826332436)




CCA coefficients mean non-concern: (0.7634728337417341, 0.7634728337417341)




Linear CKA concern: 0.903711013305006




Linear CKA non-concern: 0.885149378144105




Kernel CKA concern: 0.852872985605087




Kernel CKA non-concern: 0.8522538115888525




--8--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7572783317101228, 0.7572783317101228)




CCA coefficients mean non-concern: (0.7622464046504989, 0.7622464046504989)




Linear CKA concern: 0.9068079979141839




Linear CKA non-concern: 0.878326606422166




Kernel CKA concern: 0.8685519705522178




Kernel CKA non-concern: 0.8402330595369355




--9--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7498890977696999, 0.7498890977696999)




CCA coefficients mean non-concern: (0.7649247944220907, 0.7649247944220907)




Linear CKA concern: 0.8858908092506167




Linear CKA non-concern: 0.8814363352216226




Kernel CKA concern: 0.8435764512807781




Kernel CKA non-concern: 0.8477104059940819




In [11]:
get_sparsity(module)

(0.496021614307495,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.5,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.5,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.5,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.5,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.5,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.5,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.5,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.5,
  'bert.encoder.layer.1.attention.self.key.bias': 0.0,
  'bert.encoder.layer.1.attention.self.value.weight': 0.5,
  'bert.encoder.l