In [1]:
import os
import sys
sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_wanda
)

In [3]:
name= "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size=32
num_workers=48
num_samples=16
wanda_ratio=0.4
seed=44
include_layers=["attention", "intermediate", "output"]
exclude_layers=None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-19 14:41:59


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}




The model sadickam/sdg-classification-bert is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}




Loading cached dataset OSDG.




The dataset OSDG is loaded




In [7]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [8]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [9]:
# Evaluate the original model
# Evaluating: 100%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗| 200/200 [03:16<00:00,  1.02it/s]
# Loss: 0.9485
# Precision: 0.7801, Recall: 0.7867, F1-Score: 0.7793
#               precision    recall  f1-score   support

#            0       0.77      0.66      0.71       797
#            1       0.84      0.72      0.78       775
#            2       0.88      0.87      0.88       795
#            3       0.87      0.83      0.85      1110
#            4       0.86      0.80      0.83      1260
#            5       0.88      0.69      0.77       882
#            6       0.85      0.80      0.83       940
#            7       0.49      0.61      0.54       473
#            8       0.66      0.85      0.74       746
#            9       0.62      0.73      0.67       689
#           10       0.75      0.79      0.77       670
#           11       0.62      0.81      0.70       312
#           12       0.73      0.81      0.77       665
#           13       0.83      0.85      0.84       314
#           14       0.85      0.78      0.81       756
#           15       0.97      0.98      0.97      1607

#     accuracy                           0.80     12791
#    macro avg       0.78      0.79      0.78     12791
# weighted avg       0.81      0.80      0.80     12791

In [10]:
module = copy.deepcopy(model)
prune_wanda(module, model_config, all_samples, sparsity_ratio=wanda_ratio, include_layers=include_layers, exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(module, model_config, test_dataloader)
# save_module(module, "Modules/", f"wanda_{name}_{wanda_ratio}p.pt")

Evaluate the pruned model




Evaluating:   0%|          | 0/200 [00:00<?, ?it/s]

Evaluating:   0%|          | 1/200 [00:00<01:25,  2.32it/s]

Evaluating:   1%|          | 2/200 [00:00<01:24,  2.35it/s]

Evaluating:   2%|?둞         | 3/200 [00:01<01:24,  2.34it/s]

Evaluating:   2%|?둞         | 4/200 [00:01<01:23,  2.34it/s]

Evaluating:   2%|?둝         | 5/200 [00:02<01:23,  2.33it/s]

Evaluating:   3%|?둝         | 6/200 [00:02<01:23,  2.34it/s]

Evaluating:   4%|?둝         | 7/200 [00:02<01:22,  2.33it/s]

Evaluating:   4%|?둜         | 8/200 [00:03<01:22,  2.33it/s]

Evaluating:   4%|?둜         | 9/200 [00:03<01:21,  2.33it/s]

Evaluating:   5%|?둛         | 10/200 [00:04<01:21,  2.33it/s]

Evaluating:   6%|?둛         | 11/200 [00:04<01:21,  2.33it/s]

Evaluating:   6%|?둛         | 12/200 [00:05<01:20,  2.33it/s]

Evaluating:   6%|?둚         | 13/200 [00:05<01:20,  2.32it/s]

Evaluating:   7%|?둚         | 14/200 [00:06<01:20,  2.32it/s]

Evaluating:   8%|?둙         | 15/200 [00:06<01:19,  2.32it/s]

Evaluating:   8%|?둙         | 16/200 [00:06<01:19,  2.32it/s]

Evaluating:   8%|?둙         | 17/200 [00:07<01:18,  2.32it/s]

Evaluating:   9%|?둘         | 18/200 [00:07<01:18,  2.32it/s]

Evaluating:  10%|?둘         | 19/200 [00:08<01:18,  2.31it/s]

Evaluating:  10%|?둗         | 20/200 [00:08<01:17,  2.31it/s]

Evaluating:  10%|?둗         | 21/200 [00:09<01:17,  2.31it/s]

Evaluating:  11%|?둗         | 22/200 [00:09<01:17,  2.31it/s]

Evaluating:  12%|?둗?둞        | 23/200 [00:09<01:16,  2.31it/s]

Evaluating:  12%|?둗?둞        | 24/200 [00:10<01:16,  2.30it/s]

Evaluating:  12%|?둗?둝        | 25/200 [00:10<01:16,  2.29it/s]

Evaluating:  13%|?둗?둝        | 26/200 [00:11<01:17,  2.26it/s]

Evaluating:  14%|?둗?둝        | 27/200 [00:11<01:18,  2.19it/s]

Evaluating:  14%|?둗?둜        | 28/200 [00:12<01:20,  2.14it/s]

Evaluating:  14%|?둗?둜        | 29/200 [00:12<01:23,  2.05it/s]

Evaluating:  15%|?둗?둛        | 30/200 [00:13<01:25,  1.98it/s]

Evaluating:  16%|?둗?둛        | 31/200 [00:13<01:29,  1.89it/s]

Evaluating:  16%|?둗?둛        | 32/200 [00:14<01:33,  1.80it/s]

Evaluating:  16%|?둗?둚        | 33/200 [00:15<01:35,  1.75it/s]

Evaluating:  17%|?둗?둚        | 34/200 [00:15<01:41,  1.64it/s]

Evaluating:  18%|?둗?둙        | 35/200 [00:16<01:46,  1.54it/s]

Evaluating:  18%|?둗?둙        | 36/200 [00:17<01:50,  1.48it/s]

Evaluating:  18%|?둗?둙        | 37/200 [00:18<01:53,  1.44it/s]

Evaluating:  19%|?둗?둘        | 38/200 [00:18<01:54,  1.41it/s]

Evaluating:  20%|?둗?둘        | 39/200 [00:19<01:55,  1.39it/s]

Evaluating:  20%|?둗?둗        | 40/200 [00:20<01:55,  1.38it/s]

Evaluating:  20%|?둗?둗        | 41/200 [00:20<01:55,  1.38it/s]

Evaluating:  21%|?둗?둗        | 42/200 [00:21<01:57,  1.35it/s]

Evaluating:  22%|?둗?둗?둞       | 43/200 [00:22<01:58,  1.32it/s]

Evaluating:  22%|?둗?둗?둞       | 44/200 [00:23<01:59,  1.30it/s]

Evaluating:  22%|?둗?둗?둝       | 45/200 [00:24<02:00,  1.29it/s]

Evaluating:  23%|?둗?둗?둝       | 46/200 [00:24<02:00,  1.28it/s]

Evaluating:  24%|?둗?둗?둝       | 47/200 [00:25<02:00,  1.27it/s]

Evaluating:  24%|?둗?둗?둜       | 48/200 [00:26<02:00,  1.27it/s]

Evaluating:  24%|?둗?둗?둜       | 49/200 [00:27<01:59,  1.26it/s]

Evaluating:  25%|?둗?둗?둛       | 50/200 [00:28<02:02,  1.23it/s]

Evaluating:  26%|?둗?둗?둛       | 51/200 [00:29<02:05,  1.18it/s]

Evaluating:  26%|?둗?둗?둛       | 52/200 [00:30<02:08,  1.16it/s]

Evaluating:  26%|?둗?둗?둚       | 53/200 [00:30<02:09,  1.14it/s]

Evaluating:  27%|?둗?둗?둚       | 54/200 [00:31<02:09,  1.13it/s]

Evaluating:  28%|?둗?둗?둙       | 55/200 [00:32<02:09,  1.12it/s]

Evaluating:  28%|?둗?둗?둙       | 56/200 [00:33<02:09,  1.11it/s]

Evaluating:  28%|?둗?둗?둙       | 57/200 [00:34<02:09,  1.11it/s]

Evaluating:  29%|?둗?둗?둘       | 58/200 [00:35<02:08,  1.11it/s]

Evaluating:  30%|?둗?둗?둘       | 59/200 [00:36<02:07,  1.11it/s]

Evaluating:  30%|?둗?둗?둗       | 60/200 [00:37<02:06,  1.11it/s]

Evaluating:  30%|?둗?둗?둗       | 61/200 [00:38<02:05,  1.10it/s]

Evaluating:  31%|?둗?둗?둗       | 62/200 [00:39<02:05,  1.10it/s]

Evaluating:  32%|?둗?둗?둗?둞      | 63/200 [00:40<02:04,  1.10it/s]

Evaluating:  32%|?둗?둗?둗?둞      | 64/200 [00:40<02:03,  1.10it/s]

Evaluating:  32%|?둗?둗?둗?둝      | 65/200 [00:41<02:02,  1.10it/s]

Evaluating:  33%|?둗?둗?둗?둝      | 66/200 [00:42<02:01,  1.10it/s]

Evaluating:  34%|?둗?둗?둗?둝      | 67/200 [00:43<02:00,  1.10it/s]

Evaluating:  34%|?둗?둗?둗?둜      | 68/200 [00:44<01:59,  1.10it/s]

Evaluating:  34%|?둗?둗?둗?둜      | 69/200 [00:45<01:58,  1.10it/s]

Evaluating:  35%|?둗?둗?둗?둛      | 70/200 [00:46<01:57,  1.10it/s]

Evaluating:  36%|?둗?둗?둗?둛      | 71/200 [00:47<01:56,  1.11it/s]

Evaluating:  36%|?둗?둗?둗?둛      | 72/200 [00:48<01:55,  1.10it/s]

Evaluating:  36%|?둗?둗?둗?둚      | 73/200 [00:49<01:55,  1.10it/s]

Evaluating:  37%|?둗?둗?둗?둚      | 74/200 [00:49<01:54,  1.10it/s]

Evaluating:  38%|?둗?둗?둗?둙      | 75/200 [00:50<01:53,  1.10it/s]

Evaluating:  38%|?둗?둗?둗?둙      | 76/200 [00:51<01:52,  1.10it/s]

Evaluating:  38%|?둗?둗?둗?둙      | 77/200 [00:52<01:51,  1.10it/s]

Evaluating:  39%|?둗?둗?둗?둘      | 78/200 [00:53<01:50,  1.10it/s]

Evaluating:  40%|?둗?둗?둗?둘      | 79/200 [00:54<01:49,  1.11it/s]

Evaluating:  40%|?둗?둗?둗?둗      | 80/200 [00:55<01:48,  1.11it/s]

Evaluating:  40%|?둗?둗?둗?둗      | 81/200 [00:56<01:47,  1.11it/s]

Evaluating:  41%|?둗?둗?둗?둗      | 82/200 [00:57<01:46,  1.11it/s]

Evaluating:  42%|?둗?둗?둗?둗?둞     | 83/200 [00:58<01:45,  1.11it/s]

Evaluating:  42%|?둗?둗?둗?둗?둞     | 84/200 [00:59<01:44,  1.11it/s]

Evaluating:  42%|?둗?둗?둗?둗?둝     | 85/200 [00:59<01:43,  1.11it/s]

Evaluating:  43%|?둗?둗?둗?둗?둝     | 86/200 [01:00<01:42,  1.11it/s]

Evaluating:  44%|?둗?둗?둗?둗?둝     | 87/200 [01:01<01:41,  1.11it/s]

Evaluating:  44%|?둗?둗?둗?둗?둜     | 88/200 [01:02<01:40,  1.11it/s]

Evaluating:  44%|?둗?둗?둗?둗?둜     | 89/200 [01:03<01:39,  1.11it/s]

Evaluating:  45%|?둗?둗?둗?둗?둛     | 90/200 [01:04<01:39,  1.11it/s]

Evaluating:  46%|?둗?둗?둗?둗?둛     | 91/200 [01:05<01:38,  1.11it/s]

Evaluating:  46%|?둗?둗?둗?둗?둛     | 92/200 [01:06<01:37,  1.11it/s]

Evaluating:  46%|?둗?둗?둗?둗?둚     | 93/200 [01:07<01:36,  1.11it/s]

Evaluating:  47%|?둗?둗?둗?둗?둚     | 94/200 [01:08<01:35,  1.11it/s]

Evaluating:  48%|?둗?둗?둗?둗?둙     | 95/200 [01:08<01:34,  1.11it/s]

Evaluating:  48%|?둗?둗?둗?둗?둙     | 96/200 [01:09<01:33,  1.11it/s]

Evaluating:  48%|?둗?둗?둗?둗?둙     | 97/200 [01:10<01:32,  1.11it/s]

Evaluating:  49%|?둗?둗?둗?둗?둘     | 98/200 [01:11<01:31,  1.11it/s]

Evaluating:  50%|?둗?둗?둗?둗?둘     | 99/200 [01:12<01:30,  1.11it/s]

Evaluating:  50%|?둗?둗?둗?둗?둗     | 100/200 [01:13<01:30,  1.11it/s]

Evaluating:  50%|?둗?둗?둗?둗?둗     | 101/200 [01:14<01:29,  1.11it/s]

Evaluating:  51%|?둗?둗?둗?둗?둗     | 102/200 [01:15<01:28,  1.11it/s]

Evaluating:  52%|?둗?둗?둗?둗?둗?둞    | 103/200 [01:16<01:27,  1.11it/s]

Evaluating:  52%|?둗?둗?둗?둗?둗?둞    | 104/200 [01:17<01:26,  1.11it/s]

Evaluating:  52%|?둗?둗?둗?둗?둗?둝    | 105/200 [01:17<01:25,  1.11it/s]

Evaluating:  53%|?둗?둗?둗?둗?둗?둝    | 106/200 [01:18<01:24,  1.11it/s]

Evaluating:  54%|?둗?둗?둗?둗?둗?둝    | 107/200 [01:19<01:23,  1.11it/s]

Evaluating:  54%|?둗?둗?둗?둗?둗?둜    | 108/200 [01:20<01:22,  1.11it/s]

Evaluating:  55%|?둗?둗?둗?둗?둗?둜    | 109/200 [01:21<01:22,  1.11it/s]

Evaluating:  55%|?둗?둗?둗?둗?둗?둛    | 110/200 [01:22<01:21,  1.11it/s]

Evaluating:  56%|?둗?둗?둗?둗?둗?둛    | 111/200 [01:23<01:20,  1.11it/s]

Evaluating:  56%|?둗?둗?둗?둗?둗?둛    | 112/200 [01:24<01:19,  1.11it/s]

Evaluating:  56%|?둗?둗?둗?둗?둗?둚    | 113/200 [01:25<01:18,  1.11it/s]

Evaluating:  57%|?둗?둗?둗?둗?둗?둚    | 114/200 [01:26<01:17,  1.11it/s]

Evaluating:  57%|?둗?둗?둗?둗?둗?둙    | 115/200 [01:26<01:16,  1.11it/s]

Evaluating:  58%|?둗?둗?둗?둗?둗?둙    | 116/200 [01:27<01:15,  1.11it/s]

Evaluating:  58%|?둗?둗?둗?둗?둗?둙    | 117/200 [01:28<01:14,  1.11it/s]

Evaluating:  59%|?둗?둗?둗?둗?둗?둘    | 118/200 [01:29<01:13,  1.11it/s]

Evaluating:  60%|?둗?둗?둗?둗?둗?둘    | 119/200 [01:30<01:12,  1.11it/s]

Evaluating:  60%|?둗?둗?둗?둗?둗?둗    | 120/200 [01:31<01:12,  1.11it/s]

Evaluating:  60%|?둗?둗?둗?둗?둗?둗    | 121/200 [01:32<01:11,  1.11it/s]

Evaluating:  61%|?둗?둗?둗?둗?둗?둗    | 122/200 [01:33<01:10,  1.11it/s]

Evaluating:  62%|?둗?둗?둗?둗?둗?둗?둞   | 123/200 [01:34<01:09,  1.11it/s]

Evaluating:  62%|?둗?둗?둗?둗?둗?둗?둞   | 124/200 [01:35<01:08,  1.11it/s]

Evaluating:  62%|?둗?둗?둗?둗?둗?둗?둝   | 125/200 [01:35<01:07,  1.11it/s]

Evaluating:  63%|?둗?둗?둗?둗?둗?둗?둝   | 126/200 [01:36<01:06,  1.11it/s]

Evaluating:  64%|?둗?둗?둗?둗?둗?둗?둝   | 127/200 [01:37<01:05,  1.11it/s]

Evaluating:  64%|?둗?둗?둗?둗?둗?둗?둜   | 128/200 [01:38<01:04,  1.11it/s]

Evaluating:  64%|?둗?둗?둗?둗?둗?둗?둜   | 129/200 [01:39<01:04,  1.11it/s]

Evaluating:  65%|?둗?둗?둗?둗?둗?둗?둛   | 130/200 [01:40<01:03,  1.11it/s]

Evaluating:  66%|?둗?둗?둗?둗?둗?둗?둛   | 131/200 [01:41<01:02,  1.11it/s]

Evaluating:  66%|?둗?둗?둗?둗?둗?둗?둛   | 132/200 [01:42<01:01,  1.11it/s]

Evaluating:  66%|?둗?둗?둗?둗?둗?둗?둚   | 133/200 [01:43<01:00,  1.11it/s]

Evaluating:  67%|?둗?둗?둗?둗?둗?둗?둚   | 134/200 [01:44<00:59,  1.11it/s]

Evaluating:  68%|?둗?둗?둗?둗?둗?둗?둙   | 135/200 [01:44<00:58,  1.11it/s]

Evaluating:  68%|?둗?둗?둗?둗?둗?둗?둙   | 136/200 [01:45<00:57,  1.11it/s]

Evaluating:  68%|?둗?둗?둗?둗?둗?둗?둙   | 137/200 [01:46<00:56,  1.11it/s]

Evaluating:  69%|?둗?둗?둗?둗?둗?둗?둘   | 138/200 [01:47<00:55,  1.11it/s]

Evaluating:  70%|?둗?둗?둗?둗?둗?둗?둘   | 139/200 [01:48<00:54,  1.11it/s]

Evaluating:  70%|?둗?둗?둗?둗?둗?둗?둗   | 140/200 [01:49<00:53,  1.11it/s]

Evaluating:  70%|?둗?둗?둗?둗?둗?둗?둗   | 141/200 [01:50<00:53,  1.11it/s]

Evaluating:  71%|?둗?둗?둗?둗?둗?둗?둗   | 142/200 [01:51<00:52,  1.11it/s]

Evaluating:  72%|?둗?둗?둗?둗?둗?둗?둗?둞  | 143/200 [01:52<00:51,  1.11it/s]

Evaluating:  72%|?둗?둗?둗?둗?둗?둗?둗?둞  | 144/200 [01:53<00:50,  1.11it/s]

Evaluating:  72%|?둗?둗?둗?둗?둗?둗?둗?둝  | 145/200 [01:53<00:49,  1.11it/s]

Evaluating:  73%|?둗?둗?둗?둗?둗?둗?둗?둝  | 146/200 [01:54<00:48,  1.11it/s]

Evaluating:  74%|?둗?둗?둗?둗?둗?둗?둗?둝  | 147/200 [01:55<00:47,  1.11it/s]

Evaluating:  74%|?둗?둗?둗?둗?둗?둗?둗?둜  | 148/200 [01:56<00:46,  1.11it/s]

Evaluating:  74%|?둗?둗?둗?둗?둗?둗?둗?둜  | 149/200 [01:57<00:45,  1.11it/s]

Evaluating:  75%|?둗?둗?둗?둗?둗?둗?둗?둛  | 150/200 [01:58<00:45,  1.11it/s]

Evaluating:  76%|?둗?둗?둗?둗?둗?둗?둗?둛  | 151/200 [01:59<00:44,  1.11it/s]

Evaluating:  76%|?둗?둗?둗?둗?둗?둗?둗?둛  | 152/200 [02:00<00:43,  1.11it/s]

Evaluating:  76%|?둗?둗?둗?둗?둗?둗?둗?둚  | 153/200 [02:01<00:42,  1.11it/s]

Evaluating:  77%|?둗?둗?둗?둗?둗?둗?둗?둚  | 154/200 [02:02<00:41,  1.11it/s]

Evaluating:  78%|?둗?둗?둗?둗?둗?둗?둗?둙  | 155/200 [02:02<00:40,  1.11it/s]

Evaluating:  78%|?둗?둗?둗?둗?둗?둗?둗?둙  | 156/200 [02:03<00:39,  1.11it/s]

Evaluating:  78%|?둗?둗?둗?둗?둗?둗?둗?둙  | 157/200 [02:04<00:38,  1.11it/s]

Evaluating:  79%|?둗?둗?둗?둗?둗?둗?둗?둘  | 158/200 [02:05<00:37,  1.11it/s]

Evaluating:  80%|?둗?둗?둗?둗?둗?둗?둗?둘  | 159/200 [02:06<00:36,  1.11it/s]

Evaluating:  80%|?둗?둗?둗?둗?둗?둗?둗?둗  | 160/200 [02:07<00:36,  1.11it/s]

Evaluating:  80%|?둗?둗?둗?둗?둗?둗?둗?둗  | 161/200 [02:08<00:35,  1.11it/s]

Evaluating:  81%|?둗?둗?둗?둗?둗?둗?둗?둗  | 162/200 [02:09<00:34,  1.11it/s]

Evaluating:  82%|?둗?둗?둗?둗?둗?둗?둗?둗?둞 | 163/200 [02:10<00:33,  1.11it/s]

Evaluating:  82%|?둗?둗?둗?둗?둗?둗?둗?둗?둞 | 164/200 [02:11<00:32,  1.11it/s]

Evaluating:  82%|?둗?둗?둗?둗?둗?둗?둗?둗?둝 | 165/200 [02:11<00:31,  1.11it/s]

Evaluating:  83%|?둗?둗?둗?둗?둗?둗?둗?둗?둝 | 166/200 [02:12<00:30,  1.11it/s]

Evaluating:  84%|?둗?둗?둗?둗?둗?둗?둗?둗?둝 | 167/200 [02:13<00:29,  1.11it/s]

Evaluating:  84%|?둗?둗?둗?둗?둗?둗?둗?둗?둜 | 168/200 [02:14<00:28,  1.11it/s]

Evaluating:  84%|?둗?둗?둗?둗?둗?둗?둗?둗?둜 | 169/200 [02:15<00:27,  1.11it/s]

Evaluating:  85%|?둗?둗?둗?둗?둗?둗?둗?둗?둛 | 170/200 [02:16<00:27,  1.11it/s]

Evaluating:  86%|?둗?둗?둗?둗?둗?둗?둗?둗?둛 | 171/200 [02:17<00:26,  1.11it/s]

Evaluating:  86%|?둗?둗?둗?둗?둗?둗?둗?둗?둛 | 172/200 [02:18<00:25,  1.11it/s]

Evaluating:  86%|?둗?둗?둗?둗?둗?둗?둗?둗?둚 | 173/200 [02:19<00:24,  1.11it/s]

Evaluating:  87%|?둗?둗?둗?둗?둗?둗?둗?둗?둚 | 174/200 [02:20<00:23,  1.11it/s]

Evaluating:  88%|?둗?둗?둗?둗?둗?둗?둗?둗?둙 | 175/200 [02:21<00:22,  1.11it/s]

Evaluating:  88%|?둗?둗?둗?둗?둗?둗?둗?둗?둙 | 176/200 [02:21<00:21,  1.11it/s]

Evaluating:  88%|?둗?둗?둗?둗?둗?둗?둗?둗?둙 | 177/200 [02:22<00:20,  1.11it/s]

Evaluating:  89%|?둗?둗?둗?둗?둗?둗?둗?둗?둘 | 178/200 [02:23<00:19,  1.11it/s]

Evaluating:  90%|?둗?둗?둗?둗?둗?둗?둗?둗?둘 | 179/200 [02:24<00:18,  1.11it/s]

Evaluating:  90%|?둗?둗?둗?둗?둗?둗?둗?둗?둗 | 180/200 [02:25<00:18,  1.11it/s]

Evaluating:  90%|?둗?둗?둗?둗?둗?둗?둗?둗?둗 | 181/200 [02:26<00:17,  1.11it/s]

Evaluating:  91%|?둗?둗?둗?둗?둗?둗?둗?둗?둗 | 182/200 [02:27<00:16,  1.11it/s]

Evaluating:  92%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둞| 183/200 [02:28<00:15,  1.11it/s]

Evaluating:  92%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둞| 184/200 [02:29<00:14,  1.11it/s]

Evaluating:  92%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둝| 185/200 [02:30<00:13,  1.11it/s]

Evaluating:  93%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둝| 186/200 [02:30<00:12,  1.11it/s]

Evaluating:  94%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둝| 187/200 [02:31<00:11,  1.11it/s]

Evaluating:  94%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둜| 188/200 [02:32<00:10,  1.11it/s]

Evaluating:  94%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둜| 189/200 [02:33<00:09,  1.11it/s]

Evaluating:  95%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둛| 190/200 [02:34<00:08,  1.11it/s]

Evaluating:  96%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둛| 191/200 [02:35<00:08,  1.11it/s]

Evaluating:  96%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둛| 192/200 [02:36<00:07,  1.11it/s]

Evaluating:  96%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둚| 193/200 [02:37<00:06,  1.12it/s]

Evaluating:  97%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둚| 194/200 [02:38<00:05,  1.12it/s]

Evaluating:  98%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둙| 195/200 [02:38<00:04,  1.12it/s]

Evaluating:  98%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둙| 196/200 [02:39<00:03,  1.11it/s]

Evaluating:  98%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둙| 197/200 [02:40<00:02,  1.11it/s]

Evaluating:  99%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둘| 198/200 [02:41<00:01,  1.11it/s]

Evaluating: 100%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둘| 199/200 [02:42<00:00,  1.11it/s]

Evaluating: 100%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗| 200/200 [02:43<00:00,  1.16it/s]

Evaluating: 100%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗| 200/200 [02:43<00:00,  1.22it/s]




Loss: 0.9428




Precision: 0.7766, Recall: 0.7810, F1-Score: 0.7748




              precision    recall  f1-score   support

           0       0.76      0.66      0.70       797
           1       0.85      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.87      0.82      0.85      1110
           4       0.85      0.80      0.82      1260
           5       0.89      0.68      0.77       882
           6       0.84      0.80      0.82       940
           7       0.47      0.59      0.52       473
           8       0.66      0.85      0.74       746
           9       0.59      0.73      0.65       689
          10       0.76      0.79      0.77       670
          11       0.66      0.80      0.72       312
          12       0.69      0.81      0.75       665
          13       0.83      0.85      0.84       314
          14       0.86      0.77      0.81       756
          15       0.97      0.97      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




In [11]:
for concern in range(num_labels):
    print(f"--{concern}--")
    positive_samples = SamplingDataset(
        train_dataloader, concern, num_samples, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train_dataloader, concern, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    similar(model, module, valid_dataloader, concern, num_samples, num_labels, device=device, seed=seed)

--0--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8226286097933246, 0.8226286097933246)




CCA coefficients mean non-concern: (0.8367720065178548, 0.8367720065178548)




Linear CKA concern: 0.9783394716334197




Linear CKA non-concern: 0.9608236129224744




Kernel CKA concern: 0.9730719476944569




Kernel CKA non-concern: 0.9622991798107942




--1--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8294714514241879, 0.8294714514241879)




CCA coefficients mean non-concern: (0.8356634587091096, 0.8356634587091096)




Linear CKA concern: 0.9671119140115609




Linear CKA non-concern: 0.9590982300082449




Kernel CKA concern: 0.9626117924317332




Kernel CKA non-concern: 0.9609736980512601




--2--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8274993916000501, 0.8274993916000501)




CCA coefficients mean non-concern: (0.8351424955890031, 0.8351424955890031)




Linear CKA concern: 0.9778631578498195




Linear CKA non-concern: 0.959155960873212




Kernel CKA concern: 0.9721124740736766




Kernel CKA non-concern: 0.9612715408607847




--3--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8291847107362506, 0.8291847107362506)




CCA coefficients mean non-concern: (0.8343154347770338, 0.8343154347770338)




Linear CKA concern: 0.9637792403554236




Linear CKA non-concern: 0.9604282015030198




Kernel CKA concern: 0.9619452024300883




Kernel CKA non-concern: 0.9626001628838934




--4--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8376996704755532, 0.8376996704755532)




CCA coefficients mean non-concern: (0.8347075618420705, 0.8347075618420705)




Linear CKA concern: 0.9746932887179454




Linear CKA non-concern: 0.9595125785456102




Kernel CKA concern: 0.9697857789005053




Kernel CKA non-concern: 0.9622451664586336




--5--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.826181847514096, 0.826181847514096)




CCA coefficients mean non-concern: (0.835922423082001, 0.835922423082001)




Linear CKA concern: 0.9542848152693344




Linear CKA non-concern: 0.9618804859420383




Kernel CKA concern: 0.946144046177586




Kernel CKA non-concern: 0.9638875692980095




--6--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.824526812772124, 0.824526812772124)




CCA coefficients mean non-concern: (0.8359432546828469, 0.8359432546828469)




Linear CKA concern: 0.9587361292024391




Linear CKA non-concern: 0.9602947933818162




Kernel CKA concern: 0.9522901212851893




Kernel CKA non-concern: 0.9630679257103381




--7--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8346600753026564, 0.8346600753026564)




CCA coefficients mean non-concern: (0.8342875801129447, 0.8342875801129447)




Linear CKA concern: 0.9668550338495071




Linear CKA non-concern: 0.9615493396295551




Kernel CKA concern: 0.9638206002172789




Kernel CKA non-concern: 0.9635719209390149




--8--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8341932121766125, 0.8341932121766125)




CCA coefficients mean non-concern: (0.8357776842502765, 0.8357776842502765)




Linear CKA concern: 0.9693368752609363




Linear CKA non-concern: 0.960324393435019




Kernel CKA concern: 0.964012581364805




Kernel CKA non-concern: 0.9627795462055394




--9--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8328644261241628, 0.8328644261241628)




CCA coefficients mean non-concern: (0.8344623879199119, 0.8344623879199119)




Linear CKA concern: 0.9753984873740874




Linear CKA non-concern: 0.959525874297616




Kernel CKA concern: 0.9688668952445518




Kernel CKA non-concern: 0.9620840768161482




--10--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8304124724374354, 0.8304124724374354)




CCA coefficients mean non-concern: (0.8364196863577948, 0.8364196863577948)




Linear CKA concern: 0.9722205695134822




Linear CKA non-concern: 0.9601746597825337




Kernel CKA concern: 0.9667031962579608




Kernel CKA non-concern: 0.9628825621684443




--11--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8311163706151093, 0.8311163706151093)




CCA coefficients mean non-concern: (0.8349856726076968, 0.8349856726076968)




Linear CKA concern: 0.9679692223919872




Linear CKA non-concern: 0.9610183895547093




Kernel CKA concern: 0.9618900108750613




Kernel CKA non-concern: 0.9632737158340113




--12--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8320831872659384, 0.8320831872659384)




CCA coefficients mean non-concern: (0.8354464694050086, 0.8354464694050086)




Linear CKA concern: 0.9715399586710445




Linear CKA non-concern: 0.9612208421252896




Kernel CKA concern: 0.9673567543478372




Kernel CKA non-concern: 0.9636911914563707




--13--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.836552878975126, 0.836552878975126)




CCA coefficients mean non-concern: (0.8352360367932641, 0.8352360367932641)




Linear CKA concern: 0.9725934791033818




Linear CKA non-concern: 0.9603372371515883




Kernel CKA concern: 0.9659144769641006




Kernel CKA non-concern: 0.9622102066091679




--14--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8366596188629685, 0.8366596188629685)




CCA coefficients mean non-concern: (0.8347246781779508, 0.8347246781779508)




Linear CKA concern: 0.971674352496807




Linear CKA non-concern: 0.9600910694616267




Kernel CKA concern: 0.9661713567541994




Kernel CKA non-concern: 0.9630235633692065




--15--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8259826494118568, 0.8259826494118568)




CCA coefficients mean non-concern: (0.834847198642709, 0.834847198642709)




Linear CKA concern: 0.9509139554109393




Linear CKA non-concern: 0.9619178119137102




Kernel CKA concern: 0.9451170593695163




Kernel CKA non-concern: 0.9640993752421094




In [12]:
get_sparsity(module)

(0.39653757670359674,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.3997395833333333,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.3997395833333333,
  'bert.e