In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_wanda
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
wanda_ratio = 0.4
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-21 12:13:06


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [8]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [9]:
module = copy.deepcopy(model)
prune_wanda(module, model_config, all_samples, sparsity_ratio=wanda_ratio, include_layers=include_layers,
            exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(module, model_config, test_dataloader)
# save_module(module, "Modules/", f"wanda_{name}_{wanda_ratio}p.pt")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Evaluate the pruned model




Evaluating:   0%|                                                                             | 0/1875 [00:50<…

Loss: 1.0236




Precision: 0.6766, Recall: 0.6724, F1-Score: 0.6699




              precision    recall  f1-score   support

           0       0.56      0.55      0.55      2972
           1       0.74      0.63      0.68      3016
           2       0.71      0.76      0.74      2985
           3       0.53      0.50      0.52      3023
           4       0.81      0.81      0.81      3039
           5       0.92      0.81      0.86      3076
           6       0.57      0.40      0.47      2965
           7       0.56      0.76      0.64      3031
           8       0.63      0.76      0.69      2932
           9       0.73      0.74      0.74      2961

    accuracy                           0.67     30000
   macro avg       0.68      0.67      0.67     30000
weighted avg       0.68      0.67      0.67     30000





In [10]:
for concern in range(num_labels):
    print(f"--{concern}--")
    valid = copy.deepcopy(valid_dataloader)
    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

--0--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8299086027492952, 0.8299086027492952)




CCA coefficients mean non-concern: (0.8338302370981038, 0.8338302370981038)




Linear CKA concern: 0.9646284058391678




Linear CKA non-concern: 0.9626927774720758




Kernel CKA concern: 0.9384434753710436




Kernel CKA non-concern: 0.9470408411814522




--1--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8269942660406296, 0.8269942660406296)




CCA coefficients mean non-concern: (0.8329809290081603, 0.8329809290081603)




Linear CKA concern: 0.9653907119344363




Linear CKA non-concern: 0.9621128054385024




Kernel CKA concern: 0.9416935683743134




Kernel CKA non-concern: 0.9460453396822337




--2--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8208185574992831, 0.8208185574992831)




CCA coefficients mean non-concern: (0.8323960970831258, 0.8323960970831258)




Linear CKA concern: 0.9735889337913666




Linear CKA non-concern: 0.9631065355590067




Kernel CKA concern: 0.9578931457100227




Kernel CKA non-concern: 0.9424046082432055




--3--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8277867695695487, 0.8277867695695487)




CCA coefficients mean non-concern: (0.8324888083537998, 0.8324888083537998)




Linear CKA concern: 0.9579459460328154




Linear CKA non-concern: 0.9617308880021063




Kernel CKA concern: 0.9343286878762178




Kernel CKA non-concern: 0.9478417172715295




--4--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8295830968399089, 0.8295830968399089)




CCA coefficients mean non-concern: (0.8332953360862955, 0.8332953360862955)




Linear CKA concern: 0.9682919265682798




Linear CKA non-concern: 0.962774781388758




Kernel CKA concern: 0.9471562189098162




Kernel CKA non-concern: 0.9457874569367295




--5--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8196632795173803, 0.8196632795173803)




CCA coefficients mean non-concern: (0.8314426025851602, 0.8314426025851602)




Linear CKA concern: 0.9778698190448047




Linear CKA non-concern: 0.9614456977505201




Kernel CKA concern: 0.9619997387398661




Kernel CKA non-concern: 0.9430038498443786




--6--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8312997905423316, 0.8312997905423316)




CCA coefficients mean non-concern: (0.8335478981609974, 0.8335478981609974)




Linear CKA concern: 0.9568573118297451




Linear CKA non-concern: 0.9630692076628307




Kernel CKA concern: 0.9141231037754686




Kernel CKA non-concern: 0.9487247200577685




--7--




adding eps to diagonal and taking inverse

In [None]:
get_sparsity(module)