# Visualization of the Normalization Process

This notebook contains a visualization of the normalization process of this repository. The first part of the notebook contains code for detailed visualization of one example. The second part contains code for a less detailed visualization of five examples for comparison. 

### Imports

In [14]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch

from pathlib import Path
from dotenv import load_dotenv
from cellpose import models
from sklearn.preprocessing import minmax_scale
from tqdm import tqdm
from joblib import Parallel, delayed

from utils.data import Dataset
from utils.component_analysis import SamplePCA, ImagePCA
from utils.surface_estimation import get_masks, fit_polynomial_background, fit_gpr_surface

### Load Dataset

Load data by creating a dataset. Specify file paths either by creating a .env file as described by README.md or manually writing the paths to the dataset folders. 

In [15]:
def load_dataset_paths():
    """ Get dataset paths from .env file """
    load_dotenv()
    dataset_paths = []

    # Loop through environment variables and collect dataset paths
    for key, value in os.environ.items():
        if key.startswith("DATASET_PATH_"):  # Look for keys starting with "DATASET_PATH_"
            dataset_paths.append(Path(value.strip("'")))

    return dataset_paths

# Get dataset paths from .env file
dataset_paths = load_dataset_paths()

# Alternatively, manually write the correct paths in the following line: 
# dataset_paths = [Path('C:/.../toy1/'), Path('C:/.../toy2/')]

# Create dataset
dataset = Dataset(dataset_paths)

## Part 1 - Detailed Visualization of an Example

This part contains code for detailed visualization of an example sample. First, extract a random sample from the data and visualize modality averages. 

In [None]:
# Select a random image
idx = np.random.randint(0, len(dataset))
data = dataset[idx]
print(idx)

# Extract reflected, scattered, and transmitted images
sample = data['sample']
image_R = sample[0]
image_S = sample[1]
image_T = sample[2]

# Show modality averages
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(minmax_scale(image_R.mean(0)), cmap='gray', vmin=0, vmax=1)
axs[1].imshow(minmax_scale(image_S.mean(0)), cmap='gray', vmin=0, vmax=1)
axs[2].imshow(minmax_scale(image_T.mean(0)), cmap='gray', vmin=0, vmax=1)
axs[0].set_title("Reflected Channel Average")
axs[1].set_title("Scattered Channel Average")
axs[2].set_title("Transmitted Channel Average")
for ax in axs:
    ax.axis(False)
plt.suptitle("Channel Averages", fontsize=24)
plt.tight_layout()
plt.show()

We can visualize the three principal components of each mode of the example to better see variations across the image. Note that by looking at the average intensity, we do not see much variation. Looking at the principal channel components however, the variance is better viewed. 

In [None]:
# Compute principal components
pca = SamplePCA(n_components=3)
pca.fit(sample)
sample_principal = pca.transform(sample)[0]

# Extract modalities
principal_image_R = sample_principal[0]
principal_image_S = sample_principal[1]
principal_image_T = sample_principal[2]

# Rescale to range [0,1]
min_R = np.min(principal_image_R, (1,2))
max_R = np.max(principal_image_R, (1,2))
principal_image_R = (principal_image_R - min_R[:, None, None]) / (max_R - min_R)[:, None, None]
min_S = np.min(principal_image_S, (1,2))
max_S = np.max(principal_image_S, (1,2))
principal_image_S = (principal_image_S - min_S[:, None, None]) / (max_S - min_S)[:, None, None]
min_T = np.min(principal_image_T, (1,2))
max_T = np.max(principal_image_T, (1,2))
principal_image_T = (principal_image_T - min_T[:, None, None]) / (max_T - min_T)[:, None, None]

# Plot results
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(principal_image_R.transpose(1, 2, 0))
axs[1].imshow(principal_image_S.transpose(1, 2, 0))
axs[2].imshow(principal_image_T.transpose(1, 2, 0))
axs[0].set_title("Reflected Principal Components")
axs[1].set_title("Scattered Principal Components")
axs[2].set_title("Transmitted Principal Components")
for ax in axs:
    ax.axis(False)
plt.suptitle("Principal Components", fontsize=24)
plt.tight_layout()
plt.show()

### Construction of Binary Masks

Construct the binary masks indicating background and cell pixels here for the example. 

In [None]:
# Use CellPose for cell segmentation
segmentation_model = models.Cellpose(model_type='cyto3', gpu=torch.cuda.is_available())

# Get binary masks
background_mask, cell_mask = get_masks(sample, segmentation_model, plot=True)

### Normalization Through Surface Estimations

Now fit a polynomial surface to the background indicated by the background mask and a probabilistic surface to local average cell pixel intensity magnitude indicated by the cell mask.

In [None]:
# Background estimation
background_corrected_sample, fitted_background_surface = fit_polynomial_background(sample, background_mask, verbose=True)

# Cell intensity magnitude estimation
normalized_sample, fitted_cell_surface = fit_gpr_surface(background_corrected_sample, cell_mask, verbose=True)

### Visualization

To visualize the process, display the principal components at each intermediate stage during the normalization process. 

In [None]:
# Original sample modalities
image_R = sample[0]
image_S = sample[1]
image_T = sample[2]

# Fitted background surface modalities
fitted_background_surface_R = fitted_background_surface[0]
fitted_background_surface_S = fitted_background_surface[1]
fitted_background_surface_T = fitted_background_surface[2]

# Background-corrected sample modalities
background_corrected_image_R = background_corrected_sample[0]
background_corrected_image_S = background_corrected_sample[1]
background_corrected_image_T = background_corrected_sample[2]

# Estimated cell intensity magnitude surface modalities
fitted_cell_surface_R = fitted_cell_surface[0]
fitted_cell_surface_S = fitted_cell_surface[1]
fitted_cell_surface_T = fitted_cell_surface[2]

# Normalized sample modalities
normalized_image_R = normalized_sample[0]
normalized_image_S = normalized_sample[1]
normalized_image_T = normalized_sample[2]

# Fit PCA models to compute principal components of each modality
pca_R = ImagePCA(n_components=3)
pca_S = ImagePCA(n_components=3)
pca_T = ImagePCA(n_components=3)
pca_R.fit(image_R)
pca_S.fit(image_S)
pca_T.fit(image_T)

# print(np.min(fitted_cell_surface_R, (1,2)))
# print(np.min(fitted_cell_surface_S, (1,2)))
# print(np.min(fitted_cell_surface_T, (1,2)))
# print(np.max(fitted_cell_surface_R, (1,2)))
# print(np.max(fitted_cell_surface_S, (1,2)))
# print(np.max(fitted_cell_surface_T, (1,2)))

def compute_rescale_principal_components(image, pca, rescale_img=None, rescale=True):
    """ Help function to compute and rescale principal components of a modality."""
    principal_image = pca.transform(image)[0]
    if rescale_img is None:
        rescale_img = principal_image
    else:
        rescale_img = pca.transform(rescale_img)[0]
    if rescale:
        min = np.min(rescale_img, (1,2))
        max = np.max(rescale_img, (1,2))
        print(min, max)
        principal_image = (principal_image.copy() - min[:, None, None]) / (max - min)[:, None, None]
    return principal_image

# Compute principal components of background-corrected and normalized sample
principal_background_corrected_image_R = compute_rescale_principal_components(background_corrected_image_R, pca_R, None)
principal_background_corrected_image_S = compute_rescale_principal_components(background_corrected_image_S, pca_S, None)
principal_background_corrected_image_T = compute_rescale_principal_components(background_corrected_image_T, pca_T, None)
principal_normalized_image_R = compute_rescale_principal_components(normalized_image_R, pca_R, None)
principal_normalized_image_S = compute_rescale_principal_components(normalized_image_S, pca_S, None)
principal_normalized_image_T = compute_rescale_principal_components(normalized_image_T, pca_T, None)

def plot_surface_channels(image, pca, rescale_img, title, mask=None, rescale=True):
    """ Plot principal components of estimated surfaces and plot them in 3D axis per modality. """

    # Compute principal components
    image = compute_rescale_principal_components(image, pca, rescale_img, rescale)

    # Create grid for X, Y coordinates
    H, W = image.shape[1:]
    x = np.arange(W)
    y = np.arange(H)
    x, y = np.meshgrid(x, y)

    # Extract pixels by the mask
    if not mask is None:
        component_1 = np.where(mask == 1, image[0], np.nan)
        component_2 = np.where(mask == 1, image[1], np.nan)
        component_3 = np.where(mask == 1, image[2], np.nan)
    else:
        component_1 = image[0]
        component_2 = image[1]
        component_3 = image[2]

    # Plot three principal components in 3D plot with colors corresponding to RGB color channels
    fig = plt.figure(figsize=(15, 6))
    # Plot component 1 in x-z plane
    ax1 = fig.add_subplot(131, projection='3d')
    ax1.plot_surface(x, y, component_1, cmap='Reds', edgecolor='none')
    ax1.set_title('Principal Component 1')
    ax1.set_xlabel('X')
    ax1.set_xlim(0, W-1)
    ax1.set_ylabel('Y')
    ax1.set_ylim(H-1, 0)
    ax1.set_zlabel('Intensity')
    ax1.set_zlim([0, 1])
    # Plot component 2 in y-z plane
    ax2 = fig.add_subplot(132, projection='3d')
    ax2.plot_surface(x, y, component_2, cmap='Greens', edgecolor='none')
    ax2.set_title('Principal Component 2')
    ax2.set_xlabel('X')
    ax2.set_xlim(0, W-1)
    ax2.set_ylabel('Y')
    ax2.set_ylim(H-1, 0)
    ax2.set_zlabel('Intensity')
    ax2.set_zlim([0, 1])
    # Plot component 3 in x-y plane
    ax3 = fig.add_subplot(133, projection='3d')
    ax3.plot_surface(x, y, component_3, cmap='Blues', edgecolor='none')
    ax3.set_title('Principal Component 3')
    ax3.set_xlabel('X')
    ax3.set_xlim(0, W-1)
    ax3.set_ylabel('Y')
    ax3.set_ylim(H-1, 0)
    ax3.set_zlabel('Intensity')
    ax3.set_zlim([0, 1])
    plt.suptitle(title, fontsize=24)
    plt.tight_layout()
    plt.show()

# Plot reflectance mode normalization process
print("Reflected Image Normalization")
plt.imshow(principal_image_R.transpose(1, 2, 0))
plt.axis(False)
plt.title('Reflected Image Principal Components')
plt.show()
plot_surface_channels(image_R, pca_R, None, "Original Reflected Image Pricipal Components")
plot_surface_channels(image_R, pca_R, None, "Reflected Image Isolated Background Pricipal Components", background_mask)
plot_surface_channels(fitted_background_surface_R, pca_R, image_R, "Reflected Image Fitted Background Pricipal Components")
plot_surface_channels(background_corrected_image_R, pca_R, None, "Normalized Reflected Image Isolated Background Pricipal Components", background_mask)
plt.imshow(principal_background_corrected_image_R.transpose(1, 2, 0))
plt.axis(False)
plt.title('Background Normalized Reflected Image Principal Components')
plt.show()
plot_surface_channels(background_corrected_image_R, pca_R, None, "Background Normalized Reflected Image Pricipal Components")
plot_surface_channels(abs(background_corrected_image_R), pca_R, None, "Reflected Image Isolated Cells Pricipal Components", cell_mask)
print(np.min(fitted_cell_surface_R, (1,2)))
plot_surface_channels(fitted_cell_surface_R, pca_R, background_corrected_image_R, "Reflected Image Fitted Cell Pricipal Components")
plot_surface_channels(normalized_image_R, pca_R, None, "Normalized Reflected Image Pricipal Components")
plt.imshow(principal_normalized_image_R.transpose(1, 2, 0))
plt.axis(False)
plt.title('Normalized Reflected Image Principal Components')
plt.show()

# Plot scattering mode normalization process
print("Scattered Image Normalization")
plt.imshow(principal_image_S.transpose(1, 2, 0))
plt.axis(False)
plt.title('Scattered Image Principal Components')
plt.show()
plot_surface_channels(image_S, pca_S, None, "Original Scattered Image Pricipal Components")
plot_surface_channels(image_S, pca_S, None, "Scattered Image Isolated Background Pricipal Components", background_mask)
plot_surface_channels(fitted_background_surface_S, pca_S, image_S, "Scattered Image Fitted Background Pricipal Components")
plot_surface_channels(background_corrected_image_S, pca_S, None, "Normalized Scattered Image Isolated Background Pricipal Components", background_mask)
plt.imshow(principal_background_corrected_image_S.transpose(1, 2, 0))
plt.axis(False)
plt.title('Background Normalized Scattered Image Principal Components')
plt.show()
plot_surface_channels(background_corrected_image_S, pca_S, None, "Background Normalized Scattered Image Pricipal Components")
plot_surface_channels(abs(background_corrected_image_S), pca_S, None, "Scattered Image Isolated Cells Pricipal Components", cell_mask)
print(np.min(fitted_cell_surface_S, (1,2)))
plot_surface_channels(fitted_cell_surface_S, pca_S, background_corrected_image_S, "Scattered Image Fitted Cell Pricipal Components")
plot_surface_channels(normalized_image_S, pca_S, None, "Normalized Scattered Image Pricipal Components")
plt.imshow(principal_normalized_image_S.transpose(1, 2, 0))
plt.axis(False)
plt.title('Normalized Scattered Image Principal Components')
plt.show()

# Plot transmittance mode normalization process
print("Transmitted Image Normalization")
plt.imshow(principal_image_T.transpose(1, 2, 0))
plt.axis(False)
plt.title('Transmitted Image Principal Components')
plt.show()
plot_surface_channels(image_T, pca_T, None, "Original Transmitted Image Pricipal Components")
plot_surface_channels(image_T, pca_T, None, "Transmitted Image Isolated Background Pricipal Components", background_mask)
plot_surface_channels(fitted_background_surface_T, pca_T, image_T, "Transmitted Image Fitted Background Pricipal Components")
plot_surface_channels(background_corrected_image_T, pca_T, None, "Normalized Transmitted Image Isolated Background Pricipal Components", background_mask)
plt.imshow(principal_background_corrected_image_T.transpose(1, 2, 0))
plt.axis(False)
plt.title('Background Normalized Transmitted Image Principal Components')
plt.show()
plot_surface_channels(background_corrected_image_T, pca_T, None, "Background Normalized Transmitted Image Pricipal Components")
plot_surface_channels(abs(background_corrected_image_T), pca_T, None, "Transmitted Image Isolated Cells Pricipal Components", cell_mask)
# plot_surface_channels(normalized_background_image_T, pca_T, None, "Transmitted Image Isolated Cells Pricipal Components", cell_mask)
print(np.min(fitted_cell_surface_T, (1,2)))
plot_surface_channels(fitted_cell_surface_T, pca_T, background_corrected_image_T, "Transmitted Image Fitted Cell Pricipal Components")
plot_surface_channels(normalized_image_T, pca_T, None, "Normalized Transmitted Image Pricipal Components")
plt.imshow(principal_normalized_image_T.transpose(1, 2, 0))
plt.axis(False)
plt.title('Normalized Transmitted Image Principal Components')
plt.show()

print("Gathered Results")

# Plot original image principal components
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(principal_image_R.transpose(1, 2, 0))
axs[1].imshow(principal_image_S.transpose(1, 2, 0))
axs[2].imshow(principal_image_T.transpose(1, 2, 0))
axs[0].set_title("Reflected Principal Components")
axs[1].set_title("Scattered Principal Components")
axs[2].set_title("Transmitted Principal Components")
for ax in axs:
    ax.axis(False)
plt.suptitle("Original Image Principal Components", fontsize=24)
plt.tight_layout()
plt.show()

# Plot background-corrected image principal components
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(principal_background_corrected_image_R.transpose(1, 2, 0))
axs[1].imshow(principal_background_corrected_image_S.transpose(1, 2, 0))
axs[2].imshow(principal_background_corrected_image_T.transpose(1, 2, 0))
axs[0].set_title("Reflected Principal Components")
axs[1].set_title("Scattered Principal Components")
axs[2].set_title("Transmitted Principal Components")
for ax in axs:
    ax.axis(False)
plt.suptitle("Background Normalized Image Principal Components", fontsize=24)
plt.tight_layout()
plt.show()

# Plot normalized image principal components
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(principal_normalized_image_R.transpose(1, 2, 0))
axs[1].imshow(principal_normalized_image_S.transpose(1, 2, 0))
axs[2].imshow(principal_normalized_image_T.transpose(1, 2, 0))
axs[0].set_title("Reflected Principal Components")
axs[1].set_title("Scattered Principal Components")
axs[2].set_title("Transmitted Principal Components")
for ax in axs:
    ax.axis(False)
plt.suptitle("Normalized Image Principal Components", fontsize=24)
plt.tight_layout()
plt.show()


Display detailed figures of an example channel of the reflectance mode. These are also included in the report. 

In [None]:
# Figures for report

# Extract an example reflectance mode channel
c = 1
example_channel = image_R[c]
example_channel_background_surface = fitted_background_surface_R[c]
example_channel_background_normalized = background_corrected_image_R[c]
example_channel_cell_surface = fitted_cell_surface_R[c]
example_channel_normalized = normalized_image_R[c]

# Extract coordinates
H, W = example_channel.shape
x = np.arange(W)
y = np.arange(H)
x, y = np.meshgrid(x, y)

# Plot parameters
aspect_ratio = [1, H/W, 0.5]
fontsize_ticks = 8
fontsize_axislabels = 9
fontsize_titles = 12

# 6 detailed plots of the background normalization process
fig = plt.figure(figsize=(10, 15))

ax = fig.add_subplot(321, projection='3d')
ax.plot_surface(x, y, example_channel, cmap='gray', edgecolor='none')
ax.text2D(0.5, 0.91, 'Original Image Channel', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax.set_xlim(0, W-1)
ax.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax.set_ylim(H-1, 0)
ax.text2D(1.05, 0.73, 'intensity', transform=ax.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax.set_zlim([0, 1])
ax.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax.set_box_aspect(aspect_ratio)
ax.text2D(0.5, 0, 'a)', transform=ax.transAxes, ha='center', fontsize=fontsize_axislabels)

ax = fig.add_subplot(322, projection='3d')
example_channel_masked = np.where(background_mask == 1, example_channel, np.nan)
ax.plot_surface(x, y, example_channel_masked, cmap='gray', edgecolor='none')
ax.text2D(0.5, 0.91, 'Isolated Background Pixels', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax.set_xlim(0, W-1)
ax.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax.set_ylim(H-1, 0)
ax.text2D(1.05, 0.73, 'intensity', transform=ax.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax.set_zlim([0, 1])
ax.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax.set_box_aspect(aspect_ratio)
ax.text2D(0.5, 0, 'b)', transform=ax.transAxes, ha='center', fontsize=fontsize_axislabels)

ax = fig.add_subplot(323, projection='3d')
ax.plot_surface(x, y, example_channel_background_surface, cmap='gray', edgecolor='none')
ax.text2D(0.5, 0.91, 'Fitted Background Surface', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax.set_xlim(0, W-1)
ax.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax.set_ylim(H-1, 0)
ax.text2D(1.05, 0.73, 'intensity', transform=ax.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax.set_zlim([0, 1])
ax.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax.set_box_aspect(aspect_ratio)
ax.text2D(0.5, 0, 'c)', transform=ax.transAxes, ha='center', fontsize=fontsize_axislabels)

ax = fig.add_subplot(324, projection='3d')
example_channel_background_normalized_masked = np.where(background_mask == 1, example_channel_background_normalized, np.nan)
ax.plot_surface(x, y, example_channel_background_normalized_masked, cmap='gray', edgecolor='none')
ax.text2D(0.5, 0.91, 'Corrected Background Pixels', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax.set_xlim(0, W-1)
ax.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax.set_ylim(H-1, 0)
ax.text2D(1.05, 0.73, 'intensity', transform=ax.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax.set_zlim([-1, 1])
ax.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax.set_box_aspect(aspect_ratio)
ax.text2D(0.5, 0, 'd)', transform=ax.transAxes, ha='center', fontsize=fontsize_axislabels)

ax = fig.add_subplot(325, projection='3d')
ax.plot_surface(x, y, example_channel_background_normalized, cmap='gray', edgecolor='none')
ax.text2D(0.5, 0.91, 'Background-Corrected Image Channel', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax.set_xlim(0, W-1)
ax.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax.set_ylim(H-1, 0)
ax.text2D(1.05, 0.73, 'intensity', transform=ax.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax.set_zlim([-1, 1])
ax.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax.set_box_aspect(aspect_ratio)
ax.text2D(0.5, 0, 'e)', transform=ax.transAxes, ha='center', fontsize=fontsize_axislabels)

ax = fig.add_subplot(326)
ax.imshow(example_channel_background_normalized, cmap='gray')
ax.text(0.5, 1.05, 'Background-Corrected Image Channel', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.axis(False)
ax.text(0.5, -0.166, 'f)', transform=ax.transAxes, ha='center', fontsize=fontsize_axislabels)

plt.subplots_adjust(hspace=0)
plt.show()

# 4 detailed plots of the background normalization process
fig = plt.figure(figsize=(5, 20))

ax1 = fig.add_subplot(411, projection='3d')
ax1.plot_surface(x, y, example_channel, cmap='gray', edgecolor='none')
ax1.text2D(0.5, 0.91, 'Original Image Channel', transform=ax1.transAxes, ha='center', fontsize=fontsize_titles)
ax1.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax1.set_xlim(0, W-1)
ax1.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax1.set_ylim(H-1, 0)
ax1.text2D(1.05, 0.73, 'intensity', transform=ax1.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax1.set_zlim([0, 1])
ax1.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax1.set_box_aspect(aspect_ratio)

ax2 = fig.add_subplot(412, projection='3d')
example_channel_masked = np.where(background_mask == 1, example_channel, np.nan)
ax2.plot_surface(x, y, example_channel_masked, cmap='gray', edgecolor='none')
ax2.text2D(0.5, 0.91, 'Isolated Background Pixels', transform=ax2.transAxes, ha='center', fontsize=fontsize_titles)
ax2.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax2.set_xlim(0, W-1)
ax2.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax2.set_ylim(H-1, 0)
ax2.text2D(1.05, 0.73, 'intensity', transform=ax2.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax2.set_zlim([0, 1])
ax2.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax2.set_box_aspect(aspect_ratio)

ax3 = fig.add_subplot(413, projection='3d')
ax3.plot_surface(x, y, example_channel_background_surface, cmap='gray', edgecolor='none')
ax3.text2D(0.5, 0.91, 'Fitted Background Surface', transform=ax3.transAxes, ha='center', fontsize=fontsize_titles)
ax3.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax3.set_xlim(0, W-1)
ax3.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax3.set_ylim(H-1, 0)
ax3.text2D(1.05, 0.73, 'intensity', transform=ax3.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax3.set_zlim([0, 1])
ax3.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax3.set_box_aspect(aspect_ratio)

ax4 = fig.add_subplot(414)
ax4.imshow(example_channel_background_normalized, cmap='gray')
ax4.text(0.5, 1.05, 'Background-Corrected Image Channel', transform=ax4.transAxes, ha='center', fontsize=fontsize_titles)
ax4.axis(False)

ax4.text(0.5, 0.03, 'a)', transform=ax1.transAxes, ha='center', fontsize=fontsize_axislabels)
ax4.text(0.5, 0.03, 'b)', transform=ax2.transAxes, ha='center', fontsize=fontsize_axislabels)
ax4.text(0.5, 0.03, 'c)', transform=ax3.transAxes, ha='center', fontsize=fontsize_axislabels)
ax4.text(0.5, -0.08, 'd)', transform=ax4.transAxes, ha='center', fontsize=fontsize_axislabels)

plt.subplots_adjust(hspace=-0.1)
plt.show()

# Compute sampled values of local average cell intensity magnitude for visualization
def sample_local_average_magnitude(
        channel: np.ndarray, 
        mask: np.ndarray, 
        sample_fraction: float
    ) -> tuple[np.ndarray, np.ndarray]:
        """ Samples local average cell intensity magnitudes from channel.
        
        Computes the local average value by computing the distance between the point and all other points 
        and taking the average of the nearest neighborhood from a sampled subset points in the image. 
        
        Args: 
            channel: A one channel image (H, W) from which to sample local average intensity magnitudes.
            mask: Binary mask of shape (H, W), where 1 indicates cell pixels.
            sample_fraction: The fraction of values to sample.
        
        Returns:
            tuple - Sampled points (x and y coordinates) and corresponding values of local average cell 
            intensity magnitude.
        """
        
        # Get channel magnitudes instead of intensity
        channel = abs(channel)

        # Extract x, y coordinates, z values and mask values
        H, W = channel.shape
        x, y = np.meshgrid(np.arange(W), np.arange(H))
        x = x.ravel()
        y = y.ravel()
        z = channel.ravel()
        m = mask.ravel()

        # Filter by mask
        x_mask = x[m == 1]
        y_mask = y[m == 1]
        z_mask = z[m == 1]

        # Sample cell pixels to compute the local average from 
        # (10 times as many as the sample fraction of the subset we are returning)
        num_cell_samples = int(sample_fraction * 10 * len(x_mask))
        idx = np.random.choice(len(x_mask), size=num_cell_samples, replace=False)
        x_cell_sample = x_mask[idx]
        y_cell_sample = y_mask[idx]
        z_cell_sample = z_mask[idx]
        coord_cell_sample = np.vstack((x_cell_sample, y_cell_sample)).T

        # Sample points to compute local average at
        num_samples = int(sample_fraction * len(x))
        idx = np.random.choice(len(x), size=num_samples, replace=False)
        x_sample = x[idx]
        y_sample = y[idx]
        coord_sample = np.vstack([x_sample, y_sample]).T

        # Set neighborhood size
        N = 50000           # The neighborhood size in the original image (before sampling)
        num_neighbors = int(N*sample_fraction*10)   # The neighborhood size after sampling

        def sample_mean_value(i: int) -> tuple[np.ndarray, float]:
            """ Compute the local average value. """
            distances = np.linalg.norm(coord_cell_sample - coord_sample[i], axis=1)
            nearest_idx = np.argsort(distances)[:num_neighbors]
            mean_value = np.mean(z_cell_sample[nearest_idx])
            return (coord_sample[i], mean_value)
        
        # Sample local average values in parallel
        results = Parallel(n_jobs=-1)(
            delayed(sample_mean_value)(i) for i in range(num_samples)
        )

        # Collect results into two arrays for the sampled coordinates and computed values
        sampled_points = np.array([result[0] for result in results])
        sampled_values = np.array([result[1] for result in results])
        return sampled_points, sampled_values

sampled_points, sampled_values = sample_local_average_magnitude(example_channel_background_normalized, cell_mask, 0.001)

# 6 detailed plots of the cell intensity magnitude normalization process
fig = plt.figure(figsize=(10, 15))

ax = fig.add_subplot(321, projection='3d')
ax.plot_surface(x, y, example_channel_background_normalized, cmap='gray', edgecolor='none')
ax.text2D(0.5, 0.91, 'Background-Corrected Image Channel', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax.set_xlim(0, W-1)
ax.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax.set_ylim(H-1, 0)
ax.text2D(1.05, 0.73, 'intensity', transform=ax.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax.set_zlim([-0.5, 0.5])
ax.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax.set_box_aspect(aspect_ratio)
ax.text2D(0.5, 0, 'a)', transform=ax.transAxes, ha='center', fontsize=fontsize_axislabels)

ax = fig.add_subplot(322, projection='3d')
example_channel_masked = np.where(cell_mask == 1, abs(example_channel_background_normalized), np.nan)
ax.plot_surface(x, y, example_channel_masked, cmap='gray', edgecolor='none')
ax.text2D(0.5, 0.91, 'Intensity Magnitude of Isolated Cell Pixels', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax.set_xlim(0, W-1)
ax.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax.set_ylim(H-1, 0)
ax.text2D(1.05, 0.73, 'intensity', transform=ax.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax.set_zlim([0, 1])
ax.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax.set_box_aspect(aspect_ratio)
ax.text2D(0.5, 0, 'b)', transform=ax.transAxes, ha='center', fontsize=fontsize_axislabels)

ax = fig.add_subplot(323, projection='3d')
ax.plot_trisurf(sampled_points[:,0], sampled_points[:,1], sampled_values, cmap='gray', edgecolor='none')
ax.text2D(0.5, 0.91, 'Local Average Intensity Magnitude', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax.set_xlim(0, W-1)
ax.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax.set_ylim(H-1, 0)
ax.text2D(1.05, 0.73, 'intensity', transform=ax.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax.set_zlim([0, 1])
ax.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax.set_box_aspect(aspect_ratio)
ax.text2D(0.5, 0, 'c)', transform=ax.transAxes, ha='center', fontsize=fontsize_axislabels)

ax = fig.add_subplot(324, projection='3d')
ax.plot_surface(x, y, example_channel_cell_surface, cmap='gray', edgecolor='none')
ax.text2D(0.5, 0.91, 'Estimated Intensity Magnitude Surface', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax.set_xlim(0, W-1)
ax.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax.set_ylim(H-1, 0)
ax.text2D(1.05, 0.73, 'intensity', transform=ax.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax.set_zlim([0, 1])
ax.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax.set_box_aspect(aspect_ratio)
ax.text2D(0.5, 0, 'd)', transform=ax.transAxes, ha='center', fontsize=fontsize_axislabels)

ax = fig.add_subplot(325, projection='3d')
example_channel_masked = np.where(cell_mask == 1, abs(example_channel_normalized), np.nan)
ax.plot_surface(x, y, example_channel_masked, cmap='gray', edgecolor='none')
ax.text2D(0.5, 0.91, 'Normalized Intensity Magnitude of Isolated Cell Pixels', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax.set_xlim(0, W-1)
ax.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax.set_ylim(H-1, 0)
ax.text2D(1.05, 0.73, 'intensity', transform=ax.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax.set_zlim([-1, 2])
ax.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax.set_box_aspect(aspect_ratio)
ax.text2D(0.5, 0, 'e)', transform=ax.transAxes, ha='center', fontsize=fontsize_axislabels)

ax = fig.add_subplot(326)
ax.imshow(example_channel_normalized, cmap='gray')
ax.text(0.5, 1.05, 'Normalized Image Channel', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.axis(False)
ax.text(0.5, -0.166, 'f)', transform=ax.transAxes, ha='center', fontsize=fontsize_axislabels)

plt.subplots_adjust(hspace=0)
plt.show()

# 4 detailed plots of the cell intensity magnitude normalization process
fig = plt.figure(figsize=(5, 20))

ax1 = fig.add_subplot(411, projection='3d')
ax1.plot_surface(x, y, example_channel_background_normalized, cmap='gray', edgecolor='none')
ax1.text2D(0.5, 0.91, 'Background-Corrected Image Channel', transform=ax1.transAxes, ha='center', fontsize=fontsize_titles)
ax1.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax1.set_xlim(0, W-1)
ax1.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax1.set_ylim(H-1, 0)
ax1.text2D(1.05, 0.73, 'intensity', transform=ax1.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax1.set_zlim([-0.5, 0.5])
ax1.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax1.set_box_aspect(aspect_ratio)

ax2 = fig.add_subplot(412, projection='3d')
example_channel_masked = np.where(cell_mask == 1, abs(example_channel_background_normalized), np.nan)
ax2.plot_surface(x, y, example_channel_masked, cmap='gray', edgecolor='none')
ax2.text2D(0.5, 0.91, 'Intensity Magnitude of Isolated Cell Pixels', transform=ax2.transAxes, ha='center', fontsize=fontsize_titles)
ax2.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax2.set_xlim(0, W-1)
ax2.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax2.set_ylim(H-1, 0)
ax2.text2D(1.05, 0.73, 'intensity', transform=ax2.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax2.set_zlim([0, 1])
ax2.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax2.set_box_aspect(aspect_ratio)

ax3 = fig.add_subplot(413, projection='3d')
ax3.plot_surface(x, y, example_channel_cell_surface, cmap='gray', edgecolor='none')
ax3.text2D(0.5, 0.91, 'Estimated Intensity Magnitude Surface', transform=ax3.transAxes, ha='center', fontsize=fontsize_titles)
ax3.set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
ax3.set_xlim(0, W-1)
ax3.set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
ax3.set_ylim(H-1, 0)
ax3.text2D(1.05, 0.73, 'intensity', transform=ax3.transAxes, ha='center', va='bottom', rotation=0, fontsize=fontsize_axislabels)
ax3.set_zlim([0, 1])
ax3.tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
ax3.set_box_aspect(aspect_ratio)

ax4 = fig.add_subplot(414)
ax4.imshow(example_channel_normalized, cmap='gray')
ax4.text(0.5, 1.05, 'Normalized Image Channel', transform=ax4.transAxes, ha='center', fontsize=fontsize_titles)
ax4.axis(False)

ax4.text(0.5, 0.03, 'a)', transform=ax1.transAxes, ha='center', fontsize=fontsize_axislabels)
ax4.text(0.5, 0.03, 'b)', transform=ax2.transAxes, ha='center', fontsize=fontsize_axislabels)
ax4.text(0.5, 0.03, 'c)', transform=ax3.transAxes, ha='center', fontsize=fontsize_axislabels)
ax4.text(0.5, -0.08, 'd)', transform=ax4.transAxes, ha='center', fontsize=fontsize_axislabels)

plt.subplots_adjust(hspace=-0.1)
plt.show()

# Plot the one-channel average image
combined_image = minmax_scale(image_R.mean(0)+(1-image_S.mean(0))+image_T.mean(0))
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111)
ax.imshow(combined_image, cmap='gray')
ax.text(0.5, 1.05, 'One-Channel Average Image', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.axis(False)
plt.show()

# Plot the original image channel
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111)
ax.imshow(example_channel, cmap='gray')
ax.text(0.5, 1.05, 'Example Image Channel', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.axis(False)
plt.show()

# Plot the original image channel with the background mask
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111)
background_image = example_channel.copy()
background_image[background_mask == 0] = 0
ax.imshow(background_image, cmap='gray', vmin=example_channel.min(), vmax=example_channel.max())
ax.text(0.5, 1.05, 'Background Mask', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.axis(False)
plt.show()

# Plot the original image channel with the cell mask
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111)
cell_image = example_channel.copy()
cell_image[cell_mask == 0] = 0
ax.imshow(cell_image, cmap='gray', vmin=example_channel.min(), vmax=example_channel.max())
ax.text(0.5, 1.05, 'Cell Mask', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.axis(False)
plt.show()

# Plot the background-corrected image channel
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111)
ax.imshow(example_channel_background_normalized, cmap='gray')
ax.text(0.5, 1.05, 'Background-Corrected Image Channel', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.axis(False)
plt.show()

# Plot the normalized image channel
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111)
ax.imshow(example_channel_normalized, cmap='gray')
ax.text(0.5, 1.05, 'Normalized Image Channel', transform=ax.transAxes, ha='center', fontsize=fontsize_titles)
ax.axis(False)
plt.show()

To check how well the normalization worked, we can fit a new PCA to better see remaining variance in the now normalized sample.

In [None]:
# Fit PCA models to image and compute principal components
norm_pca_R = ImagePCA()
norm_pca_S = ImagePCA()
norm_pca_T = ImagePCA()
norm_pca_R.fit(normalized_image_R)
norm_pca_S.fit(normalized_image_S)
norm_pca_T.fit(normalized_image_T)
principal_components_normalized_image_R = norm_pca_R.transform(normalized_image_R)[0]
principal_components_normalized_image_S = norm_pca_S.transform(normalized_image_S)[0]
principal_components_normalized_image_T = norm_pca_T.transform(normalized_image_T)[0]

# Rescale images to range [0,1]
norm_min_R = np.min(principal_components_normalized_image_R, (1,2))
norm_max_R = np.max(principal_components_normalized_image_R, (1,2))
principal_components_normalized_image_R = (principal_components_normalized_image_R - norm_min_R[:, None, None]) / (norm_max_R - norm_min_R)[:, None, None]
norm_min_S = np.min(principal_components_normalized_image_S, (1,2))
norm_max_S = np.max(principal_components_normalized_image_S, (1,2))
principal_components_normalized_image_S = (principal_components_normalized_image_S - norm_min_S[:, None, None]) / (norm_max_S - norm_min_S)[:, None, None]
norm_min_T = np.min(principal_components_normalized_image_T, (1,2))
norm_max_T = np.max(principal_components_normalized_image_T, (1,2))
principal_components_normalized_image_T = (principal_components_normalized_image_T - norm_min_T[:, None, None]) / (norm_max_T - norm_min_T)[:, None, None]

# Plot results
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].imshow(principal_components_normalized_image_R.transpose(1, 2, 0))
axs[1].imshow(principal_components_normalized_image_S.transpose(1, 2, 0))
axs[2].imshow(principal_components_normalized_image_T.transpose(1, 2, 0))
axs[0].set_title("Reflected Principal Components")
axs[1].set_title("Scattered Principal Components")
axs[2].set_title("Transmitted Principal Components")
for ax in axs:
    ax.axis(False)
plt.suptitle("Normalized Image Principal Components", fontsize=24)
plt.tight_layout()
plt.show()

## Part 2 - Comparison of Multiple Examples

This part contains code for visualization of the comparison of multiple example samples. 

### Normalize 5 Random Samples

In [None]:
# Define the number of elements to compare
n_samples = 5
idxs = np.random.randint(0, len(dataset), n_samples)

# Store results in lists for later comparison
image_R_list = []
image_S_list = []
image_T_list = []

fitted_background_R_list = []
fitted_background_S_list = []
fitted_background_T_list = []

background_normalized_image_R_list = []
background_normalized_image_S_list = []
background_normalized_image_T_list = []

fitted_cell_plane_R_list = []
fitted_cell_plane_S_list = []
fitted_cell_plane_T_list = []

normalized_image_R_list = []
normalized_image_S_list = []
normalized_image_T_list = []

# Define CellPose segmentation model
segmentation_model = models.Cellpose(model_type='cyto3', gpu=torch.cuda.is_available(), device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'))

# Process samples by randomly chosen index in idxs
for i, idx in tqdm(enumerate(idxs)):
    print(f"Image {i+1}/{n_samples}: {idx}")

    # Retrieve the dataset element
    data = dataset[idx]

    # Extract the sample
    sample = data['sample']

    # Compute background binary mask
    background_mask, cell_mask = get_masks(sample, segmentation_model, plot=False, verbose=True)

    # Background normalization
    background_corrected_sample, fitted_background_surface = fit_polynomial_background(sample, background_mask, verbose=True)

    # Cell intensity magnitude normalization
    normalized_sample, fitted_cell_surface = fit_gpr_surface(background_corrected_sample, cell_mask, verbose=True)

    # Extract sample modalities
    image_R = sample[0]
    image_S = sample[1]
    image_T = sample[2]

    fitted_background_surface_R = fitted_background_surface[0]
    fitted_background_surface_S = fitted_background_surface[1]
    fitted_background_surface_T = fitted_background_surface[2]

    background_corrected_image_R = background_corrected_sample[0]
    background_corrected_image_S = background_corrected_sample[1]
    background_corrected_image_T = background_corrected_sample[2]

    fitted_cell_surface_R = fitted_cell_surface[0]
    fitted_cell_surface_S = fitted_cell_surface[1]
    fitted_cell_surface_T = fitted_cell_surface[2]

    normalized_image_R = normalized_sample[0]
    normalized_image_S = normalized_sample[1]
    normalized_image_T = normalized_sample[2]

    # Store results
    image_R_list.append(image_R)
    image_S_list.append(image_S)
    image_T_list.append(image_T)
    
    fitted_background_R_list.append(fitted_background_surface_R)
    fitted_background_S_list.append(fitted_background_surface_S)
    fitted_background_T_list.append(fitted_background_surface_T)
    
    background_normalized_image_R_list.append(background_corrected_image_R)
    background_normalized_image_S_list.append(background_corrected_image_S)
    background_normalized_image_T_list.append(background_corrected_image_T)

    fitted_cell_plane_R_list.append(fitted_cell_surface_R)
    fitted_cell_plane_S_list.append(fitted_cell_surface_S)
    fitted_cell_plane_T_list.append(fitted_cell_surface_T)

    normalized_image_R_list.append(normalized_image_R)
    normalized_image_S_list.append(normalized_image_S)
    normalized_image_T_list.append(normalized_image_T)
    
# Reshape into numpy arrays
images = np.stack([np.stack(image_R_list), 
                   np.stack(image_S_list), 
                   np.stack(image_T_list)])

fitted_backgrounds = np.stack([np.stack(fitted_background_R_list), 
                               np.stack(fitted_background_S_list), 
                               np.stack(fitted_background_T_list)])

background_normalized_images = np.stack([np.stack(background_normalized_image_R_list),
                                         np.stack(background_normalized_image_S_list),
                                         np.stack(background_normalized_image_T_list)])

fitted_cell_planes = np.stack([np.stack(fitted_cell_plane_R_list),
                               np.stack(fitted_cell_plane_S_list),
                               np.stack(fitted_cell_plane_T_list)])

normalized_images = np.stack([np.stack(normalized_image_R_list), 
                              np.stack(normalized_image_S_list), 
                              np.stack(normalized_image_T_list)])

Before plotting results, we compute principal components for compact visualization of variance.

In [None]:
# Initialize PCA for each channel
pca_R = ImagePCA(n_components=3)
pca_S = ImagePCA(n_components=3)
pca_T = ImagePCA(n_components=3)

# Fit PCA to each image
pca_R.fit(images[0])
pca_S.fit(images[1])
pca_T.fit(images[2])

# Compute principal components of results
principal_images = np.stack([pca.transform(image) for (image, pca) in zip(images, [pca_R, pca_S, pca_T])])
principal_fitted_backgrounds = np.stack([pca.transform(image) for (image, pca) in zip(fitted_backgrounds, [pca_R, pca_S, pca_T])])
principal_background_normalized_images = np.stack([pca.transform(image) for (image, pca) in zip(background_normalized_images, [pca_R, pca_S, pca_T])])
principal_fitted_cell_planes = np.stack([pca.transform(image) for (image, pca) in zip(fitted_cell_planes, [pca_R, pca_S, pca_T])])
principal_normalized_images = np.stack([pca.transform(image) for (image, pca) in zip(normalized_images, [pca_R, pca_S, pca_T])])
print(principal_images.shape)

# Compute min and max values before and after normalization
min_principal = np.min(principal_images, (1,3,4))
max_principal = np.max(principal_images, (1,3,4))
min_bg_norm_principal = np.min(principal_background_normalized_images, (1,3,4))
max_bg_norm_principal = np.max(principal_background_normalized_images, (1,3,4))
min_norm_principal = np.min(principal_normalized_images, (1,3,4))
max_norm_principal = np.max(principal_normalized_images, (1,3,4))

# Rescale principal components to range [0,1]
principal_images = (principal_images - min_principal[:,None,:,None,None]) / (max_principal - min_principal)[:,None,:,None,None]
principal_fitted_backgrounds = (principal_fitted_backgrounds - min_principal[:,None,:,None,None]) / (max_principal - min_principal)[:,None,:,None,None]
principal_background_normalized_images = (principal_background_normalized_images - min_bg_norm_principal[:,None,:,None,None]) / (max_bg_norm_principal - min_bg_norm_principal)[:,None,:,None,None]
principal_fitted_cell_planes = (principal_fitted_cell_planes - min_bg_norm_principal[:,None,:,None,None]) / (np.maximum(max_bg_norm_principal, abs(min_bg_norm_principal)) - min_bg_norm_principal)[:,None,:,None,None]
principal_normalized_images = (principal_normalized_images - min_norm_principal[:,None,:,None,None]) / (max_norm_principal - min_norm_principal)[:,None,:,None,None]

# Initialize PCA for each channel
norm_pca_R = ImagePCA(n_components=3)
norm_pca_S = ImagePCA(n_components=3)
norm_pca_T = ImagePCA(n_components=3)

# Fit PCA to each image
norm_pca_R.fit(normalized_images[0])
norm_pca_S.fit(normalized_images[1])
norm_pca_T.fit(normalized_images[2])

# Compute principal components of results
principal_normalized_images = np.stack([pca.transform(image) for (image, pca) in zip(normalized_images, [norm_pca_R, norm_pca_S, norm_pca_T])])

# Compute min and max values before and after normalization
min_norm_principal = np.min(principal_normalized_images, (1,3,4))
max_norm_principal = np.max(principal_normalized_images, (1,3,4))

# Rescale principal components to range [0,1]
principal_normalized_images = (principal_normalized_images - min_norm_principal[:,None,:,None,None]) / (max_norm_principal - min_norm_principal)[:,None,:,None,None]

### Visualization

Visualize process for comparison

In [None]:
# Create grid for X, Y coordinates
H, W = images.shape[3:]
x = np.arange(W)
y = np.arange(H)
x, y = np.meshgrid(x, y)

# Plot parameters
plane_alpha = 0.5
data_alpha = 0.05
fontsize_ticks = 8
fontsize_axislabels = 9
aspect_ratio = [1, H/W, 0.5]

# Plot original samples
fig, axs = plt.subplots(n_samples, 3, figsize=(15, 20))
for i in range(n_samples):
    axs[i, 0].imshow(principal_images[0,i].transpose(1, 2, 0))
    axs[i, 1].imshow(principal_images[1,i].transpose(1, 2, 0))
    axs[i, 2].imshow(principal_images[2,i].transpose(1, 2, 0))
axs[0, 0].set_title("Reflected Principal Components")
axs[0, 1].set_title("Scattered Principal Components")
axs[0, 2].set_title("Transmitted Principal Components")
for ax in axs.flatten():
    ax.axis(False)
plt.suptitle("Original Image Principal Components", fontsize=24)
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.show()

# Plot fitted background surfaces in 2D
fig, axs = plt.subplots(n_samples, 3, figsize=(15, 20))
for i in range(n_samples):
    axs[i, 0].imshow(principal_fitted_backgrounds[0,i].transpose(1, 2, 0))
    axs[i, 1].imshow(principal_fitted_backgrounds[1,i].transpose(1, 2, 0))
    axs[i, 2].imshow(principal_fitted_backgrounds[2,i].transpose(1, 2, 0))
axs[0, 0].set_title("Reflected Principal Components")
axs[0, 1].set_title("Scattered Principal Components")
axs[0, 2].set_title("Transmitted Principal Components")
for ax in axs.flatten():
    ax.axis(False)
plt.suptitle("Fitted Background Surface Principal Components", fontsize=24)
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.show()

# Plot fitted background surfaces in 3D
fig, axs = plt.subplots(n_samples, 3, figsize=(15, 20), subplot_kw={'projection': '3d'})
for i in range(n_samples):
    axs[i, 0].plot_surface(x, y, principal_fitted_backgrounds[0,i,0], cmap='Reds', edgecolor='none', alpha=plane_alpha)
    axs[i, 0].plot_surface(x, y, principal_fitted_backgrounds[0,i,1], cmap='Greens', edgecolor='none', alpha=plane_alpha)
    axs[i, 0].plot_surface(x, y, principal_fitted_backgrounds[0,i,2], cmap='Blues', edgecolor='none', alpha=plane_alpha)
    axs[i, 0].set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
    axs[i, 0].set_xlim(0, W-1)
    axs[i, 0].set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
    axs[i, 0].set_ylim(H-1, 0)
    axs[i, 0].set_zlim([0, 1])
    axs[i, 0].tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
    axs[i, 0].set_box_aspect(aspect_ratio)
    axs[i, 1].plot_surface(x, y, principal_fitted_backgrounds[1,i,0], cmap='Reds', edgecolor='none', alpha=plane_alpha)
    axs[i, 1].plot_surface(x, y, principal_fitted_backgrounds[1,i,1], cmap='Greens', edgecolor='none', alpha=plane_alpha)
    axs[i, 1].plot_surface(x, y, principal_fitted_backgrounds[1,i,2], cmap='Blues', edgecolor='none', alpha=plane_alpha)
    axs[i, 1].set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
    axs[i, 1].set_xlim(0, W-1)
    axs[i, 1].set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
    axs[i, 1].set_ylim(H-1, 0)
    axs[i, 1].set_zlim([0, 1])
    axs[i, 1].tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
    axs[i, 1].set_box_aspect(aspect_ratio)
    axs[i, 2].plot_surface(x, y, principal_fitted_backgrounds[2,i,0], cmap='Reds', edgecolor='none', alpha=plane_alpha)
    axs[i, 2].plot_surface(x, y, principal_fitted_backgrounds[2,i,1], cmap='Greens', edgecolor='none', alpha=plane_alpha)
    axs[i, 2].plot_surface(x, y, principal_fitted_backgrounds[2,i,2], cmap='Blues', edgecolor='none', alpha=plane_alpha)
    axs[i, 2].set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
    axs[i, 2].set_xlim(0, W-1)
    axs[i, 2].set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
    axs[i, 2].set_ylim(H-1, 0)
    axs[i, 2].set_zlim([0, 1])
    axs[i, 2].tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
    axs[i, 2].set_box_aspect(aspect_ratio)
axs[0, 0].set_title("Reflected Principal Components")
axs[0, 1].set_title("Scattered Principal Components")
axs[0, 2].set_title("Transmitted Principal Components")
plt.suptitle("Fitted Background Surface Principal Components", fontsize=24)
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.subplots_adjust(hspace=0, wspace=0)
plt.show()

# Plot background-corrected samples
fig, axs = plt.subplots(n_samples, 3, figsize=(15, 20))
for i in range(n_samples):
    axs[i, 0].imshow(principal_background_normalized_images[0,i].transpose(1, 2, 0))
    axs[i, 1].imshow(principal_background_normalized_images[1,i].transpose(1, 2, 0))
    axs[i, 2].imshow(principal_background_normalized_images[2,i].transpose(1, 2, 0))
axs[0, 0].set_title("Reflected Principal Components")
axs[0, 1].set_title("Scattered Principal Components")
axs[0, 2].set_title("Transmitted Principal Components")
for ax in axs.flatten():
    ax.axis(False)
plt.suptitle("Background-Corrected Image Principal Components", fontsize=24)
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.show()

# Plot estimated cell intensity magnitude surfaces in 2D
fig, axs = plt.subplots(n_samples, 3, figsize=(15, 20))
for i in range(n_samples):
    axs[i, 0].imshow(principal_fitted_cell_planes[0,i].transpose(1, 2, 0))
    axs[i, 1].imshow(principal_fitted_cell_planes[1,i].transpose(1, 2, 0))
    axs[i, 2].imshow(principal_fitted_cell_planes[2,i].transpose(1, 2, 0))
axs[0, 0].set_title("Reflected Principal Components")
axs[0, 1].set_title("Scattered Principal Components")
axs[0, 2].set_title("Transmitted Principal Components")
for ax in axs.flatten():
    ax.axis(False)
plt.suptitle("Estimated Intensity Magnitude Surface Principal Components", fontsize=24)
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.show()

# Plot estimated cell intensity magnitude surfaces in 3D
fig, axs = plt.subplots(n_samples, 3, figsize=(15, 20), subplot_kw={'projection': '3d'})
for i in range(n_samples):
    axs[i, 0].plot_surface(x, y, principal_fitted_cell_planes[0,i,0], cmap='Reds', edgecolor='none', alpha=plane_alpha)
    axs[i, 0].plot_surface(x, y, principal_fitted_cell_planes[0,i,1], cmap='Greens', edgecolor='none', alpha=plane_alpha)
    axs[i, 0].plot_surface(x, y, principal_fitted_cell_planes[0,i,2], cmap='Blues', edgecolor='none', alpha=plane_alpha)
    axs[i, 0].set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
    axs[i, 0].set_xlim(0, W-1)
    axs[i, 0].set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
    axs[i, 0].set_ylim(H-1, 0)
    axs[i, 0].set_zlim([0, 1])
    axs[i, 0].tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
    axs[i, 0].set_box_aspect(aspect_ratio)
    axs[i, 1].plot_surface(x, y, principal_fitted_cell_planes[1,i,0], cmap='Reds', edgecolor='none', alpha=plane_alpha)
    axs[i, 1].plot_surface(x, y, principal_fitted_cell_planes[1,i,1], cmap='Greens', edgecolor='none', alpha=plane_alpha)
    axs[i, 1].plot_surface(x, y, principal_fitted_cell_planes[1,i,2], cmap='Blues', edgecolor='none', alpha=plane_alpha)
    axs[i, 1].set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
    axs[i, 1].set_xlim(0, W-1)
    axs[i, 1].set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
    axs[i, 1].set_ylim(H-1, 0)
    axs[i, 1].set_zlim([0, 1])
    axs[i, 1].tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
    axs[i, 1].set_box_aspect(aspect_ratio)
    axs[i, 2].plot_surface(x, y, principal_fitted_cell_planes[2,i,0], cmap='Reds', edgecolor='none', alpha=plane_alpha)
    axs[i, 2].plot_surface(x, y, principal_fitted_cell_planes[2,i,1], cmap='Greens', edgecolor='none', alpha=plane_alpha)
    axs[i, 2].plot_surface(x, y, principal_fitted_cell_planes[2,i,2], cmap='Blues', edgecolor='none', alpha=plane_alpha)
    axs[i, 2].set_xlabel('x', fontsize=fontsize_axislabels, labelpad=0)
    axs[i, 2].set_xlim(0, W-1)
    axs[i, 2].set_ylabel('y', fontsize=fontsize_axislabels, labelpad=0)
    axs[i, 2].set_ylim(H-1, 0)
    axs[i, 2].set_zlim([0, 1])
    axs[i, 2].tick_params(axis='both', which='major', labelsize=fontsize_ticks, pad=0)
    axs[i, 2].set_box_aspect(aspect_ratio)
axs[0, 0].set_title("Reflected Principal Components")
axs[0, 1].set_title("Scattered Principal Components")
axs[0, 2].set_title("Transmitted Principal Components")
plt.suptitle("Estimated Intensity Magnitude Surface Principal Components", fontsize=24)
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.subplots_adjust(hspace=0, wspace=0)
plt.show()

# Plot normalized samples
fig, axs = plt.subplots(n_samples, 3, figsize=(15, 20))
for i in range(n_samples):
    axs[i, 0].imshow(principal_normalized_images[0,i].transpose(1, 2, 0))
    axs[i, 1].imshow(principal_normalized_images[1,i].transpose(1, 2, 0))
    axs[i, 2].imshow(principal_normalized_images[2,i].transpose(1, 2, 0))
axs[0, 0].set_title("Reflected Principal Components")
axs[0, 1].set_title("Scattered Principal Components")
axs[0, 2].set_title("Transmitted Principal Components")
for ax in axs.flatten():
    ax.axis(False)
plt.suptitle("Normalized Image Principal Components", fontsize=24)
plt.tight_layout(rect=[0, 0, 1, 0.98])
plt.show()