# **Tutorial on training the MNN (vanilla ver)**

Here, we show a minimal example of training the MNN using the standard pytorch style without the wrapper. This is suited to those who require under-the-hood modifications of the model.

If you don't already have **PyTorch** installed, you need to install it following the instruction on this page: https://pytorch.org/get-started/locally/


You need to copy this notebook to the root directory (under `moment-neural-network`).

First, the necessary imports.

In [1]:
from mnn.mnn_core.mnn_pytorch import *
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

A quick check of your pytorch version and GPU availability.

In [2]:
print('Using PyTorch version:', torch.__version__)
if torch.cuda.is_available():
    print('Using GPU, device name:', torch.cuda.get_device_name(0))
    device = torch.device('cuda')
else:
    print('No GPU found, using CPU instead.') 
    device = torch.device('cpu')

Using PyTorch version: 2.7.0+cu126
Using GPU, device name: NVIDIA GeForce RTX 4090


## Loading the data & input encoding

PyTorch has two classes from [`torch.utils.data`](https://pytorch.org/docs/stable/data.html#module-torch.utils.data) to work with data: 
- [Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) which represents the actual data items, such as images or pieces of text, and their labels
- [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) which is used for processing the dataset in batches during training.

Here we will use TorchVision and `torchvision.datasets` to access the [MNIST dataset](https://en.wikipedia.org/wiki/MNIST_database). (By setting `download=True`, the code below will attempt to download the dataset if it doesn't already exist locally.)

In [4]:
batch_size = 32

train_dataset = datasets.MNIST('./datasets/', train=True, download=True,
                transform=transforms.Compose([ToTensor(),
                transforms.Normalize((0,), (1,))]))
test_dataset = datasets.MNIST('./datasets/', train=False, download=True,
                  transform=transforms.Compose([ToTensor(),
                  transforms.Normalize((0,), (1,))]))

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

The data loaders provide a way of iterating through the datasets in batches. 

In [5]:
# load the first batch of data
for (data, target) in train_loader:
    print('data:', data.size(), 'type:', data.type())
    print('target:', target.size(), 'type:', target.type())
    break

data: torch.Size([32, 1, 28, 28]) type: torch.FloatTensor
target: torch.Size([32]) type: torch.LongTensor


We need to specify an appropriate encoding scheme of the input data to pass it to the MNN. Here, we suppose that the inputs are the statistical moments of independent Poisson spike trains whose firing rates are proportional to the image pixel values. Below is a helper function that implements this input encoding. The `scale` parameter describes how input pixel values should be converted to firing rates in sp/ms.

In [None]:
def input_encoder(data, scale=1):
    data = torch.flatten(data, start_dim=1)
    input_mean = data*scale
    input_cov = torch.diag_embed(input_mean)
    return input_mean, input_cov

# load the first batch of data
for (data, target) in train_loader:
    input_mean, input_cov = input_encoder(data)
    print('input_mean:', input_mean.size(), 'type:', data.type())
    print('input_cov:', input_cov.size(), 'type:', data.type())
    print('target:', target.size(), 'type:', target.type())
    break


input_mean: torch.Size([32, 784]) type: torch.FloatTensor
input_cov: torch.Size([32, 784, 784]) type: torch.FloatTensor
target: torch.Size([32]) type: torch.LongTensor


## Building a feedforward MNN

A single feedforward layer of MNN consists of the following components:

- **linear (bilinear) layer**: outputs synaptic current mean/covariance given pre-synaptic neuron spike mean/covariance. Accessed through the `LinearDuo` class under `mnn.mnn_core.nn.linear`.
- **moment batch normalization**: outputs batch-normalized synaptic current mean/covariance. This is a generalization of standard batchnorm to second-order moments and is required to avoid vanishing gradient problem. Accessed through the `CustomBatchNorm1D` class under `mnn.mnn_core.nn.custom_batch_norm`.
- **moment activation**: outputs post-synaptic neuron spike mean/covairance, given input current mean/covariance. Accessed through the `OriginMnnActivation` class under `mnn.mnn_core.nn.activation`.

A single feedforward layer can be stack multiple times to form a deep MNN. For illustrative purposes, here we show an example consisting of a single hidden layer followed by a linear readout.

In [68]:
from mnn.mnn_core.nn.activation import OriginMnnActivation
from mnn.mnn_core.nn.linear import LinearDuo
from mnn.mnn_core.nn.custom_batch_norm import CustomBatchNorm1D

class SimpleMNN(torch.nn.Module):
    def __init__(self, hidden_size = 64, input_size = 2, output_size = 1):
        super(SimpleMNN, self).__init__()
        self.linear = LinearDuo(input_size, hidden_size)
        self.batchnorm = CustomBatchNorm1D(hidden_size)
        self.activate = OriginMnnActivation()
        self.readout = LinearDuo(hidden_size,output_size)        
        return

    def forward(self, input_mean, input_cov):
        curr_mean, curr_cov = self.linear(input_mean, input_cov)
        bn_mean, bn_cov = self.batchnorm(curr_mean, curr_cov)
        hidden_mean, hidden_cov = self.activate(bn_mean, bn_cov)
        readout_mean, readout_cov = self.readout(hidden_mean, hidden_cov)
        return readout_mean, readout_cov

The following script creates an instance of the model and prints the name and shape of their trainable parameters.

In [69]:
model = SimpleMNN(hidden_size=100, input_size=28*28,output_size=10)

# you can inspect the trainable parameters
print('Linear layer: ', model.linear)
for name, param in model.linear.named_parameters():
    print('    Name: {}, Shape: {}'.format(name, param.shape))
print('Moment batchnorm: ', model.batchnorm)
for name, param in model.batchnorm.named_parameters():
    print('    Name: {}, Shape: {}'.format(name, param.shape))
print('Moment activation', model.activate)
print('Readout layer:', model.readout)
for name, param in model.readout.named_parameters():
    print('    Name: {}, Shape: {}'.format(name, param.shape))

Linear layer:  LinearDuo(in_features: 784, out_features: 100, bias_mean: False, bias_var: False, dropout: False, scale: None)
    Name: weight, Shape: torch.Size([100, 784])
Moment batchnorm:  CustomBatchNorm1D(num_features: 100, bias_std=False, special_init=True, momentum=0.9, eps=1e-05, affine=True)
    Name: weight, Shape: torch.Size([100])
    Name: bias, Shape: torch.Size([100])
Moment activation OriginMnnActivation()
Readout layer: LinearDuo(in_features: 100, out_features: 10, bias_mean: False, bias_var: False, dropout: False, scale: None)
    Name: weight, Shape: torch.Size([10, 100])


## Training the MNN
So far we have defined the dataset, the data loader, and the model. To train the model, we need to specify the loss function and the optimizer.

For classification problems, we provide `CrossEntropyOnMean` and `GaussianSamplingCrossEntropyLoss` under `mnn.mnn_core.nn.criterion`. The former is identical to the standard cross-entropy loss in PyTorch, whereas the latter is a generalized cross-entropy taking into account of the second-order moments of the output. 

Below is a minimal example using the standard cross-entropy and Adam optimizer.

In [None]:
from mnn.mnn_core.nn.criterion import CrossEntropyOnMean

batch_size = 32
num_epoch = 1
lr = 0.01
input_size = 28*28
hidden_size = 100
output_size = 10

model = SimpleMNN(hidden_size = hidden_size, input_size = input_size, output_size = output_size)    
params = model.parameters()
optimizer = torch.optim.Adam(params, lr = lr, amsgrad = True)
criterion = CrossEntropyOnMean()

for epoch in range(num_epoch):            
    model.train()

    print('Training epoch {}/{}...'.format(epoch,num_epoch))
    for i_batch, (images, target) in enumerate(train_loader):
        optimizer.zero_grad()
        input_mean, input_cov = input_encoder(images) # encode input data to moment representation
        output_mean, output_cov = model.forward(input_mean,input_cov) # run the forward pass
        loss = criterion((output_mean, output_cov), target) # calculate the loss function
        loss.backward() # backpropagation
        optimizer.step() # update model parameters
    
    with torch.no_grad():
        model.eval()
        num_correct = 0
        for i_batch, (images, target) in enumerate(test_loader):
            input_mean, input_cov = input_encoder(images) 
            output_mean, output_cov = model.forward(input_mean,input_cov) 
            prediction = output_mean.argmax(1)  # index of the largest entry in the output mean
            num_correct += torch.sum(prediction == target).item()  # count correct predictions
        
        acc = np.round(num_correct/len(test_dataset)*100,2)
        print('Validation accuracy = {}%'.format(acc))


Training epoch 0/1...
Validation accuracy = 94.43%


We can access all the trained parameters of the model and the state of the optimizer using the following lines of code:

In [74]:
print('Model state dictionary: ', model.state_dict().keys())
print('Optimizer state dictionary: ', optimizer.state_dict().keys())

Model state dictionary:  odict_keys(['linear.weight', 'batchnorm.weight', 'batchnorm.bias', 'batchnorm.running_mean', 'batchnorm.running_var', 'readout.weight'])
Optimizer state dictionary:  dict_keys(['state', 'param_groups'])


## Reconstruct spiking neural network

As the MNN is derived from its corresponding spiking neural network (SNN) model (of current-based leaky integrate-and-fire neurons) on a mathematically rigorous ground, the trained parameters can be used to reconstruct the SNN without futher tuning. 

Note that the moment batchnorm is only required for training purposes, and we can simply absorb its parameters into the linear layer, using the following helper function:

In [73]:
@torch.no_grad()
def weight_fusion(ln, bn):
    ln_weight = ln.weight.detach()
    bn_weight = bn.weight / torch.sqrt(bn.running_var + bn.eps)
    bn_weight = bn_weight.detach()
    weight = ln_weight * bn_weight.unsqueeze(-1)
    bias = -bn.running_mean * bn_weight + bn.bias
    return weight, bias

weight, bias = weight_fusion(model.linear,model.batchnorm)
    

These `weight` and `bias` can then be used to reconstruct the SNN that will generate the same firing statistics as in the MNN.

## Exercises
1. Try modifying `SimpleMNN` by stacking multiple hidden layers to form a deep MNN.
2. Replace the task with a regression problem and also the loss function accordingly. Hint: see `MSEOnMean` and `LikelihoodMSE` under `mnn.mnn_core.nn.criterion`.