<hr style="height:2px;">

# Demo: Application of trained neural network for isotropic reconstruction of *Danio rerio* retina

### Notes 

- Assumes that training was already completed via [training.ipynb](training.ipynb).
- The trained CARE network is here applied to the same image that the model was trained on (data generated via [datagen.ipynb](datagen.ipynb)).  
Of course, in practice one would typically use it to restore images that the model hasn't seen during training.
- Documentation available: http://csbdeep.bioimagecomputing.com/doc/

<hr style="height:2px;">
# Setup 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import os
from tifffile import imread
from csbdeep.models import IsotropicCARE
from csbdeep.predict import PercentileNormalizer, PadAndCropResizer
from csbdeep.plot_utils import plot_some

In [None]:
#try:
#    import gputools
#    from csbdeep.tf import limit_gpu_memory
#    limit_gpu_memory(fraction=3/4)    
#except ImportError:
#    pass

<hr style="height:2px;">

# Download example data (if necessary)

In [None]:
from csbdeep.utils import download_and_extract_zip_file
download_and_extract_zip_file(
    url = 'https://cloud.mpi-cbg.de/index.php/s/Vu0rN1G33z9hQa4/download',
    provides = ('raw_data/retina/cropped_farred_RFP_GFP_2109175_2color_sub_10.20.tif',)
)

Data should have been downloaded like this:

    raw_data/
    └── retina
        └── cropped_farred_RFP_GFP_2109175_2color_sub_10.20.tif

<hr style="height:2px;">

# Raw 3D image stack with low z resolution

In [None]:
x = imread('raw_data/retina/cropped_farred_RFP_GFP_2109175_2color_sub_10.20.tif')
print('image size =', x.shape)

plt.figure(figsize=(15,15))
plot_some(np.moveaxis(x,1,0)[:,10],
          np.moveaxis(x,1,0)[:,:,64],
          title_list=[['xy slice, channel 0','xy slice, channel 1'],
                      ['xz slice, channel 0','xz slice, channel 1']],
          pmin=2, pmax=99.8);

In [None]:
plt.figure(figsize=(15,15))
plot_some(np.moveaxis(np.moveaxis(x,1,-1)[:,[50,-50]],1,0), title_list=[['xz slice','xz slice']], pmin=2,pmax=99.8);

<hr style="height:2px;">

# Model

Load trained model (located in folder `my_model`) from disk.  
The configuration was saved during training and is automatically loaded when `CARE` is initialized with `config=None`.

In [None]:
model = IsotropicCARE(config=None, name='my_model')
model.load_weights()

- Select appropriate normalization
- Choose how to resize the image to be able to apply the model

In [None]:
normalizer = PercentileNormalizer(1,99.8)
resizer = PadAndCropResizer()

## Apply CARE network to raw image

Predict the restored image

In [None]:
%%time

z_factor = 10.2 # z upscaling factor
restored = model.predict(x, z_factor, normalizer, resizer, z=0, channel=1)

print('input  (z, ch, y, x) = ', x.shape)
print('output (z, ch, y, x) = ', restored.shape)
print()

<hr style="height:2px;">

# Reconstructed image via CARE network

In [None]:
plt.figure(figsize=(15,15))
plot_some(np.moveaxis(restored,1,0)[:,100],
          np.moveaxis(restored,1,0)[:,:,64],
          title_list=[['xy slice, channel 0','xy slice, channel 1'],
                      ['xz slice, channel 0','xz slice, channel 1']],
          pmin=2, pmax=99.8);

In [None]:
plt.figure(figsize=(15,15))
plot_some(np.moveaxis(np.moveaxis(restored,1,-1)[:,[50,-50]],1,0), title_list=[['xz slice','xz slice']], pmin=2,pmax=99.8);