In [1]:
import os
import sys
sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_magnitude
)

In [3]:
name= "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size=16
num_workers=4
num_samples=16
magnitude_ratio=0.3
seed=44
include_layers=["attention", "intermediate", "output"]
exclude_layers=None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-21 19:36:30


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [8]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [9]:
module = copy.deepcopy(model)
prune_magnitude(module, sparsity_ratio=magnitude_ratio, include_layers=include_layers, exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(model, model_config, test_dataloader)
# save_module(module, "Modules/", f"magnitude_{name}_{magnitude_ratio}p.pt")

Evaluate the pruned model




Evaluating:   0%|                                                                             | 0/1875 [00:25<…

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Loss: 1.0014




Precision: 0.6875, Recall: 0.6865, F1-Score: 0.6838




              precision    recall  f1-score   support

           0       0.57      0.55      0.56      2972
           1       0.74      0.67      0.70      3016
           2       0.71      0.78      0.74      2985
           3       0.54      0.53      0.53      3023
           4       0.81      0.82      0.82      3039
           5       0.90      0.84      0.87      3076
           6       0.60      0.43      0.50      2965
           7       0.62      0.74      0.67      3031
           8       0.63      0.76      0.69      2932
           9       0.75      0.75      0.75      2961

    accuracy                           0.69     30000
   macro avg       0.69      0.69      0.68     30000
weighted avg       0.69      0.69      0.68     30000





In [10]:
for concern in range(num_labels):
    print(f"--{concern}--")
    valid = copy.deepcopy(valid_dataloader)
    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

--0--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.828924435388757, 0.828924435388757)




CCA coefficients mean non-concern: (0.8373104745093987, 0.8373104745093987)




Linear CKA concern: 0.9557404129039558




Linear CKA non-concern: 0.9560405225572791




Kernel CKA concern: 0.9280090922397243




Kernel CKA non-concern: 0.9413872826315083




--1--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.830357039754326, 0.830357039754326)




CCA coefficients mean non-concern: (0.8363449166404511, 0.8363449166404511)




Linear CKA concern: 0.9601061860160381




Linear CKA non-concern: 0.9558649425506218




Kernel CKA concern: 0.9343310941761948




Kernel CKA non-concern: 0.940192054859092




--2--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8283118979673819, 0.8283118979673819)




CCA coefficients mean non-concern: (0.8343003910713932, 0.8343003910713932)




Linear CKA concern: 0.9639001081422921




Linear CKA non-concern: 0.9550709770373105




Kernel CKA concern: 0.9446263814632723




Kernel CKA non-concern: 0.934009245156313




--3--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8299144430023583, 0.8299144430023583)




CCA coefficients mean non-concern: (0.8377393356276445, 0.8377393356276445)




Linear CKA concern: 0.9607611711481588




Linear CKA non-concern: 0.9563576969322096




Kernel CKA concern: 0.9376920731476355




Kernel CKA non-concern: 0.9431789160391022




--4--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8298584542894335, 0.8298584542894335)




CCA coefficients mean non-concern: (0.8355962741268612, 0.8355962741268612)




Linear CKA concern: 0.9641704723910135




Linear CKA non-concern: 0.9557753860165835




Kernel CKA concern: 0.9409174838593223




Kernel CKA non-concern: 0.9396883592686857




--5--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8308264752721318, 0.8308264752721318)




CCA coefficients mean non-concern: (0.8318297137625015, 0.8318297137625015)




Linear CKA concern: 0.9682543157189152




Linear CKA non-concern: 0.9552651383985432




Kernel CKA concern: 0.9482607178260063




Kernel CKA non-concern: 0.9382523322720756




--6--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.835510443597024, 0.835510443597024)




CCA coefficients mean non-concern: (0.8360096027129184, 0.8360096027129184)




Linear CKA concern: 0.9553448242480703




Linear CKA non-concern: 0.9551150609026475




Kernel CKA concern: 0.9176703083884029




Kernel CKA non-concern: 0.9416628819251532




--7--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8287242710839295, 0.8287242710839295)




CCA coefficients mean non-concern: (0.8383956107246566, 0.8383956107246566)




Linear CKA concern: 0.964474027034688




Linear CKA non-concern: 0.9570617067348958




Kernel CKA concern: 0.9435208137843148




Kernel CKA non-concern: 0.9438120987342815




--8--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8216147324417091, 0.8216147324417091)




CCA coefficients mean non-concern: (0.8374581391067525, 0.8374581391067525)




Linear CKA concern: 0.9575634951436686




Linear CKA non-concern: 0.9561694665771445




Kernel CKA concern: 0.9363818937715511




Kernel CKA non-concern: 0.9420536418270591




--9--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8332330564769816, 0.8332330564769816)




CCA coefficients mean non-concern: (0.8373251023387003, 0.8373251023387003)




Linear CKA concern: 0.9612389809631996




Linear CKA non-concern: 0.955293676650583




Kernel CKA concern: 0.9401257685139446




Kernel CKA non-concern: 0.94171835981785




In [11]:
get_sparsity(module)

(0.29761262051823517,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.2999996609157986,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.2999996609157986,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.2999996609157986,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.2999996609157986,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.2999996609157986,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.2999996609157986,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.2999996609157986,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.2999996609157986,
  'bert.e