## Finding the correspondences between fiducial particle positions in EM and LM images

This notebook demonstrates the automatic detection of correspondences between fiducial particles in EM and LM images, followed by computation of a displacement field and warping of the LM image accordingly.

The main steps of the algorithm are:
- **1. Loading fiducial particle locations** - Fiducial particle coordinates are loaded from an .xml file for both EM (target) and LM (source) images.
- **2. Rescaling LM coordinates** - The LM fiducial coordinates are rescaled to match the coordinate system of the EM image.
- **3. Multilevel registration** - The algorithm performs a two-stage registration: first rigid, then non-rigid. It automatically identifies corresponding fiducial particles between the two modalities, assigns them matching IDs, and saves the updated information in the .xml file.
- **4. Warping the LM image** - Using the displacement field computed from the matched fiducial points, the LM image is warped to align with the EM image.

As an alternative, the updated .xml file (step 3) containing the automatically detected correspondences can be imported into the ec-CLEM software, allowing EM and LM image registration without manual annotation.

Load the necessary python libraries:

In [None]:
import os
import open3d as o3d
import numpy as np
import pandas as pd
from pathlib import Path
import bigfish.stack as stack
import bigfish.plot as plot
from probreg import cpd
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
from utils import xml_to_dataframe, dataframe_to_xml, dataframe_to_pointcloud, dataframe_to_xml_
from utils import visualize_result_nparray, clean_correspondences, print_transformations, chamfer_distance
from utils import convert_to_pcd, save_correspondences_in2df, save_correspondences_in1df
from utils_displacement_field import expand_displacement_field, extrapolate_displacement_field, visualize_extended_field
from utils_displacement_field import calculate_displacement_vectors, plot_and_save_overlay_images, warp_image

#### 1. Loading fiducial particle locations

The locations of fiducial particles in EM (target) and LM (source) images were automatically detected using the following notebooks:

- Detect_fiducial_particles_in_EM.ipynb
- Detect_fiducial_particles_in_LM.ipynb

These detections were saved in various formats. For further processing, we load the .xml files into Pandas DataFrames. 

In [None]:
input_folder = 'E:/DATA/AI4Life_Pr26/20240805_Trial_data_fiducial_particles/240723_JB294_CLEM-AI4life_sample1/pos1/'

# target_xml_clust = "target_clusters.xml"
target_xml_clust = "Ground_truth_fiducials_EM_only_clusters.xml" # Ground truth fiducials
target_path_xml_clust = Path(os.path.join(input_folder, 'output', target_xml_clust))
print(target_path_xml_clust.exists())

# source_xml_clust = 'source_clusters.xml'
source_xml_clust = "Ground_truth_fiducials_LM_only_clusters.xml" # Ground truth fiducials
source_path_xml_clust = Path(os.path.join(input_folder, 'output', source_xml_clust))
print(source_path_xml_clust.exists())

target_df_clust = xml_to_dataframe(target_path_xml_clust)
#print(target_df_clust)
source_df_small_clust = xml_to_dataframe(source_path_xml_clust)   # small means it is smaller resolution than target, we need to scale it up
#print(source_df_small_clust)

# Create output folder if it does not exist
output_folder = Path(os.path.join(input_folder,"output"))
output_folder.mkdir(exist_ok=True)

#### 2. Rescaling LM coordinates

Since EM and LM images often differ in size—with LM images typically being much smaller—the coordinates of fiducial points detected in the LM image must be rescaled to match the EM image coordinate system. Without this rescaling, registration becomes difficult or inaccurate, as the two point sets would exist in incompatible spatial references.

In [None]:
def find_the_scale(EM_shape, LM_shape):
    scale_x = EM_shape[0]/LM_shape[0]
    scale_y = EM_shape[1]/LM_shape[1]
    return scale_x, scale_y

# load the EM and LM images
EM_image_path = os.path.join(input_folder, "240726_JB295_HEK293_CLEM_LAMP1-488_Particles-555_grid4_pos1_bin4_EM.tif")
LM_image_path  = os.path.join(input_folder, "240726_JB295_HEK293_CLEM_LAMP1-488_Particles-555_grid4_pos1_LM.tif")

EMimage = stack.read_image(EM_image_path)
LMimage_small = stack.read_image(LM_image_path)

# Find what is the scaling rate between the 2 images
scale_y, scale_x = find_the_scale(EMimage.shape, LMimage_small.shape)

print("Scale x: ",scale_x)  # Scale_x = 22.966165413533833
print("Scale y: ",scale_y)  # Scale y:  24.873456790123456

# Resize the LM point positions
source_df_clust = xml_to_dataframe(source_path_xml_clust)
source_df_clust['pos_x'] = source_df_small_clust['pos_x']*scale_x
source_df_clust['pos_y'] = source_df_small_clust['pos_y']*scale_y

# Resize the LM image to fit the position of the resized points
LMimage = stack.resize_image(LMimage_small, EMimage.shape, method='bilinear')
#EMimage_small = stack.resize_image(EMimage, LMimage_small.shape, method='bilinear')

Plotting the point coordinates on top of the corresponding image allows us to visually verify their correct placement and ensure the coordinates were loaded and scaled properly.

In [None]:
#plot.plot_detection(LMimage_small[:,:,1], source_small, contrast=True)
target_clust = target_df_clust[['pos_x', 'pos_y']].to_numpy()
source_clust = source_df_clust[['pos_x', 'pos_y']].to_numpy()

plot.plot_detection(LMimage[:,:,1], (source_clust[:, [1, 0]]), shape="circle", radius = 3*scale_y, color = "red", linewidth = 1, fill=False, contrast=True) 
plot.plot_detection(EMimage, (target_clust[:, [1, 0]]), radius = 3*scale_y, contrast=False)

print(source_clust.shape)
print(target_clust.shape)

The fiducial points are converted from Pandas DataFrames into point clouds and saved as .ply files for further processing.

In [None]:
target_clusters = dataframe_to_pointcloud(target_df_clust, "target_clusters.ply")  #target_pcd
source_clusters = dataframe_to_pointcloud(source_df_clust, "source_clusters.ply")  #source_pcd

# Visualize the point cloud
#o3d.visualization.draw_geometries([target_clusters])
#o3d.visualization.draw_geometries([source_clusters])

Loading the coordinates corresponding to the regions of fiducial particles for further analysis or processing.

In [None]:
# target_xml_regi = "target_regions.xml"
target_xml_regi = "Ground_truth_fiducials_EM.xml"  # Ground truth fiducial
target_path_xml_regi = Path(os.path.join(input_folder, 'output', target_xml_regi))
print(target_path_xml_regi.exists())

#source_xml_regi = 'source_regions.xml'
source_xml_regi = 'Ground_truth_fiducials_LM.xml' # Ground truth fiducial
source_path_xml_regi = Path(os.path.join(input_folder, 'output', source_xml_regi))
print(source_path_xml_regi.exists())

target_df_regi = xml_to_dataframe(target_path_xml_regi)
print(target_df_regi)
source_df_small_regi = xml_to_dataframe(source_path_xml_regi)   # small means it is smaller resolution than target, we need to scale it up
print(source_df_small_regi)

# Resize the LM point positions
source_df_regi = xml_to_dataframe(source_path_xml_regi)
source_df_regi['pos_x'] = source_df_small_regi['pos_x']*scale_x
source_df_regi['pos_y'] = source_df_small_regi['pos_y']*scale_y

target_regions = dataframe_to_pointcloud(target_df_regi, "target_regions.ply")  #target_pcd
source_regions = dataframe_to_pointcloud(source_df_regi, "source_regions.ply")  #source_pcd

#### 3. Multilevel registration

The multilevel registration by Coherent Point Drift (CPD) process consists of two stages:

 - **Rigid registration** – This initial step uses a smaller number of points, specifically the coordinates of fiducial particle clusters, to align the images through translation, rotation, and scaling.

- **Non-rigid registration** – In this stage, a larger set of points is used, including the positions of individual fiducial particles located within the previously identified regions. This allows for finer, local adjustments to account for non-linear distortions between the images.

In [None]:

'''
# Load the point clouds and set the scale
source = o3d.io.read_point_cloud("source.ply") # ('source.ply') ('source_all.ply')
target = o3d.io.read_point_cloud("target.ply") # ('target.ply') ('target_all.ply')
source_all = o3d.io.read_point_cloud("source_all.ply") # ('source.ply') ('source_all.ply')
target_all = o3d.io.read_point_cloud("target_all.ply") # ('target.ply') ('target_all.ply')


# Convert to numpy arrays and subscale the points
source_points = np.asarray(source.points)/scale         # Subscale the points, so the physical distance between them is not too large
target_points = np.asarray(target.points)/scale          # Subscale the points, so the physical distance between them is not too large
source_points_all = np.asarray(source_all.points)/scale         # Subscale the points, so the physical distance between them is not too large
target_points_all = np.asarray(target_all.points)/scale          # Subscale the points, so the physical distance between them is not too large

'''
scale = 1000

# Convert to numpy arrays and subscale the points
source_points = np.asarray(source_clusters.points)/scale         # Subscale the points, so the physical distance between them is not too large
target_points = np.asarray(target_clusters.points)/scale          # Subscale the points, so the physical distance between them is not too large
source_points_all = np.asarray(source_regions.points)/scale         # Subscale the points, so the physical distance between them is not too large
target_points_all = np.asarray(target_regions.points)/scale          # Subscale the points, so the physical distance between them is not too large

# ------------------------------------------------------------------------------------------
# 1st - registration by Coherent Point Drift (CPD) - rigid
tf_param_rigid, _, _ = cpd.registration_cpd(source_points, target_points, tf_type_name='rigid', maxiter=1000, tol=1e-5)
source_points_res2 = tf_param_rigid.transform(source_points)
source_points_all_res2 = tf_param_rigid.transform(source_points_all)

#visualize_result_nparray(source_points, target_points, source_points_res2, "Rigid CPD")
print_transformations(tf_param_rigid, "Rigid CPD Transformation:")
chamfer_distance(target_points, source_points_res2, "Chamfer distance 1st - Rigid CPD")
chamfer_distance(target_points_all, source_points_all_res2, " all Chamfer distance 2nd - Nonrigid CPD")

# ------------------------------------------------------------------------------------------
# 2nd - registration by Coherent Point Drift (CPD) - nonrigid 
tf_param_nonrigid, _, _ = cpd.registration_cpd(source_points_res2, target_points, tf_type_name='nonrigid', maxiter=1000, tol=1e-5)
source_points_res3 = tf_param_nonrigid.transform(source_points_res2)

tf_param_nonrigid_all, _, _ = cpd.registration_cpd(source_points_all_res2, target_points_all, tf_type_name='nonrigid', maxiter=1000, tol=1e-5)
source_points_all_res3 = tf_param_nonrigid_all.transform(source_points_all_res2)

#visualize_result_nparray(source_points_res2, target_points, source_points_res3, "Nonrigid CPD")
print("Non-rigid CPD Transformation:")
print(tf_param_nonrigid.g)  # Displacement field
print(tf_param_nonrigid.w)  # Weight matrix

chamfer_distance(target_points, source_points_res3, "Chamfer distance 2nd - Nonrigid CPD")
chamfer_distance(target_points_all, source_points_all_res3, "all Chamfer distance 2nd - Nonrigid CPD")

Finding correspondences between the points using Point2Point algorithm - This one can be applied to full set of points. Create panda dataframe with the corresponding points having the same "id"

In [None]:
threshold = 0.02
trans_init = np.asarray([[ 1, 0, 0, 0],  # In trying to have identity matrix as initial transformation so I can instead of source point use result of non-rigid reg points and to the ICP on those
                         [ 0, 1, 0, 0],
                         [ 0, 0, 1, 0],
                         [ 0, 0, 0, 1]])

sour = convert_to_pcd(source_points_res3)
targ = convert_to_pcd(target_points)
sour_all = convert_to_pcd(source_points_all_res3)
targ_all = convert_to_pcd(target_points_all)

evaluation = o3d.pipelines.registration.evaluate_registration(sour, targ, threshold, trans_init)
evaluation_all = o3d.pipelines.registration.evaluate_registration(sour_all, targ_all, threshold, trans_init)
print("Evaluation: ", evaluation)
print("Correspondence set: ")
print(np.asarray(evaluation.correspondence_set))

print("Apply point-to-point ICP")
reg_p2p = o3d.pipelines.registration.registration_icp(
    sour, targ, threshold, trans_init,
    o3d.pipelines.registration.TransformationEstimationPointToPoint())
print(reg_p2p)
print("Transformation is:")
print(reg_p2p.transformation)

#evaluation = o3d.pipelines.registration.evaluate_registration(sour, targ, threshold, reg_p2p.transformation)
#evaluation_all = o3d.pipelines.registration.evaluate_registration(sour_all, targ_all, threshold, reg_p2p.transformation)
#print("Evaluation: ", evaluation)
print("Correspondence set: ")
print(np.asarray(evaluation_all.correspondence_set))


correspondences = clean_correspondences(np.asarray(evaluation.correspondence_set))
correspondences_all = clean_correspondences(np.asarray(evaluation_all.correspondence_set))
print("Correspondences: ", correspondences_all)

dff2s,dff2t = save_correspondences_in2df(source_points, target_points, correspondences)

# This correspondence seems to be working in a way that corresponding point may repeat in the list. 
# same point can be corresponding to multiple points - fixed by clean_correspondences

#orig_source_df = save_correspondences_in1df(np.asarray(orig_source.points), np.asarray(orig_target.points), correspondences)
orig_source_df, orig_target_df = save_correspondences_in2df((source_points)/[scale_y,scale_x,1]*scale, 
                                                            target_points*scale, correspondences)

orig_source_all_df, orig_target_all_df = save_correspondences_in2df((source_points_all)/[scale_y,scale_x,1]*scale, 
                                                            target_points_all*scale, correspondences_all)

df = save_correspondences_in1df(source_points_all, target_points_all, correspondences_all)

dataframe_to_xml(orig_source_df, 'original_source_points_2205.xml')
dataframe_to_xml(orig_target_df, 'original_target_points_2205.xml')
dataframe_to_xml(orig_source_all_df, 'original_source_points_all_2205.xml')
dataframe_to_xml(orig_target_all_df, 'original_target_points_all_2205.xml')

In [None]:
'''
source_database = orig_source_all_df
target_database = orig_target_all_df

pointsT = np.column_stack((target_database['pos_y'], target_database['pos_x'])) #np.column_stack((target_database['pos_y'], target_database['pos_x']))
#pointsT_swapped = [[x, y] for y, x in pointsT]
plot.plot_detection(EMimage, pointsT, shape="circle", radius = 3*scale_y, color = "red", linewidth = 1, fill=False, contrast=True) 

pointsS = np.column_stack((source_database['pos_y']*scale_x, source_database['pos_x']*scale_y)) #np.column_stack((target_database['pos_y'], target_database['pos_x']))
#pointsT_swapped = [[x, y] for y, x in pointsT]
plot.plot_detection(LMimage[:,:,1], pointsS, shape="circle", radius = 3*scale_y, color = "red", linewidth = 1, fill=False, contrast=True) 
'''


#### 4. Warping the LM image

The LM image is warped using a displacement field derived from point cloud correspondences between LM and EM images. 

Since the initial displacement field is defined only within the convex hull of the matched points, it is extrapolated to cover the entire image domain to enable full-image warping.

In [None]:
#source_database = orig_source_all_df
#target_database = orig_target_all_df
#points = np.column_stack((target_database['pos_y'], target_database['pos_x']))  #points = np.column_stack((source_database['pos_x']*scale_x, source_database['pos_y']*scale_y))

#displacements = calculate_displacement_vectors(source_database, target_database, scale_x, scale_y)

#displacement_field = expand_displacement_field(points, displacements, EMimage.shape)

#visualize_extended_field(displacement_field, points)


source_database = orig_source_all_df
target_database = orig_target_all_df
points = np.column_stack((target_database['pos_x'], target_database['pos_y']))  #points = np.column_stack((source_database['pos_x']*scale_x, source_database['pos_y']*scale_y))

displacements = calculate_displacement_vectors(source_database, target_database, scale_x, scale_y)

points_swapped = [[x, y] for y, x in points]  # here I am swapping the x and y coordinates
displacement_field = expand_displacement_field(points_swapped, displacements, EMimage.shape)

visualize_extended_field(displacement_field, points)

In [None]:
extrapolated_field = extrapolate_displacement_field(displacement_field)
visualize_extended_field(extrapolated_field, points)

In [None]:
warped_LMimage = warp_image(LMimage, extrapolated_field.astype(np.float16))

In [None]:
def plot_image_with_points(image, dataframe, scale_y):
    points = np.column_stack((dataframe['pos_y'], dataframe['pos_x']))
    plot.plot_detection(image, points, shape="circle", radius = 3*scale_y, color = "red", linewidth = 1, fill=False, contrast=True) 
   
#pointsT = np.column_stack((target_database['pos_y'], target_database['pos_x'])) #np.column_stack((target_database['pos_y'], target_database['pos_x']))
#plot.plot_detection(warped_LMimage, pointsT, shape="circle", radius = 3*scale_y, color = "red", linewidth = 1, fill=False, contrast=True) 


plot_image_with_points(warped_LMimage[:,:,1], target_database, scale_y)

In [None]:
plot_and_save_overlay_images(EMimage, warped_LMimage[:,:,1], 'overlay_EM_LM_1.png') 