In [1]:
import os
import sys

sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "IMDB"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.4
ti_ratio = 0.1
seed = 44

include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [4]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'textattack/bert-base-uncased-imdb', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'IMDB', 'num_labels': 2, 'cache_dir': 'Models'}




The model textattack/bert-base-uncased-imdb is loaded.




In [5]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'IMDB', 'path': 'imdb', 'config_name': 'plain_text', 'text_column': 'text', 'label_column': 'label', 'cache_dir': 'Datasets/IMDB', 'task_type': 'classification'}




Loading cached dataset IMDB.




The dataset IMDB is loaded




In [6]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [7]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )
    # print(f"Evaluate the pruned model {concern}")
    # result = evaluate_model(module, model_config, test_dataloader)
    # similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    recover_tangling_identification(
        model,
        module,
        model_config,
        positive_samples,
        negative_samples,
        recovery_ratio=ti_ratio,
        include_layers=include_layers,
        exclude_layers=exclude_layers
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"citi_{name}_{ci_ratio-ti_ratio}p.pt")

  attn_output = torch.nn.functional.scaled_dot_product_attention(


333




Evaluate the pruned model 0




Evaluating:   0%|                                                                             | 0/1563 [00:25<…

Loss: 0.3145




Precision: 0.9300, Recall: 0.9293, F1-Score: 0.9293




              precision    recall  f1-score   support

           0       0.91      0.95      0.93     12500
           1       0.95      0.91      0.93     12500

    accuracy                           0.93     25000
   macro avg       0.93      0.93      0.93     25000
weighted avg       0.93      0.93      0.93     25000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8440066389558538, 0.8440066389558538)




CCA coefficients mean non-concern: (0.8393574314468292, 0.8393574314468292)




Linear CKA concern: 0.960461105240837




Linear CKA non-concern: 0.940832960594569




Kernel CKA concern: 0.9570607767322195




Kernel CKA non-concern: 0.941119121223524




333




Evaluate the pruned model 1




Evaluating:   0%|                                                                             | 0/1563 [00:38<…

Loss: 0.2933




Precision: 0.9319, Recall: 0.9319, F1-Score: 0.9319




              precision    recall  f1-score   support

           0       0.93      0.93      0.93     12500
           1       0.93      0.93      0.93     12500

    accuracy                           0.93     25000
   macro avg       0.93      0.93      0.93     25000
weighted avg       0.93      0.93      0.93     25000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8473124364688593, 0.8473124364688593)




CCA coefficients mean non-concern: (0.8427681895916757, 0.8427681895916757)




Linear CKA concern: 0.9657415783794806




Linear CKA non-concern: 0.927397066539306




Kernel CKA concern: 0.9650676694730327




Kernel CKA non-concern: 0.9235928884348048


