In [1]:
import os
import sys
sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_magnitude
)

In [3]:
name= "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size=16
num_workers=4
num_samples=16
magnitude_ratio=0.5
seed=44
include_layers=["attention", "intermediate", "output"]
exclude_layers=None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-21 21:08:39


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [8]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [9]:
module = copy.deepcopy(model)
prune_magnitude(module, sparsity_ratio=magnitude_ratio, include_layers=include_layers, exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(model, model_config, test_dataloader)
# save_module(module, "Modules/", f"magnitude_{name}_{magnitude_ratio}p.pt")

Evaluate the pruned model




Evaluating:   0%|                                                                             | 0/1875 [00:25<…

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Loss: 1.0014




Precision: 0.6875, Recall: 0.6865, F1-Score: 0.6838




              precision    recall  f1-score   support

           0       0.57      0.55      0.56      2972
           1       0.74      0.67      0.70      3016
           2       0.71      0.78      0.74      2985
           3       0.54      0.53      0.53      3023
           4       0.81      0.82      0.82      3039
           5       0.90      0.84      0.87      3076
           6       0.60      0.43      0.50      2965
           7       0.62      0.74      0.67      3031
           8       0.63      0.76      0.69      2932
           9       0.75      0.75      0.75      2961

    accuracy                           0.69     30000
   macro avg       0.69      0.69      0.68     30000
weighted avg       0.69      0.69      0.68     30000





In [10]:
for concern in range(num_labels):
    print(f"--{concern}--")
    valid = copy.deepcopy(valid_dataloader)
    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

--0--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6859985612616648, 0.6859985612616648)




CCA coefficients mean non-concern: (0.6916821379235377, 0.6916821379235377)




Linear CKA concern: 0.819021829306645




Linear CKA non-concern: 0.826003127585358




Kernel CKA concern: 0.7092300075467171




Kernel CKA non-concern: 0.7626592096636872




--1--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6849073994377715, 0.6849073994377715)




CCA coefficients mean non-concern: (0.6931575948990766, 0.6931575948990766)




Linear CKA concern: 0.8396511590409833




Linear CKA non-concern: 0.8249050354427412




Kernel CKA concern: 0.7345432658668438




Kernel CKA non-concern: 0.7582447772777179




--2--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6759830827333598, 0.6759830827333598)




CCA coefficients mean non-concern: (0.692207143248664, 0.692207143248664)




Linear CKA concern: 0.8138335447955515




Linear CKA non-concern: 0.8280491637310524




Kernel CKA concern: 0.7349511254875309




Kernel CKA non-concern: 0.7467722527104474




--3--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6844404753834676, 0.6844404753834676)




CCA coefficients mean non-concern: (0.6920461377538463, 0.6920461377538463)




Linear CKA concern: 0.8376509129684216




Linear CKA non-concern: 0.8260316198386168




Kernel CKA concern: 0.7502654031544675




Kernel CKA non-concern: 0.770980837822248




--4--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6802236348506058, 0.6802236348506058)




CCA coefficients mean non-concern: (0.6926045966634268, 0.6926045966634268)




Linear CKA concern: 0.8193453110616616




Linear CKA non-concern: 0.8294432990127144




Kernel CKA concern: 0.7204207462013779




Kernel CKA non-concern: 0.7644229919050942




--5--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6793252703151496, 0.6793252703151496)




CCA coefficients mean non-concern: (0.6899138069013551, 0.6899138069013551)




Linear CKA concern: 0.8606415636826593




Linear CKA non-concern: 0.8235537648969005




Kernel CKA concern: 0.7728515918608435




Kernel CKA non-concern: 0.7491634655569166




--6--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6953725434682048, 0.6953725434682048)




CCA coefficients mean non-concern: (0.6927756519122718, 0.6927756519122718)




Linear CKA concern: 0.8084944858056227




Linear CKA non-concern: 0.8256753580864502




Kernel CKA concern: 0.6420609354502717




Kernel CKA non-concern: 0.7692256303038869




--7--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6890831684449795, 0.6890831684449795)




CCA coefficients mean non-concern: (0.6924652819920655, 0.6924652819920655)




Linear CKA concern: 0.8416188032816265




Linear CKA non-concern: 0.8277013235417742




Kernel CKA concern: 0.7511528640446645




Kernel CKA non-concern: 0.7701188126912274




--8--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6831892946314326, 0.6831892946314326)




CCA coefficients mean non-concern: (0.692151933242562, 0.692151933242562)




Linear CKA concern: 0.8405439267359528




Linear CKA non-concern: 0.8248801577643313




Kernel CKA concern: 0.7442006476985709




Kernel CKA non-concern: 0.7632856003837674




--9--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6749258466413084, 0.6749258466413084)




CCA coefficients mean non-concern: (0.693816045976728, 0.693816045976728)




Linear CKA concern: 0.8478078894466059




Linear CKA non-concern: 0.831070125892685




Kernel CKA concern: 0.7570961709361909




Kernel CKA non-concern: 0.7729852045274318




In [11]:
get_sparsity(module)

(0.496021614307495,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.5,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.5,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.5,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.5,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.5,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.5,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.5,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.5,
  'bert.encoder.layer.1.attention.self.key.bias': 0.0,
  'bert.encoder.layer.1.attention.self.value.weight': 0.5,
  'bert.encoder.l