<hr style="height:2px;">

# Demo: Apply trained CARE model for isotropic reconstruction of Zebrafish retina

This notebook demonstrates applying a CARE model for an isotropic reconstruction task, assuming that training was already completed via [2_training.ipynb](2_training.ipynb).  
The trained model is assumed to be located in the folder `models` with the name `my_model`.

**Note:** The CARE model is here applied to the same image that the model was trained on.  
Of course, in practice one would typically use it to restore images that the model hasn't seen during training.

More documentation is available at http://csbdeep.bioimagecomputing.com/doc/.

In [3]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from tifffile import imread
from csbdeep.utils import Path, download_and_extract_zip_file, plot_some
from csbdeep.io import save_tiff_imagej_compatible
from csbdeep.models import IsotropicCARE
from dexp.datasets.zarr_dataset import ZDataset

<hr style="height:2px;">

# Load data

The example data (also for testing) should have been downloaded in [1_datagen.ipynb](1_datagen.ipynb).  
Just in case, we will download it here again if it's not already present.

In [None]:
data_path = '/mnt/hd1/_dorado/2021/March/03262021_PhotoM/stabilized.croped.fused.deconv.zarr.zip'
# data_path = '/home/jordao/Softwares/interactive-tracker/croped.fused.deconv.zarr.zip'
dataset = ZDataset(data_path)
channel = 'DendraRed'

<hr style="height:2px;">

# Raw 3D image stack with low axial resolution

We plot XY and XZ slices of the stack and define the image axes and subsampling factor, which will be needed later for prediction.

In [1]:
x = dataset.get_stack(channel, 200)
subsample = 1.97 / 0.485
print('image size         =', x.shape)
print('Z subsample factor =', subsample)

plt.figure(figsize=(16,15))
plot_some(np.moveaxis(x, 1,-1)[[130,150]],
          title_list=[['XY slice','XY slice']],
          pmin=2,pmax=99.8);

plt.figure(figsize=(16,15))
plot_some(np.moveaxis(np.moveaxis(x,1,-1)[:,[350,450]],1,0),
          title_list=[['XZ slice','XZ slice']],
          pmin=2,pmax=99.8, aspect=subsample);

NameError: name 'dataset' is not defined

<hr style="height:2px;">

# Isotropic CARE model

Load trained model (located in base directory `models` with name `my_model`) from disk.  
The configuration was saved during training and is automatically loaded when `IsotropicCARE` is initialized with `config=None`.

In [None]:
model_path = 'my_model'
model = IsotropicCARE(config=None, name=model_path, basedir='models')

## Apply CARE network to raw image

Predict the reconstructed image, which can take a while. If there are memory issues, reduce the parameter `batch_size`.  

**Important:** You need to supply the subsampling factor, which must be the same as used during [training data generation](1_datagen.ipynb).

**Note**: *Out of memory* problems during `model.predict` can also indicate that the GPU is used by another process. In particular, shut down the training notebook before running the prediction (you may need to restart this notebook).

In [None]:
%%time

output_path = '/mnt/hd1/stabilized.zarr.zip'
output_ds = ZDataset(output_path, mode='a')

shape = dataset.shape(channel)
output_ds.add_channel(channel, shape, dtype=x.dtype)

axes = 'ZYX'
for t in range(dataset.nb_timepoints(shape[0])):
    restored = model.predict(dataset.get_stack(channel, t), axes, subsample)
    print('orig', x.shape)
    print('restored', restored.shape)
    output_ds.write_stack(channel, t, restored)

print('input  size =', x.shape)
print('output size =', restored.shape)
print()