# <nobr>Micro$\mathbb{S}$plit</nobr>: Semantic Unmixing of Fluorescent Microscopy Data

In this notebook, we tackle the problem of joint splitting and unsupervised denoising, which has a use case in the field of fluorescence microscopy. From a technical perspective, given a noisy image $x$, the goal is to predict two images $c_1$ and $c_2$ such that $x = c_1 + c_2 + n$, where $n$ is the noise in $x$. In other words, we have a superimposed image $x$ and we want to predict the denoised estimates of the constituent images $c_1$ and $c_2$. It is important to note that the network is trained with noisy data and the denoising is done in a unsupervised manner.

For this, we will use [denoiSplit](https://arxiv.org/pdf/2403.11854.pdf), a recently developed approach for this task. In this notebook we train denoiSplit and later evaluate it on one validation frame. The overall schema for denoiSplit is shown below:
<!-- Insert a figure -->
<!-- ![Schema](imgs/teaser.png) -->
<img src="imgs/teaser.png" alt="drawing" width="800"/>

<div class="alert alert-danger">
Set your python kernel to <code>05_image_restoration</code>
</div>

In [None]:
from functools import partial
from pathlib import Path

import torch
import matplotlib.pyplot as plt
from careamics.lightning import VAEModule
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader

from microsplit_reproducibility.configs.data.custom_dataset_2D import get_data_configs
from microsplit_reproducibility.configs.factory import (
    create_algorithm_config,
    get_likelihood_config,
    get_loss_config,
    get_model_config,
    get_optimizer_config,
    get_training_config,
    get_lr_scheduler_config,
)
from microsplit_reproducibility.configs.parameters.custom_dataset_2D import get_microsplit_parameters
from microsplit_reproducibility.datasets import create_train_val_datasets
from microsplit_reproducibility.utils.callbacks import get_callbacks
from microsplit_reproducibility.utils.io import load_checkpoint_path
from microsplit_reproducibility.utils.utils import plot_input_patches
from microsplit_reproducibility.notebook_utils.custom_dataset_2D import (
    get_unnormalized_predictions,
    get_target,
    get_input,
    full_frame_evaluation,
    load_pretrained_model
)

from utils import get_train_val_data, compute_metrics

%matplotlib inline

In [None]:
assert torch.cuda.is_available()
torch.set_float32_matmul_precision('medium')

# **Exercise 1**: Training MicroSplit
Training is done in a supervised way. For every input patch, we have the two corresponding target patches using which we train our MicroSplit. 
Besides the primary input patch, we also feed LC inputs to MicroSplit. We introduced LC inputs in [μSplit: efficient image decomposition for microscopy data](https://openaccess.thecvf.com/content/ICCV2023/papers/Ashesh_uSplit_Image_Decomposition_for_Fluorescence_Microscopy_ICCV_2023_paper.pdf), which enabled the network to understand the global spatial context around the input patch.

To enable unsupervised denoising, we integrated the KL loss formulation and Noise models from our previous work [denoiSplit: a method for joint microscopy image splitting and unsupervised denoising](https://eccv.ecva.net/virtual/2024/poster/2538). 

The loss function for MicroSplit is a weighted average of denoiSplit loss and μSplit loss. For both denoiSplit and μSplit, their loss expression have two terms: KL divergence loss and likelihood loss. For more details, please refer to the respective papers.

## 1.1. Data Preparation

Since the channel unmixing capabilities of <nobr>Micro$\mathbb{S}$plit</nobr> are trained in a supervised way, we must later feed *(i)* input images that contain both selected structures, and *(ii)* two seperate channels that show these two structures separately. As previosuly mentioned, the mixed input image is obtained synthetically by overlapping the other two channels.

In this exercise, we will train a <nobr>Micro$\mathbb{S}$plit</nobr> network for unmixing superimposed channels from a dataset imaged at the National Facility for Light Imaging at Human Technopole.

This dataset contains four labeled structures: 
1. Cell Nucleui,
1. Microtubules,
1. Nuclear Membrane,
1. Centromeres/Kinetocores.

Additionally, this dataset offers acquisitions taken with different exposure times **(2, 20, 500 ms)**. Hence, the data is available at various [signal-to-noise ratios](https://en.wikipedia.org/wiki/Signal-to-noise_ratio#:~:text=Signal%2Dto%2Dnoise%20ratio%20(,power%2C%20often%20expressed%20in%20decibels.)) (SNR). Shorter exposure times entails the collection of fewer photons, leading to higher *Poisson shot noise* and, therefore, a lower SNR.

<div class="alert alert-info"><h4><b>Task 1.1.</b></h4>

In the following, you will be prompted to select:
1. The labeled structures to unmix;
2. The exposure time (and, thus, the SNR) of the input superimposed images.

Observe that:
- The more structures to unmix you pick, the more difficult the task becomes. Therefore, we suggest to start with an easier 2-structures unmixing and then try out 3 and 4 structures unmixing later on.
- The lower the SNR of the data you will choose to train <nobr>Micro$\mathbb{S}$plit</nobr> with, the more important will the unsupervised denoising feature of <nobr>Micro$\mathbb{S}$plit</nobr> become.

You can play with these parameters and check MicroSplit performance with different combinations.
</div>

A few notes:
- MicroSplit is trained on `(64, 64)` patches mainly for GPU memory contraints. Indeed, to train on full images we would need to reduce the batch size, which, unfortunately, has shown to hinder the model performance in our experiments.
- ...

In [None]:
# pick structures and exposure time
STRUCTURES = ["Nuclei", "Microtubules"] # choose among "Nuclei", "Microtubules", "NucMembranes", "Centromeres"
EXPOSURE_TIME = 500 # in ms, choose among 2, 20, 500 ms

assert EXPOSURE_TIME in [2, 20, 500], "Exposure time must be one of [2, 20, 500] ms"
assert all([
    s in ["Nuclei", "Microtubules", "NucMembranes", "Centromeres"] for s in STRUCTURES
]), "Invalid structure selected. Choose among 'Nuclei', 'Microtubules', 'NucMembranes', 'Centromeres'."

Custom functions for loading data

In [None]:
load_data_func = partial(get_train_val_data, structures=STRUCTURES)

In [None]:
datapath = Path(f"/group/jug/federico/data/MBL_course/{EXPOSURE_TIME}ms") # FIXME

In [None]:
train_data_config, val_data_config, test_data_config = get_data_configs(
    image_size=(64, 64),
    num_channels=len(STRUCTURES),
) # TODO: multiscale count hardcoded in here!!! Define experiment parameters before

In [None]:
# create the dataset
train_dset, val_dset, test_dset, data_stats = create_train_val_datasets(
    datapath=datapath,
    train_config=train_data_config,
    val_config=val_data_config,
    test_config=val_data_config,
    load_data_func=load_data_func,
)

In [None]:
train_dloader = DataLoader(
    train_dset,
    batch_size=32,
    num_workers=3,
    shuffle=True,
)
val_dloader = DataLoader(
    val_dset,
    batch_size=32,
    num_workers=3,
    shuffle=False,
)

<div class="alert alert-info"><h4><b>Check some training data patches!</b></h4>

***Tip:*** the following functions shows a few samples of the prepared training data. In case you don't like what you see (empty or noisy patches), execute the cell again. Different randomly chosen patches will be shown!

<div class="alert alert-warning"><h4><b>Question 1.1.</b></h4>

- Can you tell what is the role of the different patches shown below?
- What do the input patches show? Why are there multiple inputs?
- Why do we need targets? How do we use such targets? 

*Answers*
- First columns contain, respectively, the superimposed input patch, and the additional input patches for Lateral Contextualization (LC). The later columns show, instead, the target unmixed patches.
- Input patches represent the image obtained by superimposing (mixing) the signal coming from different labeled structures. The additional LC inputs are used to enhance the field of view and, hence, the semantic context processed by the network.
- For this task we need unmixing targets as we are doing Supervised Learning.

</div>

In [None]:
plot_input_patches(dataset=train_dset, num_channels=len(STRUCTURES), num_samples=3, patch_size=64)

<div class="alert alert-warning"><h4><b>Question 1.1.bis</b></h4>

Below are 2 examples of superimposed labeled structures with the correspondent ground truths. 
1. Which one you think it's harder to unmix? Why?
2. What are, in your opinion, features of the input data that would make unmixing more difficult? 

*Answers*
1. (b), because it shows more morphologically similar structures. MicroSplit is a content-aware method, i.e., it extracts semantic information regarding morphology, shape, brightness, etc., from the input data. Since structurally similar signal share many semantic features, the unmixing task becomes more challenging.
2. Semantic similarity between labeled structures, difference in brightness/intensity between labeled structures, colocalization, ...
</div>

<div class="alert alert-success"><h2><b>Checkpoint 1: Data Preparation</b></h2>
</div>

<hr style="height:2px;">

## 1.2. Setup <nobr>Micro$\mathbb{S}$plit</nobr> for training
In this section, we create all the configs for the upcoming model initialization and training run. Configs allow to group all the affine parameters in the same place (architecture, training, loss, etc. etc.) and offer automated validation of the input parameters to prevent the user from inputting wrong combinations.

Notice that <nobr>Micro$\mathbb{S}$plit</nobr> is being implemented in CAREamics library, therefore the API is quite similar to the one you (perhaps) saw for Noise2Void. 

TODO: add break down of parameters

In [None]:
NM_PATH = f"/group/jug/federico/data/MBL_course/noise_models/{EXPOSURE_TIME}ms"
"""The path to the noise models to load, if any."""

In [None]:
BATCH_SIZE = 32
"""The batch size for training."""
EPOCHS = 10
"""The number of epochs to train the network."""

In [None]:
# setting up MicroSplit parametrization
experiment_params = get_microsplit_parameters(
    algorithm="denoisplit",
    img_size=(64, 64),
    batch_size=BATCH_SIZE,
    num_epochs=EPOCHS,
    multiscale_count=3,
    noise_model_path=NM_PATH,
    target_channels=len(STRUCTURES),
) # TODO: use SplittingParameters class instead, check which parameters to expose
# add data statistics that will be used for data standardization
experiment_params["data_stats"] = data_stats

In [None]:
# setting up training losses and model config (using default parameters)
loss_config = get_loss_config(**experiment_params)
model_config = get_model_config(**experiment_params)
gaussian_lik_config, noise_model_config, nm_lik_config = get_likelihood_config(
    **experiment_params
)
training_config = get_training_config(**experiment_params)

# setting up learning rate scheduler and optimizer (using default parameters)
lr_scheduler_config = get_lr_scheduler_config(**experiment_params)
optimizer_config = get_optimizer_config(**experiment_params)

# finally, assemble the full set of experiment configurations...
experiment_config = create_algorithm_config(
    algorithm=experiment_params["algorithm"],
    loss_config=loss_config,
    model_config=model_config,
    gaussian_lik_config=gaussian_lik_config,
    nm_config=noise_model_config,
    nm_lik_config=nm_lik_config,
    lr_scheduler_config=lr_scheduler_config,
    optimizer_config=optimizer_config,
)

In [None]:
model = VAEModule(algorithm_config=experiment_config)

## 1.3. Train MicroSplit model

In this section we will train out MicroSplit model using `lightning`.

In [None]:
# TODO: exercise on callbacks?

In [None]:
# create the Trainer
trainer = Trainer(
    max_epochs=training_config.num_epochs,
    accelerator="gpu",
    enable_progress_bar=True,
    callbacks=get_callbacks("./checkpoints/"),
    precision=training_config.precision,
    gradient_clip_val=training_config.gradient_clip_val,
    gradient_clip_algorithm=training_config.gradient_clip_algorithm,
)

In [None]:
# start the training
trainer.fit(
    model=model,
    train_dataloaders=train_dloader,
    val_dataloaders=val_dloader,
)

## 1.4 Visualize predictions on validation data

In order to check that the training process has been successful, we check MicroSplit predictions on the validation set.

<div class="alert alert-warning"><h4><b>Question 1.4.</b></h4>

A proper evaluation including prediction on mutliple images and computation of performance metrics will be performed later on the test data.
Do you remember what are the limitations of evaluating a model's perfomance on the validation set, instead?

</div>

Before proceeding with the evaluation, let's focus once more on how <nobr>Micro$\mathbb{S}$plit</nobr> works.

As we mentioned, <nobr>Micro$\mathbb{S}$plit</nobr> uses a modified version Hierarchical Variational Autoencoder (HVAE) similarly to COSDD and other models you encountered during the course. This architecture, given an input patch, enables the generation of multiple outputs. Technically, this happens by sampling multiple different *latent vectors* in the latent space. In mathematical terms we say that "*<nobr>Micro$\mathbb{S}$plit</nobr> is learning a full posterior of possible solutions*".

This is a cool feature that makes our variational models pretty powerful and handy!!! Indeed, averaging multiple samples (predictions) generally allows to get smoother, more consistent predictions (in other terms, it somehow averages out potential "hallucinations" of the network). Moreover, by computing the pixel-wise standard deviation over multiple samples (predictions) we can obtain a preliminary estimate of the (data) uncertainty in the model's predictions.

In this framework, the parameter `mmse_count` (int) determines the number of samples (predictions) generated for any given input patch. A larger value allows to get smoother predictions, also limiting recurring issues such as *tiling artefacts*. However, it obviously increases the time and cost of the computation. Generally, a value of >5 is enough to get decently smooth predicted frames. For reference, in our papers we often use values of 50 to get the best results. 

In [None]:
MMSE_COUNT = 10
"""The number of MMSE samples to use for the splitting predictions."""

In [None]:
# Reduce the validation dataset to a single structure for testing
val_dset.reduce_data([0])

# Get patch predictions for the validation dataset + stitching into full images + de-normalization
stitched_predictions, _, _ = get_unnormalized_predictions(
    model, val_dset, mmse_count=MMSE_COUNT, num_workers=0, batch_size=8
)

In [None]:
# get the target and input from the validation dataset for visualization purposes
tar = get_target(val_dset)
inp = get_input(val_dset).sum(-1)

In [None]:
_, ax = plt.subplots(2, 2, figsize=(20, 20))
ax[0, 0].imshow(tar[0, ..., 0], cmap="gray")
ax[0, 0].set_title("Input ch1")
ax[0, 1].imshow(tar[0, ..., 1], cmap="gray")
ax[0, 1].set_title("Input ch2")
ax[1, 0].imshow(stitched_predictions[0, ..., 0], cmap="gray")
ax[1, 0].set_title("Prediction ch1")
ax[1, 1].imshow(stitched_predictions[0, ..., 1], cmap="gray")
ax[1, 1].set_title("Prediction ch2")

In [None]:
frame_idx = 0
assert frame_idx < len(stitched_predictions), f"Frame index {frame_idx} out of bounds. Max index is {len(stitched_predictions) - 1}."

full_frame_evaluation(stitched_predictions[frame_idx], tar[frame_idx], inp[frame_idx])

<div class="alert alert-success"><h2><b>Checkpoint 2: Model Training</b></h2>
</div>

<hr style="height:2px;">

# **Exercise 2**: Evaluating MicroSplit performance

So far, you have trained MicroSplit and had a first qualitative evaluation on the validation set. However, at this point of the course you should be familiar with the idea that a proper evaluation should be carried out on a held-out test set, which has not been seen by the model during any part of the training process. In this section we perform the evaluation on the test set, which will include a further qualitative inspection of predicted images and a quantitative evaluation using adequate metrics to measure models' performance

Recall that for this task, on a standard GPU, we cannot feed the entire image to <nobr>Micro$\mathbb{S}$plit</nobr>. Hence, we process smaller chunks of the full image that we so far called **patches**. Usually, at training time these patches are obtained as random crops from the full input images, as random cropping works as a kind of ***data augmentation*** technique. However, at test time we want our predictions to be done on the full images. Hence, we need a more "organized" strategy to obtain the patches. An option is to divide the full frames into an ordered grid of patches. In our paper, we call this process ***tiling*** and we call the single crops ***tiles***, to differentiate them from the ones we use for training.

A recurrent issue in ***tiled prediction*** is the possible presence of the so-called ***tiling artefacts***, which originate from inconsistencies and mismatches at the borders of neighboring tiles (see (c) - No padding in the figure below). This problem can be alleviated by performing ***padding*** of the input tile, and later discarding the padded area when stitching the predictions. The idea here is to introduce some overlap between neighboring tiles to have a smoother transition between them. Common padding strategies are:
- ***Outer padding***: the patch (tile) size used for training (e.g., `(64, 64)`) is padded to a larger size. Then, the padded are is discarded during stitching.
- ***Inner padding***: the patch (tile) size used for training (e.g., `(64, 64)`) is used as input for prediction. Then, only the inner part of it is kept during stitiching.

In our work we use ***Inner padding*** as it preserves the same field of view the network has seen during training and empirically provides better performance on our task (see (b)).


![InnerAndOuterPadding.png](attachment:dbe65ad8-8c38-45bd-8301-a7e574a84cf2.png)

## 2.1. Compute MicroSplit predictions on the test set

<div class="alert alert-info"><h4><b>(Optional) Task 2.1.1: Load checkpoint</b></h4>

In case you had any troubles while executing the notebook (disconnection, dead kernel, ...), you can avoid retraining MicroSplit from scratch and load, instead, some checkpoints.
Here you can choose between a pre-trained models by us, or the one you previously trained, if any.

Similarly, if you are not satisfied with your trained model, you can try with the pre-trained one by us. However, we strongly suggest you first try with yours to identify pitfalls, and then you resort back to ours to check how close you got to that.

In [None]:
# Recursively search for .ckpt files in 'checkpoints' folder
ckpt_folder = Path("./checkpoints")
ckpt_folders = set()
for file in ckpt_folder.rglob("*.ckpt"):
    ckpt_folders.add(file.parent)
ckpt_folders = sorted(ckpt_folders)


def list_available_model_checkpoint_folders():
    print("These models you have trained have been found:")
    if len(ckpt_folders) == 0:
        print(" ❌ None!")
    else:
        for file in ckpt_folders:
            print(" 🟢", file)

In [None]:
list_available_model_checkpoint_folders()

In [None]:
ckpt_folder = "checkpoints"
selected_ckpt = load_checkpoint_path(str(ckpt_folder), best=True)
print("✅ Selected model checkpoint:", selected_ckpt)

In [None]:
load_pretrained_model(model, selected_ckpt)

<div class="alert alert-info"><h4><b>Task 2.1.2: Get test set predictions</b></h4>

Here we reuse the `get_unnormalized_predictions` you saw before to get the unmixed predicted images for the training set.
You will have to:
- Set `MMSE_COUNT` parameter, being careful at finding an appropriate trade-off between prediction quality (remember the tiling artefacts we discussed above) and computation time. Given out time contraint, a reasonable range to try is `[2, 20]`.
- Set `INNER_TILE_SIZE` parameter, trying different values for inner padding. Also here notice that a smaller `INNER_TILE_SIZE` entails larger padding/overlap between neighboring patches and, hence, more predictions to be done. A reasonable range to try is `[16, 64]`, where `64` means that no padding is done (recall, we used a patch size of `64`).


In [None]:
MMSE_COUNT = 10
"""The number of MMSE samples to use for the splitting predictions."""
INNER_TILE_SIZE = 32
"""The inner tile size considered for the predictions."""

In [None]:
stitched_predictions, _, stitched_stds = (
    get_unnormalized_predictions(
        model,
        test_dset,
        mmse_count=MMSE_COUNT,
        grid_size=INNER_TILE_SIZE,
        num_workers=4,
        batch_size=8,
    )
)

***NOTE***: you might have seen that the function also returns `stitched_stds`. These are the pixel-wise standard deviations over the `MMSE_COUNT`-many samples for each image (yes, also these have been stitched bac to images)!!

<div class="alert alert-success"><h2><b>Checkpoint 3: Test set predictions</b></h2>
</div>

<hr style="height:2px;">

## 2.2. Qualitative evaluation of MicroSplit predictions

In this section you will provided with tools to interactively inspect the predicted unmixed images from the test set to have a premliminary qualitative evaluation and spot potential issues.

<div class="alert alert-info"><h4><b>Task 2.2: Look for defects in the obtained predictions</b></h4>

Previously we discussed how noise, number of labeled structures, and morphological similarity between label structures can influence the complexity of the unmixing task. Depending on these factors, you might see some defects on your predicted unmixed images. In addition, we mentioned that tiled prediction can cause the so-called tiling artefacts.

In this section, your task is to:
1. identify these defects (if any).
2. determine what is the likely source (e.g., tiling artefact, unmixing failure, ...).

You will be provided with functions to visualize (i) full images, (ii) random smaller crops, (iii) custom crops.

#### (i) Full image visualization

In [None]:
frame_idx = 0 # Change this index to visualize different frames
assert frame_idx < len(stitched_predictions), f"Frame index {frame_idx} out of bounds. Max index is {len(stitched_predictions) - 1}."

full_frame_evaluation(stitched_predictions[frame_idx], tar[frame_idx], inp[frame_idx])

#### (ii) Random crops visualization

In [None]:

from microsplit_reproducibility.utils.utils import clean_ax
from microsplit_reproducibility.notebook_utils.HT_LIF24 import (
    pick_random_patches_with_content,
)
import matplotlib.pyplot as plt

img_sz = 128
rand_locations = pick_random_patches_with_content(tar, 128)
h_start = rand_locations[
    2, 0
]  # np.random.randint(stitched_predictions.shape[1] - img_sz)
w_start = rand_locations[
    2, 1
]  # np.random.randint(stitched_predictions.shape[2] - img_sz)

ncols = 5
nrows = min(len(rand_locations), 5)
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 3, nrows * 3))

for i, (h_start, w_start) in enumerate(rand_locations[:nrows]):
    ax[i, 0].imshow(inp[0, h_start : h_start + img_sz, w_start : w_start + img_sz])
    for j in range(ncols // 2):
        vmin = stitched_predictions[..., j].min()
        vmax = stitched_predictions[..., j].max()
        ax[i, 2 * j + 1].imshow(
            tar[0, h_start : h_start + img_sz, w_start : w_start + img_sz, j],
            vmin=vmin,
            vmax=vmax,
        )
        ax[i, 2 * j + 2].imshow(
            stitched_predictions[
                0, h_start : h_start + img_sz, w_start : w_start + img_sz, j
            ],
            vmin=vmin,
            vmax=vmax,
        )

ax[0, 0].set_title("Primary Input")
for i in range(2):  # 2 channel splitting
    ax[0, 2 * i + 1].set_title(f"Target Channel {i+1}")
    ax[0, 2 * i + 2].set_title(f"Predicted Channel {i+1}")

# reduce the spacing between the subplots
plt.subplots_adjust(wspace=0.03, hspace=0.03)
clean_ax(ax)

#### (iii) Custom crop visualization

In [None]:
# --- Pick coordinates of upper-left corner and crop size ---
y_start = 750
x_start = 750
crop_size = 512
#--------------
assert y_start + crop_size <= stitched_predictions[0].shape[1], f"y_start + crop_size exceeds image height, which is {stitched_predictions[0].shape[1]}"
assert x_start + crop_size <= stitched_predictions[0].shape[2], f"x_start + crop_size exceeds image width, which is {stitched_predictions[0].shape[2]}"

ncols = 3
nrows = 2
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols * 5, nrows * 5))
ax[0, 0].imshow(inp[0, y_start : y_start + crop_size, x_start : x_start + crop_size])
for i in range(ncols - 1):
    vmin = stitched_predictions[..., i].min()
    vmax = stitched_predictions[..., i].max()
    ax[0, i + 1].imshow(
        tar[0, y_start : y_start + crop_size, x_start : x_start + crop_size, i],
        vmin=vmin,
        vmax=vmax,
    )
    ax[1, i + 1].imshow(
        stitched_predictions[
            0, y_start : y_start + crop_size, x_start : x_start + crop_size, i
        ],
        vmin=vmin,
        vmax=vmax,
    )

# disable the axis for ax[1,0]
ax[1, 0].axis("off")
ax[0, 0].set_title("Input")
ax[0, 1].set_title("Channel 1")
ax[0, 2].set_title("Channel 2")
# set y labels on the right for ax[0,2]
ax[0, 2].yaxis.set_label_position("right")
ax[0, 2].set_ylabel("Target")

ax[1, 2].yaxis.set_label_position("right")
ax[1, 2].set_ylabel("Predicted")

print("Here the crop you selected:")

<div class="alert alert-warning"><h4><b>Question 2.2.</b></h4>

Can you propose any idea about how to get rid of the current issues in the predictions? Take into account the things we mentioned during the course so far...

</div>

<div class="alert alert-warning"><h4><b>Bonus Question</b></h4>

In this and other exercises we spoke of "tiling artefacts". These are generally due to a mismatch in the predictions of adjacent tiles/patches. In the context of CNN and, specifically, VAE-based models, can you think about reasons why we have such effect?

*Hint1*: for CNN, think about how convolution works at the image borders... <br>
*Hint2*: for VAE, reflect on the sampling happening in the latent space....

*Answers*
Receptive field of CNN, Intensity mismatch due to different sampling, ...

</div>

## 2.3. Quantitative evaluation of MicroSplit predictions

In this section you will perform a quantitative evaluation of MicroSplit unmixing performance using the provided function to compute metrics. In image restoration there are several commonly used metrics to quantitatively assess the goodness of a model's predictions. Clearly, different metrics focus on different aspects and provide different insights. Some metrics evaluate the ***pixel-wise similarity*** between images, while some other focus on higher-order features (e.g., brightness, contrast, ...) and, hence, we say they evaluate the ***perceptual similarity*** of images. Some commonly used metrics are:
- ***Pixel-wise similarity***: `Peak Signal-to-Noise Ratio (PSNR)`, `Pearson's Correlation Coefficient`.
- ***Perceptual similarity***: `Structural similarity index measure (SSIM)` with its multi-scale variant `(MS-SSIM)`, and our variant for microscopy `MicroSSIM` (paper: [link](https://arxiv.org/abs/2408.08747)) with its multi-scale variant `(MicroMS3IM)`, `Learned Perceptual Image Patch Similarity (LPIPS)`, `Fréchet Inception Distance (FID)`.

<div class="alert alert-info"><h4><b>Task 2.3: Compute metrics</b></h4>

Here, your task is to select appropriate metrics to use for the quantitative evaluation among the available ones.

*Hint*: there are no absolutely good and bad metrics. All the metrics are useful! They key is to understand *what they are telling you*.

In [None]:
# Comment out the metrics you don
METRICS = [
    "PSNR",
    "Pearson",
    "SSIM",
    "MS-SSIM",
    "MicroSSIM",
    "MicroMS3IM",
    "LPIPS",
]

**NOTE**: as ground truth reference for computing the metrics, we will use the high-SNR images obtained with long exposure (500ms).

In [None]:
_, _, gt_test_dset, _ = create_train_val_datasets(
    datapath=Path(f"/group/jug/federico/data/MBL_course/500ms"),
    train_config=train_data_config,
    val_config=val_data_config,
    test_config=val_data_config,
    load_data_func=load_data_func,
)
gt_target = gt_test_dset._data[..., :-1]

In [None]:
print("Metric, followed by values for each channel")
_ = compute_metrics(gt_target, stitched_predictions, metrics=METRICS)

<div class="alert alert-warning"><h4><b>Question 2.3.</b></h4>

- Do you spot inconsistencies between your qualitative judgement and the computed metrics? Did you expect something different?
- Which metrics are the most informative/interpretable?

</div>

<div class="alert alert-success"><h2><b>Checkpoint 4: Qualitative and Quantitative evaluation</b></h2>
</div>

<hr style="height:2px;">