# Image translation (Virtual Staining) - Part 1

Written by Eduardo Hirata-Miyasaki, Ziwen Liu, and Shalin Mehta, CZ Biohub San Francisco

## Overview

In this exercise, we will predict fluorescence images of
nuclei and plasma membrane markers from quantitative phase images of cells,
i.e., we will _virtually stain_ the nuclei and plasma membrane
visible in the phase image.
This is an example of an image translation task.
We will apply spatial and intensity augmentations to train robust models
and evaluate their performance using a regression approach.

[![HEK293T](https://raw.githubusercontent.com/mehta-lab/VisCy/main/docs/figures/svideo_1.png)](https://github.com/mehta-lab/VisCy/assets/67518483/d53a81eb-eb37-44f3-b522-8bd7bddc7755)
(Click on image to play video)

### Goals

#### Part 1: Learn to use iohub (I/O library), VisCy dataloaders, and TensorBoard.

  - Use a OME-Zarr dataset of 34 FOVs of adenocarcinomic human alveolar basal epithelial cells (A549),
  each FOV has 3 channels (phase, nuclei, and cell membrane).
  The nuclei were stained with DAPI and the cell membrane with Cellmask.
  - Explore OME-Zarr using [iohub](https://czbiohub-sf.github.io/iohub/main/index.html)
  and the high-content-screen (HCS) format.
  - Use [MONAI](https://monai.io/) to implement data augmentations.

#### Part 2: Train and evaluate the model to translate phase into fluorescence.
  - Train a 2D UNeXt2 model to predict nuclei and membrane from phase images.
  - Compare the performance of the trained model and a pre-trained model.
  - Evaluate the model using pixel-level and instance-level metrics.


Checkout [VisCy](https://github.com/mehta-lab/VisCy/tree/main/examples/demos),
our deep learning pipeline for training and deploying computer vision models
for image-based phenotyping including the robust virtual staining of landmark organelles.
VisCy exploits recent advances in data and metadata formats
([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks,
[PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/).

### References

- [Liu, Z. and Hirata-Miyasaki, E. et al. (2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v2.full.pdf)
- [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning. eLife](https://elifesciences.org/articles/55502)

<div class="alert alert-info">
The exercise is organized in 2 parts

<ul>
<li><b>Part 1</b> - Learn to use iohub (I/O library), VisCy dataloaders, and tensorboard.</li>
<li><b>Part 2</b> - Train and evaluate the model to translate phase into fluorescence.</li>
</ul>

</div>

<div class="alert alert-danger">
Set your python kernel to <span style="color:black;">06_image_translation</span>
</div>

## Part 1: Log training data to tensorboard, start training a model.
---------
Learning goals:

- Load the OME-zarr dataset and examine the channels (A549).
- Configure and understand the data loader.
- Log some patches to tensorboard.
- Initialize a 2D UNeXt2 model for virtual staining of nuclei and membrane from phase.
- Start training the model to predict nuclei and membrane from phase.

In [None]:
import os
from glob import glob
from pathlib import Path
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torchview
import torchvision
from cellpose import models
from iohub import open_ome_zarr
from iohub.reader import print_info
from lightning.pytorch import seed_everything
from lightning.pytorch.loggers import TensorBoardLogger
from natsort import natsorted
from numpy.typing import ArrayLike
from skimage import metrics  # for metrics.
# pytorch lightning wrapper for Tensorboard.
from skimage.color import label2rgb
from torch.utils.tensorboard import SummaryWriter  # for logging to tensorboard
from torchmetrics.functional import accuracy, dice, jaccard_index
from tqdm import tqdm
# HCSDataModule makes it easy to load data during training.
from viscy.data.hcs import HCSDataModule
from viscy.evaluation.evaluation_metrics import mean_average_precision
# Trainer class and UNet.
from viscy.light.engine import MixedLoss, VSUNet
from viscy.light.trainer import VSTrainer
# training augmentations
from viscy.transforms import (NormalizeSampled, RandAdjustContrastd,
                              RandAffined, RandGaussianNoised,
                              RandGaussianSmoothd, RandScaleIntensityd,
                              RandWeightedCropd)

In [None]:
# seed random number generators for reproducibility.
seed_everything(42, workers=True)

# Paths to data and log directory
top_dir = Path(
    "/mnt/efs/dlmbl/data/"
)  # If this fails, make sure this to point to your data directory in the shared mounting point inside /dlmbl/data

# Path to the training data
data_path = (
    top_dir / "06_image_translation/part1/training/a549_hoechst_cellmask_train_val.zarr"
)

# Path where we will save our training logs
training_top_dir = Path(f"{os.environ['HOME']}/data/")
# Create top_training_dir directory if needed, and launch tensorboard
training_top_dir.mkdir(parents=True, exist_ok=True)
log_dir = training_top_dir / "06_image_translation/part1/logs/"
# Create log directory if needed, and launch tensorboard
log_dir.mkdir(parents=True, exist_ok=True)

if not data_path.exists():
    raise FileNotFoundError(
        f"Data not found at {data_path}. Please check the top_dir and data_path variables."
    )

The next cell starts tensorboard.

<div class="alert alert-warning">
If you launched jupyter lab from ssh terminal, add <code>--host &lt;your-server-name&gt;</code> to the tensorboard command below. <code>&lt;your-server-name&gt;</code> is the address of your compute node that ends in amazonaws.com.

</div>

In [None]:
# Imports and paths
# Function to find an available port
def find_free_port():
    import socket

    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("", 0))
        return s.getsockname()[1]


# Launch TensorBoard on the browser
def launch_tensorboard(log_dir):
    import subprocess

    port = find_free_port()
    tensorboard_cmd = f"tensorboard --logdir={log_dir} --port={port}"
    process = subprocess.Popen(tensorboard_cmd, shell=True)
    print(
        f"TensorBoard started at http://localhost:{port}. \n"
        "If you are using VSCode remote session, forward the port using the PORTS tab next to TERMINAL."
    )
    return process


# Launch tensorboard and click on the link to view the logs.
tensorboard_process = launch_tensorboard(log_dir)

<div class="alert alert-warning">
If you are using VSCode and a remote server, you will need to forward the port to view the tensorboard. <br>
Take note of the port number was assigned in the previous cell.(i.e <code> http://localhost:{port_number_assigned}</code>) <br>

Locate the your VSCode terminal and select the <code>Ports</code> tab <br>
<ul>
<li>Add a new port with the <code>port_number_assigned</code>
</ul>
Click on the link to view the tensorboard and it should open in your browser.
</div>

## Load OME-Zarr Dataset

There should be 34 FOVs in the dataset.

Each FOV consists of 3 channels of 2048x2048 images,
saved in the [High-Content Screening (HCS) layout](https://ngff.openmicroscopy.org/latest/#hcs-layout)
specified by the Open Microscopy Environment Next Generation File Format
(OME-NGFF).

- The layout on the disk is: `row/col/field/pyramid_level/timepoint/channel/z/y/x.`
- These datasets only have 1 level in the pyramid (highest resolution) which is '0'.

<div class="alert alert-warning">
You can inspect the tree structure by using your terminal:
<code> iohub info -v "path-to-ome-zarr" </code>

<br>
More info on the CLI:
<code>iohub info --help </code> to see the help menu.
</div>

In [None]:
# This is the python function called by `iohub info` CLI command
print_info(data_path, verbose=True)

# Open and inspect the dataset.
dataset = open_ome_zarr(data_path)

<div class="alert alert-info">

### Task 1.1
Look at a couple different fields of view (FOVs) by changing the `field` variable.
Check the cell density, the cell morphologies, and fluorescence signal.
HINT: look at the HCS Plate format to see what are your options.
</div>
%%tags=["task"]
Use the field and pyramid_level below to visualize data.
row = 0
col = 0
field = 9  # TODO: Change this to explore data.

NOTE: this dataset only has one level
pyaramid_level = 0

`channel_names` is the metadata that is stored with data according to the OME-NGFF spec.
n_channels = len(dataset.channel_names)

image = dataset[f"{row}/{col}/{field}/{pyaramid_level}"].numpy()
print(f"data shape: {image.shape}, FOV: {field}, pyramid level: {pyaramid_level}")

figure, axes = plt.subplots(1, n_channels, figsize=(9, 3))

for i in range(n_channels):
    for i in range(n_channels):
        channel_image = image[0, i, 0]
        # Adjust contrast to 0.5th and 99.5th percentile of pixel values.
        p_low, p_high = np.percentile(channel_image, (0.5, 99.5))
        channel_image = np.clip(channel_image, p_low, p_high)
        axes[i].imshow(channel_image, cmap="gray")
        axes[i].axis("off")
        axes[i].set_title(dataset.channel_names[i])
plt.tight_layout()

## Explore the effects of augmentation on batch.

VisCy builds on top of PyTorch Lightning. PyTorch Lightning is a thin wrapper around PyTorch that allows rapid experimentation. It provides a [DataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) to handle loading and processing of data during training. VisCy provides a child class, `HCSDataModule` to make it intuitve to access data stored in the HCS layout.

The dataloader in `HCSDataModule` returns a batch of samples. A `batch` is a list of dictionaries. The length of the list is equal to the batch size. Each dictionary consists of following key-value pairs.
- `source`: the input image, a tensor of size 1*1*Y*X
- `target`: the target image, a tensor of size 2*1*Y*X
- `index` : the tuple of (location of field in HCS layout, time, and z-slice) of the sample.

<div class="alert alert-info">

### Task 1.2
- Run the next cell to setup a logger for your augmentations.
- Setup the `HCSDataloader()` in for training.
  - Configure the dataloader for the `"UNeXt2_2D"`
  - Configure the dataloader for the phase (source) to fluorescence cell nuclei and membrane (targets) regression task.
  - Configure the dataloader for training. Hint: use the `HCSDataloader.setup()`
- Open your tensorboard and look at the `IMAGES tab`.

Note: If tensorboard is not showing images or the plots, try refreshing and using the "Images" tab.
</div>

In [None]:
# Define a function to write a batch to tensorboard log.
def log_batch_tensorboard(batch, batchno, writer, card_name):
    """
    Logs a batch of images to TensorBoard.

    Args:
        batch (dict): A dictionary containing the batch of images to be logged.
        writer (SummaryWriter): A TensorBoard SummaryWriter object.
        card_name (str): The name of the card to be displayed in TensorBoard.

    Returns:
        None
    """
    batch_phase = batch["source"][:, :, 0, :, :]  # batch_size x z_size x Y x X tensor.
    batch_membrane = batch["target"][:, 1, 0, :, :].unsqueeze(
        1
    )  # batch_size x 1 x Y x X tensor.
    batch_nuclei = batch["target"][:, 0, 0, :, :].unsqueeze(
        1
    )  # batch_size x 1 x Y x X tensor.

    p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))
    batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)

    p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))
    batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)

    p1, p99 = np.percentile(batch_phase, (0.1, 99.9))
    batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)

    [N, C, H, W] = batch_phase.shape
    interleaved_images = torch.zeros((3 * N, C, H, W), dtype=batch_phase.dtype)
    interleaved_images[0::3, :] = batch_phase
    interleaved_images[1::3, :] = batch_nuclei
    interleaved_images[2::3, :] = batch_membrane

    grid = torchvision.utils.make_grid(interleaved_images, nrow=3)

    # add the grid to tensorboard
    writer.add_image(card_name, grid, batchno)

# Define a function to visualize a batch on jupyter, in case tensorboard is finicky
def log_batch_jupyter(batch):
    """
    Logs a batch of images on jupyter using ipywidget.

    Args:
        batch (dict): A dictionary containing the batch of images to be logged.

    Returns:
        None
    """
    batch_phase = batch["source"][:, :, 0, :, :]  # batch_size x z_size x Y x X tensor.
    batch_size = batch_phase.shape[0]
    batch_membrane = batch["target"][:, 1, 0, :, :].unsqueeze(
        1
    )  # batch_size x 1 x Y x X tensor.
    batch_nuclei = batch["target"][:, 0, 0, :, :].unsqueeze(
        1
    )  # batch_size x 1 x Y x X tensor.

    p1, p99 = np.percentile(batch_membrane, (0.1, 99.9))
    batch_membrane = np.clip((batch_membrane - p1) / (p99 - p1), 0, 1)

    p1, p99 = np.percentile(batch_nuclei, (0.1, 99.9))
    batch_nuclei = np.clip((batch_nuclei - p1) / (p99 - p1), 0, 1)

    p1, p99 = np.percentile(batch_phase, (0.1, 99.9))
    batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1)

    plt.figure()
    fig, axes = plt.subplots(
        batch_size, n_channels, figsize=(n_channels * 2, batch_size * 2)
    )
    [N, C, H, W] = batch_phase.shape
    for sample_id in range(batch_size):
        axes[sample_id, 0].imshow(batch_phase[sample_id, 0])
        axes[sample_id, 1].imshow(batch_nuclei[sample_id, 0])
        axes[sample_id, 2].imshow(batch_membrane[sample_id, 0])

        for i in range(n_channels):
            axes[sample_id, i].axis("off")
            axes[sample_id, i].set_title(dataset.channel_names[i])
    plt.tight_layout()
    plt.show()

In [None]:
# #######################
# ##### SOLUTION ########
# #######################

BATCH_SIZE = 4
# 4 is a perfectly reasonable batch size
# (batch size does not have to be a power of 2)
# See: https://sebastianraschka.com/blog/2022/batch-size-2.html

source_channel = ["Phase3D"]
target_channel = ["Nucl", "Mem"]

data_module = HCSDataModule(
    data_path,
    z_window_size=1,
    architecture="UNeXt2_2D",
    source_channel=source_channel,
    target_channel=target_channel,
    split_ratio=0.8,
    batch_size=BATCH_SIZE,
    num_workers=8,
    yx_patch_size=(256, 256),  # larger patch size makes it easy to see augmentations.
    augmentations=[],  # Turn off augmentation for now.
    normalizations=[],  # Turn off normalization for now.
)

# Setup the data_module to fit. HINT: data_module.setup()
data_module.setup("fit")

# Evaluate the data module
print(
    f"Samples in training set: {len(data_module.train_dataset)}, "
    f"samples in validation set:{len(data_module.val_dataset)}"
)
train_dataloader = data_module.train_dataloader()
# Instantiate the tensorboard SummaryWriter, logs the first batch and then iterates through all the batches and logs them to tensorboard.
writer = SummaryWriter(log_dir=f"{log_dir}/view_batch")
# Draw a batch and write to tensorboard.
batch = next(iter(train_dataloader))
log_batch_tensorboard(batch, 0, writer, "augmentation/none")
writer.close()

<div class="alert alert-warning">

### Questions
1. What are the two channels in the target image?
2. How many samples are in the training and validation set? What determined that split?

Note: If tensorboard is not showing images, try refreshing and using the "Images" tab.
</div>

If your tensorboard is causing issues, you can visualize directly on Jupyter /VSCode

In [None]:
# Visualize in Jupyter
log_batch_jupyter(batch)

<div class="alert alert-warning">
<h3> Question for Task 1.3 </h3>
1. How do they make the model more robust to imaging parameters or conditions
without having to acquire data for every possible condition? <br>
</div>

<div class="alert alert-info">

### Task 1.3
Add the following augmentations: 
- Add augmentations to rotate about $\pi$ around z-axis, 30% scale in y,x,
shearing of 10% and no padding with zeros with a probablity of 80%.
- Add a Gaussian noise with a mean of 0.0 and standard deviation of 0.3 with a probability of 50%.

HINT: `RandAffined()` and `RandGaussianNoised()` are from
`viscy.transforms` [here](https://github.com/mehta-lab/VisCy/blob/main/viscy/transforms.py). You can look at the docs by running `RandAffined?`.<br><br>
*Note these are MONAI transforms that have been redefined for VisCy.* 
[Compare your choice of augmentations by dowloading the pretrained models and config files](https://github.com/mehta-lab/VisCy/releases/download/v0.1.0/VisCy-0.1.0-VS-models.zip).
</div>

In [None]:
# #######################
# ##### SOLUTION ########
# #######################
source_channel = ["Phase3D"]
target_channel = ["Nucl", "Mem"]

augmentations = [
    RandWeightedCropd(
        keys=source_channel + target_channel,
        spatial_size=(1, 384, 384),
        num_samples=2,
        w_key=target_channel[0],
    ),
    RandAffined(
        keys=source_channel + target_channel,
        rotate_range=[3.14, 0.0, 0.0],
        scale_range=[0.0, 0.3, 0.3],
        prob=0.8,
        padding_mode="zeros",
        shear_range=[0.0, 0.01, 0.01],
    ),
    RandAdjustContrastd(keys=source_channel, prob=0.5, gamma=(0.8, 1.2)),
    RandScaleIntensityd(keys=source_channel, factors=0.5, prob=0.5),
    RandGaussianNoised(keys=source_channel, prob=0.5, mean=0.0, std=0.3),
    RandGaussianSmoothd(
        keys=source_channel,
        sigma_x=(0.25, 0.75),
        sigma_y=(0.25, 0.75),
        sigma_z=(0.0, 0.0),
        prob=0.5,
    ),
]

normalizations = [
    NormalizeSampled(
        keys=source_channel + target_channel,
        level="fov_statistics",
        subtrahend="mean",
        divisor="std",
    )
]

data_module.augmentations = augmentations

# Setup the data_module to fit. HINT: data_module.setup()
data_module.setup("fit")

# get the new data loader with augmentation turned on
augmented_train_dataloader = data_module.train_dataloader()

# Draw batches and write to tensorboard
writer = SummaryWriter(log_dir=f"{log_dir}/view_batch")
augmented_batch = next(iter(augmented_train_dataloader))
log_batch_tensorboard(augmented_batch, 0, writer, "augmentation/some")
writer.close()

<div class="alert alert-warning">
<h3> Question for Task 1.3 </h3>
1. Look at your tensorboard. Can you tell the agumentations were applied to the sample batch? Compare the batch with and without augmentations. <br>
2. Are these augmentations good enough? What else would you add?
</div>

Visualize directly on Jupyter

In [None]:
log_batch_jupyter(augmented_batch)

## Train a 2D U-Net model to predict nuclei and membrane from phase.

## Constructing a 2D UNeXt2 using VisCy

<div class="alert alert-info">

### Task 1.5
- Run the next cell to instantiate the `UNeXt2_2D` model
  - Configure the network for the phase (source) to fluorescence cell nuclei and membrane (targets) regression task.
  - Call the VSUNet with the `"UNeXt2_2D"` architecture.
- Run the next cells to instantiate data module and trainer.
  - Add the source channel name and the target channel names
- Start the training <br>

<b> Note </b> <br>
See ``viscy.unet.networks.Unet2D.Unet2d`` ([source code](https://github.com/mehta-lab/VisCy/blob/7c5e4c1d68e70163cf514d22c475da8ea7dc3a88/viscy/unet/networks/Unet2D.py#L7)) to learn more about the configuration.
</div>

In [None]:

# Here we are creating a 2D UNet.
GPU_ID = 0

BATCH_SIZE = 12
YX_PATCH_SIZE = (256, 256)

# Dictionary that specifies key parameters of the model.
# #######################
# ##### SOLUTION ########
# #######################
phase2fluor_config = dict(
    in_channels=1,
    out_channels=2,
    encoder_blocks=[3, 3, 9, 3],
    dims=[96, 192, 384, 768],
    decoder_conv_blocks=2,
    stem_kernel_size=(1, 2, 2),
    in_stack_depth=1,
    pretraining=False,
)

phase2fluor_model = VSUNet(
    architecture="UNeXt2_2D",  # 2D UNeXt2 architecture
    model_config=phase2fluor_config.copy(),
    loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5),
    schedule="WarmupCosine",
    lr=2e-5,
    log_batches_per_epoch=5,  # Number of samples from each batch to log to tensorboard.
    freeze_encoder=False,
)

### Instantiate data module and trainer, test that we are setup to launch training.

In [None]:
# Selecting the source and target channel names from the dataset.
source_channel = ["Phase3D"]
target_channel = ["Nucl", "Mem"]
# Setup the data module.
phase2fluor_2D_data = HCSDataModule(
    data_path,
    architecture="UNeXt2_2D",
    source_channel=source_channel,
    target_channel=target_channel,
    z_window_size=1,
    split_ratio=0.8,
    batch_size=BATCH_SIZE,
    num_workers=8,
    yx_patch_size=YX_PATCH_SIZE,
    augmentations=augmentations,
    normalizations=normalizations,
)
phase2fluor_2D_data.setup("fit")
# fast_dev_run runs a single batch of data through the model to check for errors.
trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], fast_dev_run=True)

# trainer class takes the model and the data module as inputs.
trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)

## View model graph.

PyTorch uses dynamic graphs under the hood.
The graphs are constructed on the fly.
This is in contrast to TensorFlow,
where the graph is constructed before the training loop and remains static.
In other words, the graph of the network can change with every forward pass.
Therefore, we need to supply an input tensor to construct the graph.
The input tensor can be a random tensor of the correct shape and type.
We can also supply a real image from the dataset.
The latter is more useful for debugging.

<div class="alert alert-info">

### Task 1.5
Run the next cell to generate a graph representation of the model architecture.
</div>

In [None]:
# visualize graph of phase2fluor model as image.
model_graph_phase2fluor = torchview.draw_graph(
    phase2fluor_model,
    phase2fluor_2D_data.train_dataset[0]["source"][0].unsqueeze(dim=0),
    roll=True,
    depth=3,  # adjust depth to zoom in.
    device="cpu",
    # expand_nested=True,
)
# Print the image of the model.
model_graph_phase2fluor.visual_graph

<div class="alert alert-warning">

### Question:
Can you recognize the UNet structure and skip connections in this graph visualization?
</div>

<div class="alert alert-info">

<h3> Task 1.6 </h3>
Start training by running the following cell. Check the new logs on the tensorboard.
</div>

In [None]:
# Check if GPU is available
# You can check by typing `nvidia-smi`
GPU_ID = 0

n_samples = len(phase2fluor_2D_data.train_dataset)
steps_per_epoch = n_samples // BATCH_SIZE  # steps per epoch.
n_epochs = 25  # Set this to 25-30 or the number of epochs you want to train for.

trainer = VSTrainer(
    accelerator="gpu",
    devices=[GPU_ID],
    max_epochs=n_epochs,
    log_every_n_steps=steps_per_epoch // 2,
    # log losses and image samples 2 times per epoch.
    logger=TensorBoardLogger(
        save_dir=log_dir,
        # lightning trainer transparently saves logs and model checkpoints in this directory.
        name="phase2fluor",
        log_graph=True,
    ),
)
# Launch training and check that loss and images are being logged on tensorboard.
trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data)

<div class="alert alert-success">

<h2> Checkpoint 1 </h2>

While your model is training, let's think about the following questions:<br>
<ul>
<li>What is the information content of each channel in the dataset?</li>
<li>How would you use image translation models?</li>
<li>What can you try to improve the performance of each model?</li>
</ul>

Now the training has started,
we can come back after a while and evaluate the performance!

</div>

## Part 2: Assess your trained model

Now we will look at some metrics of performance of previous model.
We typically evaluate the model performance on a held out test data.
We will use the following metrics to evaluate the accuracy of regression of the model:

- [Person Correlation](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient).
- [Structural similarity](https://en.wikipedia.org/wiki/Structural_similarity) (SSIM). 

You should also look at the validation samples on tensorboard
(hint: the experimental data in nuclei channel is imperfect.)

<div class="alert alert-info">

<h3> Task 2.1 Define metrics </h3>

For each of the above metrics, write a brief definition of what they are and what they mean
for this image translation task. Use your favorite search engine and/or resources.

</div>

```
#######################
##### Todo ############
#######################

```

- Pearson Correlation:

- Structural similarity:

### Let's compute metrics directly and plot below.

<div class="alert alert-danger">
If you weren't able to train or training didn't complete please run the following lines to load the latest checkpoint <br>

```python
phase2fluor_model_ckpt = natsorted(glob(
   str(top_dir / "06_image_translation/part1/logs/phase2fluor/version*/checkpoints/*.ckpt")
))[-1]
```
<br>
NOTE: if their model didn't go past epoch 5, lost their checkpoint, or didnt train anything. 
Run the following:

```python
phase2fluor_model_ckpt = natsorted(glob(
 str(top_dir/"06_image_translation/backup/phase2fluor/version_3/checkpoints/*.ckpt")
))[-1]
````
</div>

In [None]:
# Setup the test data module.
test_data_path = top_dir / "06_image_translation/part1/test/a549_hoechst_cellmask_test.zarr"
source_channel = ["Phase3D"]
target_channel = ["Nucl", "Mem"]

test_data = HCSDataModule(
    test_data_path,
    source_channel=source_channel,
    target_channel=target_channel,
    z_window_size=1,
    batch_size=1,
    num_workers=8,
    architecture="UNeXt2",
)
test_data.setup("test")

test_metrics = pd.DataFrame(
    columns=["pearson_nuc", "SSIM_nuc", "pearson_mem", "SSIM_mem"]
)

In [None]:
# Compute metrics directly and plot here.
def normalize_fov(input:ArrayLike):
    "Normalizing the fov with zero mean and unit variance"
    mean = np.mean(input)
    std = np.std(input)
    return (input - mean) / std

for i, sample in enumerate(tqdm(test_data.test_dataloader(), desc="Computing metrics per sample")):
    phase_image = sample["source"].to(phase2fluor_model.device)
    with torch.inference_mode():  # turn off gradient computation.
        predicted_image = phase2fluor_model(phase_image)

    target_image = (
        sample["target"].cpu().numpy().squeeze(0)
    )  # Squeezing batch dimension.
    predicted_image = predicted_image.cpu().numpy().squeeze(0)
    phase_image = phase_image.cpu().numpy().squeeze(0)
    target_mem = normalize_fov(target_image[1, 0, :, :])
    target_nuc = normalize_fov(target_image[0, 0, :, :])
    # slicing channel dimension, squeezing z-dimension.
    predicted_mem = normalize_fov(predicted_image[1, :, :, :].squeeze(0))
    predicted_nuc = normalize_fov(predicted_image[0, :, :, :].squeeze(0))

    # Compute SSIM and pearson correlation.
    ssim_nuc = metrics.structural_similarity(target_nuc, predicted_nuc, data_range=1)
    ssim_mem = metrics.structural_similarity(target_mem, predicted_mem, data_range=1)
    pearson_nuc = np.corrcoef(target_nuc.flatten(), predicted_nuc.flatten())[0, 1]
    pearson_mem = np.corrcoef(target_mem.flatten(), predicted_mem.flatten())[0, 1]

    test_metrics.loc[i] = {
        "pearson_nuc": pearson_nuc,
        "SSIM_nuc": ssim_nuc,
        "pearson_mem": pearson_mem,
        "SSIM_mem": ssim_mem,
    }

# Plot the following metrics
test_metrics.boxplot(
    column=["pearson_nuc", "SSIM_nuc", "pearson_mem", "SSIM_mem"],
    rot=30,
)

In [None]:
# Adjust the image to the 0.5-99.5 percentile range.
def process_image(image):
    p_low, p_high = np.percentile(image, (0.5, 99.5))
    return np.clip(image, p_low, p_high)

# Plot the predicted image vs target image.
channel_titles = ["Phase", "Target Nuclei", "Target Membrane", "Predicted Nuclei", "Predicted Membrane"]
fig, axes = plt.subplots(5, 1, figsize=(20, 20))

# Get a writer to output the images into tensorboard and plot the source, predictions and target images
for i, sample in enumerate(test_data.test_dataloader()):
    # Plot the phase image
    phase_image = sample["source"]
    channel_image = phase_image[0, 0, 0]
    p_low, p_high = np.percentile(channel_image, (0.5, 99.5))
    channel_image = np.clip(channel_image, p_low, p_high)
    axes[0].imshow(channel_image, cmap="gray")
    axes[0].axis("off")
    axes[0].set_title(channel_titles[0])

    with torch.inference_mode():  # turn off gradient computation.
        predicted_image = (
            phase2fluor_model(phase_image.to(phase2fluor_model.device))
            .cpu()
            .numpy()
            .squeeze(0)
        )

    target_image = sample["target"].cpu().numpy().squeeze(0)
    phase_raw = process_image(phase_image[0, 0, 0])
    predicted_nuclei = process_image(predicted_image[0,0])
    predicted_membrane = process_image(predicted_image[1,0])
    target_nuclei = process_image(target_image[0,0])
    target_membrane = process_image(target_image[1,0])
       # Concatenate all images side by side
    combined_image = np.concatenate(
        (phase_raw, predicted_nuclei, predicted_membrane, target_nuclei, target_membrane),
        axis=1
    )

    # Plot the phase,target nuclei, target membrane, predicted nuclei, predicted membrane
    axes[1].imshow(target_nuclei, cmap="gray")
    axes[2].imshow(target_membrane, cmap="gray")
    axes[3].imshow(predicted_nuclei, cmap="gray")
    axes[4].imshow(predicted_membrane, cmap="gray")

    plt.tight_layout()
    plt.show()
    break

<div class="alert alert-info">

<h3> Task 2.2 Loading the pretrained model VSCyto2D </h3>
Here we will compare your model with the VSCyto2D pretrained model by computing the pixel-based metrics and segmentation-based metrics.

<ul>
<li>When you ran the `setup.sh` you also downloaded the models in `/06_image_translation/part1/pretrained_models/VSCyto2D/*.ckpt`</li>
<li>Load the <b>VSCyto2 model</b> model checkpoint and the configuration file</li>
<li>Compute the pixel-based metrics and segmentation-based metrics between the model you trained and the pretrained model</li>
</ul>
<br>

</div>

In [None]:
# #######################
# ##### SOLUTION ########
# #######################

pretrained_model_ckpt = (
    top_dir / "06_image_translation/part1/pretrained_models/VSCyto2D/epoch=399-step=23200.ckpt"
)

phase2fluor_config = dict(
    in_channels=1,
    out_channels=2,
    encoder_blocks=[3, 3, 9, 3],
    dims=[96, 192, 384, 768],
    decoder_conv_blocks=2,
    stem_kernel_size=(1, 2, 2),
    in_stack_depth=1,
    pretraining=False,
)
# Load the model checkpoint
pretrained_phase2fluor = VSUNet.load_from_checkpoint(
    pretrained_model_ckpt,
    architecture="UNeXt2_2D",
    model_config = phase2fluor_config,
)
pretrained_phase2fluor.eval()

### Re-load your trained model
# NOTE: assuming the latest checkpoint it your latest training and model
phase2fluor_model_ckpt = natsorted(glob(
    str(training_top_dir / "06_image_translation/part1/logs/phase2fluor/version*/checkpoints/*.ckpt")
))[-1]

# NOTE: if their model didn't go past epoch 5, lost their checkpoint, or didnt train anything. 
# Uncomment the next lines
#phase2fluor_model_ckpt = natsorted(glob(
#  str(top_dir/"06_image_translation/backup/phase2fluor/version_3/checkpoints/*.ckpt")
#))[-1]


phase2fluor_config = dict(
    in_channels=1,
    out_channels=2,
    encoder_blocks=[3, 3, 9, 3],
    dims=[96, 192, 384, 768],
    decoder_conv_blocks=2,
    stem_kernel_size=(1, 2, 2),
    in_stack_depth=1,
    pretraining=False,
)
# Load the model checkpoint
phase2fluor_model = VSUNet.load_from_checkpoint(
    phase2fluor_model_ckpt,
    architecture="UNeXt2_2D",
    model_config = phase2fluor_config,
    accelerator='gpu'
)
phase2fluor_model.eval()

<div class="alert alert-warning">
<h3> Question </h3> 
1. Can we evaluate a model's performance based on their segmentations?<br>
2. Look up IoU or Jaccard index, dice coefficient, and AP metrics. LINK:https://metrics-reloaded.dkfz.de/metric-library <br>
We will evaluate the performance of your trained model with a pre-trained model using pixel based metrics as above and
segmantation based metrics including (mAP@0.5, dice, accuracy and jaccard index). <br>
</div>


- <b> IoU (Intersection over Union): </b> Also referred to as the Jaccard index, is essentially a method to quantify the percent overlap between the target and predicted masks. 
It is calculated as the intersection of the target and predicted masks divided by the union of the target and predicted masks. <br>
- <b> Dice Coefficient:</b> Metric used to evaluate the similarity between two sets.<br>
It is calculated as twice the intersection of the target and predicted masks divided by the sum of the target and predicted masks.<br>
- <b> mAP (mean Average Precision):</b>  The mean Average Precision (mAP) is a metric used to evaluate the performance of object detection models. 
It is calculated as the average precision across all classes and is used to measure the accuracy of the model in localizing objects.


### Let's compute the metrics for the test dataset
Before you run the following code, make sure you have the pretrained model loaded and the test data is ready.

The following code will compute the following:
- the pixel-based metrics  (pearson correlation, SSIM)
- segmentation-based metrics (mAP@0.5, dice, accuracy, jaccard index)

#### Note:
- The segmentation-based metrics are computed using the cellpose stock `nuclei` model
- The metrics will be store in the `test_pixel_metrics` and `test_segmentation_metrics` dataframes
- The segmentations will be stored in the `segmentation_store` zarr file
- Analyze the code while it runs.

In [None]:
# Define the function to compute the cellpose segmentation
def cellpose_segmentation(prediction:ArrayLike,target:ArrayLike)->Tuple[torch.ShortTensor]:
    #NOTE these are hardcoded for this notebook and A549 dataset
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    cp_nuc_kwargs = {
        "diameter": 65,
        "channels": [0, 0],
        "cellprob_threshold": 0.0,
    }
    cellpose_model = models.CellposeModel(
            gpu=True, model_type='nuclei', device=torch.device(device)
    )
    pred_label, _, _ = cellpose_model.eval(prediction, **cp_nuc_kwargs)
    target_label, _, _ = cellpose_model.eval(target, **cp_nuc_kwargs)

    pred_label = pred_label.astype(np.int32)
    target_label = target_label.astype(np.int32)
    pred_label = torch.ShortTensor(pred_label)
    target_label = torch.ShortTensor(target_label)

    return (pred_label,target_label)

In [None]:
# Setting the paths for the test data and the output segmentation
test_data_path = top_dir / "06_image_translation/part1/test/a549_hoechst_cellmask_test.zarr"
output_segmentation_path= training_top_dir /"06_image_translation/part1/pretrained_model_segmentations.zarr"

# Creating the dataframes to store the pixel and segmentation metrics
test_pixel_metrics = pd.DataFrame(
    columns=["model", "fov","pearson_nuc", "SSIM_nuc", "pearson_mem", "SSIM_mem"]
)
test_segmentation_metrics= pd.DataFrame(
    columns=["model", "fov","masks_per_fov","accuracy","dice","jaccard","mAP","mAP_50","mAP_75","mAR_100"]
)
# Opening the test dataset
test_dataset = open_ome_zarr(test_data_path)

# Creating an output store for the predictions and segmentations
segmentation_store = open_ome_zarr(output_segmentation_path,channel_names=['nuc_pred','mem_pred','nuc_labels'],mode='w',layout='hcs')

# Looking at the test dataset
print('Test dataset:')
test_dataset.print_tree()
channel_names = test_dataset.channel_names
print(f'Channel names: {channel_names}')

# Finding the channel indices for the corresponding channel names
phase_cidx = channel_names.index("Phase3D")
nuc_cidx = channel_names.index("Nucl")
mem_cidx =  channel_names.index("Mem")
nuc_label_cidx =  channel_names.index("nuclei_segmentation")

In [None]:
def min_max_scale(image:ArrayLike)->ArrayLike:
    "Normalizing the image using min-max scaling"
    min_val = image.min()
    max_val = image.max()
    return (image - min_val) / (max_val - min_val)

# Iterating through the test dataset positions to:
positions = list(test_dataset.positions())
total_positions = len(positions)

# Initializing the progress bar with the total number of positions
with tqdm(total=total_positions, desc="Processing FOVs") as pbar:
    # Iterating through the test dataset positions
    for fov, pos in positions:
        T,C,Z,Y,X = pos.data.shape
        Z_slice = slice(Z//2,Z//2+1)
        # Getting the arrays and the center slices
        phase_image = pos.data[:,phase_cidx:phase_cidx+1,Z_slice]
        target_nucleus =  pos.data[0,nuc_cidx:nuc_cidx+1,Z_slice]
        target_membrane =  pos.data[0,mem_cidx:mem_cidx+1,Z_slice]
        target_nuc_label = pos.data[0,nuc_label_cidx:nuc_label_cidx+1,Z_slice]

        #normalize the phase
        phase_image = normalize_fov(phase_image)
        
        # Running the prediction for both models
        phase_image = torch.from_numpy(phase_image).type(torch.float32)
        phase_image = phase_image.to(phase2fluor_model.device)
        with torch.inference_mode():  # turn off gradient computation.
            predicted_image_phase2fluor = phase2fluor_model(phase_image)
            predicted_image_pretrained = pretrained_phase2fluor(phase_image)

        # Loading and Normalizing the target and predictions for both models 
        predicted_image_phase2fluor = predicted_image_phase2fluor.cpu().numpy().squeeze(0)
        predicted_image_pretrained = predicted_image_pretrained.cpu().numpy().squeeze(0)
        phase_image = phase_image.cpu().numpy().squeeze(0)

        target_mem = min_max_scale(target_membrane[0,0])
        target_nuc = min_max_scale(target_nucleus[0,0])
    
        # Normalizing the dataset using min-max scaling
        predicted_mem_phase2fluor = min_max_scale(
            predicted_image_phase2fluor[1, :, :, :].squeeze(0)
        )
        predicted_nuc_phase2fluor = min_max_scale(
            predicted_image_phase2fluor[0, :, :, :].squeeze(0)
        )

        predicted_mem_pretrained = min_max_scale(
            predicted_image_pretrained[1, :, :, :].squeeze(0)
        )
        predicted_nuc_pretrained = min_max_scale(
            predicted_image_pretrained[0, :, :, :].squeeze(0)
        )

        #######  Pixel-based Metrics ############
        # Compute SSIM and Pearson correlation for phase2fluor_model
        print('Computing Pixel Metrics')
        ssim_nuc_phase2fluor = metrics.structural_similarity(
            target_nuc, predicted_nuc_phase2fluor, data_range=1
        )
        ssim_mem_phase2fluor = metrics.structural_similarity(
            target_mem, predicted_mem_phase2fluor, data_range=1
        )
        pearson_nuc_phase2fluor = np.corrcoef(
            target_nuc.flatten(), predicted_nuc_phase2fluor.flatten()
        )[0, 1]
        pearson_mem_phase2fluor = np.corrcoef(
            target_mem.flatten(), predicted_mem_phase2fluor.flatten()
        )[0, 1]

        test_pixel_metrics.loc[len(test_pixel_metrics)] = {
            "model": "phase2fluor",
            "fov":fov,
            "pearson_nuc": pearson_nuc_phase2fluor,
            "SSIM_nuc": ssim_nuc_phase2fluor,
            "pearson_mem": pearson_mem_phase2fluor,
            "SSIM_mem": ssim_mem_phase2fluor,
        }
        # Compute SSIM and Pearson correlation for pretrained_model
        ssim_nuc_pretrained = metrics.structural_similarity(
            target_nuc, predicted_nuc_pretrained, data_range=1
        )
        ssim_mem_pretrained = metrics.structural_similarity(
            target_mem, predicted_mem_pretrained, data_range=1
        )
        pearson_nuc_pretrained = np.corrcoef(
            target_nuc.flatten(), predicted_nuc_pretrained.flatten()
        )[0, 1]
        pearson_mem_pretrained = np.corrcoef(
            target_mem.flatten(), predicted_mem_pretrained.flatten()
        )[0, 1]

        test_pixel_metrics.loc[len(test_pixel_metrics)] = {
            "model": "pretrained_phase2fluor",
            "fov":fov,
            "pearson_nuc": pearson_nuc_pretrained,
            "SSIM_nuc": ssim_nuc_pretrained,
            "pearson_mem": pearson_mem_pretrained,
            "SSIM_mem": ssim_mem_pretrained,
        }

        ###### Segmentation based metrics #########
        # Load the manually curated nuclei target label
        print('Computing Segmentation Metrics')
        pred_label,target_label= cellpose_segmentation(predicted_nuc_phase2fluor,target_nucleus)
        # Binary labels
        pred_label_binary = pred_label > 0
        target_label_binary = target_label > 0

        # Use Coco metrics to get mean average precision
        coco_metrics = mean_average_precision(pred_label, target_label)
        # Find unique number of labels
        num_masks_fov = len(np.unique(pred_label))

        test_segmentation_metrics.loc[len(test_segmentation_metrics)] = {
            "model": "phase2fluor",
            "fov":fov,
            "masks_per_fov": num_masks_fov,
            "accuracy": accuracy(pred_label_binary, target_label_binary, task="binary").item(),
            "dice":  dice(pred_label_binary, target_label_binary).item(),
            "jaccard": jaccard_index(pred_label_binary, target_label_binary, task="binary").item(),
            "mAP":coco_metrics["map"].item(),
            "mAP_50":coco_metrics["map_50"].item(),
            "mAP_75":coco_metrics["map_75"].item(),
            "mAR_100":coco_metrics["mar_100"].item()
        }

        pred_label,target_label= cellpose_segmentation(predicted_nuc_pretrained,target_nucleus)
        
        # Binary labels
        pred_label_binary = pred_label > 0
        target_label_binary = target_label > 0

        # Use Coco metrics to get mean average precision
        coco_metrics = mean_average_precision(pred_label, target_label)
        # Find unique number of labels
        num_masks_fov = len(np.unique(pred_label))

        test_segmentation_metrics.loc[len(test_segmentation_metrics)] = {
            "model": "phase2fluor_pretrained",
            "fov":fov,
            "masks_per_fov": num_masks_fov,
            "accuracy": accuracy(pred_label_binary, target_label_binary, task="binary").item(),
            "dice":  dice(pred_label_binary, target_label_binary).item(),
            "jaccard": jaccard_index(pred_label_binary, target_label_binary, task="binary").item(),
            "mAP":coco_metrics["map"].item(),
            "mAP_50":coco_metrics["map_50"].item(),
            "mAP_75":coco_metrics["map_75"].item(),
            "mAR_100":coco_metrics["mar_100"].item()
        }
        
        #Save the predictions and segmentations
        position = segmentation_store.create_position(*Path(fov).parts[-3:])
        output_array = np.zeros((T,3,1,Y,X),dtype=np.float32)
        output_array[0,0,0]=predicted_nuc_pretrained
        output_array[0,1,0]=predicted_mem_pretrained
        output_array[0,2,0]=np.array(pred_label)
        position.create_image("0",output_array)

        # Update the progress bar
        pbar.update(1)
    
# Close the OME-Zarr files
test_dataset.close()
segmentation_store.close()

In [None]:
# Save the test metrics into a dataframe
pixel_metrics_path = training_top_dir/"06_image_translation/part1/VS_metrics_pixel_part_1.csv"
segmentation_metrics_path = training_top_dir/"06_image_translation/part1/VS_metrics_segments_part_1.csv"
test_pixel_metrics.to_csv(pixel_metrics_path)
test_segmentation_metrics.to_csv(segmentation_metrics_path)

<div class="alert alert-info">

<h3> Task 2.3 Compare the model's metrics </h3>
In the previous section, we computed the pixel-based metrics and segmentation-based metrics.
Now we will compare the performance of the model you trained with the pretrained model by plotting the boxplots.

After you plot the metrics answer the following:
<ul>
<li>What do these metrics tells us about the performance of the model?</li>
<li>How do you interpret the differences in the metrics between the models?</li>
<li>How is your model compared to the pretrained model? How can you improve it?</li>
</ul>
</div>

In [None]:
# Show boxplot of the metrics
# Boxplot of the metrics
test_pixel_metrics.boxplot(
    by="model",
    column=["pearson_nuc", "SSIM_nuc", "pearson_mem", "SSIM_mem"],
    rot=30,
    figsize=(8, 8),
)
plt.suptitle("Model Pixel Metrics")
plt.show()
# Show boxplot of the metrics
# Boxplot of the metrics
test_segmentation_metrics.boxplot(
    by="model",
    column=["jaccard", "accuracy", "mAP_75","mAP_50"],
    rot=30,
    figsize=(8, 8),
)
plt.suptitle("Model Segmentation Metrics")
plt.show()

<div class="alert alert-info">

### Plotting the predictions and segmentations
Here we will plot the predictions and segmentations side by side for the pretrained and trained models.<br>
- How do yout model, the pretrained model and the ground truth compare?<br>
- How do the segmentations compare? <br>
Feel free to modify the crop size and Y,X slicing to view different areas of the FOV
</div>

<div class="alert alert-success">

<h2>
🎉 The end of the notebook 🎉
Continue to Part 2: Image translation with generative models.
</h2>

Congratulations! You have trained an image translation model and evaluated its performance.
</div>