In [1]:
import os
import sys
sys.path.append("../../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.save_module import save_module
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity, similar
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import (
    prune_magnitude
)

In [3]:
name= "OSDG"
device = torch.device("cuda:0")
checkpoint = None
batch_size=16
num_workers=4
num_samples=16
magnitude_ratio=0.3
seed=44
include_layers=["attention", "intermediate", "output"]
exclude_layers=None

In [4]:
script_start_time = datetime.now()
print(f"Script started at: {script_start_time.strftime('%Y-%m-%d %H:%M:%S')}")

Script started at: 2024-08-21 19:16:40


In [5]:
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]
model, tokenizer, checkpoint = load_model(model_config)

Loading the model.




{'model_name': 'sadickam/sdg-classification-bert', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'OSDG', 'num_labels': 16, 'cache_dir': 'Models'}




The model sadickam/sdg-classification-bert is loaded.




In [6]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=batch_size, num_workers=num_workers, do_cache=True, seed=seed
)

{'dataset_name': 'OSDG', 'path': 'albertmartinez/OSDG', 'config_name': '2024-01-01', 'text_column': 'text', 'label_column': 'labels', 'cache_dir': 'Datasets/OSDG', 'task_type': 'classification'}




Loading cached dataset OSDG.




The dataset OSDG is loaded




In [7]:
all_samples = SamplingDataset(
    train_dataloader, 200, num_samples, num_labels, False, 4, device=device, resample=False, seed=seed
)

In [8]:
# print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

In [9]:
module = copy.deepcopy(model)
prune_magnitude(module, sparsity_ratio=magnitude_ratio, include_layers=include_layers, exclude_layers=exclude_layers)
print("Evaluate the pruned model")
result = evaluate_model(model, model_config, test_dataloader)
# save_module(module, "Modules/", f"magnitude_{name}_{magnitude_ratio}p.pt")

Evaluate the pruned model




Evaluating:   0%|                                                                              | 0/800 [00:13<…

  attn_output = torch.nn.functional.scaled_dot_product_attention(


Loss: 0.9478




Precision: 0.7801, Recall: 0.7867, F1-Score: 0.7793




              precision    recall  f1-score   support

           0       0.77      0.66      0.71       797
           1       0.84      0.72      0.78       775
           2       0.88      0.87      0.88       795
           3       0.87      0.83      0.85      1110
           4       0.86      0.80      0.83      1260
           5       0.88      0.69      0.77       882
           6       0.85      0.80      0.83       940
           7       0.49      0.61      0.54       473
           8       0.66      0.85      0.74       746
           9       0.62      0.73      0.67       689
          10       0.75      0.79      0.77       670
          11       0.62      0.81      0.70       312
          12       0.73      0.81      0.77       665
          13       0.83      0.85      0.84       314
          14       0.85      0.78      0.81       756
          15       0.97      0.98      0.97      1607

    accuracy                           0.80     12791
   macro avg       0.78   




In [10]:
for concern in range(num_labels):
    print(f"--{concern}--")
    valid = copy.deepcopy(valid_dataloader)
    similar(model, module, valid, concern, num_samples, num_labels, device=device, seed=seed)

--0--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8425177702550711, 0.8425177702550711)




CCA coefficients mean non-concern: (0.8437628650424857, 0.8437628650424857)




Linear CKA concern: 0.974047125844879




Linear CKA non-concern: 0.9695997285310273




Kernel CKA concern: 0.9691173875936028




Kernel CKA non-concern: 0.9709354245714679




--1--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8449788607682597, 0.8449788607682597)




CCA coefficients mean non-concern: (0.8404295799943744, 0.8404295799943744)




Linear CKA concern: 0.9700331745913581




Linear CKA non-concern: 0.971111945390472




Kernel CKA concern: 0.9657701265693169




Kernel CKA non-concern: 0.9717987504187973




--2--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8368836462414982, 0.8368836462414982)




CCA coefficients mean non-concern: (0.8434139447989739, 0.8434139447989739)




Linear CKA concern: 0.9803780812611076




Linear CKA non-concern: 0.9714448112696421




Kernel CKA concern: 0.9734826144705763




Kernel CKA non-concern: 0.971727141487986




--3--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8388353549045644, 0.8388353549045644)




CCA coefficients mean non-concern: (0.842471827220395, 0.842471827220395)




Linear CKA concern: 0.9692589211022596




Linear CKA non-concern: 0.9718415105458642




Kernel CKA concern: 0.961496224564721




Kernel CKA non-concern: 0.9720107468473534




--4--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.851221034683349, 0.851221034683349)




CCA coefficients mean non-concern: (0.8419762044149111, 0.8419762044149111)




Linear CKA concern: 0.9797643899263282




Linear CKA non-concern: 0.9716559477201554




Kernel CKA concern: 0.9740635115502648




Kernel CKA non-concern: 0.9720533694857885




--5--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8390722884335842, 0.8390722884335842)




CCA coefficients mean non-concern: (0.8428627482975307, 0.8428627482975307)




Linear CKA concern: 0.973325285405836




Linear CKA non-concern: 0.9726175334512481




Kernel CKA concern: 0.9658387349701665




Kernel CKA non-concern: 0.9730810618433041




--6--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8373804306633096, 0.8373804306633096)




CCA coefficients mean non-concern: (0.8438143688712759, 0.8438143688712759)




Linear CKA concern: 0.9731672286510377




Linear CKA non-concern: 0.9727729306606749




Kernel CKA concern: 0.9693847621364344




Kernel CKA non-concern: 0.9730976912226128




--7--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8406055416335882, 0.8406055416335882)




CCA coefficients mean non-concern: (0.8421511655320231, 0.8421511655320231)




Linear CKA concern: 0.9706564176909634




Linear CKA non-concern: 0.9716570253746581




Kernel CKA concern: 0.9640915375682345




Kernel CKA non-concern: 0.9724239537868159




--8--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8313582733109359, 0.8313582733109359)




CCA coefficients mean non-concern: (0.8416036449440765, 0.8416036449440765)




Linear CKA concern: 0.9732075530740117




Linear CKA non-concern: 0.9711597174853525




Kernel CKA concern: 0.9672119493003865




Kernel CKA non-concern: 0.9713493840345343




--9--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8438123653876867, 0.8438123653876867)




CCA coefficients mean non-concern: (0.8440440549614383, 0.8440440549614383)




Linear CKA concern: 0.9746912579909466




Linear CKA non-concern: 0.9717290111024878




Kernel CKA concern: 0.967105674738507




Kernel CKA non-concern: 0.9725359788079635




--10--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8457183416751295, 0.8457183416751295)




CCA coefficients mean non-concern: (0.8425750405025146, 0.8425750405025146)




Linear CKA concern: 0.9731416553496997




Linear CKA non-concern: 0.9713018413698682




Kernel CKA concern: 0.9683555997213098




Kernel CKA non-concern: 0.9720997175845797




--11--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8390322798661889, 0.8390322798661889)




CCA coefficients mean non-concern: (0.8425083301119427, 0.8425083301119427)




Linear CKA concern: 0.9759281510101658




Linear CKA non-concern: 0.9714888516846416




Kernel CKA concern: 0.9683649579135907




Kernel CKA non-concern: 0.9717966319927194




--12--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8347351812102105, 0.8347351812102105)




CCA coefficients mean non-concern: (0.8424964038638293, 0.8424964038638293)




Linear CKA concern: 0.9728171508637563




Linear CKA non-concern: 0.9714043326108389




Kernel CKA concern: 0.9683001298036712




Kernel CKA non-concern: 0.9716960864627124




--13--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8369243578738818, 0.8369243578738818)




CCA coefficients mean non-concern: (0.8420296887916204, 0.8420296887916204)




Linear CKA concern: 0.9753658705052135




Linear CKA non-concern: 0.9716954363833203




Kernel CKA concern: 0.9654475537643427




Kernel CKA non-concern: 0.9719369404446215




--14--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8448082645603373, 0.8448082645603373)




CCA coefficients mean non-concern: (0.8416571344181075, 0.8416571344181075)




Linear CKA concern: 0.9747433403365108




Linear CKA non-concern: 0.9707662980798664




Kernel CKA concern: 0.9711202760851374




Kernel CKA non-concern: 0.9710360700606017




--15--




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




adding eps to diagonal and taking inverse




taking square root




dot products...




trying to take final svd




computed everything!




CCA coefficients mean concern: (0.8429579871765327, 0.8429579871765327)




CCA coefficients mean non-concern: (0.842164316185835, 0.842164316185835)




Linear CKA concern: 0.9627922693401134




Linear CKA non-concern: 0.9726151297753707




Kernel CKA concern: 0.9558722063956742




Kernel CKA non-concern: 0.9729758519191727




In [11]:
get_sparsity(module)

(0.29759659416128587,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.2999996609157986,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.2999996609157986,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.2999996609157986,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.2999996609157986,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.2999996609157986,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.2999996609157986,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.2999996609157986,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.2999996609157986,
  'bert.e