In [1]:
import copy
import os.path
import sys

pwd = os.getcwd()
sys.path.append(os.path.dirname(pwd))

from utils.model_utils.evaluate import evaluate_model, get_sparsity
from utils.model_utils.load_model import load_model
from utils.helper import ModelConfig
from utils.dataset_utils.load_dataset import load_data
from utils.decompose_utils.weight_remover import WeightRemoverBert
from utils.decompose_utils.concern_identification import ConcernIdentificationBert
from utils.decompose_utils.tangling_identification import TanglingIdentification
from transformers import AutoConfig
from utils.model_utils.save_module import save_module
from datetime import datetime
from utils.decompose_utils.concern_modularization import ConcernModularizationBert
from utils.decompose_utils.sampling import sampling_class
from utils.dataset_utils.load_dataset import convert_dataset_labels_to_binary, extract_and_convert_dataloader
import torch
from utils.prune_utils.prune import prune_magnitude


In [2]:
from utils.dataset_utils.load_dataset import *

In [3]:
from datasets import load_dataset

In [4]:
name = "OSDG"
device = torch.device("cuda:0")

In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]

In [6]:
for i in range(num_labels):
    model, tokenizer, checkpoint = load_model(model_config)

    train_dataloader, valid_dataloader, test_dataloader = load_data(
        name, batch_size=32
    )

    print("Start Time:" + datetime.now().strftime("%H:%M:%S"))
    print("#Module " + str(i) + " in progress....")
    num_samples = 64

    positive_samples = sampling_class(
        train_dataloader, i, num_samples, num_labels, True, 4, device=device
    )
    negative_samples = sampling_class(
        train_dataloader, i, num_samples, num_labels, False, 4, device=device
    )

    all_samples = sampling_class(
        train_dataloader, 200, 20, num_labels, False, 4, device=device
    )

    print("origin")
    # evaluate_model(model, model_config, test_dataloader)

    module = copy.deepcopy(model)
    wr = WeightRemoverBert(model, p=0.9)
    ci = ConcernIdentificationBert(model, p=0.4)
    ti = TanglingIdentification(model, p=0.6)

    print("Start Positive CI sparse")

    eval_step = 5
    for idx, batch in enumerate(all_samples):
        input_ids, attn_mask, _, total_sampled = batch
        with torch.no_grad():
            wr.propagate(module, input_ids)
        if idx % eval_step:
            # result = evaluate_model(module, model_config, test_dataloader)
            pass
    result = evaluate_model(module, model_config, test_dataloader)

    print("Start Positive CI after sparse")

    for idx, batch in enumerate(positive_samples):
        input_ids, attn_mask, _, total_sampled = batch
        with torch.no_grad():
            ci.propagate(module, input_ids)
        if idx % eval_step:
            # result = evaluate_model(module, model_config, test_dataloader)
            pass

    print("Start Negative TI")

    for idx, batch in enumerate(negative_samples):
        input_ids, attn_mask, _, total_sampled = batch
        with torch.no_grad():
            ti.propagate(module, input_ids)
        if idx % eval_step:
            # result = evaluate_model(module, model_config, test_dataloader)
            pass
    result = evaluate_model(module, model_config, test_dataloader)
    # ConcernModularizationBert.channeling(module, ci.active_node, ti.dead_node, i, model_config.device)
    # binary_module = ConcernModularizationBert.convert2binary(model_config, module)
    # save_module(binary_module, model_config.module_dir, model_config.model_name)
    
    # for m in range(num_labels):
    #     if i == m:
    #         continue
    #     print(f"[{i}, {m}]")
    #     converted_test_dataloader = extract_and_convert_dataloader(test_dataloader, i, m)
    #     result = evaluate_model(module, model_config, converted_test_dataloader)
    break

Loading the model.
{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}
The model sadickam/sdg-classification-bert is loaded.
{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}
Loading cached dataset OSDG.
The dataset OSDG is loaded
Start Time:08:05:46
#Module 0 in progress....
origin
Start Positive CI sparse


Evaluating: 100%|██████████| 400/400 [03:54<00:00,  1.71it/s]


Loss: 0.9466
Precision: 0.7789, Recall: 0.7857, F1-Score: 0.7781
              precision    recall  f1-score   support

           0       0.76      0.67      0.71       797
           1       0.84      0.71      0.77       775
           2       0.87      0.88      0.88       795
           3       0.87      0.82      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.69      0.77       882
           6       0.86      0.80      0.82       940
           7       0.48      0.61      0.54       473
           8       0.65      0.86      0.74       746
           9       0.61      0.72      0.66       689
          10       0.76      0.78      0.77       670
          11       0.62      0.81      0.71       312
          12       0.72      0.81      0.76       665
          13       0.83      0.86      0.84       314
          14       0.86      0.78      0.82       756
          15       0.97      0.98      0.97      1607

    accuracy   

Evaluating: 100%|██████████| 400/400 [04:20<00:00,  1.54it/s]


Loss: 0.9261
Precision: 0.7732, Recall: 0.7811, F1-Score: 0.7725
              precision    recall  f1-score   support

           0       0.73      0.67      0.70       797
           1       0.85      0.70      0.77       775
           2       0.85      0.88      0.86       795
           3       0.88      0.80      0.84      1110
           4       0.86      0.80      0.83      1260
           5       0.89      0.68      0.77       882
           6       0.86      0.80      0.83       940
           7       0.49      0.59      0.53       473
           8       0.66      0.85      0.74       746
           9       0.58      0.73      0.65       689
          10       0.78      0.77      0.77       670
          11       0.61      0.81      0.70       312
          12       0.70      0.81      0.75       665
          13       0.81      0.86      0.83       314
          14       0.85      0.77      0.81       756
          15       0.98      0.97      0.97      1607

    accuracy   

In [7]:
get_sparsity(module)

(0.20453685801915392,
 {'bert.embeddings.word_embeddings.weight': 0.0,
  'bert.embeddings.position_embeddings.weight': 0.0,
  'bert.embeddings.token_type_embeddings.weight': 0.0,
  'bert.embeddings.LayerNorm.weight': 0.0,
  'bert.embeddings.LayerNorm.bias': 0.0,
  'bert.encoder.layer.0.attention.self.query.weight': 0.0,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.0,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.0,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.0,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.attention.output.LayerNorm.weight': 0.0,
  'bert.encoder.layer.0.attention.output.LayerNorm.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.39404212103949654,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.l