In [1]:
import os
import sys
sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_wanda
)

In [3]:
name= "IMDB"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

checkpoint = None
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]

model, tokenizer, checkpoint = load_model(model_config)

Loading the model.
{'model_name': 'textattack/bert-base-uncased-imdb', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'IMDB', 'num_labels': 2, 'cache_dir': 'Models'}
The model textattack/bert-base-uncased-imdb is loaded.


In [4]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=32, num_workers=48
)

{'dataset_name': 'IMDB', 'path': 'imdb', 'config_name': 'plain_text', 'text_column': 'text', 'label_column': 'label', 'cache_dir': 'Datasets/IMDB', 'task_type': 'classification'}
Loading cached dataset IMDB.
The dataset IMDB is loaded


In [5]:
color_print("Start Time:" + datetime.now().strftime("%H:%M:%S"))
num_samples = 64

all_samples = SamplingDataset(
    train_dataloader, 200, 20, num_labels, False, 4, device=device
)

Start Time:20:51:06


In [6]:
print("Evaluate the original model")
result = evaluate_model(model, model_config, test_dataloader)

Evaluate the original model


Evaluating: 100%|██████████| 782/782 [05:48<00:00,  2.24it/s]


Loss: 0.3423
Precision: 0.9306, Recall: 0.9303, F1-Score: 0.9303
              precision    recall  f1-score   support

           0       0.92      0.94      0.93     12500
           1       0.94      0.92      0.93     12500

    accuracy                           0.93     25000
   macro avg       0.93      0.93      0.93     25000
weighted avg       0.93      0.93      0.93     25000



In [7]:
print("Evaluate the pruned model")
prune_wanda(model, all_samples, sparsity_ratio=0.4)
result = evaluate_model(model, model_config, test_dataloader)
save_module(model, "Modules/", "wanda_imdb_40p.pt")

Evaluate the pruned model


Evaluating: 100%|██████████| 782/782 [05:45<00:00,  2.26it/s]


Loss: 0.3012
Precision: 0.9317, Recall: 0.9315, F1-Score: 0.9315
              precision    recall  f1-score   support

           0       0.92      0.94      0.93     12500
           1       0.94      0.92      0.93     12500

    accuracy                           0.93     25000
   macro avg       0.93      0.93      0.93     25000
weighted avg       0.93      0.93      0.93     25000



In [9]:
get_sparsity(model)

(0.39934869552794994,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.3997395833333333,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.3997395833333333,
  'bert.e