<hr style="height:2px;">

# Demo: Neural network training for denoising of *Tribolium castaneum*

### Notes 

- Assumes that training data was already generated via [datagen.ipynb](datagen.ipynb) and has been saved to disk to the file ``my_training_data.npz``.
- Training a neural network for actual use should be done on more (representative) data and with more training time.
- More documentation available (within CBG/CSBD network): http://myers-pc-8:8080/

<hr style="height:2px;">
# Setup 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import csbdeep
from csbdeep.train import load_data
from csbdeep.models import Config, CARE
from csbdeep.tf import limit_gpu_memory
from csbdeep.plot_utils import plot_some

TensorFlow uses all available GPU memory by default, hence it can be useful to limit it:

In [None]:
# limit_gpu_memory(fraction=1/2)

<hr style="height:2px;">

# Training data

Load training data generated via [datagen.ipynb](datagen.ipynb), use 10% as validation data.

In [None]:
(X,Y), data_val = load_data('my_training_data.npz', validation_split=0.1)

X_val, Y_val = data_val
n_train, n_val = len(X), len(X_val)
image_size = X.shape[1:-1]
n_dim = len(image_size)
n_channel_in, n_channel_out = X.shape[-1], Y.shape[-1]

In [None]:
print('number of training images:\t', n_train)
print('number of validation images:\t', n_val)
print('image size (%dD):\t\t'%n_dim, image_size)
print('Channels in / out:\t\t', n_channel_in, '/', n_channel_out)

In [None]:
plt.figure(figsize=(10,4))
plot_some(X_val[:5],Y_val[:5])
plt.suptitle('5 example validation patches (top row: source, bottom row: target)');

<hr style="height:2px;">

# Model

Configuration options

In [None]:
config = Config(n_dim, n_channel_in, n_channel_out, train_steps_per_epoch=5)
print(config)
vars(config)

Create model based on chosen options

In [None]:
model = CARE(config,'my_model')

<hr style="height:2px;">

# Training

[TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard) allows monitoring of progress during training.  
Start TensorBoard e.g. with **`tensorboard --logdir=. --reload-interval=2`** and connect to [http://localhost:6006/](http://localhost:6006/) with your browser.

In [None]:
history = model.train(X,Y, validation_data=data_val)

Plot final training history (available in TensorBoard during training):

In [None]:
from csbdeep.plot_utils import plot_history
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);

<hr style="height:2px;">

# Evaluation

Example results for validation images

In [None]:
plt.figure(figsize=(12,7))
_P = model.keras_model.predict(X_val[:5])
if config.probabilistic:
    _P = _P[...,:(_P.shape[-1]//2)]
plot_some(X_val[:5],Y_val[:5],_P,pmax=99.5)
plt.suptitle('5 example validation patches\n'+
             'top row: input (source),  '+
             'middle row: target (ground truth),  '+
             'bottom row: predicted from source')
#plt.tight_layout()
None;

<hr style="height:2px;">

# Export model to be used with CSBDeep **Fiji** plugins and **KNIME** workflows

See https://github.com/CSBDeep/CSBDeep/wiki/Your-Model-in-Fiji for details.

In [None]:
model.export_TF()