In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.6
seed = 44
include_layers = ["attention"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-09-01 01:08:13


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

Evaluate the pruned model 0




Evaluating:   0%|                                                                                             …

Loss: 1.0703




Precision: 0.6735, Recall: 0.6626, F1-Score: 0.6614




              precision    recall  f1-score   support

           0       0.54      0.56      0.55      2941
           1       0.75      0.61      0.67      2997
           2       0.74      0.72      0.73      3016
           3       0.54      0.48      0.51      2978
           4       0.79      0.81      0.80      3017
           5       0.92      0.78      0.84      3004
           6       0.58      0.39      0.47      3037
           7       0.52      0.74      0.61      3026
           8       0.59      0.79      0.68      2997
           9       0.77      0.73      0.75      2987

    accuracy                           0.66     30000
   macro avg       0.67      0.66      0.66     30000
weighted avg       0.67      0.66      0.66     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7438777708471513, 0.7438777708471513)




CCA coefficients mean non-concern: (0.7422323705180839, 0.7422323705180839)




Linear CKA concern: 0.9297137998522202




Linear CKA non-concern: 0.9040675844307716




Kernel CKA concern: 0.8902572341864721




Kernel CKA non-concern: 0.887712816515034




Evaluate the pruned model 1




Evaluating:   0%|                                                                                             …

Loss: 1.0597




Precision: 0.6741, Recall: 0.6670, F1-Score: 0.6656




              precision    recall  f1-score   support

           0       0.54      0.56      0.55      2941
           1       0.71      0.68      0.69      2997
           2       0.73      0.73      0.73      3016
           3       0.54      0.49      0.52      2978
           4       0.79      0.81      0.80      3017
           5       0.92      0.78      0.84      3004
           6       0.57      0.40      0.47      3037
           7       0.53      0.74      0.62      3026
           8       0.63      0.77      0.69      2997
           9       0.77      0.72      0.75      2987

    accuracy                           0.67     30000
   macro avg       0.67      0.67      0.67     30000
weighted avg       0.67      0.67      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7401727116920016, 0.7401727116920016)




CCA coefficients mean non-concern: (0.7479971874105625, 0.7479971874105625)




Linear CKA concern: 0.9322136697093446




Linear CKA non-concern: 0.9056890692678002




Kernel CKA concern: 0.9004606180416449




Kernel CKA non-concern: 0.8837378377421928




Evaluate the pruned model 2




Evaluating:   0%|                                                                                             …

Loss: 1.0606




Precision: 0.6720, Recall: 0.6659, F1-Score: 0.6651




              precision    recall  f1-score   support

           0       0.55      0.55      0.55      2941
           1       0.74      0.62      0.67      2997
           2       0.70      0.76      0.73      3016
           3       0.51      0.51      0.51      2978
           4       0.79      0.82      0.80      3017
           5       0.92      0.77      0.84      3004
           6       0.54      0.41      0.47      3037
           7       0.56      0.73      0.64      3026
           8       0.64      0.76      0.69      2997
           9       0.77      0.72      0.74      2987

    accuracy                           0.67     30000
   macro avg       0.67      0.67      0.67     30000
weighted avg       0.67      0.67      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7265419327174072, 0.7265419327174072)




CCA coefficients mean non-concern: (0.7422652617238152, 0.7422652617238152)




Linear CKA concern: 0.935773029301702




Linear CKA non-concern: 0.9053924548953941




Kernel CKA concern: 0.9219550072866116




Kernel CKA non-concern: 0.8706355952493438




Evaluate the pruned model 3




Evaluating:   0%|                                                                                             …

Loss: 1.0517




Precision: 0.6730, Recall: 0.6674, F1-Score: 0.6662




              precision    recall  f1-score   support

           0       0.55      0.55      0.55      2941
           1       0.73      0.64      0.68      2997
           2       0.74      0.72      0.73      3016
           3       0.53      0.51      0.52      2978
           4       0.79      0.81      0.80      3017
           5       0.92      0.79      0.85      3004
           6       0.54      0.40      0.46      3037
           7       0.55      0.73      0.63      3026
           8       0.63      0.77      0.69      2997
           9       0.75      0.74      0.75      2987

    accuracy                           0.67     30000
   macro avg       0.67      0.67      0.67     30000
weighted avg       0.67      0.67      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7431022783441559, 0.7431022783441559)




CCA coefficients mean non-concern: (0.7456482476932711, 0.7456482476932711)




Linear CKA concern: 0.9175470031038728




Linear CKA non-concern: 0.9158468698341252




Kernel CKA concern: 0.8856005631554175




Kernel CKA non-concern: 0.8944708585349268




Evaluate the pruned model 4




Evaluating:   0%|                                                                                             …

Loss: 1.0727




Precision: 0.6718, Recall: 0.6638, F1-Score: 0.6615




              precision    recall  f1-score   support

           0       0.56      0.54      0.55      2941
           1       0.74      0.63      0.68      2997
           2       0.74      0.71      0.73      3016
           3       0.55      0.48      0.52      2978
           4       0.75      0.85      0.80      3017
           5       0.92      0.77      0.84      3004
           6       0.56      0.39      0.46      3037
           7       0.51      0.75      0.61      3026
           8       0.62      0.78      0.69      2997
           9       0.76      0.74      0.75      2987

    accuracy                           0.66     30000
   macro avg       0.67      0.66      0.66     30000
weighted avg       0.67      0.66      0.66     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7240918476097191, 0.7240918476097191)




CCA coefficients mean non-concern: (0.7382793034166434, 0.7382793034166434)




Linear CKA concern: 0.9547840733296653




Linear CKA non-concern: 0.8928336340811365




Kernel CKA concern: 0.9285267837208635




Kernel CKA non-concern: 0.8686956048709598




Evaluate the pruned model 5




Evaluating:   0%|                                                                                             …

Loss: 1.0616




Precision: 0.6732, Recall: 0.6652, F1-Score: 0.6639




              precision    recall  f1-score   support

           0       0.56      0.55      0.55      2941
           1       0.75      0.61      0.67      2997
           2       0.74      0.72      0.73      3016
           3       0.52      0.50      0.51      2978
           4       0.79      0.82      0.80      3017
           5       0.90      0.80      0.85      3004
           6       0.56      0.40      0.47      3037
           7       0.52      0.75      0.62      3026
           8       0.62      0.77      0.69      2997
           9       0.77      0.73      0.75      2987

    accuracy                           0.67     30000
   macro avg       0.67      0.67      0.66     30000
weighted avg       0.67      0.67      0.66     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7302507570638197, 0.7302507570638197)




CCA coefficients mean non-concern: (0.7409636774914765, 0.7409636774914765)




Linear CKA concern: 0.9567243765406274




Linear CKA non-concern: 0.9030883607809052




Kernel CKA concern: 0.9312407501913251




Kernel CKA non-concern: 0.885193171034956




Evaluate the pruned model 6




Evaluating:   0%|                                                                                             …

Loss: 1.0763




Precision: 0.6730, Recall: 0.6618, F1-Score: 0.6605




              precision    recall  f1-score   support

           0       0.53      0.57      0.55      2941
           1       0.75      0.61      0.67      2997
           2       0.75      0.71      0.73      3016
           3       0.54      0.49      0.52      2978
           4       0.77      0.82      0.80      3017
           5       0.93      0.77      0.84      3004
           6       0.58      0.39      0.46      3037
           7       0.53      0.74      0.62      3026
           8       0.59      0.79      0.68      2997
           9       0.77      0.72      0.75      2987

    accuracy                           0.66     30000
   macro avg       0.67      0.66      0.66     30000
weighted avg       0.67      0.66      0.66     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7368810207791505, 0.7368810207791505)




CCA coefficients mean non-concern: (0.7428630459543385, 0.7428630459543385)




Linear CKA concern: 0.9082720248617199




Linear CKA non-concern: 0.9097638885578273




Kernel CKA concern: 0.8662468066652588




Kernel CKA non-concern: 0.8860137353400979




Evaluate the pruned model 7




Evaluating:   0%|                                                                                             …

Loss: 1.0726




Precision: 0.6746, Recall: 0.6610, F1-Score: 0.6597




              precision    recall  f1-score   support

           0       0.55      0.56      0.55      2941
           1       0.76      0.60      0.67      2997
           2       0.74      0.71      0.73      3016
           3       0.53      0.49      0.51      2978
           4       0.78      0.82      0.80      3017
           5       0.92      0.78      0.84      3004
           6       0.59      0.38      0.46      3037
           7       0.50      0.77      0.61      3026
           8       0.61      0.78      0.68      2997
           9       0.77      0.72      0.74      2987

    accuracy                           0.66     30000
   macro avg       0.67      0.66      0.66     30000
weighted avg       0.67      0.66      0.66     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7302133637963079, 0.7302133637963079)




CCA coefficients mean non-concern: (0.7438562465249059, 0.7438562465249059)




Linear CKA concern: 0.9485766165820652




Linear CKA non-concern: 0.8980088504167146




Kernel CKA concern: 0.9263990277738949




Kernel CKA non-concern: 0.8718670887531503




Evaluate the pruned model 8




Evaluating:   0%|                                                                                             …

Loss: 1.0707




Precision: 0.6738, Recall: 0.6636, F1-Score: 0.6625




              precision    recall  f1-score   support

           0       0.53      0.56      0.55      2941
           1       0.75      0.61      0.67      2997
           2       0.73      0.73      0.73      3016
           3       0.52      0.51      0.52      2978
           4       0.80      0.80      0.80      3017
           5       0.93      0.78      0.84      3004
           6       0.59      0.39      0.47      3037
           7       0.55      0.73      0.62      3026
           8       0.59      0.80      0.68      2997
           9       0.76      0.73      0.75      2987

    accuracy                           0.66     30000
   macro avg       0.67      0.66      0.66     30000
weighted avg       0.67      0.66      0.66     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.733835742342467, 0.733835742342467)




CCA coefficients mean non-concern: (0.7467738154731557, 0.7467738154731557)




Linear CKA concern: 0.9592063901107731




Linear CKA non-concern: 0.8968998997823386




Kernel CKA concern: 0.9334198724618678




Kernel CKA non-concern: 0.8770547936498068




Evaluate the pruned model 9




Evaluating:   0%|                                                                                             …

Loss: 1.0621




Precision: 0.6739, Recall: 0.6647, F1-Score: 0.6621




              precision    recall  f1-score   support

           0       0.53      0.57      0.55      2941
           1       0.74      0.62      0.68      2997
           2       0.74      0.72      0.73      3016
           3       0.54      0.49      0.51      2978
           4       0.78      0.81      0.80      3017
           5       0.92      0.78      0.85      3004
           6       0.61      0.37      0.46      3037
           7       0.52      0.75      0.62      3026
           8       0.62      0.77      0.69      2997
           9       0.73      0.76      0.75      2987

    accuracy                           0.66     30000
   macro avg       0.67      0.66      0.66     30000
weighted avg       0.67      0.66      0.66     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7282837660241349, 0.7282837660241349)




CCA coefficients mean non-concern: (0.7455482511111741, 0.7455482511111741)




Linear CKA concern: 0.941143574592491




Linear CKA non-concern: 0.9126240303819777




Kernel CKA concern: 0.901714729368755




Kernel CKA non-concern: 0.8893286201592769


