In [1]:
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
    convert_dataset_labels_to_binary,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity
from utils.model_utils.save_module import save_module
from utils.decompose_utils.weight_remover import WeightRemoverBert
from utils.decompose_utils.concern_identification import ConcernIdentificationBert
from utils.decompose_utils.tangling_identification import TanglingIdentification
from utils.decompose_utils.concern_modularization import ConcernModularizationBert
from utils.decompose_utils.sampling import sampling_class
from utils.prune_utils.prune import prune_magnitude, find_layers, LayerWrapper

In [3]:
def prune_concern_identification(
    model,
    module,
    dataloader,
    sparsity_ratio=0.4,
    p=1,
    include_layers=None,
    exclude_layers=None,
):
    ref_layers = find_layers(
        model, include_layers=include_layers, exclude_layers=exclude_layers
    )
    target_layers = find_layers(
        module, include_layers=include_layers, exclude_layers=exclude_layers
    )
    device = next(model.parameters()).device

    wrappers = {}
    ref_handle_list = []
    target_handle_list = []

    def get_hook(wrapper):
        def hook(module, input, output):
            wrapper.update(input, output)

        return hook

    for (ref_name, ref_layer), (target_name, target_layer) in zip(
        ref_layers.items(), target_layers.items()
    ):
        ref_wrapper = LayerWrapper(ref_name, ref_layer)
        target_wrapper = LayerWrapper(target_name, target_layer)

        wrappers[ref_name] = {"ref": ref_wrapper, "target": target_wrapper}

        ref_handle = ref_layer.register_forward_hook(get_hook(ref_wrapper))
        target_handle = target_layer.register_forward_hook(get_hook(target_wrapper))

        ref_handle_list.append(ref_handle)
        target_handle_list.append(target_handle)

    for batch in dataloader:
        input_ids, attn_mask, _, _ = batch
        input_ids = input_ids.to(device)
        attn_mask = attn_mask.to(device)
        with torch.no_grad():
            model(input_ids, attention_mask=attn_mask)
            module(input_ids, attention_mask=attn_mask)

    for handle in ref_handle_list + target_handle_list:
        handle.remove()

    for name, wrapper_pair in wrappers.items():
        wrapper_pair["ref"].update_batch()
        wrapper_pair["target"].update_batch()

        ref_outputs = wrapper_pair["ref"].outputs
        target_outputs = wrapper_pair["target"].outputs

        current_weight = wrapper_pair["target"].layer.weight.data
        original_weight = wrapper_pair["ref"].layer.weight.data

        # output_loss = (target_outputs - ref_outputs) ** 2
        output_loss = target_outputs - ref_outputs  # (batch_size, seq_dim, output_dim)

        # Normalize batch size
        # if p == 'mean':
        #     output_loss = torch.mean(output_loss, dim=0)
        # elif p == 1:
        #     output_loss = torch.norm(output_loss, p=1, dim=0)
        # elif p == 2:
        #     output_loss = torch.norm(output_loss, p=2, dim=0)
        # elif p == "inf":
        #     output_loss = torch.norm(output_loss, p=float('inf'), dim=0)
        # else:
        #     raise ValueError("Unsupported norm type")

        # Normalize output_loss
        # output_loss = (output_loss - output_loss.mean(dim=0)) /output_loss.std(dim=0)

        # weight_score = torch.mean(output_loss, dim=0).reshape(-1, 1).to(device)
        output_loss_flat = output_loss.to(device).view(
            -1, output_loss.size(-1)
        )  # (batch_size * seq_dim, output_dim)

        inputs = wrapper_pair["target"].inputs  # (batch_size, seq_dim, input_dim)
        inputs_flat = inputs.view(
            -1, inputs.size(-1)
        )  # (batch_size * seq_dim, input_dim)
        inverse_inputs = torch.linalg.pinv(inputs_flat).to(
            device
        )  # (input_dim, batch_size * seq_dim)
        pseudo_weight_matrix = torch.matmul(
            inverse_inputs, output_loss_flat
        )  # (input_dim, output_dim)

        importance_score = torch.abs(
            pseudo_weight_matrix.T * current_weight
        )  # (output_dim, input_dim)

        W_mask = torch.zeros_like(importance_score) == 1
        sort_res = torch.sort(importance_score, dim=-1, stable=True)
        indices = sort_res[1][:, : int(importance_score.shape[1] * sparsity_ratio)]
        W_mask.scatter_(1, indices, True)
        current_weight[W_mask] = 0

        # importance_score = importance_score.reshape(-1)
        # W_mask = (torch.zeros_like(importance_score) == 1)
        # sort_res = torch.sort(importance_score, stable=True)
        # indices = sort_res[1][:int(importance_score.shape[0] * sparsity_ratio)]
        # W_mask.scatter_(0, indices, True)
        # W_mask = W_mask.view_as(current_weight)
        # current_weight[W_mask] = 0

In [4]:
name = "OSDG"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

checkpoint = None
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]

In [5]:
i = 2
model, tokenizer, checkpoint = load_model(model_config)

train_dataloader, valid_dataloader, test_dataloader = load_data(name, batch_size=64)

color_print("Start Time:" + datetime.now().strftime("%H:%M:%S"))
color_print("#Module " + str(i) + " in progress....")
num_samples = 64

positive_samples = sampling_class(
    train_dataloader, i, num_samples, num_labels, True, 4, device=device
)
negative_samples = sampling_class(
    train_dataloader, i, num_samples, num_labels, False, 4, device=device
)

all_samples = sampling_class(
    train_dataloader, 200, 20, num_labels, False, 4, device=device
)

print("origin")
# evaluate_model(model, model_config, test_dataloader)

module = copy.deepcopy(model)
wr = WeightRemoverBert(model, p=0.9)
ci = ConcernIdentificationBert(model, p=0.4)
ti = TanglingIdentification(model, p=0.5)

print("Start Magnitude pruning")
prune_magnitude(module, sparsity_ratio=0.1)
print(get_sparsity(module)[0])
print("Start Positive CI after sparse")

prune_concern_identification(
    model,
    module,
    positive_samples,
    include_layers=["attention", "intermediate", "output"],
    sparsity_ratio=0.5,
    p=1,
)

print(get_sparsity(module))
result = evaluate_model(module, model_config, test_dataloader)
torch.cuda.empty_cache()

Loading the model.
{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}
The model sadickam/sdg-classification-bert is loaded.
{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}
Loading cached dataset OSDG.
The dataset OSDG is loaded
Start Time:01:18:59
#Module 2 in progress....
origin
Start Magnitude pruning
0.09990180388583593
Start Positive CI after sparse
(0.49669810368769646, {'bert.encoder.layer.0.attention.self.query.weight': 0.5, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.weight': 0.5, 'bert.encoder.layer.0.attention.self.key.bias': 0.0, 'bert.encoder.layer.0.attention.self.value.weight': 0.5, 'bert.encoder.layer.0.attention.self.value.bias': 0.0, 'bert.encoder.layer.0

Evaluating: 100%|██████████| 200/200 [04:30<00:00,  1.35s/it]


Loss: 1.7504
Precision: 0.6866, Recall: 0.4419, F1-Score: 0.4567
              precision    recall  f1-score   support

           0       0.67      0.24      0.35       797
           1       0.82      0.10      0.19       775
           2       0.87      0.48      0.62       795
           3       0.91      0.47      0.62      1110
           4       0.28      0.87      0.42      1260
           5       0.92      0.12      0.21       882
           6       0.89      0.20      0.33       940
           7       0.15      0.50      0.23       473
           8       0.52      0.54      0.53       746
           9       0.29      0.73      0.42       689
          10       0.66      0.49      0.57       670
          11       0.74      0.26      0.38       312
          12       0.54      0.51      0.53       665
          13       0.89      0.32      0.47       314
          14       0.84      0.40      0.55       756
          15       0.99      0.82      0.90      1607

    accuracy   