In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.6
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-09-01 13:04:43


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

Evaluate the pruned model 0




Evaluating:   0%|                                                                                             …

Loss: 1.1001




Precision: 0.6664, Recall: 0.6517, F1-Score: 0.6517




              precision    recall  f1-score   support

           0       0.53      0.58      0.55      2941
           1       0.77      0.52      0.62      2997
           2       0.75      0.71      0.73      3016
           3       0.51      0.48      0.50      2978
           4       0.80      0.79      0.80      3017
           5       0.93      0.77      0.84      3004
           6       0.52      0.41      0.46      3037
           7       0.52      0.76      0.62      3026
           8       0.61      0.77      0.68      2997
           9       0.73      0.74      0.74      2987

    accuracy                           0.65     30000
   macro avg       0.67      0.65      0.65     30000
weighted avg       0.67      0.65      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7599866907093483, 0.7599866907093483)




CCA coefficients mean non-concern: (0.7618393562424952, 0.7618393562424952)




Linear CKA concern: 0.9059465950584221




Linear CKA non-concern: 0.8700297248190032




Kernel CKA concern: 0.8531205971129672




Kernel CKA non-concern: 0.8270428222780593




Evaluate the pruned model 1




Evaluating:   0%|                                                                                             …

Loss: 1.1118




Precision: 0.6646, Recall: 0.6456, F1-Score: 0.6453




              precision    recall  f1-score   support

           0       0.52      0.58      0.55      2941
           1       0.77      0.53      0.63      2997
           2       0.76      0.68      0.72      3016
           3       0.53      0.43      0.47      2978
           4       0.81      0.79      0.80      3017
           5       0.93      0.77      0.84      3004
           6       0.54      0.39      0.46      3037
           7       0.48      0.77      0.59      3026
           8       0.59      0.77      0.67      2997
           9       0.72      0.75      0.73      2987

    accuracy                           0.65     30000
   macro avg       0.66      0.65      0.65     30000
weighted avg       0.66      0.65      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7605384502937064, 0.7605384502937064)




CCA coefficients mean non-concern: (0.7625739086621819, 0.7625739086621819)




Linear CKA concern: 0.8831756392939565




Linear CKA non-concern: 0.8744808836344546




Kernel CKA concern: 0.8412698325187772




Kernel CKA non-concern: 0.8229393954296389




Evaluate the pruned model 2




Evaluating:   0%|                                                                                             …

Loss: 1.1102




Precision: 0.6685, Recall: 0.6486, F1-Score: 0.6488




              precision    recall  f1-score   support

           0       0.53      0.57      0.55      2941
           1       0.77      0.53      0.63      2997
           2       0.76      0.69      0.72      3016
           3       0.52      0.46      0.49      2978
           4       0.83      0.77      0.80      3017
           5       0.93      0.78      0.85      3004
           6       0.56      0.39      0.46      3037
           7       0.48      0.78      0.59      3026
           8       0.60      0.76      0.67      2997
           9       0.71      0.75      0.73      2987

    accuracy                           0.65     30000
   macro avg       0.67      0.65      0.65     30000
weighted avg       0.67      0.65      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7430740071626574, 0.7430740071626574)




CCA coefficients mean non-concern: (0.7582934627945708, 0.7582934627945708)




Linear CKA concern: 0.9005090829263463




Linear CKA non-concern: 0.8699801788309893




Kernel CKA concern: 0.8883995550188523




Kernel CKA non-concern: 0.8045273694681436




Evaluate the pruned model 3




Evaluating:   0%|                                                                                             …

Loss: 1.1018




Precision: 0.6670, Recall: 0.6502, F1-Score: 0.6499




              precision    recall  f1-score   support

           0       0.50      0.59      0.54      2941
           1       0.76      0.56      0.64      2997
           2       0.76      0.69      0.72      3016
           3       0.55      0.44      0.49      2978
           4       0.82      0.79      0.80      3017
           5       0.94      0.76      0.84      3004
           6       0.54      0.39      0.45      3037
           7       0.50      0.77      0.60      3026
           8       0.60      0.77      0.67      2997
           9       0.72      0.74      0.73      2987

    accuracy                           0.65     30000
   macro avg       0.67      0.65      0.65     30000
weighted avg       0.67      0.65      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7614723505529916, 0.7614723505529916)




CCA coefficients mean non-concern: (0.7640389307686751, 0.7640389307686751)




Linear CKA concern: 0.8720266958064877




Linear CKA non-concern: 0.8785830367519056




Kernel CKA concern: 0.8250395799269075




Kernel CKA non-concern: 0.8294745776627587




Evaluate the pruned model 4




Evaluating:   0%|                                                                                             …

Loss: 1.1322




Precision: 0.6630, Recall: 0.6411, F1-Score: 0.6411




              precision    recall  f1-score   support

           0       0.46      0.61      0.52      2941
           1       0.78      0.51      0.61      2997
           2       0.77      0.67      0.72      3016
           3       0.51      0.45      0.48      2978
           4       0.81      0.78      0.80      3017
           5       0.94      0.76      0.84      3004
           6       0.55      0.37      0.44      3037
           7       0.51      0.75      0.61      3026
           8       0.58      0.78      0.67      2997
           9       0.72      0.74      0.73      2987

    accuracy                           0.64     30000
   macro avg       0.66      0.64      0.64     30000
weighted avg       0.66      0.64      0.64     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7518573248307242, 0.7518573248307242)




CCA coefficients mean non-concern: (0.7571107631607663, 0.7571107631607663)




Linear CKA concern: 0.9121085679428116




Linear CKA non-concern: 0.8440468642792737




Kernel CKA concern: 0.8892356362923618




Kernel CKA non-concern: 0.7813777003891834




Evaluate the pruned model 5




Evaluating:   0%|                                                                                             …

Loss: 1.1215




Precision: 0.6635, Recall: 0.6474, F1-Score: 0.6479




              precision    recall  f1-score   support

           0       0.50      0.59      0.54      2941
           1       0.78      0.50      0.61      2997
           2       0.76      0.68      0.72      3016
           3       0.50      0.49      0.49      2978
           4       0.83      0.76      0.79      3017
           5       0.93      0.79      0.85      3004
           6       0.49      0.41      0.45      3037
           7       0.56      0.73      0.63      3026
           8       0.58      0.78      0.67      2997
           9       0.72      0.74      0.73      2987

    accuracy                           0.65     30000
   macro avg       0.66      0.65      0.65     30000
weighted avg       0.66      0.65      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7522730508769664, 0.7522730508769664)




CCA coefficients mean non-concern: (0.7590909337358624, 0.7590909337358624)




Linear CKA concern: 0.9223999575472192




Linear CKA non-concern: 0.849532713030342




Kernel CKA concern: 0.8942941097431044




Kernel CKA non-concern: 0.7967502061581436




Evaluate the pruned model 6




Evaluating:   0%|                                                                                             …

Loss: 1.0969




Precision: 0.6665, Recall: 0.6523, F1-Score: 0.6524




              precision    recall  f1-score   support

           0       0.52      0.58      0.55      2941
           1       0.77      0.52      0.62      2997
           2       0.76      0.69      0.72      3016
           3       0.50      0.49      0.49      2978
           4       0.81      0.79      0.80      3017
           5       0.92      0.77      0.84      3004
           6       0.53      0.41      0.46      3037
           7       0.52      0.75      0.62      3026
           8       0.61      0.77      0.68      2997
           9       0.72      0.75      0.73      2987

    accuracy                           0.65     30000
   macro avg       0.67      0.65      0.65     30000
weighted avg       0.67      0.65      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7654326234877962, 0.7654326234877962)




CCA coefficients mean non-concern: (0.7656076513604961, 0.7656076513604961)




Linear CKA concern: 0.9044968493094984




Linear CKA non-concern: 0.8766800141568583




Kernel CKA concern: 0.8560908757607031




Kernel CKA non-concern: 0.8280025325763819




Evaluate the pruned model 7




Evaluating:   0%|                                                                                             …

Loss: 1.1044




Precision: 0.6657, Recall: 0.6518, F1-Score: 0.6522




              precision    recall  f1-score   support

           0       0.52      0.58      0.55      2941
           1       0.77      0.50      0.61      2997
           2       0.78      0.69      0.73      3016
           3       0.47      0.51      0.49      2978
           4       0.80      0.79      0.80      3017
           5       0.93      0.78      0.85      3004
           6       0.48      0.41      0.44      3037
           7       0.59      0.72      0.65      3026
           8       0.59      0.78      0.68      2997
           9       0.71      0.75      0.73      2987

    accuracy                           0.65     30000
   macro avg       0.67      0.65      0.65     30000
weighted avg       0.67      0.65      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7611676632956571, 0.7611676632956571)




CCA coefficients mean non-concern: (0.7641967220961128, 0.7641967220961128)




Linear CKA concern: 0.9181414490876448




Linear CKA non-concern: 0.8528494758974149




Kernel CKA concern: 0.8834756600316909




Kernel CKA non-concern: 0.794581765336382




Evaluate the pruned model 8




Evaluating:   0%|                                                                                             …

Loss: 1.1046




Precision: 0.6664, Recall: 0.6505, F1-Score: 0.6511




              precision    recall  f1-score   support

           0       0.52      0.58      0.55      2941
           1       0.78      0.50      0.61      2997
           2       0.76      0.70      0.73      3016
           3       0.46      0.52      0.49      2978
           4       0.81      0.78      0.80      3017
           5       0.93      0.77      0.84      3004
           6       0.52      0.41      0.46      3037
           7       0.55      0.74      0.63      3026
           8       0.61      0.76      0.68      2997
           9       0.72      0.74      0.73      2987

    accuracy                           0.65     30000
   macro avg       0.67      0.65      0.65     30000
weighted avg       0.67      0.65      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7610756156074493, 0.7610756156074493)




CCA coefficients mean non-concern: (0.7645055977497337, 0.7645055977497337)




Linear CKA concern: 0.8921230152503306




Linear CKA non-concern: 0.8623850878896075




Kernel CKA concern: 0.8627716369361785




Kernel CKA non-concern: 0.8159046436399461




Evaluate the pruned model 9




Evaluating:   0%|                                                                                             …

Loss: 1.1086




Precision: 0.6667, Recall: 0.6502, F1-Score: 0.6510




              precision    recall  f1-score   support

           0       0.52      0.57      0.55      2941
           1       0.77      0.51      0.61      2997
           2       0.76      0.70      0.73      3016
           3       0.50      0.48      0.49      2978
           4       0.82      0.77      0.80      3017
           5       0.93      0.77      0.84      3004
           6       0.50      0.43      0.46      3037
           7       0.52      0.76      0.62      3026
           8       0.61      0.76      0.68      2997
           9       0.73      0.75      0.74      2987

    accuracy                           0.65     30000
   macro avg       0.67      0.65      0.65     30000
weighted avg       0.67      0.65      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7572487211168909, 0.7572487211168909)




CCA coefficients mean non-concern: (0.7649919610141125, 0.7649919610141125)




Linear CKA concern: 0.935065508575107




Linear CKA non-concern: 0.8604533637577678




Kernel CKA concern: 0.8973787619065285




Kernel CKA non-concern: 0.7955454620506534


