In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.4
seed = 44
include_layers = ["attention"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-09-01 00:06:55


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

Evaluate the pruned model 0




Evaluating:   0%|                                                                                             …

Loss: 1.0173




Precision: 0.6785, Recall: 0.6759, F1-Score: 0.6732




              precision    recall  f1-score   support

           0       0.55      0.56      0.56      2941
           1       0.73      0.64      0.68      2997
           2       0.73      0.76      0.75      3016
           3       0.53      0.52      0.52      2978
           4       0.79      0.81      0.80      3017
           5       0.92      0.81      0.86      3004
           6       0.58      0.39      0.46      3037
           7       0.59      0.74      0.66      3026
           8       0.64      0.76      0.69      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.67     30000
weighted avg       0.68      0.68      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8388344151713067, 0.8388344151713067)




CCA coefficients mean non-concern: (0.8368248947411939, 0.8368248947411939)




Linear CKA concern: 0.9674942450905125




Linear CKA non-concern: 0.9493551132744115




Kernel CKA concern: 0.9441874246562784




Kernel CKA non-concern: 0.93400581044909




Evaluate the pruned model 1




Evaluating:   0%|                                                                                             …

Loss: 1.0126




Precision: 0.6809, Recall: 0.6791, F1-Score: 0.6760




              precision    recall  f1-score   support

           0       0.55      0.57      0.56      2941
           1       0.73      0.66      0.69      2997
           2       0.73      0.77      0.75      3016
           3       0.55      0.50      0.52      2978
           4       0.79      0.82      0.81      3017
           5       0.91      0.81      0.86      3004
           6       0.59      0.40      0.47      3037
           7       0.59      0.74      0.66      3026
           8       0.64      0.76      0.70      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.833400850989348, 0.833400850989348)




CCA coefficients mean non-concern: (0.8356511268737326, 0.8356511268737326)




Linear CKA concern: 0.9730576783956988




Linear CKA non-concern: 0.950575394909527




Kernel CKA concern: 0.9574833647304765




Kernel CKA non-concern: 0.9302760533778657




Evaluate the pruned model 2




Evaluating:   0%|                                                                                             …

Loss: 1.0168




Precision: 0.6799, Recall: 0.6758, F1-Score: 0.6732




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.73      0.63      0.68      2997
           2       0.72      0.77      0.74      3016
           3       0.54      0.50      0.52      2978
           4       0.80      0.81      0.80      3017
           5       0.92      0.81      0.86      3004
           6       0.59      0.40      0.48      3037
           7       0.57      0.75      0.65      3026
           8       0.63      0.76      0.69      2997
           9       0.75      0.75      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.67     30000
weighted avg       0.68      0.68      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8298088132332487, 0.8298088132332487)




CCA coefficients mean non-concern: (0.836269399412066, 0.836269399412066)




Linear CKA concern: 0.9783815857268331




Linear CKA non-concern: 0.9479122016153093




Kernel CKA concern: 0.9685238119152147




Kernel CKA non-concern: 0.9215789666544761




Evaluate the pruned model 3




Evaluating:   0%|                                                                                             …

Loss: 1.0125




Precision: 0.6807, Recall: 0.6779, F1-Score: 0.6753




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.73      0.65      0.69      2997
           2       0.73      0.77      0.75      3016
           3       0.54      0.51      0.52      2978
           4       0.79      0.82      0.81      3017
           5       0.91      0.81      0.86      3004
           6       0.58      0.40      0.48      3037
           7       0.56      0.75      0.64      3026
           8       0.65      0.76      0.70      2997
           9       0.74      0.75      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.836807383622231, 0.836807383622231)




CCA coefficients mean non-concern: (0.8394026938498685, 0.8394026938498685)




Linear CKA concern: 0.9620310729129677




Linear CKA non-concern: 0.9548558830024905




Kernel CKA concern: 0.9435880217194449




Kernel CKA non-concern: 0.9371130187754807




Evaluate the pruned model 4




Evaluating:   0%|                                                                                             …

Loss: 1.0199




Precision: 0.6786, Recall: 0.6756, F1-Score: 0.6729




              precision    recall  f1-score   support

           0       0.55      0.57      0.56      2941
           1       0.73      0.63      0.68      2997
           2       0.74      0.76      0.75      3016
           3       0.54      0.50      0.52      2978
           4       0.78      0.83      0.81      3017
           5       0.92      0.80      0.86      3004
           6       0.58      0.40      0.47      3037
           7       0.58      0.75      0.65      3026
           8       0.64      0.76      0.70      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.67     30000
weighted avg       0.68      0.68      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8308898459391137, 0.8308898459391137)




CCA coefficients mean non-concern: (0.8317748319239606, 0.8317748319239606)




Linear CKA concern: 0.9811541719866529




Linear CKA non-concern: 0.941237284493822




Kernel CKA concern: 0.9684893484939491




Kernel CKA non-concern: 0.9197129440875693




Evaluate the pruned model 5




Evaluating:   0%|                                                                                             …

Loss: 1.0149




Precision: 0.6789, Recall: 0.6769, F1-Score: 0.6748




              precision    recall  f1-score   support

           0       0.56      0.55      0.56      2941
           1       0.73      0.64      0.68      2997
           2       0.74      0.76      0.75      3016
           3       0.52      0.52      0.52      2978
           4       0.80      0.81      0.81      3017
           5       0.91      0.82      0.86      3004
           6       0.56      0.41      0.47      3037
           7       0.59      0.74      0.66      3026
           8       0.64      0.76      0.69      2997
           9       0.74      0.75      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.67     30000
weighted avg       0.68      0.68      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8208918422067251, 0.8208918422067251)




CCA coefficients mean non-concern: (0.8341393674029053, 0.8341393674029053)




Linear CKA concern: 0.9790565640175445




Linear CKA non-concern: 0.9499553194583297




Kernel CKA concern: 0.9655793845913915




Kernel CKA non-concern: 0.9323624643256261




Evaluate the pruned model 6




Evaluating:   0%|                                                                                             …

Loss: 1.0149




Precision: 0.6802, Recall: 0.6779, F1-Score: 0.6753




              precision    recall  f1-score   support

           0       0.55      0.57      0.56      2941
           1       0.73      0.63      0.68      2997
           2       0.73      0.76      0.74      3016
           3       0.54      0.50      0.52      2978
           4       0.79      0.83      0.81      3017
           5       0.92      0.81      0.86      3004
           6       0.58      0.41      0.48      3037
           7       0.59      0.74      0.66      3026
           8       0.63      0.77      0.69      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.833419012287808, 0.833419012287808)




CCA coefficients mean non-concern: (0.8353297951353243, 0.8353297951353243)




Linear CKA concern: 0.9653678310927379




Linear CKA non-concern: 0.9529033708568843




Kernel CKA concern: 0.9447862882265623




Kernel CKA non-concern: 0.9363991799953677




Evaluate the pruned model 7




Evaluating:   0%|                                                                                             …

Loss: 1.0194




Precision: 0.6797, Recall: 0.6765, F1-Score: 0.6740




              precision    recall  f1-score   support

           0       0.55      0.56      0.56      2941
           1       0.73      0.63      0.68      2997
           2       0.74      0.76      0.75      3016
           3       0.54      0.51      0.52      2978
           4       0.78      0.82      0.80      3017
           5       0.92      0.81      0.86      3004
           6       0.57      0.40      0.47      3037
           7       0.57      0.75      0.65      3026
           8       0.63      0.77      0.69      2997
           9       0.75      0.75      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.67     30000
weighted avg       0.68      0.68      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8270675108458779, 0.8270675108458779)




CCA coefficients mean non-concern: (0.8357783283253646, 0.8357783283253646)




Linear CKA concern: 0.9748579376617703




Linear CKA non-concern: 0.9491390339836183




Kernel CKA concern: 0.9623374525743424




Kernel CKA non-concern: 0.928316393432754




Evaluate the pruned model 8




Evaluating:   0%|                                                                                             …

Loss: 1.0207




Precision: 0.6787, Recall: 0.6748, F1-Score: 0.6727




              precision    recall  f1-score   support

           0       0.55      0.56      0.56      2941
           1       0.74      0.61      0.67      2997
           2       0.72      0.77      0.74      3016
           3       0.52      0.52      0.52      2978
           4       0.80      0.81      0.80      3017
           5       0.91      0.81      0.86      3004
           6       0.57      0.41      0.48      3037
           7       0.60      0.73      0.66      3026
           8       0.62      0.78      0.69      2997
           9       0.75      0.74      0.75      2987

    accuracy                           0.67     30000
   macro avg       0.68      0.67      0.67     30000
weighted avg       0.68      0.67      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8357215299482674, 0.8357215299482674)




CCA coefficients mean non-concern: (0.8345711284524975, 0.8345711284524975)




Linear CKA concern: 0.9743397125169536




Linear CKA non-concern: 0.9459131978924409




Kernel CKA concern: 0.9607225150400887




Kernel CKA non-concern: 0.9282539703101161




Evaluate the pruned model 9




Evaluating:   0%|                                                                                             …

Loss: 1.0131




Precision: 0.6797, Recall: 0.6782, F1-Score: 0.6761




              precision    recall  f1-score   support

           0       0.55      0.58      0.56      2941
           1       0.73      0.64      0.68      2997
           2       0.73      0.76      0.75      3016
           3       0.54      0.50      0.52      2978
           4       0.80      0.81      0.80      3017
           5       0.92      0.81      0.86      3004
           6       0.55      0.41      0.47      3037
           7       0.60      0.74      0.66      3026
           8       0.65      0.76      0.70      2997
           9       0.73      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8308326844931228, 0.8308326844931228)




CCA coefficients mean non-concern: (0.8369474497272718, 0.8369474497272718)




Linear CKA concern: 0.9757963856323834




Linear CKA non-concern: 0.9491639787279608




Kernel CKA concern: 0.9580065068801981




Kernel CKA non-concern: 0.9264390200649371


