In [1]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [2]:
name = "IMDB"
device = torch.device("cuda:0")
checkpoint = None
batch_size = 16
num_workers = 4
num_samples = 16
ci_ratio = 0.3
seed = 44
include_layers = ["attention", "intermediate", "output"]
exclude_layers = None

In [3]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-31 08:28:34


In [4]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'textattack/bert-base-uncased-imdb', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'IMDB', 'num_labels': 2, 'cache_dir': 'Models'}




The model textattack/bert-base-uncased-imdb is loaded.




In [5]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'IMDB', 'path': 'imdb', 'config_name': 'plain_text', 'text_column': 'text', 'label_column': 'label', 'cache_dir': 'Datasets/IMDB', 'task_type': 'classification'}




Loading cached dataset IMDB.




The dataset IMDB is loaded




In [6]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [7]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train, concern, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train, 200, num_samples // 2, num_labels, False, 4, device=device, resample=False, seed=seed
    )

    module = copy.deepcopy(model)

    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )

    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

Evaluate the pruned model 0




Evaluating:   0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 0.7203




Precision: 0.7689, Recall: 0.6106, F1-Score: 0.5433




              precision    recall  f1-score   support

           0       0.56      0.99      0.72     12500
           1       0.98      0.23      0.37     12500

    accuracy                           0.61     25000
   macro avg       0.77      0.61      0.54     25000
weighted avg       0.77      0.61      0.54     25000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5095995852380395, 0.5095995852380395)




CCA coefficients mean non-concern: (0.5082305403405525, 0.5082305403405525)




Linear CKA concern: 0.3329378581486473




Linear CKA non-concern: 0.07023412275331505




Kernel CKA concern: 0.2597236834204278




Kernel CKA non-concern: 0.05403883750708562




Evaluate the pruned model 1




Evaluating:   0%|          | 0/1563 [00:00<?, ?it/s]

Loss: 0.5666




Precision: 0.7841, Recall: 0.6799, F1-Score: 0.6476




              precision    recall  f1-score   support

           0       0.61      0.98      0.75     12500
           1       0.96      0.38      0.54     12500

    accuracy                           0.68     25000
   macro avg       0.78      0.68      0.65     25000
weighted avg       0.78      0.68      0.65     25000





adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.5057256844665154, 0.5057256844665154)




CCA coefficients mean non-concern: (0.5062955042340898, 0.5062955042340898)




Linear CKA concern: 0.08168522159816696




Linear CKA non-concern: 0.22164780650294327




Kernel CKA concern: 0.06290294132011119




Kernel CKA non-concern: 0.16347500715876043


