In [1]:
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.getcwd())))

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
    convert_dataset_labels_to_binary,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity
from utils.model_utils.save_module import save_module
from utils.decompose_utils.weight_remover import WeightRemoverBert
from utils.decompose_utils.concern_identification import ConcernIdentificationBert
from utils.decompose_utils.tangling_identification import TanglingIdentification
from utils.decompose_utils.concern_modularization import ConcernModularizationBert
from utils.decompose_utils.sampling import sampling_class
from utils.prune_utils.prune import prune_magnitude, find_layers, LayerWrapper, prune_ci

In [3]:
name = "OSDG"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

checkpoint = None
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]

In [4]:
def prune_concern_identification(
    model,
    module,
    dataloader,
    sparsity_ratio=0.4,
    step_size=4,
    p=1,
    include_layers=None,
    exclude_layers=None,
):
    ref_layers = find_layers(
        model, include_layers=include_layers, exclude_layers=exclude_layers
    )
    target_layers = find_layers(
        module, include_layers=include_layers, exclude_layers=exclude_layers
    )
    device = next(model.parameters()).device

    wrappers = {}
    ref_handle_list = []
    target_handle_list = []

    def get_hook(wrapper):
        def hook(module, input, output):
            wrapper.update(input, output)

        return hook

    for (ref_name, ref_layer), (target_name, target_layer) in zip(
        ref_layers.items(), target_layers.items()
    ):
        ref_wrapper = LayerWrapper(ref_name, ref_layer)
        target_wrapper = LayerWrapper(target_name, target_layer)

        wrappers[ref_name] = {"ref": ref_wrapper, "target": target_wrapper}

        ref_handle = ref_layer.register_forward_hook(get_hook(ref_wrapper))
        target_handle = target_layer.register_forward_hook(get_hook(target_wrapper))

        ref_handle_list.append(ref_handle)
        target_handle_list.append(target_handle)

    for batch in dataloader:
        input_ids, attn_mask, _, _ = batch
        input_ids = input_ids.to(device)
        attn_mask = attn_mask.to(device)
        with torch.no_grad():
            model(input_ids, attention_mask=attn_mask)
            module(input_ids, attention_mask=attn_mask)

    for handle in ref_handle_list + target_handle_list:
        handle.remove()

    for name, wrapper_pair in wrappers.items():
        wrapper_pair["ref"].update_batch()
        wrapper_pair["target"].update_batch()

        ref_outputs = wrapper_pair["ref"].outputs
        target_outputs = wrapper_pair["target"].outputs

        current_weight = wrapper_pair["target"].layer.weight.data

        output_loss = (
            target_outputs - ref_outputs
        )  # (batch_size, seq_dim, output_dim)

        output_loss = output_loss.reshape((-1, output_loss.shape[-1]))  # (batch_size * seq_dim, output_dim)

        output_loss = output_loss.t().to(device) # (output_dim, batch_size * seq_dim)
        importance_score = torch.norm(output_loss, p=2, dim=1) ** 2 # (output_dim)

        importance_score = torch.abs(current_weight) * importance_score.reshape((-1, 1)) # (output_dim, input_dim) * (output_dim, 1) = (output_dim, input_dim)

        W_mask = torch.zeros_like(importance_score) == 1
        sort_res = torch.sort(importance_score, dim=0, stable=True)
        indices = sort_res[1][:int(importance_score.shape[0] * sparsity_ratio), :]
        W_mask.scatter_(0, indices, True)
        current_weight[W_mask] = 0

        wrapper_pair["ref"].remove()
        wrapper_pair["target"].remove()



In [5]:
i = 0
model, tokenizer, checkpoint = load_model(model_config)

train_dataloader, valid_dataloader, test_dataloader = load_data(name, batch_size=64)

color_print("Start Time:" + datetime.now().strftime("%H:%M:%S"))
color_print("#Module " + str(i) + " in progress....")
num_samples = 64

positive_samples = sampling_class(
    train_dataloader, i, num_samples, num_labels, True, 4, device=device
)
negative_samples = sampling_class(
    train_dataloader, i, num_samples, num_labels, False, 4, device=device
)

all_samples = sampling_class(
    train_dataloader, 200, 20, num_labels, False, 4, device=device
)

print("origin")
# evaluate_model(model, model_config, test_dataloader)

module = copy.deepcopy(model)
wr = WeightRemoverBert(model, p=0.9)
ci = ConcernIdentificationBert(model, p=0.4)
ti = TanglingIdentification(model, p=0.5)

print("Start Magnitude pruning")
prune_magnitude(module, include_layers=["attention", "intermediate", "output"], sparsity_ratio=0.1)
print(get_sparsity(module)[0])
print("Start Positive CI after sparse")

prune_concern_identification(
    model,
    module,
    positive_samples,
    include_layers=["attention", "intermediate", "output"],
    sparsity_ratio=0.6,
    p=1,
)

print(get_sparsity(module))
result = evaluate_model(module, model_config, test_dataloader)
torch.cuda.empty_cache()

Loading the model.
{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}
The model sadickam/sdg-classification-bert is loaded.
{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}
Loading cached dataset OSDG.
The dataset OSDG is loaded
Start Time:13:23:28
#Module 0 in progress....
origin
Start Magnitude pruning
0.09919858441371328
Start Positive CI after sparse
(0.5944862664659172, {'bert.encoder.layer.0.attention.self.query.weight': 0.5989583333333334, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.weight': 0.5989583333333334, 'bert.encoder.layer.0.attention.self.key.bias': 0.0, 'bert.encoder.layer.0.attention.self.value.weight': 0.5989634195963541, 'bert.encoder.layer.0.attention.

Evaluating: 100%|██████████| 200/200 [02:24<00:00,  1.39it/s]


Loss: 1.1752
Precision: 0.6840, Recall: 0.6382, F1-Score: 0.6415
              precision    recall  f1-score   support

           0       0.62      0.51      0.56       797
           1       0.83      0.41      0.55       775
           2       0.79      0.81      0.80       795
           3       0.88      0.69      0.77      1110
           4       0.61      0.83      0.70      1260
           5       0.87      0.57      0.69       882
           6       0.84      0.59      0.69       940
           7       0.32      0.47      0.38       473
           8       0.53      0.77      0.62       746
           9       0.46      0.55      0.50       689
          10       0.59      0.70      0.64       670
          11       0.49      0.54      0.51       312
          12       0.44      0.75      0.55       665
          13       0.88      0.51      0.65       314
          14       0.84      0.64      0.72       756
          15       0.97      0.88      0.93      1607

    accuracy   

In [6]:
for idx, batch in enumerate(negative_samples):
    input_ids, attn_mask, _, total_sampled = batch
    with torch.no_grad():
        ti.propagate(module, input_ids)
    # if idx % eval_step:
    #     evaluate_model(module, model_config, test_dataloader)
result = evaluate_model(module, model_config, test_dataloader)

Evaluating: 100%|██████████| 200/200 [02:37<00:00,  1.27it/s]


Loss: 1.0054
Precision: 0.7293, Recall: 0.7089, F1-Score: 0.7091
              precision    recall  f1-score   support

           0       0.65      0.57      0.61       797
           1       0.85      0.51      0.63       775
           2       0.86      0.83      0.84       795
           3       0.87      0.74      0.80      1110
           4       0.68      0.85      0.75      1260
           5       0.89      0.65      0.75       882
           6       0.85      0.65      0.74       940
           7       0.44      0.47      0.45       473
           8       0.54      0.82      0.65       746
           9       0.47      0.71      0.56       689
          10       0.71      0.74      0.72       670
          11       0.63      0.64      0.64       312
          12       0.60      0.74      0.67       665
          13       0.81      0.79      0.80       314
          14       0.84      0.73      0.78       756
          15       0.96      0.91      0.93      1607

    accuracy   

In [7]:
print(get_sparsity(module))

(0.5096870847372806, {'bert.encoder.layer.0.attention.self.query.weight': 0.5989583333333334, 'bert.encoder.layer.0.attention.self.query.bias': 0.0, 'bert.encoder.layer.0.attention.self.key.weight': 0.5989583333333334, 'bert.encoder.layer.0.attention.self.key.bias': 0.0, 'bert.encoder.layer.0.attention.self.value.weight': 0.5989634195963541, 'bert.encoder.layer.0.attention.self.value.bias': 0.0, 'bert.encoder.layer.0.attention.output.dense.weight': 0.5989583333333334, 'bert.encoder.layer.0.attention.output.dense.bias': 0.0, 'bert.encoder.layer.0.intermediate.dense.weight': 0.46533838907877606, 'bert.encoder.layer.0.intermediate.dense.bias': 0.0, 'bert.encoder.layer.0.output.dense.weight': 0.4637468126085069, 'bert.encoder.layer.0.output.dense.bias': 0.0, 'bert.encoder.layer.1.attention.self.query.weight': 0.5989583333333334, 'bert.encoder.layer.1.attention.self.query.bias': 0.0, 'bert.encoder.layer.1.attention.self.key.weight': 0.5989583333333334, 'bert.encoder.layer.1.attention.self.k

In [8]:
module = copy.deepcopy(model)
prune_magnitude(module, include_layers=["attention", "intermediate", "output"], sparsity_ratio=0.1)
print(get_sparsity(module)[0])
print("Start Positive CI after sparse")

prune_ci(
    model,
    module,
    positive_samples,
    include_layers=["attention", "intermediate", "output"],
    sparsity_ratio=0.6,
    p=1,
)

print(get_sparsity(module))
result = evaluate_model(module, model_config, test_dataloader)
torch.cuda.empty_cache()

0.09919858441371328
Start Positive CI after sparse


RuntimeError: torch.cat(): expected a non-empty list of Tensors