In [1]:
import copy
import os
import sys
import torch
sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import utils
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity
from utils.helper import ModelConfig
from utils.dataset_utils.load_dataset import load_data

In [3]:
name= "IMDB"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_config = ModelConfig(name, device)

In [4]:
model, tokenizer, _ = load_model(model_config)

train_dataloader, valid_dataloader, test_dataloader = load_data(
        name, batch_size=32
)

Loading the model.
{'model_name': 'textattack/bert-base-uncased-imdb', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'IMDB', 'num_labels': 2, 'cache_dir': 'Models'}
The model textattack/bert-base-uncased-imdb is loaded.
{'dataset_name': 'IMDB', 'path': 'imdb', 'config_name': 'plain_text', 'text_column': 'text', 'label_column': 'label', 'cache_dir': 'Datasets/IMDB', 'task_type': 'classification'}
Loading cached dataset IMDB.
The dataset IMDB is loaded


In [5]:
from utils.prune_utils.prune import prune_magnitude

In [6]:
prune_magnitude(model, include_layers=["attention", "intermediate", "output"], sparsity_ratio=0.4)

In [7]:
result = evaluate_model(model, model_config, test_dataloader)

Evaluating: 100%|██████████| 782/782 [06:30<00:00,  2.00it/s]


Loss: 0.3078
Precision: 0.9281, Recall: 0.9278, F1-Score: 0.9278
              precision    recall  f1-score   support

           0       0.92      0.94      0.93     12500
           1       0.94      0.91      0.93     12500

    accuracy                           0.93     25000
   macro avg       0.93      0.93      0.93     25000
weighted avg       0.93      0.93      0.93     25000



In [8]:
get_sparsity(model)

(0.39684534690157297,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.3999989827473958,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.3999989827473958,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.3999989827473958,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.3999989827473958,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.39999983045789933,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.39999983045789933,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.3999989827473958,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.3999989827473958,
  'bert