# denoiSplit: joint splitting and unsupervised denoising
In this notebook, we tackle the problem of joint splitting and unsupervised denoising, which has a use case in the field of fluorescence microscopy. From a technical perspective, given a noisy image $x$, the goal is to predict two images $c_1$ and $c_2$ such that $x = c_1 + c_2 + n$, where $n$ is the noise in $x$. In other words, we have a superimposed image $x$ and we want to predict the denoised estimates of the constituent images $c_1$ and $c_2$. It is important to note that the network is trained with noisy data and the denoising is done in a unsupervised manner. 

For this, we will use [denoiSplit](https://arxiv.org/pdf/2403.11854.pdf), a recently developed approach for this task. In this notebook we train denoiSplit and later evaluate it on one validation frame. The overall schema for denoiSplit is shown below:
<!-- Insert a figure -->
<!-- ![Schema](imgs/teaser.png) -->
<img src="imgs/teaser.png" alt="drawing" width="800"/>


Here, we look at CCPs (clathrin-coated pits) vs ER (Endoplasmic reticulum) task, one of the tasks tackled by denoiSplit which is generated from [BioSR](https://figshare.com/articles/dataset/BioSR/13264793) dataset.

1) First, we will load both CCPs and ER images. <br>
2) We'll add synthetic Poisson and Gaussian noise to them. This simulates the noise that typically occurs in light microscopy.<br>
3) Each noisy CCPs image will be added to each corresponding ER image, making a superimposed image, $x$. <br>
4) A VSE network will be trained to take $x$ as input and return unsplit, denoised CCPs and ER images.
5) You'll inspect the results, then re-run the notebook with different noise levels and model hyper-parameters to see how performance changes.

<div class="alert alert-danger">
Set your python kernel to <code>05_image_restoration</code>
</div>

## Set directories 
In the next cell, we enumerate the necessary fields for this task.

In [None]:
import os

work_dir = "."
tensorboard_log_dir = os.path.join(work_dir, "tensorboard_logs")
os.makedirs(tensorboard_log_dir, exist_ok=True)

In [None]:
import sys
sys.path.append('./denoisplit')

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches

from denoisplit.data_loader.vanilla_dloader import MultiChDloader
from denoisplit.analysis.plot_utils import clean_ax
from denoisplit.configs.biosr_config import get_config
from denoisplit.training import create_dataset
from denoisplit.nets.model_utils import create_model
from denoisplit.core.metric_monitor import MetricMonitor
from denoisplit.scripts.run import get_mean_std_dict_for_model
from denoisplit.core.data_split_type import DataSplitType
from denoisplit.scripts.evaluate import get_highsnr_data
from denoisplit.analysis.mmse_prediction import get_dset_predictions
from denoisplit.data_loader.patch_index_manager import GridAlignement
from denoisplit.scripts.evaluate import avg_range_inv_psnr, compute_multiscale_ssim

<a id='things-to-try'></a>
<div class="alert alert-block alert-warning"><h3>
    Several Things to try:</h3> 
    <ol>
        <li>Run once with unchanged config to see the performance. </li>
        <li>Increase the noise (double the gaussian noise?) and see how performance degrades.</li>
        <ol style="text-indent: 25px;">
        <li>Recap: Poisson and Gaussian are the two most prominant pixelwise independent noise sources. Here, we encorporate both.  Note that the larger the noise, the harder the task becomes.</li> 
        </ol>
        <li> Increase the max_epochs, if you want to get better performance. </li>
        <li> For faster training ( but compromising on performance), reduce the number of hierarchy levels and/or the channel count by modifying <em>config.model.z_dims</em>.</li> 
        <li> First we train the model to split CCPs and ER channels. Later you can try to split other channels, e.g. F-actin and ER. You'll be able to see that this is a substantially harder task. </li>
    </ol>
</div>


## Config 

Here we'll load the data and set model hyper-parameters.
To create the dataset, we'll load two sets of images: CCPs (clathrin-coated pits) and ER (Endoplasmic reticulum). 
Each image from the CCPs will be added to an image from ER, then noise added on top.

The level of noise is determined by `config.data.poisson_noise_factor` and `config.data.synthetic_gaussian_scale`.
The former simulates photon shot noise, which is more destructive on lower intensity signals.
The latter simulates electronic read noise, which has a constant variance for all signal intensities.


`config.data.poisson_noise_factor` (float): the intensity of the Poisson (shot) noise.

`config.data.synthetic_gaussian_scale` (float): the intensity of the Gaussian (readout) noise.

`config.model.z_dims` (list(int)): Determines the depth of our network. The number of entries is the number of levels. The value of each entry is the number of hidden dimensions at each level.

`config.training.lr` (float): The learning rate.

`config.training.max_epochs` (int): Number of training epochs. Increase for better performance, decrease for shorter training time.

`config.training.batch_size` (int): Training batch size. Increasing this will require more memory. Performance may improve, but bigger batches aren't always better.

`config.training.num_workers` (int): Number of subprocesses to use for data loading. This is different for different GPUs.


In [None]:
datapath = "./../data/"

# load the default config.
config = get_config()

config.data.ch1_fname = 'ER/GT_all.mrc'
config.data.ch2_fname = 'CCPs/GT_all.mrc'
# Channge the noise level
config.data.poisson_noise_factor = (
    1000  # 1000 is the default value. noise increases with the value.
)
config.data.synthetic_gaussian_scale = (
    5000  # 5000 is the default value. noise increases with the value.
)

# change the number of hierarchy levels.
config.model.z_dims = [128, 128, 128, 128]

# change the training parameters
config.training.lr = 3e-3
config.training.max_epochs = 10
config.training.batch_size = 8
config.training.num_workers = 4

config.workdir = "."

## Create the dataset and pytorch dataloaders. 

In [None]:
print(config)

In [None]:
train_dset = MultiChDloader(config.data,
                            datapath,
                            datasplit_type=DataSplitType.Train,
                            val_fraction=config.training.val_fraction,
                            test_fraction=config.training.test_fraction,
                            normalized_input=config.data.normalized_input,
                            use_one_mu_std=config.data.use_one_mu_std,
                            enable_rotation_aug=config.data.train_aug_rotate
                            )
val_dset = MultiChDloader(config.data,
                datapath,
                datasplit_type=DataSplitType.Val,
                val_fraction=config.training.val_fraction,
                test_fraction=config.training.test_fraction,
                normalized_input=config.data.normalized_input,
                use_one_mu_std=config.data.use_one_mu_std,
                enable_rotation_aug=False,  # No rotation aug on validation
                max_val=train_dset.get_max_val(),
                )


In [None]:
mean_dict, std_dict = train_dset.compute_mean_std()
train_dset.set_mean_std(mean_dict, std_dict)
val_dset.set_mean_std(mean_dict, std_dict)

mean_dict, std_dict = get_mean_std_dict_for_model(config, train_dset)

## Inspecting the training data generated using the above config.
<div class="alert alert-block alert-warning">
If you want to change the noise, then you should change the config first and run the following cell again to see how the training data changes in terms of noise.
</div>


In [None]:
val_dset.set_img_sz(800, 64)
inp, tar = val_dset[0]
_,ax = plt.subplots(1,3, figsize=(15,5))
ax[0].imshow(inp[0], cmap='magma')
ax[0].set_title('Input')
ax[1].imshow(tar[0], cmap='magma')
ax[1].set_title('Channel 1')
ax[2].imshow(tar[1], cmap='magma')
ax[2].set_title('Channel 2')

val_dset.set_img_sz(config.data.image_size, config.data.image_size)

## Define the dataloaders

In [None]:
batch_size = config.training.batch_size
train_dloader = DataLoader(
    train_dset,
    pin_memory=False,
    num_workers=config.training.num_workers,
    shuffle=True,
    batch_size=batch_size,
)
val_dloader = DataLoader(
    val_dset,
    pin_memory=False,
    num_workers=config.training.num_workers,
    shuffle=False,
    batch_size=batch_size,
)

## Create the model.
Here, we instantiate the [denoiSplit model](https://arxiv.org/pdf/2403.11854.pdf). For simplicity, we have disabled the noise model. For enabling the noise model, one would additionally have to train a denoiser. The next step would be to create a noise model using the noisy data and the corresponding denoised predictions. 


In [None]:
model = create_model(config, mean_dict, std_dict)
model = model.cuda()

## Start training

In [None]:
logger = TensorBoardLogger(tensorboard_log_dir, name="", version="", default_hp_metric=False)
trainer = pl.Trainer(
    max_epochs=config.training.max_epochs,
    gradient_clip_val=(
        None
        if not model.automatic_optimization
        else config.training.grad_clip_norm_value
    ),
    logger=logger,
    precision=config.training.precision,
)
trainer.fit(model, train_dloader, val_dloader)

## Evaluate the model

In [None]:
model.eval()
_ = model.cuda()
eval_frame_idx = 0
# reducing the data, just for speed
val_dset.reduce_data(t_list=[eval_frame_idx])
mmse_count = 10
overlapping_padding_kwargs = {
    "mode": config.data.get("padding_mode", "constant"),
}
if overlapping_padding_kwargs["mode"] == "constant":
    overlapping_padding_kwargs["constant_values"] = config.data.get("padding_value", 0)
val_dset.set_img_sz(
    128,
    32,
    grid_alignment=GridAlignement.Center,
    overlapping_padding_kwargs=overlapping_padding_kwargs,
)

# MMSE prediction
pred_tiled, rec_loss, logvar_tiled, patch_psnr_tuple, pred_std_tiled = (
    get_dset_predictions(
        model,
        val_dset,
        batch_size,
        num_workers=config.training.num_workers,
        mmse_count=mmse_count,
        model_type=config.model.model_type,
    )
)

# One sample prediction
pred1_tiled, *_ = get_dset_predictions(
    model,
    val_dset,
    batch_size,
    num_workers=config.training.num_workers,
    mmse_count=1,
    model_type=config.model.model_type,
)
# One sample prediction
pred2_tiled, *_ = get_dset_predictions(
    model,
    val_dset,
    batch_size,
    num_workers=config.training.num_workers,
    mmse_count=1,
    model_type=config.model.model_type,
)

## Stitching predictions

In [None]:
from denoisplit.analysis.stitch_prediction import stitch_predictions

pred = stitch_predictions(pred_tiled, val_dset)


# ignore pixels at the [right/bottom] boundary.
def print_ignored_pixels():
    ignored_pixels = 1
    while (
        pred[
            0,
            -ignored_pixels:,
            -ignored_pixels:,
        ].std()
        == 0
    ):
        ignored_pixels += 1
    ignored_pixels -= 1
    return ignored_pixels


actual_ignored_pixels = print_ignored_pixels()
pred = pred[:, :-actual_ignored_pixels, :-actual_ignored_pixels]
pred1 = stitch_predictions(pred1_tiled, val_dset)[
    :, :-actual_ignored_pixels, :-actual_ignored_pixels
]
pred2 = stitch_predictions(pred2_tiled, val_dset)[
    :, :-actual_ignored_pixels, :-actual_ignored_pixels
]

## Get the ground truth

In [None]:
highres_data = get_highsnr_data(config, datapath, DataSplitType.Val)

highres_data = highres_data[
    eval_frame_idx : eval_frame_idx + 1,
    :-actual_ignored_pixels,
    :-actual_ignored_pixels,
]

noisy_data = val_dset._noise_data[..., 1:] + val_dset._data
noisy_data = noisy_data[..., :-actual_ignored_pixels, :-actual_ignored_pixels, :]
model_input = np.mean(noisy_data, axis=-1)


<div class="alert alert-block alert-success"><h1>Checkpoint 1: Model trained</h1>
</div>

# Qualitative performance on a random crop
denoiSplit is capable of sampling from a learned posterior.
Here we show full input frame and a randomly cropped input (300*300),
two corresponding prediction samples, the difference between the two samples (S1âˆ’S2),
the MMSE prediction, and otherwise unused high SNR microscopy crop. 
The MMSE predictions are computed by averaging 10 samples. 

In [None]:
def add_str(ax_, txt):
    """
    Add psnr string to the axes
    """
    textstr = txt
    props = dict(boxstyle="round", facecolor="gray", alpha=0.5)
    # place a text box in upper left in axes coords
    ax_.text(
        0.05,
        0.95,
        textstr,
        transform=ax_.transAxes,
        fontsize=11,
        verticalalignment="top",
        bbox=props,
        color="white",
    )


ncols = 7
nrows = 2
sz = 300
hs = np.random.randint(0, highres_data.shape[1] - sz)
ws = np.random.randint(0, highres_data.shape[2] - sz)
_, ax = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))
ax[0, 0].imshow(model_input[0], cmap="magma")

rect = patches.Rectangle((ws, hs), sz, sz, linewidth=1, edgecolor="r", facecolor="none")
ax[0, 0].add_patch(rect)
ax[1, 0].imshow(model_input[0, hs : hs + sz, ws : ws + sz], cmap="magma")
add_str(ax[0, 0], "Full Input Frame")
add_str(ax[1, 0], "Random Input Crop")

ax[0, 1].imshow(noisy_data[0, hs : hs + sz, ws : ws + sz, 0], cmap="magma")
ax[1, 1].imshow(noisy_data[0, hs : hs + sz, ws : ws + sz, 1], cmap="magma")

ax[0, 2].imshow(pred1[0, hs : hs + sz, ws : ws + sz, 0], cmap="magma")
ax[1, 2].imshow(pred1[0, hs : hs + sz, ws : ws + sz, 1], cmap="magma")

ax[0, 3].imshow(pred2[0, hs : hs + sz, ws : ws + sz, 0], cmap="magma")
ax[1, 3].imshow(pred2[0, hs : hs + sz, ws : ws + sz, 1], cmap="magma")

diff = pred2 - pred1
ax[0, 4].imshow(diff[0, hs : hs + sz, ws : ws + sz, 0], cmap="coolwarm")
ax[1, 4].imshow(diff[0, hs : hs + sz, ws : ws + sz, 1], cmap="coolwarm")

ax[0, 5].imshow(pred[0, hs : hs + sz, ws : ws + sz, 0], cmap="magma")
ax[1, 5].imshow(pred[0, hs : hs + sz, ws : ws + sz, 1], cmap="magma")


ax[0, 6].imshow(highres_data[0, hs : hs + sz, ws : ws + sz, 0], cmap="magma")
ax[1, 6].imshow(highres_data[0, hs : hs + sz, ws : ws + sz, 1], cmap="magma")
plt.subplots_adjust(wspace=0.02, hspace=0.02)
ax[0, 0].set_title("Model Input", size=13)
ax[0, 1].set_title("Target", size=13)
ax[0, 2].set_title("Sample 1 (S1)", size=13)
ax[0, 3].set_title("Sample 2 (S2)", size=13)
ax[0, 4].set_title('"S2" - "S1"', size=13)
ax[0, 5].set_title(f"Prediction MMSE({mmse_count})", size=13)
ax[0, 6].set_title("High SNR Reality", size=13)

twinx = ax[0, 6].twinx()
twinx.set_ylabel("Channel 1", size=13)
clean_ax(twinx)
twinx = ax[1, 6].twinx()
twinx.set_ylabel("Channel 2", size=13)
clean_ax(twinx)
clean_ax(ax)

# Qualitative performance on multiple random crops


In [None]:
nimgs = 3
ncols = 7
nrows = 2 * nimgs
sz = 300
_, ax = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))

for img_idx in range(nimgs):
    hs = np.random.randint(0, highres_data.shape[1] - sz)
    ws = np.random.randint(0, highres_data.shape[2] - sz)
    ax[2 * img_idx, 0].imshow(model_input[0], cmap="magma")

    rect = patches.Rectangle(
        (ws, hs), sz, sz, linewidth=1, edgecolor="r", facecolor="none"
    )
    ax[2 * img_idx, 0].add_patch(rect)
    ax[2 * img_idx + 1, 0].imshow(
        model_input[0, hs : hs + sz, ws : ws + sz], cmap="magma"
    )
    add_str(ax[2 * img_idx, 0], "Full Input Frame")
    add_str(ax[2 * img_idx + 1, 0], "Random Input Crop")

    ax[2 * img_idx, 1].imshow(
        noisy_data[0, hs : hs + sz, ws : ws + sz, 0], cmap="magma"
    )
    ax[2 * img_idx + 1, 1].imshow(
        noisy_data[0, hs : hs + sz, ws : ws + sz, 1], cmap="magma"
    )

    ax[2 * img_idx, 2].imshow(pred1[0, hs : hs + sz, ws : ws + sz, 0], cmap="magma")
    ax[2 * img_idx + 1, 2].imshow(pred1[0, hs : hs + sz, ws : ws + sz, 1], cmap="magma")

    ax[2 * img_idx, 3].imshow(pred2[0, hs : hs + sz, ws : ws + sz, 0], cmap="magma")
    ax[2 * img_idx + 1, 3].imshow(pred2[0, hs : hs + sz, ws : ws + sz, 1], cmap="magma")

    diff = pred2 - pred1
    ax[2 * img_idx, 4].imshow(diff[0, hs : hs + sz, ws : ws + sz, 0], cmap="coolwarm")
    ax[2 * img_idx + 1, 4].imshow(
        diff[0, hs : hs + sz, ws : ws + sz, 1], cmap="coolwarm"
    )

    ax[2 * img_idx, 5].imshow(pred[0, hs : hs + sz, ws : ws + sz, 0], cmap="magma")
    ax[2 * img_idx + 1, 5].imshow(pred[0, hs : hs + sz, ws : ws + sz, 1], cmap="magma")

    ax[2 * img_idx, 6].imshow(
        highres_data[0, hs : hs + sz, ws : ws + sz, 0], cmap="magma"
    )
    ax[2 * img_idx + 1, 6].imshow(
        highres_data[0, hs : hs + sz, ws : ws + sz, 1], cmap="magma"
    )

    twinx = ax[2 * img_idx, 6].twinx()
    twinx.set_ylabel("Channel 1", size=15)
    clean_ax(twinx)

    twinx = ax[2 * img_idx + 1, 6].twinx()
    twinx.set_ylabel("Channel 2", size=15)
    clean_ax(twinx)

ax[0, 0].set_title("Model Input", size=15)
ax[0, 1].set_title("Target", size=15)
ax[0, 2].set_title("Sample 1 (S1)", size=15)
ax[0, 3].set_title("Sample 2 (S2)", size=15)
ax[0, 4].set_title('"S2" - "S1"', size=15)
ax[0, 5].set_title(f"Prediction MMSE({mmse_count})", size=15)
ax[0, 6].set_title("High SNR Reality", size=15)

clean_ax(ax)
plt.subplots_adjust(wspace=0.02, hspace=0.02)
# plt.tight_layout()

<div class="alert alert-block alert-warning">
    <h3>Questions:</h3>
    1) When is it relatively easy to split the two structures from the input?<br>
    2) Why might you see the grid-like artifacts and what can be done to mitigate this?<br>
</div>


## Quantitative performance
We evaluate on two metrics, Multiscale SSIM and PSNR.

Multi-scale SSIM is a metric that computes SSIM at multiple scales and averages them. It's reminiscent of multiscale processing in the early vision system 

PSNR is a metric that computes the peak signal-to-noise ratio. It's one of the most widely used metrics to measure the quality of image reconstruction

In [None]:
mean_tar = mean_dict["target"].cpu().numpy().squeeze().reshape(1, 1, 1, 2)
std_tar = std_dict["target"].cpu().numpy().squeeze().reshape(1, 1, 1, 2)
pred_unnorm = pred * std_tar + mean_tar

psnr_list = [
    avg_range_inv_psnr(highres_data[..., i].copy(), pred_unnorm[..., i].copy())
    for i in range(highres_data.shape[-1])
]
ssim_list = compute_multiscale_ssim(highres_data.copy(), pred_unnorm.copy())
print("Metric: Ch1\t Ch2")
print(f"PSNR  : {psnr_list[0]:.2f}\t {psnr_list[1]:.2f}")
print(f"MS-SSIM  : {ssim_list[0]:.3f}\t {ssim_list[1]:.3f}")

<div class="alert alert-block alert-success"><h1>Checkpoint 2: Try one of the "Several things to try"</h1>

</div>

Click [here](#things-to-try) to go back to the relevant section.

<hr style="height:2px;"><div class="alert alert-block alert-success"><h1>End of the exercise</h1>
</div>