In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.4
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-30 05:40:21


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}




The model sadickam/sdg-classification-bert is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}




Loading cached dataset OSDG.




The dataset OSDG is loaded




In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

Evaluate the pruned model 0




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.4051




Precision: 0.6198, Recall: 0.1705, F1-Score: 0.1360




              precision    recall  f1-score   support

           0       0.91      0.01      0.02       797
           1       1.00      0.00      0.01       775
           2       1.00      0.01      0.02       795
           3       0.90      0.28      0.43      1110
           4       0.74      0.12      0.21      1260
           5       0.80      0.01      0.02       882
           6       0.85      0.09      0.15       940
           7       0.00      0.00      0.00       473
           8       0.87      0.02      0.03       746
           9       0.44      0.04      0.07       689
          10       0.16      0.74      0.27       670
          11       0.00      0.00      0.00       312
          12       0.62      0.01      0.01       665
          13       0.80      0.14      0.23       314
          14       0.64      0.28      0.39       756
          15       0.19      0.99      0.31      1607

    accuracy                           0.23     12791
   macro avg       0.62   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5101434215142439, 0.5101434215142439)




CCA coefficients mean non-concern: (0.5009616043362842, 0.5009616043362842)




Linear CKA concern: 0.05852129908701758




Linear CKA non-concern: 0.10782792519352234




Kernel CKA concern: 0.04914409047379938




Kernel CKA non-concern: 0.06995611145163565




Evaluate the pruned model 1




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.4146




Precision: 0.5917, Recall: 0.1818, F1-Score: 0.1452




              precision    recall  f1-score   support

           0       0.00      0.00      0.00       797
           1       0.80      0.07      0.13       775
           2       0.92      0.12      0.21       795
           3       0.86      0.42      0.56      1110
           4       0.84      0.20      0.32      1260
           5       0.00      0.00      0.00       882
           6       0.83      0.01      0.01       940
           7       0.00      0.00      0.00       473
           8       1.00      0.01      0.01       746
           9       0.36      0.02      0.04       689
          10       0.08      0.94      0.16       670
          11       1.00      0.00      0.01       312
          12       0.75      0.00      0.01       665
          13       0.81      0.07      0.13       314
          14       0.84      0.13      0.22       756
          15       0.36      0.93      0.52      1607

    accuracy                           0.24     12791
   macro avg       0.59   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.4924725186711448, 0.4924725186711448)




CCA coefficients mean non-concern: (0.5035669936613943, 0.5035669936613943)




Linear CKA concern: 0.05423231702395039




Linear CKA non-concern: 0.11365121903482615




Kernel CKA concern: 0.05876678527301401




Kernel CKA non-concern: 0.05837561617074943




Evaluate the pruned model 2




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.4531




Precision: 0.4461, Recall: 0.0815, F1-Score: 0.0492




              precision    recall  f1-score   support

           0       0.00      0.00      0.00       797
           1       0.00      0.00      0.00       775
           2       1.00      0.04      0.08       795
           3       0.97      0.03      0.06      1110
           4       0.88      0.04      0.07      1260
           5       0.00      0.00      0.00       882
           6       0.00      0.00      0.00       940
           7       0.00      0.00      0.00       473
           8       1.00      0.00      0.00       746
           9       0.33      0.00      0.00       689
          10       0.75      0.10      0.17       670
          11       0.00      0.00      0.00       312
          12       0.25      0.00      0.00       665
          13       1.00      0.00      0.01       314
          14       0.82      0.09      0.17       756
          15       0.13      1.00      0.23      1607

    accuracy                           0.15     12791
   macro avg       0.45   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.4971510847373771, 0.4971510847373771)




CCA coefficients mean non-concern: (0.5001444677799807, 0.5001444677799807)




Linear CKA concern: 0.08001197663087674




Linear CKA non-concern: 0.10557730829357516




Kernel CKA concern: 0.04046381486752681




Kernel CKA non-concern: 0.05306811617142356




Evaluate the pruned model 3




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.4306




Precision: 0.5230, Recall: 0.2437, F1-Score: 0.2165




              precision    recall  f1-score   support

           0       0.60      0.00      0.01       797
           1       0.50      0.00      0.00       775
           2       0.97      0.12      0.22       795
           3       0.84      0.55      0.66      1110
           4       0.51      0.40      0.45      1260
           5       0.80      0.03      0.05       882
           6       0.83      0.04      0.07       940
           7       0.00      0.00      0.00       473
           8       0.86      0.02      0.03       746
           9       0.20      0.00      0.01       689
          10       0.09      0.93      0.16       670
          11       0.00      0.00      0.00       312
          12       0.65      0.30      0.41       665
          13       0.65      0.26      0.37       314
          14       0.38      0.41      0.39       756
          15       0.50      0.84      0.62      1607

    accuracy                           0.30     12791
   macro avg       0.52   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5041298126922544, 0.5041298126922544)




CCA coefficients mean non-concern: (0.499442530241953, 0.499442530241953)




Linear CKA concern: 0.1340313640256479




Linear CKA non-concern: 0.11060518670149982




Kernel CKA concern: 0.08067980732346691




Kernel CKA non-concern: 0.04517638027061818




Evaluate the pruned model 4




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.4483




Precision: 0.6154, Recall: 0.1877, F1-Score: 0.1569




              precision    recall  f1-score   support

           0       0.70      0.01      0.02       797
           1       0.94      0.02      0.04       775
           2       1.00      0.00      0.01       795
           3       0.88      0.03      0.05      1110
           4       0.86      0.32      0.46      1260
           5       0.80      0.00      0.01       882
           6       0.90      0.05      0.09       940
           7       0.00      0.00      0.00       473
           8       1.00      0.00      0.00       746
           9       0.46      0.07      0.12       689
          10       0.07      0.97      0.13       670
          11       0.00      0.00      0.00       312
          12       0.75      0.06      0.12       665
          13       0.45      0.43      0.44       314
          14       0.50      0.42      0.46       756
          15       0.52      0.62      0.57      1607

    accuracy                           0.21     12791
   macro avg       0.62   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5008733040717932, 0.5008733040717932)




CCA coefficients mean non-concern: (0.5037993380392521, 0.5037993380392521)




Linear CKA concern: 0.04744980044769999




Linear CKA non-concern: 0.08834283767596839




Kernel CKA concern: 0.0342389553405869




Kernel CKA non-concern: 0.051994299055434026




Evaluate the pruned model 5




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.5233




Precision: 0.5327, Recall: 0.0735, F1-Score: 0.0348




              precision    recall  f1-score   support

           0       0.67      0.00      0.00       797
           1       0.80      0.01      0.01       775
           2       1.00      0.01      0.03       795
           3       1.00      0.03      0.05      1110
           4       0.69      0.09      0.16      1260
           5       0.75      0.01      0.02       882
           6       0.00      0.00      0.00       940
           7       0.00      0.00      0.00       473
           8       1.00      0.00      0.00       746
           9       0.26      0.01      0.02       689
          10       0.73      0.01      0.02       670
          11       0.00      0.00      0.00       312
          12       0.50      0.00      0.00       665
          13       0.00      0.00      0.00       314
          14       1.00      0.00      0.01       756
          15       0.13      1.00      0.23      1607

    accuracy                           0.14     12791
   macro avg       0.53   




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e17cd7261f0>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e17cd7261f0>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e17cd7261f0>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e17cd7261f0>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e17cd7261f0>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e17cd7261f0>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




Exception ignored in: 

<function _MultiProcessingDataLoaderIter.__del__ at 0x7e17cd7261f0>




Traceback (most recent call last):


  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1477, in __del__


    

self._shutdown_workers()




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1460, in _shutdown_workers


    

if w.is_alive():




  File "/home/jieungkim/anaconda3/envs/DecomposeTransformer/lib/python3.8/multiprocessing/process.py", line 160, in is_alive


    

assert self._parent_pid == os.getpid(), 'can only test a child process'




AssertionError

: 

can only test a child process




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.496259153200641, 0.496259153200641)




CCA coefficients mean non-concern: (0.4977055434743795, 0.4977055434743795)




Linear CKA concern: 0.05052839289274509




Linear CKA non-concern: 0.12534376179568227




Kernel CKA concern: 0.0346092349309519




Kernel CKA non-concern: 0.06423837623016865




Evaluate the pruned model 6




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.3855




Precision: 0.4645, Recall: 0.1887, F1-Score: 0.1742




              precision    recall  f1-score   support

           0       0.00      0.00      0.00       797
           1       0.00      0.00      0.00       775
           2       0.71      0.31      0.43       795
           3       0.45      0.79      0.58      1110
           4       0.68      0.16      0.26      1260
           5       0.00      0.00      0.00       882
           6       0.81      0.34      0.48       940
           7       0.00      0.00      0.00       473
           8       1.00      0.00      0.00       746
           9       0.44      0.02      0.04       689
          10       0.83      0.04      0.08       670
          11       0.00      0.00      0.00       312
          12       0.81      0.03      0.06       665
          13       0.67      0.15      0.25       314
          14       0.86      0.19      0.32       756
          15       0.16      0.97      0.28      1607

    accuracy                           0.27     12791
   macro avg       0.46   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.49347877927263134, 0.49347877927263134)




CCA coefficients mean non-concern: (0.500900215130649, 0.500900215130649)




Linear CKA concern: 0.09729710978765299




Linear CKA non-concern: 0.10913101674547901




Kernel CKA concern: 0.07644919143273078




Kernel CKA non-concern: 0.05415365460285517




Evaluate the pruned model 7




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.4964




Precision: 0.3640, Recall: 0.0875, F1-Score: 0.0543




              precision    recall  f1-score   support

           0       0.00      0.00      0.00       797
           1       0.00      0.00      0.00       775
           2       0.00      0.00      0.00       795
           3       0.92      0.32      0.47      1110
           4       0.89      0.02      0.04      1260
           5       0.00      0.00      0.00       882
           6       0.89      0.02      0.04       940
           7       0.00      0.00      0.00       473
           8       1.00      0.01      0.01       746
           9       0.32      0.01      0.02       689
          10       1.00      0.01      0.02       670
          11       0.00      0.00      0.00       312
          12       0.67      0.02      0.04       665
          13       0.00      0.00      0.00       314
          14       0.00      0.00      0.00       756
          15       0.13      1.00      0.23      1607

    accuracy                           0.16     12791
   macro avg       0.36   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.507109046468378, 0.507109046468378)




CCA coefficients mean non-concern: (0.501294359857932, 0.501294359857932)




Linear CKA concern: 0.09809548473151462




Linear CKA non-concern: 0.10016336029978755




Kernel CKA concern: 0.04825841277206293




Kernel CKA non-concern: 0.05788215938301816




Evaluate the pruned model 8




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.3889




Precision: 0.5014, Recall: 0.1493, F1-Score: 0.1282




              precision    recall  f1-score   support

           0       0.84      0.03      0.07       797
           1       0.00      0.00      0.00       775
           2       0.73      0.08      0.15       795
           3       0.58      0.66      0.62      1110
           4       0.45      0.31      0.37      1260
           5       0.00      0.00      0.00       882
           6       0.87      0.05      0.09       940
           7       0.00      0.00      0.00       473
           8       0.82      0.02      0.04       746
           9       0.52      0.09      0.15       689
          10       0.84      0.06      0.11       670
          11       0.00      0.00      0.00       312
          12       0.67      0.02      0.03       665
          13       0.70      0.02      0.04       314
          14       0.85      0.07      0.12       756
          15       0.15      0.98      0.27      1607

    accuracy                           0.24     12791
   macro avg       0.50   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5084070024646143, 0.5084070024646143)




CCA coefficients mean non-concern: (0.5036865474910605, 0.5036865474910605)




Linear CKA concern: 0.05798949132712282




Linear CKA non-concern: 0.10444264496429236




Kernel CKA concern: 0.030280660280681285




Kernel CKA non-concern: 0.051451487703207545




Evaluate the pruned model 9




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.3766




Precision: 0.5509, Recall: 0.2497, F1-Score: 0.2037




              precision    recall  f1-score   support

           0       0.00      0.00      0.00       797
           1       0.00      0.00      0.00       775
           2       0.95      0.03      0.05       795
           3       0.48      0.80      0.60      1110
           4       0.79      0.48      0.60      1260
           5       0.75      0.01      0.02       882
           6       0.84      0.14      0.24       940
           7       0.27      0.01      0.01       473
           8       1.00      0.02      0.03       746
           9       0.34      0.42      0.38       689
          10       0.19      0.74      0.31       670
          11       0.67      0.01      0.01       312
          12       0.80      0.08      0.15       665
          13       0.62      0.31      0.42       314
          14       0.87      0.03      0.07       756
          15       0.24      0.92      0.38      1607

    accuracy                           0.32     12791
   macro avg       0.55   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5092458243142747, 0.5092458243142747)




CCA coefficients mean non-concern: (0.5029219225952486, 0.5029219225952486)




Linear CKA concern: 0.14075783373257658




Linear CKA non-concern: 0.1006865442126072




Kernel CKA concern: 0.07831477384372042




Kernel CKA non-concern: 0.05646242394870411




Evaluate the pruned model 10




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.3377




Precision: 0.6455, Recall: 0.2270, F1-Score: 0.1753




              precision    recall  f1-score   support

           0       0.82      0.01      0.02       797
           1       0.89      0.01      0.02       775
           2       0.51      0.24      0.33       795
           3       0.93      0.20      0.33      1110
           4       0.20      0.89      0.33      1260
           5       0.80      0.00      0.01       882
           6       1.00      0.00      0.00       940
           7       0.00      0.00      0.00       473
           8       1.00      0.01      0.01       746
           9       0.30      0.51      0.38       689
          10       0.26      0.66      0.37       670
          11       0.83      0.02      0.03       312
          12       0.82      0.06      0.10       665
          13       0.79      0.04      0.07       314
          14       0.80      0.17      0.28       756
          15       0.38      0.83      0.52      1607

    accuracy                           0.30     12791
   macro avg       0.65   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5067094377061084, 0.5067094377061084)




CCA coefficients mean non-concern: (0.502865557995241, 0.502865557995241)




Linear CKA concern: 0.11656431851157158




Linear CKA non-concern: 0.11578853076401499




Kernel CKA concern: 0.06462951840075346




Kernel CKA non-concern: 0.06291768194368474




Evaluate the pruned model 11




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.4758




Precision: 0.4177, Recall: 0.1693, F1-Score: 0.1443




              precision    recall  f1-score   support

           0       0.50      0.00      0.00       797
           1       0.00      0.00      0.00       775
           2       0.93      0.19      0.32       795
           3       0.76      0.54      0.63      1110
           4       0.33      0.38      0.35      1260
           5       0.56      0.01      0.01       882
           6       0.00      0.00      0.00       940
           7       0.00      0.00      0.00       473
           8       1.00      0.00      0.01       746
           9       0.35      0.25      0.30       689
          10       0.31      0.29      0.30       670
          11       1.00      0.00      0.01       312
          12       0.00      0.00      0.00       665
          13       0.00      0.00      0.00       314
          14       0.77      0.05      0.09       756
          15       0.17      1.00      0.30      1607

    accuracy                           0.25     12791
   macro avg       0.42   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5021473683436096, 0.5021473683436096)




CCA coefficients mean non-concern: (0.49921649159646964, 0.49921649159646964)




Linear CKA concern: 0.04358137698285093




Linear CKA non-concern: 0.12266070099633304




Kernel CKA concern: 0.047680051224673545




Kernel CKA non-concern: 0.0584284217280743




Evaluate the pruned model 12




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.3656




Precision: 0.4814, Recall: 0.2257, F1-Score: 0.2183




              precision    recall  f1-score   support

           0       0.67      0.00      0.00       797
           1       0.80      0.01      0.01       775
           2       0.76      0.28      0.41       795
           3       0.49      0.79      0.60      1110
           4       0.51      0.30      0.38      1260
           5       0.60      0.00      0.01       882
           6       0.80      0.24      0.37       940
           7       0.00      0.00      0.00       473
           8       0.00      0.00      0.00       746
           9       0.25      0.01      0.01       689
          10       0.32      0.33      0.33       670
          11       0.00      0.00      0.00       312
          12       0.73      0.24      0.37       665
          13       0.74      0.25      0.38       314
          14       0.85      0.20      0.32       756
          15       0.18      0.96      0.31      1607

    accuracy                           0.30     12791
   macro avg       0.48   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5036797552508637, 0.5036797552508637)




CCA coefficients mean non-concern: (0.5024771867284393, 0.5024771867284393)




Linear CKA concern: 0.05794266579646472




Linear CKA non-concern: 0.09405436568828603




Kernel CKA concern: 0.03191690226956986




Kernel CKA non-concern: 0.04979803224837604




Evaluate the pruned model 13




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.4708




Precision: 0.4926, Recall: 0.1561, F1-Score: 0.1318




              precision    recall  f1-score   support

           0       0.80      0.01      0.01       797
           1       0.80      0.02      0.04       775
           2       0.77      0.09      0.17       795
           3       0.96      0.18      0.30      1110
           4       0.41      0.52      0.46      1260
           5       1.00      0.00      0.00       882
           6       0.00      0.00      0.00       940
           7       0.67      0.00      0.01       473
           8       0.00      0.00      0.00       746
           9       0.46      0.08      0.14       689
          10       0.21      0.27      0.23       670
          11       0.00      0.00      0.00       312
          12       0.00      0.00      0.00       665
          13       0.76      0.32      0.45       314
          14       0.88      0.01      0.02       756
          15       0.16      0.99      0.28      1607

    accuracy                           0.23     12791
   macro avg       0.49   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5050471424352861, 0.5050471424352861)




CCA coefficients mean non-concern: (0.4994819447061824, 0.4994819447061824)




Linear CKA concern: 0.11257538801150936




Linear CKA non-concern: 0.1145358100267737




Kernel CKA concern: 0.07545639858204312




Kernel CKA non-concern: 0.05556128063560024




Evaluate the pruned model 14




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.4746




Precision: 0.5129, Recall: 0.2187, F1-Score: 0.1770




              precision    recall  f1-score   support

           0       1.00      0.00      0.00       797
           1       0.88      0.01      0.02       775
           2       0.97      0.04      0.08       795
           3       0.76      0.51      0.61      1110
           4       0.73      0.27      0.39      1260
           5       0.60      0.00      0.01       882
           6       0.00      0.00      0.00       940
           7       0.50      0.00      0.01       473
           8       0.73      0.04      0.07       746
           9       0.34      0.30      0.32       689
          10       0.11      0.84      0.19       670
          11       0.00      0.00      0.00       312
          12       0.00      0.00      0.00       665
          13       0.48      0.41      0.44       314
          14       0.84      0.16      0.27       756
          15       0.27      0.92      0.42      1607

    accuracy                           0.27     12791
   macro avg       0.51   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.49098738264298863, 0.49098738264298863)




CCA coefficients mean non-concern: (0.5024371254451416, 0.5024371254451416)




Linear CKA concern: 0.045236609613800646




Linear CKA non-concern: 0.11084423508252414




Kernel CKA concern: 0.030163199369694544




Kernel CKA non-concern: 0.059021299637961745




Evaluate the pruned model 15




Evaluating:   0%|          | 0/800 [00:00<?, ?it/s]

Loss: 2.4961




Precision: 0.4240, Recall: 0.1209, F1-Score: 0.0682




              precision    recall  f1-score   support

           0       0.93      0.02      0.03       797
           1       0.00      0.00      0.00       775
           2       0.00      0.00      0.00       795
           3       0.00      0.00      0.00      1110
           4       0.89      0.04      0.07      1260
           5       0.67      0.00      0.00       882
           6       0.00      0.00      0.00       940
           7       0.09      0.49      0.16       473
           8       0.39      0.09      0.14       746
           9       0.37      0.05      0.08       689
          10       0.78      0.02      0.04       670
          11       0.78      0.02      0.04       312
          12       0.50      0.00      0.01       665
          13       0.23      0.21      0.22       314
          14       1.00      0.00      0.00       756
          15       0.17      1.00      0.28      1607

    accuracy                           0.16     12791
   macro avg       0.42   




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.49828740646156666, 0.49828740646156666)




CCA coefficients mean non-concern: (0.4957884974196339, 0.4957884974196339)




Linear CKA concern: 0.4353100435794515




Linear CKA non-concern: 0.16694895067158627




Kernel CKA concern: 0.3546911987684




Kernel CKA non-concern: 0.08241152532370716


