<hr style="height:2px;">

# Demo: Application of trained neural network for denoising of *Tribolium castaneum*

### Notes 

- Assumes that training was already completed via [training.ipynb](training.ipynb).
- The trained CARE network is here applied to the same image that the model was trained on (data generated via [datagen.ipynb](datagen.ipynb)).  
Of course, in practice one would typically use it to restore images that the model hasn't seen during training.
- Documentation available: http://csbdeep.bioimagecomputing.com/doc/

<hr style="height:2px;">
# Setup 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import os
from tifffile import imread
from csbdeep.models import CARE
from csbdeep.predict import PercentileNormalizer, PadAndCropResizer
from csbdeep.plot_utils import plot_some

<hr style="height:2px;">

# Download example data (if necessary)

In [None]:
from csbdeep.utils import download_and_extract_zip_file
download_and_extract_zip_file(
    url = 'https://cloud.mpi-cbg.de/index.php/s/jKHFIS4isNwagMd/download',
    provides = ('raw_data/tribolium/%s/nGFP_0.1_0.2_0.5_20_13_late.tif'%d for d in ('GT','low'))
)

Data should have been downloaded like this:

    raw_data/tribolium
    ├── GT
    │   └── nGFP_0.1_0.2_0.5_20_13_late.tif
    └── low
        └── nGFP_0.1_0.2_0.5_20_13_late.tif

<hr style="height:2px;">

# Raw low-SNR image and associated high-SNR ground truth

In [None]:
y = imread('raw_data/tribolium/GT/nGFP_0.1_0.2_0.5_20_13_late.tif')
x = imread('raw_data/tribolium/low/nGFP_0.1_0.2_0.5_20_13_late.tif')
axes = 'ZYX'
print('image size =', x.shape)
print('image axes =', axes)

plt.figure(figsize=(15,10))
plot_some(np.stack([x,y]),title_list=[['low (maximum projection)','GT (maximum projection)']], pmin=2,pmax=99.8);

<hr style="height:2px;">

# Model

Load trained model (located in folder `my_model`) from disk.  
The configuration was saved during training and is automatically loaded when `CARE` is initialized with `config=None`.

In [None]:
model = CARE(config=None, name='my_model')
model.load_weights()

- Select appropriate normalization
- Choose how to resize the image to be able to apply the model

In [None]:
normalizer = PercentileNormalizer(3,99.8)
resizer = PadAndCropResizer()

## Apply CARE network to raw image

Predict the restored image

In [None]:
%%time
restored = model.predict(x, axes, normalizer, resizer)

Choose `n_tiles` explicitly to prevent out of memory issues and speed up prediction:

In [None]:
%%time
restored = model.predict(x, axes, normalizer, resizer, n_tiles=4)

<hr style="height:2px;">

# Raw low-SNR image and denoised image via CARE network

In [None]:
plt.figure(figsize=(15,10))
plot_some(np.stack([x,restored]), pmin=2,pmax=99.8,
          title_list=[['low (maximum projection)','CARE (maximum projection)']]);