<hr style="height:2px;">

# Train a Noise2Void network

Both the CARE network and Noise2Noise network you trained in part 1 and 2 require that you acquire additional data for the purpose of denoising. For CARE we used a paired acquisition with high SNR, for Noise2Noise we had paired noisy acquisitions.
We will now train a Noise2Void network from single noisy images.

This notebook uses a single image from the SEM data from the Noise2Noise notebook, but as you'll see in Task 3.1 if you brought your own raw data you should adapt the notebook to use that instead.

We now use the Noise2Void ([n2v](https://github.com/juglab/n2v)) library instead of csbdeep/care, but don't worry - they're pretty similar. Make sure you have the right kernel (`n2v`) selected. 

---
<div class="alert alert-block alert-info"><h4>
    TASK 3.1</h4>
    <p>
This notebook uses a single image from the SEM data from the Noise2Noise notebook.

If you brought your own raw data, use that instead!
The only requirement is that the noise in your data is pixel-independent and zero-mean. If you're unsure whether your data fulfills that requirement or you don't yet understand why it is necessary ask one of us to discuss!

If you don't have suitable data of your own, feel free to find some online or ask your fellow course participants. You can however also stick with the SEM data provided here and compare the results to what you achieved with Noise2Noise in the previous part.
    </p>
</div>

---

In [None]:
# We import all our dependencies.
from n2v.models import N2VConfig, N2V
import numpy as np
from csbdeep.utils import plot_history
from n2v.utils.n2v_utils import manipulate_val_data
from n2v.internals.N2V_DataGenerator import N2V_DataGenerator
from matplotlib import pyplot as plt
import urllib
import os
from skimage.metrics import structural_similarity, peak_signal_noise_ratio
from tifffile import imread
import zipfile
%load_ext tensorboard

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

<hr style="height:2px;">

## Part 1: Prepare data
We downloaded the data during the Noise2Noise exercise, but let's make sure it's there!

In [None]:
# if you get an error go to the Noise2Noise notebook and download the data
assert os.path.exists("data/SEM/train/train.tif")
assert os.path.exists("data/SEM/test/test.tif")

We create a N2V_DataGenerator object to help load data and extract patches for training and validation.

In [None]:
datagen = N2V_DataGenerator()

The data generator provides two methods for loading data: `load_imgs_from_directory` and `load_imgs`. Let's look at the docstring of `load_imgs` to figure out how to use it.

In [None]:
?N2V_DataGenerator.load_imgs_from_directory

We need to pass in the directory containing our image files (`"data/SEM/train"`), our image matches the default filter (`"*.tif"`) so we do not need to specify that. But our tif image is a stack of several images, so as dims we need to specify `"TYX"`.
If you're using your own data adapt this part to match your use case. You can also have a look at `load_imgs`. Or load your images manually.

In [None]:
imgs = datagen.load_imgs_from_directory("data/SEM/train", dims="TYX")
print(f"Loaded {len(imgs)} images.")
print(f"First image has shape {imgs[0].shape}")

The method returned a list of images, as per the doc string the dimensions of each are "SYXC". However, we only want to use one of the images here since Noise2Void is designed to work with just one acquisition of the sample. Let's use the first image at $1\mu s$ scantime.

In [None]:
imgs = [img[2:3,:,:,:] for img in imgs]
print(f"First image has shape {imgs[0].shape}")

For generating patches the datagenerator provides the methods `generate_patches` and `generate_patches_from_list`. As before, let's have a quick look at the docstring

In [None]:
?N2V_DataGenerator.generate_patches_from_list

In [None]:
patches = datagen.generate_patches_from_list(imgs, shape=(96,96))

In [None]:
# split into training and validation
n_train = int(round(.9 * patches.shape[0]))
X, X_val = patches[:n_train,...], patches[n_train:,...]

As per usual, let's look at a training and validation patch to make sure everything looks okay.

In [None]:
plt.figure(figsize=(14,7))
plt.subplot(1,2,1)
plt.imshow(X[np.random.randint(X.shape[0]),...,0], cmap="gray_r")
plt.title("Training patch")
plt.subplot(1,2,2)
plt.imshow(X_val[np.random.randint(X_val.shape[0]),...,0], cmap="gray_r")
plt.title("Validation patch")

<hr style="height:2px;">

## Part 2: Configure and train the Noise2Void Network

Noise2Void comes with a special config-object, where we store network-architecture and training specific parameters. See the docstring of the <code>N2VConfig</code> constructor for a description of all parameters.

When creating the config-object, we provide the training data <code>X</code>. From <code>X</code> we extract <code>mean</code> and <code>std</code> that will be used to normalize all data before it is processed by the network. We also extract the dimensionality and number of channels from <code>X</code>.

Compared to supervised training (i.e. traditional CARE), we recommend to use N2V with an increased <code>train_batch_size</code> and <code>batch_norm</code>.
To keep the network from learning the identity we have to manipulate the input pixels during training. For this we have the parameter <code>n2v_manipulator</code> with default value <code>'uniform_withCP'</code>. Most pixel manipulators will compute the replacement value based on a neighborhood. With <code>n2v_neighborhood_radius</code> we can control its size. 

Other pixel manipulators:
* normal_withoutCP: samples the neighborhood according to a normal gaussian distribution, but without the center pixel
* normal_additive: adds a random number to the original pixel value. The random number is sampled from a gaussian distribution with zero-mean and sigma = <code>n2v_neighborhood_radius</code>
* normal_fitted: uses a random value from a gaussian normal distribution with mean equal to the mean of the neighborhood and standard deviation equal to the standard deviation of the neighborhood.
* identity: performs no pixel manipulation

For faster training multiple pixels per input patch can be manipulated. In our experiments we manipulated about 0.198% of the input pixels per patch. For a patch size of 64 by 64 pixels this corresponds to about 8 pixels. This fraction can be tuned via <code>n2v_perc_pix</code>.

For Noise2Void training it is possible to pass arbitrarily large patches to the training method. From these patches random subpatches of size <code>n2v_patch_shape</code> are extracted during training. Default patch shape is set to (64, 64).  

In the past we experienced bleedthrough artifacts between channels if training was terminated to early. To counter bleedthrough we added the `single_net_per_channel` option, which is turned on by default. In the back a single U-Net for each channel is created and trained independently, thereby removing the possiblity of bleedthrough. <br/>
Essentially the network gets multiplied by the number of channels, which increases the memory requirements. If your GPU gets too small, you can always split the channels manually and train a network for each channel one after another.

---
<div class="alert alert-block alert-info"><h4>
    TASK 3.2</h4>
    <p>
As suggested look at the docstring of the N2VConfig and then generate a configuration for your Noise2Void network, and choose a name to identify your model by.
    </p>
</div>

In [None]:
?N2VConfig

In [None]:
###TODO###
config = N2VConfig(
)
vars(config)
model_name = ""

---

In [None]:
#initialize the model
model = N2V(config, model_name, basedir="models")

Now let's train the model and monitor the progress in tensorboard.
Adapt the command below as you did before.

In [None]:
%tensorboard --logdir=models

In [None]:
history = model.train(X, X_val)

In [None]:
print(sorted(list(history.history.keys())))
plt.figure(figsize=(16,5))
plot_history(history, ["loss", "val_loss"])

<hr style="height:2px;">

## Part 3: Prediction

Similar to CARE a previously trained model is loaded by creating a new N2V-object without providing a `config`.

In [None]:
model = N2V(config=None, name=model_name, basedir="models")

Let's load a $1\mu s$ scantime test images and denoise them using our network and like before we'll use the high SNR image to make a quantitative comparison. If you're using your own data and don't have an equivalent you can ignore that part.

In [None]:
test_img = imread("data/SEM/test/test.tif")[2,...]
test_img_highSNR = imread("data/SEM/test/test.tif")[-1,...]
print(f"Loaded test image with shape {test_img.shape}")

In [None]:
test_denoised = model.predict(test_img, axes="YX", n_tiles=(2,1))

Let's look at the results

In [None]:
plt.figure(figsize=(30,30))
plt.subplot(2,3,1)
plt.imshow(test_img, cmap="gray_r")
plt.title("Noisy test image")
plt.subplot(2,3,4)
plt.imshow(test_img[2000:2200,500:700], cmap="gray_r")
plt.subplot(2,3,2)
plt.imshow(test_denoised, cmap="gray_r")
plt.title("Denoised test image")
plt.subplot(2,3,5)
plt.imshow(test_denoised[2000:2200,500:700], cmap="gray_r")
plt.subplot(2,3,3)
plt.imshow(test_img_highSNR, cmap="gray_r")
plt.title("High SNR image (4x5us)")
plt.subplot(2,3,6)
plt.imshow(test_img_highSNR[2000:2200,500:700], cmap="gray_r")
plt.show()

---
<div class="alert alert-block alert-info"><h4>
    TASK 3.3</h4>
    <p>

If you're using the SEM data (or happen to have a high SNR version of the image you predicted from) compare the structural similarity index and peak signal to noise ratio (wrt the high SNR image) of the noisy input image and the predicted image.
    </p>
</div>

In [None]:
###TODO###
ssi_input = #TODO
ssi_restored = #TODO
print(f"Structural similarity index (higher is better) wrt average of 4x5us images: \n"
      f"Input: {ssi_input} \n"
      f"Prediction: {ssi_restored}")
psnr_input = #TODO
psnr_restored = #TODO
print(f"Peak signal-to-noise ratio (higher is better) wrt average of 4x5us images:\n"
      f"Input: {psnr_input} \n"
      f"Prediction: {psnr_restored}")

---
<hr style="height:2px;">
<div class="alert alert-block alert-success"><h1>
    Congratulations!</h1>
    <p>
    <b>You have reached the third checkpoint of this exercise! Please mark your progress on slack!</b>
    </p>
    <p>
    If there's still time, check out the bonus exercise.
    </p>
</div>