<a href="https://colab.research.google.com/github/WinetraubLab/coregister-xy/blob/main/coregister_xy_cell_to_cell.ipynb" target="_blank">
<img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/></a>
<a href="https://github.com/WinetraubLab/coregister-xy/blob/main/coregister_xy_cell_to_cell.ipynb" target="_blank">
  <img src="https://img.shields.io/badge/view%20in-GitHub-blue" alt="View in GitHub"/>
</a>

# Overview
Cell-to-cell alignment of a single 2D fluorescent/histology image to 3D OCT volume.

In [None]:
!pip install cellpose==2.2 --quiet
!git clone https://github.com/WinetraubLab/coregister-xy.git
%cd coregister-xy

In [None]:
# Imports
from cell_to_cell_registration import segment_cells
from cell_to_cell_registration import PCR99a
from plane.fit_plane_elastic import FitPlaneElastic

from scipy.interpolate import RegularGridInterpolator
from scipy.spatial import cKDTree
from skimage.measure import regionprops, label
from cellpose import io, models, plot

import matplotlib.pyplot as plt
import cv2
import numpy as np

from google.colab import drive
drive.mount('/content/drive')

In [None]:
# @title Inputs

# Image paths
oct_volume_image_path = ""      # Path to OCT volume 
hist_image_path = ""            # Path to fluorescent image

# For rough alignment
histology_coords_px = []        # Coordinates in px from fluorescent image
oct_coords_mm = []              # Matching coordinates in mm from photobleach template
smoothing = 0.5                 # 0 for fully elastic, 1 for fully affine
x_range_mm = [0, 1]             # OCT view range in x
y_range_mm = [0, 1]             # OCT view range in y

# Cell-level 
do_cell_to_cell_alignment = True

# For segmentation
z_buffer_um = 20                # Desired amount of extra space above/below the rough fit estimated from the above provided coordinate pairs.
barcode_focus_oct_slice = 120   # Voxel depth/slice number in OCT volume where photobleach barcode z=0.
oct_xy_downscale_factor = 2     # Optional: Downsampling to increase segmentation speed 
avg_cell_diameter_px = 10       # Expected average diameter of the cells to segment, in the raw volume (no downsampling)
hist_valid_cell_segmentation_path = None # Optional: binary mask for histology image to denote areas to keep/discard cell segmentations

# For point matching
initial_pairs_dist_threshold = 40 # Approximate accuracy of rough alignment
z_weight = 1.5                  # Increase this value the more confident you are in your z depth estimations from barcodes.

## Rough Alignment

In [None]:
fp = FitPlaneElastic.from_points(histology_coords_px, oct_coords_mm, smoothing=smoothing)
rough_align_2d_image = fp.image_to_physical_z_projection(hist_img, x_range_mm, y_range_mm, 1e-3) # 1um/px
surface_z_um = fp.get_image_surface_z(x_range_mm, y_range_mm, pixel_size_mm) * 1000 # zmean_plane

plt.imshow(surface_z_um)
plt.colorbar()
plt.title("Surface Depth Estimated from Photobleach Barcodes")
plt.show()

## Cell-to-Cell Alignment

# @title Segmentation
if do_cell_to_cell_alignment:
    oxy_vol_xy = io.imread(oct_volume_image_path)
    oct_vol_xy = downscale_local_mean(oct_vol_xy, (1,oct_xy_downscale_factor,oct_xy_downscale_factor))
    
    lower_z_bound = max(0, surface_z_um.min() - z_buffer_um)
    upper_z_bound = surface_z_um.max() + z_buffer_um
    oct_vol_xy_to_segment = oct_xy_vol[barcode_focus_oct_slice + lower_z_bound: barcode_focus_oct_slice + upper_z_bound]

    # Segment
    oct_masks, oct_flows = segment_cells.segment_cells(oct_xy_vol, avg_cell_diameter_px, flow_threshold=0.85, cellprob_threshold=-8, keep_cells='dark', gpu=True, normalization="clahe")

    # Get cell centers, assign z
    oct_centroids = np.array([r.centroid for r in regionprops(oct_masks)])
    oct_centroids[:,2] = oct_centroids[:,2] * oct_xy_downscale_factor
    oct_centroids[:,1] = oct_centroids[:,1] * oct_xy_downscale_factor
    oct_centroids[:,0] += max(0,int(surface_z_um.min()-lower_z_bound))

    z = oct_centroids[:, 0]
    y = oct_centroids[:, 1]
    x = oct_centroids[:, 2]
    zmin_plane = surface_z_um - z_buffer_um
    zmax_plane = surface_z_um + z_buffer_um
    within_bounds = (z.astype(int) > zmin_plane[y.astype(int), x.astype(int)]) & (z.astype(int) < zmax_plane[y.astype(int), x.astype(int)])
    oct_centroids_xyz = oct_centroids[within_bounds][:, [2, 1, 0]]

    # Save point
    np.savetxt('/content/filtered_oct_centroids_xyz.csv', oct_centroids_xyz, delimiter=',')  # (n,3)

    # 2D Segmentation 

    hist_img = io.imread(hist_image_path)
    hist_masks, hist_flows = segment_cells.segment_cells(hist_img, avg_cell_diameter_px, flow_threshold=0.7, cellprob_threshold=-5, keep_cells='dark', gpu=True, normalization="global")

    # Get cell centers, assign z
    hist_centroids = [np.array(r.centroid) for r in regionprops(hist_masks)]

    constant_z = 1
    z = np.full((len(hist_centroids), 1), constant_z)
    hist_centroids = np.hstack((z, np.array(hist_centroids)))

    hist_centroids_xyz = hist_centroids[:,(2,1,0)]

    interp = RegularGridInterpolator(
        (np.arange(surface_z_um.shape[0]), np.arange(surface_z_um.shape[1])),
        surface_z_um, method='linear', bounds_error=False, fill_value=None
    )

    hist_points = np.stack([hist_centroids_xyz[:,1].ravel(), hist_centroids_xyz[:,0].ravel()], axis=-1)
    hist_z_coords = interp(hist_points).reshape((hist_centroids_xyz[:,2].shape))
    hist_z_coords = hist_z_coords.reshape(-1, 1)
    hist_centroids_xyz = np.hstack((hist_centroids_xyz[:,:2], hist_z_coords))

    if hist_valid_cell_mask_path is not None:
        valid_cell_mask = cv2.imread(hist_valid_cell_mask_path, cv2.IMREAD_GRAYSCALE)
        valid_cell_mask = (valid_cell_mask > 0).astype(np.uint8)
        assert valid_cell_mask.shape[:2] == hist_img.shape[:2], "Valid cell mask shape does not match histology image shape"

        x = hist_centroids_xyz[:, 0].astype(int)
        y = hist_centroids_xyz[:, 1].astype(int)

        hist_centroids_xyz = hist_centroids_xyz[valid_cell_mask[y,x] > 0]

    # Save point
    np.savetxt('/content/filtered_hist_centroids_xyz.csv', hist_centroids_xyz, delimiter=',')  # (n,3)

In [None]:
# @title Matching initial pairs
if do_cell_to_cell_alignment:
    k = 20

    A_T = np.asarray(oct_centroids_xyz, dtype=float)  # (n, 3)
    B_T = np.asarray(hist_centroids_xyz, dtype=float)  # (m, 3)

    # Scale Z coordinate
    A_scaled = A_T.copy()
    B_scaled = B_T.copy()
    A_scaled[:, 2] *= z_weight
    B_scaled[:, 2] *= z_weight

    # Build KD-tree on scaled B
    tree = cKDTree(B_scaled)

    # Query k nearest neighbors
    dists, indices = tree.query(A_scaled, k=k)

    A_matched = []
    B_matched = []

    for i in range(A_scaled.shape[0]):
        a_orig = A_T[i]  # unscaled
        for j in range(k):
            b_orig = B_T[indices[i, j]]  # unscaled match
            if dists[i, j] < initial_pairs_dist_threshold: 
                A_matched.append(a_orig)
                B_matched.append(b_orig)

    A_matched = np.array(A_matched).T
    B_matched = np.array(B_matched).T
    matched_pairs = np.vstack((A_matched, B_matched))  # (6, n). OCT then hist

    # Save
    np.savetxt('/content/centroids_matched_pairs.csv', matched_pairs, delimiter=',')

    print(f"{A_matched.shape[1]} initial point pairs created")