In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.5
seed = 44
include_layers = ["attention"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-09-01 00:37:36


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

Evaluate the pruned model 0




Evaluating:   0%|                                                                                             …

Loss: 1.1042




Precision: 0.6667, Recall: 0.6524, F1-Score: 0.6516




              precision    recall  f1-score   support

           0       0.52      0.58      0.55      2941
           1       0.76      0.56      0.64      2997
           2       0.75      0.70      0.73      3016
           3       0.54      0.46      0.49      2978
           4       0.77      0.81      0.79      3017
           5       0.93      0.75      0.83      3004
           6       0.53      0.39      0.45      3037
           7       0.50      0.77      0.60      3026
           8       0.62      0.77      0.69      2997
           9       0.74      0.74      0.74      2987

    accuracy                           0.65     30000
   macro avg       0.67      0.65      0.65     30000
weighted avg       0.67      0.65      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7703866686781426, 0.7703866686781426)




CCA coefficients mean non-concern: (0.7671964053134034, 0.7671964053134034)




Linear CKA concern: 0.9067333678497405




Linear CKA non-concern: 0.863385374950487




Kernel CKA concern: 0.8550811161734908




Kernel CKA non-concern: 0.8300182292898883




Evaluate the pruned model 1




Evaluating:   0%|                                                                                             …

Loss: 1.1034




Precision: 0.6692, Recall: 0.6548, F1-Score: 0.6543




              precision    recall  f1-score   support

           0       0.55      0.55      0.55      2941
           1       0.72      0.62      0.67      2997
           2       0.75      0.71      0.73      3016
           3       0.55      0.44      0.49      2978
           4       0.77      0.82      0.79      3017
           5       0.94      0.75      0.83      3004
           6       0.56      0.40      0.46      3037
           7       0.48      0.78      0.59      3026
           8       0.63      0.76      0.69      2997
           9       0.75      0.72      0.74      2987

    accuracy                           0.66     30000
   macro avg       0.67      0.65      0.65     30000
weighted avg       0.67      0.66      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7633966744492939, 0.7633966744492939)




CCA coefficients mean non-concern: (0.7696084323619775, 0.7696084323619775)




Linear CKA concern: 0.9052986240240094




Linear CKA non-concern: 0.8673076647757896




Kernel CKA concern: 0.8674285072003483




Kernel CKA non-concern: 0.8227516491500001




Evaluate the pruned model 2




Evaluating:   0%|                                                                                             …

Loss: 1.1036




Precision: 0.6690, Recall: 0.6556, F1-Score: 0.6555




              precision    recall  f1-score   support

           0       0.55      0.56      0.55      2941
           1       0.76      0.57      0.65      2997
           2       0.73      0.73      0.73      3016
           3       0.52      0.49      0.50      2978
           4       0.79      0.81      0.80      3017
           5       0.93      0.75      0.83      3004
           6       0.53      0.40      0.46      3037
           7       0.50      0.76      0.61      3026
           8       0.61      0.77      0.68      2997
           9       0.76      0.72      0.74      2987

    accuracy                           0.66     30000
   macro avg       0.67      0.66      0.66     30000
weighted avg       0.67      0.66      0.66     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7537779772051525, 0.7537779772051525)




CCA coefficients mean non-concern: (0.7648271285911397, 0.7648271285911397)




Linear CKA concern: 0.9197300419533943




Linear CKA non-concern: 0.8715127518585485




Kernel CKA concern: 0.9015492900798987




Kernel CKA non-concern: 0.8103376869548191




Evaluate the pruned model 3




Evaluating:   0%|                                                                                             …

Loss: 1.0936




Precision: 0.6695, Recall: 0.6566, F1-Score: 0.6563




              precision    recall  f1-score   support

           0       0.53      0.57      0.55      2941
           1       0.74      0.61      0.67      2997
           2       0.75      0.72      0.73      3016
           3       0.55      0.47      0.50      2978
           4       0.79      0.81      0.80      3017
           5       0.93      0.75      0.83      3004
           6       0.54      0.39      0.45      3037
           7       0.49      0.76      0.60      3026
           8       0.63      0.76      0.69      2997
           9       0.74      0.74      0.74      2987

    accuracy                           0.66     30000
   macro avg       0.67      0.66      0.66     30000
weighted avg       0.67      0.66      0.66     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7685245773994431, 0.7685245773994431)




CCA coefficients mean non-concern: (0.7692809828287933, 0.7692809828287933)




Linear CKA concern: 0.8861520809685693




Linear CKA non-concern: 0.8839024323408705




Kernel CKA concern: 0.8416384375541871




Kernel CKA non-concern: 0.8426579135961251




Evaluate the pruned model 4




Evaluating:   0%|                                                                                             …

Loss: 1.1252




Precision: 0.6644, Recall: 0.6496, F1-Score: 0.6491




              precision    recall  f1-score   support

           0       0.51      0.57      0.54      2941
           1       0.75      0.58      0.65      2997
           2       0.77      0.68      0.72      3016
           3       0.54      0.45      0.49      2978
           4       0.76      0.84      0.79      3017
           5       0.94      0.74      0.83      3004
           6       0.53      0.39      0.45      3037
           7       0.49      0.75      0.60      3026
           8       0.62      0.77      0.68      2997
           9       0.75      0.73      0.74      2987

    accuracy                           0.65     30000
   macro avg       0.66      0.65      0.65     30000
weighted avg       0.66      0.65      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7581748163914891, 0.7581748163914891)




CCA coefficients mean non-concern: (0.765716195240935, 0.765716195240935)




Linear CKA concern: 0.9309651375110749




Linear CKA non-concern: 0.8424012098974013




Kernel CKA concern: 0.9013825421883799




Kernel CKA non-concern: 0.7904297051719932




Evaluate the pruned model 5




Evaluating:   0%|                                                                                             …

Loss: 1.1019




Precision: 0.6692, Recall: 0.6572, F1-Score: 0.6577




              precision    recall  f1-score   support

           0       0.53      0.57      0.55      2941
           1       0.76      0.56      0.65      2997
           2       0.76      0.70      0.73      3016
           3       0.51      0.51      0.51      2978
           4       0.80      0.79      0.80      3017
           5       0.92      0.78      0.85      3004
           6       0.51      0.41      0.46      3037
           7       0.53      0.76      0.62      3026
           8       0.63      0.76      0.69      2997
           9       0.74      0.73      0.74      2987

    accuracy                           0.66     30000
   macro avg       0.67      0.66      0.66     30000
weighted avg       0.67      0.66      0.66     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7576502590279584, 0.7576502590279584)




CCA coefficients mean non-concern: (0.7679491247470237, 0.7679491247470237)




Linear CKA concern: 0.9397013060079498




Linear CKA non-concern: 0.8654542768133897




Kernel CKA concern: 0.9141161300606075




Kernel CKA non-concern: 0.8270119373197282




Evaluate the pruned model 6




Evaluating:   0%|                                                                                             …

Loss: 1.1040




Precision: 0.6681, Recall: 0.6516, F1-Score: 0.6520




              precision    recall  f1-score   support

           0       0.52      0.58      0.55      2941
           1       0.75      0.56      0.64      2997
           2       0.75      0.70      0.73      3016
           3       0.53      0.47      0.50      2978
           4       0.79      0.80      0.80      3017
           5       0.94      0.75      0.83      3004
           6       0.54      0.40      0.46      3037
           7       0.49      0.77      0.60      3026
           8       0.62      0.77      0.69      2997
           9       0.75      0.73      0.74      2987

    accuracy                           0.65     30000
   macro avg       0.67      0.65      0.65     30000
weighted avg       0.67      0.65      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7678272977705392, 0.7678272977705392)




CCA coefficients mean non-concern: (0.7693658172331863, 0.7693658172331863)




Linear CKA concern: 0.8890429697403354




Linear CKA non-concern: 0.877116315492155




Kernel CKA concern: 0.8377826144740546




Kernel CKA non-concern: 0.8370544729857199




Evaluate the pruned model 7




Evaluating:   0%|                                                                                             …

Loss: 1.1044




Precision: 0.6679, Recall: 0.6546, F1-Score: 0.6548




              precision    recall  f1-score   support

           0       0.53      0.57      0.55      2941
           1       0.76      0.55      0.64      2997
           2       0.77      0.70      0.73      3016
           3       0.51      0.49      0.50      2978
           4       0.78      0.81      0.79      3017
           5       0.93      0.76      0.84      3004
           6       0.50      0.41      0.45      3037
           7       0.52      0.76      0.62      3026
           8       0.62      0.77      0.69      2997
           9       0.75      0.73      0.74      2987

    accuracy                           0.65     30000
   macro avg       0.67      0.65      0.65     30000
weighted avg       0.67      0.65      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7598956283390047, 0.7598956283390047)




CCA coefficients mean non-concern: (0.7704256800895413, 0.7704256800895413)




Linear CKA concern: 0.9242650686665457




Linear CKA non-concern: 0.8644960884925297




Kernel CKA concern: 0.8952202475002551




Kernel CKA non-concern: 0.816302110762877




Evaluate the pruned model 8




Evaluating:   0%|                                                                                             …

Loss: 1.1197




Precision: 0.6653, Recall: 0.6504, F1-Score: 0.6515




              precision    recall  f1-score   support

           0       0.51      0.57      0.54      2941
           1       0.77      0.53      0.63      2997
           2       0.75      0.71      0.73      3016
           3       0.50      0.50      0.50      2978
           4       0.79      0.80      0.79      3017
           5       0.94      0.75      0.83      3004
           6       0.48      0.42      0.45      3037
           7       0.56      0.71      0.63      3026
           8       0.59      0.80      0.68      2997
           9       0.76      0.71      0.74      2987

    accuracy                           0.65     30000
   macro avg       0.67      0.65      0.65     30000
weighted avg       0.67      0.65      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7636553678396967, 0.7636553678396967)




CCA coefficients mean non-concern: (0.769393916807169, 0.769393916807169)




Linear CKA concern: 0.9230876381498477




Linear CKA non-concern: 0.8507281972942968




Kernel CKA concern: 0.8967589414540875




Kernel CKA non-concern: 0.8091926079008714




Evaluate the pruned model 9




Evaluating:   0%|                                                                                             …

Loss: 1.1032




Precision: 0.6662, Recall: 0.6551, F1-Score: 0.6546




              precision    recall  f1-score   support

           0       0.52      0.57      0.55      2941
           1       0.75      0.56      0.64      2997
           2       0.75      0.72      0.73      3016
           3       0.53      0.47      0.50      2978
           4       0.79      0.80      0.79      3017
           5       0.93      0.76      0.84      3004
           6       0.51      0.41      0.45      3037
           7       0.52      0.76      0.62      3026
           8       0.63      0.76      0.69      2997
           9       0.72      0.75      0.74      2987

    accuracy                           0.66     30000
   macro avg       0.67      0.66      0.65     30000
weighted avg       0.67      0.66      0.65     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7618437683207059, 0.7618437683207059)




CCA coefficients mean non-concern: (0.7719681885571481, 0.7719681885571481)




Linear CKA concern: 0.9348705995280927




Linear CKA non-concern: 0.8674389100288677




Kernel CKA concern: 0.8912785953901643




Kernel CKA non-concern: 0.8153350456273464


