In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.4
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-25 02:42:04


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}




The model sadickam/sdg-classification-bert is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}




Loading cached dataset OSDG.




The dataset OSDG is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Evaluate the pruned model 0




Evaluating:   0%|                                                                    | 0/800 [00:15<?, ?it/s]

Loss: 0.9414




Precision: 0.7787, Recall: 0.7800, F1-Score: 0.7751




              precision    recall  f1-score   support

           0       0.74      0.67      0.70       797
           1       0.84      0.71      0.77       775
           2       0.87      0.87      0.87       795
           3       0.88      0.82      0.85      1110
           4       0.85      0.80      0.83      1260
           5       0.90      0.68      0.77       882
           6       0.85      0.79      0.82       940
           7       0.48      0.59      0.53       473
           8       0.66      0.85      0.74       746
           9       0.57      0.74      0.65       689
          10       0.78      0.78      0.78       670
          11       0.66      0.80      0.72       312
          12       0.68      0.82      0.74       665
          13       0.87      0.83      0.85       314
          14       0.86      0.78      0.82       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8323210073407598, 0.8323210073407598)




CCA coefficients mean non-concern: (0.8235624879608412, 0.8235624879608412)




Linear CKA concern: 0.9770241969908591




Linear CKA non-concern: 0.9439718186609288




Kernel CKA concern: 0.971828230168721




Kernel CKA non-concern: 0.9448751194839987




Evaluate the pruned model 1




Evaluating:   0%|                                                                    | 0/800 [00:23<?, ?it/s]

Loss: 0.9283




Precision: 0.7761, Recall: 0.7796, F1-Score: 0.7741




              precision    recall  f1-score   support

           0       0.76      0.65      0.70       797
           1       0.84      0.73      0.78       775
           2       0.88      0.88      0.88       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.80      0.82      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.79      0.82       940
           7       0.48      0.57      0.52       473
           8       0.66      0.85      0.74       746
           9       0.57      0.73      0.64       689
          10       0.75      0.78      0.77       670
          11       0.66      0.79      0.72       312
          12       0.70      0.81      0.75       665
          13       0.84      0.85      0.84       314
          14       0.85      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8231539414897405, 0.8231539414897405)




CCA coefficients mean non-concern: (0.8260402537458588, 0.8260402537458588)




Linear CKA concern: 0.9697374358836919




Linear CKA non-concern: 0.9511291096659219




Kernel CKA concern: 0.9663490262105696




Kernel CKA non-concern: 0.953214732274874




Evaluate the pruned model 2




Evaluating:   0%|                                                                    | 0/800 [00:23<?, ?it/s]

Loss: 0.9284




Precision: 0.7762, Recall: 0.7794, F1-Score: 0.7740




              precision    recall  f1-score   support

           0       0.75      0.66      0.70       797
           1       0.84      0.70      0.77       775
           2       0.86      0.88      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.81      0.83      1260
           5       0.89      0.68      0.78       882
           6       0.85      0.79      0.82       940
           7       0.47      0.59      0.52       473
           8       0.67      0.84      0.74       746
           9       0.59      0.72      0.65       689
          10       0.75      0.79      0.77       670
          11       0.66      0.78      0.72       312
          12       0.69      0.81      0.75       665
          13       0.85      0.84      0.85       314
          14       0.85      0.78      0.81       756
          15       0.96      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8243448453092745, 0.8243448453092745)




CCA coefficients mean non-concern: (0.8162833874884108, 0.8162833874884108)




Linear CKA concern: 0.9820871205563386




Linear CKA non-concern: 0.9371435397348825




Kernel CKA concern: 0.9763578122065768




Kernel CKA non-concern: 0.9403602853837321




Evaluate the pruned model 3




Evaluating:   0%|                                                                    | 0/800 [00:23<?, ?it/s]

Loss: 0.9400




Precision: 0.7771, Recall: 0.7793, F1-Score: 0.7743




              precision    recall  f1-score   support

           0       0.75      0.66      0.70       797
           1       0.84      0.71      0.77       775
           2       0.87      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.85      0.80      0.83      1260
           5       0.90      0.68      0.77       882
           6       0.85      0.79      0.82       940
           7       0.48      0.58      0.53       473
           8       0.66      0.85      0.74       746
           9       0.58      0.73      0.65       689
          10       0.76      0.78      0.77       670
          11       0.66      0.78      0.72       312
          12       0.69      0.81      0.74       665
          13       0.85      0.84      0.85       314
          14       0.86      0.78      0.82       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8310790322220368, 0.8310790322220368)




CCA coefficients mean non-concern: (0.8270235946489202, 0.8270235946489202)




Linear CKA concern: 0.9714713595387137




Linear CKA non-concern: 0.9519721049039636




Kernel CKA concern: 0.9673719242163661




Kernel CKA non-concern: 0.9526662314589425




Evaluate the pruned model 4




Evaluating:   0%|                                                                    | 0/800 [00:23<?, ?it/s]

Loss: 0.9333




Precision: 0.7768, Recall: 0.7788, F1-Score: 0.7736




              precision    recall  f1-score   support

           0       0.74      0.65      0.70       797
           1       0.84      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.89      0.80      0.84      1110
           4       0.83      0.81      0.82      1260
           5       0.90      0.68      0.78       882
           6       0.85      0.80      0.82       940
           7       0.48      0.59      0.52       473
           8       0.66      0.85      0.74       746
           9       0.58      0.74      0.65       689
          10       0.77      0.78      0.78       670
          11       0.67      0.79      0.72       312
          12       0.68      0.82      0.74       665
          13       0.85      0.84      0.84       314
          14       0.86      0.78      0.82       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8360265416115085, 0.8360265416115085)




CCA coefficients mean non-concern: (0.8239366626660455, 0.8239366626660455)




Linear CKA concern: 0.980902390409245




Linear CKA non-concern: 0.9459325417867867




Kernel CKA concern: 0.9751539486201157




Kernel CKA non-concern: 0.945620037758603




Evaluate the pruned model 5




Evaluating:   0%|                                                                    | 0/800 [00:22<?, ?it/s]

Loss: 0.9275




Precision: 0.7781, Recall: 0.7806, F1-Score: 0.7754




              precision    recall  f1-score   support

           0       0.77      0.65      0.70       797
           1       0.84      0.71      0.77       775
           2       0.88      0.88      0.88       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.81      0.83      1260
           5       0.89      0.69      0.78       882
           6       0.85      0.80      0.83       940
           7       0.46      0.60      0.52       473
           8       0.67      0.84      0.75       746
           9       0.58      0.74      0.65       689
          10       0.75      0.78      0.76       670
          11       0.67      0.78      0.72       312
          12       0.70      0.81      0.75       665
          13       0.84      0.84      0.84       314
          14       0.85      0.78      0.82       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8253096692482754, 0.8253096692482754)




CCA coefficients mean non-concern: (0.8195984435818823, 0.8195984435818823)




Linear CKA concern: 0.9734948953163728




Linear CKA non-concern: 0.950139911781089




Kernel CKA concern: 0.9671123093779261




Kernel CKA non-concern: 0.9534495232097946




Evaluate the pruned model 6




Evaluating:   0%|                                                                    | 0/800 [00:22<?, ?it/s]

Loss: 0.9391




Precision: 0.7753, Recall: 0.7790, F1-Score: 0.7731




              precision    recall  f1-score   support

           0       0.75      0.65      0.70       797
           1       0.85      0.70      0.76       775
           2       0.87      0.87      0.87       795
           3       0.87      0.82      0.85      1110
           4       0.85      0.80      0.83      1260
           5       0.89      0.69      0.78       882
           6       0.83      0.81      0.82       940
           7       0.47      0.59      0.53       473
           8       0.67      0.85      0.75       746
           9       0.59      0.73      0.65       689
          10       0.77      0.78      0.77       670
          11       0.63      0.79      0.70       312
          12       0.69      0.81      0.74       665
          13       0.84      0.85      0.84       314
          14       0.87      0.77      0.82       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.824891890797168, 0.824891890797168)




CCA coefficients mean non-concern: (0.828121253123054, 0.828121253123054)




Linear CKA concern: 0.9758964129455993




Linear CKA non-concern: 0.9549964793284077




Kernel CKA concern: 0.9708650769005042




Kernel CKA non-concern: 0.956133312306847




Evaluate the pruned model 7




Evaluating:   0%|                                                                    | 0/800 [00:23<?, ?it/s]

Loss: 0.9413




Precision: 0.7764, Recall: 0.7795, F1-Score: 0.7737




              precision    recall  f1-score   support

           0       0.74      0.66      0.70       797
           1       0.85      0.69      0.76       775
           2       0.87      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.84      0.81      0.83      1260
           5       0.90      0.68      0.78       882
           6       0.84      0.80      0.82       940
           7       0.46      0.60      0.52       473
           8       0.67      0.84      0.75       746
           9       0.59      0.73      0.65       689
          10       0.78      0.78      0.78       670
          11       0.66      0.79      0.72       312
          12       0.69      0.81      0.74       665
          13       0.84      0.84      0.84       314
          14       0.85      0.78      0.81       756
          15       0.97      0.96      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8312879316698543, 0.8312879316698543)




CCA coefficients mean non-concern: (0.8252127249050504, 0.8252127249050504)




Linear CKA concern: 0.9694163178755473




Linear CKA non-concern: 0.9566447808040203




Kernel CKA concern: 0.964067198704471




Kernel CKA non-concern: 0.957673118587018




Evaluate the pruned model 8




Evaluating:   0%|                                                                    | 0/800 [00:22<?, ?it/s]

Loss: 0.9389




Precision: 0.7762, Recall: 0.7804, F1-Score: 0.7739




              precision    recall  f1-score   support

           0       0.76      0.65      0.70       797
           1       0.85      0.70      0.77       775
           2       0.86      0.88      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.80      0.82      1260
           5       0.90      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.47      0.59      0.52       473
           8       0.65      0.86      0.74       746
           9       0.59      0.72      0.65       689
          10       0.76      0.78      0.77       670
          11       0.65      0.80      0.72       312
          12       0.69      0.81      0.75       665
          13       0.83      0.86      0.84       314
          14       0.86      0.76      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8295586552988088, 0.8295586552988088)




CCA coefficients mean non-concern: (0.8191225436573929, 0.8191225436573929)




Linear CKA concern: 0.9765885984517895




Linear CKA non-concern: 0.9435384857579705




Kernel CKA concern: 0.9711871835048673




Kernel CKA non-concern: 0.9442636769707252




Evaluate the pruned model 9




Evaluating:   0%|                                                                    | 0/800 [00:22<?, ?it/s]

Loss: 0.9477




Precision: 0.7759, Recall: 0.7778, F1-Score: 0.7726




              precision    recall  f1-score   support

           0       0.75      0.65      0.70       797
           1       0.85      0.69      0.76       775
           2       0.87      0.88      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.84      0.80      0.82      1260
           5       0.90      0.68      0.77       882
           6       0.85      0.79      0.82       940
           7       0.47      0.59      0.52       473
           8       0.66      0.85      0.74       746
           9       0.57      0.73      0.64       689
          10       0.76      0.78      0.77       670
          11       0.67      0.78      0.72       312
          12       0.69      0.81      0.74       665
          13       0.84      0.85      0.85       314
          14       0.86      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8284427370051644, 0.8284427370051644)




CCA coefficients mean non-concern: (0.8258503594845444, 0.8258503594845444)




Linear CKA concern: 0.9746162481503985




Linear CKA non-concern: 0.9532625229980092




Kernel CKA concern: 0.96757785387297




Kernel CKA non-concern: 0.9549091576074621




Evaluate the pruned model 10




Evaluating:   0%|                                                                    | 0/800 [00:23<?, ?it/s]

Loss: 0.9313




Precision: 0.7786, Recall: 0.7809, F1-Score: 0.7759




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.84      0.71      0.77       775
           2       0.88      0.87      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.84      0.81      0.83      1260
           5       0.89      0.69      0.78       882
           6       0.84      0.79      0.82       940
           7       0.48      0.59      0.53       473
           8       0.66      0.84      0.74       746
           9       0.59      0.73      0.65       689
          10       0.75      0.79      0.77       670
          11       0.67      0.79      0.73       312
          12       0.69      0.81      0.75       665
          13       0.86      0.84      0.85       314
          14       0.86      0.78      0.82       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8280391296574087, 0.8280391296574087)




CCA coefficients mean non-concern: (0.8261762412808512, 0.8261762412808512)




Linear CKA concern: 0.9707741905165033




Linear CKA non-concern: 0.9524923536713038




Kernel CKA concern: 0.9668790280178096




Kernel CKA non-concern: 0.9538760964019585




Evaluate the pruned model 11




Evaluating:   0%|                                                                    | 0/800 [00:23<?, ?it/s]

Loss: 0.9405




Precision: 0.7742, Recall: 0.7796, F1-Score: 0.7719




              precision    recall  f1-score   support

           0       0.78      0.64      0.71       797
           1       0.85      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.86      0.79      0.82      1260
           5       0.89      0.69      0.78       882
           6       0.85      0.80      0.82       940
           7       0.45      0.60      0.51       473
           8       0.67      0.85      0.75       746
           9       0.57      0.73      0.64       689
          10       0.75      0.78      0.77       670
          11       0.61      0.80      0.69       312
          12       0.70      0.80      0.75       665
          13       0.83      0.85      0.84       314
          14       0.86      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.77   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8272222291165698, 0.8272222291165698)




CCA coefficients mean non-concern: (0.8223238991582918, 0.8223238991582918)




Linear CKA concern: 0.9800435165905604




Linear CKA non-concern: 0.9514732073192373




Kernel CKA concern: 0.9740522154818905




Kernel CKA non-concern: 0.9538427775639872




Evaluate the pruned model 12




Evaluating:   0%|                                                                    | 0/800 [00:22<?, ?it/s]

Loss: 0.9387




Precision: 0.7748, Recall: 0.7779, F1-Score: 0.7718




              precision    recall  f1-score   support

           0       0.77      0.65      0.70       797
           1       0.85      0.70      0.76       775
           2       0.87      0.88      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.86      0.79      0.82      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.46      0.59      0.52       473
           8       0.66      0.84      0.74       746
           9       0.57      0.73      0.64       689
          10       0.75      0.78      0.77       670
          11       0.65      0.78      0.71       312
          12       0.69      0.81      0.75       665
          13       0.84      0.85      0.85       314
          14       0.85      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.77   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.824900317040104, 0.824900317040104)




CCA coefficients mean non-concern: (0.822828619517673, 0.822828619517673)




Linear CKA concern: 0.9705738079191472




Linear CKA non-concern: 0.9492611182845221




Kernel CKA concern: 0.964901851956852




Kernel CKA non-concern: 0.9519511955203662




Evaluate the pruned model 13




Evaluating:   0%|                                                                    | 0/800 [00:23<?, ?it/s]

Loss: 0.9310




Precision: 0.7752, Recall: 0.7796, F1-Score: 0.7729




              precision    recall  f1-score   support

           0       0.78      0.65      0.70       797
           1       0.85      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.80      0.82      1260
           5       0.89      0.69      0.78       882
           6       0.86      0.79      0.82       940
           7       0.47      0.58      0.52       473
           8       0.65      0.85      0.74       746
           9       0.57      0.74      0.64       689
          10       0.78      0.77      0.78       670
          11       0.63      0.79      0.70       312
          12       0.69      0.81      0.74       665
          13       0.82      0.86      0.84       314
          14       0.86      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8263396368848608, 0.8263396368848608)




CCA coefficients mean non-concern: (0.8162216299697819, 0.8162216299697819)




Linear CKA concern: 0.9803620125344554




Linear CKA non-concern: 0.9415373522808179




Kernel CKA concern: 0.9728915393782865




Kernel CKA non-concern: 0.9443894375967172




Evaluate the pruned model 14




Evaluating:   0%|                                                                    | 0/800 [00:23<?, ?it/s]

Loss: 0.9302




Precision: 0.7765, Recall: 0.7793, F1-Score: 0.7742




              precision    recall  f1-score   support

           0       0.75      0.66      0.70       797
           1       0.84      0.71      0.77       775
           2       0.87      0.87      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.80      0.82      1260
           5       0.89      0.68      0.78       882
           6       0.85      0.79      0.82       940
           7       0.47      0.58      0.52       473
           8       0.66      0.85      0.75       746
           9       0.58      0.72      0.65       689
          10       0.75      0.79      0.77       670
          11       0.68      0.79      0.73       312
          12       0.69      0.80      0.75       665
          13       0.84      0.84      0.84       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8274386023349147, 0.8274386023349147)




CCA coefficients mean non-concern: (0.8244737783110592, 0.8244737783110592)




Linear CKA concern: 0.9726552785920776




Linear CKA non-concern: 0.9486606959291514




Kernel CKA concern: 0.9698443661656467




Kernel CKA non-concern: 0.9499200680352631




Evaluate the pruned model 15




Evaluating:   0%|                                                                    | 0/800 [00:23<?, ?it/s]

Loss: 0.9289




Precision: 0.7731, Recall: 0.7733, F1-Score: 0.7680




              precision    recall  f1-score   support

           0       0.76      0.63      0.69       797
           1       0.83      0.69      0.75       775
           2       0.87      0.87      0.87       795
           3       0.88      0.81      0.84      1110
           4       0.86      0.79      0.82      1260
           5       0.90      0.67      0.77       882
           6       0.85      0.79      0.82       940
           7       0.47      0.58      0.52       473
           8       0.66      0.85      0.74       746
           9       0.54      0.75      0.63       689
          10       0.77      0.77      0.77       670
          11       0.64      0.79      0.71       312
          12       0.69      0.81      0.75       665
          13       0.85      0.84      0.84       314
          14       0.87      0.75      0.81       756
          15       0.95      0.98      0.96      1607

    accuracy                           0.79     12791
   macro avg       0.77   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8226726764625745, 0.8226726764625745)




CCA coefficients mean non-concern: (0.7908071645890764, 0.7908071645890764)




Linear CKA concern: 0.9694129228175768




Linear CKA non-concern: 0.9025059512248046




Kernel CKA concern: 0.9651996831339438




Kernel CKA non-concern: 0.9055714044158193


