<hr style="height:2px;">

# Demo: Probabilistic prediction for denoising of 2D toy data

### Notes 

- Assumes that training was already completed via [training.ipynb](training.ipynb).
- The trained CARE network is here applied to the same image that the model was trained on (data generated via [datagen.ipynb](datagen.ipynb)).  
Of course, in practice one would typically use it to restore images that the model hasn't seen during training.
- Documentation available: http://csbdeep.bioimagecomputing.com/doc/

<hr style="height:2px;">
# Setup 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import os
from tifffile import imread
from csbdeep.models import CARE
from csbdeep.predict import PercentileNormalizer, PadAndCropResizer
from csbdeep.plot_utils import plot_some

<hr style="height:2px;">

# Download example data (if necessary)

In [None]:
from csbdeep.utils import download_and_extract_zip_file
download_and_extract_zip_file(
    url = 'https://cloud.mpi-cbg.de/index.php/s/OnQFwcB0vCASBMu/download',
    provides = ('toy_data.npz',)
)

In [None]:
from csbdeep.train import load_data
(X,Y), (X_val,Y_val) = load_data('toy_data.npz', validation_split=0.1)
X_val, Y_val = X_val[...,0], Y_val[...,0]
del X,Y

<hr style="height:2px;">

# Input image and associated ground truth

In [None]:
y = Y_val[2]
x = X_val[2]
print('image size =', x.shape)

plt.figure(figsize=(15,10))
plot_some(np.stack([x,y]),title_list=[['input','target (GT)']], pmin=2,pmax=99.8);

<hr style="height:2px;">

# Model

Load trained model (located in folder `my_model`) from disk.  
The configuration was saved during training and is automatically loaded when `CARE` is initialized with `config=None`.

In [None]:
model = CARE(config=None, name='my_model')
model.load_weights()

- Select appropriate normalization
- Choose how to resize the image to be able to apply the model

In [None]:
normalizer = PercentileNormalizer(3,99.8)
resizer = PadAndCropResizer()

## Typical CARE prediction

Predict the restored image as in the non-probabilistic case if you're inly interested in a restored image.  
But actually, the network returns the expected restored image for a probabilistic network.

In [None]:
restored = model.predict(x, normalizer, resizer)

In [None]:
plt.figure(figsize=(15,10))
plot_some(np.stack([x,restored]), title_list=[['input','expected restored image']], pmin=2,pmax=99.8);

# Probabilistic prediction

In [None]:
prob = model.predict_probabilistic(x, normalizer, resizer)

In [None]:
plt.figure(figsize=(15,10))
plot_some(np.stack([prob.mean(),prob.scale()]), title_list=[['mean','scale']], pmin=2,pmax=99.8);

In [None]:
plt.figure(figsize=(15,10))
plot_some(np.stack([prob.variance(),prob.entropy()]), title_list=[['variance','entropy']], pmin=2,pmax=99.8);

In [None]:
gen = prob.sampling_generator()
samples = np.stack([next(gen) for _ in range(50)])

plt.figure(figsize=(15,10))
plot_some(samples[:3], pmin=2,pmax=99.8);

In [None]:
#plt.rcParams['animation.ffmpeg_path'] = '/home/uschmidt/bin/ffmpeg'
from matplotlib import animation
from IPython.display import HTML

fig = plt.figure(figsize=(8,8))
clim = np.percentile(samples,2), np.percentile(samples,99.8)
im = plt.imshow(samples[0], cmap='viridis', clim=clim)
plt.close()

def updatefig(j):
    im.set_array(samples[j])
    return [im]

anim = animation.FuncAnimation(fig, updatefig, frames=len(samples), interval=50)
HTML(anim.to_jshtml())