In [1]:
import os
import sys
sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size=16
num_workers=4
num_samples=4
ci_ratio=0.5
seed=44
include_layers=["attention", "intermediate", "output"]
exclude_layers=None

In [None]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

In [4]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.
{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}
The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.


In [5]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [7]:
for concern in range(num_labels):
    train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True
    )
    
    positive_samples = SamplingDataset(
        train_dataloader, concern, num_samples, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train_dataloader, concern, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    all_samples = SamplingDataset(
        train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    
    module = copy.deepcopy(model)
    
    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        exclude_layers=exclude_layers,
        sparsity_ratio=ci_ratio,
    )
    
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)
    
    similar(model, module, valid_dataloader, concern, num_samples, num_labels, device=device, seed=seed)
    
    # save_module(module, "Modules/", f"ci_{name}_{ci_ratio}p.pt")

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}
Loading cached dataset YahooAnswersTopics.
The dataset YahooAnswersTopics is loaded
Evaluate the pruned model 0


Evaluating: 100%|?��?��?��?��?��?��?��?��?��?��| 1875/1875 [15:36<00:00,  2.00it/s]


Loss: 1.1605
Precision: 0.6613, Recall: 0.6402, F1-Score: 0.6417
              precision    recall  f1-score   support

           0       0.55      0.55      0.55      6000
           1       0.77      0.53      0.63      6000
           2       0.74      0.69      0.71      6000
           3       0.53      0.45      0.48      6000
           4       0.79      0.79      0.79      6000
           5       0.93      0.72      0.82      6000
           6       0.48      0.42      0.45      6000
           7       0.46      0.78      0.58      6000
           8       0.61      0.77      0.68      6000
           9       0.74      0.71      0.73      6000

    accuracy                           0.64     60000
   macro avg       0.66      0.64      0.64     60000
weighted avg       0.66      0.64      0.64     60000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

Evaluating:  14%|?��?��        | 260/1875 [02:04<12:54,  2.08it/s]


KeyboardInterrupt: 