In [1]:
import os
import sys
sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_wanda
)

In [3]:
name= "YahooAnswersTopics"
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

checkpoint = None
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]

model, tokenizer, checkpoint = load_model(model_config)

Loading the model.
{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}
The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.


In [4]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=32, num_workers=48
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}
Loading cached dataset YahooAnswersTopics.
The dataset YahooAnswersTopics is loaded


In [5]:
color_print("Start Time:" + datetime.now().strftime("%H:%M:%S"))
num_samples = 64

all_samples = SamplingDataset(
    train_dataloader, 200, 20, num_labels, False, 4, device=device
)

Start Time:18:19:20


In [6]:
print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

Evaluate the original model


In [7]:
print("Evaluate the pruned model")
prune_wanda(model, all_samples, sparsity_ratio=0.4, include_layers=["attention", "intermediate", "output"])
result = evaluate_model(model, model_config, test_dataloader)
# save_module(model, "Modules/", "wanda_yahoo_40p.pt")

Evaluate the pruned model


Evaluating: 100%|██████████| 1875/1875 [21:42<00:00,  1.44it/s]


Loss: 1.0238
Precision: 0.6769, Recall: 0.6742, F1-Score: 0.6713
              precision    recall  f1-score   support

           0       0.57      0.56      0.56      6000
           1       0.73      0.64      0.68      6000
           2       0.71      0.76      0.74      6000
           3       0.54      0.51      0.52      6000
           4       0.80      0.81      0.81      6000
           5       0.91      0.81      0.85      6000
           6       0.57      0.40      0.47      6000
           7       0.57      0.75      0.65      6000
           8       0.64      0.76      0.70      6000
           9       0.73      0.75      0.74      6000

    accuracy                           0.67     60000
   macro avg       0.68      0.67      0.67     60000
weighted avg       0.68      0.67      0.67     60000



In [8]:
get_sparsity(model)

(0.3965589468552108,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.3997395833333333,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.3997395833333333,
  'bert.en