# **Tutorial on training the MNN (vanilla ver)**

Here, we show a minimal example of training the MNN using the standard pytorch style without the wrapper. This is suited to those who require under-the-hood modifications of the model.

If you don't already have **PyTorch** installed, you need to install it following the instruction on this page: https://pytorch.org/get-started/locally/


You need to copy this notebook to the root direction (under `moment-neural-network`).

First, the necessary imports.

In [None]:
from mnn.mnn_core.mnn_pytorch import *
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

A quick check of your pytorch version and GPU availability.

In [5]:
print('Using PyTorch version:', torch.__version__)
if torch.cuda.is_available():
    print('Using GPU, device name:', torch.cuda.get_device_name(0))
    device = torch.device('cuda')
else:
    print('No GPU found, using CPU instead.') 
    device = torch.device('cpu')

Using PyTorch version: 2.7.0+cu126
Using GPU, device name: NVIDIA GeForce RTX 4090


## Loading the data

PyTorch has two classes from [`torch.utils.data`](https://pytorch.org/docs/stable/data.html#module-torch.utils.data) to work with data: 
- [Dataset](https://pytorch.org/docs/stable/data.html#torch.utils.data.Dataset) which represents the actual data items, such as images or pieces of text, and their labels
- [DataLoader](https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) which is used for processing the dataset in batches during training.

Here we will use TorchVision and `torchvision.datasets` to access the [MNIST dataset](https://en.wikipedia.org/wiki/MNIST_database). (By setting `download=True`, the code below will attempt to download the dataset if it doesn't already exist locally.)

In [6]:
batch_size = 32

train_dataset = datasets.MNIST('./datasets/', train=True, download=True,
                transform=transforms.Compose([ToTensor(),
                transforms.Normalize((0,), (1,))]))
test_dataset = datasets.MNIST('./datasets/', train=False, download=True,
                  transform=transforms.Compose([ToTensor(),
                  transforms.Normalize((0,), (1,))]))

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

The data loaders provide a way of iterating through the datasets in batches. 

In [None]:
# load the first batch of data
for (data, target) in train_loader:
    print('data:', data.size(), 'type:', data.type())
    print('target:', target.size(), 'type:', target.type())
    break

data: torch.Size([32, 1, 28, 28]) type: torch.FloatTensor
target: torch.Size([32]) type: torch.LongTensor


## Building a feedforward MNN

A single feedforward layer of MNN consists of the following components:

- **linear (bilinear) layer**: outputs synaptic current mean/covariance given pre-synaptic neuron spike mean/covariance. Accessed through the `LinearDuo` class under `mnn.mnn_core.nn.linear`.
- **moment batch normalization**: outputs batch-normalized synaptic current mean/covariance. This is a generalization of standard batchnorm to second-order moments and is required to avoid vanishing gradient problem. Accessed through the `BatchNorm1dDuo` class under `mnn.mnn_core.nn.batchnorm`.
- **moment activation**: outputs post-synaptic neuron spike mean/covairance, given input current mean/covariance. Accessed through the `OriginMnnActivation` class under `mnn.mnn_core.nn.activation`.

They can be accessed individually or as a single block using the `EnsembleLinearDuo` class under `mnn.mnn_core.nn.ensemble`.

A single feedforward layer can be stack multiple times to form a deep MNN. For illustrative purposes, here we show an example consisting of a single hidden layer with linear readout.

In [None]:
class MoNet(torch.nn.Module):
    def __init__(self, num_hidden_layers = 10, hidden_layer_size = 64, input_size = 2, output_size = 1):
        super(MoNet, self).__init__()
        self.layer_sizes = [input_size]+[hidden_layer_size]*num_hidden_layers+[output_size]  # input, hidden, output
        
        
        self.layers = torch.nn.ModuleList(
            [MomentLayer(self.layer_sizes[i], self.layer_sizes[i + 1]) for i in range(len(self.layer_sizes) - 1)])
        
        return

    def forward(self, u, s, rho):
        for i in range(len(self.layer_sizes) - 1):
            u, s, rho = self.layers[i].forward(u, s, rho)
        return u, s, rho

## Training the MNN

In [None]:
sample_size = 10000       
batch_size = 32
num_batches = int(sample_size/batch_size)
num_epoch = 10
lr = 0.01
momentum = 0.9
input_size = 28*28
output_size = 10

model = MoNet(num_hidden_layers = config['num_hidden_layers'], hidden_layer_size = config['hidden_layer_size'], input_size = input_size, output_size = output_size)
    
train_dataset = Dataset(config['dataset_name'], sample_size = sample_size, input_dim = input_size, output_dim = output_size, with_corr = config['with_corr'], fixed_rho = config['fixed_rho'])
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size)        

model.target_transform = train_dataset.transform

validation_dataset = Dataset(config['dataset_name'], sample_size = 32, input_dim = input_size, output_dim = output_size, transform = train_dataset.transform, with_corr = config['with_corr'], fixed_rho = config['fixed_rho'] )          
validation_dataloader = torch.utils.data.DataLoader(validation_dataset, batch_size = 32)

params = model.parameters()

optimizer = torch.optim.Adam(params, lr = lr, amsgrad = True) #recommended lr: 0.1 (Adam requires a much smaller learning rate than SGD otherwise won't converge)
    
for epoch in range(num_epoch):            
    model.train()
    for i_batch, sample in enumerate(train_dataloader):
        optimizer.zero_grad()                
        u, s, rho = model.forward(sample['input_data'][0], sample['input_data'][1], sample['input_data'][2])
        loss = loss_mse_covariance(u, s, rho, sample['target_data'][0], sample['target_data'][1], sample['target_data'][2])
        loss.backward()
        optimizer.step()
    
    print('Training epoch {}/{}'.format(epoch,num_epoch))
    with torch.no_grad():
        model.eval()
        for i_batch, sample in enumerate(validation_dataloader):
            u, s, rho = model.forward(sample['input_data'][0], sample['input_data'][1], sample['input_data'][2])
            
            loss = loss_function_mse(u, s, sample['target_data'][0], sample['target_data'][1])
            
            print('Validation loss:', loss.item())
            print('Epoch:',epoch)

#model.checkpoint['model_state_dict'] =  model.state_dict()
#model.checkpoint['optimizer_state_dict'] = optimizer.state_dict()
