In [1]:
import os
import time
import copy

import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image
import pickle
from trimesh import Trimesh
import trimesh
import open3d as o3d

import image
from utils import get_integer_segments, plot_region_numbers, triangulate_segments, trimesh_to_o3d, o3d_to_trimesh
from fuse_meshes import mesh_fusion

In [2]:
def get_kf(kf, dataset='tum'):
    dsname = 'kfs_moge_tum' if dataset == 'tum' else 'kfs_moge_realsense'
    with open(f'../data/{dsname}/kf_{int(kf)}.pkl', 'rb') as file:
        data = pickle.load(file)
    return data

def get_SP(data=None, dataset=None, kf=None, SPID=None):
    if data is None and kf is not None and dataset is not None and SPID is not None:
        data = get_kf(kf=kf, dataset=dataset)
    elif data is None or SPID is None:
        raise ValueError('Invalid input. Pass data and SPID, or dataset, kf and SPID.')
    
    sp_regions = data['state_dict']['sp_regions'].cpu().numpy()

    verts = data['state_dict']['pointmap'].cpu().reshape(-1, 3)

    colours = data['state_dict']['image_raw'].cpu().permute(1,2,0).reshape(-1,3)
    colours_4chnl = np.hstack([colours,np.ones((colours.shape[0],1))])
    colours_255 = (colours_4chnl*255).astype('uint8')

    integer_segments = get_integer_segments(sp_regions)

    tris = triangulate_segments(verts, integer_segments)

    meshes = [Trimesh(vertices=verts, faces=individual_tris, vertex_colors=colours_255) for individual_tris in tris]

    return meshes[SPID-1]

def plot_kf(data, figsize=(6,5)):
    sp_regions = data['state_dict']['sp_regions'].cpu().numpy()
    integer_segments = get_integer_segments(sp_regions)
    plot_region_numbers(integer_segments, figsize=figsize)

In [3]:
# for i in range(31):
#     kfid = 10*i + 1
#     data = get_kf(kfid, 'not_tum')
#     plot_kf(data)
#     print(f'^keyframe {kfid}')

In [4]:
mesh_path = '../data/mustard360/super-primitives/'
mesh_filenames = os.listdir(mesh_path)
mesh_filenames.sort(key=lambda name: int(name.split('_')[1].split('.')[0]))

meshes = []
for mesh_filename in mesh_filenames:
    mesh = trimesh.load(mesh_path+mesh_filename)
    meshes.append(mesh)

In [5]:
# mesh1 = get_SP(dataset='kfs_moge_realsense', kf=1, SPID=22)
# mesh2 = get_SP(dataset='kfs_moge_realsense', kf=11, SPID=15)
# mesh3 = get_SP(dataset='kfs_moge_realsense', kf=21, SPID=15)
# mesh4 = get_SP(dataset='kfs_moge_realsense', kf=31, SPID=1)
# mesh5 = get_SP(dataset='kfs_moge_realsense', kf=41, SPID=24)
# mesh6 = get_SP(dataset='kfs_moge_realsense', kf=51, SPID=17)
# mesh7 = get_SP(dataset='kfs_moge_realsense', kf=61, SPID=20)

# meshes = (mesh1, mesh2, mesh3, mesh4, mesh5, mesh6, mesh7)

# # meshA.export('chair_1_unposed.ply')

In [6]:
def trimesh_to_o3d_pcd(mesh: trimesh.Trimesh) -> o3d.geometry.PointCloud:
    pcd = o3d.geometry.PointCloud()
    pcd.points  = o3d.utility.Vector3dVector(mesh.vertices)
    # vertex_colors might be 0–255 uint8; convert to 0–1 float
    colors = mesh.visual.vertex_colors[:, :3].astype(np.float32) / 255.0
    pcd.colors = o3d.utility.Vector3dVector(colors)
    return pcd

def align_meshes(meshA: trimesh.Trimesh, meshB: trimesh.Trimesh) -> np.ndarray:
    # --- load two coloured meshes -------------------------------------------------
    source_tm = meshA
    target_tm = meshB
    source, target = map(trimesh_to_o3d_pcd, (source_tm, target_tm))

    # --- pre-processing: down-sample & normals ------------------------------------
    voxel = 0.005          # in world units – tune for your data
    sd = source.voxel_down_sample(voxel)
    td = target.voxel_down_sample(voxel)

    # o3d.visualization.draw_geometries([sd, td])

    for pcd in (sd, td):
        pcd.estimate_normals(
            o3d.geometry.KDTreeSearchParamHybrid(radius=voxel*2, max_nn=30)
        )

    src_ctr = sd.get_center()
    tgt_ctr = td.get_center()
    delta   = tgt_ctr - src_ctr

    T_init          = np.eye(4)
    T_init[:3, 3]   = delta

    # ---------- coloured-ICP multi-scale loop -----------------------------------
    est   = o3d.pipelines.registration.TransformationEstimationForColoredICP(
                lambda_geometric=0.968)
    crit  = o3d.pipelines.registration.ICPConvergenceCriteria(max_iteration=50)

    max_start_multiplier = 4 # initial “coarse” radius = 4 × voxel
    max_attempts         = 4 # 4 tries: 4, 8, 16, 32

    attempt = 0
    while attempt < max_attempts:
        start_mult = max_start_multiplier * (2 ** attempt)

        # Build list of icp correspondence radii to try:
        # [voxel*start_mult, voxel/2 * start_mult, ... , ~voxel]
        radius_list = []
        mult = start_mult
        while mult >= 1:
            radius_list.append(voxel * mult)
            mult /= 2

        try:
            T = T_init # Reset guess
            for radius in radius_list:
                result = o3d.pipelines.registration.registration_colored_icp(
                            sd, td,
                            max_correspondence_distance = radius,
                            init                = T,
                            estimation_method   = est,
                            criteria            = crit)
                T = result.transformation
            break
        except RuntimeError as e:
            if "No correspondences found" in str(e):
                print(f"[Attempt {attempt+1}] No matches with start radius "
                    f"{start_mult*voxel:.4f} - trying larger radius …")
                attempt += 1
                continue            # Try again with doubled starting radius
            else:
                raise               # Any other RuntimeError is unexpected

    else:
        raise RuntimeError(f"Colored ICP failed up to start radius {start_mult} * voxel")

    # print("Final RMSE (legacy):\n", result.inlier_rmse)
    # print("Transformation:\n", T)

    return T

def show_alignment(T:np.ndarray, source:trimesh.Trimesh, target:trimesh.Trimesh, show_pcd:bool=False, paint_pcd:bool=True):
    source_tm = copy.deepcopy(source)
    target_tm = copy.deepcopy(target)

    if show_pcd:
        source_pcd, target_pcd = map(trimesh_to_o3d_pcd, (source_tm, target_tm))
        source_pcd.transform(T)
        if paint_pcd:
            source_pcd.paint_uniform_color([1, 0.706, 0]) # orange
            target_pcd.paint_uniform_color([0, 0.651, 0.929]) # blue
        geom_to_show = [source_pcd, target_pcd]
    else:
        source_o3d, target_o3d = map(trimesh_to_o3d, (source_tm, target_tm))
        source_o3d.transform(T)
        geom_to_show = [source_o3d, target_o3d]

    o3d.visualization.draw_geometries(
        geom_to_show,
        zoom         = 0.6,
        front        = [0.5, -0.3, -0.8],
        up           = [0, -1, 0],
        point_show_normal = False
    )

In [7]:
import open3d as o3d

for i in range(len(meshes) - 1):
    j = i+1

    meshA = meshes[i]
    meshB = meshes[j]

    T = align_meshes(meshA, meshB)

    show_alignment(T, meshA, meshB, show_pcd=True, paint_pcd=False)

# meshA_o3d, meshB_o3d = map(trimesh_to_o3d, (meshA, meshB))
# meshA_aligned_o3d = meshA_o3d.transform(T)
# o3d.visualization.draw_geometries([meshA_aligned_o3d, meshB_o3d])

# meshA_aligned = meshA.apply_transform(T)

# fused_mesh = mesh_fusion(
#     mesh1=meshA_aligned,
#     mesh2=meshB
# )
# o3d.visualization.draw_geometries([fused_mesh])

In [8]:
# fused_mesh = meshes[0]
# for i in range(1,len(meshes)):
#     T = align_meshes(fused_mesh, meshes[i])

#     fused_mesh = fused_mesh.apply_transform(T)

#     fused_mesh_o3d = trimesh_to_o3d(fused_mesh, copy=True)
#     # new_mesh_o3d = trimesh_to_o3d(meshes[i], copy=True)

#     # o3d.visualization.draw_geometries([fused_mesh_o3d, new_mesh_o3d])

#     fused_mesh_o3d = mesh_fusion(
#         mesh1=fused_mesh,
#         mesh2=meshes[i],
#         h_alpha=8,
#     )
#     print(f'showing iteration {i}')

#     o3d.visualization.draw_geometries([fused_mesh_o3d])

#     fused_mesh = o3d_to_trimesh(fused_mesh_o3d)

# o3d.visualization.draw_geometries([fused_mesh_o3d])