In [1]:
import os
import sys
sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_wanda
)

In [3]:
name= "YahooAnswersTopics"
device = torch.device("cuda:1")
checkpoint = None
batch_size=32
num_workers=48
num_samples=16
concern=0
wanda_ratio=0.6
seed=44
include_layers=["attention", "intermediate", "output"]
exclude_layers=None

In [None]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

In [4]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.
{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}
The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.


In [5]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}
Loading cached dataset YahooAnswersTopics.
The dataset YahooAnswersTopics is loaded


In [6]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
# Evaluate the original model
# Evaluating: 100%|███████████████████████████████████████████████████████████████████| 1875/1875 [30:03<00:00,  1.04it/s]
# Loss: 1.0044
# Precision: 0.6874, Recall: 0.6865, F1-Score: 0.6839
#               precision    recall  f1-score   support

#            0       0.57      0.57      0.57      6000
#            1       0.74      0.66      0.69      6000
#            2       0.71      0.78      0.74      6000
#            3       0.54      0.53      0.53      6000
#            4       0.80      0.82      0.81      6000
#            5       0.90      0.84      0.87      6000
#            6       0.61      0.43      0.50      6000
#            7       0.62      0.73      0.67      6000
#            8       0.64      0.76      0.70      6000
#            9       0.75      0.75      0.75      6000

#     accuracy                           0.69     60000
#    macro avg       0.69      0.69      0.68     60000
# weighted avg       0.69      0.69      0.68     60000

In [9]:
module = copy.deepcopy(model)
prune_wanda(module, model_config, all_samples, sparsity_ratio=wanda_ratio, include_layers=include_layers, exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(module, model_config, test_dataloader)
# save_module(module, "Modules/", f"wanda_{name}_{wanda_ratio}p.pt")

Evaluate the pruned model


Evaluating: 100%|██████████| 1875/1875 [29:09<00:00,  1.07it/s]


Loss: 1.5007
Precision: 0.6495, Recall: 0.5565, F1-Score: 0.5658
              precision    recall  f1-score   support

           0       0.48      0.53      0.50      6000
           1       0.81      0.38      0.51      6000
           2       0.83      0.43      0.57      6000
           3       0.51      0.36      0.42      6000
           4       0.76      0.76      0.76      6000
           5       0.95      0.60      0.74      6000
           6       0.48      0.34      0.40      6000
           7       0.30      0.85      0.44      6000
           8       0.58      0.74      0.65      6000
           9       0.79      0.58      0.67      6000

    accuracy                           0.56     60000
   macro avg       0.65      0.56      0.57     60000
weighted avg       0.65      0.56      0.57     60000



In [10]:
for concern in range(num_labels):
    print(f"--{concern}--")
    positive_samples = SamplingDataset(
        train_dataloader, concern, num_samples, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train_dataloader, concern, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    similar(model, module, valid_dataloader, concern, num_samples, num_labels, device=device)

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
CCA coefficients mean concern: (0.7020227225376106, 0.7020227225376106)
CCA coefficients mean non-concern: (0.7004565747052875, 0.7004565747052875)
Linear CKA concern: 0.7277756339268255
Linear CKA non-concern: 0.70545823333982
Kernel CKA concern: 0.5095953891649833
Kernel CKA non-concern: 0.4767323097394551


In [11]:
get_sparsity(module)

(0.5945154895443348,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.5989583333333334,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.5999348958333334,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.5989583333333334,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.5989583333333334,
  'bert.en