Train network
===========

This script is based on a [jupyter notebook](https://nbviewer.jupyter.org/url/csbdeep.bioimagecomputing.com/examples/denoising3D/2_training.ipynb).
With this notebook you can train your own network after you have [created some training data](https://github.com/Rickmic/Deep_CLEM/blob/master/load_data.ipynb). <br/>
At first you have to import all required python packages:

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from csbdeep.utils import axes_dict, plot_some, plot_history
from csbdeep.utils.tf import limit_gpu_memory
from csbdeep.io import load_training_data
from csbdeep.models import ProjectionConfig, ProjectionCARE

In the next step you have to specify the directory where your generated training data are located.

In [None]:
(X,Y), (X_val,Y_val), axes = load_training_data('my_training_data.npz', validation_split=0.1, verbose=True)

c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]


In [None]:
plt.figure(figsize=(12,5))
plot_some(X_val[:5],Y_val[:5])
plt.suptitle('5 example validation patches (top row: source, bottom row: target)');


In [None]:
config = ProjectionConfig(axes, n_channel_in, n_channel_out, unet_n_depth=3, train_batch_size=8, train_steps_per_epoch=800, train_epochs=150)
print(config)
vars(config)


Now you need to specify a directory, where a folder with the name *my_model* should be located. The trained network will be saved in this folder.

In [None]:
model = ProjectionCARE(config, 'my_model', basedir='')

In [None]:


model.proj_params



With the following cell you will start the training process of the network. Depending on the GPU power and GPU memory this step will take 2-3 h.

In [None]:

history = model.train(X,Y, validation_data=(X_val,Y_val))



In [None]:


print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);



In [None]:
plt.figure(figsize=(12,7))
_P = model.keras_model.predict(X_val[:5])
if config.probabilistic:
    _P = _P[...,:(_P.shape[-1]//2)]
plot_some(X_val[:5],Y_val[:5],_P,pmax=99.5)
plt.suptitle('5 example validation patches\n'      
             'top row: input (source),  '          
             'middle row: target (ground truth),  '
             'bottom row: predicted from source');



In the last step you will save the network as a .zip file in the folder *my_model*. This .zip file contains the network, that you can load afterwards in Fiji.

If you receive a warning about TF 2 models not working in the CSBDeep Fiji Plugin, please follow the instructions to convert it to a TF1 model.

In [None]:


model.export_TF()

