In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.3
seed = 44
include_layers = ["attention"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-31 23:36:11


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

Evaluate the pruned model 0




Evaluating:   0%|                                                                                             …

Loss: 1.0011




Precision: 0.6850, Recall: 0.6834, F1-Score: 0.6808




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.73      0.66      0.69      2997
           2       0.73      0.77      0.75      3016
           3       0.53      0.52      0.53      2978
           4       0.80      0.82      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.59      0.41      0.48      3037
           7       0.61      0.74      0.67      3026
           8       0.64      0.76      0.70      2997
           9       0.75      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.69      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9158266944493085, 0.9158266944493085)




CCA coefficients mean non-concern: (0.912970932846153, 0.912970932846153)




Linear CKA concern: 0.9887163825010571




Linear CKA non-concern: 0.9823804145400087




Kernel CKA concern: 0.9804613890154908




Kernel CKA non-concern: 0.9765548244898472




Evaluate the pruned model 1




Evaluating:   0%|                                                                                             …

Loss: 0.9970




Precision: 0.6854, Recall: 0.6849, F1-Score: 0.6819




              precision    recall  f1-score   support

           0       0.56      0.57      0.57      2941
           1       0.73      0.67      0.70      2997
           2       0.72      0.78      0.75      3016
           3       0.55      0.52      0.53      2978
           4       0.80      0.83      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.59      0.41      0.49      3037
           7       0.61      0.74      0.67      3026
           8       0.65      0.76      0.70      2997
           9       0.75      0.75      0.75      2987

    accuracy                           0.69     30000
   macro avg       0.69      0.68      0.68     30000
weighted avg       0.69      0.69      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9086767319023554, 0.9086767319023554)




CCA coefficients mean non-concern: (0.9116583348726048, 0.9116583348726048)




Linear CKA concern: 0.9895957632113066




Linear CKA non-concern: 0.9815806693679392




Kernel CKA concern: 0.9840467712304075




Kernel CKA non-concern: 0.9742220622054953




Evaluate the pruned model 2




Evaluating:   0%|                                                                                             …

Loss: 0.9993




Precision: 0.6856, Recall: 0.6845, F1-Score: 0.6819




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.73      0.66      0.69      2997
           2       0.72      0.78      0.75      3016
           3       0.55      0.52      0.53      2978
           4       0.80      0.82      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.59      0.42      0.49      3037
           7       0.60      0.75      0.66      3026
           8       0.65      0.76      0.70      2997
           9       0.75      0.75      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.69      0.68      0.68     30000
weighted avg       0.69      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.908467919768572, 0.908467919768572)




CCA coefficients mean non-concern: (0.9105914315917224, 0.9105914315917224)




Linear CKA concern: 0.9921782428515649




Linear CKA non-concern: 0.9806497881092372




Kernel CKA concern: 0.9890596081277938




Kernel CKA non-concern: 0.9707870346756555




Evaluate the pruned model 3




Evaluating:   0%|                                                                                             …

Loss: 0.9969




Precision: 0.6861, Recall: 0.6854, F1-Score: 0.6826




              precision    recall  f1-score   support

           0       0.56      0.57      0.57      2941
           1       0.73      0.67      0.70      2997
           2       0.72      0.77      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.80      0.83      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.59      0.41      0.49      3037
           7       0.61      0.74      0.67      3026
           8       0.65      0.75      0.70      2997
           9       0.75      0.75      0.75      2987

    accuracy                           0.69     30000
   macro avg       0.69      0.69      0.68     30000
weighted avg       0.69      0.69      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9142995708162149, 0.9142995708162149)




CCA coefficients mean non-concern: (0.9142990791525503, 0.9142990791525503)




Linear CKA concern: 0.9894854268771297




Linear CKA non-concern: 0.98317570952536




Kernel CKA concern: 0.9831758107361486




Kernel CKA non-concern: 0.9768167824271052




Evaluate the pruned model 4




Evaluating:   0%|                                                                                             …

Loss: 0.9983




Precision: 0.6859, Recall: 0.6849, F1-Score: 0.6820




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.73      0.66      0.70      2997
           2       0.72      0.78      0.75      3016
           3       0.55      0.52      0.53      2978
           4       0.79      0.83      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.59      0.41      0.49      3037
           7       0.60      0.75      0.66      3026
           8       0.65      0.76      0.70      2997
           9       0.75      0.75      0.75      2987

    accuracy                           0.69     30000
   macro avg       0.69      0.68      0.68     30000
weighted avg       0.69      0.69      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.908726639218357, 0.908726639218357)




CCA coefficients mean non-concern: (0.9095602971198986, 0.9095602971198986)




Linear CKA concern: 0.9938562252369042




Linear CKA non-concern: 0.9807066615460336




Kernel CKA concern: 0.9896152349046802




Kernel CKA non-concern: 0.9721582726854582




Evaluate the pruned model 5




Evaluating:   0%|                                                                                             …

Loss: 1.0010




Precision: 0.6839, Recall: 0.6832, F1-Score: 0.6805




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.73      0.65      0.69      2997
           2       0.72      0.77      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.80      0.82      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.58      0.41      0.48      3037
           7       0.61      0.74      0.67      3026
           8       0.64      0.76      0.69      2997
           9       0.75      0.75      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.901654265764883, 0.901654265764883)




CCA coefficients mean non-concern: (0.9125601713499061, 0.9125601713499061)




Linear CKA concern: 0.9928555358938769




Linear CKA non-concern: 0.9811257688269067




Kernel CKA concern: 0.9875577080321405




Kernel CKA non-concern: 0.9740265243788165




Evaluate the pruned model 6




Evaluating:   0%|                                                                                             …

Loss: 0.9999




Precision: 0.6860, Recall: 0.6844, F1-Score: 0.6815




              precision    recall  f1-score   support

           0       0.56      0.57      0.57      2941
           1       0.73      0.66      0.69      2997
           2       0.72      0.78      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.80      0.82      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.61      0.41      0.49      3037
           7       0.60      0.74      0.67      3026
           8       0.64      0.76      0.70      2997
           9       0.75      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.69      0.68      0.68     30000
weighted avg       0.69      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9114917140964004, 0.9114917140964004)




CCA coefficients mean non-concern: (0.9139664733113274, 0.9139664733113274)




Linear CKA concern: 0.9882639153709867




Linear CKA non-concern: 0.9832055154476743




Kernel CKA concern: 0.9807603466191872




Kernel CKA non-concern: 0.9768067769761112




Evaluate the pruned model 7




Evaluating:   0%|                                                                                             …

Loss: 1.0009




Precision: 0.6850, Recall: 0.6835, F1-Score: 0.6809




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.73      0.66      0.70      2997
           2       0.73      0.77      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.80      0.82      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.59      0.41      0.48      3037
           7       0.60      0.75      0.67      3026
           8       0.64      0.76      0.70      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.69      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9066325391646998, 0.9066325391646998)




CCA coefficients mean non-concern: (0.9129673550242196, 0.9129673550242196)




Linear CKA concern: 0.9921451266744423




Linear CKA non-concern: 0.9810253650896779




Kernel CKA concern: 0.9876343962762764




Kernel CKA non-concern: 0.9728092466243229




Evaluate the pruned model 8




Evaluating:   0%|                                                                                             …

Loss: 1.0023




Precision: 0.6846, Recall: 0.6834, F1-Score: 0.6807




              precision    recall  f1-score   support

           0       0.55      0.57      0.56      2941
           1       0.74      0.65      0.69      2997
           2       0.72      0.78      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.80      0.82      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.59      0.41      0.48      3037
           7       0.62      0.73      0.67      3026
           8       0.63      0.77      0.69      2997
           9       0.75      0.75      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9135327331494212, 0.9135327331494212)




CCA coefficients mean non-concern: (0.9140132012334798, 0.9140132012334798)




Linear CKA concern: 0.9911145954161241




Linear CKA non-concern: 0.9807573676146479




Kernel CKA concern: 0.9858775415735225




Kernel CKA non-concern: 0.9735461674829975




Evaluate the pruned model 9




Evaluating:   0%|                                                                                             …

Loss: 0.9972




Precision: 0.6860, Recall: 0.6852, F1-Score: 0.6826




              precision    recall  f1-score   support

           0       0.55      0.57      0.56      2941
           1       0.73      0.67      0.70      2997
           2       0.73      0.78      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.80      0.82      0.81      3017
           5       0.90      0.83      0.87      3004
           6       0.60      0.41      0.49      3037
           7       0.61      0.74      0.67      3026
           8       0.65      0.75      0.70      2997
           9       0.75      0.75      0.75      2987

    accuracy                           0.69     30000
   macro avg       0.69      0.69      0.68     30000
weighted avg       0.69      0.69      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9125049513924424, 0.9125049513924424)




CCA coefficients mean non-concern: (0.9146344858300461, 0.9146344858300461)




Linear CKA concern: 0.9920110768690977




Linear CKA non-concern: 0.9818024146974894




Kernel CKA concern: 0.9862081285166987




Kernel CKA non-concern: 0.9729069952183872


