## Cell-to-Cell Alignment

In [None]:
# @title Segmentation
if do_cell_to_cell_alignment:
    oxy_vol_xy = io.imread(oct_volume_image_path)
    oct_vol_xy = downscale_local_mean(oct_vol_xy, (1,oct_xy_downscale_factor,oct_xy_downscale_factor))

    lower_z_bound = max(0, surface_z_um.min() - z_buffer_um)
    upper_z_bound = surface_z_um.max() + z_buffer_um
    oct_vol_xy_to_segment = oct_xy_vol[barcode_focus_oct_slice + lower_z_bound: barcode_focus_oct_slice + upper_z_bound]

    # Segment
    oct_masks, oct_flows = segment_cells.segment_cells(oct_xy_vol, avg_cell_diameter_px, flow_threshold=0.85, cellprob_threshold=-8, keep_cells='dark', gpu=True, normalization="clahe")

    # Get cell centers, assign z
    oct_centroids = np.array([r.centroid for r in regionprops(oct_masks)])
    oct_centroids[:,2] = oct_centroids[:,2] * oct_xy_downscale_factor
    oct_centroids[:,1] = oct_centroids[:,1] * oct_xy_downscale_factor
    oct_centroids[:,0] += max(0,int(surface_z_um.min()-lower_z_bound))

    z = oct_centroids[:, 0]
    y = oct_centroids[:, 1]
    x = oct_centroids[:, 2]
    zmin_plane = surface_z_um - z_buffer_um
    zmax_plane = surface_z_um + z_buffer_um
    within_bounds = (z.astype(int) > zmin_plane[y.astype(int), x.astype(int)]) & (z.astype(int) < zmax_plane[y.astype(int), x.astype(int)])
    oct_centroids_xyz = oct_centroids[within_bounds][:, [2, 1, 0]]

    # Save point
    np.savetxt('/content/filtered_oct_centroids_xyz.csv', oct_centroids_xyz, delimiter=',')  # (n,3)

    # 2D Segmentation

    hist_img = io.imread(hist_image_path)
    hist_masks, hist_flows = segment_cells.segment_cells(hist_img, avg_cell_diameter_px, flow_threshold=0.7, cellprob_threshold=-5, keep_cells='dark', gpu=True, normalization="global")

    # Get cell centers, assign z
    hist_centroids = [np.array(r.centroid) for r in regionprops(hist_masks)]

    constant_z = 1
    z = np.full((len(hist_centroids), 1), constant_z)
    hist_centroids = np.hstack((z, np.array(hist_centroids)))

    hist_centroids_xyz = hist_centroids[:,(2,1,0)]

    interp = RegularGridInterpolator(
        (np.arange(surface_z_um.shape[0]), np.arange(surface_z_um.shape[1])),
        surface_z_um, method='linear', bounds_error=False, fill_value=None
    )

    hist_points = np.stack([hist_centroids_xyz[:,1].ravel(), hist_centroids_xyz[:,0].ravel()], axis=-1)
    hist_z_coords = interp(hist_points).reshape((hist_centroids_xyz[:,2].shape))
    hist_z_coords = hist_z_coords.reshape(-1, 1)
    hist_centroids_xyz = np.hstack((hist_centroids_xyz[:,:2], hist_z_coords))

    if hist_valid_cell_mask_path is not None:
        valid_cell_mask = cv2.imread(hist_valid_cell_mask_path, cv2.IMREAD_GRAYSCALE)
        valid_cell_mask = (valid_cell_mask > 0).astype(np.uint8)
        assert valid_cell_mask.shape[:2] == hist_img.shape[:2], "Valid cell mask shape does not match histology image shape"

        x = hist_centroids_xyz[:, 0].astype(int)
        y = hist_centroids_xyz[:, 1].astype(int)

        hist_centroids_xyz = hist_centroids_xyz[valid_cell_mask[y,x] > 0]

    # Save point
    np.savetxt('/content/filtered_hist_centroids_xyz.csv', hist_centroids_xyz, delimiter=',')  # (n,3)
