In [1]:
import os
import sys
sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_magnitude
)

In [3]:
name= "IMDB"
device = torch.device("cuda:0")
checkpoint = None
batch_size=32
num_workers=48
num_samples=16
magnitude_ratio=0.3
seed=44
include_layers=["attention", "intermediate", "output"]
exclude_layers=None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-19 22:53:34


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'textattack/bert-base-uncased-imdb', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'IMDB', 'num_labels': 2, 'cache_dir': 'Models'}




The model textattack/bert-base-uncased-imdb is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'IMDB', 'path': 'imdb', 'config_name': 'plain_text', 'text_column': 'text', 'label_column': 'label', 'cache_dir': 'Datasets/IMDB', 'task_type': 'classification'}




Downloading the Dataset IMDB




Downloading data:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Tokenizing dataset:   0%|          | 0/25000 [00:00<?, ?it/s]

Tokenizing dataset:   0%|          | 43/25000 [00:00<00:58, 429.58it/s]

Tokenizing dataset:   0%|          | 112/25000 [00:00<00:43, 576.46it/s]

Tokenizing dataset:   1%|          | 190/25000 [00:00<00:37, 668.68it/s]

Tokenizing dataset:   1%|          | 263/25000 [00:00<00:35, 691.23it/s]

Tokenizing dataset:   1%|?둞         | 347/25000 [00:00<00:33, 742.94it/s]

Tokenizing dataset:   2%|?둞         | 423/25000 [00:00<00:32, 745.38it/s]

Tokenizing dataset:   2%|?둞         | 498/25000 [00:00<00:33, 731.22it/s]

Tokenizing dataset:   2%|?둞         | 572/25000 [00:00<00:33, 718.50it/s]

Tokenizing dataset:   3%|?둝         | 647/25000 [00:00<00:33, 726.46it/s]

Tokenizing dataset:   3%|?둝         | 720/25000 [00:01<00:35, 685.66it/s]

Tokenizing dataset:   3%|?둝         | 790/25000 [00:01<00:36, 656.54it/s]

Tokenizing dataset:   3%|?둝         | 865/25000 [00:01<00:35, 681.18it/s]

Tokenizing dataset:   4%|?둜         | 939/25000 [00:01<00:34, 696.63it/s]

Tokenizing dataset:   4%|?둜         | 1016/25000 [00:01<00:33, 715.32it/s]

Tokenizing dataset:   4%|?둜         | 1094/25000 [00:01<00:32, 732.31it/s]

Tokenizing dataset:   5%|?둜         | 1173/25000 [00:01<00:31, 748.93it/s]

Tokenizing dataset:   5%|?둜         | 1249/25000 [00:01<00:32, 724.12it/s]

Tokenizing dataset:   5%|?둛         | 1327/25000 [00:01<00:32, 738.52it/s]

Tokenizing dataset:   6%|?둛         | 1402/25000 [00:01<00:32, 727.99it/s]

Tokenizing dataset:   6%|?둛         | 1477/25000 [00:02<00:32, 731.14it/s]

Tokenizing dataset:   6%|?둛         | 1551/25000 [00:02<00:32, 717.89it/s]

Tokenizing dataset:   6%|?둚         | 1623/25000 [00:02<00:32, 711.36it/s]

Tokenizing dataset:   7%|?둚         | 1710/25000 [00:02<00:30, 756.55it/s]

Tokenizing dataset:   7%|?둚         | 1787/25000 [00:02<00:30, 758.40it/s]

Tokenizing dataset:   7%|?둚         | 1865/25000 [00:02<00:30, 761.16it/s]

Tokenizing dataset:   8%|?둙         | 1942/25000 [00:02<00:31, 728.14it/s]

Tokenizing dataset:   8%|?둙         | 2017/25000 [00:02<00:31, 734.15it/s]

Tokenizing dataset:   8%|?둙         | 2091/25000 [00:02<00:33, 685.41it/s]

Tokenizing dataset:   9%|?둙         | 2161/25000 [00:03<00:33, 679.37it/s]

Tokenizing dataset:   9%|?둘         | 2230/25000 [00:03<00:33, 677.43it/s]

Tokenizing dataset:   9%|?둘         | 2332/25000 [00:03<00:29, 775.25it/s]

Tokenizing dataset:  10%|?둘         | 2458/25000 [00:03<00:24, 907.52it/s]

Tokenizing dataset:  10%|?둗         | 2570/25000 [00:03<00:23, 968.14it/s]

Tokenizing dataset:  11%|?둗         | 2697/25000 [00:03<00:21, 1055.61it/s]

Tokenizing dataset:  11%|?둗         | 2807/25000 [00:03<00:20, 1063.55it/s]

Tokenizing dataset:  12%|?둗?둞        | 2926/25000 [00:03<00:20, 1099.07it/s]

Tokenizing dataset:  12%|?둗?둞        | 3037/25000 [00:03<00:21, 1040.87it/s]

Tokenizing dataset:  13%|?둗?둝        | 3142/25000 [00:03<00:21, 1020.22it/s]

Tokenizing dataset:  13%|?둗?둝        | 3245/25000 [00:04<00:21, 1011.21it/s]

Tokenizing dataset:  13%|?둗?둝        | 3347/25000 [00:04<00:22, 971.33it/s] 

Tokenizing dataset:  14%|?둗?둜        | 3447/25000 [00:04<00:22, 977.83it/s]

Tokenizing dataset:  14%|?둗?둜        | 3546/25000 [00:04<00:22, 969.15it/s]

Tokenizing dataset:  15%|?둗?둜        | 3644/25000 [00:04<00:22, 962.44it/s]

Tokenizing dataset:  15%|?둗?둜        | 3741/25000 [00:04<00:22, 934.70it/s]

Tokenizing dataset:  15%|?둗?둛        | 3836/25000 [00:04<00:22, 938.82it/s]

Tokenizing dataset:  16%|?둗?둛        | 3931/25000 [00:04<00:22, 939.00it/s]

Tokenizing dataset:  16%|?둗?둛        | 4029/25000 [00:04<00:22, 949.51it/s]

Tokenizing dataset:  16%|?둗?둚        | 4125/25000 [00:05<00:22, 938.68it/s]

Tokenizing dataset:  17%|?둗?둚        | 4225/25000 [00:05<00:21, 955.99it/s]

Tokenizing dataset:  17%|?둗?둚        | 4326/25000 [00:05<00:21, 971.58it/s]

Tokenizing dataset:  18%|?둗?둙        | 4424/25000 [00:05<00:21, 972.18it/s]

Tokenizing dataset:  18%|?둗?둙        | 4522/25000 [00:05<00:27, 742.22it/s]

Tokenizing dataset:  18%|?둗?둙        | 4605/25000 [00:05<00:27, 749.43it/s]

Tokenizing dataset:  19%|?둗?둘        | 4714/25000 [00:05<00:24, 836.15it/s]

Tokenizing dataset:  19%|?둗?둘        | 4826/25000 [00:05<00:22, 908.85it/s]

Tokenizing dataset:  20%|?둗?둘        | 4927/25000 [00:05<00:21, 936.65it/s]

Tokenizing dataset:  20%|?둗?둗        | 5049/25000 [00:06<00:19, 1015.76it/s]

Tokenizing dataset:  21%|?둗?둗        | 5176/25000 [00:06<00:18, 1087.61it/s]

Tokenizing dataset:  21%|?둗?둗        | 5288/25000 [00:06<00:17, 1096.00it/s]

Tokenizing dataset:  22%|?둗?둗?둞       | 5404/25000 [00:06<00:17, 1112.91it/s]

Tokenizing dataset:  22%|?둗?둗?둞       | 5517/25000 [00:06<00:18, 1073.58it/s]

Tokenizing dataset:  23%|?둗?둗?둝       | 5626/25000 [00:06<00:18, 1054.84it/s]

Tokenizing dataset:  23%|?둗?둗?둝       | 5733/25000 [00:06<00:18, 1043.19it/s]

Tokenizing dataset:  23%|?둗?둗?둝       | 5846/25000 [00:06<00:17, 1067.90it/s]

Tokenizing dataset:  24%|?둗?둗?둜       | 5954/25000 [00:06<00:17, 1062.93it/s]

Tokenizing dataset:  24%|?둗?둗?둜       | 6070/25000 [00:06<00:17, 1088.27it/s]

Tokenizing dataset:  25%|?둗?둗?둜       | 6180/25000 [00:07<00:18, 1012.03it/s]

Tokenizing dataset:  25%|?둗?둗?둛       | 6283/25000 [00:07<00:19, 947.39it/s] 

Tokenizing dataset:  26%|?둗?둗?둛       | 6380/25000 [00:07<00:20, 923.42it/s]

Tokenizing dataset:  26%|?둗?둗?둛       | 6474/25000 [00:07<00:22, 828.04it/s]

Tokenizing dataset:  26%|?둗?둗?둛       | 6559/25000 [00:07<00:23, 799.87it/s]

Tokenizing dataset:  27%|?둗?둗?둚       | 6641/25000 [00:07<00:24, 753.27it/s]

Tokenizing dataset:  27%|?둗?둗?둚       | 6718/25000 [00:07<00:25, 726.48it/s]

Tokenizing dataset:  27%|?둗?둗?둚       | 6795/25000 [00:07<00:24, 735.90it/s]

Tokenizing dataset:  27%|?둗?둗?둚       | 6870/25000 [00:08<00:24, 734.87it/s]

Tokenizing dataset:  28%|?둗?둗?둙       | 6944/25000 [00:08<00:25, 721.52it/s]

Tokenizing dataset:  28%|?둗?둗?둙       | 7017/25000 [00:08<00:25, 694.05it/s]

Tokenizing dataset:  28%|?둗?둗?둙       | 7094/25000 [00:08<00:25, 714.31it/s]

Tokenizing dataset:  29%|?둗?둗?둙       | 7172/25000 [00:08<00:24, 730.99it/s]

Tokenizing dataset:  29%|?둗?둗?둘       | 7246/25000 [00:08<00:25, 700.69it/s]

Tokenizing dataset:  29%|?둗?둗?둘       | 7323/25000 [00:08<00:24, 719.67it/s]

Tokenizing dataset:  30%|?둗?둗?둘       | 7410/25000 [00:08<00:23, 758.62it/s]

Tokenizing dataset:  30%|?둗?둗?둗       | 7500/25000 [00:08<00:21, 798.84it/s]

Tokenizing dataset:  30%|?둗?둗?둗       | 7582/25000 [00:08<00:21, 804.20it/s]

Tokenizing dataset:  31%|?둗?둗?둗       | 7676/25000 [00:09<00:20, 837.94it/s]

Tokenizing dataset:  31%|?둗?둗?둗       | 7765/25000 [00:09<00:20, 851.70it/s]

Tokenizing dataset:  31%|?둗?둗?둗?둞      | 7854/25000 [00:09<00:19, 861.65it/s]

Tokenizing dataset:  32%|?둗?둗?둗?둞      | 7941/25000 [00:09<00:20, 830.71it/s]

Tokenizing dataset:  32%|?둗?둗?둗?둞      | 8032/25000 [00:09<00:19, 852.44it/s]

Tokenizing dataset:  32%|?둗?둗?둗?둝      | 8125/25000 [00:09<00:19, 874.20it/s]

Tokenizing dataset:  33%|?둗?둗?둗?둝      | 8219/25000 [00:09<00:18, 891.78it/s]

Tokenizing dataset:  33%|?둗?둗?둗?둝      | 8309/25000 [00:09<00:18, 889.52it/s]

Tokenizing dataset:  34%|?둗?둗?둗?둝      | 8399/25000 [00:09<00:19, 862.48it/s]

Tokenizing dataset:  34%|?둗?둗?둗?둜      | 8489/25000 [00:10<00:18, 871.73it/s]

Tokenizing dataset:  34%|?둗?둗?둗?둜      | 8591/25000 [00:10<00:17, 912.76it/s]

Tokenizing dataset:  35%|?둗?둗?둗?둜      | 8687/25000 [00:10<00:17, 925.85it/s]

Tokenizing dataset:  35%|?둗?둗?둗?둛      | 8780/25000 [00:10<00:17, 919.30it/s]

Tokenizing dataset:  35%|?둗?둗?둗?둛      | 8873/25000 [00:10<00:18, 877.77it/s]

In [None]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [None]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [None]:
# Evaluate the original model
# Evaluating: 100%|?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗?둗| 782/782 [05:36<00:00,  2.32it/s]
# Loss: 0.3423
# Precision: 0.9306, Recall: 0.9303, F1-Score: 0.9303
#               precision    recall  f1-score   support

#            0       0.92      0.94      0.93     12500
#            1       0.94      0.92      0.93     12500

#     accuracy                           0.93     25000
#    macro avg       0.93      0.93      0.93     25000
# weighted avg       0.93      0.93      0.93     25000

In [None]:
module = copy.deepcopy(model)
prune_magnitude(module, sparsity_ratio=magnitude_ratio, include_layers=include_layers, exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(model, model_config, test_dataloader)
# save_module(module, "Modules/", f"magnitude_{name}_{magnitude_ratio}p.pt")

In [None]:
for concern in range(num_labels):
    print(f"--{concern}--")
    positive_samples = SamplingDataset(
        train_dataloader, concern, num_samples, num_labels, True, 4, device=device, resample=False, seed=seed
    )
    negative_samples = SamplingDataset(
        train_dataloader, concern, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
    )
    similar(model, module, valid_dataloader, concern, num_samples, num_labels, device=device, seed=seed)

In [None]:
get_sparsity(module)