# Tutorial on MNN training
## Quick start: three steps to run your first MNN model

The following provides a step-by-step instruction to train an MNN to learn MNIST image classification task with a multi-layer perceptron structure.

1. Clone the repository to your local drive.
2. Copy the demo files, **./example/mnist/mnist.py** and **./example/mnist/mnist_config.yaml** to the root directory.
3. Create two directories, **./checkpoint/** (for saving trained model results) and **./data/** (for downloading the MNIST dataset).
4. Run the following command to call the script named `mnist.py` with the config file specified through the option:

   ```
   python mnist.py --config=./mnist_config.yaml
   ```

After training is finished, you should find four files in the **./checkpoint/mnist/** folder：

- Two '.pth' files which contain the trained model parameters.
- One '.yaml' file which is a copy of the config file used for running the training the model.
- One '.txt' log file that prints the standard output during training (such as model performance).
- One directroy called `mnn_net_snn_result` that stores the simulation result of the SNN reconstructed from the trained MNN (if enabled).

## A step-by-step explaination of how the MNN is trained

Here we will illustate how the above codes work.
Before we start, we need load the reqired package.
Since MNN is still at a early stage such that we do not yet publish it on Pypi so you need to copy this notebook to root directory of the repo (moment-neural-network)

In [6]:
import torch
from mnn import snn, models, utils
from mnn.utils.training_tools import general_train, general_prepare

When calling the script
   ```
   python mnist.py --config=./mnist_config.yaml
   ```
It will first load all necessary hyperparameters for training, which is equivalent to run the following code:

In [29]:
class TempArgs:
    def __init__(self):
        self.bs = 50 # batch size
        self.print_freq = 20 # print frequency
        self.dir = 'mnist' # directory name to save the model
        self.save_name = 'mnn_mnist' # name of the model to save
        self.use_cuda = True # whether to use cuda
        self.seed = None # random seed
        self.resume = False # whether to resume training from a checkpoint
        self.distributed = False # whether to use distributed training
        self.evaluate = False # whether to evaluate the model only
        self.start_epoch = 0 # starting epoch
        self.local_rank = 0 # local rank for distributed training

args = TempArgs()
setattr(args, 'config', './examples/mnist/mnist_config.yaml') # path to the config file
args = general_prepare.set_config2args(args)
for key, value in args.__dict__.items():
    print(f'{key}: {value}')

bs: 50
print_freq: 20
dir: mnist
save_name: mnn_mnist
use_cuda: True
seed: None
resume: False
distributed: False
evaluate: False
start_epoch: 0
local_rank: 0
config: ./examples/mnist/mnist_config.yaml
LR_SCHEDULER: None
OPTIMIZER: {'name': 'AdamW', 'args': {'lr': 0.001, 'weight_decay': 0.01}}
DATASET: None
DATALOADER: None
MODEL: {'meta': {'arch': 'mnn_mlp', 'cnn_type': None, 'mlp_type': 'mnn_mlp'}, 'mnn_mlp': {'structure': [784, 100], 'num_class': 10, 'bn_bias_var': False, 'predict_bias': True, 'predict_bias_var': False, 'special_init': True, 'dropout': None, 'momentum': 0.9, 'eps': 1e-05}, 'snn_mlp': {'structure': [784, 800], 'num_class': 10, 'use_cov': False, 'bn_bias_var': False}}
CRITERION: {'name': 'CrossEntropyOnMean', 'source': 'mnn_core', 'args': {'reduction': 'mean'}}
DATAAUG_TRAIN: {'aug_order': ['ToTensor'], 'RandomCrop': {'size': 28, 'padding': 2}}
DATAAUG_VAL: {'aug_order': ['ToTensor']}
workers: 2
lr: 0.001
epochs: 1
pin_mem: True
world_size: 1
dataset: mnist
dataset_typ

In this configuration,
* `MODEL` specifyy the network architecture and hyperparameters necessciate to constructe a MNN model. In this tutorial, the configuration will create a MNN model that using MLP architecture consist of one hidden layer (100 neurons)
* `OPTIMIZER` specify which optimizer (offered by Pytorch) should we used and its hyperparameters
* `CRITERION` specify the criterion used to compute the output loss for optimizing the model. `source` further specify where to get the corresponding module. By default we will use `CrossEntropyOnMean`
* `DATASET` and `DATALOADER` are reserved args for those who want customize their own dataset. By default, we use the dataset provided by torchvision, where `dataset` specify which dataset is used and `data_dir` is the path where dataset is stored. In this tutorial, we use MNIST. The batch size is specified by `bs` and the training epochs is specified by `epochs`
* `DATAAUG_TRAIN` and `DATAAUG_VAL` specify the way to do data augmentation.
  
After setting up all necessary hyperparameters, we can run the following code to start training:


In [30]:
args.print_freq = 100 # set print frequency to 100 for this example
# Simlarly, you can set other hyperparameters as needed
general_train.general_train_pipeline(args, train_func=general_train.TrainProcessCollections)

Epoch: [0][   0/1200]	Time  0.024 ( 0.024)	Data  0.007 ( 0.007)	Loss 2.3007e+00 (2.3007e+00)	Acc@1   8.00 (  8.00)
Epoch: [0][ 100/1200]	Time  0.030 ( 0.028)	Data  0.004 ( 0.004)	Loss 1.9582e+00 (2.0891e+00)	Acc@1  70.00 ( 64.28)
Epoch: [0][ 200/1200]	Time  0.031 ( 0.029)	Data  0.004 ( 0.004)	Loss 1.6539e+00 (1.9580e+00)	Acc@1  88.00 ( 71.00)
Epoch: [0][ 300/1200]	Time  0.030 ( 0.030)	Data  0.004 ( 0.004)	Loss 1.4744e+00 (1.8345e+00)	Acc@1  82.00 ( 74.86)
Epoch: [0][ 400/1200]	Time  0.030 ( 0.030)	Data  0.004 ( 0.004)	Loss 1.2726e+00 (1.7258e+00)	Acc@1  90.00 ( 76.98)
Epoch: [0][ 500/1200]	Time  0.030 ( 0.030)	Data  0.004 ( 0.004)	Loss 1.2473e+00 (1.6237e+00)	Acc@1  78.00 ( 78.71)
Epoch: [0][ 600/1200]	Time  0.030 ( 0.030)	Data  0.004 ( 0.004)	Loss 1.0248e+00 (1.5336e+00)	Acc@1  90.00 ( 79.94)
Epoch: [0][ 700/1200]	Time  0.030 ( 0.030)	Data  0.004 ( 0.004)	Loss 9.2382e-01 (1.4497e+00)	Acc@1  92.00 ( 81.05)
Epoch: [0][ 800/1200]	Time  0.030 ( 0.030)	Data  0.004 ( 0.004)	Loss 7.5740e-01 

When training is finished, you are will find a directory named by `dir` in the path specified by `dump_path`. The directory will contain four files named by `save_name` with different suffix:
* `*_config.yaml` recorded all hyperparamters used in training so you can reproduce the experiments.
* `*_log.txt` recorded the loss and accuracy of the model during the training process
* `*.pth` contained the model parameters at the last epoch.
* `*_best_model.pth` contained the model parameters that hit the highest accuracy on the validation set during the training

## Reconstruct SNN based on trained MNN

The parameters of MNN can be directly used in SNN without further fine tuning.
We also provided a pipeline to recontructe SNN based on trained MNN and run simulation by using the following codes:

In [None]:
dt = 1 # time step for simulation
input_type = 'poisson' # Using Poisson process to generate input spikes
num_trial = 100 # number of trials for validation
running_time = 100 # running time for each trial in ms
pregenerate = False # whether to pregenerate the input spikes
m = snn.functional.MnnSnnValidate(args, running_time=running_time, dt=dt, num_trials=num_trial, 
pregenerate=pregenerate, resume_best=False, input_type=input_type)
for index in range(5): # run simuations with the first 5 samples in the validation set
    m.validate_one_sample(index, do_reset=True, dump_spike_train=True, record=True)

test set, Img idx: 0, target: 7, pred: tensor([7])
test set, Img idx: 1, target: 2, pred: tensor([2])
test set, Img idx: 2, target: 1, pred: tensor([1])
test set, Img idx: 3, target: 0, pred: tensor([0])
test set, Img idx: 4, target: 4, pred: tensor([4])


You will find another directory in the `dir` that named by `save_name` with a suffix `_snn_validate_result`.
There a two type file:
* `*.snnval` stored the information of running the simulation such as spike count and simulation duration.
* `*.spt` stored the spike trains of hidden neurons during the simulation, which is stored as sparse tensor with the shape `(int(running_time/dt), num_trial, hidden_neurons)`