In [1]:
import copy
import os.path
import sys

In [2]:
pwd = os.getcwd()
sys.path.append(os.path.dirname(pwd))

In [3]:
from utils.model_utils.evaluate import evaluate_model
from utils.model_utils.load_model import *
from utils.model_utils.model_config import ModelConfig
from utils.dataset_utils.load_dataset import load_data
from utils.decompose_utils.weight_remover import WeightRemoverBert
from utils.decompose_utils.concern_identification import ConcernIdentificationBert
from utils.decompose_utils.tangling_identification import TanglingIdentification
from transformers import AutoConfig
from utils.model_utils.save_module import save_module
from datetime import datetime
from utils.decompose_utils.sampling import sampling_class
import torch

In [4]:
# model_name = "sadickam/sdg-classification-bert"
# model_type = "pretrained"
# data = "OSDG"
# num_labels = 16


model_name = "textattack/bert-base-uncased-imdb"
model_type = "pretrained"
data = "IMDb"
num_labels = 2

# model_name = "fabriceyhc/bert-base-uncased-yahoo_answers_topics"
# model_type = "pretrained"
# data = "Yahoo"
# num_labels = 10

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint_name = None
config = AutoConfig.from_pretrained(model_name, num_labels=num_labels)
model_config = ModelConfig(
    _model_name=model_name,
    _model_type=model_type,
    _data=data,
    _transformer_config=config,
    _checkpoint_name=checkpoint_name,
    _device=device,
)
model, tokenizer, checkpoint = load_classification_model(model_config, train_mode=False)

train_dataloader, valid_dataloader, test_dataloader = load_data(
    model_config, batch_size=32, test_size=0.3
)
print("Start Time:" + datetime.now().strftime("%H:%M:%S"))

Directory /home/Minwoo/LESN/Decompose/DecomposeBERT/Models/Configs/pretrained/textattack/bert-base-uncased-imdb exists.
Loading the model.


  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


In [None]:
i = 0
print("#Module " + str(i) + " in progress....")
num_samples = 64

positive_samples = sampling_class(
    train_dataloader, i, num_samples, num_labels, True, 4, device=device
)
negative_samples = sampling_class(
    train_dataloader, i, num_samples, num_labels, False, 4, device=device
)

all_samples = sampling_class(
    train_dataloader, 200, 20, num_labels, False, 4, device=device
)

module1 = copy.deepcopy(model)
w = WeightRemoverBert(model, p=0.9)
ci1 = ConcernIdentificationBert(model, p=0.4)
ti1 = TanglingIdentification(model, p=0.5)

ff1 = [
    [torch.sum(model.bert.encoder.layer[num].intermediate.dense.weight != 0).item()]
    for num in range(config.num_hidden_layers)
]
ff2 = [
    [torch.sum(model.bert.encoder.layer[num].output.dense.weight != 0).item()]
    for num in range(config.num_hidden_layers)
]
pooler = [torch.sum(model.bert.pooler.dense.weight != 0).item()]
classifier = [torch.sum(model.classifier.weight != 0).item()]
print("origin")
j = 0
print(j)
result = evaluate_model(model, model_config, test_dataloader)

print("Start Positive CI sparse")

for batch in all_samples:
    input_ids, attn_mask, _, total_sampled = batch
    with torch.no_grad():
        t1 = w.propagate(module1, input_ids)
    for num in range(config.num_hidden_layers):
        ff1[num].append(
            torch.sum(
                module1.bert.encoder.layer[num].intermediate.dense.weight != 0
            ).item()
        )
        ff2[num].append(
            torch.sum(
                module1.bert.encoder.layer[num].output.dense.weight != 0
            ).item()
        )
    pooler.append(torch.sum(module1.bert.pooler.dense.weight != 0).item())
    classifier.append(torch.sum(module1.classifier.weight != 0).item())

    j += 1
    print(j)

    # result = evaluate_model(module1, model_config, test_dataloader)

print("Start Positive CI after sparse")

for batch in positive_samples:
    input_ids, attn_mask, _, total_sampled = batch
    with torch.no_grad():
        t1 = ci1.propagate(module1, input_ids)
    for num in range(config.num_hidden_layers):
        ff1[num].append(
            torch.sum(
                module1.bert.encoder.layer[num].intermediate.dense.weight != 0
            ).item()
        )
        ff2[num].append(
            torch.sum(
                module1.bert.encoder.layer[num].output.dense.weight != 0
            ).item()
        )
    pooler.append(torch.sum(module1.bert.pooler.dense.weight != 0).item())
    classifier.append(torch.sum(module1.classifier.weight != 0).item())

    j += 1
    print(j)

    # result = evaluate_model(module1, model_config, test_dataloader)

print("Start Negative TI")

for batch in negative_samples:
    input_ids, attn_mask, _, total_sampled = batch
    with torch.no_grad():
        t = ti1.propagate(module1, input_ids)
    for num in range(config.num_hidden_layers):
        ff1[num].append(
            torch.sum(
                module1.bert.encoder.layer[num].intermediate.dense.weight != 0
            ).item()
        )
        ff2[num].append(
            torch.sum(
                module1.bert.encoder.layer[num].output.dense.weight != 0
            ).item()
        )
    pooler.append(torch.sum(module1.bert.pooler.dense.weight != 0).item())
    classifier.append(torch.sum(module1.classifier.weight != 0).item())

    j += 1
    print(j)

    # result = evaluate_model(module1, model_config, test_dataloader)


result = evaluate_model(model, model_config, test_dataloader)
Loss: 0.9480
Precision: 0.7801, Recall: 0.7867, F1-Score: 0.7793
              precision    recall  f1-score   support

           0       0.77      0.66      0.71       797
           1       0.84      0.72      0.78       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.69      0.77       882
           6       0.85      0.80      0.83       940
           7       0.49      0.61      0.54       473
           8       0.66      0.85      0.74       746
           9       0.62      0.73      0.67       689
          10       0.75      0.79      0.77       670
          11       0.62      0.81      0.70       312
          12       0.73      0.81      0.77       665
          13       0.83      0.85      0.84       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78      0.79      0.78     12791
weighted avg       0.81      0.80      0.80     12791

result = evaluate_model(module1, model_config, test_dataloader)
Loss: 0.8861
Precision: 0.7749, Recall: 0.7805, F1-Score: 0.7729
              precision    recall  f1-score   support

           0       0.74      0.65      0.70       797
           1       0.86      0.69      0.76       775
           2       0.86      0.88      0.87       795
           3       0.88      0.81      0.84      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.85      0.81      0.83       940
           7       0.48      0.59      0.53       473
           8       0.66      0.85      0.74       746
           9       0.57      0.74      0.64       689
          10       0.78      0.76      0.77       670
          11       0.61      0.81      0.70       312
          12       0.71      0.81      0.76       665
          13       0.83      0.85      0.84       314
          14       0.85      0.78      0.81       756
          15       0.98      0.97      0.97      1607

    accuracy                           0.79     12791
   macro avg       0.77      0.78      0.77     12791
weighted avg       0.81      0.79      0.80     12791


In [None]:
ci1.active_node

In [None]:
ti1.dead_node

In [None]:
from utils.decompose_utils.concern_modularization import ConcernModularizationBert

In [None]:
collected_input_ids = []
collected_attention_mask = []
collected_labels = []
count = 0

for batch in test_dataloader:
    if count >= 100:
        break

    input_ids = batch["input_ids"].to(model_config.device)
    attention_mask = batch["attention_mask"].to(model_config.device)
    labels = batch["labels"].to(model_config.device)

    # Add data to lists
    collected_input_ids.append(input_ids)
    collected_attention_mask.append(attention_mask)
    collected_labels.append(labels)

    # Increment the count by the batch size
    count += input_ids.size(0)

In [None]:
from utils.dataset_utils.load_dataset import convert_dataset_labels_to_binary
converted_train_dataloader = convert_dataset_labels_to_binary(train_dataloader, i)
converted_valid_dataloader = convert_dataset_labels_to_binary(valid_dataloader, i)
converted_test_dataloader = convert_dataset_labels_to_binary(test_dataloader, i)

In [None]:
conv_collected_input_ids = []
conv_collected_attention_mask = []
conv_collected_labels = []
count = 0

for batch in converted_test_dataloader:
    if count >= 100:
        break

    input_ids = batch["input_ids"].to(model_config.device)
    attention_mask = batch["attention_mask"].to(model_config.device)
    labels = batch["labels"].to(model_config.device)

    # Add data to lists
    conv_collected_input_ids.append(input_ids)
    conv_collected_attention_mask.append(attention_mask)
    conv_collected_labels.append(labels)

    # Increment the count by the batch size
    count += input_ids.size(0)

# module -> binary_model

In [None]:
def pppp(module1, ci1, ti1, model_config):
    ConcernModularizationBert.channeling(module1, ci1.active_node, ti1.active_node,0, model_config.device)
    from transformers import BertForSequenceClassification
    config1 = AutoConfig.from_pretrained(model_name)
    config1.id2label = {0: "negative", 1: "positive"}
    config1.label2id = {"negative": 0, "positive": 1}
    config1._num_labels=2
    module2 = BertForSequenceClassification(config1)
    module2 = module1.to(model_config.device)
    module2.load_state_dict(module1.state_dict())
    return module2

In [None]:
module2 = pppp(module1, ci1, ti1, model_config)
module3 = pppp(model, ci1, ti1, model_config)

In [None]:
def qqqq(module2, conv_collected_input_ids, conv_collected_attention_mask):
    logits = module2(conv_collected_input_ids, conv_collected_attention_mask).logits
    result = evaluate_model(module2, model_config, converted_test_dataloader)
    print(logits)
    pred = logits.argmax(dim=1)
    print(labels)
    print(pred)
    conv_collected_labels[0]

In [None]:
def pad(numbers):
    # 각 숫자를 2자리 문자열로 변환
    print([f"{num:02}" for num in numbers])

In [None]:
pad(ci1.active_node)
pad(ci1.dead_node)
pad(ti1.active_node)
pad(ti1.dead_node)

In [None]:
#SDG
['16', '06', '00', '00', '00', '00', '00', '01', '00', '16', '11', '00', '00', '00', '00', '08']
['00', '03', '11', '16', '16', '15', '16', '07', '15', '00', '01', '16', '16', '15', '16', '00']
['15', '00', '00', '00', '00', '00', '00', '01', '00', '15', '03', '00', '00', '00', '00', '11']
['00', '09', '11', '14', '11', '13', '12', '10', '11', '00', '02', '12', '05', '12', '10', '00']

# IMDb
['16', '00']
['00', '16']
['15', '03']
['01', '15']
#yahoo
['14', '05', '00', '13', '00', '01', '14', '07', '00', '03']
['00', '05', '13', '00', '16', '11', '00', '01', '12', '02']
['04', '00', '00', '15', '00', '00', '14', '00', '00', '07']
['00', '03', '15', '00', '12', '08', '00', '04', '11', '00']

In [None]:
qqqq(module2, conv_collected_input_ids[1], conv_collected_attention_mask[1])
qqqq(module3, conv_collected_input_ids[1], conv_collected_attention_mask[1])