# Method Comparison for RBC Intensity Magnitude Surface Estimation

This notebook contains code to compare implemented methods for red blood cell intensity magnitude surface estimation, the final step of the normalization process. 

### Imports

In [None]:
import os
import torch
import random
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from dotenv import load_dotenv
from cellpose import models

from utils.data import Dataset
from utils.surface_estimation import get_masks, fit_polynomial_background, fit_rbc_surface
from utils.component_analysis import SamplePCA

### Load Dataset and Select a Random Sample

Load data by creating a dataset. Specify file paths either by creating a .env file as described by README.md or manually writing the paths to the dataset folders. 

In [None]:
def load_dataset_paths():
    """ Get dataset paths from .env file """
    load_dotenv()
    dataset_paths = []

    # Loop through environment variables and collect dataset paths
    for key, value in os.environ.items():
        if key.startswith("DATASET_PATH_"):  # Look for keys starting with "DATASET_PATH_"
            dataset_paths.append(Path(value.strip("'")))

    return dataset_paths


# Parameter to write out details during processing
verbose = True

# Get dataset paths from .env file
dataset_paths = load_dataset_paths()

# Alternatively, manually write the correct paths in the following line: 
# dataset_paths = [Path('C:/.../toy1/'), Path('C:/.../toy2/')]

# Create dataset
dataset = Dataset(dataset_paths)

# Define segmentation model (here a pre-trained CellPose model)
segmentation_model = models.Cellpose(model_type='cyto3', gpu=torch.cuda.is_available())

# Get sample from dataset
idx = random.randint(0, len(dataset))
print(idx)
data = dataset[idx]
sample = data["sample"]

Help function to plot principal components of results.

In [None]:
def plot_example(sample, title):
    def get_principal_components(sample):
        """ Help function to compute and scale principal components. """
        pca = SamplePCA(n_components=3)
        pca.fit(sample)
        sample_principal = pca.transform(sample)[0]

        principal_image_R = sample_principal[0]
        principal_image_S = sample_principal[1]
        principal_image_T = sample_principal[2]

        # Rescale to range [0,1]
        min_R = np.min(principal_image_R, (1,2))
        max_R = np.max(principal_image_R, (1,2))
        principal_image_R = (principal_image_R - min_R[:, None, None]) / (max_R - min_R)[:, None, None]
        min_S = np.min(principal_image_S, (1,2))
        max_S = np.max(principal_image_S, (1,2))
        principal_image_S = (principal_image_S - min_S[:, None, None]) / (max_S - min_S)[:, None, None]
        min_T = np.min(principal_image_T, (1,2))
        max_T = np.max(principal_image_T, (1,2))
        principal_image_T = (principal_image_T - min_T[:, None, None]) / (max_T - min_T)[:, None, None]

        return principal_image_R, principal_image_S, principal_image_T

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    principal_image_R, principal_image_S, principal_image_T = get_principal_components(sample)
    axs[0].imshow(principal_image_R.transpose(1, 2, 0))
    axs[1].imshow(principal_image_S.transpose(1, 2, 0))
    axs[2].imshow(principal_image_T.transpose(1, 2, 0))
    axs[0].set_title("Reflected Principal Components")
    axs[1].set_title("Scattered Principal Components")
    axs[2].set_title("Transmitted Principal Components")
    for ax in axs:
        ax.axis(False)
    plt.suptitle(title, fontsize=24)
    plt.tight_layout()
    plt.show()

## Run Normalization Pipeline

Run the steps of the normalization pipeline (binary mask construction, background correction, RBC surface estimation) and compare the results by the different methods for the last step. 

### Binary mask construction

In [None]:
background_mask, cell_mask = get_masks(sample, segmentation_model, plot=False, verbose=verbose)

### Background Correction

In [None]:
bc_sample, _ = fit_polynomial_background(sample, background_mask, verbose=verbose)

### Linear Interpolation RBC Surface

In [None]:
norm_sample_linear, rbc_surface_linear = fit_rbc_surface(bc_sample, cell_mask, method="linear", verbose=verbose)

### Cubic Bivariate Spline RBC Surface

In [None]:
norm_sample_bsplines, rbc_surface_bsplines = fit_rbc_surface(bc_sample, cell_mask, method="b-spline", verbose=verbose)

### Gaussian Process Regression RBC Surface

In [None]:
norm_sample_gpr, rbc_surface_gpr = fit_rbc_surface(bc_sample, cell_mask, method="gpr", verbose=verbose)

Visualize the normalized sample and estimated RBC surface by the different methods for comparison

In [None]:
plot_example(norm_sample_linear, "Normalized Sample Linear")
plot_example(rbc_surface_linear, "RBC Surface Linear")

plot_example(norm_sample_bsplines, "Normalized Sample B-Spline")
plot_example(rbc_surface_bsplines, "RBC Surface B-Spline")

plot_example(norm_sample_gpr, "Normalized Sample GPR")
plot_example(rbc_surface_gpr, "RBC Surface GPR")

Visualize a detailed view of the results (method comparison) for an example channel of the sample. 

In [None]:
modality_idx = 0
channel_idx = 1
example_channel = sample[modality_idx, channel_idx]
H, W = example_channel.shape
x = np.arange(W)
y = np.arange(H)
x, y = np.meshgrid(x, y)

aspect_ratio = [1, H/W, 0.5]
fontsize_ticks = 8
fontsize_axislabels = 9
fontsize_titles = 12

fig = plt.figure(figsize=(10, 13))
ax = fig.add_subplot(321, projection='3d')
ax.plot_surface(x, y, rbc_surface_linear[modality_idx, channel_idx], cmap='gray', edgecolor='none')
ax.text2D(1.3, 1, 'Linear Interpolation', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax.set_xlim(0, W-1)
ax.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax.set_ylim(H-1, 0)
ax.text2D(1.05, 0.73, 'intensity', transform=ax.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax.set_zlim([0, 1])
ax.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax.set_box_aspect(aspect_ratio)

ax = fig.add_subplot(322)
ax.imshow(rbc_surface_linear[modality_idx, channel_idx], cmap="gray")
ax.axis(False)

ax = fig.add_subplot(323, projection='3d')
ax.plot_surface(x, y, rbc_surface_bsplines[modality_idx, channel_idx], cmap='gray', edgecolor='none')
ax.text2D(1.3, 1, 'Cubic Bivariate Spline', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax.set_xlim(0, W-1)
ax.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax.set_ylim(H-1, 0)
ax.text2D(1.05, 0.73, 'intensity', transform=ax.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax.set_zlim([0, 1])
ax.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax.set_box_aspect(aspect_ratio)

ax = fig.add_subplot(324)
ax.imshow(rbc_surface_bsplines[modality_idx, channel_idx], cmap="gray")
ax.axis(False)

ax = fig.add_subplot(325, projection='3d')
ax.plot_surface(x, y, rbc_surface_gpr[modality_idx, channel_idx], cmap='gray', edgecolor='none')
ax.text2D(1.3, 1, 'Gaussian Process Regression', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax.set_xlim(0, W-1)
ax.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax.set_ylim(H-1, 0)
ax.text2D(1.05, 0.73, 'intensity', transform=ax.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax.set_zlim([0, 1])
ax.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax.set_box_aspect(aspect_ratio)

ax = fig.add_subplot(326)
ax.imshow(rbc_surface_gpr[modality_idx, channel_idx], cmap="gray")
ax.axis(False)