In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_magnitude
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
magnitude_ratio = 0.6
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-21 21:53:57


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [8]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [9]:
module = copy.deepcopy(model)
prune_magnitude(module, sparsity_ratio=magnitude_ratio, include_layers=include_layers, exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(model, model_config, test_dataloader)
# save_module(module, "Modules/", f"magnitude_{name}_{magnitude_ratio}p.pt")

Evaluate the pruned model




Evaluating:   0%|                                                                             | 0/1875 [00:25<…

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Loss: 1.0014




Precision: 0.6875, Recall: 0.6865, F1-Score: 0.6838




              precision    recall  f1-score   support

           0       0.57      0.55      0.56      2972
           1       0.74      0.67      0.70      3016
           2       0.71      0.78      0.74      2985
           3       0.54      0.53      0.53      3023
           4       0.81      0.82      0.82      3039
           5       0.90      0.84      0.87      3076
           6       0.60      0.43      0.50      2965
           7       0.62      0.74      0.67      3031
           8       0.63      0.76      0.69      2932
           9       0.75      0.75      0.75      2961

    accuracy                           0.69     30000
   macro avg       0.69      0.69      0.68     30000
weighted avg       0.69      0.69      0.68     30000





In [10]:
for concern in range(num_labels):
    print(f"--{concern}--")
    valid = copy.deepcopy(valid_dataloader)
    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

--0--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6320496708370995, 0.6320496708370995)




CCA coefficients mean non-concern: (0.6342362728577334, 0.6342362728577334)




Linear CKA concern: 0.6979935714859637




Linear CKA non-concern: 0.6523790384746131




Kernel CKA concern: 0.49871462911672404




Kernel CKA non-concern: 0.503570410938894




--1--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6315616644112375, 0.6315616644112375)




CCA coefficients mean non-concern: (0.6342404845105718, 0.6342404845105718)




Linear CKA concern: 0.6091402052711687




Linear CKA non-concern: 0.6651753952892827




Kernel CKA concern: 0.4255892815452718




Kernel CKA non-concern: 0.5049692635377862




--2--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6215537732419549, 0.6215537732419549)




CCA coefficients mean non-concern: (0.6329442895535833, 0.6329442895535833)




Linear CKA concern: 0.4860901195376556




Linear CKA non-concern: 0.6921681828590261




Kernel CKA concern: 0.36933307627677253




Kernel CKA non-concern: 0.5264626471308201




--3--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6301605061501799, 0.6301605061501799)




CCA coefficients mean non-concern: (0.6332583245089879, 0.6332583245089879)




Linear CKA concern: 0.6463829812164652




Linear CKA non-concern: 0.6512686777838854




Kernel CKA concern: 0.45415737257713207




Kernel CKA non-concern: 0.518266853206734




--4--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6223317118276445, 0.6223317118276445)




CCA coefficients mean non-concern: (0.6356566754615601, 0.6356566754615601)




Linear CKA concern: 0.5889520213675158




Linear CKA non-concern: 0.6732261287771146




Kernel CKA concern: 0.43953932529670037




Kernel CKA non-concern: 0.5228691560066707




--5--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6203759511500948, 0.6203759511500948)




CCA coefficients mean non-concern: (0.6322953842957673, 0.6322953842957673)




Linear CKA concern: 0.7140668716381098




Linear CKA non-concern: 0.6592511304032497




Kernel CKA concern: 0.5906797785188532




Kernel CKA non-concern: 0.48278853846761877




--6--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.641532674649352, 0.641532674649352)




CCA coefficients mean non-concern: (0.6331294146659255, 0.6331294146659255)




Linear CKA concern: 0.681205437897264




Linear CKA non-concern: 0.6589987451564431




Kernel CKA concern: 0.40986566140875996




Kernel CKA non-concern: 0.521703351504095




--7--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6324005264306142, 0.6324005264306142)




CCA coefficients mean non-concern: (0.6337340780048992, 0.6337340780048992)




Linear CKA concern: 0.731311764498252




Linear CKA non-concern: 0.6561984953126722




Kernel CKA concern: 0.5764149425001889




Kernel CKA non-concern: 0.5195170416422097




--8--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6244909151332085, 0.6244909151332085)




CCA coefficients mean non-concern: (0.6325358107909346, 0.6325358107909346)




Linear CKA concern: 0.6958981048571966




Linear CKA non-concern: 0.6541739185335494




Kernel CKA concern: 0.5402696621252097




Kernel CKA non-concern: 0.5058478073502244




--9--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6195998216452386, 0.6195998216452386)




CCA coefficients mean non-concern: (0.6351487530992227, 0.6351487530992227)




Linear CKA concern: 0.65700833657672




Linear CKA non-concern: 0.654959432985232




Kernel CKA concern: 0.5103348152522571




Kernel CKA non-concern: 0.5186907025800185




In [11]:
get_sparsity(module)

(0.595225509678216,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.5999993218315972,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.5999976264105903,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.5999993218315972,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.5999993218315972,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.599999745686849,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.599999745686849,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.5999993218315972,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.5999993218315972,
  'bert.encod