In [1]:
import os
import sys
sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.save_module import save_module
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_magnitude,
    prune_concern_identification,
    recover_tangling_identification,
)

In [3]:
name = "IMDB"
device = torch.device("cuda:0")
checkpoint = None
batch_size=32
num_workers=48
num_samples=16
concern=0
magnitude_ratio=0.1
ci_ratio=0.4
ti_ratio=0.1
include_layers=["attention", "intermediate", "output"]

In [4]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.
{'model_name': 'textattack/bert-base-uncased-imdb', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'IMDB', 'num_labels': 2, 'cache_dir': 'Models'}
The model textattack/bert-base-uncased-imdb is loaded.


In [5]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers
)

{'dataset_name': 'IMDB', 'path': 'imdb', 'config_name': 'plain_text', 'text_column': 'text', 'label_column': 'label', 'cache_dir': 'Datasets/IMDB', 'task_type': 'classification'}
Loading cached dataset IMDB.
The dataset IMDB is loaded


In [6]:
positive_samples = SamplingDataset(
    train_dataloader, concern, num_samples, num_labels, True, 4, device=device
)
negative_samples = SamplingDataset(
    train_dataloader, concern, num_samples, num_labels, False, 4, device=device
)
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device
)

In [7]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [8]:
# Evaluate the original model
# Evaluating: 100%|█████████████████████████████████████████████████████████████████████| 782/782 [05:36<00:00,  2.32it/s]
# Loss: 0.3423
# Precision: 0.9306, Recall: 0.9303, F1-Score: 0.9303
#               precision    recall  f1-score   support

#            0       0.92      0.94      0.93     12500
#            1       0.94      0.92      0.93     12500

#     accuracy                           0.93     25000
#    macro avg       0.93      0.93      0.93     25000
# weighted avg       0.93      0.93      0.93     25000

In [9]:
module = copy.deepcopy(model)
prune_magnitude(
    module, include_layers=include_layers, sparsity_ratio=magnitude_ratio
)

In [10]:
# result = evaluate_model(module, model_config, test_dataloader)
# get_sparsity(module)
# similar(model, module, positive_samples, negative_samples, include_layers=["attention", "intermediate", "output", "pooler", "classifier"], device=device)

In [11]:
prune_concern_identification(
    model,
    module,
    model_config,
    positive_samples,
    negative_samples,
    include_layers=include_layers,
    sparsity_ratio=ci_ratio,
)

In [12]:
print("Evaluate the pruned model")
result = evaluate_model(module, model_config, test_dataloader)
get_sparsity(module)

Evaluate the pruned model


Evaluating: 100%|██████████| 782/782 [05:57<00:00,  2.19it/s]


Loss: 0.3089
Precision: 0.9286, Recall: 0.9281, F1-Score: 0.9281
              precision    recall  f1-score   support

           0       0.91      0.95      0.93     12500
           1       0.94      0.91      0.93     12500

    accuracy                           0.93     25000
   macro avg       0.93      0.93      0.93     25000
weighted avg       0.93      0.93      0.93     25000



(0.39661782603449397,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.4000481499565972,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.3997395833333333,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.3997395833333333,
  'bert.e

In [13]:
similar(model, module, positive_samples, negative_samples, include_layers=["attention", "intermediate", "output", "pooler", "classifier"], device=device)

Cosine similarity for concern bert.encoder.layer.0.attention.self.query: 0.666
Cosine similarity for non_concern bert.encoder.layer.0.attention.self.query: 0.755
Cosine similarity for concern bert.encoder.layer.0.attention.self.key: 0.616
Cosine similarity for non_concern bert.encoder.layer.0.attention.self.key: 0.728
Cosine similarity for concern bert.encoder.layer.0.attention.self.value: 0.331
Cosine similarity for non_concern bert.encoder.layer.0.attention.self.value: 0.506
Cosine similarity for concern bert.encoder.layer.0.attention.output.dense: 0.399
Cosine similarity for non_concern bert.encoder.layer.0.attention.output.dense: 0.483
Cosine similarity for concern bert.encoder.layer.0.intermediate.dense: 0.866
Cosine similarity for non_concern bert.encoder.layer.0.intermediate.dense: 0.902
Cosine similarity for concern bert.encoder.layer.0.output.dense: 0.373
Cosine similarity for non_concern bert.encoder.layer.0.output.dense: 0.547
Cosine similarity for concern bert.encoder.layer

In [14]:
recover_tangling_identification(
    model,
    module,
    model_config,
    negative_samples,
    recovery_ratio=ti_ratio,
    include_layers=include_layers,
)

In [15]:
result = evaluate_model(module, model_config, test_dataloader)

Evaluating: 100%|██████████| 782/782 [06:12<00:00,  2.10it/s]


Loss: 0.3349
Precision: 0.9288, Recall: 0.9280, F1-Score: 0.9279
              precision    recall  f1-score   support

           0       0.91      0.95      0.93     12500
           1       0.95      0.91      0.93     12500

    accuracy                           0.93     25000
   macro avg       0.93      0.93      0.93     25000
weighted avg       0.93      0.93      0.93     25000



In [16]:
similar(model, module, positive_samples, negative_samples, include_layers=["attention", "intermediate", "output", "pooler", "classifier"], device=device)

Cosine similarity for concern bert.encoder.layer.0.attention.self.query: 0.683
Cosine similarity for non_concern bert.encoder.layer.0.attention.self.query: 0.709
Cosine similarity for concern bert.encoder.layer.0.attention.self.key: 0.635
Cosine similarity for non_concern bert.encoder.layer.0.attention.self.key: 0.665
Cosine similarity for concern bert.encoder.layer.0.attention.self.value: 0.363
Cosine similarity for non_concern bert.encoder.layer.0.attention.self.value: 0.411
Cosine similarity for concern bert.encoder.layer.0.attention.output.dense: 0.427
Cosine similarity for non_concern bert.encoder.layer.0.attention.output.dense: 0.451
Cosine similarity for concern bert.encoder.layer.0.intermediate.dense: 0.873
Cosine similarity for non_concern bert.encoder.layer.0.intermediate.dense: 0.881
Cosine similarity for concern bert.encoder.layer.0.output.dense: 0.404
Cosine similarity for non_concern bert.encoder.layer.0.output.dense: 0.450
Cosine similarity for concern bert.encoder.layer

In [17]:
get_sparsity(module)

(0.2974067667308377,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.2997402615017361,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.2997402615017361,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.300048828125,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.2997402615017361,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.2997398376464844,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.2997398376464844,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.2997402615017361,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.2997402615017361,
  'bert.encode

In [18]:
# save_module(module, "Modules/", f"citi_{name}_{ci_ratio-ti_ratio}p.pt")

In [19]:
        # importance_score = torch.abs(current_weight) * torch.abs(2 * a - b)