<hr style="height:2px;">

# Train your first CARE network (supervised)

In this first example we will train a CARE network for a 3D denoising task, where corresponding pairs of low and high quality stacks can be acquired.

Each pair should be registered, which is best achieved by acquiring both stacks _interleaved_, i.e. as different channels that correspond to the different exposure/laser settings. 

We will use a single Tribolium stack pair for training, whereas in your real-life application you should aim to acquire at least 10-50 stacks from different developmental timepoints to ensure a well trained model. 

By the way, this exercise was adapted from the examples in the  [CSB Deep Repo](https://github.com/CSBDeep/CSBDeep) (as you see we will also use that library extensively here).
More documentation is available at http://csbdeep.bioimagecomputing.com/doc/.

There will be no tasks to fill in in this exercise, but go through each cell and try to understand what's going on - it will help you in the next part! We put some questions to answer along the way to help with that.

You'll want to select the kernel for the `'care'` environment for this exercise.
![title](nb_material/change_kernel.png)

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
import gc
%matplotlib inline
%load_ext tensorboard
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from csbdeep.utils import axes_dict, download_and_extract_zip_file, Path, plot_history, plot_some
from csbdeep.utils.tf import limit_gpu_memory
from csbdeep.data import RawData, create_patches
from csbdeep.io import load_training_data, save_tiff_imagej_compatible
from csbdeep.models import Config, CARE

<hr style="height:2px;">

## Part 1: Training Data Generation
Network training usually happens on batches of smaller sized images than the ones recorded on a microscopy. In this first part of the exercise, we will load all of the image data and chop it into smaller pieces, a.k.a. patches.

### Download example data

First we download some example data, consisting of low-SNR and high-SNR 3D images of Tribolium.  
Note that `GT` stands for ground truth and represents high signal-to-noise ratio (SNR) stacks.

In [None]:
download_and_extract_zip_file (
    url       = 'http://csbdeep.bioimagecomputing.com/example_data/tribolium.zip',
    targetdir = 'data',
)

We can plot the training stack pair via maximum-projection:

In [None]:
y_train = imread('data/tribolium/train/GT/nGFP_0.1_0.2_0.5_20_13_late.tif')
x_train = imread('data/tribolium/train/low/nGFP_0.1_0.2_0.5_20_13_late.tif')
print('image size =', x_train.shape)

plt.figure(figsize=(16,10))
plot_some(np.stack([x_train,y_train]),
          title_list=[['low (maximum projection)','GT (maximum projection)']], 
          pmin=2,pmax=99.8);

### Generate training data for CARE

We first need to create a `RawData` object, which defines how to get the pairs of low/high SNR stacks and the semantics of each axis (e.g. which one is considered a color channel, etc.).

Here we have two folders "low" and "GT", where corresponding low and high-SNR stacks are TIFF images with identical filenames.  
For this case, we can simply use `RawData.from_folder` and set `axes = 'ZYX'` to indicate the semantic order of the image axes. 

In [None]:
raw_data = RawData.from_folder (
    basepath    = 'data/tribolium/train',
    source_dirs = ['low'],
    target_dir  = 'GT',
    axes        = 'ZYX',
)

From corresponding stacks, we now generate some 3D patches. As a general rule, use a patch size that is a power of two along XYZT, or at least divisible by 8.  
Typically, you should use more patches the more trainings stacks you have. By default, patches are sampled from non-background regions (i.e. that are above a relative threshold), see the documentation of `create_patches` for details.

Note that returned values `(X, Y, XY_axes)` by `create_patches` are not to be confused with the image axes X and Y.  
By convention, the variable name `X` (or `x`) refers to an input variable for a machine learning model, whereas `Y` (or `y`) indicates an output variable.

In [None]:
X, Y, XY_axes = create_patches (
    raw_data            = raw_data,
    patch_size          = (16,64,64),
    n_patches_per_image = 1024,
    save_file           = 'data/tribolium/my_training_data.npz',
)

In [None]:
assert X.shape == Y.shape
print("shape of X,Y =", X.shape)
print("axes  of X,Y =", XY_axes)

### Show

This shows the maximum projection of some of the generated patch pairs (odd rows: *source*, even rows: *target*)

In [None]:
for i in range(2):
    plt.figure(figsize=(16,4))
    sl = slice(8*i, 8*(i+1)), 0
    plot_some(X[sl],Y[sl],title_list=[np.arange(sl[0].start,sl[0].stop)])
    plt.show()
None;

### Questions:
1. Where is the training data located? 
2.How does data have to be stored so that CSBDeep will find and load it correctly?

<hr style="height:2px;">

## Part 2: Training the network


### Load Training data

Load the patches generated in part 1, use 10% as validation data.

In [None]:
(X,Y), (X_val,Y_val), axes = load_training_data('data/tribolium/my_training_data.npz', validation_split=0.1, verbose=True)

c = axes_dict(axes)['C']
n_channel_in, n_channel_out = X.shape[c], Y.shape[c]

In [None]:
plt.figure(figsize=(12,5))
plot_some(X_val[:5],Y_val[:5])
plt.suptitle('5 example validation patches (top row: source, bottom row: target)');

### Configure the CARE model
Before we construct the actual CARE model, we have to define its configuration via a `Config` object, which includes 
* parameters of the underlying neural network,
* the learning rate,
* the number of parameter updates per epoch,
* the loss function, and
* whether the model is probabilistic or not.

The defaults should be sensible in many cases, so a change should only be necessary if the training process fails.  

<span style="color:red;font-weight:bold;">Important</span>: Note that for this notebook we use a very small number of update steps per epoch for immediate feedback, whereas this number should be increased considerably (e.g. `train_steps_per_epoch=400`) to obtain a well-trained model.

In [None]:
config = Config(axes, n_channel_in, n_channel_out, train_steps_per_epoch=10, train_epochs=100)
vars(config)

We now create a CARE model with the chosen configuration:

In [None]:
model = CARE(config, 'my_CARE_model', basedir='models')

### Training

Training the model will likely take some time. We recommend to monitor the progress with [TensorBoard](https://www.tensorflow.org/programmers_guide/summaries_and_tensorboard), which allows you to inspect the losses during training.
Furthermore, you can look at the predictions for some of the validation images, which can be helpful to recognize problems early on.

We start tensorboard in the notebook (you can also launch it outside the notebook if you prefer by changing the `%` to `!`) to then monitor the training.

In [None]:
%tensorboard --logdir models/my_CARE_model --host localhost

In [None]:
history = model.train(X,Y, validation_data=(X_val,Y_val))

Plot final training history (available in TensorBoard during training):

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history,['loss','val_loss'],['mse','val_mse','mae','val_mae']);

### Evaluation
Example results for validation images.

In [None]:
plt.figure(figsize=(12,7))
_P = model.keras_model.predict(X_val[:5])
if config.probabilistic:
    _P = _P[...,:(_P.shape[-1]//2)]
plot_some(X_val[:5],Y_val[:5],_P,pmax=99.5)
plt.suptitle('5 example validation patches\n'      
             'top row: input (source),  '          
             'middle row: target (ground truth),  '
             'bottom row: predicted from source');

### Questions:
1. Where are trained models stored? What models are being stored, how do they differ?
2. How does the name of the saved models get specified?
3. How can you influence the number of training steps per epoch? What did you use?

<hr style="height:2px;">

## Part 3: Prediction

Plot the test stack pair and define its image axes, which will be needed later for CARE prediction.

In [None]:
y_test = imread('data/tribolium/test/GT/nGFP_0.1_0.2_0.5_20_14_late.tif')
x_test = imread('data/tribolium/test/low/nGFP_0.1_0.2_0.5_20_14_late.tif')

axes = 'ZYX'
print('image size =', x_test.shape)
print('image axes =', axes)

plt.figure(figsize=(16,10))
plot_some(np.stack([x_test,y_test]),
          title_list=[['low (maximum projection)','GT (maximum projection)']], 
          pmin=2,pmax=99.8);


### Load CARE model

Load trained model (located in base directory `models` with name `my_model`) from disk.  
The configuration was saved during training and is automatically loaded when `CARE` is initialized with `config=None`.

In [None]:
model = CARE(config=None, name='my_CARE_model', basedir='models')

### Apply CARE network to raw image
Predict the restored image (image will be successively split into smaller tiles if there are memory issues).

In [None]:
%%time
restored = model.predict(x_test, axes, n_tiles=(1,2,1))

### Save restored image

Save the restored image stack as a ImageJ-compatible TIFF image, i.e. the image can be opened in ImageJ/Fiji with correct axes semantics.

In [None]:
Path('results').mkdir(exist_ok=True)
save_tiff_imagej_compatible('results/%s_nGFP_0.1_0.2_0.5_20_14_late.tif' % model.name, restored, axes)

### Visualize results
Plot the test stack pair and the predicted restored stack (middle).

In [None]:
plt.figure(figsize=(16,10))
plot_some(np.stack([x_test,restored,y_test]),
          title_list=[['low (maximum projection)','CARE (maximum projection)','GT (maximum projection)']], 
          pmin=2,pmax=99.8);

<hr style="height:2px;">

### Bonus Questions:
Feel free to skip these
1. How would you load an existing CARE network and continue to train it? (This is _not_ done in this notebook)
2. How would you go about saving the new model separately, under a new name?



<hr style="height:2px;">

# Congratulations!
__You have reached the first checkpoint of this exercise! Please mark your progress on slack and move on to the next part!__