# Bayesian optimization trials analysis

In [1]:
import sys
sys.path.append("..")
from src.data.datamodules import SpikingDataModule
from src.training.utils import get_model, load_config
from src.training.tasks import ClassificationTask
from torch.nn import CrossEntropyLoss
import pytorch_lightning as pl
import torch


  from .autonotebook import tqdm as notebook_tqdm


### Load datasets

In [2]:
mnist_data_module = SpikingDataModule(dataset="mnist", data_dir="../data/", seed=42)
cifar10_data_module = SpikingDataModule(dataset="cifar10", data_dir="../data/", seed=42)
mnist_data_module.setup()
cifar10_data_module.setup()

### Load selected "best" models from Pareto plots

In [3]:
config_spik_mnist = "../config/optimize_spikformer_mnist.yaml"
config_srnn_mnist = "../config/optimize_srnn_mnist.yaml"
config_spik_cifar10 = "../config/optimize_spikformer_cifar10.yaml"
config_srnn_cifar10 = "../config/optimize_srnn_cifar10.yaml"

model_weights = {
    "spikformer_mnist": "../experiments/thesis_hyperparameter_search_spikformer_mnist/dymo78zm/checkpoints/epoch=9-step=3760.ckpt",
    "spikformer_cifar10": "../experiments/thesis_hyperparameter_search_spikformer_cifar10/298swbe8/checkpoints/epoch=9-step=3130.ckpt",
    "srnn_mnist": "../experiments/thesis_hyperparameter_search_srnn_mnist/zfc1310z/checkpoints/epoch=4-step=1880.ckpt",
    "srnn_cifar10": "../experiments/thesis_hyperparameter_search_srnn_cifar10/8a1y4lr1/checkpoints/epoch=4-step=1565.ckpt",
}


def get_trained_model(model_name, dataset):
    weights_path = model_weights[f"{model_name}_{dataset}"]
    state_dict = torch.load(weights_path)['state_dict']
    
    if dataset == "mnist":
        config = config_spik_mnist if model_name == "spikformer" else config_srnn_mnist
    else:
        config = config_spik_cifar10 if model_name == "spikformer" else config_srnn_cifar10

    
    model = get_model(load_config(config))
    task = ClassificationTask(model, CrossEntropyLoss(), 0, backend='spikingjelly' if model_name == "spikformer" else "pytorch")
    task.load_state_dict(state_dict)
    return task



srnn_mnist = get_trained_model("srnn", "mnist")
spikformer_mnist = get_trained_model("spikformer", "mnist")
srnn_cifar10 = get_trained_model("srnn", "cifar10")
spikformer_cifar10 = get_trained_model("spikformer", "cifar10")

Using Recurrent Classifier: Recurrent LIF Neurons.
Using Recurrent Classifier: Recurrent LIF Neurons.


c:\Users\dzahariev\Desktop\Thesis\Thesis\myenv\Lib\site-packages\pytorch_lightning\utilities\parsing.py:209: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
c:\Users\dzahariev\Desktop\Thesis\Thesis\myenv\Lib\site-packages\pytorch_lightning\utilities\parsing.py:209: Attribute 'loss_fn' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss_fn'])`.


### Compute test scores

In [None]:
def pretty_print_scores(scores):
    acc_key = scores.keys()[0]
    spike_density_key = scores.keys()[4]
    print(f"Accuracy: {scores[acc_key]:.2f}%")
    print(f"Spike Density: {scores[spike_density_key]:.2f}")

In [5]:
trainer = pl.Trainer(max_epochs=10)
srnn_cifar10_scores = trainer.test(srnn_cifar10, cifar10_data_module)
spikformer_cifar10_scores = trainer.test(spikformer_cifar10, cifar10_data_module)
srnn_mnist_scores = trainer.test(srnn_mnist, mnist_data_module)
spikformer_mnist_scores = trainer.test(spikformer_mnist, mnist_data_module)

print("SRNN CIFAR10")
pretty_print_scores(srnn_cifar10_scores)
print("-"*100)
print("Spikformer CIFAR10")
pretty_print_scores(spikformer_cifar10_scores)
print("SRNN MNIST")
pretty_print_scores(srnn_mnist_scores)
print("Spikformer MNIST")
pretty_print_scores(spikformer_mnist_scores)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 2500/2500 [02:28<00:00, 16.83it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         Test metric                   DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_accuracy_epoch            0.5148000121116638
test_activation_sparsity_epoch      0.5188103318214417
  test_binary_sparsity_epoch        0.7840859293937683
       test_loss_epoch              1.3383750915527344
   test_spike_density_epoch         0.2173953354358673
 test_temporal_sparsity_epoch       0.863655686378479
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 2500/2500 [03:27<00:00, 12.07it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         Test metric                   DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_accuracy_epoch            0.7457000017166138
test_activation_sparsity_epoch      0.5902919173240662
  test_binary_sparsity_epoch        0.833260178565979
       test_loss_epoch              0.7273970246315002
   test_spike_density_epoch        0.16685990989208221
 test_temporal_sparsity_epoch       0.8336960077285767
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 2500/2500 [01:14<00:00, 33.58it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         Test metric                   DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_accuracy_epoch            0.9610999822616577
test_activation_sparsity_epoch      0.6892492771148682
  test_binary_sparsity_epoch        0.9320600628852844
       test_loss_epoch             0.17872175574302673
   test_spike_density_epoch        0.07624271512031555
 test_temporal_sparsity_epoch       0.2717552185058594
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing DataLoader 0: 100%|██████████| 2500/2500 [02:59<00:00, 13.94it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         Test metric                   DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_accuracy_epoch            0.982200026512146
test_activation_sparsity_epoch      0.7166489958763123
  test_binary_sparsity_epoch        0.918038547039032
       test_loss_epoch             0.060833368450403214
   test_spike_density_epoch         0.0821024551987648
 test_temporal_sparsity_epoch       0.409807413816452
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss_epoch': 0.060833368450403214,
  'test_accuracy_epoch': 0.982200026512146,
  'test_activation_sparsity_epoch': 0.7166489958763123,
  'test_binary_sparsity_epoch': 0.918038547039032,
  'test_temporal_sparsity_epoch': 0.409807413816452,
  'test_spike_density_epoch': 0.0821024551987648}]