In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.3
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-09-01 11:28:19


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

Evaluate the pruned model 0




Evaluating:   0%|                                                                                             …

Loss: 1.0013




Precision: 0.6865, Recall: 0.6856, F1-Score: 0.6829




              precision    recall  f1-score   support

           0       0.56      0.57      0.57      2941
           1       0.74      0.66      0.69      2997
           2       0.72      0.77      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.80      0.82      0.81      3017
           5       0.90      0.83      0.87      3004
           6       0.60      0.42      0.49      3037
           7       0.62      0.74      0.67      3026
           8       0.64      0.76      0.69      2997
           9       0.75      0.76      0.75      2987

    accuracy                           0.69     30000
   macro avg       0.69      0.69      0.68     30000
weighted avg       0.69      0.69      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9498753865341298, 0.9498753865341298)




CCA coefficients mean non-concern: (0.9480891295600393, 0.9480891295600393)




Linear CKA concern: 0.9945937695445058




Linear CKA non-concern: 0.990955402248768




Kernel CKA concern: 0.9907992839332629




Kernel CKA non-concern: 0.9872535507782105




Evaluate the pruned model 1




Evaluating:   0%|                                                                                             …

Loss: 0.9984




Precision: 0.6869, Recall: 0.6865, F1-Score: 0.6836




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.73      0.66      0.70      2997
           2       0.72      0.78      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.80      0.82      0.81      3017
           5       0.90      0.84      0.87      3004
           6       0.60      0.42      0.49      3037
           7       0.62      0.74      0.67      3026
           8       0.64      0.76      0.70      2997
           9       0.75      0.76      0.75      2987

    accuracy                           0.69     30000
   macro avg       0.69      0.69      0.68     30000
weighted avg       0.69      0.69      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9463623807704357, 0.9463623807704357)




CCA coefficients mean non-concern: (0.9473035765305189, 0.9473035765305189)




Linear CKA concern: 0.9955170695743882




Linear CKA non-concern: 0.9898117396817114




Kernel CKA concern: 0.9926179499323039




Kernel CKA non-concern: 0.9846895186672568




Evaluate the pruned model 2




Evaluating:   0%|                                                                                             …

Loss: 0.9993




Precision: 0.6874, Recall: 0.6867, F1-Score: 0.6840




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.74      0.66      0.70      2997
           2       0.72      0.78      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.81      0.82      0.81      3017
           5       0.90      0.84      0.87      3004
           6       0.60      0.42      0.50      3037
           7       0.61      0.74      0.67      3026
           8       0.64      0.76      0.70      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.69     30000
   macro avg       0.69      0.69      0.68     30000
weighted avg       0.69      0.69      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9438870750567631, 0.9438870750567631)




CCA coefficients mean non-concern: (0.9471535177513608, 0.9471535177513608)




Linear CKA concern: 0.9962057565766402




Linear CKA non-concern: 0.990255716875542




Kernel CKA concern: 0.9941811051003606




Kernel CKA non-concern: 0.9843143843119113




Evaluate the pruned model 3




Evaluating:   0%|                                                                                             …

Loss: 0.9983




Precision: 0.6871, Recall: 0.6864, F1-Score: 0.6836




              precision    recall  f1-score   support

           0       0.56      0.57      0.57      2941
           1       0.73      0.66      0.70      2997
           2       0.72      0.78      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.81      0.82      0.81      3017
           5       0.90      0.84      0.87      3004
           6       0.60      0.42      0.49      3037
           7       0.61      0.74      0.67      3026
           8       0.64      0.76      0.70      2997
           9       0.75      0.76      0.75      2987

    accuracy                           0.69     30000
   macro avg       0.69      0.69      0.68     30000
weighted avg       0.69      0.69      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9503836391867191, 0.9503836391867191)




CCA coefficients mean non-concern: (0.9483144324749507, 0.9483144324749507)




Linear CKA concern: 0.995847440876123




Linear CKA non-concern: 0.9917487984435036




Kernel CKA concern: 0.9930437262826466




Kernel CKA non-concern: 0.9878539415930643




Evaluate the pruned model 4




Evaluating:   0%|                                                                                             …

Loss: 0.9984




Precision: 0.6871, Recall: 0.6867, F1-Score: 0.6839




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.74      0.66      0.70      2997
           2       0.72      0.78      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.80      0.82      0.81      3017
           5       0.90      0.84      0.87      3004
           6       0.60      0.42      0.49      3037
           7       0.62      0.73      0.67      3026
           8       0.64      0.76      0.70      2997
           9       0.75      0.76      0.75      2987

    accuracy                           0.69     30000
   macro avg       0.69      0.69      0.68     30000
weighted avg       0.69      0.69      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9486929314273385, 0.9486929314273385)




CCA coefficients mean non-concern: (0.9452854344880256, 0.9452854344880256)




Linear CKA concern: 0.9974400346817415




Linear CKA non-concern: 0.9896051707973021




Kernel CKA concern: 0.995516850973867




Kernel CKA non-concern: 0.984166947101387




Evaluate the pruned model 5




Evaluating:   0%|                                                                                             …

Loss: 1.0018




Precision: 0.6849, Recall: 0.6844, F1-Score: 0.6817




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.74      0.65      0.69      2997
           2       0.72      0.78      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.80      0.82      0.81      3017
           5       0.90      0.84      0.87      3004
           6       0.58      0.42      0.49      3037
           7       0.63      0.73      0.67      3026
           8       0.63      0.77      0.69      2997
           9       0.75      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.69      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9413454447224016, 0.9413454447224016)




CCA coefficients mean non-concern: (0.9467103258292967, 0.9467103258292967)




Linear CKA concern: 0.9965617668058033




Linear CKA non-concern: 0.9899239190478301




Kernel CKA concern: 0.9939304369497481




Kernel CKA non-concern: 0.9849639731243132




Evaluate the pruned model 6




Evaluating:   0%|                                                                                             …

Loss: 1.0000




Precision: 0.6869, Recall: 0.6860, F1-Score: 0.6831




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.73      0.66      0.69      2997
           2       0.73      0.78      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.81      0.82      0.81      3017
           5       0.90      0.84      0.87      3004
           6       0.61      0.42      0.49      3037
           7       0.62      0.74      0.67      3026
           8       0.64      0.76      0.69      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.69     30000
   macro avg       0.69      0.69      0.68     30000
weighted avg       0.69      0.69      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9501103012660594, 0.9501103012660594)




CCA coefficients mean non-concern: (0.9492817907889618, 0.9492817907889618)




Linear CKA concern: 0.9946506177096616




Linear CKA non-concern: 0.9914602455479978




Kernel CKA concern: 0.991287466488067




Kernel CKA non-concern: 0.9875356793670595




Evaluate the pruned model 7




Evaluating:   0%|                                                                                             …

Loss: 1.0001




Precision: 0.6864, Recall: 0.6857, F1-Score: 0.6831




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.74      0.66      0.70      2997
           2       0.73      0.77      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.80      0.82      0.81      3017
           5       0.90      0.83      0.87      3004
           6       0.59      0.42      0.49      3037
           7       0.62      0.73      0.67      3026
           8       0.64      0.76      0.69      2997
           9       0.75      0.76      0.75      2987

    accuracy                           0.69     30000
   macro avg       0.69      0.69      0.68     30000
weighted avg       0.69      0.69      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9461823593148011, 0.9461823593148011)




CCA coefficients mean non-concern: (0.9489421424304005, 0.9489421424304005)




Linear CKA concern: 0.9961203184081842




Linear CKA non-concern: 0.9902553885782723




Kernel CKA concern: 0.9937623148093739




Kernel CKA non-concern: 0.9852244017872028




Evaluate the pruned model 8




Evaluating:   0%|                                                                                             …

Loss: 1.0018




Precision: 0.6863, Recall: 0.6856, F1-Score: 0.6830




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.74      0.65      0.69      2997
           2       0.73      0.78      0.75      3016
           3       0.54      0.52      0.53      2978
           4       0.81      0.82      0.81      3017
           5       0.90      0.84      0.87      3004
           6       0.59      0.42      0.49      3037
           7       0.63      0.73      0.68      3026
           8       0.63      0.77      0.69      2997
           9       0.75      0.76      0.75      2987

    accuracy                           0.69     30000
   macro avg       0.69      0.69      0.68     30000
weighted avg       0.69      0.69      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9513636243986805, 0.9513636243986805)




CCA coefficients mean non-concern: (0.9488457368180416, 0.9488457368180416)




Linear CKA concern: 0.9959083513013652




Linear CKA non-concern: 0.9907618706524863




Kernel CKA concern: 0.9932134932166584




Kernel CKA non-concern: 0.9863770088686624




Evaluate the pruned model 9




Evaluating:   0%|                                                                                             …

Loss: 0.9980




Precision: 0.6874, Recall: 0.6868, F1-Score: 0.6841




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.73      0.67      0.70      2997
           2       0.73      0.78      0.75      3016
           3       0.53      0.53      0.53      2978
           4       0.80      0.82      0.81      3017
           5       0.90      0.84      0.87      3004
           6       0.61      0.42      0.50      3037
           7       0.62      0.74      0.67      3026
           8       0.65      0.75      0.70      2997
           9       0.75      0.76      0.75      2987

    accuracy                           0.69     30000
   macro avg       0.69      0.69      0.68     30000
weighted avg       0.69      0.69      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9489323206542344, 0.9489323206542344)




CCA coefficients mean non-concern: (0.9465746331266627, 0.9465746331266627)




Linear CKA concern: 0.9958375099502121




Linear CKA non-concern: 0.9888504163728028




Kernel CKA concern: 0.9931048490153889




Kernel CKA non-concern: 0.9827379502574436


