<hr style="height:2px;">

# Demo: Probabilistic CARE model for denoising of synthetic 2D data

This notebook demonstrates applying a probabilistic CARE model for a 2D denoising task, assuming that training was already completed via [1_training.ipynb](1_training.ipynb).  
The trained model is assumed to be located in the folder `models` with the name `my_model`.

More documentation is available at http://csbdeep.bioimagecomputing.com/doc/.

In [None]:
from __future__ import print_function, unicode_literals, absolute_import, division
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from csbdeep.utils import Path, download_and_extract_zip_file, plot_some
from csbdeep.io import load_training_data, save_tiff_imagej_compatible
from csbdeep.models import CARE

<hr style="height:2px;">

# Download example data

The example data should have been downloaded in [1_training.ipynb](1_training.ipynb).  
Just in case, we will download it here again if it's not already present.

In [None]:
download_and_extract_zip_file (
    url       = 'http://csbdeep.bioimagecomputing.com/example_data/synthetic_disks.zip',
    targetdir = 'data',
)

Load the validation images using during model training.

In [None]:
X_val, Y_val = load_training_data('data/synthetic_disks/data.npz', validation_split=0.1, verbose=True)[1]

We will apply the trained CARE model here to restore one validation image `x` (with associated ground truth `y`).

In [None]:
y = Y_val[2,...,0]
x = X_val[2,...,0]
axes = 'YX'

<hr style="height:2px;">

# Input image and associated ground truth

Plot the test image pair.

In [None]:
print('image size =', x.shape)
print('image axes =', axes)

plt.figure(figsize=(16,10))
plot_some(np.stack([x,y]), title_list=[['input','target (GT)']]);

<hr style="height:2px;">

# CARE model

Load trained model (located in base directory `models` with name `my_model`) from disk.  
The configuration was saved during training and is automatically loaded when `CARE` is initialized with `config=None`.

In [None]:
model = CARE(config=None, name='my_model', basedir='models')

## Typical CARE prediction

Predict the restored image as in the non-probabilistic case if you're only interested in a restored image.  
But actually, the network returns the expected restored image for the probabilistic network outputs.

Note 1: Since the synthetic image is already normalized, we don't need to do additional normalization.

**Note 2**: *Out of memory* problems during `model.predict` often indicate that the GPU is used by another process. In particular, shut down the training notebook before running the prediction (you may need to restart this notebook).

In [None]:
restored = model.predict(x, axes, normalizer=None)

In [None]:
plt.figure(figsize=(16,10))
plot_some(np.stack([x,restored]), title_list=[['input','expected restored image']]);

### Save restored image

Save the restored image stack as a ImageJ-compatible TIFF image, i.e. the image can be opened in ImageJ/Fiji with correct axes semantics.

In [None]:
Path('results').mkdir(exist_ok=True)
save_tiff_imagej_compatible('results/%s_validation_image.tif' % model.name, restored, axes)

<hr style="height:2px;">

# Probabilistic CARE prediction

We now predict the per-pixel Laplace distributions and return an object to work with these.

In [None]:
restored_prob = model.predict_probabilistic(x, axes, normalizer=None)

Plot the *mean* and *scale* parameters of the per-pixel Laplace distributions.

In [None]:
plt.figure(figsize=(16,10))
plot_some(np.stack([restored_prob.mean(),restored_prob.scale()]), title_list=[['mean','scale']]);

Plot the *variance* and *entropy* parameters of the per-pixel Laplace distributions.

In [None]:
plt.figure(figsize=(16,10))
plot_some(np.stack([restored_prob.var(),restored_prob.entropy()]), title_list=[['variance','entropy']]);

## Sampling restored images

Draw 50 samples of the distribution of the restored image. Plot the first 3 samples.

In [None]:
samples = np.stack(tuple(restored_prob.sampling_generator(50)))

plt.figure(figsize=(16,5))
plot_some(samples[:3], pmin=0.1,pmax=99.9);

Make an animation of the 50 samples.

In [None]:
from matplotlib import animation
from IPython.display import HTML

fig = plt.figure(figsize=(8,8))
im = plt.imshow(samples[0], vmin=np.percentile(samples,0.1), vmax=np.percentile(samples,99.9), cmap='magma')
plt.close()

def updatefig(j):
    im.set_array(samples[j])
    return [im]

anim = animation.FuncAnimation(fig, updatefig, frames=len(samples), interval=100)
HTML(anim.to_jshtml())

## Inspect predicted distribution along line profile with credible intervals

In [None]:
i = 61
line = restored_prob[i]
n = len(line)

plt.figure(figsize=(16,9))
plt.subplot(211)
plt.imshow(restored_prob.mean()[i-15:i+15], cmap='magma')
plt.plot(range(n),15*np.ones(n),'--w',linewidth=2)
plt.title('expected restored image')
plt.xlim(0,n-1); plt.axis('off')

plt.subplot(212)
q = 0.025
plt.fill_between(range(n), line.ppf(q), line.ppf(1-q), alpha=0.5, label='%.0f%% credible interval'%(100*(1-2*q)))
plt.plot(line.mean(),linewidth=3, label='expected restored image')
plt.plot(y[i],'--',linewidth=3, label='ground truth')
plt.plot(x[i],':',linewidth=1, label='input image')
plt.title('line profile')
plt.xlim(0,n-1); plt.legend(loc='lower right')

None;