In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.4
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-31 22:23:21


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}




The model sadickam/sdg-classification-bert is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}




Loading cached dataset OSDG.




The dataset OSDG is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

Evaluate the pruned model 0




Evaluating:   0%|                                                                                             …

Loss: 0.9400




Precision: 0.7784, Recall: 0.7804, F1-Score: 0.7752




              precision    recall  f1-score   support

           0       0.74      0.66      0.70       797
           1       0.84      0.71      0.77       775
           2       0.88      0.87      0.87       795
           3       0.87      0.82      0.85      1110
           4       0.85      0.81      0.83      1260
           5       0.90      0.68      0.77       882
           6       0.85      0.79      0.82       940
           7       0.48      0.59      0.53       473
           8       0.66      0.85      0.74       746
           9       0.57      0.74      0.64       689
          10       0.78      0.78      0.78       670
          11       0.67      0.79      0.73       312
          12       0.69      0.81      0.75       665
          13       0.85      0.84      0.85       314
          14       0.86      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8351391954062671, 0.8351391954062671)




CCA coefficients mean non-concern: (0.8276816797279373, 0.8276816797279373)




Linear CKA concern: 0.9781988473413965




Linear CKA non-concern: 0.9493404689476994




Kernel CKA concern: 0.9733705819152783




Kernel CKA non-concern: 0.9500517153726854




Evaluate the pruned model 1




Evaluating:   0%|                                                                                             …

Loss: 0.9374




Precision: 0.7767, Recall: 0.7807, F1-Score: 0.7746




              precision    recall  f1-score   support

           0       0.76      0.65      0.70       797
           1       0.84      0.72      0.78       775
           2       0.87      0.88      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.47      0.58      0.52       473
           8       0.65      0.86      0.74       746
           9       0.58      0.73      0.65       689
          10       0.77      0.78      0.77       670
          11       0.66      0.79      0.72       312
          12       0.70      0.81      0.75       665
          13       0.83      0.85      0.84       314
          14       0.86      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8284251749818424, 0.8284251749818424)




CCA coefficients mean non-concern: (0.8308429187888139, 0.8308429187888139)




Linear CKA concern: 0.9733304932210476




Linear CKA non-concern: 0.9583979890338145




Kernel CKA concern: 0.9695114509896925




Kernel CKA non-concern: 0.9592634096148827




Evaluate the pruned model 2




Evaluating:   0%|                                                                                             …

Loss: 0.9326




Precision: 0.7763, Recall: 0.7800, F1-Score: 0.7741




              precision    recall  f1-score   support

           0       0.75      0.66      0.70       797
           1       0.84      0.70      0.77       775
           2       0.86      0.88      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.79      0.82       940
           7       0.47      0.59      0.52       473
           8       0.67      0.85      0.75       746
           9       0.59      0.73      0.65       689
          10       0.74      0.79      0.76       670
          11       0.66      0.80      0.72       312
          12       0.70      0.81      0.75       665
          13       0.85      0.84      0.84       314
          14       0.86      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8301022336406346, 0.8301022336406346)




CCA coefficients mean non-concern: (0.8253113144718512, 0.8253113144718512)




Linear CKA concern: 0.9833965987559726




Linear CKA non-concern: 0.9515816564572487




Kernel CKA concern: 0.9779089332980436




Kernel CKA non-concern: 0.9520052190053653




Evaluate the pruned model 3




Evaluating:   0%|                                                                                             …

Loss: 0.9401




Precision: 0.7782, Recall: 0.7805, F1-Score: 0.7755




              precision    recall  f1-score   support

           0       0.76      0.66      0.70       797
           1       0.84      0.71      0.77       775
           2       0.86      0.88      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.85      0.80      0.83      1260
           5       0.90      0.68      0.77       882
           6       0.85      0.79      0.82       940
           7       0.48      0.59      0.53       473
           8       0.66      0.85      0.74       746
           9       0.58      0.73      0.65       689
          10       0.77      0.78      0.77       670
          11       0.67      0.79      0.73       312
          12       0.70      0.81      0.75       665
          13       0.85      0.84      0.85       314
          14       0.86      0.78      0.82       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x711d24d1f820>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x711d24d1f820>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x711d24d1f820>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x711d24d1f820>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x711d24d1f820>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x711d24d1f820>

    




self._shutdown_workers()

Traceback (most recent call last):





  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

    

self._shutdown_workers()

if w.is_alive():







  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

    

assert self._parent_pid == os.getpid(), 'can only test a child process'

if w.is_alive():







AssertionError

  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


: 

can only test a child process

    

AssertionError




assert self._parent_pid == os.getpid(), 'can only test a child process'




: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x711d24d1f820>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x711d24d1f820>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.832435012806202, 0.832435012806202)




CCA coefficients mean non-concern: (0.8297806333256929, 0.8297806333256929)




Linear CKA concern: 0.9706427820917973




Linear CKA non-concern: 0.9579819130437068




Kernel CKA concern: 0.9670580603516934




Kernel CKA non-concern: 0.9590010183918145




Evaluate the pruned model 4




Evaluating:   0%|                                                                                             …

Loss: 0.9390




Precision: 0.7779, Recall: 0.7799, F1-Score: 0.7750




              precision    recall  f1-score   support

           0       0.75      0.66      0.70       797
           1       0.84      0.70      0.77       775
           2       0.87      0.87      0.87       795
           3       0.87      0.82      0.85      1110
           4       0.84      0.81      0.82      1260
           5       0.90      0.68      0.77       882
           6       0.84      0.79      0.82       940
           7       0.48      0.59      0.53       473
           8       0.65      0.85      0.74       746
           9       0.59      0.73      0.65       689
          10       0.77      0.78      0.77       670
          11       0.67      0.79      0.73       312
          12       0.69      0.81      0.74       665
          13       0.85      0.85      0.85       314
          14       0.86      0.78      0.82       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8406302372261398, 0.8406302372261398)




CCA coefficients mean non-concern: (0.8309326543906819, 0.8309326543906819)




Linear CKA concern: 0.9831411452800918




Linear CKA non-concern: 0.959117413756798




Kernel CKA concern: 0.978750998128986




Kernel CKA non-concern: 0.9592527226592147




Evaluate the pruned model 5




Evaluating:   0%|                                                                                             …

Loss: 0.9337




Precision: 0.7772, Recall: 0.7808, F1-Score: 0.7750




              precision    recall  f1-score   support

           0       0.76      0.65      0.70       797
           1       0.84      0.72      0.78       775
           2       0.87      0.88      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.80      0.83      1260
           5       0.89      0.69      0.78       882
           6       0.85      0.79      0.82       940
           7       0.47      0.59      0.52       473
           8       0.67      0.84      0.75       746
           9       0.58      0.74      0.65       689
          10       0.76      0.78      0.77       670
          11       0.66      0.78      0.72       312
          12       0.70      0.81      0.75       665
          13       0.84      0.85      0.84       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8325532518867677, 0.8325532518867677)




CCA coefficients mean non-concern: (0.8284939596367116, 0.8284939596367116)




Linear CKA concern: 0.9761421655709529




Linear CKA non-concern: 0.9581803309721605




Kernel CKA concern: 0.970203936040936




Kernel CKA non-concern: 0.9597916360439428




Evaluate the pruned model 6




Evaluating:   0%|                                                                                             …

Loss: 0.9402




Precision: 0.7763, Recall: 0.7795, F1-Score: 0.7738




              precision    recall  f1-score   support

           0       0.76      0.65      0.70       797
           1       0.85      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.84      0.80      0.82       940
           7       0.47      0.59      0.52       473
           8       0.65      0.85      0.74       746
           9       0.59      0.72      0.65       689
          10       0.77      0.78      0.77       670
          11       0.66      0.79      0.72       312
          12       0.69      0.81      0.75       665
          13       0.84      0.84      0.84       314
          14       0.86      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8282120075369787, 0.8282120075369787)




CCA coefficients mean non-concern: (0.8327694228205752, 0.8327694228205752)




Linear CKA concern: 0.9767245514493731




Linear CKA non-concern: 0.9593776849022133




Kernel CKA concern: 0.97204692750772




Kernel CKA non-concern: 0.9602489658455098




Evaluate the pruned model 7




Evaluating:   0%|                                                                                             …

Loss: 0.9392




Precision: 0.7773, Recall: 0.7805, F1-Score: 0.7749




              precision    recall  f1-score   support

           0       0.74      0.66      0.70       797
           1       0.84      0.70      0.76       775
           2       0.87      0.88      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.84      0.81      0.82      1260
           5       0.90      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.46      0.60      0.52       473
           8       0.67      0.85      0.75       746
           9       0.59      0.72      0.65       689
          10       0.78      0.78      0.78       670
          11       0.67      0.79      0.73       312
          12       0.69      0.82      0.75       665
          13       0.84      0.85      0.85       314
          14       0.86      0.77      0.81       756
          15       0.97      0.96      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8350735195394887, 0.8350735195394887)




CCA coefficients mean non-concern: (0.8292993697851875, 0.8292993697851875)




Linear CKA concern: 0.9725389974062076




Linear CKA non-concern: 0.9608640333475049




Kernel CKA concern: 0.9671483735056515




Kernel CKA non-concern: 0.9615525266348303




Evaluate the pruned model 8




Evaluating:   0%|                                                                                             …

Loss: 0.9426




Precision: 0.7767, Recall: 0.7796, F1-Score: 0.7739




              precision    recall  f1-score   support

           0       0.77      0.65      0.71       797
           1       0.85      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.87      0.82      0.85      1110
           4       0.84      0.80      0.82      1260
           5       0.90      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.46      0.59      0.52       473
           8       0.65      0.86      0.74       746
           9       0.58      0.73      0.65       689
          10       0.77      0.78      0.77       670
          11       0.67      0.79      0.72       312
          12       0.70      0.81      0.75       665
          13       0.83      0.85      0.84       314
          14       0.86      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8328071387107717, 0.8328071387107717)




CCA coefficients mean non-concern: (0.8281632664707769, 0.8281632664707769)




Linear CKA concern: 0.9790403489924211




Linear CKA non-concern: 0.9541814602487589




Kernel CKA concern: 0.9744147685363064




Kernel CKA non-concern: 0.9544059504279785




Evaluate the pruned model 9




Evaluating:   0%|                                                                                             …

Loss: 0.9431




Precision: 0.7766, Recall: 0.7787, F1-Score: 0.7734




              precision    recall  f1-score   support

           0       0.75      0.65      0.70       797
           1       0.84      0.69      0.76       775
           2       0.87      0.88      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.84      0.80      0.82      1260
           5       0.90      0.68      0.77       882
           6       0.85      0.79      0.82       940
           7       0.47      0.59      0.52       473
           8       0.66      0.85      0.74       746
           9       0.57      0.73      0.64       689
          10       0.77      0.79      0.78       670
          11       0.68      0.78      0.72       312
          12       0.69      0.81      0.74       665
          13       0.84      0.85      0.84       314
          14       0.86      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8319318977553726, 0.8319318977553726)




CCA coefficients mean non-concern: (0.8287967471143275, 0.8287967471143275)




Linear CKA concern: 0.9758875684195986




Linear CKA non-concern: 0.9548807938449535




Kernel CKA concern: 0.9695430467392394




Kernel CKA non-concern: 0.9565476265956246




Evaluate the pruned model 10




Evaluating:   0%|                                                                                             …

Loss: 0.9366




Precision: 0.7765, Recall: 0.7797, F1-Score: 0.7742




              precision    recall  f1-score   support

           0       0.75      0.66      0.70       797
           1       0.84      0.71      0.77       775
           2       0.88      0.87      0.88       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.81      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.84      0.79      0.82       940
           7       0.47      0.59      0.52       473
           8       0.66      0.85      0.74       746
           9       0.59      0.73      0.65       689
          10       0.75      0.79      0.77       670
          11       0.67      0.79      0.72       312
          12       0.69      0.81      0.74       665
          13       0.84      0.84      0.84       314
          14       0.86      0.78      0.82       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8324241371698773, 0.8324241371698773)




CCA coefficients mean non-concern: (0.8311990024757022, 0.8311990024757022)




Linear CKA concern: 0.9742144025310986




Linear CKA non-concern: 0.9593748229022692




Kernel CKA concern: 0.9695169666556139




Kernel CKA non-concern: 0.9607097871579037




Evaluate the pruned model 11




Evaluating:   0%|                                                                                             …

Loss: 0.9409




Precision: 0.7753, Recall: 0.7802, F1-Score: 0.7730




              precision    recall  f1-score   support

           0       0.78      0.65      0.70       797
           1       0.85      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.88      0.81      0.84      1110
           4       0.85      0.80      0.83      1260
           5       0.89      0.69      0.78       882
           6       0.85      0.80      0.82       940
           7       0.46      0.59      0.51       473
           8       0.65      0.85      0.74       746
           9       0.57      0.73      0.64       689
          10       0.76      0.78      0.77       670
          11       0.63      0.80      0.71       312
          12       0.70      0.80      0.75       665
          13       0.83      0.86      0.84       314
          14       0.86      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8304975569343751, 0.8304975569343751)




CCA coefficients mean non-concern: (0.8291627365599312, 0.8291627365599312)




Linear CKA concern: 0.9815553787173171




Linear CKA non-concern: 0.956400273506724




Kernel CKA concern: 0.9758907074270738




Kernel CKA non-concern: 0.9577735534921281




Evaluate the pruned model 12




Evaluating:   0%|                                                                                             …

Loss: 0.9375




Precision: 0.7779, Recall: 0.7805, F1-Score: 0.7750




              precision    recall  f1-score   support

           0       0.77      0.65      0.71       797
           1       0.85      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.87      0.82      0.84      1110
           4       0.85      0.80      0.82      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.80      0.82       940
           7       0.47      0.60      0.52       473
           8       0.66      0.84      0.74       746
           9       0.58      0.74      0.65       689
          10       0.77      0.78      0.77       670
          11       0.67      0.78      0.72       312
          12       0.69      0.81      0.75       665
          13       0.84      0.85      0.85       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8299641775288333, 0.8299641775288333)




CCA coefficients mean non-concern: (0.8286259019474012, 0.8286259019474012)




Linear CKA concern: 0.9735432205435307




Linear CKA non-concern: 0.9572038647745083




Kernel CKA concern: 0.9678828293446895




Kernel CKA non-concern: 0.9580538839205189




Evaluate the pruned model 13




Evaluating:   0%|                                                                                             …

Loss: 0.9365




Precision: 0.7753, Recall: 0.7801, F1-Score: 0.7732




              precision    recall  f1-score   support

           0       0.77      0.65      0.71       797
           1       0.84      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.88      0.81      0.84      1110
           4       0.85      0.80      0.82      1260
           5       0.89      0.69      0.78       882
           6       0.85      0.79      0.82       940
           7       0.46      0.59      0.52       473
           8       0.66      0.85      0.74       746
           9       0.57      0.73      0.64       689
          10       0.77      0.78      0.77       670
          11       0.64      0.79      0.71       312
          12       0.69      0.81      0.74       665
          13       0.83      0.86      0.84       314
          14       0.86      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8304041929925502, 0.8304041929925502)




CCA coefficients mean non-concern: (0.8237314864788988, 0.8237314864788988)




Linear CKA concern: 0.9822597560417343




Linear CKA non-concern: 0.9526316675626083




Kernel CKA concern: 0.9747878156699655




Kernel CKA non-concern: 0.9534443431151464




Evaluate the pruned model 14




Evaluating:   0%|                                                                                             …

Loss: 0.9341




Precision: 0.7771, Recall: 0.7805, F1-Score: 0.7749




              precision    recall  f1-score   support

           0       0.76      0.66      0.70       797
           1       0.85      0.71      0.77       775
           2       0.87      0.87      0.87       795
           3       0.87      0.82      0.85      1110
           4       0.85      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.79      0.82       940
           7       0.47      0.59      0.52       473
           8       0.66      0.85      0.75       746
           9       0.59      0.72      0.65       689
          10       0.76      0.78      0.77       670
          11       0.67      0.79      0.73       312
          12       0.68      0.81      0.74       665
          13       0.84      0.85      0.84       314
          14       0.85      0.78      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.831514812179185, 0.831514812179185)




CCA coefficients mean non-concern: (0.830132420134685, 0.830132420134685)




Linear CKA concern: 0.9744945223669494




Linear CKA non-concern: 0.9557962210386877




Kernel CKA concern: 0.9720433935156177




Kernel CKA non-concern: 0.956624717941665




Evaluate the pruned model 15




Evaluating:   0%|                                                                                             …

Loss: 0.9377




Precision: 0.7765, Recall: 0.7774, F1-Score: 0.7722




              precision    recall  f1-score   support

           0       0.77      0.63      0.69       797
           1       0.84      0.70      0.77       775
           2       0.88      0.87      0.87       795
           3       0.88      0.81      0.84      1110
           4       0.85      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.79      0.82       940
           7       0.48      0.59      0.53       473
           8       0.64      0.86      0.73       746
           9       0.56      0.74      0.64       689
          10       0.77      0.77      0.77       670
          11       0.67      0.79      0.72       312
          12       0.70      0.81      0.75       665
          13       0.83      0.85      0.84       314
          14       0.86      0.77      0.81       756
          15       0.96      0.98      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.78   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.829996096809566, 0.829996096809566)




CCA coefficients mean non-concern: (0.8142171011224574, 0.8142171011224574)




Linear CKA concern: 0.9726415054552747




Linear CKA non-concern: 0.9368389967571633




Kernel CKA concern: 0.9692731119565329




Kernel CKA non-concern: 0.9360199658670252


