In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.4
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-09-01 11:59:51


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

Evaluate the pruned model 0




Evaluating:   0%|                                                                                             …

Loss: 1.0067




Precision: 0.6835, Recall: 0.6815, F1-Score: 0.6788




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.74      0.64      0.68      2997
           2       0.72      0.77      0.74      3016
           3       0.53      0.52      0.53      2978
           4       0.81      0.81      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.60      0.41      0.48      3037
           7       0.60      0.74      0.67      3026
           8       0.63      0.76      0.69      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8887380552103885, 0.8887380552103885)




CCA coefficients mean non-concern: (0.8888937905362275, 0.8888937905362275)




Linear CKA concern: 0.9852001839227059




Linear CKA non-concern: 0.9764153915371342




Kernel CKA concern: 0.9745363378972145




Kernel CKA non-concern: 0.9664186866664759




Evaluate the pruned model 1




Evaluating:   0%|                                                                                             …

Loss: 1.0045




Precision: 0.6840, Recall: 0.6823, F1-Score: 0.6795




              precision    recall  f1-score   support

           0       0.55      0.58      0.56      2941
           1       0.74      0.65      0.69      2997
           2       0.72      0.78      0.75      3016
           3       0.54      0.51      0.53      2978
           4       0.81      0.81      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.59      0.41      0.48      3037
           7       0.60      0.74      0.66      3026
           8       0.64      0.76      0.69      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8905509256264482, 0.8905509256264482)




CCA coefficients mean non-concern: (0.887670986138214, 0.887670986138214)




Linear CKA concern: 0.9870686722526483




Linear CKA non-concern: 0.9750656950633584




Kernel CKA concern: 0.9787893896421752




Kernel CKA non-concern: 0.9623925633033019




Evaluate the pruned model 2




Evaluating:   0%|                                                                                             …

Loss: 1.0057




Precision: 0.6826, Recall: 0.6802, F1-Score: 0.6774




              precision    recall  f1-score   support

           0       0.55      0.57      0.56      2941
           1       0.73      0.64      0.68      2997
           2       0.72      0.78      0.75      3016
           3       0.53      0.51      0.52      2978
           4       0.81      0.81      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.60      0.40      0.48      3037
           7       0.59      0.74      0.66      3026
           8       0.64      0.76      0.69      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8806595435332054, 0.8806595435332054)




CCA coefficients mean non-concern: (0.887027562551019, 0.887027562551019)




Linear CKA concern: 0.989865816678302




Linear CKA non-concern: 0.974987347231819




Kernel CKA concern: 0.9847287848294366




Kernel CKA non-concern: 0.9594015019503964




Evaluate the pruned model 3




Evaluating:   0%|                                                                                             …

Loss: 1.0033




Precision: 0.6844, Recall: 0.6829, F1-Score: 0.6801




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.73      0.66      0.69      2997
           2       0.73      0.77      0.75      3016
           3       0.54      0.51      0.53      2978
           4       0.80      0.82      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.60      0.41      0.49      3037
           7       0.59      0.74      0.66      3026
           8       0.64      0.76      0.70      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.892905752506112, 0.892905752506112)




CCA coefficients mean non-concern: (0.8905469917138421, 0.8905469917138421)




Linear CKA concern: 0.9868664159713989




Linear CKA non-concern: 0.9774339487947545




Kernel CKA concern: 0.9792806129210333




Kernel CKA non-concern: 0.9672415781483025




Evaluate the pruned model 4




Evaluating:   0%|                                                                                             …

Loss: 1.0070




Precision: 0.6826, Recall: 0.6810, F1-Score: 0.6781




              precision    recall  f1-score   support

           0       0.55      0.58      0.56      2941
           1       0.73      0.65      0.69      2997
           2       0.73      0.77      0.75      3016
           3       0.54      0.51      0.52      2978
           4       0.80      0.82      0.81      3017
           5       0.91      0.82      0.87      3004
           6       0.59      0.41      0.48      3037
           7       0.60      0.74      0.66      3026
           8       0.64      0.76      0.69      2997
           9       0.73      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8891695238620398, 0.8891695238620398)




CCA coefficients mean non-concern: (0.8834505121823389, 0.8834505121823389)




Linear CKA concern: 0.9923525757643568




Linear CKA non-concern: 0.9731686704602361




Kernel CKA concern: 0.9869321381711689




Kernel CKA non-concern: 0.9590697088470065




Evaluate the pruned model 5




Evaluating:   0%|                                                                                             …

Loss: 1.0059




Precision: 0.6833, Recall: 0.6819, F1-Score: 0.6796




              precision    recall  f1-score   support

           0       0.56      0.56      0.56      2941
           1       0.74      0.64      0.69      2997
           2       0.73      0.77      0.75      3016
           3       0.53      0.53      0.53      2978
           4       0.80      0.81      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.57      0.42      0.48      3037
           7       0.61      0.73      0.67      3026
           8       0.63      0.76      0.69      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8772621207385465, 0.8772621207385465)




CCA coefficients mean non-concern: (0.8862853851595558, 0.8862853851595558)




Linear CKA concern: 0.990622556006843




Linear CKA non-concern: 0.97411556861686




Kernel CKA concern: 0.9841615200119471




Kernel CKA non-concern: 0.9624902264256433




Evaluate the pruned model 6




Evaluating:   0%|                                                                                             …

Loss: 1.0051




Precision: 0.6837, Recall: 0.6826, F1-Score: 0.6797




              precision    recall  f1-score   support

           0       0.56      0.57      0.57      2941
           1       0.74      0.65      0.69      2997
           2       0.72      0.77      0.75      3016
           3       0.54      0.51      0.52      2978
           4       0.80      0.82      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.59      0.41      0.49      3037
           7       0.61      0.74      0.67      3026
           8       0.63      0.76      0.69      2997
           9       0.73      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8923363406904631, 0.8923363406904631)




CCA coefficients mean non-concern: (0.8896290013725277, 0.8896290013725277)




Linear CKA concern: 0.9870836961805131




Linear CKA non-concern: 0.9787607643963528




Kernel CKA concern: 0.9786833568937434




Kernel CKA non-concern: 0.9690890815688602




Evaluate the pruned model 7




Evaluating:   0%|                                                                                             …

Loss: 1.0070




Precision: 0.6827, Recall: 0.6812, F1-Score: 0.6787




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.73      0.63      0.68      2997
           2       0.73      0.77      0.75      3016
           3       0.53      0.52      0.52      2978
           4       0.80      0.82      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.58      0.42      0.49      3037
           7       0.61      0.74      0.67      3026
           8       0.63      0.76      0.69      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8855994858654773, 0.8855994858654773)




CCA coefficients mean non-concern: (0.8896101841727015, 0.8896101841727015)




Linear CKA concern: 0.9881367439095557




Linear CKA non-concern: 0.9745901115182706




Kernel CKA concern: 0.9813953722589589




Kernel CKA non-concern: 0.9615236674142161




Evaluate the pruned model 8




Evaluating:   0%|                                                                                             …

Loss: 1.0078




Precision: 0.6822, Recall: 0.6797, F1-Score: 0.6775




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.74      0.62      0.67      2997
           2       0.72      0.78      0.75      3016
           3       0.52      0.53      0.52      2978
           4       0.81      0.81      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.59      0.42      0.49      3037
           7       0.62      0.73      0.67      3026
           8       0.63      0.76      0.69      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8894923637323721, 0.8894923637323721)




CCA coefficients mean non-concern: (0.8878082555910742, 0.8878082555910742)




Linear CKA concern: 0.9879685918815285




Linear CKA non-concern: 0.9751296404459938




Kernel CKA concern: 0.9809395669568576




Kernel CKA non-concern: 0.9635743977482815




Evaluate the pruned model 9




Evaluating:   0%|                                                                                             …

Loss: 1.0047




Precision: 0.6830, Recall: 0.6822, F1-Score: 0.6798




              precision    recall  f1-score   support

           0       0.56      0.57      0.56      2941
           1       0.73      0.64      0.68      2997
           2       0.73      0.77      0.75      3016
           3       0.53      0.51      0.52      2978
           4       0.81      0.82      0.81      3017
           5       0.91      0.83      0.87      3004
           6       0.58      0.42      0.49      3037
           7       0.61      0.74      0.67      3026
           8       0.65      0.76      0.70      2997
           9       0.74      0.76      0.75      2987

    accuracy                           0.68     30000
   macro avg       0.68      0.68      0.68     30000
weighted avg       0.68      0.68      0.68     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.889960382271644, 0.889960382271644)




CCA coefficients mean non-concern: (0.888049513891567, 0.888049513891567)




Linear CKA concern: 0.9882137196712775




Linear CKA non-concern: 0.9713294952562211




Kernel CKA concern: 0.9806144354974871




Kernel CKA non-concern: 0.9558082348227567


