In [1]:
import copy
import os.path
import sys

pwd = os.getcwd()
sys.path.append(os.path.dirname(pwd))

from utils.model_utils.evaluate import evaluate_model
from utils.model_utils.load_model import *
from utils.model_utils.model_config import ModelConfig
from utils.dataset_utils.load_dataset import load_data
from utils.decompose_utils.weight_remover import WeightRemoverBert
from utils.decompose_utils.concern_identification import ConcernIdentificationBert
from utils.decompose_utils.tangling_identification import TanglingIdentification
from transformers import AutoConfig
from utils.model_utils.save_module import save_module
from datetime import datetime
from utils.decompose_utils.concern_modularization import ConcernModularizationBert
from utils.decompose_utils.sampling import sampling_class
from utils.dataset_utils.load_dataset import convert_dataset_labels_to_binary, extract_and_convert_dataloader
import torch


In [2]:
model_name = "sadickam/sdg-classification-bert"
model_type = "pretrained"
data = "OSDG"
num_labels = 16


# model_name = "textattack/bert-base-uncased-imdb"
# model_type = "pretrained"
# data = "IMDb"
# num_labels = 2

# model_name = "fabriceyhc/bert-base-uncased-yahoo_answers_topics"
# model_type = "pretrained"
# data = "Yahoo"
# num_labels = 10

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

checkpoint_name = None
config = AutoConfig.from_pretrained(model_name, num_labels=num_labels)
model_config = ModelConfig(
    _model_name=model_name,
    _model_type=model_type,
    _data=data,
    _transformer_config=config,
    _checkpoint_name=checkpoint_name,
    _device=device,
)


In [None]:
for i in range(num_labels):
    model, tokenizer, checkpoint = load_classification_model(model_config, train_mode=False)

    train_dataloader, valid_dataloader, test_dataloader = load_data(
        model_config, batch_size=32, test_size=0.3
    )

    print("Start Time:" + datetime.now().strftime("%H:%M:%S"))
    print("#Module " + str(i) + " in progress....")
    num_samples = 64

    positive_samples = sampling_class(
        train_dataloader, i, num_samples, num_labels, True, 4, device=device
    )
    negative_samples = sampling_class(
        train_dataloader, i, num_samples, num_labels, False, 4, device=device
    )

    all_samples = sampling_class(
        train_dataloader, 200, 20, num_labels, False, 4, device=device
    )

    print("origin")
    evaluate_model(model, model_config, test_dataloader)

    module = copy.deepcopy(model)
    wr = WeightRemoverBert(model, p=0.9)
    ci = ConcernIdentificationBert(model, p=0.4)
    ti = TanglingIdentification(model, p=0.5)

    print("Start Positive CI sparse")

    eval_step = 5
    for idx, batch in enumerate(all_samples):
        input_ids, attn_mask, _, total_sampled = batch
        with torch.no_grad():
            wr.propagate(module, input_ids)
        if idx % eval_step:
            # result = evaluate_model(module, model_config, test_dataloader)
            pass

    print("Start Positive CI after sparse")

    for idx, batch in enumerate(positive_samples):
        input_ids, attn_mask, _, total_sampled = batch
        with torch.no_grad():
            ci.propagate(module, input_ids)
        if idx % eval_step:
            # result = evaluate_model(module, model_config, test_dataloader)
            pass

    print("Start Negative TI")

    for idx, batch in enumerate(negative_samples):
        input_ids, attn_mask, _, total_sampled = batch
        with torch.no_grad():
            ti.propagate(module, input_ids)
        if idx % eval_step:
            # result = evaluate_model(module, model_config, test_dataloader)
            pass

    ConcernModularizationBert.channeling(module, ci.active_node, ti.dead_node, i, model_config.device)
    binary_module = ConcernModularizationBert.convert2binary(model_config, module)
    # save_module(binary_module, model_config.module_dir, model_config.model_name)
    
    for m in range(num_labels):
        if i == m:
            continue
        print(f"[i, m]")
        converted_test_dataloader = convert_dataset_labels_to_binary(test_dataloader, i, True)
        result = evaluate_model(module, model_config, converted_test_dataloader)


Directory /home/Minwoo/LESN/Decompose/DecomposeBERT/Models/Configs/pretrained/sadickam/sdg-classification-bert exists.
Loading the model.
Start Time:21:33:05
#Module 0 in progress....
origin


Evaluating: 100%|██████████| 400/400 [01:35<00:00,  4.19it/s]


Loss: 0.9480
Precision: 0.7801, Recall: 0.7867, F1-Score: 0.7793
              precision    recall  f1-score   support

           0       0.77      0.66      0.71       797
           1       0.84      0.72      0.78       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.69      0.77       882
           6       0.85      0.80      0.83       940
           7       0.49      0.61      0.54       473
           8       0.66      0.85      0.74       746
           9       0.62      0.73      0.67       689
          10       0.75      0.79      0.77       670
          11       0.62      0.81      0.70       312
          12       0.73      0.81      0.77       665
          13       0.83      0.85      0.84       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy   

Evaluating: 100%|██████████| 50/50 [00:11<00:00,  4.22it/s]


Loss: 0.4936
Precision: 0.7851, Recall: 0.7359, F1-Score: 0.7240
              precision    recall  f1-score   support

           0       0.90      0.53      0.67       797
           1       0.67      0.94      0.78       797

    accuracy                           0.74      1594
   macro avg       0.79      0.74      0.72      1594
weighted avg       0.79      0.74      0.72      1594



Evaluating: 100%|██████████| 50/50 [00:11<00:00,  4.19it/s]


Loss: 0.5039
Precision: 0.7751, Recall: 0.7202, F1-Score: 0.7055
              precision    recall  f1-score   support

           0       0.90      0.50      0.64       797
           1       0.65      0.94      0.77       797

    accuracy                           0.72      1594
   macro avg       0.78      0.72      0.71      1594
weighted avg       0.78      0.72      0.71      1594



Evaluating: 100%|██████████| 50/50 [00:11<00:00,  4.19it/s]


Loss: 0.5068
Precision: 0.7807, Recall: 0.7290, F1-Score: 0.7159
              precision    recall  f1-score   support

           0       0.90      0.51      0.65       797
           1       0.66      0.94      0.78       797

    accuracy                           0.73      1594
   macro avg       0.78      0.73      0.72      1594
weighted avg       0.78      0.73      0.72      1594



Evaluating: 100%|██████████| 50/50 [00:11<00:00,  4.17it/s]


Loss: 0.5137
Precision: 0.7699, Recall: 0.7120, F1-Score: 0.6957
              precision    recall  f1-score   support

           0       0.89      0.48      0.63       797
           1       0.64      0.94      0.77       797

    accuracy                           0.71      1594
   macro avg       0.77      0.71      0.70      1594
weighted avg       0.77      0.71      0.70      1594



Evaluating: 100%|██████████| 50/50 [00:11<00:00,  4.18it/s]


Loss: 0.5066
Precision: 0.7703, Recall: 0.7127, F1-Score: 0.6965
              precision    recall  f1-score   support

           0       0.90      0.48      0.63       797
           1       0.65      0.94      0.77       797

    accuracy                           0.71      1594
   macro avg       0.77      0.71      0.70      1594
weighted avg       0.77      0.71      0.70      1594



Evaluating: 100%|██████████| 50/50 [00:11<00:00,  4.17it/s]


Loss: 0.5013
Precision: 0.7759, Recall: 0.7215, F1-Score: 0.7070
              precision    recall  f1-score   support

           0       0.90      0.50      0.64       797
           1       0.65      0.94      0.77       797

    accuracy                           0.72      1594
   macro avg       0.78      0.72      0.71      1594
weighted avg       0.78      0.72      0.71      1594



Evaluating: 100%|██████████| 50/50 [00:11<00:00,  4.17it/s]


Loss: 0.5169
Precision: 0.7711, Recall: 0.7139, F1-Score: 0.6980
              precision    recall  f1-score   support

           0       0.90      0.48      0.63       797
           1       0.65      0.94      0.77       797

    accuracy                           0.71      1594
   macro avg       0.77      0.71      0.70      1594
weighted avg       0.77      0.71      0.70      1594



Evaluating: 100%|██████████| 50/50 [00:12<00:00,  4.16it/s]


Loss: 0.5203
Precision: 0.7727, Recall: 0.7164, F1-Score: 0.7010
              precision    recall  f1-score   support

           0       0.90      0.49      0.63       797
           1       0.65      0.94      0.77       797

    accuracy                           0.72      1594
   macro avg       0.77      0.72      0.70      1594
weighted avg       0.77      0.72      0.70      1594



Evaluating: 100%|██████████| 50/50 [00:12<00:00,  4.16it/s]


Loss: 0.4937
Precision: 0.7815, Recall: 0.7302, F1-Score: 0.7174
              precision    recall  f1-score   support

           0       0.90      0.52      0.66       797
           1       0.66      0.94      0.78       797

    accuracy                           0.73      1594
   macro avg       0.78      0.73      0.72      1594
weighted avg       0.78      0.73      0.72      1594



Evaluating: 100%|██████████| 50/50 [00:12<00:00,  4.16it/s]


Loss: 0.4988
Precision: 0.7835, Recall: 0.7334, F1-Score: 0.7211
              precision    recall  f1-score   support

           0       0.90      0.52      0.66       797
           1       0.66      0.94      0.78       797

    accuracy                           0.73      1594
   macro avg       0.78      0.73      0.72      1594
weighted avg       0.78      0.73      0.72      1594



Evaluating: 100%|██████████| 50/50 [00:12<00:00,  4.16it/s]


Loss: 0.5065
Precision: 0.7775, Recall: 0.7240, F1-Score: 0.7100
              precision    recall  f1-score   support

           0       0.90      0.50      0.65       797
           1       0.66      0.94      0.77       797

    accuracy                           0.72      1594
   macro avg       0.78      0.72      0.71      1594
weighted avg       0.78      0.72      0.71      1594



Evaluating: 100%|██████████| 50/50 [00:12<00:00,  4.15it/s]


Loss: 0.5096
Precision: 0.7723, Recall: 0.7158, F1-Score: 0.7003
              precision    recall  f1-score   support

           0       0.90      0.49      0.63       797
           1       0.65      0.94      0.77       797

    accuracy                           0.72      1594
   macro avg       0.77      0.72      0.70      1594
weighted avg       0.77      0.72      0.70      1594



Evaluating: 100%|██████████| 50/50 [00:12<00:00,  4.16it/s]


Loss: 0.5010
Precision: 0.7795, Recall: 0.7271, F1-Score: 0.7137
              precision    recall  f1-score   support

           0       0.90      0.51      0.65       797
           1       0.66      0.94      0.78       797

    accuracy                           0.73      1594
   macro avg       0.78      0.73      0.71      1594
weighted avg       0.78      0.73      0.71      1594



Evaluating: 100%|██████████| 50/50 [00:12<00:00,  4.16it/s]


Loss: 0.4970
Precision: 0.7771, Recall: 0.7233, F1-Score: 0.7092
              precision    recall  f1-score   support

           0       0.90      0.50      0.65       797
           1       0.66      0.94      0.77       797

    accuracy                           0.72      1594
   macro avg       0.78      0.72      0.71      1594
weighted avg       0.78      0.72      0.71      1594



Evaluating: 100%|██████████| 50/50 [00:12<00:00,  4.16it/s]


Loss: 0.4997
Precision: 0.7763, Recall: 0.7221, F1-Score: 0.7078
              precision    recall  f1-score   support

           0       0.90      0.50      0.64       797
           1       0.65      0.94      0.77       797

    accuracy                           0.72      1594
   macro avg       0.78      0.72      0.71      1594
weighted avg       0.78      0.72      0.71      1594



Evaluating: 100%|██████████| 50/50 [00:12<00:00,  4.16it/s]


Loss: 0.5113
Precision: 0.7747, Recall: 0.7196, F1-Score: 0.7048
              precision    recall  f1-score   support

           0       0.90      0.50      0.64       797
           1       0.65      0.94      0.77       797

    accuracy                           0.72      1594
   macro avg       0.77      0.72      0.70      1594
weighted avg       0.77      0.72      0.70      1594

Directory /home/Minwoo/LESN/Decompose/DecomposeBERT/Models/Configs/pretrained/sadickam/sdg-classification-bert exists.
Loading the model.
Start Time:21:38:22
#Module 1 in progress....
origin


Evaluating: 100%|██████████| 400/400 [01:36<00:00,  4.14it/s]


Loss: 0.9480
Precision: 0.7801, Recall: 0.7867, F1-Score: 0.7793
              precision    recall  f1-score   support

           0       0.77      0.66      0.71       797
           1       0.84      0.72      0.78       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.69      0.77       882
           6       0.85      0.80      0.83       940
           7       0.49      0.61      0.54       473
           8       0.66      0.85      0.74       746
           9       0.62      0.73      0.67       689
          10       0.75      0.79      0.77       670
          11       0.62      0.81      0.70       312
          12       0.73      0.81      0.77       665
          13       0.83      0.85      0.84       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy   

Evaluating: 100%|██████████| 49/49 [00:11<00:00,  4.22it/s]


Loss: 0.4344
Precision: 0.8206, Recall: 0.8013, F1-Score: 0.7983
              precision    recall  f1-score   support

           0       0.90      0.68      0.77       775
           1       0.74      0.92      0.82       775

    accuracy                           0.80      1550
   macro avg       0.82      0.80      0.80      1550
weighted avg       0.82      0.80      0.80      1550



Evaluating: 100%|██████████| 49/49 [00:11<00:00,  4.21it/s]


Loss: 0.4304
Precision: 0.8239, Recall: 0.8058, F1-Score: 0.8031
              precision    recall  f1-score   support

           0       0.90      0.69      0.78       775
           1       0.75      0.92      0.83       775

    accuracy                           0.81      1550
   macro avg       0.82      0.81      0.80      1550
weighted avg       0.82      0.81      0.80      1550



Evaluating: 100%|██████████| 49/49 [00:11<00:00,  4.20it/s]


Loss: 0.4460
Precision: 0.8201, Recall: 0.8006, F1-Score: 0.7976
              precision    recall  f1-score   support

           0       0.90      0.68      0.77       775
           1       0.74      0.92      0.82       775

    accuracy                           0.80      1550
   macro avg       0.82      0.80      0.80      1550
weighted avg       0.82      0.80      0.80      1550



Evaluating: 100%|██████████| 49/49 [00:11<00:00,  4.20it/s]


Loss: 0.4450
Precision: 0.8220, Recall: 0.8032, F1-Score: 0.8003
              precision    recall  f1-score   support

           0       0.90      0.68      0.78       775
           1       0.74      0.92      0.82       775

    accuracy                           0.80      1550
   macro avg       0.82      0.80      0.80      1550
weighted avg       0.82      0.80      0.80      1550



Evaluating:  29%|██▊       | 14/49 [00:03<00:08,  4.16it/s]