In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.3
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-23 15:13:20


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}




The model sadickam/sdg-classification-bert is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}




Loading cached dataset OSDG.




The dataset OSDG is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Evaluate the pruned model 0




Evaluating:   0%|                                                                              | 0/800 [00:15<…

Loss: 0.9440




Precision: 0.7777, Recall: 0.7826, F1-Score: 0.7759




              precision    recall  f1-score   support

           0       0.75      0.66      0.70       797
           1       0.85      0.71      0.77       775
           2       0.88      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.48      0.60      0.53       473
           8       0.65      0.85      0.74       746
           9       0.60      0.74      0.66       689
          10       0.76      0.78      0.77       670
          11       0.62      0.80      0.70       312
          12       0.71      0.81      0.76       665
          13       0.85      0.85      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9193124957301936, 0.9193124957301936)




CCA coefficients mean non-concern: (0.91266331296639, 0.91266331296639)




Linear CKA concern: 0.9908229619675725




Linear CKA non-concern: 0.9794335145776694




Kernel CKA concern: 0.9887720431004353




Kernel CKA non-concern: 0.9793532711754057




Evaluate the pruned model 1




Evaluating:   0%|                                                                              | 0/800 [00:23<…

Loss: 0.9402




Precision: 0.7778, Recall: 0.7842, F1-Score: 0.7769




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.84      0.71      0.77       775
           2       0.87      0.88      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.83       940
           7       0.48      0.60      0.53       473
           8       0.66      0.85      0.74       746
           9       0.60      0.73      0.66       689
          10       0.75      0.78      0.77       670
          11       0.63      0.80      0.70       312
          12       0.72      0.81      0.76       665
          13       0.83      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9124525412544445, 0.9124525412544445)




CCA coefficients mean non-concern: (0.9132289052519235, 0.9132289052519235)




Linear CKA concern: 0.9894103551503298




Linear CKA non-concern: 0.9847218880050923




Kernel CKA concern: 0.9878910795675058




Kernel CKA non-concern: 0.9848688087382931




Evaluate the pruned model 2




Evaluating:   0%|                                                                              | 0/800 [00:23<…

Loss: 0.9400




Precision: 0.7773, Recall: 0.7827, F1-Score: 0.7758




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.68      0.77       882
           6       0.86      0.80      0.83       940
           7       0.48      0.60      0.53       473
           8       0.66      0.85      0.74       746
           9       0.59      0.73      0.65       689
          10       0.75      0.78      0.77       670
          11       0.63      0.80      0.71       312
          12       0.71      0.81      0.76       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9159352268034964, 0.9159352268034964)




CCA coefficients mean non-concern: (0.9117008423794506, 0.9117008423794506)




Linear CKA concern: 0.99418450129213




Linear CKA non-concern: 0.980641138824491




Kernel CKA concern: 0.9921456511219472




Kernel CKA non-concern: 0.981370519960456




Evaluate the pruned model 3




Evaluating:   0%|                                                                              | 0/800 [00:23<…

Loss: 0.9423




Precision: 0.7777, Recall: 0.7831, F1-Score: 0.7765




              precision    recall  f1-score   support

           0       0.75      0.66      0.71       797
           1       0.84      0.71      0.77       775
           2       0.88      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.48      0.59      0.53       473
           8       0.66      0.85      0.75       746
           9       0.60      0.73      0.66       689
          10       0.75      0.79      0.77       670
          11       0.63      0.79      0.70       312
          12       0.72      0.81      0.76       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.91652184147496, 0.91652184147496)




CCA coefficients mean non-concern: (0.913023972771409, 0.913023972771409)




Linear CKA concern: 0.9921610366683729




Linear CKA non-concern: 0.9816067962630018




Kernel CKA concern: 0.9904902604666389




Kernel CKA non-concern: 0.9815474170075618




Evaluate the pruned model 4




Evaluating:   0%|                                                                              | 0/800 [00:23<…

Loss: 0.9377




Precision: 0.7770, Recall: 0.7826, F1-Score: 0.7753




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.87      0.87      0.87       795
           3       0.88      0.80      0.84      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.83       940
           7       0.47      0.59      0.52       473
           8       0.66      0.85      0.75       746
           9       0.59      0.74      0.65       689
          10       0.75      0.78      0.77       670
          11       0.63      0.80      0.71       312
          12       0.71      0.82      0.76       665
          13       0.84      0.86      0.85       314
          14       0.86      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9193560950737388, 0.9193560950737388)




CCA coefficients mean non-concern: (0.9094789170003488, 0.9094789170003488)




Linear CKA concern: 0.9927184235518137




Linear CKA non-concern: 0.9760076205002839




Kernel CKA concern: 0.9902087297450256




Kernel CKA non-concern: 0.9754159393182291




Evaluate the pruned model 5




Evaluating:   0%|                                                                              | 0/800 [00:23<…

Loss: 0.9377




Precision: 0.7770, Recall: 0.7828, F1-Score: 0.7758




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.84      0.71      0.77       775
           2       0.88      0.88      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.68      0.77       882
           6       0.86      0.79      0.83       940
           7       0.48      0.59      0.53       473
           8       0.66      0.86      0.75       746
           9       0.59      0.73      0.66       689
          10       0.74      0.78      0.76       670
          11       0.63      0.79      0.70       312
          12       0.71      0.81      0.75       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9148222389731554, 0.9148222389731554)




CCA coefficients mean non-concern: (0.9112934014378272, 0.9112934014378272)




Linear CKA concern: 0.9908640392953464




Linear CKA non-concern: 0.9822974215165987




Kernel CKA concern: 0.9887859524468375




Kernel CKA non-concern: 0.9833000089035107




Evaluate the pruned model 6




Evaluating:   0%|                                                                              | 0/800 [00:23<…

Loss: 0.9424




Precision: 0.7780, Recall: 0.7843, F1-Score: 0.7770




              precision    recall  f1-score   support

           0       0.75      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.87      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.84      0.81      0.82       940
           7       0.48      0.60      0.53       473
           8       0.67      0.85      0.75       746
           9       0.60      0.73      0.66       689
          10       0.77      0.78      0.77       670
          11       0.62      0.81      0.70       312
          12       0.71      0.81      0.76       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9162884558808385, 0.9162884558808385)




CCA coefficients mean non-concern: (0.9145705638925775, 0.9145705638925775)




Linear CKA concern: 0.9923922434999461




Linear CKA non-concern: 0.9830757148688706




Kernel CKA concern: 0.9909139502202282




Kernel CKA non-concern: 0.9829993130630095




Evaluate the pruned model 7




Evaluating:   0%|                                                                              | 0/800 [00:23<…

Loss: 0.9439




Precision: 0.7783, Recall: 0.7835, F1-Score: 0.7767




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.86      0.80      0.83       940
           7       0.48      0.60      0.53       473
           8       0.65      0.85      0.74       746
           9       0.60      0.73      0.66       689
          10       0.75      0.78      0.76       670
          11       0.63      0.79      0.70       312
          12       0.71      0.81      0.76       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9187009513161823, 0.9187009513161823)




CCA coefficients mean non-concern: (0.9143910482673601, 0.9143910482673601)




Linear CKA concern: 0.9895041873210622




Linear CKA non-concern: 0.9843508958007784




Kernel CKA concern: 0.9877028590645769




Kernel CKA non-concern: 0.9845388465745311




Evaluate the pruned model 8




Evaluating:   0%|                                                                              | 0/800 [00:23<…

Loss: 0.9452




Precision: 0.7768, Recall: 0.7826, F1-Score: 0.7755




              precision    recall  f1-score   support

           0       0.76      0.66      0.70       797
           1       0.85      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.85      0.80      0.82      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.49      0.60      0.54       473
           8       0.65      0.86      0.74       746
           9       0.60      0.73      0.66       689
          10       0.76      0.79      0.77       670
          11       0.63      0.80      0.70       312
          12       0.72      0.80      0.76       665
          13       0.83      0.85      0.84       314
          14       0.85      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9166504820821633, 0.9166504820821633)




CCA coefficients mean non-concern: (0.9092130938382859, 0.9092130938382859)




Linear CKA concern: 0.9925227564244851




Linear CKA non-concern: 0.977771078355977




Kernel CKA concern: 0.990926400768598




Kernel CKA non-concern: 0.9776837078349959




Evaluate the pruned model 9




Evaluating:   0%|                                                                              | 0/800 [00:23<…

Loss: 0.9425




Precision: 0.7785, Recall: 0.7835, F1-Score: 0.7769




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.84      0.71      0.77       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.83       940
           7       0.47      0.59      0.53       473
           8       0.66      0.86      0.75       746
           9       0.59      0.73      0.65       689
          10       0.75      0.79      0.77       670
          11       0.64      0.79      0.71       312
          12       0.71      0.82      0.76       665
          13       0.84      0.86      0.85       314
          14       0.86      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9158323345617861, 0.9158323345617861)




CCA coefficients mean non-concern: (0.9143509316560726, 0.9143509316560726)




Linear CKA concern: 0.9904082703575066




Linear CKA non-concern: 0.9819932364750831




Kernel CKA concern: 0.9879377085965413




Kernel CKA non-concern: 0.9822383953388301




Evaluate the pruned model 10




Evaluating:   0%|                                                                              | 0/800 [00:23<…

Loss: 0.9412




Precision: 0.7777, Recall: 0.7835, F1-Score: 0.7766




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.68      0.77       882
           6       0.86      0.80      0.83       940
           7       0.48      0.60      0.53       473
           8       0.67      0.85      0.75       746
           9       0.60      0.73      0.66       689
          10       0.75      0.79      0.77       670
          11       0.62      0.80      0.70       312
          12       0.71      0.82      0.76       665
          13       0.84      0.85      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9163444759614677, 0.9163444759614677)




CCA coefficients mean non-concern: (0.9127984154506947, 0.9127984154506947)




Linear CKA concern: 0.9899454601343135




Linear CKA non-concern: 0.9823610955628191




Kernel CKA concern: 0.9883255033143247




Kernel CKA non-concern: 0.9829628244590586




Evaluate the pruned model 11




Evaluating:   0%|                                                                              | 0/800 [00:24<…

Loss: 0.9431




Precision: 0.7775, Recall: 0.7838, F1-Score: 0.7764




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.84      0.71      0.77       775
           2       0.87      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.69      0.77       882
           6       0.86      0.80      0.83       940
           7       0.48      0.60      0.53       473
           8       0.66      0.85      0.74       746
           9       0.59      0.73      0.66       689
          10       0.76      0.78      0.77       670
          11       0.62      0.81      0.70       312
          12       0.72      0.81      0.76       665
          13       0.83      0.85      0.84       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9154807112494509, 0.9154807112494509)




CCA coefficients mean non-concern: (0.9115943321921114, 0.9115943321921114)




Linear CKA concern: 0.992875421263338




Linear CKA non-concern: 0.9821612070523186




Kernel CKA concern: 0.9911497403342683




Kernel CKA non-concern: 0.9830628026685023




Evaluate the pruned model 12




Evaluating:   0%|                                                                              | 0/800 [00:25<…

Loss: 0.9419




Precision: 0.7766, Recall: 0.7822, F1-Score: 0.7751




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.70      0.77       775
           2       0.87      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.79      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.48      0.58      0.52       473
           8       0.66      0.85      0.75       746
           9       0.59      0.74      0.65       689
          10       0.74      0.79      0.77       670
          11       0.63      0.79      0.70       312
          12       0.71      0.81      0.76       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9132891727668296, 0.9132891727668296)




CCA coefficients mean non-concern: (0.9113029724376531, 0.9113029724376531)




Linear CKA concern: 0.9885252946684355




Linear CKA non-concern: 0.9816061559894808




Kernel CKA concern: 0.986971593746534




Kernel CKA non-concern: 0.9822684245189065




Evaluate the pruned model 13




Evaluating:   0%|                                                                              | 0/800 [00:25<…

Loss: 0.9390




Precision: 0.7776, Recall: 0.7839, F1-Score: 0.7765




              precision    recall  f1-score   support

           0       0.75      0.66      0.71       797
           1       0.84      0.70      0.77       775
           2       0.87      0.88      0.87       795
           3       0.88      0.82      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.48      0.59      0.53       473
           8       0.66      0.86      0.74       746
           9       0.60      0.73      0.66       689
          10       0.76      0.78      0.77       670
          11       0.62      0.81      0.70       312
          12       0.72      0.81      0.76       665
          13       0.83      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.915666291759241, 0.915666291759241)




CCA coefficients mean non-concern: (0.9109924775849277, 0.9109924775849277)




Linear CKA concern: 0.9930433399085622




Linear CKA non-concern: 0.9811681758307919




Kernel CKA concern: 0.9904148840503455




Kernel CKA non-concern: 0.9815273009369637




Evaluate the pruned model 14




Evaluating:   0%|                                                                              | 0/800 [00:23<…

Loss: 0.9400




Precision: 0.7767, Recall: 0.7826, F1-Score: 0.7756




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.87      0.87      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.48      0.58      0.53       473
           8       0.66      0.85      0.75       746
           9       0.59      0.74      0.66       689
          10       0.75      0.79      0.77       670
          11       0.63      0.80      0.70       312
          12       0.71      0.81      0.76       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9154672966793579, 0.9154672966793579)




CCA coefficients mean non-concern: (0.9116706124125982, 0.9116706124125982)




Linear CKA concern: 0.9901336509564084




Linear CKA non-concern: 0.9816844496146544




Kernel CKA concern: 0.9893988247039248




Kernel CKA non-concern: 0.9820615794314419




Evaluate the pruned model 15




Evaluating:   0%|                                                                              | 0/800 [00:23<…

Loss: 0.9408




Precision: 0.7770, Recall: 0.7801, F1-Score: 0.7738




              precision    recall  f1-score   support

           0       0.78      0.64      0.70       797
           1       0.85      0.70      0.77       775
           2       0.87      0.87      0.87       795
           3       0.88      0.82      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.78       882
           6       0.85      0.80      0.82       940
           7       0.46      0.60      0.52       473
           8       0.65      0.84      0.73       746
           9       0.57      0.74      0.65       689
          10       0.77      0.78      0.77       670
          11       0.63      0.80      0.71       312
          12       0.71      0.80      0.76       665
          13       0.84      0.84      0.84       314
          14       0.85      0.78      0.81       756
          15       0.96      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9109156465672228, 0.9109156465672228)




CCA coefficients mean non-concern: (0.8889157088878646, 0.8889157088878646)




Linear CKA concern: 0.9892197594104027




Linear CKA non-concern: 0.9710207897871301




Kernel CKA concern: 0.9875590877697454




Kernel CKA non-concern: 0.9708840006109575


