In [1]:
# import packages 
import pandas as pd
import os
from io import StringIO
from skimage import io, transform
import numpy as np
import open3d as o3d
import copy
import napari
from skimage import morphology
from skimage.measure import label, regionprops, block_reduce
from scipy import stats, ndimage
import matplotlib.pyplot as plt
from tifffile import imsave
from tqdm import tqdm
import sys
import os

root_dir = os.path.join(os.getcwd(), '..')
sys.path.append(root_dir)

from src.aux_pcd_functions  import pcd_to_image, image_to_pcd

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def draw_registration_result(source, target, transformation):
    """
    Draw two point clouds in different colors
    
    Parameters
    ----------
    source : point cloud
    target : point cloud
    transformation : numpy 4x4 array
        transformation matrix applied to the source point cloud.

    Returns
    -------
    None.

    """
    source_temp = copy.deepcopy(source)
    target_temp = copy.deepcopy(target)
    source_temp.paint_uniform_color([1, 0.706, 0])
    target_temp.paint_uniform_color([0, 0.651, 0.929])
    source_temp.transform(transformation)
    o3d.visualization.draw_geometries([source_temp, target_temp])
    
    return

def _manual_registration(source, source_values, target, target_values):
    """
    Parameters
    ----------
    source : pcd
    source_values : numpy array
    target : pcd
    target_values : numpy array
    
    Returns
    -------
    transformation: open3d transformation
    the affine transformation mapping the source points cloud on the target points cloud.

    """
    source_temp = copy.deepcopy(source)
    target_temp = copy.deepcopy(target)
    
    bw_colors = 10*(source_values-np.min(source_values))/np.max(source_values)
    source_temp.colors = o3d.utility.Vector3dVector(np.asarray([bw_colors, bw_colors, bw_colors]).T)

    bw_colors = 10*(target_values-np.min(target_values))/np.max(target_values)
    target_temp.colors = o3d.utility.Vector3dVector(np.asarray([bw_colors, bw_colors, bw_colors]).T)

    # pick points from two point clouds and builds correspondences
    picked_id_source = pick_points(source_temp)
    picked_id_target = pick_points(target_temp)
    
    assert (len(picked_id_source) >= 3 and len(picked_id_target) >= 3)
    assert (len(picked_id_source) == len(picked_id_target))
    corr = np.zeros((len(picked_id_source), 2))
    corr[:, 0] = picked_id_source
    corr[:, 1] = picked_id_target

    # estimate transformation:
    p2p = o3d.pipelines.registration.TransformationEstimationPointToPoint(with_scaling=True)
    transformation = p2p.compute_transformation(source, target, o3d.utility.Vector2iVector(corr))

    return transformation


def pick_points(pcd):
    """
    Visualize a point cloud object and allows the user to select a series of points
    on the object. Returns the coordinates of the selected points.

    Parameters
    ----------
    pcd : point cloud object

    Returns
    -------
    selected points (pcd).

    """
    
    # These are used to suppress the printed output from Open3D while picking points:
    stdout_old = sys.stdout
    sys.stdout = StringIO()
    # Create Visualizer with editing:
    vis = o3d.visualization.VisualizerWithEditing()
    vis.create_window()
    vis.add_geometry(pcd)
    # user picks points
    result = vis.run()  
    vis.destroy_window()
    # This restores the output:
    sys.stdout = stdout_old
    return vis.get_picked_points()

### test registration without color:

In [None]:
# Test registration:

# Read files and create point clouds
reference_fly_filename = "../../data_2/01_raw/multiple/C1-220528_D3-4xEn-Gal4_1.tif"
abdomen_mask_file = "../../data_2/References_and_masks/Reference_abdomen_mask_iso_thick.tif"

source_file_name = "../../data_2/02_preprocessed/Preprocessed_C1-D3_en_pnr_male_02_20221025.tif"
source_file_name_dsred = "../../data_2/02_preprocessed/Preprocessed_C2-D3_en_pnr_male_02_20221025.tif"

Reference_Image = io.imread(reference_fly_filename)
Abdomen_Mask = io.imread(abdomen_mask_file)
Source_Image = io.imread(source_file_name)
Source_dsred = io.imread(source_file_name_dsred)

source, source_values = image_to_pcd(Source_Image)
source_dsred, source_values_dsred = image_to_pcd(Source_dsred)

target, target_values = image_to_pcd(Reference_Image)

source = _manual_registration(source, source_values, target, target_values)