<hr style="height:2px;">

# Demo: Neural network training for denoising of *Tribolium castaneum*

### Notes 

- Assumes that training data was already generated via [1_datagen.ipynb](1_datagen.ipynb) and has been saved to disk to the file ``my_training_data.npz``.
- Training a neural network for actual use should be done on more (representative) data and with more training time.
- Documentation available: http://csbdeep.bioimagecomputing.com/doc/

<hr style="height:2px;">
# Setup 

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
from IPython.core.display import display, HTML
#display(HTML("<style>.rendered_html { font-size: 16px; }</style>"))
import os
from tifffile import imread

import csbdeep
from csbdeep.utils import axes_dict, plot_some
from csbdeep.io import load_training_data
from csbdeep.models import Config, CARE
from csbdeep.utils.tf import limit_gpu_memory


TensorFlow uses all available GPU memory by default, hence it can be useful to limit it:

In [None]:
# limit_gpu_memory(fraction=1/2)

<hr style="height:2px;">

# Training data

Load training data generated via [1_datagen.ipynb](1_datagen.ipynb), use 10% as validation data.

In [None]:
(X,Y), (X_val,Y_val), axes = load_training_data('my_training_data.npz', validation_split=0.1)
ax = axes_dict(axes)

n_train, n_val = len(X), len(X_val)
image_size = tuple(X.shape[i] for i in ((ax['Z'],ax['Y'],ax['X']) if (ax['Z'] is not None) else (ax['Y'],ax['X'])))
n_dim = len(image_size)
n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']]

In [None]:
print('number of training images:\t', n_train)
print('number of validation images:\t', n_val)
print('image size (%dD):\t\t'%n_dim, image_size)
print('axes:\t\t\t\t', axes)
print('Channels in / out:\t\t', n_channel_in, '/', n_channel_out)

In [None]:
plt.figure(figsize=(10,4))
plot_some(X_val[:5],Y_val[:5])
plt.suptitle('5 example validation patches (top row: source, bottom row: target)');

<hr style="height:2px;">

# Model

Before we construct the actual CARE model, we have to define its (training) configuration via a `Config` object, that includes things like 

* type of network 
* learning rate
* number of steps per epoch
* loss function  
* whether the model is probabilistic or not

The defaults should be sensible in many cases, so a change should only be necessary if the training process fails.  

Note that for this notebook we use a very small number of iterations per epoch for immediate feedvback, wheras for a properly trained model this number should be increased (e.g. `train_seps_per_epoch =400`).

In [None]:
config = Config(axes, n_channel_in, n_channel_out, train_epochs =30, train_steps_per_epoch=20)
print(config)
vars(config)

Now we can create a CARE model based on chosen options:

In [None]:
model = CARE(config, 'my_model')

<hr style="height:2px;">

# Training

Now we actually train the model, which may take some time.

To monitor the progress during training one can use [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard), by starting it from the current working directory:

`tensorboard --logdir=. --reload-interval=2`

and then connect to [http://localhost:6006/](http://localhost:6006/) with your browser.

In [None]:
history = model.train(X,Y, validation_data=(X_val,Y_val))

Plot final training history (available in TensorBoard during training):

In [None]:
from csbdeep.utils import plot_history
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);

<hr style="height:2px;">

# Evaluation

Example results for validation images

In [None]:
model.load_weights() # load best weights according to validation loss

In [None]:
plt.figure(figsize=(12,7))
_P = model.keras_model.predict(X_val[:5])
if config.probabilistic:
    _P = _P[...,:(_P.shape[-1]//2)]
plot_some(X_val[:5],Y_val[:5],_P,pmax=99.5)
plt.suptitle('5 example validation patches\n'       +
             'top row: input (source),  '           +
             'middle row: target (ground truth),  ' +
             'bottom row: predicted from source')
#plt.tight_layout()
None;

<hr style="height:2px;">

# Export model to be used with CSBDeep **Fiji** plugins and **KNIME** workflows

See https://github.com/CSBDeep/CSBDeep/wiki/Your-Model-in-Fiji for details.

In [None]:
model.export_TF()