In [1]:
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
    convert_dataset_labels_to_binary,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity
from utils.model_utils.save_module import save_module
from utils.decompose_utils.weight_remover import WeightRemoverBert
from utils.decompose_utils.concern_identification import ConcernIdentificationBert
from utils.decompose_utils.tangling_identification import TanglingIdentification
from utils.decompose_utils.concern_modularization import ConcernModularizationBert
from utils.decompose_utils.sampling import sampling_class
from utils.prune_utils.prune import prune_magnitude, find_layers, LayerWrapper, prune_concern_identification

In [3]:
name = "OSDG"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

checkpoint = None
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]

In [4]:
i = 0
model, tokenizer, checkpoint = load_model(model_config)

train_dataloader, valid_dataloader, test_dataloader = load_data(name, batch_size=64)

color_print("Start Time:" + datetime.now().strftime("%H:%M:%S"))
color_print("#Module " + str(i) + " in progress....")
num_samples = 64

positive_samples = sampling_class(
    train_dataloader, i, num_samples, num_labels, True, 4, device=device
)
negative_samples = sampling_class(
    train_dataloader, i, num_samples, num_labels, False, 4, device=device
)

all_samples = sampling_class(
    train_dataloader, 200, 20, num_labels, False, 4, device=device
)

print("origin")
# evaluate_model(model, model_config, test_dataloader)

module = copy.deepcopy(model)
wr = WeightRemoverBert(model, p=0.9)
ci = ConcernIdentificationBert(model, p=0.4)
ti = TanglingIdentification(model, p=0.5)

print("Start Magnitude pruning")
prune_magnitude(module, sparsity_ratio=0.1)
print(get_sparsity(module)[0])
print("Start Positive CI after sparse")

prune_concern_identification(
    model,
    module,
    positive_samples,
    include_layers=["attention", "intermediate", "output"],
    sparsity_ratio=0.5,
    p=1,
)

print(get_sparsity(module))
result = evaluate_model(module, model_config, test_dataloader)
torch.cuda.empty_cache()

Loading the model.
{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}
The model sadickam/sdg-classification-bert is loaded.
{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}
Loading cached dataset OSDG.
The dataset OSDG is loaded
Start Time:11:16:00
#Module 0 in progress....
origin
Start Magnitude pruning
0.09990180388583593
Start Positive CI after sparse
(0.49670961962191856, {'bert.encoder.layer.0.attention.self.query.weight': 0.5000033908420138, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.weight': 0.5, 'bert.encoder.layer.0.attention.self.key.bias': 0.0, 'bert.encoder.layer.0.attention.self.value.weight': 0.5001390245225694, 'bert.encoder.layer.0.attention.self.value.bia

Evaluating: 100%|██████████| 200/200 [02:25<00:00,  1.38it/s]


Loss: 1.7187
Precision: 0.6786, Recall: 0.4710, F1-Score: 0.4769
              precision    recall  f1-score   support

           0       0.45      0.51      0.48       797
           1       0.79      0.03      0.06       775
           2       0.89      0.54      0.68       795
           3       0.90      0.47      0.62      1110
           4       0.37      0.87      0.52      1260
           5       0.92      0.31      0.46       882
           6       0.90      0.35      0.50       940
           7       0.48      0.06      0.11       473
           8       0.48      0.59      0.53       746
           9       0.24      0.80      0.37       689
          10       0.69      0.48      0.57       670
          11       0.76      0.21      0.32       312
          12       0.48      0.65      0.55       665
          13       0.84      0.35      0.49       314
          14       0.89      0.37      0.53       756
          15       0.77      0.94      0.85      1607

    accuracy   

In [5]:
for idx, batch in enumerate(negative_samples):
    input_ids, attn_mask, _, total_sampled = batch
    with torch.no_grad():
        ti.propagate(module, input_ids)
    # if idx % eval_step:
    #     evaluate_model(module, model_config, test_dataloader)
result = evaluate_model(module, model_config, test_dataloader)

Evaluating: 100%|██████████| 200/200 [02:35<00:00,  1.29it/s]


Loss: 1.7527
Precision: 0.6874, Recall: 0.4578, F1-Score: 0.4623
              precision    recall  f1-score   support

           0       0.46      0.50      0.48       797
           1       0.86      0.02      0.05       775
           2       0.91      0.51      0.66       795
           3       0.91      0.46      0.61      1110
           4       0.36      0.87      0.51      1260
           5       0.92      0.29      0.44       882
           6       0.91      0.32      0.47       940
           7       0.56      0.06      0.11       473
           8       0.47      0.58      0.52       746
           9       0.23      0.81      0.36       689
          10       0.65      0.50      0.57       670
          11       0.77      0.19      0.30       312
          12       0.50      0.62      0.55       665
          13       0.86      0.32      0.46       314
          14       0.89      0.33      0.48       756
          15       0.74      0.95      0.83      1607

    accuracy   

In [6]:
print(get_sparsity(module))

(0.4939691192632194, {'bert.encoder.layer.0.attention.self.query.weight': 0.5000033908420138, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.weight': 0.5, 'bert.encoder.layer.0.attention.self.key.bias': 0.0, 'bert.encoder.layer.0.attention.self.value.weight': 0.5001390245225694, 'bert.encoder.layer.0.attention.self.value.bias': 0.0, 'bert.encoder.layer.0.attention.output.dense.weight': 0.5, 'bert.encoder.layer.0.attention.output.dense.bias': 0.0, 'bert.encoder.layer.0.intermediate.dense.weight': 0.5, 'bert.encoder.layer.0.intermediate.dense.bias': 0.0, 'bert.encoder.layer.0.output.dense.weight': 0.5, 'bert.encoder.layer.0.output.dense.bias': 0.0, 'bert.encoder.layer.1.attention.self.query.weight': 0.5, 'bert.encoder.layer.1.attention.self.query.bias': 0.0, 'bert.encoder.layer.1.attention.self.key.weight': 0.5, 'bert.encoder.layer.1.attention.self.key.bias': 0.0, 'bert.encoder.layer.1.attention.self.value.weight': 0.5, 'bert.encoder.layer