## 1. Imports

In [1]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import matplotlib.pyplot as plt

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
import pytorch_lightning as L

from pytorch_lightning import Trainer
import wandb

import os

from utils.RSNN import SpikeSynth

## 2. Dataset Definition

In [2]:
# Our data in in the shape: trainings samples(28k) * number of time steps (100 + 6) * time dimension(1)
# The time steps is voltage over time
data = torch.load(f'./data/dataset.ds')

print(data.keys())
print(data['X_train'].shape)
print(data['Y_train'].shape)

dict_keys(['X_train', 'Y_train', 'X_valid', 'Y_valid', 'X_test', 'Y_test'])
torch.Size([28797, 106, 1])
torch.Size([28797, 100])


In [3]:
# Extract tensors
X_train, Y_train = data['X_train'], data['Y_train']
X_valid, Y_valid = data['X_valid'], data['Y_valid']
X_test, Y_test = data['X_test'], data['Y_test']

train_dataset = TensorDataset(X_train, Y_train)
valid_dataset = TensorDataset(X_valid, Y_valid)
test_dataset  = TensorDataset(X_test, Y_test)

## 3. Model Definition

In [4]:
max_epochs=10
experiment_name="test" # Changes wandb expeirment name
project_name = "Spike-Synth-Surrogate" # Changes where wandb project
logging_directory = ".temp"
checkpoint_path = "models/SRNN" # where the final checkpoint is saved

model = SpikeSynth(
        optimizer_class=torch.optim.AdamW,
        beta=0.9,
        lr=0.005,
        num_hidden=256,
        batch_size=2048,
        gamma=0.9,
        num_hidden_layers=4,
        train_dataset=train_dataset,
        valid_dataset=valid_dataset,
        max_epochs=max_epochs,
        surrogate_gradient=snn.surrogate.atan(),
        temporal_skip=None,
        layer_skip=2
    )

script_dir = os.getcwd() 
logging_directory = os.path.join(script_dir, logging_directory)
logging_directory = os.path.abspath(logging_directory)
os.makedirs(logging_directory, exist_ok=True)
os.environ["WANDB_DIR"] = logging_directory

In [5]:
# Create a CometLogger instance
wandb_logger = WandbLogger(
                          log_model=True,
                          project=project_name,
                          name=experiment_name,
                          save_dir=logging_directory
                          )

# log gradients and model topology
wandb_logger.watch(model)
wandb_logger.experiment.log_code(".", include_fn=lambda path: path.endswith(".py") or path.endswith(".ipynb"))

[34m[1mwandb[0m: Currently logged in as: [33mlupos[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


<Artifact source-Spike-Synth-Surrogate-surrogate_2_create_rsnn_surrogate.ipynb>

## 4. Training

In [6]:
checkpoint_callback = ModelCheckpoint(
    dirpath=checkpoint_path,
    filename=experiment_name + str("spike_model-{epoch:02d}-{val_loss:.2f}"),
    save_top_k=1,  # save only the best model
    monitor="val_loss",  # metric to monitor
    mode="min"
)

trainer = Trainer(
    max_epochs=max_epochs,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    logger=wandb_logger, 
    callbacks=[checkpoint_callback]
)

model = torch.compile(model)

# Train the model
trainer.fit(model)

wandb_logger.finalize("sucess")

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/nix/store/mxb5f60mz822vg50ll0pz7063spw4bnr-python3-3.12.11-env/lib/python3.12/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:751: Checkpoint directory /home/monkeman/SpikeSynth/surrogate/models/SRNN exists and is not empty.

  | Name              | Type          | Params | Mode 
------------------------------------------------------------
0 | norm              | LayerNorm     | 14     | train
1 | lif_layers        | ModuleList    | 462 K  | train
2 | residual_alphas   | ParameterList | 4      | train
3 | residual_projs    | ModuleList    | 2.0 K  | train
4 | layer_skip_alphas | ParameterList | 2      | train
5 | layer_skip_projs  | ModuleList    | 131 K  | train
6 | output_layer      | Linear        | 257    | train
------------------------------------------------------------
334 K     Trainable params
262 K     Non-trainable params
596 K     Total params
2.386

Sanity Checking: |             | 0/? [00:00<?, ?it/s]

If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
  torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
/nix/store/mxb5f60mz822vg50ll0pz7063spw4bnr-python3-3.12.11-env/lib/python3.12/site-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (15) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                    | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
