In [3]:
import copy
import os
import sys
sys.path.append(os.path.dirname(os.getcwd()))

In [4]:
import utils
import torch

In [5]:
from utils.model_utils.load_model import load_model

In [6]:
from utils.model_utils.evaluate import evaluate_model

In [7]:
from utils.helper import ModelConfig
from utils.decompose_utils.sampling import sampling_class

In [8]:
name = "OSDG"
device = torch.device("cuda:0")

In [9]:
from utils.dataset_utils.load_dataset import load_data

In [10]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]

In [9]:
model, tokenizer, _ = load_model(model_config)

Loading the model.
{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}
The model sadickam/sdg-classification-bert is loaded.


In [10]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
        name, batch_size=32
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}
Loading cached dataset OSDG.
The dataset OSDG is loaded


In [11]:
all_samples = sampling_class(
    train_dataloader, 200, 20, num_labels, False, 4, device=device
)

In [12]:
import torch.nn as nn

In [13]:
from utils.prune_utils.prune import prune_wanda, prune_magnitude

In [14]:
from utils.model_utils.evaluate import get_sparsity

In [15]:
prune_magnitude(model, sparsity_ratio=0.4)

In [16]:
evaluate_model(model, model_config, test_dataloader)

Evaluating: 100%|██████████| 400/400 [03:48<00:00,  1.75it/s]


Loss: 0.8905
Precision: 0.7751, Recall: 0.7771, F1-Score: 0.7719
              precision    recall  f1-score   support

           0       0.76      0.64      0.70       797
           1       0.84      0.71      0.77       775
           2       0.88      0.88      0.88       795
           3       0.87      0.81      0.84      1110
           4       0.83      0.81      0.82      1260
           5       0.89      0.69      0.78       882
           6       0.86      0.78      0.81       940
           7       0.48      0.56      0.51       473
           8       0.65      0.85      0.73       746
           9       0.56      0.75      0.64       689
          10       0.75      0.78      0.76       670
          11       0.69      0.78      0.73       312
          12       0.68      0.81      0.74       665
          13       0.83      0.86      0.84       314
          14       0.85      0.78      0.82       756
          15       0.98      0.96      0.97      1607

    accuracy   

{'loss': 0.8905173048004508,
 'precision': 0.775142192944216,
 'recall': 0.777113407064294,
 'f1_score': 0.7719472762360807,
 'report': '              precision    recall  f1-score   support\n\n           0       0.76      0.64      0.70       797\n           1       0.84      0.71      0.77       775\n           2       0.88      0.88      0.88       795\n           3       0.87      0.81      0.84      1110\n           4       0.83      0.81      0.82      1260\n           5       0.89      0.69      0.78       882\n           6       0.86      0.78      0.81       940\n           7       0.48      0.56      0.51       473\n           8       0.65      0.85      0.73       746\n           9       0.56      0.75      0.64       689\n          10       0.75      0.78      0.76       670\n          11       0.69      0.78      0.73       312\n          12       0.68      0.81      0.74       665\n          13       0.83      0.86      0.84       314\n          14       0.85      0.78   

In [17]:
get_sparsity(model)

(0.3124789943871541,
 {'bert.embeddings.word_embeddings.weight': 0.0,
  'bert.embeddings.position_embeddings.weight': 0.0,
  'bert.embeddings.token_type_embeddings.weight': 0.0,
  'bert.embeddings.LayerNorm.weight': 0.0,
  'bert.embeddings.LayerNorm.bias': 0.0,
  'bert.encoder.layer.0.attention.self.query.weight': 0.4000006781684028,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.4000006781684028,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.4000006781684028,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.4000006781684028,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.attention.output.LayerNorm.weight': 0.0,
  'bert.encoder.layer.0.attention.output.LayerNorm.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.40000025431315106,
  'bert.encod

In [13]:
def prune_magnitude(model, sparsity_ratio=0.4):
    layers = find_layers(model)
    for name, layer in layers.items():
        if any(keyword in name for keyword in ['query', 'key', 'value']):
            continue  # 'query', 'key', 'value'가 이름에 포함된 레이어는 제외
        current_weight = layer.weight.data
        threshold = torch.sort(torch.abs(current_weight).flatten())[0][
            int(current_weight.numel() * sparsity_ratio)
        ]
        mask = torch.abs(current_weight) > threshold
        layer.weight.data.mul_(mask)

In [11]:
model, tokenizer, _ = load_model(model_config)

Loading the model.
{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}
The model sadickam/sdg-classification-bert is loaded.


In [12]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
        name, batch_size=32
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}
Loading cached dataset OSDG.
The dataset OSDG is loaded


In [15]:
from utils.prune_utils.prune import find_layers

In [16]:
prune_magnitude(model, sparsity_ratio=0.4)

In [18]:
from utils.model_utils.evaluate import get_sparsity

In [19]:
get_sparsity(model)

(0.23490910195488826,
 {'bert.embeddings.word_embeddings.weight': 0.0,
  'bert.embeddings.position_embeddings.weight': 0.0,
  'bert.embeddings.token_type_embeddings.weight': 0.0,
  'bert.embeddings.LayerNorm.weight': 0.0,
  'bert.embeddings.LayerNorm.bias': 0.0,
  'bert.encoder.layer.0.attention.self.query.weight': 0.0,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.0,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.0,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.4000006781684028,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.attention.output.LayerNorm.weight': 0.0,
  'bert.encoder.layer.0.attention.output.LayerNorm.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.40000025431315106,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  

In [20]:
evaluate_model(model, model_config, test_dataloader)

Evaluating: 100%|██████████| 400/400 [04:01<00:00,  1.66it/s]


Loss: 0.8949
Precision: 0.7783, Recall: 0.7829, F1-Score: 0.7763
              precision    recall  f1-score   support

           0       0.77      0.65      0.70       797
           1       0.84      0.71      0.77       775
           2       0.88      0.88      0.88       795
           3       0.88      0.82      0.85      1110
           4       0.86      0.81      0.83      1260
           5       0.89      0.69      0.78       882
           6       0.86      0.79      0.82       940
           7       0.51      0.56      0.54       473
           8       0.64      0.86      0.73       746
           9       0.58      0.74      0.65       689
          10       0.74      0.79      0.76       670
          11       0.64      0.81      0.72       312
          12       0.71      0.80      0.75       665
          13       0.84      0.86      0.85       314
          14       0.85      0.78      0.82       756
          15       0.98      0.97      0.97      1607

    accuracy   

{'loss': 0.894946059230715,
 'precision': 0.7782835544536545,
 'recall': 0.7828768379788639,
 'f1_score': 0.7762836755597392,
 'report': '              precision    recall  f1-score   support\n\n           0       0.77      0.65      0.70       797\n           1       0.84      0.71      0.77       775\n           2       0.88      0.88      0.88       795\n           3       0.88      0.82      0.85      1110\n           4       0.86      0.81      0.83      1260\n           5       0.89      0.69      0.78       882\n           6       0.86      0.79      0.82       940\n           7       0.51      0.56      0.54       473\n           8       0.64      0.86      0.73       746\n           9       0.58      0.74      0.65       689\n          10       0.74      0.79      0.76       670\n          11       0.64      0.81      0.72       312\n          12       0.71      0.80      0.75       665\n          13       0.84      0.86      0.85       314\n          14       0.85      0.78  