# Normalization Method Evaluation

This notebook contains code for the quantitative evaluation of the normalization method. The quantitative evaluation consists of computing the average signal-to-noise ratios (SNR) before and after normalization of each modality. 

Before running this pipeline, the main pipeline must have been run for all datasets to evaluate to make sure we have normalized samples to evaluate by. 

### Imports

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch

from pathlib import Path
from dotenv import load_dotenv
from tqdm import tqdm

from sklearn.preprocessing import minmax_scale
from skimage.filters import threshold_otsu
from cellpose import models

from utils.data import NormalizedDataset

### Load dataset

In [4]:
def load_dataset_paths():
    """ Get dataset paths from .env file """
    load_dotenv()
    dataset_paths = []

    # Loop through environment variables and collect dataset paths
    for key, value in os.environ.items():
        if key.startswith("DATASET_PATH_"):  # Look for keys starting with "DATASET_PATH_"
            dataset_paths.append(Path(value.strip("'")))

    return dataset_paths

# Get dataset paths from .env file
dataset_paths = load_dataset_paths()

# Alternatively, manually write the correct paths in the following line: 
# dataset_paths = [Path('C:/.../toy1/'), Path('C:/.../toy2/')]

# Create dataset to get both raw and normalized images
dataset = NormalizedDataset(dataset_paths)

### Function to Compute SNR

The signal is considered the direct output of CellPose segmentation. The noise is considered a background mask based on low-pass filtering in the Fourier domain and Otsu thresholding. 

In [26]:
def compute_snr(
    sample: np.ndarray, 
    segmentation_model, 
    plot: bool
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """ Computes signal-to-noise ratio of each modality and channel of a sample.

    Computes the signal-to-noise ratio (SNR) as the ratio between signal and noise, where the signal
    is considered the output of CellPose segmentation and the noise is considered a background mask 
    obtained by low-pass filtering in the Fourier domain and Otsu thresholding. 
    
    Args:
        sample: A numpy array with the sample data, format (M, C, H, W),
                M - number of modalities, C - number of channels, H - image height, W - image width.
        segmentation_model: A CellPose segmentation model.
        plot: A boolean to indicate whether to plot the masks or not. 
    
    Returns:
        tuple - Five 1D arrays containing the SNR, signal mean, signal standard deviation, noise mean 
                and noise standard deviation of each modality and channel.
    """

    def compute_snr_channel(
        channel: np.ndarray, 
        signal_mask: np.ndarray, 
        noise_mask: np.ndarray
    ) -> tuple[float, float, float, float, float]:
        """ Help function to compute the SNR of a channel. """

        # Get the noise and signal pixels
        channel_noise = channel[noise_mask == 1]
        channel_signal = channel[signal_mask == 1]

        # Compute noise and signal mean
        noise_mean = channel_noise.mean()
        noise_std = channel_noise.std()

        # Compute noise and signal standard deviation
        signal_mean = channel_signal.mean()
        signal_std = channel_signal.std()

        # Compute the SNR
        snr = signal_mean / noise_std

        return snr, signal_mean, signal_std, noise_mean, noise_std
    
    # Rescale the channels of a sample to ensure consistent scale. 
    for i in range(sample.shape[0]):
        for c in range(sample.shape[1]):
            sample[i,c] = minmax_scale(sample[i,c])
            if i == 1:  # Invert scattering mode for consistency of computed values
                sample[i,c] = 1 - sample[i,c]
    
    # Get cell mask from CellPose segmentation
    combined_image = minmax_scale(sample.mean(axis=(0,1)))
    mask, _, _, _ = segmentation_model.eval(combined_image, channels=[0, 0], diameter=35)
    cell_mask = np.where(mask != 0, 1, 0)

    # Get background mask by low-pass filtering and thresholding
    cutoff = 500
    fft_image = np.fft.fft2(combined_image)
    fft_shifted = np.fft.fftshift(fft_image)
    rows, cols = combined_image.shape
    crow, ccol = rows // 2, cols // 2
    mask = np.zeros_like(combined_image)
    y, x = np.ogrid[:rows, :cols]
    mask_area = (x - ccol)**2 + (y - crow)**2 <= cutoff**2
    mask[mask_area] = 1
    filtered_fft = fft_shifted * mask
    filtered_image = np.fft.ifft2(np.fft.ifftshift(filtered_fft)).real
    otsu_thresh = threshold_otsu(filtered_image)
    background_mask = filtered_image > otsu_thresh

    # Plot signal and noise masks
    if plot:
        fig, axes = plt.subplots(1, 2, figsize=(10, 5))
        axes[0].imshow(combined_image, cmap="gray", vmin=0, vmax=1)
        axes[0].set_title("Combined Image")
        axes[0].axis(False)
        background_image = combined_image.copy()
        background_image[background_mask == 0] = 0
        axes[1].imshow(background_image, cmap="gray", vmin=0, vmax=1)
        axes[1].set_title("Isolated Background")
        axes[1].axis(False)
        plt.suptitle("Background Mask", fontsize=24)
        plt.tight_layout()
        plt.show()

        fig, axes = plt.subplots(1, 2, figsize=(10, 5))
        axes[0].imshow(combined_image, cmap="gray", vmin=0, vmax=1)
        axes[0].set_title("Combined Image")
        axes[0].axis(False)
        cell_image = combined_image.copy()
        cell_image[cell_mask == 0] = 0
        axes[1].imshow(cell_image, cmap="gray", vmin=0, vmax=1)
        axes[1].set_title("Isolated Cells")
        axes[1].axis(False)
        plt.suptitle("Cell Mask", fontsize=24)
        plt.tight_layout()
        plt.show()

    # Arrays to store results
    snr = np.zeros((3, 13), dtype=np.float32)
    signal_mean = np.zeros((3, 13), dtype=np.float32)
    signal_std = np.zeros((3, 13), dtype=np.float32)
    noise_mean = np.zeros((3, 13), dtype=np.float32)
    noise_std = np.zeros((3, 13), dtype=np.float32)

    # Compute SNR of each modality and channel of the sample
    for i in range(sample.shape[0]):
        for c in range(sample.shape[1]):
            snr[i,c], signal_mean[i,c], signal_std[i,c], noise_mean[i,c], noise_std[i,c] = compute_snr_channel(sample[i,c], cell_mask, background_mask)

    return snr, signal_mean, signal_std, noise_mean, noise_std

### Compute SNR and Display Results

In [None]:
N = len(dataset)

# Allocate space for results
snr_values = np.zeros((N, 3, 13))
signal_means = np.zeros((N, 3, 13))
signal_stds = np.zeros((N, 3, 13))
noise_means = np.zeros((N, 3, 13))
noise_stds = np.zeros((N, 3, 13))

norm_snr_values = np.zeros((N, 3, 13))
norm_signal_means = np.zeros((N, 3, 13))
norm_signal_stds = np.zeros((N, 3, 13))
norm_noise_means = np.zeros((N, 3, 13))
norm_noise_stds = np.zeros((N, 3, 13))

# Segmentation model
segmentation_model = models.Cellpose(model_type='cyto3', gpu=torch.cuda.is_available())

# Iterate through dataset and compute SNR values
for idx in tqdm(range(N)):
    dataset_idx = idx
    snr_values[idx], signal_means[idx], signal_stds[idx], noise_means[idx], noise_stds[idx] = compute_snr(dataset[dataset_idx]["sample"], segmentation_model, plot=False)
    norm_snr_values[idx], norm_signal_means[idx], norm_signal_stds[idx], norm_noise_means[idx], norm_noise_stds[idx] = compute_snr(dataset[dataset_idx]["sample_norm"], segmentation_model, plot=False)

# Print results
print("\nSNR:")
print(snr_values.mean(axis=(0,2)))
print(norm_snr_values.mean(axis=(0,2)))

print("\nSignal mean:")
print(signal_means.mean(axis=(0,2)))
print(norm_signal_means.mean(axis=(0,2)))

print("\nSignal std:")
print(signal_stds.mean(axis=(0,2)))
print(norm_signal_stds.mean(axis=(0,2)))

print("\nNoise mean:")
print(noise_means.mean(axis=(0,2)))
print(norm_noise_means.mean(axis=(0,2)))

print("\nNoise std:")
print(noise_stds.mean(axis=(0,2)))
print(norm_noise_stds.mean(axis=(0,2)))

# print("\nSNR:")
# print(snr_values.mean(axis=(0)))
# print(norm_snr_values.mean(axis=(0)))

# print("\nSignal mean:")
# print(signal_means.mean(axis=(0)))
# print(norm_signal_means.mean(axis=(0)))

# print("\nSignal std:")
# print(signal_stds.mean(axis=(0)))
# print(norm_signal_stds.mean(axis=(0)))

# print("\nNoise mean:")
# print(noise_means.mean(axis=(0)))
# print(norm_noise_means.mean(axis=(0)))

# print("\nNoise std:")
# print(noise_stds.mean(axis=(0)))
# print(norm_noise_stds.mean(axis=(0)))