In [1]:
import math
import numpy as np
import open3d as o3d
from sklearn.cluster import MeanShift
import trimesh
from trimesh import Trimesh
from utils import (
    trimesh_to_o3d,
    o3d_to_trimesh,
    smooth_normals,
    calc_local_spacing,
    find_cyl_neighbours,
    calc_local_spacing,
    compute_overlap_set,
    trilateral_shift,
    get_o3d_colours_from_trimesh,
)

### Load meshes and extract data

In [2]:
mesh1 = trimesh.load_mesh('meshes/bottle_1.ply')
mesh2 = trimesh.load_mesh('meshes/bottle_2.ply')

mesh1_o3d = trimesh_to_o3d(mesh1)
mesh2_o3d = trimesh_to_o3d(mesh2)

In [3]:
points1 = np.asarray(mesh1.vertices)
points2 = np.asarray(mesh2.vertices)
pointclouds = (points1, points2)
points = np.vstack(pointclouds)

colours1 = mesh1.visual.vertex_colors
colours2 = mesh2.visual.vertex_colors
colours = np.concat([colours1, colours2])

tree = o3d.geometry.KDTreeFlann(points.T)

mesh1_o3d.compute_vertex_normals()
mesh2_o3d.compute_vertex_normals()

normals1 = np.asarray(mesh1_o3d.vertex_normals)
normals2 = np.asarray(mesh2_o3d.vertex_normals)
normals = np.concat([normals1, normals2], axis=0)

scan_ids = np.concat([np.ones(len(these_points)) * i for (i, these_points) in enumerate(pointclouds)])

In [4]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd.colors = get_o3d_colours_from_trimesh(colours)

o3d.visualization.draw_geometries([pcd])

### Smooth normals and calculate local properties

In [5]:
normals = smooth_normals(points, normals, k=8, T=0.7, n_iters=5)

In [6]:
local_spacing_1, local_density_1 = calc_local_spacing(mesh1.vertices, np.asarray(mesh1_o3d.vertices))
local_spacing_2, local_density_2 = calc_local_spacing(mesh2.vertices, np.asarray(mesh2_o3d.vertices))
local_spacings = (local_spacing_1, local_spacing_2)
local_spacing = np.concat(local_spacings)

local_density = np.concat((local_spacing_1, local_spacing_2))

global_avg_spacing = (1/len(local_spacings)) * np.sum([(1/len(localspacing)) * np.sum(localspacing) for localspacing in local_spacings])

### Calculate overlapping region

In [7]:
h_alpha = 2.5
r_alpha = 2
overlap_idx, overlap_mask = compute_overlap_set(points, normals, local_spacing, scan_ids, h_alpha, r_alpha, tree)

In [19]:
tris = np.asarray(mesh1_o3d.triangles)
tris_to_keep = tris[~np.all(overlap_mask[tris], axis=1)]

mesh = Trimesh(points, tris_to_keep)
# mesh.show()

### Perform trilateral point shifting

In [9]:
trilat_shifted_pts = points
for i in range(3):
    trilat_shifted_pts = trilateral_shift(trilat_shifted_pts, normals, local_spacing, local_density, overlap_idx, tree, r_alpha, h_alpha)

In [10]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(trilat_shifted_pts)
pcd.colors = get_o3d_colours_from_trimesh(colours)

o3d.visualization.draw_geometries([pcd])

### Merge nearby clusters of points

In [11]:
overlap_idx = np.flatnonzero(overlap_mask)                            # (M,)
trilat_shifted_overlap_pts = trilat_shifted_pts[overlap_idx]          # (M, 3)
trilat_shifted_overlap_tree = o3d.geometry.KDTreeFlann(               # KD‑tree
    trilat_shifted_overlap_pts.T)

delta = np.sqrt(2) / 2
sigma = delta * global_avg_spacing

merge_mapping = -np.ones(len(points), dtype=int)

merged_pnts = []
merged_cols = []
merged_nrms = []

while (merge_mapping[overlap_idx] < 0).any():

    free_local_idx = np.flatnonzero(merge_mapping[overlap_idx] < 0)  # local indices
    id_local = np.random.choice(free_local_idx)                      # local id in [0, M)

    point  = trilat_shifted_overlap_pts[id_local]                    # (3,)
    normal = normals[overlap_idx[id_local]]                          # (3,)

    nbr_local, d2 = find_cyl_neighbours(
        point,
        normal,
        global_avg_spacing,
        h_alpha,
        delta,
        trilat_shifted_overlap_pts,
        trilat_shifted_overlap_tree,
        self_idx=None)

    nbr_global = overlap_idx[nbr_local]

    mask = merge_mapping[nbr_global] < 0
    nbr_global = nbr_global[mask]
    d2 = d2[mask]

    if len(nbr_global) == 0:
        raise RuntimeError("Point neighbourhood is unexpectedly empty!")

    w = np.exp(-d2 / (2 * sigma ** 2))
    w /= w.sum()

    merged_id = len(merged_pnts)
    merge_mapping[nbr_global] = merged_id

    merged_pnts.append(w @ trilat_shifted_pts[nbr_global])
    merged_cols.append(w @ colours[nbr_global])

    merged_nrm = w @ normals[nbr_global]
    merged_nrms.append(merged_nrm / np.linalg.norm(merged_nrm))

merged_pnts = np.vstack(merged_pnts)
merged_cols = np.vstack(merged_cols)
merged_nrms = np.vstack(merged_nrms)

In [12]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(merged_pnts)
pcd.colors = get_o3d_colours_from_trimesh(200*np.ones_like(merged_cols))

pcd2 = o3d.geometry.PointCloud()
pcd2.points = o3d.utility.Vector3dVector(points[merge_mapping==-1])
pcd2.colors = get_o3d_colours_from_trimesh(np.zeros((len(points[merge_mapping==-1]),4)))

o3d.visualization.draw_geometries([pcd,pcd2])

### Remesh the overlapping region

In [13]:
# Find non-overlapping vertices that are connected to overlapping vertices
# Mesh the union of the overlapping pointcloud and the points found in the previous step
# Take the union of this mesh and the non-overlapping mesh regions

In [None]:

tris1 = np.asarray(mesh1_o3d.triangles)
tris2 = np.asarray(mesh2_o3d.triangles)

tris2_shifted = tris2 + len(points1)
all_tris = np.vstack([tris1, tris2_shifted])

overlap_v = np.zeros(len(points), dtype=bool)
overlap_v[overlap_idx] = True

tri_has_overlap_any = overlap_v[all_tris].any(axis=1)

overlap_any_idx = np.unique(
    all_tris[tri_has_overlap_any]
)

border_mask = np.zeros(len(points), dtype=bool)
border_mask[overlap_any_idx] = True # Add both border verts and overlap verts
border_mask[merge_mapping!=-1] = False # Remove overlap verts

nonoverlap_nonborder_mask = np.zeros(len(points), dtype=bool)
nonoverlap_nonborder_mask[merge_mapping==-1] = True
nonoverlap_nonborder_mask[border_mask] = False

n_overlap, n_border, n_free = len(merged_pnts), border_mask.sum(), nonoverlap_nonborder_mask.sum()
#                        Overlapping  Border with overlap              Not overlapping or border (free)
new_points  = np.concat([merged_pnts, trilat_shifted_pts[border_mask], trilat_shifted_pts[nonoverlap_nonborder_mask]])
new_colours = np.concat([merged_cols, colours[border_mask],            colours[nonoverlap_nonborder_mask]])
new_normals = np.concat([merged_nrms, normals[border_mask],            normals[nonoverlap_nonborder_mask]])

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(new_points)
c = np.concat([230*np.ones((n_overlap,4)), 160*np.ones((n_border,4)), 0*np.ones((n_free,4))])
pcd.colors = get_o3d_colours_from_trimesh(c)

o3d.visualization.draw_geometries([pcd])

In [86]:
overlap_idx_from = overlap_idx
overlap_idx_to   = np.array(range(n_overlap))

border_idx_from  = np.array(range(len(points)))[border_mask]
border_idx_to    = np.array(range(n_border)) + n_overlap

free_idx_from    = np.array(range(len(points)))[nonoverlap_nonborder_mask]
free_idx_to      = np.array(range(n_free)) + n_overlap + n_border

mapping = merge_mapping # contains mappings for overlap area already
mapping[border_idx_from] = border_idx_to
mapping[free_idx_from]   = free_idx_to

In [99]:
pcd = o3d.geometry.PointCloud()
pcd.points  = o3d.utility.Vector3dVector(new_points[:n_overlap+n_border])
pcd.colors = get_o3d_colours_from_trimesh(new_colours[:n_overlap+n_border])

if new_normals[:n_overlap+n_border] is not None:
    pcd.normals = o3d.utility.Vector3dVector(new_normals[:n_overlap+n_border])
else:
    pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(
                            radius=2.5 * global_avg_spacing,  # ~10–30 neighbours
                            max_nn=30))
    pcd.orient_normals_consistent_tangent_plane(k=30)

ball_r   = 1.5 * global_avg_spacing
radii    = o3d.utility.DoubleVector([ball_r, ball_r * 2.0])

overlap_border_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
           pcd, radii)

overlap_border_mesh.remove_duplicated_vertices()
overlap_border_mesh.remove_duplicated_triangles()
overlap_border_mesh.remove_degenerate_triangles()
overlap_border_mesh.remove_non_manifold_edges()
overlap_border_mesh.compute_vertex_normals()

overlap_border_mesh.vertex_colors = pcd.colors

# o3d.visualization.draw_geometries([overlap_border_mesh])



overlap_border_triangles = np.asarray(overlap_border_mesh.triangles)
fused_mesh_triangles = np.concat([overlap_border_triangles, mapping[tris_to_keep]], axis=0)
fused_mesh = o3d.geometry.TriangleMesh(vertices=o3d.utility.Vector3dVector(new_points),
                                       triangles=o3d.utility.Vector3iVector(fused_mesh_triangles))

fused_mesh.remove_duplicated_vertices()
fused_mesh.remove_duplicated_triangles()
fused_mesh.remove_degenerate_triangles()
fused_mesh.remove_non_manifold_edges()
fused_mesh.compute_vertex_normals()

fused_mesh.vertex_colors = get_o3d_colours_from_trimesh(new_colours)

o3d.visualization.draw_geometries([fused_mesh])

In [24]:
import numpy as np
import open3d as o3d

pcd = o3d.geometry.PointCloud()
pcd.points  = o3d.utility.Vector3dVector(merged_pnts)
pcd.colors = get_o3d_colours_from_trimesh(merged_cols)

if merged_nrms is not None and len(merged_nrms) == len(merged_pnts):
    pcd.normals = o3d.utility.Vector3dVector(merged_nrms)
else:
    pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(
                            radius=2.5 * global_avg_spacing,  # ~10–30 neighbours
                            max_nn=30))
    pcd.orient_normals_consistent_tangent_plane(k=30)

ball_r   = 1.5 * global_avg_spacing
radii    = o3d.utility.DoubleVector([ball_r, ball_r * 2.0])

mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
           pcd, radii)

mesh.remove_duplicated_vertices()
mesh.remove_duplicated_triangles()
mesh.remove_degenerate_triangles()
mesh.remove_non_manifold_edges()
mesh.compute_vertex_normals()

mesh.vertex_colors = pcd.colors

# o3d.io.write_triangle_mesh("merged_mesh.ply", mesh, write_ascii=False)
o3d.visualization.draw_geometries([mesh])

### [Mean-Shift Clustering Implementation (didn't work well)]

In [53]:
# from tqdm import tqdm

# eps = 0.01
# delta = np.sqrt(2)/2
# mclust_shifted_points = np.copy(trilat_shifted_pts)
# mclust_shifted_normals = np.copy(normals)
# for i, point in tqdm(enumerate(trilat_shifted_pts), total=len(trilat_shifted_pts), desc="Mean-Shift Clustering"):
#     normal = normals[i]
#     shifted_point = point.copy()
#     shifted_normal = normal.copy()
#     it_count = 0
#     while True:
#         it_count += 1
#         nbr, d2 = find_cyl_neighbours(shifted_point,
#                                       shifted_normal,
#                                       global_avg_spacing,
#                                       h_alpha,
#                                       delta,
#                                       trilat_shifted_pts,
#                                       trilat_shifted_tree,
#                                       i)
        
#         if len(nbr) == 0:
#             break
#         if it_count == 5000:
#             print(nbr, d2)
        
#         nbr_points = trilat_shifted_pts[nbr]
#         nbr_normals  = normals[nbr]

#         hv = np.max(np.linalg.norm(nbr_points - shifted_point, axis=1))
#         hn = np.max(np.linalg.norm(nbr_normals - shifted_normal, axis=1))

#         hv = max(hv, 0.05*global_avg_spacing)
#         hn = max(hn, 0.05*global_avg_spacing)

#         diff_p   = nbr_points - shifted_point
#         dist2_p  = np.sum(diff_p*diff_p, axis=1)

#         diff_n   = nbr_normals - shifted_normal
#         dist2_n  = np.sum(diff_n*diff_n, axis=1)

#         w = np.exp(-dist2_p / (2*hv**2)) * np.exp(-dist2_n / (2*hn**2))
#         w_sum    = w.sum()

#         delta_p  = (w[:, None] * diff_p).sum(axis=0) / w_sum
#         delta_n  = (w[:, None] * diff_n).sum(axis=0) / w_sum

#         shifted_point  += delta_p
#         shifted_normal += delta_n
#         shifted_normal /= np.linalg.norm(shifted_normal)
        
#         if np.linalg.norm(delta_p) <= eps * global_avg_spacing:
#             break
        
#     mclust_shifted_points[i] = shifted_point
#     mclust_shifted_normals[i] = shifted_normal

In [54]:
# pcd = o3d.geometry.PointCloud()
# pcd.points = o3d.utility.Vector3dVector(mclust_shifted_points)
# pcd.colors = get_o3d_colours_from_trimesh(colours)

# # o3d.visualization.draw_geometries([pcd])