# Exercise 1. Training COSDD<br>
In this section, we will train a COSDD model to remove row correlated and signal-dependent imaging noise. 
You will load noisy data and examine the noise for spatial correlation, then initialise a model and monitor its training.
Finally, you'll use the model to denoise the data.

COSDD is a Ladder VAE with an autoregressive decoder, a type of deep generative model. Deep generative models are trained with the objective of capturing all the structures and characteristics present in a dataset, i.e., modelling the dataset. In our case the dataset will be a collection of noisy microscopy images. 

When COSDD is trained to model noisy images, it exploits differences between the structure of imaging noise and the structure of the clean signal to separate them, capturing each with different components of the model. Specifically, the noise will be captured by the autoregressive decoder and the signal will be captured by the VAE's latent variables. We can then feed an image into the model and sample a latent variable, which will describe the image's clean signal content. This latent variable is then fed through a second network, which was trained alongside the main VAE, to reveal an estimate of the denoised image.

<div class="alert alert-danger">
Set your python kernel to <code>05_image_restoration</code>
</div>

In [None]:
import os

import torch
import tifffile
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from COSDD import utils
from COSDD.models.lvae import LadderVAE
from COSDD.models.pixelcnn import PixelCNN
from COSDD.models.s_decoder import SDecoder
from COSDD.models.unet import UNet
from COSDD.models.hub import Hub

%matplotlib inline

In [None]:
assert torch.cuda.is_available()

### 1.1. Load data

In this example we will be using the Mito Confocal dataset, provided by: <br>
Hagen, G.M., Bendesky, J., Machado, R., Nguyen, T.A., Kumar, T. and Ventura, J., 2021. Fluorescence microscopy datasets for training deep neural networks. GigaScience, 10(5), p.giab032.

You will have tried denoising this in the final section of the N2V exercise and hopefully noticed the horizontal artifacts. In this exercise, we'll train a model that can handle spatially correlated noise and won't leave behind artifacts.

<div class="alert alert-info">

### Task 1.1.

The low signal-to-noise ratio data that we will be using in this exercise has been downloaded and stored as a tiff file at `./../data/mito-confocal-lowsnr.tif`. 

In the following cell, you'll load it and get it into a format suitable for training the denoiser.

1. Use the function `tifffile.imread` to load the data as a numpy array.
2. Then use np.newaxis to add a channel axis. *Hint* The data is a stack of 2D images, so the channel axis should be the second dimension (dimension 1 if we start counting from zero).
3. Next, use `torch.from_numpy` to convert it into a pytorch tensor.
4. Lastly, convert the datatype to `torch.float32`.

COSDD can handle 1-, 2- and 3-dimensional data, as long as it's loaded as a PyTorch tensor with a batch and channel dimension. For 1D data, it should have dimensions [Number of images, Channels, X], for 2D data: [Number of images, Channels, Y, X] and for 3D: [Number of images, Channels, Z, Y, X]. This applies even if the data has only one channel.
</div>

In [None]:
# load the data
low_snr = ...
low_snr = ...
low_snr = ...
low_snr = ...

assert [*low_snr.size()] == [79, 1, 1024, 1024], "Incorrect dimensions"
assert low_snr.dtype == torch.float32, "Incorrect data type"

### 1.2. Examine spatial correlation of the noise

COSDD can be applied to noise that is correlated along rows or columns of pixels (or not spatially correlated at all). 
However, it cannot be applied to noise that is correlated along rows *and* columns of pixels.
Noise2Void on the other hand, is designed for noise that is not spatially correlated at all.

When we say that the noise is spatially correlated, we mean that knowing the value of the noise in one pixel tells us something about the noise in other (usually nearby) pixels.
Specifically, positive correlatation between two pixels tells us that if the intensity of the noise value in one pixel is high, the intensity of the noise value in the other pixel is likely to be high.
Similarly, if one is low, the other is likely to be low.
Negative correlation between pixels means that a low noise intensity in one pixel is more likely if the intensity in the other is high, and vice versa.

To examine an image's spatial correlation, we can create an autocorrelation plot. 
The plot will have two axes, horizontal lag and vertical lag, and tells us what the correlation between a pair of pixels separated by a given horizontal and vertical lag is.
For example, if the square at a horizontal lag of 3 and a vertical lag of 6 is red, it means that if we picked any pixel in the image, then counted 3 pixels to the right and 6 pixels down, this pair of pixels are positively correlated.
Correlation is symmetric, so the same is true if we counted left or up.

<div class="alert alert-warning">

### Question 1.1.

Below are three autocorrelation plots. The show how the noise is spatially correlated in three different examples of noise.
Identify which noise examples could be removed by:<br>
(a) COSDD<br>
(b) Noise2Void<br>
(c) neither
</div>

<img src="resources/ac-question.png"/>

<div class="alert alert-info">

### Task 1.2.

Now we will create an autocorrelation plot of the data we loaded.
To do this, we need a sample of pure noise.
This can be a patch of `low_snr` with no signal. 
Adjust the values for `image_idx`, `top`, `bottom`, `left` and `right` to explore slices of the data and identify a suitable dark patch. 
When decided, set the `dark_patch` in the following cell and pass it as an argument to `utils.autocorrelation`, then plot the result. 

*Hint: The bigger the dark patch, the more accurate our estimate of the spatial autocorrelation will be.*
</div>

In [None]:
vmin = np.percentile(low_snr, 1)
vmax = np.percentile(low_snr, 99)

In [None]:
### Explore slices of the data here
image_index = 0
top = 0
bottom = 1024
left = 0
right = 1024

crop = (image_index, 0, slice(top, bottom), slice(left, right))

plt.figure(figsize=(10, 10))
plt.imshow(low_snr[crop], vmin=vmin, vmax=vmax)
plt.show()

In [None]:
### Define the crop of the dark image patch here
dark_image_index = ...
dark_top = ...
dark_bottom = ...
dark_left = ...
dark_right = ...

dark_crop = (dark_image_index, 0, slice(dark_top, dark_bottom), slice(dark_left, dark_right))
dark_patch = low_snr[dark_crop]

noise_ac = utils.autocorrelation(dark_patch, max_lag=25)

In [None]:
# Plot the autocorrelation
plt.figure()
plt.imshow(noise_ac, cmap="seismic", vmin=-1, vmax=1)
plt.colorbar()
plt.title("Autocorrelation of the noise")
plt.xlabel("Horizontal lag")
plt.ylabel("Vertical lag")
plt.show()

In this plot, all of the squares should be white, except for the top row. The autocorrelation of the square at (0, 0) will always be 1.0, as a pixel's value will always be perfectly correlated with itself. We define this type of noise as correlated along the x axis.

To remove this type of noise, the autoregressive decoder of our VAE must have a receptive field spanning the x axis.
Note that if the data contained spatially *un*correlated noise, we could still remove it, as the decoder's receptive field will become redundant.

<div class="alert alert-success">

## Checkpoint 1
Now that we're familiar with our data, we'll train a COSDD model to denoise it.

</div>

### 1.3. Create training and validation dataloaders

The data will be fed to the model by two dataloaders, `train_loader` and `val_loader`, for the training and validation set respectively. <br>
In this example, 90% of images will be used for training and the remaining 10% for validation.

`real_batch_size` (int) Number of images passed through the network at a time. <br>
`n_grad_batches` (int) Number of batches to pass through the network before updating parameters.<br>
`crop_size` (tuple(int)): The size of randomly cropped patches. Should be less than the dimensions of your images.<br>
`train_split` (0 < float < 1): Fraction of images to be used in the training set, with the remainder used for the validation set.


In [None]:
real_batch_size = 4
n_grad_batches = 4
print(f"Effective batch size: {real_batch_size * n_grad_batches}")
crop_size = (256, 256)
train_split = 0.9

n_iters = np.prod(low_snr.shape[2:]) // np.prod(crop_size)
transform = utils.RandomCrop(crop_size)

dataset = utils.TrainDataset(low_snr, n_iters=n_iters, transform=transform)
train_set, val_set = torch.utils.data.random_split(dataset, [train_split, 1-train_split])
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=real_batch_size, shuffle=True, pin_memory=True, num_workers=7,
)
val_loader = torch.utils.data.DataLoader(
    val_set, batch_size=real_batch_size, shuffle=False, pin_memory=True, num_workers=7,
)

### 1.4. Create the model

The model we will train to denoise consists of four modules, with forth being the optional Direct Denoiser which we can train if we want to speed up inference. Each module is listed below with an explanation of their hyperparameters.

<img src="resources/explainer.png"/>

COSDD is a Variational Autoencoder (solid arrows) trained to model the distribution of noisy images $\mathbf{x}$. 
The autoregressive (AR) decoder models the noise component of the images, while the latent variable models only the clean signal component $\mathbf{s}$.
In a second step (dashed arrows), the \emph{signal decoder} is trained to map latent variables into image space, producing an estimate of the signal underlying $\mathbf{x}$.
{\bf b):}
To ensure that the decoder models only the imaging noise and the latent variables capture only the signal, the AR decoder's receptive field is modified.
In a full AR receptive field, each output pixel (red) is a function of all input pixels located above and to the left (blue). In our decoder's row-based AR receptive field, each output pixel is a function of input pixels located in the same row, which corresponds to the row-correlated structure of imaging noise.

`dimensions` (int): The dimensionality of the data. Can be 1, 2, or 3.

`lvae` The ladder variational autoencoder that will output latent variables.<br>
* `s_code_channels` (int): Number of channels in outputted latent variable.
* `n_layers` (int): Number of levels in the ladder vae.
* `z_dims` (list(int)): List with the numer of latent space dimensions at each level of the hierarchy. List starts from the input/output level and works down.
* `downsampling` (list(int)): Binary list of whether to downsample at each level of the hierarchy. 1 for do and 0 for don't.

`ar_decoder` The autoregressive decoder that will decode latent variables into a distribution over the input.<br>
* `kernel_size` (int): Length of 1D convolutional kernels.
* `noise_direction` (str): Axis along which noise is correlated: `"x"`, `"y"` or `"z"`. This needs to match the orientation of the noise structures we revealed in the autocorrelation plot in Task 1.2.
* `n_filters` (int): Number of feature channels.
* `n_gaussians` (int): Number of components in Gaussian mixture used to model data.

`s_decoder` A decoder that will map the latent variables into image space, giving us a denoised image. <br>
* `n_filters` (int): The number of feature channels.<br>

`direct_denoiser` The U-Net that can optionally be trained to predict the MMSE or MMAE of the denoised images. This will slow training slightly but massively speed up inference and is worthwile if you have an inference dataset in the gigabytes. See [this paper](https://arxiv.org/abs/2310.18116). Enable or disable the direct denoiser by setting `use_direct_denoiser` to `True` or `False`.
* `n_filters` (int): Feature channels at each level of UNet. Defaults to `s_code_channel`.
* `n_layers` (int): Number of levels in the UNet. Defaults to the number of levels in the `LadderVAE`.
* `downsampling` (list(int)): Binary list of whether to downsample at each level of the hierarchy. 1 for do and 0 for don't. Also defaults to match the `LadderVAE`.
* `loss_fn` (str): Whether to use `"L1"` or `"L2"` loss function to predict either the mean or pixel-wise median of denoised images respectively.

`hub` The hub that will unify and train the above modules.
* `n_grad_batches` (int): Number of batches to accumulate gradients for before updating weights of all models. If the real batch or random crop size has been reduced to lower memory consumption, increase this value for the effective batch size to stay the same.
* `checkpointed` (bool): Whether to use activation checkpointing during training. This reduces memory consumption but increases training time. 

<div class="alert alert-info">

### Task 1.3.

Most hyperparameters have been set to recommended values for a small sized model. The three that have been left blank are `dimensions`, `noise_direction` under the `ar_decoder`, and `use_direct_denoiser`. Use the above description of what each hyperparameter means to determine the best value for each of these.

*Hint: In this notebook we're using 2D data*<br>
*Hint: enabling the Direct Denoiser will give us additional results to look at in the next notebook.*

</div>

In [None]:
dimensions = ... ### Insert a value here
s_code_channels = 32

n_layers = 6
z_dims = [s_code_channels // 2] * n_layers
downsampling = [1] * n_layers
lvae = LadderVAE(
    colour_channels=low_snr.shape[1],
    img_size=crop_size,
    s_code_channels=s_code_channels,
    n_filters=s_code_channels,
    z_dims=z_dims,
    downsampling=downsampling,
    dimensions=dimensions,
)

ar_decoder = PixelCNN(
    colour_channels=low_snr.shape[1],
    s_code_channels=s_code_channels,
    kernel_size=5,
    noise_direction=...  ### Insert a value here
    n_filters=32,
    n_layers=4,
    n_gaussians=4,
    dimensions=dimensions,
)

s_decoder = SDecoder(
    colour_channels=low_snr.shape[1],
    s_code_channels=s_code_channels,
    n_filters=s_code_channels,
    dimensions=dimensions,
)

use_direct_denoiser = ...  ### Insert a value here
if use_direct_denoiser:
    direct_denoiser = UNet(
        colour_channels=low_snr.shape[1],
        n_filters=s_code_channels,
        n_layers=n_layers,
        downsampling=downsampling,
        loss_fn="L2",
        dimensions=dimensions,
    )
else:
    direct_denoiser = None

hub = Hub(
    vae=lvae,
    ar_decoder=ar_decoder,
    s_decoder=s_decoder,
    direct_denoiser=direct_denoiser,
    data_mean=low_snr.mean(),
    data_std=low_snr.std(),
    n_grad_batches=n_grad_batches,
    checkpointed=True,
)

### 1.5. Train the model

<div class="alert alert-info">

### Task 1.4.

Open Tensorboard (check Task 3 in 01_CARE) to monitor training.
This model is unlike the previous two because it has more than one loss curve.
The cell below describes how to interpret each one.
</div>

#### Tensorboard metrics

In the SCALARS tab, there will be 4 metrics to track (5 if direct denoiser is enabled). These are:<br>
1. `kl_loss` The Kullback-Leibler divergence between the VAE's approximate posterior and its prior. This can be thought of as a measure of how much information about the input image is going into the VAE's latent variables. We want information about the input's underlying clean signal to go into the latent variables, so this metric shouldn't go all the way to zero. Instead, it can typically go either up or down during training before plateauing.<br>
2. `reconstruction_loss` The negative log-likelihood of the AR decoder's predicted distribution given the input data. This is how accurately the AR decoder is able to predict the input. This value can go below zero and should decrease throughout training before plateauing.<br>
3. `elbo` The Evidence Lower Bound, which is the total loss of the main VAE. This is the sum of the kl and reconstruction loss and should decrease throughout training before plateauing.<br>
4. `sd_loss` The mean squared error between the noisy image and the image predicted by the signal decoder. This metric should steadily decrease towards zero without ever reaching it. Sometimes the loss will not go down for the first few epochs because its input (produced by the VAE) is rapidly changing. This is ok and the loss should start to decrease when the VAE stabilises. <br>
5. `dd_loss` The mean squared error between the output of the direct denoiser and the clean images predicted by the signal decoder. This will only be present if `use_direct_denoiser` is set to `True`. The metric should steadily decrease towards zero without ever reaching it, but may be unstable at the start of training as its targets (produced by the signal decoder) are rapidly changing.

There will also be an IMAGES tab. This shows noisy input images from the validation set and some outputs. These will be two randomly sampled denoised images (sample 1 and sample 2), the average of ten denoised images (mmse) and if the direct denoiser is enabled, its output (direct estimate).

If noise has not been fully removed from the output images, try increasing `n_gaussians` argument of the AR decoder. This will give it more flexibility to model complex noise characteristics. However, setting the value too high can lead to unstable training. Typically, values from 3 to 5 work best.

Note that the trainer is set to train for only 10 minutes in this example. Remove the line with `max_time` to train fully.

<div class="alert alert-info">

### Task 1.5.

Now the model is ready to start training. Give the model a sensible name by setting `model_name` to a string, then run the following cells.

The `max_time` parameter in the cell below means we'll only train the model for 10 minutes, just to get idea of what to expect. In the future, to remove the time restriction, the `max_time` parameter can be set to `None`.
</div>

`model_name` (str): Should be set to something appropriate so that the trained parameters can be used later for inference.<br>
`max_epochs` (int): The number of training epochs.<br>
`patience` (int): If the validation loss has plateaued for this many epochs, training will stop.

In [None]:
model_name = ...  ### Insert a value here
checkpoint_path = os.path.join("checkpoints", model_name)
logger = TensorBoardLogger(checkpoint_path)

max_epochs = 1000
max_time = "00:00:10:00"
patience = 100

trainer = pl.Trainer(
    logger=logger,
    accelerator="gpu",
    devices=1,
    max_epochs=max_epochs,
    max_time=max_time,  # Remove this time limit to train the model fully
    log_every_n_steps=len(train_set) // (n_grad_batches * real_batch_size),
    callbacks=[EarlyStopping(patience=patience, monitor="val/elbo")],
)

In [None]:
trainer.fit(hub, train_loader, val_loader)
trainer.save_checkpoint(os.path.join(checkpoint_path, "final_model.ckpt"))
torch.cuda.empty_cache()

<div class="alert alert-success">

## Checkpoint 2
We've now trained a COSDD model to denoise our data. Continue to the next part to use it to get some results.

</div>

# Exercise 2. Inference with COSDD

### 2.1. Load test data
The images that we want to denoise are loaded here. These are the same that we used for training, but we'll only load 10 to speed up inference.

In [None]:
lowsnr_path = "./../data/mito-confocal-lowsnr.tif"
n_test_images = 5
# load the data
test_set = tifffile.imread(lowsnr_path)
test_set = test_set[:n_test_images, np.newaxis]
test_set = torch.from_numpy(test_set)
test_set = test_set.to(torch.float32)

As with training, data should be a `torch.Tensor` with dimensions: [Number of images, Channels, Z | Y | X] with data type float32.

### Part 2. Create prediction dataloader

`predict_batch_size` (int): Number of denoised images to produce at a time.

In [None]:
predict_batch_size = 1

predict_set = utils.PredictDataset(test_set)
predict_loader = torch.utils.data.DataLoader(
    predict_set,
    batch_size=predict_batch_size,
    shuffle=False,
    pin_memory=True,
)

### 2.3. Load trained model

<div class="alert alert-info">

### Task 2.1.

Our model was only trained for 10 minutes. This is long enough to get some denoising results, but a model trained for longer would do better. In the cell below, load the trained model by recalling the value you gave for `model_name`. Then procede through the notebook to look at how well it performs. 

Once you reach the end of the notebook, return to this cell to load a model that has been trained for 3.5 hours by uncommenting line 4, then run the notebook again to see how much difference the extra training time makes. 
</div>

In [None]:
model_name = ...   ### Insert a string here
checkpoint_path = os.path.join("checkpoints", model_name)

# checkpoint_path = "checkpoints/mito-confocal-pretrained" ### Once you reach the bottom of the notebook, return here and uncomment this line to see the pretrained model

hub = Hub.load_from_checkpoint(os.path.join(checkpoint_path, "final_model.ckpt"))

predictor = pl.Trainer(
    accelerator="gpu",
    devices=1,
    enable_progress_bar=False,
    enable_checkpointing=False,
    logger=False,
)

In [None]:
model_name = "mito-confocal"   ### Insert a string here
checkpoint_path = os.path.join("checkpoints", model_name)

# checkpoint_path = "checkpoints/mito-confocal-pretrained" ### Once you reach the bottom of the notebook, return here and uncomment this line to see the pretrained model

hub = Hub.load_from_checkpoint(os.path.join(checkpoint_path, "final_model.ckpt"))

predictor = pl.Trainer(
    accelerator="gpu",
    devices=1,
    enable_progress_bar=False,
    enable_checkpointing=False,
    logger=False,
)

### 2.4. Denoise
In this section, we will look at how COSDD does inference. <br>

The model denoises images randomly, giving us a different output each time. First, we will compare seven randomly sampled denoised images for the same noisy image. Then, we will produce a single consensus estimate by averaging 100 randomly sampled denoised images. Finally, if the direct denoiser was trained in the previous step, we will see how it can be used to estimate this average in a single pass.

### 2.4.1 Random sampling 
First, we will denoise each image seven times and look at the difference between each estimate. The output of the model is stored in the `samples` variable. This has dimensions [Number of images, Sample index, Channels, Z | Y | X] where different denoised samples for the same image are stored along sample index.

In [None]:
use_direct_denoiser = False
n_samples = 7

hub.direct_pred = use_direct_denoiser
samples = []
for _ in tqdm(range(n_samples)):
    out = predictor.predict(hub, predict_loader)
    out = torch.cat(out, dim=0)
    samples.append(out)

samples = torch.stack(samples, dim=1).half()

<div class="alert alert-info">

### Task 2.2.

Here, we'll look at the original noisy image and the seven denoised estimates. Change the value for `img_idx` to look at different images and change values for `top`, `bottom`, `left` and `right` to adjust the crop. Use this section to really explore the results. Compare high intensity reigons to low intensity reigons, zoom in and out and spot the differences between the different samples. 
</div>

In [None]:
vmin = np.percentile(test_set.numpy(), 1)
vmax = np.percentile(test_set.numpy(), 99)

In [None]:
img_idx = 0
top = 0
bottom = 1024
left = 0
right = 1024

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(2, 4, figsize=(16, 8))
ax[0, 0].imshow(test_set[img_idx][crop], vmin=vmin, vmax=vmax)
ax[0, 0].set_title("Input")
for i in range(n_samples):
    ax[(i + 1) // 4, (i + 1) % 4].imshow(
        samples[img_idx][i][crop], vmin=vmin, vmax=vmax
    )
    ax[(i + 1) // 4, (i + 1) % 4].set_title(f"Sample {i+1}")

plt.show()

The six sampled denoised images have subtle differences that express the uncertainty involved in this denoising problem.

### 2.4.2 MMSE estimate

In the next cell, we sample many denoised images and average them for the minimum mean square estimate (MMSE). The averaged images will be stored in the `MMSEs` variable, which has the same dimensions as `low_snr`. 

<div class="alert alert-info">

### Task 2.3.
Set `n_samples` to 100 to average 100 images, or a different value to average a different number. Then visually inspeect the results. Examine how the MMSE result differs from the random sample.
</div>

In [None]:
use_direct_denoiser = False
n_samples = ...   ### Insert an integer here

hub.direct_pred = use_direct_denoiser

samples = []
for _ in tqdm(range(n_samples)):
    out = predictor.predict(hub, predict_loader)
    out = torch.cat(out, dim=0)
    samples.append(out)

samples = torch.stack(samples, dim=1).half()
MMSEs = torch.mean(samples, dim=1)

In [None]:
use_direct_denoiser = False
n_samples = 100   ### Insert an integer here

hub.direct_pred = use_direct_denoiser

samples = []
for _ in tqdm(range(n_samples)):
    out = predictor.predict(hub, predict_loader)
    out = torch.cat(out, dim=0)
    samples.append(out)

samples = torch.stack(samples, dim=1).half()
MMSEs = torch.mean(samples, dim=1)

In [None]:
img_idx = 0
top = 0
bottom = 1024
left = 0
right = 1024

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(test_set[img_idx][crop], vmin=vmin, vmax=vmax)
ax[0].set_title("Input")
ax[1].imshow(samples[img_idx][0][crop], vmin=vmin, vmax=vmax)
ax[1].set_title("Sample")
ax[2].imshow(MMSEs[img_idx][crop], vmin=vmin, vmax=vmax)
ax[2].set_title("MMSE")

plt.show()

The MMSE will usually be closer to the reference than an individual sample and would score a higher PSNR, although it will also be blurrier.

### 2.4.3 Direct denoising
Sampling 100 images and averaging them is a very time consuming. If the direct denoiser was trained in a previous step, it can be used to directly output what the average denoised image would be for a given noisy image.

<div class="alert alert-info">

### Task 2.4.

Did you enable the direct denoiser in the previous notebook? If so, set `use_direct_denoiser` to `True` to use the Direct Denoiser for inference. If not, go back to section 2.3 to load the pretrained model and return here. 

Notice how much quicker the direct denoiser is than generating the MMSE results. Visually inspect and explore the results in the same way as before, notice how similar the direct estimate and MMSE estimate are.
</div>

In [None]:
use_direct_denoiser = ...   ### Insert a boolean here
hub.direct_pred = use_direct_denoiser

direct = predictor.predict(hub, predict_loader)
direct = torch.cat(direct, dim=0).half()

In [None]:
use_direct_denoiser = True   ### Insert a boolean here
hub.direct_pred = use_direct_denoiser

direct = predictor.predict(hub, predict_loader)
direct = torch.cat(direct, dim=0).half()

In [None]:
img_idx = 0
top = 0
bottom = 1024
left = 0
right = 1024

crop = (0, slice(top, bottom), slice(left, right))

fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(test_set[img_idx][crop], vmin=vmin, vmax=vmax)
ax[0].set_title("Input")
ax[1].imshow(direct[img_idx][crop], vmin=vmin, vmax=vmax)
ax[1].set_title("Direct")
ax[2].imshow(MMSEs[img_idx][crop], vmin=vmin, vmax=vmax)
ax[2].set_title("MMSE")

plt.show()

### 2.5. Incorrect receptive field

We've now trained a model and used it to remove structured noise from our data. Before moving onto the next notebook, we'll look at what happens when a COSDD model is trained without considering the noise structures present. 

COSDD is able to separate imaging noise from clean signal because its autoregressive decoder has a receptive field that spans pixels containing correlated noise, i.e., the row or column of pixels. If its receptive field did not contain pixels with correlated noise, it would not be able to model them and they would be captured by the VAE's latent variables. To demonstrate this, the image below shows a Direct and MMSE estimate of a denoised image where the autoregressive decoder's receptive field was incorrectly set to vertical, leaving it unable to model horizontal noise.

<img src="./resources/penicillium_ynm.png">

<div class="alert alert-success">

## Checkpoint 3

We've completed the process of training and applying a COSDD model for denoising, but there's still more it can do. Optionally continue to the bonus notebook, bonus-exercise-generation.ipynb, to see how the model of the data can be used to generate new clean and noisy images.

Otherwise, continue to 04_DenoiSplit.

</div>