# Example 1: Loading/training
In this example I explain:
- How to configure/load the desired dataset (options are 'tonic' datasets and benchmarks (such as the addition task)).
- How to configure instantiate an SNN
- How to train it
- How to load and test saves models

### Prerequisites and environment variables

Requirements (specific versions recommended but not mandatory):
- python 3.6
- pytorch 1.7.1
- numpy 1.19.5
- torch-summary 1.4.5
- matplotlib 3.3.4
- seaborn 0.11.1
- scikit-learn 0.24.1
- scipy 1.5.2
- h5py 2.10.0

Setting up the snn library:
- add the parent location of \dsnn to the environment variable PYTHONPATH
- create a directory with your preferred name, e.g: SNNData and add it to a new environment variable PYTHON_DRIVE_PATH. In this directory, create three empty folders:
  - checkpoints: here the saved models are stored.
  - tonic_datasets: here the full datasets from tonic are downloaded 
  - tonic_cache: here the disk cached tonic datasets are stored

In [1]:
import torch
from my_snn.rsnn import RSNN_2l
from my_snn.utils import train, training_plots, ModelLoader
from my_snn.tonic_dataloader import DatasetLoader
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print ('Running on: {}'.format(device))

Running on: cuda:0


### Tonic Dataloader: SHD

Here I use pre-configured Pytorch dataloaders from [Tonic](https://tonic.readthedocs.io/en/latest/) \
Available datasets:
- 'shd': [Spoken Heidelberg Digits](https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/), 700 inputs, 20 classes
- 'ssc': [Speech Commands](https://zenkelab.org/resources/spiking-heidelberg-datasets-shd/) 700 inputs, 35 classes
- 'nmnist': [Neuromorphic-MNIST](https://www.garrickorchard.com/datasets/n-mnist) 34x34x2 inputs, 10 classes
- 'ibmgesture': [DVS Gestures](https://research.ibm.com/interactive/dvsgesture/), 128x128x2 inputs, 11 classes

Caching options can be either:
- 'memory': The data goes to GPU Memory, trains faster, limited to availabe GPU memory, use for small models/datasets
- 'disk': The data is loaded directly from hard drive. Slower, but allows using larger models/datasets.

Other options:

- num_workers: for multithreading on data parallelization (check tonic documentation for more details) \
Better to keep in 0. But depending on the machine, tuning this number can lead to faster dataloading.
- batch_size: decide this value wisely, a large batch size (that fits in memory) allows faster training, \
but has slower convergence (if at all).
- time_window: sequence lenght, number of time steps/bins, from the 'ToFrame' transform in tonic

What we obtain:
- data: a tuple (test_loader, train_loader)  used for training, testing the model

In [2]:
dataset = 'shd' 
time_window = 50
batch_size = 256 # lr=1e-4
DL = DatasetLoader(dataset=dataset, caching='memory', num_workers=0, batch_size=batch_size, time_window=time_window)
data = DL.get_dataloaders()

### Network instantiation

Instantiating a network is simple, just import and instantiate the desired class, with the desired parameters as arguments. \
In this example, we call a 2-layer recurrent SNN, 64 neurons per hidden layer, trainable tau, fast-sigmoid and max-over-time loss function. \
A summary of the layer connectivity and parameter size is printed. \
For more options and details, see the [second example](02_models_and_parameters.ipynb)

In [6]:
snn = RSNN_2l(dataset, num_hidden=64, thresh=0.3, tau_m='adp', win=time_window, surr='fs', loss_fn ='mot', batch_size=256, device=device)
snn.to(device)

RSNN_2l(
  (criterion): CrossEntropyLoss()
  (fc_ih): Linear(in_features=700, out_features=64, bias=False)
  (fc_h1h1): Linear(in_features=64, out_features=64, bias=False)
  (fc_h1h2): Linear(in_features=64, out_features=64, bias=False)
  (fc_h2h2): Linear(in_features=64, out_features=64, bias=False)
  (fc_ho): Linear(in_features=64, out_features=20, bias=False)
)

### Training

Arguments for training:
- **snn**: the snn object to be trained.
- **data**: the dataloader tuple (train_loader, test_loader) used to train 'snn'.
- **learning_rate**: learning rate.
- **num_epochs**: number of epochs.
- **spkreg**: Spiking activity regularizer, added to the loss as avg_spikes_hidden_layer x spkreg. Default = 0.0 
- **l1_reg**: L1 regularizer (still in development) Default = 0.0 
- **dropout**: Random dropout ratio for the input spikes. Ranging form 0.0 to 1.0. Default = 0.0 
- **lr_scale**: Scaling of the learning rate of the tau_m and the tau_adp. Default = (2.0, 5.0) 
- **ckpt_dir**: Creates a folder in PYTHON_DRIVE_PATH/checkpoints to store the saved models. Default = 'checkpoint' 
- **test_fn**: This is to control the testing behavior (see AddTask example below). Default = None -> test every 5 epochs
- **scheduler**: Scheduler for the learning rate: Default =(1, 0.98) -> multiply learning_rate by 0.98 every 1 epoch

Prints total synaptic parameters and total mult-adds. \
Also the training loss per epoch and time elapsed.

In [7]:
ckpt_dir = 'examples' # donde se guardará
train(snn, data, learning_rate=1e-3, num_epochs=10, ckpt_dir=ckpt_dir)

RSNN_2l(
  (criterion): CrossEntropyLoss()
  (fc_ih): Linear(in_features=700, out_features=64, bias=False)
  (fc_h1h1): Linear(in_features=64, out_features=64, bias=False)
  (fc_h1h2): Linear(in_features=64, out_features=64, bias=False)
  (fc_h2h2): Linear(in_features=64, out_features=64, bias=False)
  (fc_ho): Linear(in_features=64, out_features=20, bias=False)
)
Total params: 58368
Total mult-adds (M): 2.9184
training shd50_RSNN_2l_64.t7 for 10 epochs...
Epoch [1/10]
Step [10/31], Loss: 30.74082
Step [20/31], Loss: 29.65269
Step [30/31], Loss: 28.84436
Time elasped: 5.268250942230225
Epoch [2/10]
Step [10/31], Loss: 27.12155
Step [20/31], Loss: 25.67185
Step [30/31], Loss: 24.83320
Time elasped: 5.218309164047241
Epoch [3/10]
Step [10/31], Loss: 23.04344
Step [20/31], Loss: 21.67767
Step [30/31], Loss: 20.14252
Time elasped: 5.272017240524292
Epoch [4/10]
Step [10/31], Loss: 18.58522
Step [20/31], Loss: 17.68420
Step [30/31], Loss: 16.84487
Time elasped: 5.190589427947998
Epoch [5/10

In [9]:
snn.save_model(snn.modelname)

### Loading models

If all went correctly in the previous steps, there should be a file 'shd50_RSNN_2l_64.t7' in *PYTHON_DRIVE_PATH/checkpoints/examples* \
You can load any model for future use by providing its model name, location, batch size and device (gpu or cpu).

In [8]:
modelname = 'shd50_RSNN_2l_64.t7'
loaded_snn = ModelLoader(modelname, ckpt_dir, batch_size, device)

FileNotFoundError: [Errno 2] No such file or directory: 'C:\\Users\\saucedo\\OneDrive - UNIVERSIDAD DE SEVILLA\\PythonData\\checkpoints\\examples\\shd50_RSNN_2l_64.t7'

In [None]:
loaded_snn.test(data[0])

### Custom dataloader: Adding Task

Here is an example of training with a 'non-tonic' dataloader. \
the test_fn function is used for customizing the behavior in testing. In this case, we tell the training function to test every 5 epochs and display a custom 'correct' measure.

In [7]:
from my_snn.rsnn_delays import RSNN_d_d
from my_snn.custom_dataloader import AddTaskDatasetLoader
from torch.utils.data import DataLoader

batch_size = 128 # 128: anil kag
time_window =50
d_train = AddTaskDatasetLoader(time_window, batch_size, randomness=True)
d_test = AddTaskDatasetLoader(time_window, batch_size, randomness=True) # 2560 from schmiduber paper
train_loader = DataLoader(d_train, batch_size=batch_size, num_workers=0)
test_loader = DataLoader(d_test, batch_size=batch_size, num_workers=0)

data = train_loader, test_loader # the dataloader tuple

def test_fn(snn, ckpt_dir, test_loader, max_acc, epoch):
    if (epoch + 1) % 5 == 0:
        for images, labels in test_loader:
            pred, ref = snn.propagate(images.to(device), labels.to(device))          
        correct = torch.sum(abs(pred-ref) < 0.04)
        print(f'Test set accuracy: {100*correct.item()/len(images)}% ')
        print('--------------------------')
    return max_acc

In [8]:
surr='fs'
n_h = 128
name= f'add2_{time_window}_rnn_{n_h}_{surr}'
ckpt_dir = 'some-tests-add'

hidden = (n_h, 1, 'r')
snn = RSNN_d_d('custom_2_1_{}'.format(batch_size), hidden=hidden, delay =(1,1), thresh=0.3,reset_to_zero = False,  tau_m='adp', win=time_window, surr=surr,  loss_fn ='prediction', batch_size=batch_size, device=device)
snn.debug=True
snn.to(device)

train(snn, data, 1e-3, 500, ckpt_dir=ckpt_dir, l1_reg=0.0, test_fn=test_fn, scheduler=False)
snn.save_model(name, ckpt_dir)

delays: [0]
RSNN_d_d(
  (criterion): MSELoss()
  (f0_i): Linear(in_features=2, out_features=128, bias=False)
  (r1_r1): Linear(in_features=128, out_features=128, bias=False)
  (r1_o): Linear(in_features=128, out_features=1, bias=False)
)
Total params: 16768
Total mult-adds (M): 0.8384
training custom_2_1_12850_RSNN_d_d_1_l128_1d1.t7 for 500 epochs...
Epoch [1/500]
Step [1/1], Loss: 4.42521
Time elasped: 0.17685961723327637
Epoch [2/500]
Step [1/1], Loss: 2.02188
Time elasped: 0.14940452575683594
Epoch [3/500]
Step [1/1], Loss: 0.63342
Time elasped: 0.15138602256774902
Epoch [4/500]
Step [1/1], Loss: 0.13283
Time elasped: 0.14546990394592285
Epoch [5/500]
Step [1/1], Loss: 0.38527
Time elasped: 0.15038108825683594
Test set accuracy: 4.6875% 
--------------------------
Epoch [6/500]
Step [1/1], Loss: 0.75083
Time elasped: 0.15331125259399414
Epoch [7/500]
Step [1/1], Loss: 0.75132
Time elasped: 0.15233325958251953
Epoch [8/500]
Step [1/1], Loss: 0.75702
Time elasped: 0.14745235443115234


Step [1/1], Loss: 0.16116
Time elasped: 0.14159226417541504
Epoch [94/500]
Step [1/1], Loss: 0.17590
Time elasped: 0.14647507667541504
Epoch [95/500]
Step [1/1], Loss: 0.16779
Time elasped: 0.14940428733825684
Test set accuracy: 6.25% 
--------------------------
Epoch [96/500]
Step [1/1], Loss: 0.15892
Time elasped: 0.14354681968688965
Epoch [97/500]
Step [1/1], Loss: 0.17119
Time elasped: 0.15330958366394043
Epoch [98/500]
Step [1/1], Loss: 0.16001
Time elasped: 0.15038156509399414
Epoch [99/500]
Step [1/1], Loss: 0.17167
Time elasped: 0.14842748641967773
Epoch [100/500]
Step [1/1], Loss: 0.16239
Time elasped: 0.14940524101257324
Test set accuracy: 8.59375% 
--------------------------
Epoch [101/500]
Step [1/1], Loss: 0.16496
Time elasped: 0.14452338218688965
Epoch [102/500]
Step [1/1], Loss: 0.18259
Time elasped: 0.14940404891967773
Epoch [103/500]
Step [1/1], Loss: 0.17994
Time elasped: 0.15040922164916992
Epoch [104/500]
Step [1/1], Loss: 0.16948
Time elasped: 0.14546918869018555
E

Step [1/1], Loss: 0.15872
Time elasped: 0.15038156509399414
Epoch [189/500]
Step [1/1], Loss: 0.15321
Time elasped: 0.14842748641967773
Epoch [190/500]
Step [1/1], Loss: 0.15147
Time elasped: 0.14647555351257324
Test set accuracy: 6.25% 
--------------------------
Epoch [191/500]
Step [1/1], Loss: 0.14155
Time elasped: 0.14842844009399414
Epoch [192/500]
Step [1/1], Loss: 0.18935
Time elasped: 0.14940428733825684
Epoch [193/500]
Step [1/1], Loss: 0.17227
Time elasped: 0.15526437759399414
Epoch [194/500]
Step [1/1], Loss: 0.14983
Time elasped: 0.15330958366394043
Epoch [195/500]
Step [1/1], Loss: 0.17077
Time elasped: 0.14745211601257324
Test set accuracy: 7.8125% 
--------------------------
Epoch [196/500]
Step [1/1], Loss: 0.14361
Time elasped: 0.14452290534973145
Epoch [197/500]
Step [1/1], Loss: 0.18353
Time elasped: 0.14940476417541504
Epoch [198/500]
Step [1/1], Loss: 0.16373
Time elasped: 0.14256834983825684
Epoch [199/500]
Step [1/1], Loss: 0.14211
Time elasped: 0.14061665534973

Step [1/1], Loss: 0.15458
Time elasped: 0.15038156509399414
Epoch [284/500]
Step [1/1], Loss: 0.18979
Time elasped: 0.14842820167541504
Epoch [285/500]
Step [1/1], Loss: 0.18733
Time elasped: 0.15721654891967773
Test set accuracy: 6.25% 
--------------------------
Epoch [286/500]
Step [1/1], Loss: 0.18560
Time elasped: 0.14452338218688965
Epoch [287/500]
Step [1/1], Loss: 0.17438
Time elasped: 0.14452195167541504
Epoch [288/500]
Step [1/1], Loss: 0.13513
Time elasped: 0.15331029891967773
Epoch [289/500]
Step [1/1], Loss: 0.17517
Time elasped: 0.15822219848632812
Epoch [290/500]
Step [1/1], Loss: 0.17141
Time elasped: 0.15328049659729004
Test set accuracy: 4.6875% 
--------------------------
Epoch [291/500]
Step [1/1], Loss: 0.17867
Time elasped: 0.14549875259399414
Epoch [292/500]
Step [1/1], Loss: 0.17566
Time elasped: 0.14549827575683594
Epoch [293/500]
Step [1/1], Loss: 0.14382
Time elasped: 0.15233969688415527
Epoch [294/500]
Step [1/1], Loss: 0.14557
Time elasped: 0.14744806289672

Step [1/1], Loss: 0.19466
Time elasped: 0.15135622024536133
Epoch [379/500]
Step [1/1], Loss: 0.18706
Time elasped: 0.15041112899780273
Epoch [380/500]
Step [1/1], Loss: 0.15353
Time elasped: 0.15132737159729004
Test set accuracy: 9.375% 
--------------------------
Epoch [381/500]
Step [1/1], Loss: 0.14231
Time elasped: 0.14940404891967773
Epoch [382/500]
Step [1/1], Loss: 0.19871
Time elasped: 0.14257001876831055
Epoch [383/500]
Step [1/1], Loss: 0.17712
Time elasped: 0.14647459983825684
Epoch [384/500]
Step [1/1], Loss: 0.15066
Time elasped: 0.14940547943115234
Epoch [385/500]
Step [1/1], Loss: 0.18765
Time elasped: 0.15331006050109863
Test set accuracy: 7.03125% 
--------------------------
Epoch [386/500]
Step [1/1], Loss: 0.17152
Time elasped: 0.1543140411376953
Epoch [387/500]
Step [1/1], Loss: 0.17663
Time elasped: 0.14449596405029297
Epoch [388/500]
Step [1/1], Loss: 0.16063
Time elasped: 0.14749479293823242
Epoch [389/500]
Step [1/1], Loss: 0.14988
Time elasped: 0.1425690650939

Step [1/1], Loss: 0.17616
Time elasped: 0.14452171325683594
Epoch [474/500]
Step [1/1], Loss: 0.15817
Time elasped: 0.15331149101257324
Epoch [475/500]
Step [1/1], Loss: 0.13163
Time elasped: 0.14842724800109863
Test set accuracy: 7.8125% 
--------------------------
Epoch [476/500]
Step [1/1], Loss: 0.15822
Time elasped: 0.14940881729125977
Epoch [477/500]
Step [1/1], Loss: 0.15221
Time elasped: 0.1562356948852539
Epoch [478/500]
Step [1/1], Loss: 0.16708
Time elasped: 0.14647555351257324
Epoch [479/500]
Step [1/1], Loss: 0.15140
Time elasped: 0.15135788917541504
Epoch [480/500]
Step [1/1], Loss: 0.14869
Time elasped: 0.15135693550109863
Test set accuracy: 7.8125% 
--------------------------
Epoch [481/500]
Step [1/1], Loss: 0.17362
Time elasped: 0.14745211601257324
Epoch [482/500]
Step [1/1], Loss: 0.17927
Time elasped: 0.14647436141967773
Epoch [483/500]
Step [1/1], Loss: 0.14294
Time elasped: 0.14354562759399414
Epoch [484/500]
Step [1/1], Loss: 0.15877
Time elasped: 0.1562402248382