In [4]:
# import packages 
import numpy as np
import open3d as o3d
from io import StringIO
import copy
from scipy import stats, ndimage
import sys
import os
from skimage import io

root_dir = os.path.join(os.getcwd(), '..')
sys.path.append(root_dir)

from src.aux_pcd_functions  import pcd_to_image, image_to_pcd

In [5]:
def draw_registration_result(source, target):
    """
    Draw two point clouds in blue and yellow
    
    Parameters
    ----------
    source : point cloud
    target : point cloud

    Returns
    -------
    None.

    """
    source_temp = copy.deepcopy(source)
    target_temp = copy.deepcopy(target)
    source_temp.paint_uniform_color([1, 0.706, 0])
    target_temp.paint_uniform_color([0, 0.651, 0.929])
    o3d.visualization.draw_geometries([source_temp, target_temp])
    
    return

def _manual_registration(source, source_values, target, target_values):
    """
    Parameters
    ----------
    source : pcd
    source_values : numpy array
    target : pcd
    target_values : numpy array
    
    Returns
    -------
    transformation: open3d transformation
    the affine transformation mapping the source points cloud on the target points cloud.

    """
    source_temp = copy.deepcopy(source)
    target_temp = copy.deepcopy(target)
    
    bw_colors = 10*(source_values-np.min(source_values))/np.max(source_values)
    source_temp.colors = o3d.utility.Vector3dVector(np.asarray([bw_colors, bw_colors, bw_colors]).T)

    bw_colors = 10*(target_values-np.min(target_values))/np.max(target_values)
    target_temp.colors = o3d.utility.Vector3dVector(np.asarray([bw_colors, bw_colors, bw_colors]).T)

    # pick points from two point clouds and builds correspondences
    picked_id_source = pick_points(source_temp)
    picked_id_target = pick_points(target_temp)
    
    assert (len(picked_id_source) >= 3 and len(picked_id_target) >= 3)
    assert (len(picked_id_source) == len(picked_id_target))
    corr = np.zeros((len(picked_id_source), 2))
    corr[:, 0] = picked_id_source
    corr[:, 1] = picked_id_target

    # estimate transformation:
    p2p = o3d.pipelines.registration.TransformationEstimationPointToPoint(with_scaling=True)
    transformation = p2p.compute_transformation(source, target, o3d.utility.Vector2iVector(corr))

    return transformation


def pick_points(pcd):
    """
    Visualize a point cloud object and allows the user to select a series of points
    on the object. Returns the coordinates of the selected points.

    Parameters
    ----------
    pcd : point cloud object

    Returns
    -------
    selected points (pcd).

    """
    
    # These are used to suppress the printed output from Open3D while picking points:
    stdout_old = sys.stdout
    sys.stdout = StringIO()
    # Create Visualizer with editing:
    vis = o3d.visualization.VisualizerWithEditing()
    vis.create_window()
    vis.add_geometry(pcd)
    # user picks points
    result = vis.run()  
    vis.destroy_window()
    # This restores the output:
    sys.stdout = stdout_old
    return vis.get_picked_points()

### test registration without color:

In [8]:
# Test registration:

# Read files and create point clouds
reference_fly_filename = "../../data_2/References_and_masks/C1_Reference_iso.tif"
abdomen_mask_file = "../../data_2/References_and_masks/Reference_abdomen_mask_iso.tif"

source_file_name = "../../data_2/02_preprocessed/Preprocessed_C2-D0_En_pn_male_01_20220704.tif"

Reference_Image = io.imread(reference_fly_filename)
Abdomen_Mask = io.imread(abdomen_mask_file)
Source_Image = io.imread(source_file_name)

source, source_values = image_to_pcd(Source_Image)
target, target_values = image_to_pcd(Reference_Image)

transformation = _manual_registration(source, source_values, target, target_values)
source.transform(transformation)
draw_registration_result(source, target)

registered_source_image = pcd_to_image(source, source_values, Reference_Image.shape)

In [9]:
import napari
viewer = napari.view_image(registered_source_image)

INFO - 2023-04-20 18:52:58,507 - acceleratesupport - No OpenGL_accelerate module loaded: No module named 'OpenGL_accelerate'


In [10]:
viewer = napari.view_image(Source_Image)