In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.3
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-30 05:23:32


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}




The model sadickam/sdg-classification-bert is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}




Loading cached dataset OSDG.




The dataset OSDG is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

Evaluate the pruned model 0




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.6306




Precision: 0.6987, Recall: 0.4327, F1-Score: 0.4845




              precision    recall  f1-score   support

           0       0.72      0.39      0.51       797
           1       0.85      0.20      0.33       775
           2       0.92      0.35      0.51       795
           3       0.85      0.56      0.68      1110
           4       0.78      0.60      0.68      1260
           5       0.88      0.38      0.54       882
           6       0.75      0.57      0.65       940
           7       0.45      0.11      0.17       473
           8       0.75      0.23      0.35       746
           9       0.44      0.44      0.44       689
          10       0.51      0.52      0.51       670
          11       0.69      0.13      0.22       312
          12       0.66      0.54      0.59       665
          13       0.82      0.41      0.55       314
          14       0.83      0.49      0.62       756
          15       0.26      0.99      0.41      1607

    accuracy                           0.50     12791
   macro avg       0.70   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5394276249100889, 0.5394276249100889)




CCA coefficients mean non-concern: (0.5324794477292245, 0.5324794477292245)




Linear CKA concern: 0.2626729035814128




Linear CKA non-concern: 0.1951516567994571




Kernel CKA concern: 0.22409648183014247




Kernel CKA non-concern: 0.14966321150216522




Evaluate the pruned model 1




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.7352




Precision: 0.7167, Recall: 0.4525, F1-Score: 0.4888




              precision    recall  f1-score   support

           0       0.85      0.28      0.42       797
           1       0.76      0.48      0.58       775
           2       0.82      0.59      0.68       795
           3       0.80      0.68      0.73      1110
           4       0.76      0.58      0.66      1260
           5       0.91      0.23      0.37       882
           6       0.87      0.41      0.55       940
           7       0.37      0.11      0.16       473
           8       0.77      0.10      0.18       746
           9       0.45      0.46      0.46       689
          10       0.13      0.92      0.22       670
          11       0.72      0.17      0.28       312
          12       0.81      0.30      0.44       665
          13       0.87      0.46      0.61       314
          14       0.85      0.54      0.66       756
          15       0.71      0.93      0.81      1607

    accuracy                           0.51     12791
   macro avg       0.72   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5214404407418529, 0.5214404407418529)




CCA coefficients mean non-concern: (0.5334118894121527, 0.5334118894121527)




Linear CKA concern: 0.2085108584041901




Linear CKA non-concern: 0.199360796822189




Kernel CKA concern: 0.19731114247522838




Kernel CKA non-concern: 0.15730665879969064




Evaluate the pruned model 2




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.7457




Precision: 0.7154, Recall: 0.3899, F1-Score: 0.4408




              precision    recall  f1-score   support

           0       0.87      0.23      0.37       797
           1       0.81      0.15      0.26       775
           2       0.91      0.46      0.61       795
           3       0.86      0.54      0.66      1110
           4       0.82      0.55      0.66      1260
           5       0.90      0.15      0.26       882
           6       0.88      0.30      0.44       940
           7       0.48      0.06      0.11       473
           8       0.75      0.26      0.39       746
           9       0.41      0.47      0.44       689
          10       0.55      0.46      0.50       670
          11       0.72      0.12      0.20       312
          12       0.70      0.49      0.58       665
          13       0.72      0.49      0.59       314
          14       0.84      0.51      0.64       756
          15       0.22      0.99      0.36      1607

    accuracy                           0.45     12791
   macro avg       0.72   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5276064764033851, 0.5276064764033851)




CCA coefficients mean non-concern: (0.5312491543409525, 0.5312491543409525)




Linear CKA concern: 0.31591821849114776




Linear CKA non-concern: 0.18993720016474153




Kernel CKA concern: 0.2593151834719676




Kernel CKA non-concern: 0.1366125923804057




Evaluate the pruned model 3




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.6352




Precision: 0.6823, Recall: 0.5226, F1-Score: 0.5444




              precision    recall  f1-score   support

           0       0.63      0.45      0.52       797
           1       0.80      0.39      0.52       775
           2       0.89      0.54      0.68       795
           3       0.76      0.78      0.77      1110
           4       0.78      0.60      0.68      1260
           5       0.92      0.31      0.46       882
           6       0.83      0.59      0.69       940
           7       0.50      0.11      0.17       473
           8       0.76      0.31      0.44       746
           9       0.42      0.52      0.47       689
          10       0.23      0.84      0.36       670
          11       0.68      0.27      0.39       312
          12       0.64      0.60      0.62       665
          13       0.87      0.47      0.61       314
          14       0.70      0.62      0.66       756
          15       0.53      0.97      0.68      1607

    accuracy                           0.58     12791
   macro avg       0.68   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5400568812705955, 0.5400568812705955)




CCA coefficients mean non-concern: (0.5339495437925517, 0.5339495437925517)




Linear CKA concern: 0.37116245184484326




Linear CKA non-concern: 0.1889898295049059




Kernel CKA concern: 0.32883271891484406




Kernel CKA non-concern: 0.12983396812056514




Evaluate the pruned model 4




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.6088




Precision: 0.6552, Recall: 0.5287, F1-Score: 0.5325




              precision    recall  f1-score   support

           0       0.70      0.41      0.52       797
           1       0.85      0.25      0.38       775
           2       0.92      0.45      0.60       795
           3       0.88      0.58      0.70      1110
           4       0.81      0.63      0.71      1260
           5       0.91      0.31      0.46       882
           6       0.79      0.60      0.68       940
           7       0.45      0.14      0.21       473
           8       0.63      0.51      0.56       746
           9       0.48      0.47      0.48       689
          10       0.23      0.85      0.36       670
          11       0.66      0.24      0.35       312
          12       0.57      0.65      0.61       665
          13       0.55      0.67      0.60       314
          14       0.41      0.75      0.53       756
          15       0.64      0.96      0.77      1607

    accuracy                           0.57     12791
   macro avg       0.66   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5293036268340411, 0.5293036268340411)




CCA coefficients mean non-concern: (0.5334617593328111, 0.5334617593328111)




Linear CKA concern: 0.17435825254455853




Linear CKA non-concern: 0.18366653465598942




Kernel CKA concern: 0.13450246495295334




Kernel CKA non-concern: 0.1550692594017371




Evaluate the pruned model 5




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.7337




Precision: 0.7114, Recall: 0.3969, F1-Score: 0.4557




              precision    recall  f1-score   support

           0       0.80      0.30      0.44       797
           1       0.82      0.27      0.41       775
           2       0.94      0.31      0.47       795
           3       0.86      0.56      0.68      1110
           4       0.62      0.59      0.61      1260
           5       0.88      0.44      0.58       882
           6       0.91      0.31      0.46       940
           7       0.42      0.10      0.16       473
           8       0.72      0.18      0.29       746
           9       0.41      0.45      0.43       689
          10       0.57      0.37      0.45       670
          11       0.75      0.12      0.20       312
          12       0.75      0.46      0.57       665
          13       0.82      0.47      0.60       314
          14       0.87      0.42      0.56       756
          15       0.24      0.99      0.38      1607

    accuracy                           0.46     12791
   macro avg       0.71   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5343034083678058, 0.5343034083678058)




CCA coefficients mean non-concern: (0.5315596103404803, 0.5315596103404803)




Linear CKA concern: 0.27110274009994983




Linear CKA non-concern: 0.21761208735157025




Kernel CKA concern: 0.23527260664409416




Kernel CKA non-concern: 0.17407449063636624




Evaluate the pruned model 6




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.7624




Precision: 0.7111, Recall: 0.3574, F1-Score: 0.3997




              precision    recall  f1-score   support

           0       0.86      0.20      0.32       797
           1       0.80      0.13      0.23       775
           2       0.88      0.39      0.54       795
           3       0.76      0.69      0.72      1110
           4       0.83      0.36      0.51      1260
           5       0.88      0.27      0.42       882
           6       0.73      0.68      0.70       940
           7       0.57      0.02      0.03       473
           8       0.85      0.15      0.25       746
           9       0.39      0.39      0.39       689
          10       0.67      0.23      0.34       670
          11       0.67      0.08      0.15       312
          12       0.63      0.56      0.59       665
          13       0.76      0.21      0.33       314
          14       0.88      0.37      0.53       756
          15       0.22      0.99      0.36      1607

    accuracy                           0.43     12791
   macro avg       0.71   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5262760021809354, 0.5262760021809354)




CCA coefficients mean non-concern: (0.5337519931879138, 0.5337519931879138)




Linear CKA concern: 0.33892698918523273




Linear CKA non-concern: 0.1952182560104155




Kernel CKA concern: 0.3154121254886678




Kernel CKA non-concern: 0.14109039125562572




Evaluate the pruned model 7




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.5306




Precision: 0.6766, Recall: 0.5103, F1-Score: 0.5288




              precision    recall  f1-score   support

           0       0.80      0.33      0.47       797
           1       0.85      0.19      0.31       775
           2       0.88      0.35      0.50       795
           3       0.77      0.75      0.76      1110
           4       0.78      0.68      0.73      1260
           5       0.90      0.37      0.53       882
           6       0.80      0.60      0.68       940
           7       0.48      0.12      0.19       473
           8       0.48      0.64      0.55       746
           9       0.36      0.61      0.45       689
          10       0.40      0.64      0.49       670
          11       0.70      0.24      0.36       312
          12       0.56      0.66      0.60       665
          13       0.83      0.49      0.62       314
          14       0.82      0.54      0.65       756
          15       0.41      0.96      0.57      1607

    accuracy                           0.57     12791
   macro avg       0.68   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5413875164127935, 0.5413875164127935)




CCA coefficients mean non-concern: (0.5348171959806106, 0.5348171959806106)




Linear CKA concern: 0.2967417668117031




Linear CKA non-concern: 0.22773468890155044




Kernel CKA concern: 0.21591173366316327




Kernel CKA non-concern: 0.1823722639783286




Evaluate the pruned model 8




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.7625




Precision: 0.6810, Recall: 0.4164, F1-Score: 0.4480




              precision    recall  f1-score   support

           0       0.73      0.34      0.46       797
           1       0.82      0.13      0.22       775
           2       0.81      0.47      0.59       795
           3       0.81      0.65      0.72      1110
           4       0.52      0.65      0.58      1260
           5       0.88      0.30      0.45       882
           6       0.87      0.41      0.55       940
           7       0.50      0.02      0.03       473
           8       0.76      0.28      0.41       746
           9       0.32      0.60      0.42       689
          10       0.55      0.46      0.50       670
          11       0.61      0.04      0.07       312
          12       0.81      0.41      0.55       665
          13       0.84      0.40      0.54       314
          14       0.78      0.53      0.63       756
          15       0.29      0.99      0.45      1607

    accuracy                           0.49     12791
   macro avg       0.68   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5414032180028954, 0.5414032180028954)




CCA coefficients mean non-concern: (0.5343293430451999, 0.5343293430451999)




Linear CKA concern: 0.1829447965825324




Linear CKA non-concern: 0.21653163496747463




Kernel CKA concern: 0.1311128445294884




Kernel CKA non-concern: 0.14708004405433975




Evaluate the pruned model 9




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.7786




Precision: 0.7033, Recall: 0.3528, F1-Score: 0.3874




              precision    recall  f1-score   support

           0       0.84      0.14      0.24       797
           1       0.83      0.07      0.14       775
           2       0.92      0.17      0.28       795
           3       0.77      0.68      0.72      1110
           4       0.79      0.62      0.70      1260
           5       0.87      0.13      0.22       882
           6       0.86      0.39      0.54       940
           7       0.39      0.05      0.08       473
           8       0.68      0.35      0.47       746
           9       0.31      0.60      0.41       689
          10       0.74      0.34      0.46       670
          11       0.58      0.05      0.09       312
          12       0.76      0.39      0.52       665
          13       0.79      0.35      0.48       314
          14       0.89      0.34      0.49       756
          15       0.22      0.99      0.37      1607

    accuracy                           0.43     12791
   macro avg       0.70   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.537582004180795, 0.537582004180795)




CCA coefficients mean non-concern: (0.5317838514566278, 0.5317838514566278)




Linear CKA concern: 0.345610867371294




Linear CKA non-concern: 0.19744310143216773




Kernel CKA concern: 0.27853066737172316




Kernel CKA non-concern: 0.15300192306900515




Evaluate the pruned model 10




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.6701




Precision: 0.6938, Recall: 0.4248, F1-Score: 0.4652




              precision    recall  f1-score   support

           0       0.66      0.41      0.51       797
           1       0.78      0.28      0.41       775
           2       0.75      0.50      0.60       795
           3       0.88      0.48      0.62      1110
           4       0.59      0.75      0.66      1260
           5       0.91      0.30      0.45       882
           6       0.87      0.38      0.53       940
           7       0.56      0.04      0.07       473
           8       0.76      0.26      0.39       746
           9       0.37      0.51      0.43       689
          10       0.50      0.61      0.55       670
          11       0.75      0.12      0.20       312
          12       0.83      0.33      0.47       665
          13       0.76      0.41      0.53       314
          14       0.86      0.46      0.60       756
          15       0.28      0.97      0.43      1607

    accuracy                           0.49     12791
   macro avg       0.69   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5362922271599231, 0.5362922271599231)




CCA coefficients mean non-concern: (0.532784496122194, 0.532784496122194)




Linear CKA concern: 0.29821700852680416




Linear CKA non-concern: 0.2133563379927985




Kernel CKA concern: 0.25128585347278953




Kernel CKA non-concern: 0.16367556426376428




Evaluate the pruned model 11




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.8889




Precision: 0.6865, Recall: 0.3715, F1-Score: 0.4016




              precision    recall  f1-score   support

           0       0.72      0.30      0.42       797
           1       0.77      0.08      0.14       775
           2       0.92      0.42      0.58       795
           3       0.80      0.64      0.71      1110
           4       0.66      0.40      0.50      1260
           5       0.91      0.26      0.41       882
           6       0.88      0.18      0.30       940
           7       0.47      0.09      0.15       473
           8       0.76      0.14      0.24       746
           9       0.26      0.64      0.37       689
          10       0.38      0.56      0.46       670
          11       0.69      0.21      0.32       312
          12       0.81      0.22      0.35       665
          13       0.79      0.47      0.59       314
          14       0.90      0.33      0.48       756
          15       0.25      0.99      0.40      1607

    accuracy                           0.42     12791
   macro avg       0.69   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.534824887807352, 0.534824887807352)




CCA coefficients mean non-concern: (0.5323968046335731, 0.5323968046335731)




Linear CKA concern: 0.23610770166494544




Linear CKA non-concern: 0.19855948075071625




Kernel CKA concern: 0.22154930006290846




Kernel CKA non-concern: 0.1429215391310387




Evaluate the pruned model 12




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.7091




Precision: 0.6837, Recall: 0.4977, F1-Score: 0.5185




              precision    recall  f1-score   support

           0       0.84      0.27      0.40       797
           1       0.82      0.25      0.38       775
           2       0.79      0.63      0.70       795
           3       0.81      0.70      0.75      1110
           4       0.73      0.60      0.66      1260
           5       0.91      0.45      0.60       882
           6       0.85      0.52      0.65       940
           7       0.40      0.13      0.20       473
           8       0.75      0.28      0.41       746
           9       0.44      0.45      0.44       689
          10       0.21      0.83      0.34       670
          11       0.77      0.16      0.26       312
          12       0.67      0.51      0.58       665
          13       0.66      0.70      0.68       314
          14       0.86      0.51      0.64       756
          15       0.44      0.98      0.61      1607

    accuracy                           0.55     12791
   macro avg       0.68   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5337331894075009, 0.5337331894075009)




CCA coefficients mean non-concern: (0.530404273711001, 0.530404273711001)




Linear CKA concern: 0.21045672834677479




Linear CKA non-concern: 0.17607275178842047




Kernel CKA concern: 0.1519244414825909




Kernel CKA non-concern: 0.12252105320061418




Evaluate the pruned model 13




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.9459




Precision: 0.6855, Recall: 0.3773, F1-Score: 0.3872




              precision    recall  f1-score   support

           0       0.69      0.31      0.42       797
           1       0.76      0.16      0.27       775
           2       0.95      0.34      0.50       795
           3       0.89      0.50      0.64      1110
           4       0.43      0.71      0.54      1260
           5       0.90      0.27      0.42       882
           6       0.90      0.08      0.14       940
           7       0.42      0.08      0.13       473
           8       0.82      0.16      0.27       746
           9       0.28      0.64      0.39       689
          10       0.45      0.50      0.48       670
          11       0.73      0.06      0.11       312
          12       0.91      0.16      0.27       665
          13       0.65      0.75      0.70       314
          14       0.91      0.33      0.49       756
          15       0.28      0.99      0.44      1607

    accuracy                           0.43     12791
   macro avg       0.69   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5351827810571271, 0.5351827810571271)




CCA coefficients mean non-concern: (0.5291987000482947, 0.5291987000482947)




Linear CKA concern: 0.37703003079605923




Linear CKA non-concern: 0.20516758721526188




Kernel CKA concern: 0.2975250280586392




Kernel CKA non-concern: 0.13931461580280732




Evaluate the pruned model 14




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 1.8362




Precision: 0.6999, Recall: 0.3886, F1-Score: 0.4209




              precision    recall  f1-score   support

           0       0.84      0.26      0.40       797
           1       0.78      0.23      0.35       775
           2       0.95      0.31      0.47       795
           3       0.88      0.46      0.60      1110
           4       0.70      0.41      0.52      1260
           5       0.89      0.33      0.48       882
           6       0.91      0.16      0.27       940
           7       0.43      0.03      0.06       473
           8       0.75      0.27      0.39       746
           9       0.31      0.58      0.40       689
          10       0.33      0.64      0.44       670
          11       0.75      0.07      0.12       312
          12       0.80      0.37      0.50       665
          13       0.82      0.54      0.65       314
          14       0.81      0.58      0.67       756
          15       0.26      0.99      0.41      1607

    accuracy                           0.44     12791
   macro avg       0.70   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5255628808783187, 0.5255628808783187)




CCA coefficients mean non-concern: (0.534827535659655, 0.534827535659655)




Linear CKA concern: 0.15684830975427616




Linear CKA non-concern: 0.20973423428638985




Kernel CKA concern: 0.1336089458302425




Kernel CKA non-concern: 0.155797212887236




Evaluate the pruned model 15




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.0599




Precision: 0.6985, Recall: 0.2637, F1-Score: 0.2934




              precision    recall  f1-score   support

           0       0.72      0.33      0.45       797
           1       0.81      0.04      0.07       775
           2       1.00      0.01      0.02       795
           3       0.92      0.15      0.26      1110
           4       0.79      0.22      0.34      1260
           5       0.87      0.18      0.29       882
           6       0.90      0.25      0.40       940
           7       0.20      0.40      0.27       473
           8       0.57      0.30      0.40       746
           9       0.42      0.31      0.35       689
          10       0.59      0.32      0.42       670
          11       0.68      0.13      0.22       312
          12       0.77      0.34      0.47       665
          13       0.83      0.12      0.22       314
          14       0.92      0.12      0.21       756
          15       0.18      1.00      0.31      1607

    accuracy                           0.31     12791
   macro avg       0.70   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5385226743969855, 0.5385226743969855)




CCA coefficients mean non-concern: (0.5253897968154191, 0.5253897968154191)




Linear CKA concern: 0.7471482256380946




Linear CKA non-concern: 0.26302060996937326




Kernel CKA concern: 0.6808166760615172




Kernel CKA non-concern: 0.18616418882442384


