In [1]:
import os
import sys
sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_magnitude,
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0")
checkpoint = None
batch_size=32
num_workers=48
num_samples=16
concern=0
ci_ratio=0.4
include_layers=["attention", "intermediate", "output"]

In [4]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.
{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}
The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.


In [5]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [6]:
# Evaluate the original model
# Evaluating: 100%|███████████████████████████████████████████████████████████████████| 1875/1875 [30:03<00:00,  1.04it/s]
# Loss: 1.0044
# Precision: 0.6874, Recall: 0.6865, F1-Score: 0.6839
#               precision    recall  f1-score   support

#            0       0.57      0.57      0.57      6000
#            1       0.74      0.66      0.69      6000
#            2       0.71      0.78      0.74      6000
#            3       0.54      0.53      0.53      6000
#            4       0.80      0.82      0.81      6000
#            5       0.90      0.84      0.87      6000
#            6       0.61      0.43      0.50      6000
#            7       0.62      0.73      0.67      6000
#            8       0.64      0.76      0.70      6000
#            9       0.75      0.75      0.75      6000

#     accuracy                           0.69     60000
#    macro avg       0.69      0.69      0.68     60000
# weighted avg       0.69      0.69      0.68     60000

In [None]:
for concern in range(num_labels):
    train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers
    )
    
    positive_samples = SamplingDataset(
        train_dataloader, concern, num_samples, num_labels, True, 4, device=device
    )
    negative_samples = SamplingDataset(
        train_dataloader, concern, num_samples, num_labels, False, 4, device=device
    )
    all_samples = SamplingDataset(
        train_dataloader, 200, num_samples, num_labels, False, 4, device=device
    )
    
    module = copy.deepcopy(model)
    
    prune_concern_identification(
        module,
        model_config,
        positive_samples,
        negative_samples,
        include_layers=include_layers,
        sparsity_ratio=ci_ratio,
    )
    
    print(f"Evaluate the pruned model {concern}")
    result = evaluate_model(module, model_config, test_dataloader)
    get_sparsity(module)
    
    similar(model, module, valid_dataloader, concern, num_samples, num_labels, device=device)
    
    # save_module(module, "Modules/", f"citi_{name}_{ci_ratio-ti_ratio}p.pt")

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}
Loading cached dataset YahooAnswersTopics.
The dataset YahooAnswersTopics is loaded
Evaluate the pruned model 0


Evaluating: 100%|██████████| 1875/1875 [16:27<00:00,  1.90it/s]


Loss: 1.0367
Precision: 0.6741, Recall: 0.6699, F1-Score: 0.6684
              precision    recall  f1-score   support

           0       0.57      0.56      0.56      6000
           1       0.73      0.63      0.68      6000
           2       0.71      0.75      0.73      6000
           3       0.54      0.49      0.51      6000
           4       0.80      0.80      0.80      6000
           5       0.91      0.80      0.85      6000
           6       0.53      0.42      0.47      6000
           7       0.56      0.75      0.64      6000
           8       0.64      0.76      0.69      6000
           9       0.75      0.73      0.74      6000

    accuracy                           0.67     60000
   macro avg       0.67      0.67      0.67     60000
weighted avg       0.67      0.67      0.67     60000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

Evaluating: 100%|██████████| 1875/1875 [12:40<00:00,  2.47it/s]


Loss: 1.0286
Precision: 0.6772, Recall: 0.6718, F1-Score: 0.6702
              precision    recall  f1-score   support

           0       0.55      0.57      0.56      6000
           1       0.74      0.63      0.68      6000
           2       0.71      0.76      0.74      6000
           3       0.54      0.51      0.52      6000
           4       0.81      0.80      0.81      6000
           5       0.91      0.80      0.85      6000
           6       0.57      0.40      0.47      6000
           7       0.55      0.75      0.64      6000
           8       0.65      0.75      0.70      6000
           9       0.74      0.73      0.74      6000

    accuracy                           0.67     60000
   macro avg       0.68      0.67      0.67     60000
weighted avg       0.68      0.67      0.67     60000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

Evaluating: 100%|██████████| 1875/1875 [08:52<00:00,  3.52it/s]


Loss: 1.0356
Precision: 0.6761, Recall: 0.6702, F1-Score: 0.6687
              precision    recall  f1-score   support

           0       0.55      0.57      0.56      6000
           1       0.74      0.62      0.68      6000
           2       0.72      0.76      0.74      6000
           3       0.52      0.51      0.52      6000
           4       0.81      0.80      0.81      6000
           5       0.92      0.78      0.85      6000
           6       0.57      0.40      0.47      6000
           7       0.56      0.75      0.64      6000
           8       0.64      0.76      0.69      6000
           9       0.73      0.74      0.74      6000

    accuracy                           0.67     60000
   macro avg       0.68      0.67      0.67     60000
weighted avg       0.68      0.67      0.67     60000

adding eps to diagonal and taking inverse
taking square root
dot products...
trying to take final svd
computed everything!
adding eps to diagonal and taking inverse
taking squa

Evaluating:  49%|████▊     | 912/1875 [04:13<04:22,  3.67it/s]