In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.3
seed = 44
include_layers = ["intermediate", "output"]
exclude_layers = ["attention"]

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-09-01 09:59:38


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}




The model sadickam/sdg-classification-bert is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}




Loading cached dataset OSDG.




The dataset OSDG is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

Evaluate the pruned model 0




Evaluating:   0%|                                                                                             …

Loss: 0.9466




Precision: 0.7790, Recall: 0.7848, F1-Score: 0.7777




              precision    recall  f1-score   support

           0       0.75      0.66      0.70       797
           1       0.85      0.71      0.77       775
           2       0.88      0.88      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.87      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.49      0.60      0.54       473
           8       0.66      0.85      0.75       746
           9       0.59      0.73      0.66       689
          10       0.75      0.78      0.77       670
          11       0.62      0.81      0.71       312
          12       0.73      0.81      0.77       665
          13       0.85      0.85      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9540669772454914, 0.9540669772454914)




CCA coefficients mean non-concern: (0.9506332937035338, 0.9506332937035338)




Linear CKA concern: 0.995836565145103




Linear CKA non-concern: 0.9901470284105335




Kernel CKA concern: 0.9947225178001808




Kernel CKA non-concern: 0.9891516742511443




Evaluate the pruned model 1




Evaluating:   0%|                                                                                             …

Loss: 0.9450




Precision: 0.7797, Recall: 0.7860, F1-Score: 0.7786




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.84      0.72      0.78       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.83       940
           7       0.49      0.60      0.54       473
           8       0.66      0.85      0.74       746
           9       0.60      0.74      0.66       689
          10       0.75      0.78      0.77       670
          11       0.62      0.82      0.70       312
          12       0.73      0.81      0.77       665
          13       0.84      0.85      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9511079077265905, 0.9511079077265905)




CCA coefficients mean non-concern: (0.9508733412168882, 0.9508733412168882)




Linear CKA concern: 0.9961583938966241




Linear CKA non-concern: 0.9920008998799796




Kernel CKA concern: 0.9951226450791594




Kernel CKA non-concern: 0.9914169168432116




Evaluate the pruned model 2




Evaluating:   0%|                                                                                             …

Loss: 0.9455




Precision: 0.7791, Recall: 0.7854, F1-Score: 0.7780




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.72      0.78       775
           2       0.88      0.88      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.87      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.49      0.60      0.54       473
           8       0.66      0.85      0.74       746
           9       0.60      0.73      0.66       689
          10       0.75      0.78      0.76       670
          11       0.62      0.81      0.70       312
          12       0.73      0.81      0.77       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.951326067788823, 0.951326067788823)




CCA coefficients mean non-concern: (0.9495535153240132, 0.9495535153240132)




Linear CKA concern: 0.9978338983038798




Linear CKA non-concern: 0.9908963486038542




Kernel CKA concern: 0.9968872685029151




Kernel CKA non-concern: 0.9900342902449439




Evaluate the pruned model 3




Evaluating:   0%|                                                                                             …

Loss: 0.9468




Precision: 0.7792, Recall: 0.7858, F1-Score: 0.7783




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.83       940
           7       0.49      0.60      0.54       473
           8       0.67      0.85      0.75       746
           9       0.60      0.74      0.66       689
          10       0.75      0.79      0.77       670
          11       0.61      0.81      0.70       312
          12       0.73      0.81      0.77       665
          13       0.84      0.86      0.85       314
          14       0.84      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9521664540745722, 0.9521664540745722)




CCA coefficients mean non-concern: (0.950298327328068, 0.950298327328068)




Linear CKA concern: 0.9955854669686657




Linear CKA non-concern: 0.9927584232799028




Kernel CKA concern: 0.9943415544989385




Kernel CKA non-concern: 0.992099081186322




Evaluate the pruned model 4




Evaluating:   0%|                                                                                             …

Loss: 0.9459




Precision: 0.7789, Recall: 0.7848, F1-Score: 0.7777




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.83       940
           7       0.49      0.60      0.54       473
           8       0.66      0.85      0.75       746
           9       0.60      0.73      0.66       689
          10       0.75      0.79      0.77       670
          11       0.62      0.81      0.70       312
          12       0.73      0.81      0.77       665
          13       0.84      0.85      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9555276899081369, 0.9555276899081369)




CCA coefficients mean non-concern: (0.9508549062600099, 0.9508549062600099)




Linear CKA concern: 0.9976361701670792




Linear CKA non-concern: 0.9924612977187124




Kernel CKA concern: 0.9966142033185044




Kernel CKA non-concern: 0.9917457089094553




Evaluate the pruned model 5




Evaluating:   0%|                                                                                             …

Loss: 0.9443




Precision: 0.7800, Recall: 0.7865, F1-Score: 0.7792




              precision    recall  f1-score   support

           0       0.77      0.66      0.71       797
           1       0.85      0.72      0.78       775
           2       0.88      0.88      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.87      0.80      0.83      1260
           5       0.89      0.69      0.77       882
           6       0.85      0.80      0.82       940
           7       0.49      0.60      0.54       473
           8       0.66      0.85      0.75       746
           9       0.60      0.74      0.66       689
          10       0.75      0.79      0.77       670
          11       0.62      0.81      0.70       312
          12       0.73      0.81      0.77       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9518863249652216, 0.9518863249652216)




CCA coefficients mean non-concern: (0.9486600974126126, 0.9486600974126126)




Linear CKA concern: 0.9968242771086481




Linear CKA non-concern: 0.9924540993232112




Kernel CKA concern: 0.9958003222408051




Kernel CKA non-concern: 0.9920426832587107




Evaluate the pruned model 6




Evaluating:   0%|                                                                                             …

Loss: 0.9474




Precision: 0.7790, Recall: 0.7857, F1-Score: 0.7781




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.87      0.80      0.83      1260
           5       0.88      0.68      0.77       882
           6       0.85      0.80      0.83       940
           7       0.49      0.60      0.54       473
           8       0.67      0.85      0.75       746
           9       0.60      0.73      0.66       689
          10       0.75      0.79      0.77       670
          11       0.61      0.82      0.70       312
          12       0.73      0.81      0.77       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9520078239843726, 0.9520078239843726)




CCA coefficients mean non-concern: (0.9499090708648046, 0.9499090708648046)




Linear CKA concern: 0.9965129954298346




Linear CKA non-concern: 0.9922963755437197




Kernel CKA concern: 0.9955786837118507




Kernel CKA non-concern: 0.9915963090011228




Evaluate the pruned model 7




Evaluating:   0%|                                                                                             …

Loss: 0.9469




Precision: 0.7791, Recall: 0.7856, F1-Score: 0.7780




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.49      0.60      0.54       473
           8       0.66      0.85      0.75       746
           9       0.60      0.74      0.66       689
          10       0.75      0.79      0.77       670
          11       0.62      0.81      0.70       312
          12       0.73      0.82      0.77       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9524164982053857, 0.9524164982053857)




CCA coefficients mean non-concern: (0.9510656288686525, 0.9510656288686525)




Linear CKA concern: 0.9956930561938677




Linear CKA non-concern: 0.9926151191715574




Kernel CKA concern: 0.9943729388218377




Kernel CKA non-concern: 0.992054352192508




Evaluate the pruned model 8




Evaluating:   0%|                                                                                             …

Loss: 0.9463




Precision: 0.7795, Recall: 0.7858, F1-Score: 0.7783




              precision    recall  f1-score   support

           0       0.77      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.88      0.88      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.83       940
           7       0.48      0.60      0.54       473
           8       0.66      0.85      0.75       746
           9       0.60      0.74      0.66       689
          10       0.75      0.79      0.77       670
          11       0.62      0.81      0.70       312
          12       0.73      0.81      0.77       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x799b4c09b3a0>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x799b4c09b3a0>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x799b4c09b3a0>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x799b4c09b3a0>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9527852668604011, 0.9527852668604011)




CCA coefficients mean non-concern: (0.9480288124875259, 0.9480288124875259)




Linear CKA concern: 0.9972978365272177




Linear CKA non-concern: 0.9917819974805743




Kernel CKA concern: 0.9962344976855155




Kernel CKA non-concern: 0.9909954288933422




Evaluate the pruned model 9




Evaluating:   0%|                                                                                             …

Loss: 0.9473




Precision: 0.7788, Recall: 0.7848, F1-Score: 0.7776




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.70      0.77       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.49      0.60      0.54       473
           8       0.66      0.85      0.75       746
           9       0.59      0.74      0.66       689
          10       0.75      0.79      0.77       670
          11       0.62      0.81      0.70       312
          12       0.73      0.81      0.77       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9531781149970219, 0.9531781149970219)




CCA coefficients mean non-concern: (0.9515286977935444, 0.9515286977935444)




Linear CKA concern: 0.9957235412228682




Linear CKA non-concern: 0.9923889510298448




Kernel CKA concern: 0.9945312876317438




Kernel CKA non-concern: 0.9918092234664403




Evaluate the pruned model 10




Evaluating:   0%|                                                                                             …

Loss: 0.9469




Precision: 0.7797, Recall: 0.7858, F1-Score: 0.7786




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.78       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.68      0.77       882
           6       0.85      0.80      0.83       940
           7       0.49      0.60      0.54       473
           8       0.66      0.85      0.74       746
           9       0.60      0.74      0.66       689
          10       0.75      0.78      0.76       670
          11       0.62      0.82      0.71       312
          12       0.73      0.81      0.77       665
          13       0.85      0.85      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9538815367560928, 0.9538815367560928)




CCA coefficients mean non-concern: (0.9494939193784263, 0.9494939193784263)




Linear CKA concern: 0.9964602094859324




Linear CKA non-concern: 0.9926741249250615




Kernel CKA concern: 0.9954836019911123




Kernel CKA non-concern: 0.9921502178710394




Evaluate the pruned model 11




Evaluating:   0%|                                                                                             …

Loss: 0.9463




Precision: 0.7794, Recall: 0.7862, F1-Score: 0.7785




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.78       775
           2       0.88      0.88      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.83       940
           7       0.48      0.60      0.54       473
           8       0.67      0.85      0.75       746
           9       0.60      0.73      0.66       689
          10       0.75      0.79      0.77       670
          11       0.61      0.82      0.70       312
          12       0.73      0.81      0.77       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9521717499413126, 0.9521717499413126)




CCA coefficients mean non-concern: (0.9482963307714837, 0.9482963307714837)




Linear CKA concern: 0.996501967885802




Linear CKA non-concern: 0.9922373692714863




Kernel CKA concern: 0.9951990679149009




Kernel CKA non-concern: 0.9916648914744952




Evaluate the pruned model 12




Evaluating:   0%|                                                                                             …

Loss: 0.9456




Precision: 0.7798, Recall: 0.7867, F1-Score: 0.7789




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.88      0.88      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.87      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.83       940
           7       0.49      0.60      0.54       473
           8       0.67      0.85      0.75       746
           9       0.60      0.74      0.66       689
          10       0.75      0.79      0.77       670
          11       0.62      0.82      0.71       312
          12       0.73      0.81      0.77       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9517890857461121, 0.9517890857461121)




CCA coefficients mean non-concern: (0.9500145519538469, 0.9500145519538469)




Linear CKA concern: 0.9962515224582403




Linear CKA non-concern: 0.9919616787085579




Kernel CKA concern: 0.9950685921976998




Kernel CKA non-concern: 0.9913048143979944




Evaluate the pruned model 13




Evaluating:   0%|                                                                                             …

Loss: 0.9458




Precision: 0.7793, Recall: 0.7857, F1-Score: 0.7782




              precision    recall  f1-score   support

           0       0.76      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.88      0.87      0.88       795
           3       0.88      0.82      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.69      0.77       882
           6       0.85      0.80      0.82       940
           7       0.49      0.60      0.54       473
           8       0.66      0.86      0.74       746
           9       0.59      0.74      0.66       689
          10       0.76      0.78      0.77       670
          11       0.62      0.82      0.71       312
          12       0.73      0.81      0.77       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9497675954451266, 0.9497675954451266)




CCA coefficients mean non-concern: (0.9476668701744495, 0.9476668701744495)




Linear CKA concern: 0.9971982232087736




Linear CKA non-concern: 0.9909981332160653




Kernel CKA concern: 0.9956915104321402




Kernel CKA non-concern: 0.9901350365324503




Evaluate the pruned model 14




Evaluating:   0%|                                                                                             …

Loss: 0.9464




Precision: 0.7796, Recall: 0.7855, F1-Score: 0.7783




              precision    recall  f1-score   support

           0       0.77      0.66      0.71       797
           1       0.85      0.71      0.77       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.68      0.77       882
           6       0.85      0.81      0.83       940
           7       0.49      0.60      0.54       473
           8       0.66      0.85      0.74       746
           9       0.59      0.74      0.66       689
          10       0.75      0.79      0.77       670
          11       0.62      0.82      0.70       312
          12       0.73      0.80      0.77       665
          13       0.85      0.85      0.85       314
          14       0.84      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.9537106131699208, 0.9537106131699208)




CCA coefficients mean non-concern: (0.9508407748021825, 0.9508407748021825)




Linear CKA concern: 0.9966608580088452




Linear CKA non-concern: 0.9920799769302061




Kernel CKA concern: 0.9957870482530276




Kernel CKA non-concern: 0.9913775787042993




Evaluate the pruned model 15




Evaluating:   0%|                                                                                             …

Loss: 0.9499




Precision: 0.7795, Recall: 0.7846, F1-Score: 0.7775




              precision    recall  f1-score   support

           0       0.77      0.65      0.71       797
           1       0.85      0.71      0.77       775
           2       0.88      0.88      0.88       795
           3       0.87      0.82      0.85      1110
           4       0.87      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.81      0.83       940
           7       0.49      0.59      0.53       473
           8       0.66      0.86      0.74       746
           9       0.59      0.74      0.65       689
          10       0.75      0.78      0.77       670
          11       0.62      0.81      0.70       312
          12       0.73      0.81      0.77       665
          13       0.85      0.86      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.954176201667557, 0.954176201667557)




CCA coefficients mean non-concern: (0.9397653160313101, 0.9397653160313101)




Linear CKA concern: 0.9967433831706075




Linear CKA non-concern: 0.9853116606957745




Kernel CKA concern: 0.9955821993423902




Kernel CKA non-concern: 0.9838835124868853


