In [1]:
import math
import numpy as np
import open3d as o3d
from sklearn.cluster import MeanShift
import trimesh
from trimesh import Trimesh
from utils import (
    trimesh_to_o3d,
    o3d_to_trimesh,
    smooth_normals,
    calc_local_spacing,
    find_cyl_neighbours,
    calc_local_spacing,
    compute_overlap_set,
    trilateral_shift,
    get_o3d_colours_from_trimesh,
)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


### Load meshes and extract data

In [2]:
mesh1 = trimesh.load_mesh('meshes/bottle_1.ply')
mesh2 = trimesh.load_mesh('meshes/bottle_2.ply')

mesh1_o3d = trimesh_to_o3d(mesh1)
mesh2_o3d = trimesh_to_o3d(mesh2)

In [3]:
points1 = np.asarray(mesh1.vertices)
points2 = np.asarray(mesh2.vertices)
pointclouds = (points1, points2)
points = np.vstack(pointclouds)

colours1 = mesh1.visual.vertex_colors
colours2 = mesh2.visual.vertex_colors
colours = np.concat([colours1, colours2])

tree = o3d.geometry.KDTreeFlann(points.T)

mesh1_o3d.compute_vertex_normals()
mesh2_o3d.compute_vertex_normals()

normals1 = np.asarray(mesh1_o3d.vertex_normals)
normals2 = np.asarray(mesh2_o3d.vertex_normals)
normals = np.concat([normals1, normals2], axis=0)

scan_ids = np.concat([np.ones(len(these_points)) * i for (i, these_points) in enumerate(pointclouds)])

In [22]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd.colors = get_o3d_colours_from_trimesh(colours)

o3d.visualization.draw_geometries([pcd])

### Smooth normals and calculate local properties

In [5]:
normals = smooth_normals(points, normals, k=8, T=0.7, n_iters=5)

In [6]:
local_spacing_1, local_density_1 = calc_local_spacing(mesh1.vertices, np.asarray(mesh1_o3d.vertices))
local_spacing_2, local_density_2 = calc_local_spacing(mesh2.vertices, np.asarray(mesh2_o3d.vertices))
local_spacings = (local_spacing_1, local_spacing_2)
local_spacing = np.concat(local_spacings)

local_density = np.concat((local_spacing_1, local_spacing_2))

global_avg_spacing = (1/len(local_spacings)) * np.sum([(1/len(localspacing)) * np.sum(localspacing) for localspacing in local_spacings])

### Calculate overlapping region

In [7]:
h_alpha = 2.5
r_alpha = 2
overlap_idx, overlap_mask = compute_overlap_set(points, normals, local_spacing, scan_ids, h_alpha, r_alpha, tree)

In [8]:
tris = np.asarray(mesh1_o3d.triangles)
tris_to_keep = tris[~np.all(overlap_mask[tris], axis=1)] # Tris where at most two vertices are in the overlap set

mesh = Trimesh(points, tris_to_keep)
# mesh.show()

### Perform trilateral point shifting

In [9]:
trilat_shifted_pts = points
for i in range(3):
    trilat_shifted_pts = trilateral_shift(trilat_shifted_pts, normals, local_spacing, local_density, overlap_idx, tree, r_alpha, h_alpha)

In [23]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(trilat_shifted_pts)
pcd.colors = get_o3d_colours_from_trimesh(colours)

o3d.visualization.draw_geometries([pcd])

### Merge nearby clusters of points

In [39]:
overlap_idx = np.flatnonzero(overlap_mask)                            # (M,)
trilat_shifted_overlap_pts = trilat_shifted_pts[overlap_idx]          # (M, 3)
trilat_shifted_overlap_tree = o3d.geometry.KDTreeFlann(               # KD‑tree
    trilat_shifted_overlap_pts.T)

delta = np.sqrt(2) / 2
sigma = delta * global_avg_spacing

merge_mapping = -np.ones(len(points), dtype=int)

merged_pnts = []
merged_cols = []
merged_nrms = []

while (merge_mapping[overlap_idx] < 0).any():

    free_local_idx = np.flatnonzero(merge_mapping[overlap_idx] < 0)  # local indices
    id_local = np.random.choice(free_local_idx)                      # local id in [0, M)

    point  = trilat_shifted_overlap_pts[id_local]                    # (3,)
    normal = normals[overlap_idx[id_local]]                          # (3,)

    nbr_local, d2 = find_cyl_neighbours(
        point,
        normal,
        global_avg_spacing,
        h_alpha,
        delta,
        trilat_shifted_overlap_pts,
        trilat_shifted_overlap_tree,
        self_idx=None)

    nbr_global = overlap_idx[nbr_local]

    mask = merge_mapping[nbr_global] < 0
    nbr_global = nbr_global[mask]
    d2 = d2[mask]

    if len(nbr_global) == 0:
        raise RuntimeError("Point neighbourhood is unexpectedly empty!")

    w = np.exp(-d2 / (2 * sigma ** 2))
    w /= w.sum()

    merged_id = len(merged_pnts)
    merge_mapping[nbr_global] = merged_id

    merged_pnts.append(w @ trilat_shifted_pts[nbr_global])
    merged_cols.append(w @ colours[nbr_global])

    merged_nrm = w @ normals[nbr_global]
    merged_nrms.append(merged_nrm / np.linalg.norm(merged_nrm))

merged_pnts = np.vstack(merged_pnts)
merged_cols = np.vstack(merged_cols)
merged_nrms = np.vstack(merged_nrms)

In [40]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(merged_pnts)
pcd.colors = get_o3d_colours_from_trimesh(merged_cols)

o3d.visualization.draw_geometries([pcd])

### Remesh the overlapping region

In [50]:
# Find non-overlapping vertices that are connected to overlapping vertices
# Mesh the union of the overlapping pointcloud and the points found in the previous step
# Take the union of this mesh and the non-overlapping mesh regions

In [33]:
tris1 = np.asarray(mesh1_o3d.triangles)
tris2 = np.asarray(mesh2_o3d.triangles)

tris2_shifted = tris2 + len(points1)
all_tris = np.vstack([tris1, tris2_shifted])

overlap_v = np.zeros(len(points), dtype=bool)
overlap_v[overlap_idx] = True

tri_has_overlap_any = overlap_v[all_tris].any(axis=1)
print(tri_has_overlap_any)

# vertices in those mixed triangles that are not in the overlap set
unmerged_overlap_and_border_idx = np.unique(
    all_tris[tri_has_overlap_any]
)

merged_overlap_and_border_idx = np.unique(
    merge_mapping[unmerged_overlap_and_border_idx]
)

merged_overlap_and_border_mask = np.zeros(len(points), dtype=bool)
merged_overlap_and_border_mask[merged_overlap_and_border_idx] = True


pcd = o3d.geometry.PointCloud()
pcd.points  = o3d.utility.Vector3dVector(merged_pnts[merged_overlap_and_border_idx])
pcd.colors = get_o3d_colours_from_trimesh(merged_cols[merged_overlap_and_border_idx])

if merged_nrms[merged_overlap_and_border_idx] is not None and len(merged_pnts[merged_overlap_and_border_idx]) == len(merged_cols[merged_overlap_and_border_idx]):
    pcd.normals = o3d.utility.Vector3dVector(merged_nrms[merged_overlap_and_border_idx])
else:
    pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(
                            radius=2.5 * global_avg_spacing,  # ~10–30 neighbours
                            max_nn=30))
    pcd.orient_normals_consistent_tangent_plane(k=30)

ball_r   = 1.5 * global_avg_spacing
radii    = o3d.utility.DoubleVector([ball_r, ball_r * 2.0])

fused_overlap_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
           pcd, radii)

fused_overlap_mesh.remove_duplicated_vertices()
fused_overlap_mesh.remove_duplicated_triangles()
fused_overlap_mesh.remove_degenerate_triangles()
fused_overlap_mesh.remove_non_manifold_edges()
fused_overlap_mesh.compute_vertex_normals()

fused_overlap_mesh.vertex_colors = pcd.colors

o3d.visualization.draw_geometries([fused_overlap_mesh])




# merged_pts['connected to an overlapping vertex']

# border_pts   = points[border_vertex_idx]
# border_cols  = colours[border_vertex_idx]
# border_nrms  = normals[border_vertex_idx]

# union_pts  = np.vstack([merged_pnts, border_pts])
# union_cols = np.vstack([merged_cols, border_cols])
# union_nrms = np.vstack([merged_nrms, border_nrms])

# pcd_union           = o3d.geometry.PointCloud()
# pcd_union.points    = o3d.utility.Vector3dVector(union_pts)
# pcd_union.colors    = get_o3d_colours_from_trimesh(union_cols)
# pcd_union.normals   = o3d.utility.Vector3dVector(union_nrms)

# if np.isnan(union_nrms).any() or not np.all(np.linalg.norm(union_nrms, axis=1)):
#     print('Invalid normals; recalculating')
#     pcd_union.estimate_normals(
#         search_param=o3d.geometry.KDTreeSearchParamHybrid(
#             radius=2.5 * global_avg_spacing, max_nn=30)
#     )
#     pcd_union.orient_normals_consistent_tangent_plane(k=30)

# ball_r_union = 1.5 * global_avg_spacing
# radii_union  = o3d.utility.DoubleVector([ball_r_union, ball_r_union * 2.0])

# mesh_union = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
#     pcd_union, radii_union
# )

# mesh_union.remove_duplicated_vertices()
# mesh_union.remove_duplicated_triangles()
# mesh_union.remove_degenerate_triangles()
# mesh_union.remove_non_manifold_edges()
# mesh_union.compute_vertex_normals()

# mesh_union.vertex_colors = pcd_union.colors

points_to_plot = merged_pnts[merged_overlap_and_border_idx]
pcd_border           = o3d.geometry.PointCloud()
pcd_border.points    = o3d.utility.Vector3dVector(points_to_plot)
pcd.colors = get_o3d_colours_from_trimesh(np.zeros((len(points_to_plot),4)))

o3d.visualization.draw_geometries([pcd_border])
# o3d.visualization.draw_geometries([mesh_union, pcd_border])


# mesh_static = Trimesh(points, tris_to_keep, process=False)

# mesh_union_tm = o3d_to_trimesh(mesh_union, copy=True)


# stitched_tm = trimesh.util.concatenate(
#     [mesh_static, mesh_union_tm]
# )

# unique_f = stitched_tm.unique_faces()          # returns boolean mask
# stitched_tm.update_faces(unique_f)

# stitched_tm.merge_vertices()

# non_deg_f = stitched_tm.nondegenerate_faces()  # boolean mask
# stitched_tm.update_faces(non_deg_f)

# stitched_tm.remove_unreferenced_vertices()

# stitched_tm.process(validate=True)

# stitched_tm.visual.vertex_colors = np.asarray(stitched_tm.vertices) * 0  # init
# for old, new in stitched_tm.vertex_adjacency_graph.items():
#     stitched_tm.visual.vertex_colors[new] = colours[old]

# stitched_tm.show()
# trimesh.exchange.export.export_mesh(.)
# o3d.visualization.draw_geometries([trimesh_to_o3d(stitched_tm)])


[ True  True  True ...  True  True  True]


In [24]:
import numpy as np
import open3d as o3d

pcd = o3d.geometry.PointCloud()
pcd.points  = o3d.utility.Vector3dVector(merged_pnts)
pcd.colors = get_o3d_colours_from_trimesh(merged_cols)

if merged_nrms is not None and len(merged_nrms) == len(merged_pnts):
    pcd.normals = o3d.utility.Vector3dVector(merged_nrms)
else:
    pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(
                            radius=2.5 * global_avg_spacing,  # ~10–30 neighbours
                            max_nn=30))
    pcd.orient_normals_consistent_tangent_plane(k=30)

ball_r   = 1.5 * global_avg_spacing
radii    = o3d.utility.DoubleVector([ball_r, ball_r * 2.0])

mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
           pcd, radii)

mesh.remove_duplicated_vertices()
mesh.remove_duplicated_triangles()
mesh.remove_degenerate_triangles()
mesh.remove_non_manifold_edges()
mesh.compute_vertex_normals()

mesh.vertex_colors = pcd.colors

# o3d.io.write_triangle_mesh("merged_mesh.ply", mesh, write_ascii=False)
o3d.visualization.draw_geometries([mesh])

### [Mean-Shift Clustering Implementation (didn't work well)]

In [53]:
# from tqdm import tqdm

# eps = 0.01
# delta = np.sqrt(2)/2
# mclust_shifted_points = np.copy(trilat_shifted_pts)
# mclust_shifted_normals = np.copy(normals)
# for i, point in tqdm(enumerate(trilat_shifted_pts), total=len(trilat_shifted_pts), desc="Mean-Shift Clustering"):
#     normal = normals[i]
#     shifted_point = point.copy()
#     shifted_normal = normal.copy()
#     it_count = 0
#     while True:
#         it_count += 1
#         nbr, d2 = find_cyl_neighbours(shifted_point,
#                                       shifted_normal,
#                                       global_avg_spacing,
#                                       h_alpha,
#                                       delta,
#                                       trilat_shifted_pts,
#                                       trilat_shifted_tree,
#                                       i)
        
#         if len(nbr) == 0:
#             break
#         if it_count == 5000:
#             print(nbr, d2)
        
#         nbr_points = trilat_shifted_pts[nbr]
#         nbr_normals  = normals[nbr]

#         hv = np.max(np.linalg.norm(nbr_points - shifted_point, axis=1))
#         hn = np.max(np.linalg.norm(nbr_normals - shifted_normal, axis=1))

#         hv = max(hv, 0.05*global_avg_spacing)
#         hn = max(hn, 0.05*global_avg_spacing)

#         diff_p   = nbr_points - shifted_point
#         dist2_p  = np.sum(diff_p*diff_p, axis=1)

#         diff_n   = nbr_normals - shifted_normal
#         dist2_n  = np.sum(diff_n*diff_n, axis=1)

#         w = np.exp(-dist2_p / (2*hv**2)) * np.exp(-dist2_n / (2*hn**2))
#         w_sum    = w.sum()

#         delta_p  = (w[:, None] * diff_p).sum(axis=0) / w_sum
#         delta_n  = (w[:, None] * diff_n).sum(axis=0) / w_sum

#         shifted_point  += delta_p
#         shifted_normal += delta_n
#         shifted_normal /= np.linalg.norm(shifted_normal)
        
#         if np.linalg.norm(delta_p) <= eps * global_avg_spacing:
#             break
        
#     mclust_shifted_points[i] = shifted_point
#     mclust_shifted_normals[i] = shifted_normal

In [54]:
# pcd = o3d.geometry.PointCloud()
# pcd.points = o3d.utility.Vector3dVector(mclust_shifted_points)
# pcd.colors = get_o3d_colours_from_trimesh(colours)

# # o3d.visualization.draw_geometries([pcd])