In [1]:
import os
import sys
sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_magnitude,
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "IMDB"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

checkpoint = None
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]

model, tokenizer, checkpoint = load_model(model_config)

Loading the model.
{'model_name': 'textattack/bert-base-uncased-imdb', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'IMDB', 'num_labels': 2, 'cache_dir': 'Models'}
The model textattack/bert-base-uncased-imdb is loaded.


In [4]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=32, num_workers=48
)

{'dataset_name': 'IMDB', 'path': 'imdb', 'config_name': 'plain_text', 'text_column': 'text', 'label_column': 'label', 'cache_dir': 'Datasets/IMDB', 'task_type': 'classification'}
Loading cached dataset IMDB.
The dataset IMDB is loaded


In [5]:
i = 0
color_print("Start Time:" + datetime.now().strftime("%H:%M:%S"))
color_print("#Module " + str(i) + " in progress....")
num_samples = 64

positive_samples = SamplingDataset(
    train_dataloader, i, num_samples, num_labels, True, 4, device=device
)
negative_samples = SamplingDataset(
    train_dataloader, i, num_samples, num_labels, False, 4, device=device
)
all_samples = SamplingDataset(
    train_dataloader, 200, 20, num_labels, False, 4, device=device
)

Start Time:07:13:31
#Module 0 in progress....


In [6]:
print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

Evaluate the original model


In [7]:
module = copy.deepcopy(model)
prune_magnitude(
    module, include_layers=["attention", "intermediate", "output"], sparsity_ratio=0.1
)

In [8]:
# result = evaluate_model(module, model_config, test_dataloader)
get_sparsity(module)

(0.09921105930365626,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.09999932183159722,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.09999932183159722,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.09999932183159722,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.09999932183159722,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.09999974568684895,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.09999974568684895,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.09999932183159722,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.09999932183159722,
 

In [9]:
prune_concern_identification(
    model,
    module,
    positive_samples,
    include_layers=["attention", "intermediate", "output"],
    sparsity_ratio=0.6,
)

In [10]:
# result = evaluate_model(module, model_config, test_dataloader)
get_sparsity(module)

(0.5945614125870973,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.5990075005425347,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.5989583333333334,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.5999348958333334,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.5989583333333334,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.5989583333333334,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.5989583333333334,
  'bert.en

In [11]:
recover_tangling_identification(
    model,
    module,
    negative_samples,
    recovery_ratio=0.1,
    include_layers=["attention", "intermediate", "output"],
)

In [12]:
result = evaluate_model(module, model_config, test_dataloader)
# save_module(module, "Modules/", "citi_imdb_50p.pt")
get_sparsity(module)

Evaluating: 100%|██████████| 782/782 [06:36<00:00,  1.97it/s]


Loss: 0.6885
Precision: 0.8089, Recall: 0.7133, F1-Score: 0.6893
              precision    recall  f1-score   support

           0       0.64      0.99      0.78     12500
           1       0.98      0.44      0.60     12500

    accuracy                           0.71     25000
   macro avg       0.81      0.71      0.69     25000
weighted avg       0.81      0.71      0.69     25000



(0.495350353283441,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.4989590115017361,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.4989590115017361,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.4990081787109375,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.4989590115017361,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.4999351501464844,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.4989585876464844,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.4989590115017361,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.4989590115017361,
  'bert.enc