In [1]:
import copy
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))

In [2]:
import utils
import torch

In [3]:
from utils.model_utils.load_model import load_model

In [4]:
from utils.model_utils.evaluate import evaluate_model

In [5]:
from utils.helper import ModelConfig
from utils.decompose_utils.sampling import sampling_class

In [6]:
name = "OSDG"
device = torch.device("cuda:0")

In [7]:
from utils.dataset_utils.load_dataset import load_data

In [8]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]

In [9]:
model, tokenizer, _ = load_model(model_config)

Loading the model.
{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}
The model sadickam/sdg-classification-bert is loaded.


In [10]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
        name, batch_size=32
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}
Loading cached dataset OSDG.
The dataset OSDG is loaded


In [11]:
all_samples = sampling_class(
    train_dataloader, 200, 20, num_labels, False, 4, device=device
)

In [12]:
import torch.nn as nn

In [13]:
from utils.prune_utils.prune import prune_wanda, prune_magnitude

In [14]:
from utils.model_utils.evaluate import get_sparsity

In [15]:
model, tokenizer, _ = load_model(model_config)

Loading the model.
{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}
The model sadickam/sdg-classification-bert is loaded.


In [16]:
prune_wanda(model, all_samples, model_config)

In [17]:
get_sparsity(model)

(0.31227521254392365,
 {'bert.embeddings.word_embeddings.weight': 0.0,
  'bert.embeddings.position_embeddings.weight': 0.0,
  'bert.embeddings.token_type_embeddings.weight': 0.0,
  'bert.embeddings.LayerNorm.weight': 0.0,
  'bert.embeddings.LayerNorm.bias': 0.0,
  'bert.encoder.layer.0.attention.self.query.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.3997395833333333,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.attention.output.LayerNorm.weight': 0.0,
  'bert.encoder.layer.0.attention.output.LayerNorm.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.3997395833333333,
  'bert.encod

In [18]:
evaluate_model(model,model_config, test_dataloader)

Evaluating: 100%|██████████| 400/400 [04:54<00:00,  1.36it/s]


Loss: 0.8860
Precision: 0.7770, Recall: 0.7790, F1-Score: 0.7743
              precision    recall  f1-score   support

           0       0.73      0.67      0.70       797
           1       0.84      0.71      0.77       775
           2       0.87      0.88      0.87       795
           3       0.87      0.83      0.85      1110
           4       0.85      0.80      0.82      1260
           5       0.90      0.68      0.77       882
           6       0.85      0.79      0.82       940
           7       0.49      0.58      0.53       473
           8       0.65      0.85      0.74       746
           9       0.58      0.73      0.65       689
          10       0.77      0.77      0.77       670
          11       0.68      0.79      0.73       312
          12       0.69      0.81      0.75       665
          13       0.85      0.84      0.85       314
          14       0.85      0.77      0.81       756
          15       0.98      0.96      0.97      1607

    accuracy   

{'loss': 0.8860318766534329,
 'precision': 0.7770481058741912,
 'recall': 0.7790442994499596,
 'f1_score': 0.7742676774035022,
 'report': '              precision    recall  f1-score   support\n\n           0       0.73      0.67      0.70       797\n           1       0.84      0.71      0.77       775\n           2       0.87      0.88      0.87       795\n           3       0.87      0.83      0.85      1110\n           4       0.85      0.80      0.82      1260\n           5       0.90      0.68      0.77       882\n           6       0.85      0.79      0.82       940\n           7       0.49      0.58      0.53       473\n           8       0.65      0.85      0.74       746\n           9       0.58      0.73      0.65       689\n          10       0.77      0.77      0.77       670\n          11       0.68      0.79      0.73       312\n          12       0.69      0.81      0.75       665\n          13       0.85      0.84      0.85       314\n          14       0.85      0.77 