# Neural network training for Napari-LF

Napari-LF neural net integration relies on Pytorch-lightning workflow. Which provides general functions for loading data, training, inference, etc. That can be used with any neural network.
This script is intended for preparing a network to use with Napari-LF.

## Instructions
1. Import required libraries and desired network to train.
2. Gather needed information from user
3. Create a network.
4. Load data for training.
5. Train network.
6. Store network in napari-LF compatible format

### 1. Import required libraries and desired network to train.

In [1]:
import pytorch_lightning as pl
import torch
# Let's train a VCDNet. This defines which network we will train
from neural_nets.VCDNet import  VCDNet as NN
# Or:
# from neural_nets.LFMNet import  LFMNet as NN

### 2. Gather needed information from user

In [2]:
n_gpus = 1
# What is the shape of our Light-field [angular-u, angular-v, spatial-s, spatial-t]
LFshape = [33,33,39,39]     # For the case of the MouseBrain dataset
LF_2D_shape = [LFshape[0]*LFshape[2], LFshape[1]*LFshape[3]]
# How many depths are present in each volume?
n_depths = 64

# Define training parameters
training_settings = {}
# Learning rate
training_settings['lr'] = 1e-3    
# Batch size              
training_settings['batch_size'] = 2  
# max epochs to train        
training_settings['epochs'] = 500
# Which image-volume pairs to use
training_settings['images_ids'] = list(range(100))                                               
# Where is the data
training_settings['dataset_path'] = 'D:/BrainImagesJosuePage/Brain_40x_64Depths_362imgs.h5'    
# Where to store the trained network?
training_settings['output_dir'] = 'C:/Users/OldenbourgLab2/Code/napari-LF-neural_nets/examples/pretrained_networks/'    # If left blank the logs and trained network are stored at ./lightning_logs/version_*

### 3. Create a network

In [3]:
net = NN(LF_2D_shape, (n_depths,)+tuple(LF_2D_shape), 
         network_settings_dict={'LFshape' : LFshape}, 
         training_settings_dict=training_settings)

### 4. Load data for training.

In [4]:
net.configure_dataloader()

Loading img:   0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 

### 5. Train network.

In [1]:
# Create a trainer
tb_logger = True
# Do we log to the default directory? or to a specified one
if len(training_settings['output_dir']) > 0: # Do we have a path for the logging?
    # Define network type
    network_prefix = net.__class__.__name__
    from pytorch_lightning import loggers as pl_loggers
    tb_logger = pl_loggers.TensorBoardLogger(save_dir=f"{training_settings['output_dir']}/", name=network_prefix)
    output_path = f"{training_settings['output_dir']}/{network_prefix}"
    trainer = pl.Trainer(logger=tb_logger, gpus=n_gpus, precision=32, max_epochs=net.get_train_setting('epochs'))
else:
    trainer = pl.Trainer(logger=tb_logger, gpus=n_gpus, precision=32, max_epochs=net.get_train_setting('epochs'))
    output_path = './lightning_logs/'
    

print(f'###################### Logging to: {output_path}')
print(f'run tensorboard --logdir={output_path} in the console')
trainer.fit(model=net, train_dataloaders=net.train_loader, val_dataloaders=net.val_loader)
print(f'###################### Logging to: {output_path}')
print(f'run tensorboard --logdir={output_path} in the console')


NameError: name 'training_settings' is not defined