In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.5
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-09-01 12:31:46


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}




The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}




Loading cached dataset YahooAnswersTopics.




The dataset YahooAnswersTopics is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

Evaluate the pruned model 0




Evaluating:   0%|                                                                                             …

Loss: 1.0286




Precision: 0.6761, Recall: 0.6708, F1-Score: 0.6688




              precision    recall  f1-score   support

           0       0.54      0.58      0.56      2941
           1       0.75      0.59      0.66      2997
           2       0.74      0.75      0.75      3016
           3       0.53      0.49      0.51      2978
           4       0.80      0.80      0.80      3017
           5       0.92      0.81      0.86      3004
           6       0.56      0.40      0.47      3037
           7       0.57      0.75      0.65      3026
           8       0.63      0.76      0.69      2997
           9       0.72      0.77      0.74      2987

    accuracy                           0.67     30000
   macro avg       0.68      0.67      0.67     30000
weighted avg       0.68      0.67      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8188904384402779, 0.8188904384402779)




CCA coefficients mean non-concern: (0.8223248133126354, 0.8223248133126354)




Linear CKA concern: 0.9603106564147763




Linear CKA non-concern: 0.9444500548923892




Kernel CKA concern: 0.9355414727888755




Kernel CKA non-concern: 0.9232875333193942




Evaluate the pruned model 1




Evaluating:   0%|                                                                                             …

Loss: 1.0266




Precision: 0.6769, Recall: 0.6712, F1-Score: 0.6692




              precision    recall  f1-score   support

           0       0.54      0.58      0.56      2941
           1       0.74      0.62      0.67      2997
           2       0.74      0.75      0.75      3016
           3       0.53      0.48      0.51      2978
           4       0.81      0.80      0.81      3017
           5       0.92      0.81      0.86      3004
           6       0.58      0.40      0.47      3037
           7       0.55      0.76      0.64      3026
           8       0.63      0.75      0.69      2997
           9       0.72      0.76      0.74      2987

    accuracy                           0.67     30000
   macro avg       0.68      0.67      0.67     30000
weighted avg       0.68      0.67      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8174170628617983, 0.8174170628617983)




CCA coefficients mean non-concern: (0.8214807942617788, 0.8214807942617788)




Linear CKA concern: 0.9618228297708645




Linear CKA non-concern: 0.9435815671336039




Kernel CKA concern: 0.9427506170076738




Kernel CKA non-concern: 0.9166615200764701




Evaluate the pruned model 2




Evaluating:   0%|                                                                                             …

Loss: 1.0285




Precision: 0.6761, Recall: 0.6704, F1-Score: 0.6683




              precision    recall  f1-score   support

           0       0.54      0.57      0.56      2941
           1       0.75      0.61      0.67      2997
           2       0.74      0.75      0.74      3016
           3       0.53      0.50      0.51      2978
           4       0.81      0.79      0.80      3017
           5       0.92      0.81      0.86      3004
           6       0.57      0.39      0.47      3037
           7       0.56      0.75      0.64      3026
           8       0.62      0.76      0.68      2997
           9       0.72      0.76      0.74      2987

    accuracy                           0.67     30000
   macro avg       0.68      0.67      0.67     30000
weighted avg       0.68      0.67      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8098426348670122, 0.8098426348670122)




CCA coefficients mean non-concern: (0.8199622098672588, 0.8199622098672588)




Linear CKA concern: 0.9716953600789158




Linear CKA non-concern: 0.9422779796006443




Kernel CKA concern: 0.9610212037915411




Kernel CKA non-concern: 0.910353881443783




Evaluate the pruned model 3




Evaluating:   0%|                                                                                             …

Loss: 1.0243




Precision: 0.6777, Recall: 0.6726, F1-Score: 0.6707




              precision    recall  f1-score   support

           0       0.54      0.58      0.56      2941
           1       0.74      0.62      0.67      2997
           2       0.74      0.75      0.74      3016
           3       0.55      0.49      0.52      2978
           4       0.81      0.80      0.81      3017
           5       0.92      0.81      0.86      3004
           6       0.57      0.40      0.47      3037
           7       0.56      0.75      0.64      3026
           8       0.63      0.76      0.69      2997
           9       0.72      0.76      0.74      2987

    accuracy                           0.67     30000
   macro avg       0.68      0.67      0.67     30000
weighted avg       0.68      0.67      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8234174782664515, 0.8234174782664515)




CCA coefficients mean non-concern: (0.8194462356736345, 0.8194462356736345)




Linear CKA concern: 0.9606815196949182




Linear CKA non-concern: 0.9470026517464258




Kernel CKA concern: 0.9422155766862107




Kernel CKA non-concern: 0.9240648911868424




Evaluate the pruned model 4




Evaluating:   0%|                                                                                             …

Loss: 1.0335




Precision: 0.6763, Recall: 0.6704, F1-Score: 0.6685




              precision    recall  f1-score   support

           0       0.52      0.59      0.55      2941
           1       0.75      0.61      0.67      2997
           2       0.74      0.75      0.74      3016
           3       0.54      0.49      0.51      2978
           4       0.80      0.80      0.80      3017
           5       0.92      0.80      0.86      3004
           6       0.57      0.39      0.46      3037
           7       0.57      0.75      0.65      3026
           8       0.62      0.76      0.69      2997
           9       0.73      0.76      0.74      2987

    accuracy                           0.67     30000
   macro avg       0.68      0.67      0.67     30000
weighted avg       0.68      0.67      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8139895592725522, 0.8139895592725522)




CCA coefficients mean non-concern: (0.8159376119133191, 0.8159376119133191)




Linear CKA concern: 0.9713474410399582




Linear CKA non-concern: 0.935710354830488




Kernel CKA concern: 0.9586549545501012




Kernel CKA non-concern: 0.9053692242528661




Evaluate the pruned model 5




Evaluating:   0%|                                                                                             …

Loss: 1.0282




Precision: 0.6766, Recall: 0.6719, F1-Score: 0.6705




              precision    recall  f1-score   support

           0       0.54      0.58      0.56      2941
           1       0.75      0.60      0.67      2997
           2       0.74      0.74      0.74      3016
           3       0.51      0.52      0.52      2978
           4       0.81      0.79      0.80      3017
           5       0.91      0.82      0.86      3004
           6       0.55      0.41      0.47      3037
           7       0.60      0.73      0.66      3026
           8       0.62      0.76      0.68      2997
           9       0.73      0.76      0.74      2987

    accuracy                           0.67     30000
   macro avg       0.68      0.67      0.67     30000
weighted avg       0.68      0.67      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8125805319351229, 0.8125805319351229)




CCA coefficients mean non-concern: (0.8208092047268778, 0.8208092047268778)




Linear CKA concern: 0.9746836918697439




Linear CKA non-concern: 0.9398462785760928




Kernel CKA concern: 0.9597310813674133




Kernel CKA non-concern: 0.9147322349553536




Evaluate the pruned model 6




Evaluating:   0%|                                                                                             …

Loss: 1.0263




Precision: 0.6764, Recall: 0.6711, F1-Score: 0.6694




              precision    recall  f1-score   support

           0       0.54      0.59      0.56      2941
           1       0.74      0.59      0.66      2997
           2       0.73      0.75      0.74      3016
           3       0.52      0.50      0.51      2978
           4       0.81      0.80      0.80      3017
           5       0.91      0.81      0.86      3004
           6       0.57      0.41      0.48      3037
           7       0.57      0.75      0.65      3026
           8       0.64      0.75      0.69      2997
           9       0.73      0.76      0.74      2987

    accuracy                           0.67     30000
   macro avg       0.68      0.67      0.67     30000
weighted avg       0.68      0.67      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8227974999538683, 0.8227974999538683)




CCA coefficients mean non-concern: (0.8221965995015336, 0.8221965995015336)




Linear CKA concern: 0.9620110363594048




Linear CKA non-concern: 0.9482623293126168




Kernel CKA concern: 0.9401046148528934




Kernel CKA non-concern: 0.926248864311361




Evaluate the pruned model 7




Evaluating:   0%|                                                                                             …

Loss: 1.0284




Precision: 0.6761, Recall: 0.6715, F1-Score: 0.6700




              precision    recall  f1-score   support

           0       0.54      0.58      0.56      2941
           1       0.75      0.60      0.66      2997
           2       0.75      0.75      0.75      3016
           3       0.52      0.51      0.51      2978
           4       0.80      0.80      0.80      3017
           5       0.91      0.81      0.86      3004
           6       0.55      0.41      0.47      3037
           7       0.59      0.74      0.66      3026
           8       0.62      0.76      0.69      2997
           9       0.73      0.76      0.74      2987

    accuracy                           0.67     30000
   macro avg       0.68      0.67      0.67     30000
weighted avg       0.68      0.67      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8173000926988727, 0.8173000926988727)




CCA coefficients mean non-concern: (0.8241231797876616, 0.8241231797876616)




Linear CKA concern: 0.9702665821829748




Linear CKA non-concern: 0.9398469302227954




Kernel CKA concern: 0.9544127507471939




Kernel CKA non-concern: 0.9098636569164312




Evaluate the pruned model 8




Evaluating:   0%|                                                                                             …

Loss: 1.0292




Precision: 0.6758, Recall: 0.6699, F1-Score: 0.6691




              precision    recall  f1-score   support

           0       0.53      0.58      0.56      2941
           1       0.76      0.58      0.66      2997
           2       0.73      0.75      0.74      3016
           3       0.50      0.52      0.51      2978
           4       0.82      0.79      0.80      3017
           5       0.92      0.81      0.86      3004
           6       0.54      0.42      0.47      3037
           7       0.61      0.72      0.66      3026
           8       0.62      0.77      0.69      2997
           9       0.73      0.76      0.74      2987

    accuracy                           0.67     30000
   macro avg       0.68      0.67      0.67     30000
weighted avg       0.68      0.67      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8200175953067471, 0.8200175953067471)




CCA coefficients mean non-concern: (0.8229083592377939, 0.8229083592377939)




Linear CKA concern: 0.9639089984706861




Linear CKA non-concern: 0.9406661503822499




Kernel CKA concern: 0.9510269241309239




Kernel CKA non-concern: 0.9155989813833021




Evaluate the pruned model 9




Evaluating:   0%|                                                                                             …

Loss: 1.0291




Precision: 0.6760, Recall: 0.6715, F1-Score: 0.6699




              precision    recall  f1-score   support

           0       0.54      0.58      0.56      2941
           1       0.75      0.60      0.66      2997
           2       0.73      0.76      0.75      3016
           3       0.53      0.49      0.51      2978
           4       0.81      0.79      0.80      3017
           5       0.92      0.81      0.86      3004
           6       0.54      0.42      0.47      3037
           7       0.57      0.75      0.65      3026
           8       0.63      0.76      0.69      2997
           9       0.73      0.76      0.75      2987

    accuracy                           0.67     30000
   macro avg       0.68      0.67      0.67     30000
weighted avg       0.68      0.67      0.67     30000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8220278807799585, 0.8220278807799585)




CCA coefficients mean non-concern: (0.8243741749936527, 0.8243741749936527)




Linear CKA concern: 0.9732439719771536




Linear CKA non-concern: 0.9375607935126709




Kernel CKA concern: 0.9552383819059106




Kernel CKA non-concern: 0.9048670276237831


