<hr style="height:2px;">

# Demo: Probabilistic neural network training for denoising of 2D toy data

### Notes 

- Training a neural network for actual use should be done on more (representative) data and with more training time.
- Documentation available: http://csbdeep.bioimagecomputing.com/doc/

<hr style="height:2px;">
# Setup 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import csbdeep
from csbdeep.utils import axes_dict
from csbdeep.train import load_data
from csbdeep.models import Config, CARE
from csbdeep.tf import limit_gpu_memory
from csbdeep.plot_utils import plot_some

TensorFlow uses all available GPU memory by default, hence it can be useful to limit it:

In [None]:
# limit_gpu_memory(fraction=1/2)

<hr style="height:2px;">

# Training data

Load provided training images and use 10% as validation data.

In [None]:
from csbdeep.utils import download_and_extract_zip_file
download_and_extract_zip_file(
    url = 'https://cloud.mpi-cbg.de/index.php/s/OnQFwcB0vCASBMu/download',
    provides = ('toy_data.npz',)
)

In [None]:
(X,Y), (X_val,Y_val), data_axes = load_data('toy_data.npz', axes='SCYX', validation_split=0.1)
ax = axes_dict(data_axes)

n_train, n_val = len(X), len(X_val)
image_size = tuple(X.shape[i] for i in ((ax['Z'],ax['Y'],ax['X']) if (ax['Z'] is not None) else (ax['Y'],ax['X'])))
n_dim = len(image_size)
n_channel_in, n_channel_out = X.shape[ax['C']], Y.shape[ax['C']]

In [None]:
print('number of training images:\t', n_train)
print('number of validation images:\t', n_val)
print('image size (%dD):\t\t'%n_dim, image_size)
print('Channels in / out:\t\t', n_channel_in, '/', n_channel_out)

In [None]:
plt.figure(figsize=(10,4))
plot_some(X_val[:5],Y_val[:5])
plt.suptitle('5 example validation patches (top row: source, bottom row: target)');

<hr style="height:2px;">

# Model

Choose configuration options (defaults should be sensible in many cases):

In [None]:
config = Config(data_axes, n_channel_in, n_channel_out,
                probabilistic=True, train_steps_per_epoch=50, train_learning_rate=0.0002)
print(config)
vars(config)

Create CARE model based on chosen options:

In [None]:
model = CARE(config, 'my_model')

<hr style="height:2px;">

# Training

[TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard) allows monitoring of progress during training.  
Start TensorBoard e.g. with **`tensorboard --logdir=. --reload-interval=2`** and connect to [http://localhost:6006/](http://localhost:6006/) with your browser.

In [None]:
history = model.train(X,Y, validation_data=(X_val,Y_val))

Plot final training history (available in TensorBoard during training):

In [None]:
from csbdeep.plot_utils import plot_history
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);

<hr style="height:2px;">

# Evaluation

Example results for validation images

In [None]:
model.load_weights() # load best weights according to validation loss

In [None]:
plt.figure(figsize=(15,12))
_P = model.keras_model.predict(X_val[:5])
_P_mean  = _P[...,:(_P.shape[-1]//2)]
_P_scale = _P[...,(_P.shape[-1]//2):]
plot_some(X_val[:5],Y_val[:5],_P_mean,_P_scale,pmax=99.5)
plt.suptitle('5 example validation patches\n'       +
             'first row: input (source),  '         +
             'second row: target (ground truth),  ' +
             'third row: predicted Laplace mean,  ' +
             'forth row: predicted Laplace scale')
#plt.tight_layout()
None;

<hr style="height:2px;">

# Export model to be used with CSBDeep **Fiji** plugins and **KNIME** workflows

See https://github.com/CSBDeep/CSBDeep/wiki/Your-Model-in-Fiji for details.

In [None]:
model.export_TF()