# Neural network training for Napari-LF

Napari-LF neural net integration relies on Pytorch-lightning workflow. Which provides general functions for loading data, training, inference, etc. That can be used with any neural network.
This script is intended for preparing a network to use with Napari-LF.

## Instructions
1. Import required libraries and desired network to train.
2. Gather needed information from user
3. Create a network.
4. Load data for training.
5. Train network.
6. Store network in napari-LF compatible format

### 1. Import required libraries and desired network to train.

In [1]:
import pytorch_lightning as pl
import torch
# Let's train a VCDNet. This defines which network we will train
# from neural_nets.VCDNet import VCDNet as NN
# Or:
from neural_nets.LFMNet import LFMNet as NN

### 2. Gather needed information from user

In [2]:
n_gpus = 1
# What is the shape of our Light-field [angular-u, angular-v, spatial-s, spatial-t]
LFshape = [33,33,39,39]     # For the case of the MouseBrain dataset
LF_2D_shape = [LFshape[0]*LFshape[2], LFshape[1]*LFshape[3]]
# How many depths are present in each volume?
n_depths = 64

# Define training parameters
training_settings = {}
# Learning rate
training_settings['lr'] = 1e-3    
# Batch size              
training_settings['batch_size'] = 2  
# max epochs to train        
training_settings['epochs'] = 1
# Which image-volume pairs to use
training_settings['images_ids'] = list(range(10))                                               
# Where is the data
training_settings['dataset_path'] = 'D:/BrainImagesJosuePage/Brain_40x_64Depths_362imgs.h5'    
# Where to store the trained network?
training_settings['output_dir'] = 'C:/Users/OldenbourgLab2/Code/napari-LF-neural_nets/examples/pretrained_networks/'    # If left blank the logs and trained network are stored at ./lightning_logs/version_*

### 3. Create a network

In [3]:
net = NN(LF_2D_shape, (n_depths,)+tuple(LF_2D_shape), 
         network_settings_dict={'LFshape' : LFshape}, 
         training_settings_dict=training_settings)

### 4. Load data for training.

In [4]:
net.configure_dataloader()

Loading images: 100%|██████████| 10/10 [00:06<00:00,  1.52it/s]


### 5. Train network.

In [5]:
# Do we log to the default directory? or to a specified one 
tb_logger = True
if len(training_settings['output_dir']) > 0: # Do we have a path for the logging?
    # Define network type
    network_prefix = net.__class__.__name__
    from pytorch_lightning import loggers as pl_loggers
    tb_logger = pl_loggers.TensorBoardLogger(save_dir=f"{training_settings['output_dir']}/", name=network_prefix)
    output_path = f"{training_settings['output_dir']}/{network_prefix}"
    # Create a trainer
    trainer = pl.Trainer(logger=tb_logger, gpus=n_gpus, precision=32, max_epochs=net.get_train_setting('epochs'))
else:
    # Create a trainer
    trainer = pl.Trainer(logger=tb_logger, gpus=n_gpus, precision=32, max_epochs=net.get_train_setting('epochs'))
    output_path = './lightning_logs/'
    

print(f'###################### Logging to: {output_path}')
print(f'run tensorboard --logdir={output_path} in the console')
trainer.fit(model=net, train_dataloaders=net.train_loader, val_dataloaders=net.val_loader)
print(f'###################### Logging to: {output_path}')
print(f'run tensorboard --logdir={output_path} in the console')


  rank_zero_deprecation(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name               | Type       | Params
--------------------------------------------------
0 | lensletConvolution | Sequential | 46.7 K
1 | Unet               | UNetLF     | 29.2 K
--------------------------------------------------
75.9 K    Trainable params
0         Non-trainable params
75.9 K    Total params
0.304     Total estimated model params size (MB)


###################### Logging to: C:/Users/OldenbourgLab2/Code/napari-LF-neural_nets/examples/pretrained_networks//LFMNet
run tensorboard --logdir=C:/Users/OldenbourgLab2/Code/napari-LF-neural_nets/examples/pretrained_networks//LFMNet in the console


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


###################### Logging to: C:/Users/OldenbourgLab2/Code/napari-LF-neural_nets/examples/pretrained_networks//LFMNet
run tensorboard --logdir=C:/Users/OldenbourgLab2/Code/napari-LF-neural_nets/examples/pretrained_networks//LFMNet in the console
