In [1]:
from configuration import load_args

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

from utils import FormulateArgs, MakeFolder, SetSeed
from utils.Loader import GetDataLoader
from utils.logger import GetMessageLogger
import utils.training as training

import PrintedSpikingNN_lP_New as pSNN
from surrogate.RSNN import SpikeSynth

import snntorch as snn

import pprint
import os
import time

import torch

In [2]:
# Define all our arguments for training

args = load_args(overrides={
    "DATASET": 0,
    "SEED": 0,
    "projectname": "pLR-SNN",
    "DEVICE": "cpu",
    "PROGRESSIVE": True,
    "EPOCH": 100,
    "TIMELIMITATION": 0.1,
    "LR_MIN": 5e-2,
    "LR": 0.1,
})

args = FormulateArgs(args)

In [3]:
# Dataset Definition
train_loader, datainfo = GetDataLoader(args, 'train')
valid_loader, datainfo = GetDataLoader(args, 'valid')
test_loader, datainfo = GetDataLoader(args, 'test')
pprint.pprint(datainfo)

{'N_class': 2,
 'N_feature': 6,
 'N_test': 25,
 'N_time': 100,
 'N_train': 70,
 'N_valid': 23,
 'dataname': 'acuteinflammation'}


In [4]:
psnn = pSNN.LightningPrintedSpikingNetwork(
    topology=[datainfo['N_feature']] + args.hidden + [datainfo['N_class']], 
    args=args, 
    model_class=SpikeSynth, 
    ckpt_path="surrogate/checkpoints-srnn/spike_model-epoch=91-val_loss=0.07.ckpt",
    train_loader=train_loader,
    valid_loader=valid_loader,
    test_loader=test_loader,
    surrogate_gradient=snn.surrogate.atan()
)

In [5]:
name = "Surrogate_SRNN_wTempSkip"

# Create a CometLogger instance
wandb_logger = WandbLogger(
                          log_model=True,
                          project="Spike-Synth-Full",
                          name=name,
                          )

# log gradients and model topology
wandb_logger.watch(psnn)
wandb_logger.experiment.log_code(".", include_fn=lambda path: path.endswith(".py") or path.endswith(".ipynb"))

[34m[1mwandb[0m: Currently logged in as: [33mlupos[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


<Artifact source-Spike-Synth-Full-train_pRSNN.ipynb>

In [6]:
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints-pLR-SNN/",
    filename=name + str("pLRSNN-{epoch:02d}-{val_loss:.2f}"),
    save_top_k=1,  # save only the best model
    monitor="val_loss",  # metric to monitor
    mode="min"
)

In [7]:
trainer = Trainer(
    # fast_dev_run=True,
    max_epochs=args.EPOCH,
    logger=wandb_logger,  
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
)

# Train the model
trainer.fit(psnn)

wandb_logger.experiment.finish()

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type                        | Params | Mode 
------------------------------------------------------------------
0 | network   | PrintedSpikingNeuralNetwork | 2.3 M  | train
1 | loss_fn   | LFLoss                      | 0      | train
2 | evaluator | Evaluator                   | 0      | train
------------------------------------------------------------------
64        Trainable params
2.3 M     Non-trainable params
2.3 M     Total params
9.299     Total estimated model params size (MB)
17        Modules in train mode
90        Modules in eval mode


Sanity Checking: |                                                   | 0/? [00:00<?, ?it/s]

/nix/store/mxb5f60mz822vg50ll0pz7063spw4bnr-python3-3.12.11-env/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/nix/store/mxb5f60mz822vg50ll0pz7063spw4bnr-python3-3.12.11-env/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=19` in the `DataLoader` to improve performance.
/nix/store/mxb5f60mz822vg50ll0pz7063spw4bnr-python3-3.12.11-env/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (2) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to 

Training: |                                                          | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

Validation: |                                                        | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


0,1
epoch,▁▁▁▁▁▁▁▁▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇███
lr,█████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_acc,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,▆▇▇▇▇██▇▇█▇▆▆▄▄▄▃▃▃▃▄▂▃▄▂▂▂▁▁▂▂▂▁▂▂▂▂▂▂▂
train_power,▅▁▁▁▁▂▁▁▁▁█▂▂▃▂▂▂▃▂▂▃▄▃▃▃▃▃▃▄▄▃▃▃▃▅▅▅▃▃▄
trainer/global_step,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▄▄▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
val_acc,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val_loss,▂▅▇▇███▇▇▇▇▇▄▃▃▃▃▃▄▄▄▂▃▃▂▂▃▃▂▁▃▁▁▃▂▁▂▃▂▁
val_power,▁▁▁▁▃▁▁▁▁▁▂▆▁▂▂▁▂▃▂▂▅▂▂▂▄▃▂▂█▄▂▂▃▃▃▂▂▂▃▃

0,1
epoch,99.0
lr,0.05
train_acc,0.47143
train_loss,0.68596
train_power,8e-05
trainer/global_step,199.0
val_acc,0.3913
val_loss,0.69277
val_power,9e-05
