In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.3
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-23 11:04:58


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Evaluate the pruned model 0




Evaluating:   0%|                                                                             | 0/1875 [00:49<…

Loss: 1.0063




Precision: 0.6817, Recall: 0.6808, F1-Score: 0.6781




              precision    recall  f1-score   support

           0       0.57      0.55      0.56      2972
           1       0.74      0.66      0.70      3016
           2       0.71      0.77      0.74      2985
           3       0.54      0.51      0.52      3023
           4       0.81      0.82      0.81      3039
           5       0.90      0.82      0.86      3076
           6       0.58      0.42      0.49      2965
           7       0.60      0.74      0.66      3031
           8       0.63      0.76      0.69      2932
           9       0.74      0.75      0.74      2961

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9026282809742092, 0.9026282809742092)




CCA coefficients mean non-concern: (0.8959718538891461, 0.8959718538891461)




Linear CKA concern: 0.9867761112324681




Linear CKA non-concern: 0.9792155724070866




Kernel CKA concern: 0.977450116258375




Kernel CKA non-concern: 0.9682459682677322




Evaluate the pruned model 1




Evaluating:   0%|                                                                             | 0/1875 [01:47<…

Loss: 1.0029




Precision: 0.6822, Recall: 0.6801, F1-Score: 0.6782




              precision    recall  f1-score   support

           0       0.56      0.56      0.56      2972
           1       0.73      0.67      0.70      3016
           2       0.70      0.78      0.74      2985
           3       0.52      0.52      0.52      3023
           4       0.82      0.81      0.82      3039
           5       0.91      0.81      0.86      3076
           6       0.58      0.42      0.49      2965
           7       0.59      0.74      0.66      3031
           8       0.66      0.74      0.70      2932
           9       0.74      0.75      0.74      2961

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9049435379242682, 0.9049435379242682)




CCA coefficients mean non-concern: (0.8968633198261061, 0.8968633198261061)




Linear CKA concern: 0.989606168728902




Linear CKA non-concern: 0.9795176392018907




Kernel CKA concern: 0.9823060876983803




Kernel CKA non-concern: 0.9695370459739422




Evaluate the pruned model 2




Evaluating:   0%|                                                                             | 0/1875 [01:58<…

Loss: 1.0079




Precision: 0.6805, Recall: 0.6769, F1-Score: 0.6745




              precision    recall  f1-score   support

           0       0.56      0.55      0.55      2972
           1       0.75      0.64      0.69      3016
           2       0.70      0.78      0.74      2985
           3       0.52      0.52      0.52      3023
           4       0.82      0.81      0.81      3039
           5       0.90      0.82      0.86      3076
           6       0.59      0.41      0.48      2965
           7       0.58      0.75      0.65      3031
           8       0.64      0.76      0.69      2932
           9       0.74      0.75      0.75      2961

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.67     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8940650601936768, 0.8940650601936768)




CCA coefficients mean non-concern: (0.8904506915120399, 0.8904506915120399)




Linear CKA concern: 0.9921405995801105




Linear CKA non-concern: 0.9670898210368519




Kernel CKA concern: 0.9876190731753187




Kernel CKA non-concern: 0.9486759178242435




Evaluate the pruned model 3




Evaluating:   0%|                                                                             | 0/1875 [01:48<…

Loss: 1.0029




Precision: 0.6820, Recall: 0.6807, F1-Score: 0.6781




              precision    recall  f1-score   support

           0       0.56      0.55      0.56      2972
           1       0.74      0.67      0.70      3016
           2       0.71      0.77      0.74      2985
           3       0.54      0.52      0.53      3023
           4       0.82      0.82      0.82      3039
           5       0.90      0.82      0.86      3076
           6       0.58      0.41      0.48      2965
           7       0.60      0.74      0.66      3031
           8       0.64      0.76      0.69      2932
           9       0.74      0.75      0.74      2961

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9067164349244691, 0.9067164349244691)




CCA coefficients mean non-concern: (0.9044047966013363, 0.9044047966013363)




Linear CKA concern: 0.9895723364069964




Linear CKA non-concern: 0.9800781109416334




Kernel CKA concern: 0.9822667678436775




Kernel CKA non-concern: 0.9711598300241799




Evaluate the pruned model 4




Evaluating:   0%|                                                                             | 0/1875 [01:48<…

Loss: 1.0064




Precision: 0.6818, Recall: 0.6802, F1-Score: 0.6777




              precision    recall  f1-score   support

           0       0.57      0.54      0.55      2972
           1       0.73      0.67      0.70      3016
           2       0.71      0.77      0.74      2985
           3       0.54      0.51      0.52      3023
           4       0.81      0.82      0.81      3039
           5       0.92      0.81      0.86      3076
           6       0.58      0.42      0.49      2965
           7       0.60      0.75      0.66      3031
           8       0.64      0.76      0.69      2932
           9       0.74      0.75      0.74      2961

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9077116166707708, 0.9077116166707708)




CCA coefficients mean non-concern: (0.8933689703440257, 0.8933689703440257)




Linear CKA concern: 0.9923673351377825




Linear CKA non-concern: 0.9725829496257825




Kernel CKA concern: 0.9875498101698463




Kernel CKA non-concern: 0.9592005844251066




Evaluate the pruned model 5




Evaluating:   0%|                                                                             | 0/1875 [01:47<…

Loss: 1.0145




Precision: 0.6809, Recall: 0.6779, F1-Score: 0.6763




              precision    recall  f1-score   support

           0       0.56      0.55      0.56      2972
           1       0.75      0.62      0.68      3016
           2       0.73      0.76      0.74      2985
           3       0.52      0.52      0.52      3023
           4       0.81      0.81      0.81      3039
           5       0.91      0.82      0.86      3076
           6       0.55      0.43      0.49      2965
           7       0.63      0.73      0.68      3031
           8       0.61      0.78      0.68      2932
           9       0.74      0.75      0.74      2961

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8956315172839576, 0.8956315172839576)




CCA coefficients mean non-concern: (0.888926198265946, 0.888926198265946)




Linear CKA concern: 0.9920879940472352




Linear CKA non-concern: 0.9656082174124325




Kernel CKA concern: 0.9867993741359067




Kernel CKA non-concern: 0.9449587022118608




Evaluate the pruned model 6




Evaluating:   0%|                                                                             | 0/1875 [01:52<…

Loss: 1.0060




Precision: 0.6822, Recall: 0.6799, F1-Score: 0.6779




              precision    recall  f1-score   support

           0       0.56      0.55      0.56      2972
           1       0.74      0.66      0.70      3016
           2       0.72      0.76      0.74      2985
           3       0.53      0.51      0.52      3023
           4       0.82      0.81      0.81      3039
           5       0.90      0.82      0.86      3076
           6       0.58      0.42      0.49      2965
           7       0.59      0.75      0.66      3031
           8       0.64      0.76      0.69      2932
           9       0.74      0.74      0.74      2961

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9056272671133859, 0.9056272671133859)




CCA coefficients mean non-concern: (0.9065630754083001, 0.9065630754083001)




Linear CKA concern: 0.9854798642026481




Linear CKA non-concern: 0.9735549408601419




Kernel CKA concern: 0.9722919602036714




Kernel CKA non-concern: 0.9587888989542533




Evaluate the pruned model 7




Evaluating:   0%|                                                                             | 0/1875 [01:47<…

Loss: 1.0079




Precision: 0.6826, Recall: 0.6800, F1-Score: 0.6778




              precision    recall  f1-score   support

           0       0.57      0.55      0.56      2972
           1       0.74      0.64      0.69      3016
           2       0.72      0.77      0.74      2985
           3       0.52      0.52      0.52      3023
           4       0.82      0.81      0.82      3039
           5       0.91      0.82      0.86      3076
           6       0.57      0.42      0.49      2965
           7       0.59      0.75      0.66      3031
           8       0.64      0.76      0.69      2932
           9       0.73      0.76      0.74      2961

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse