<hr style="height:2px;">

# Demo: Probabilistic prediction for denoising of 2D toy data

### Notes 

- Assumes that training was already completed via [training.ipynb](training.ipynb).
- The trained CARE network is here applied to the same image that the model was trained on.  
Of course, in practice one would typically use it to restore images that the model hasn't seen during training.
- Documentation available: http://csbdeep.bioimagecomputing.com/doc/

<hr style="height:2px;">
# Setup 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import os
from tifffile import imread
from csbdeep.models import CARE
from csbdeep.predict import NoNormalizer
from csbdeep.plot_utils import plot_some

<hr style="height:2px;">

# Download example data (if necessary)

In [None]:
from csbdeep.utils import download_and_extract_zip_file
download_and_extract_zip_file(
    url = 'https://cloud.mpi-cbg.de/index.php/s/OnQFwcB0vCASBMu/download',
    provides = ('toy_data.npz',)
)

In [None]:
from csbdeep.train import load_data
(X,Y), (X_val,Y_val), axes = load_data('toy_data.npz', axes='SCYX', validation_split=0.1)
X_val, Y_val = X_val[...,0], Y_val[...,0]
axes = axes[:-1]
del X,Y

<hr style="height:2px;">

# Input image and associated ground truth

In [None]:
y = Y_val[2]
x = X_val[2]
axes = axes[1:]
print('image size =', x.shape)
print('image axes =', axes)

plt.figure(figsize=(15,10))
plot_some(np.stack([x,y]),title_list=[['input','target (GT)']], pmin=2,pmax=99.8);

<hr style="height:2px;">

# Model

Load trained model (located in folder `my_model`) from disk.  
The configuration was saved during training and is automatically loaded when `CARE` is initialized with `config=None`.

In [None]:
model = CARE(config=None, name='my_model')
model.load_weights()

- Select appropriate normalization (no need to normalize here since chosen image here is already normalized)

In [None]:
normalizer = NoNormalizer()

## Typical CARE prediction

Predict the restored image as in the non-probabilistic case if you're inly interested in a restored image.  
But actually, the network returns the expected restored image for a probabilistic network.

In [None]:
restored = model.predict(x, axes, normalizer)

In [None]:
plt.figure(figsize=(15,10))
plot_some(np.stack([x,restored]), title_list=[['input','expected restored image']], pmin=2,pmax=99.8);

<hr style="height:2px;">

# Probabilistic prediction

Prediction of per-pixel Laplace distributions.

In [None]:
prob = model.predict_probabilistic(x, axes, normalizer)

In [None]:
plt.figure(figsize=(15,10))
plot_some(np.stack([prob.mean(),prob.scale()]), title_list=[['mean','scale']], pmin=2,pmax=99.8);

In [None]:
plt.figure(figsize=(15,10))
plot_some(np.stack([prob.var(),prob.entropy()]), title_list=[['variance','entropy']], pmin=2,pmax=99.8);

## Sampling restored images

In [None]:
gen = prob.sampling_generator(50)
samples = np.stack(gen)

plt.figure(figsize=(15,4))
plot_some(samples[:3], pmin=2,pmax=99.8);

### Sampling animation

In [None]:
from matplotlib import animation
from IPython.display import HTML

fig = plt.figure(figsize=(8,8))
clim = np.percentile(samples,2), np.percentile(samples,99.8)
im = plt.imshow(samples[0], clim=clim)
plt.close()

def updatefig(j):
    im.set_array(samples[j])
    return [im]

anim = animation.FuncAnimation(fig, updatefig, frames=len(samples), interval=50)
HTML(anim.to_jshtml())

## Prediction along line profile with credible intervals

In [None]:
i = 61
line = prob[i]
n = len(line.mean())

plt.figure(figsize=(16,9))
plt.subplot(211)
plt.imshow(prob.mean()[i-15:i+15])
plt.plot(range(n),15*np.ones(n),'--r',linewidth=2)
plt.title('expected restored image')
plt.xlim(0,n-1); plt.axis('off')

plt.subplot(212)
q = 0.025
plt.fill_between(range(n), line.ppf(q), line.ppf(1-q), alpha=0.5, label='%.0f%% credible interval'%(100*(1-2*q)))
plt.plot(line.mean(),linewidth=3, label='expected restored image')
plt.plot(y[i],'--',linewidth=3, label='ground truth')
plt.plot(x[i],':',linewidth=1, label='input image')
plt.title('line profile')
plt.xlim(0,n-1); plt.legend(loc='lower right')

None;