In [None]:
import os
import sys
sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune_head import (
    compute_heads_importance,
    head_importance_prunning
)

In [None]:
name= "IMDB"
device = torch.device("cuda:0")
checkpoint = None
batch_size=16
num_workers=4
num_samples = 16
head_pruning_ratio = 0.3
seed = 44

In [None]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

In [None]:

model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

In [None]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

In [None]:
for concern in range(num_labels):
    train = copy.deepcopy(train_dataloader)
    valid = copy.deepcopy(valid_dataloader)
    positive_samples = SamplingDataset(
        train_dataloader, concern, num_samples, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train_dataloader, concern, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    
    module = copy.deepcopy(model)
    
    (
        attn_entropy,
        head_importance,
        preds,
        labels,
        per_class_head_importance_list,
    ) = compute_heads_importance(
        module,
        model_config,
        positive_samples,
    )

    head_importance_prunning(
        module, concern, per_class_head_importance_list, head_pruning_ratio
    )
    
    result = evaluate_model(module, model_config, test_dataloader)
    
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)

    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

    # save_module(module, "Modules/", f"head_prune_{name}_{head_pruning_ratio}p.pt")