In [1]:
from configuration import load_args

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

from utils import FormulateArgs, MakeFolder, SetSeed
from utils.Loader import GetDataLoader
from utils.logger import GetMessageLogger
import utils.training as training

import PrintedSpikingNN_lP_New as pSNN
from surrogate.RSNN import SpikeSynth

import pprint
import os
import time

import torch

In [2]:
# Define all our arguments for training

args = load_args(overrides={
    "DATASET": 0,
    "SEED": 0,
    "projectname": "pLR-SNN",
    "DEVICE": "cpu",
    "PROGRESSIVE": True,
    "EPOCH": 100,
    "TIMELIMITATION": 0.1,
    "LR_MIN": 5e-2,
    "LR": 0.1,
})

args = FormulateArgs(args)

In [3]:
# Dataset Definition
train_loader, datainfo = GetDataLoader(args, 'train')
valid_loader, datainfo = GetDataLoader(args, 'valid')
test_loader, datainfo = GetDataLoader(args, 'test')
pprint.pprint(datainfo)

{'N_class': 2,
 'N_feature': 6,
 'N_test': 25,
 'N_time': 100,
 'N_train': 70,
 'N_valid': 23,
 'dataname': 'acuteinflammation'}


In [4]:
psnn = pSNN.LightningPrintedSpikingNetwork(
    topology=[datainfo['N_feature']] + args.hidden + [datainfo['N_class']], 
    args=args, 
    model_class=SpikeSynth, 
    ckpt_path="surrogate/checkpoints-srnn/spike_model-epoch=92-val_loss=0.07.ckpt",
    train_loader=train_loader,
    valid_loader=valid_loader,
    test_loader=test_loader,
)

RuntimeError: Error(s) in loading state_dict for SpikeSynth:
	Missing key(s) in state_dict: "norm.weight", "norm.bias". 

In [None]:
# Create a CometLogger instance
wandb_logger = WandbLogger(
                          log_model=True,
                          project="Spike-Synth-Full",
                          name="Surrogate_SRNN_ReduceLROnPlateau",
                          )

# log gradients and model topology
wandb_logger.watch(psnn)
wandb_logger.experiment.log_code(".", include_fn=lambda path: path.endswith(".py") or path.endswith(".ipynb"))

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints-pLR-SNN/",
    filename="pLRSNN-{epoch:02d}-{val_loss:.2f}",
    save_top_k=1,  # save only the best model
    monitor="val_loss",  # metric to monitor
    mode="min"
)

In [None]:
trainer = Trainer(
    # fast_dev_run=True,
    max_epochs=args.EPOCH,
    logger=wandb_logger,  
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
)

# Train the model
trainer.fit(psnn)

wandb_logger.experiment.finish()