In [None]:
!pip install cellpose==2.2 --quiet

In [None]:
import segment_cells
import PCR99a

from scipy.interpolate import RegularGridInterpolator
from scipy.spatial import cKDTree
from skimage.measure import regionprops, label
from cellpose import io, models, plot

import matplotlib.pyplot as plt
import cv2
import numpy as np

In [None]:
# INPUTS: TODO formatting

oct_image_path = ""
hist_image_path = ""

gt_coords_xy_mm = []
hist_coords_xyz_px = []
oct_view_corners_xy_mm = []

zero_slice_num = 0
xy_downsample_factor = 1

cell_to_cell_alignment = True
avg_cell_diameter_px = 10

hist_valid_cell_mask_path = None

initial_pairs_dist_threshold = 40
z_weight = 1.5

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Rough Alignment

In [None]:
min_x = np.min(oct_view_corners_xy_mm[:, 0])
candidates = oct_view_corners_xy_mm[oct_view_corners_xy_mm[:, 0] == min_x]
oct_origin_mm = candidates[np.argmin(candidates[:, 1])]

gt_coords_xy_mm[:,0] -= oct_origin_mm[0]
gt_coords_xy_mm[:,1] -= oct_origin_mm[1]

gt_coords_xy_um = gt_coords_xy_mm * 1000

oct_view_dims_yx = np.array(
    [(max(gt_coords_xy_um[:,1])-min(gt_coords_xy_um[:,1])), (max(gt_coords_xy_um[:,0])-min(gt_coords_xy_um[:,0]))]
)

hist_image = cv2.imread(hist_image_path)
warped_hist_image = warp_image.bspline_warp_image(hist_image, gt_coords_xy_um[:,[1,0]], hist_coords_xyz_px[:,[1,0]], order=3, output_shape=oct_view_dims_yx)

for c in range(3):
    warped_hist_image[:,:,c] = np.clip(warped_hist_image, 0, 255)

cv2.imwrite('rough_aligned_hist_image.png', warped_hist_image)

# zmin_plane = []
# zmean_plane = []
# zmax_plane = []

## Cell-to-Cell Alignment

In [None]:
# Load images
oct_xy_vol = io.imread(oct_image_path)

# Segment
oct_masks, oct_flows = segment_cells.segment_cells(oct_xy_vol, avg_cell_diameter_px, flow_threshold=0.85, cellprob_threshold=-8, keep_cells='dark', gpu=True, normalization="clahe")

# Get cell centers, assign z
oct_centroids = np.array([r.centroid for r in regionprops(oct_masks)])
oct_centroids[:,2] = oct_centroids[:,2] * xy_downsample_factor
oct_centroids[:,1] = oct_centroids[:,1] * xy_downsample_factor
oct_centroids[:,0] += max(0,int(zmin_plane.min()))

z = oct_centroids[:, 0]
y = oct_centroids[:, 1]
x = oct_centroids[:, 2]
within_bounds = (z.astype(int) > zmin_plane[y.astype(int), x.astype(int)]) & (z.astype(int) < zmax_plane[y.astype(int), x.astype(int)])
oct_centroids_xyz = oct_centroids[within_bounds][:, [2, 1, 0]] 

# Save point
np.savetxt('/content/filtered_oct_centroids_xyz.csv', oct_centroids_xyz, delimiter=',')  # (n,3)


# Load images
hist_img = io.imread(hist_image_path)

# Segment
hist_masks, hist_flows = segment_cells.segment_cells(hist_img, avg_cell_diameter_px, flow_threshold=0.7, cellprob_threshold=-5, keep_cells='dark', gpu=True, normalization="global")

# Get cell centers, assign z
hist_centroids = [np.array(r.centroid) for r in regionprops(hist_masks)]

constant_z = 1
z = np.full((len(hist_centroids), 1), constant_z)
hist_centroids = np.hstack((z, np.array(hist_centroids)))

hist_centroids_xyz = hist_centroids[:,(2,1,0)]

interp = RegularGridInterpolator(
    (np.arange(zmean_plane.shape[0]), np.arange(zmean_plane.shape[1])),
    zmean_plane, method='linear', bounds_error=False, fill_value=None
)

hist_points = np.stack([hist_centroids_xyz[:,1].ravel(), hist_centroids_xyz[:,0].ravel()], axis=-1)
hist_z_coords = interp(hist_points).reshape((hist_centroids_xyz[:,2].shape))
hist_z_coords = hist_z_coords.reshape(-1, 1)
hist_centroids_xyz = np.hstack((hist_centroids_xyz[:,:2], hist_z_coords))

if hist_valid_cell_mask_path is not None:
    valid_cell_mask = cv2.imread(hist_valid_cell_mask_path, cv2.IMREAD_GRAYSCALE)
    valid_cell_mask = (valid_cell_mask > 0).astype(np.uint8)
    assert valid_cell_mask.shape[:2] == hist_img.shape[:2], "Valid cell mask shape does not match histology image shape"

    x = hist_centroids_xyz[:, 0].astype(int)
    y = hist_centroids_xyz[:, 1].astype(int)

    hist_centroids_xyz = hist_centroids_xyz[valid_cell_mask[y,x] > 0]

# Save point
np.savetxt('/content/filtered_hist_centroids_xyz.csv', hist_centroids_xyz, delimiter=',')  # (n,3)

In [None]:
# Matching initial pairs

k = 20

A_T = np.asarray(oct_centroids_xyz, dtype=float)  # (n, 3)
B_T = np.asarray(hist_centroids_xyz, dtype=float)  # (m, 3)

# Scale Z coordinate 
A_scaled = A_T.copy()
B_scaled = B_T.copy()
A_scaled[:, 2] *= z_weight
B_scaled[:, 2] *= z_weight

# Build KD-tree on scaled B
tree = cKDTree(B_scaled)

# Query k nearest neighbors
dists, indices = tree.query(A_scaled, k=k)

A_matched = []
B_matched = []

for i in range(A_scaled.shape[0]):
    a_orig = A_T[i]  # unscaled
    for j in range(k):
        b_orig = B_T[indices[i, j]]  # unscaled match
        if dists[i, j] < initial_pairs_dist_threshold:  # threshold in scaled units TODO param
            A_matched.append(a_orig)
            B_matched.append(b_orig)

A_matched = np.array(A_matched).T
B_matched = np.array(B_matched).T
matched_pairs = np.vstack((A_matched, B_matched))  # (6, n)

# Save 
np.savetxt('/content/centroids_matched_pairs.csv', matched_pairs, delimiter=',')

print(f"{A_matched.shape[1]} initial point pairs created")