In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.6
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-09-01 01:05:51


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

Evaluate the pruned model 0




Evaluating:   0%|                                                                                             …

Loss: 1.5176




Precision: 0.6435, Recall: 0.5540, F1-Score: 0.5595




              precision    recall  f1-score   support

           0       0.50      0.50      0.50      2941
           1       0.81      0.35      0.49      2997
           2       0.84      0.45      0.59      3016
           3       0.53      0.33      0.41      2978
           4       0.70      0.79      0.74      3017
           5       0.96      0.57      0.71      3004
           6       0.44      0.36      0.39      3037
           7       0.31      0.83      0.45      3026
           8       0.53      0.79      0.64      2997
           9       0.81      0.57      0.67      2987

    accuracy                           0.55     30000
   macro avg       0.64      0.55      0.56     30000
weighted avg       0.64      0.55      0.56     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7111887564781855, 0.7111887564781855)




CCA coefficients mean non-concern: (0.704762427906554, 0.704762427906554)




Linear CKA concern: 0.7265832522933692




Linear CKA non-concern: 0.6537770385241913




Kernel CKA concern: 0.5858917500972699




Kernel CKA non-concern: 0.4976493833115187




Evaluate the pruned model 1




Evaluating:   0%|                                                                                             …

Loss: 1.5196




Precision: 0.6502, Recall: 0.5585, F1-Score: 0.5676




              precision    recall  f1-score   support

           0       0.48      0.52      0.50      2941
           1       0.77      0.45      0.57      2997
           2       0.84      0.48      0.61      3016
           3       0.54      0.29      0.38      2978
           4       0.75      0.76      0.75      3017
           5       0.96      0.57      0.72      3004
           6       0.52      0.35      0.41      3037
           7       0.30      0.84      0.44      3026
           8       0.55      0.77      0.64      2997
           9       0.80      0.56      0.66      2987

    accuracy                           0.56     30000
   macro avg       0.65      0.56      0.57     30000
weighted avg       0.65      0.56      0.57     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7031911020324019, 0.7031911020324019)




CCA coefficients mean non-concern: (0.7086413342233396, 0.7086413342233396)




Linear CKA concern: 0.6636751022060128




Linear CKA non-concern: 0.6782356818623102




Kernel CKA concern: 0.5424841935834391




Kernel CKA non-concern: 0.513165151737096




Evaluate the pruned model 2




Evaluating:   0%|                                                                                             …

Loss: 1.5184




Precision: 0.6460, Recall: 0.5672, F1-Score: 0.5755




              precision    recall  f1-score   support

           0       0.50      0.50      0.50      2941
           1       0.80      0.37      0.51      2997
           2       0.80      0.54      0.65      3016
           3       0.50      0.39      0.44      2978
           4       0.77      0.77      0.77      3017
           5       0.96      0.57      0.72      3004
           6       0.45      0.39      0.42      3037
           7       0.33      0.82      0.47      3026
           8       0.55      0.76      0.64      2997
           9       0.81      0.55      0.66      2987

    accuracy                           0.57     30000
   macro avg       0.65      0.57      0.58     30000
weighted avg       0.65      0.57      0.58     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6899703438033551, 0.6899703438033551)




CCA coefficients mean non-concern: (0.7066664016327217, 0.7066664016327217)




Linear CKA concern: 0.6061038496192391




Linear CKA non-concern: 0.704913808996724




Kernel CKA concern: 0.5936634834376516




Kernel CKA non-concern: 0.5110780761453722




Evaluate the pruned model 3




Evaluating:   0%|                                                                                             …

Loss: 1.4744




Precision: 0.6479, Recall: 0.5735, F1-Score: 0.5806




              precision    recall  f1-score   support

           0       0.47      0.54      0.50      2941
           1       0.77      0.44      0.56      2997
           2       0.84      0.49      0.62      3016
           3       0.57      0.34      0.42      2978
           4       0.76      0.77      0.77      3017
           5       0.96      0.58      0.73      3004
           6       0.45      0.37      0.41      3037
           7       0.33      0.81      0.47      3026
           8       0.55      0.77      0.64      2997
           9       0.78      0.61      0.68      2987

    accuracy                           0.57     30000
   macro avg       0.65      0.57      0.58     30000
weighted avg       0.65      0.57      0.58     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7059473300566279, 0.7059473300566279)




CCA coefficients mean non-concern: (0.7090795990908809, 0.7090795990908809)




Linear CKA concern: 0.649254385794843




Linear CKA non-concern: 0.6985227110456984




Kernel CKA concern: 0.48703547587183554




Kernel CKA non-concern: 0.5368922855658279




Evaluate the pruned model 4




Evaluating:   0%|                                                                                             …

Loss: 1.5927




Precision: 0.6338, Recall: 0.5326, F1-Score: 0.5343




              precision    recall  f1-score   support

           0       0.39      0.57      0.46      2941
           1       0.81      0.32      0.46      2997
           2       0.86      0.39      0.54      3016
           3       0.51      0.30      0.38      2978
           4       0.71      0.79      0.75      3017
           5       0.97      0.52      0.67      3004
           6       0.46      0.31      0.37      3037
           7       0.31      0.80      0.45      3026
           8       0.52      0.80      0.63      2997
           9       0.80      0.52      0.63      2987

    accuracy                           0.53     30000
   macro avg       0.63      0.53      0.53     30000
weighted avg       0.63      0.53      0.53     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6921378568013757, 0.6921378568013757)




CCA coefficients mean non-concern: (0.701489052710325, 0.701489052710325)




Linear CKA concern: 0.6835908211989755




Linear CKA non-concern: 0.6402830744894694




Kernel CKA concern: 0.6226162916266209




Kernel CKA non-concern: 0.4402458759513696




Evaluate the pruned model 5




Evaluating:   0%|                                                                                             …

Loss: 1.5216




Precision: 0.6375, Recall: 0.5628, F1-Score: 0.5654




              precision    recall  f1-score   support

           0       0.44      0.55      0.49      2941
           1       0.81      0.30      0.44      2997
           2       0.85      0.42      0.56      3016
           3       0.49      0.38      0.43      2978
           4       0.78      0.74      0.76      3017
           5       0.94      0.66      0.78      3004
           6       0.35      0.44      0.39      3037
           7       0.40      0.76      0.53      3026
           8       0.50      0.81      0.62      2997
           9       0.81      0.56      0.66      2987

    accuracy                           0.56     30000
   macro avg       0.64      0.56      0.57     30000
weighted avg       0.64      0.56      0.57     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.6990357037842815, 0.6990357037842815)




CCA coefficients mean non-concern: (0.7035504950847306, 0.7035504950847306)




Linear CKA concern: 0.7843841618370814




Linear CKA non-concern: 0.647249556865931




Kernel CKA concern: 0.7386270379184363




Kernel CKA non-concern: 0.4767842979239457




Evaluate the pruned model 6




Evaluating:   0%|                                                                                             …

Loss: 1.5300




Precision: 0.6355, Recall: 0.5553, F1-Score: 0.5574




              precision    recall  f1-score   support

           0       0.45      0.55      0.50      2941
           1       0.79      0.37      0.51      2997
           2       0.86      0.41      0.55      3016
           3       0.51      0.36      0.42      2978
           4       0.67      0.82      0.73      3017
           5       0.97      0.54      0.69      3004
           6       0.43      0.37      0.39      3037
           7       0.34      0.79      0.48      3026
           8       0.53      0.80      0.64      2997
           9       0.81      0.56      0.66      2987

    accuracy                           0.56     30000
   macro avg       0.64      0.56      0.56     30000
weighted avg       0.64      0.56      0.56     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7068570648487894, 0.7068570648487894)




CCA coefficients mean non-concern: (0.7046968378112989, 0.7046968378112989)




Linear CKA concern: 0.6761870375358328




Linear CKA non-concern: 0.6752754879613501




Kernel CKA concern: 0.4871713359378417




Kernel CKA non-concern: 0.5065942769507867




Evaluate the pruned model 7




Evaluating:   0%|                                                                                             …

Loss: 1.4856




Precision: 0.6329, Recall: 0.5692, F1-Score: 0.5678




              precision    recall  f1-score   support

           0       0.49      0.53      0.51      2941
           1       0.81      0.31      0.45      2997
           2       0.86      0.44      0.58      3016
           3       0.45      0.41      0.43      2978
           4       0.65      0.83      0.73      3017
           5       0.96      0.62      0.76      3004
           6       0.39      0.38      0.39      3037
           7       0.40      0.78      0.53      3026
           8       0.52      0.80      0.63      2997
           9       0.80      0.58      0.67      2987

    accuracy                           0.57     30000
   macro avg       0.63      0.57      0.57     30000
weighted avg       0.63      0.57      0.57     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.7020857620835779, 0.7020857620835779)




CCA coefficients mean non-concern: (0.7092168495097785, 0.7092168495097785)




Linear CKA concern: 0.7488413858591194




Linear CKA non-concern: 0.651896601387826




Kernel CKA concern: 0.6841366286560631




Kernel CKA non-concern: 0.4611211936139131




Evaluate the pruned model 8




Evaluating:   0%|                                                                                             …

Loss: 1.5287




Precision: 0.6329, Recall: 0.5674, F1-Score: 0.5667




              precision    recall  f1-score   support

           0       0.45      0.54      0.49      2941
           1       0.81      0.31      0.45      2997
           2       0.84      0.48      0.61      3016
           3       0.44      0.47      0.45      2978
           4       0.68      0.82      0.74      3017
           5       0.96      0.57      0.71      3004
           6       0.43      0.36      0.39      3037
           7       0.41      0.75      0.53      3026
           8       0.51      0.82      0.63      2997
           9       0.80      0.57      0.67      2987

    accuracy                           0.57     30000
   macro avg       0.63      0.57      0.57     30000
weighted avg       0.63      0.57      0.57     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.700994814507426, 0.700994814507426)




CCA coefficients mean non-concern: (0.7068045877920573, 0.7068045877920573)




Linear CKA concern: 0.7486264851170247




Linear CKA non-concern: 0.640598872144429




Kernel CKA concern: 0.6665507429444598




Kernel CKA non-concern: 0.4744730719978601




Evaluate the pruned model 9




Evaluating:   0%|                                                                                             …

Loss: 1.4642




Precision: 0.6438, Recall: 0.5747, F1-Score: 0.5783




              precision    recall  f1-score   support

           0       0.47      0.55      0.51      2941
           1       0.80      0.35      0.48      2997
           2       0.84      0.47      0.60      3016
           3       0.49      0.39      0.44      2978
           4       0.75      0.78      0.76      3017
           5       0.96      0.60      0.74      3004
           6       0.46      0.38      0.42      3037
           7       0.35      0.80      0.48      3026
           8       0.55      0.78      0.65      2997
           9       0.76      0.66      0.71      2987

    accuracy                           0.57     30000
   macro avg       0.64      0.57      0.58     30000
weighted avg       0.64      0.57      0.58     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.699636431551733, 0.699636431551733)




CCA coefficients mean non-concern: (0.7097146597454965, 0.7097146597454965)




Linear CKA concern: 0.7777486670924082




Linear CKA non-concern: 0.6809395735742829




Kernel CKA concern: 0.6792357121267204




Kernel CKA non-concern: 0.5083974214625183


