In [1]:
import numpy as np
import open3d as o3d  # type: ignore
import pymeshfix # type: ignore
from scipy.spatial import cKDTree # type: ignore

from superprimitive_fusion.mesh_fusion_utils import (
    smooth_normals,
    calc_local_spacing,
    compute_overlap_set_cached,
    smooth_overlap_set_cached,
    precompute_cyl_neighbours,
    trilateral_shift_cached,
    find_boundary_edges,
    topological_trim,
    merge_nearby_clusters,
    get_mesh_components,
)
from superprimitive_fusion.mesh_fusion import (
    fuse_meshes,
    sanitise_mesh,
)
from superprimitive_fusion.scanner import (
    virtual_rgbd_scan,
    capture_spherical_scans,
)
from superprimitive_fusion.utils import (
    bake_uv_to_vertex_colours,
    polar2cartesian,
)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
mesh = o3d.io.read_triangle_mesh("../data/power-drill/textured.obj", enable_post_processing=True)

bake_uv_to_vertex_colours(mesh)

mesh.compute_vertex_normals()

bb = mesh.get_minimal_oriented_bounding_box()
scale = np.mean(bb.get_max_bound())

In [3]:
scans = capture_spherical_scans(
    mesh=mesh,
    mask_generator=None,
    num_views=6,
    radius=0.3,
    width_px=360,
    height_px=240,
    fov=70,
    dropout_rate=0,
    depth_error_std=0.0,
    translation_error_std=0,
    rotation_error_std_degs=0,
    k=10,
    sampler="fibonacci",
)
meshes = [scan['mesh'] for scan in scans]

o3d.visualization.draw_geometries(
    meshes,
    window_name="Virtual scan",
    front=[0.3, 1, 0],
    lookat=[0, 0, 0],
    up=[0, 0, 1],
    zoom=0.7,
)

In [4]:
# j = 6
# fused_mesh = scans[0]
# for i in range(1,j):
#     fused_mesh = fuse_meshes(
#         fused_mesh,
#         scans[i],
#         h_alpha=3,
#         trilat_iters=2,
#         shift_all=False,
#         fill_holes=True,
#     )

# o3d.visualization.draw_geometries(
#     [fused_mesh],
#     window_name="Virtual scan",
#     front=[0.3, 1, 0],
#     lookat=[0, 0, 0],
#     up=[0, 0, 1],
#     zoom=0.7,
# )

In [5]:
# show_connected_components(fused_mesh)

In [6]:
# mesh1 = o3d.io.read_triangle_mesh("meshes/bottle_1.ply")
# mesh2 = o3d.io.read_triangle_mesh("meshes/bottle_2.ply")
mesh1 = meshes[3]
mesh2 = meshes[4]
h_alpha: float = 3.0
r_alpha: float = 2.0
trilat_iters: int = 2
nrm_smth_iters: int = 1
shift_all: bool = False

In [7]:
import copy
m1 = copy.deepcopy(mesh1)
m1.paint_uniform_color([1,0,0])
m2 = copy.deepcopy(mesh2)
m2.paint_uniform_color([0,1,0])
o3d.visualization.draw_geometries([m1,m2])

In [8]:
# ---------------------------------------------------------------------
# Raw geometry & attribute extraction
# ---------------------------------------------------------------------
points1 = np.asarray(mesh1.vertices)
points2 = np.asarray(mesh2.vertices)

pointclouds = (points1, points2)
points = np.vstack(pointclouds)

colours1 = mesh1.vertex_colors
colours2 = mesh2.vertex_colors
colours = np.concatenate([colours1, colours2], axis=0)

kd_tree = o3d.geometry.KDTreeFlann(points.T)

In [9]:
# ---------------------------------------------------------------------
# Normals
# ---------------------------------------------------------------------
mesh1.compute_vertex_normals()
mesh2.compute_vertex_normals()

normals1 = np.asarray(mesh1.vertex_normals)
normals2 = np.asarray(mesh2.vertex_normals)
normals = np.concatenate([normals1, normals2], axis=0)

scan_ids = np.concatenate([np.full(len(pts), i) for i, pts in enumerate(pointclouds)])

normals = smooth_normals(points, normals, tree=kd_tree, k=8, T=0.7, n_iters=nrm_smth_iters)

In [10]:
# ---------------------------------------------------------------------
# Local geometric properties
# ---------------------------------------------------------------------
local_spacing_1, local_density_1 = calc_local_spacing(points1, points1, tree=kd_tree)
local_spacing_2, local_density_2 = calc_local_spacing(points2, points2, tree=kd_tree)
local_spacings = (local_spacing_1, local_spacing_2)

local_spacing = np.concatenate(local_spacings)
local_density = np.concatenate((local_density_1, local_density_2))

global_avg_spacing = (1 / len(local_spacings)) * np.sum(
    [(1 / len(ls)) * np.sum(ls) for ls in local_spacings]
)

In [11]:
# ---------------------------------------------------------------------
# Overlap detection
# ---------------------------------------------------------------------
nbr_cache = precompute_cyl_neighbours(points, normals, local_spacing, r_alpha, h_alpha, kd_tree)

overlap_idx, overlap_mask = compute_overlap_set_cached(scan_ids, nbr_cache)
overlap_idx, overlap_mask = smooth_overlap_set_cached(overlap_mask, nbr_cache)
overlap_idx, overlap_mask = smooth_overlap_set_cached(overlap_mask, nbr_cache, p_thresh=1)

In [12]:
debug_colours = np.array([
    [255,   0,   0],
    [  0, 255,   0],
])
f0 = overlap_mask
f1 = ~overlap_mask
filters = np.column_stack((f0, f1))
idx = filters.argmax(axis=1)
row_colours = debug_colours[idx]
# row_colours[oidx] = np.array([0,0,0])

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd.colors = o3d.utility.Vector3dVector(row_colours)

# o3d.visualization.draw_geometries([pcd])

In [13]:
m1 = scan_ids==0
m2 = overlap_mask
debug_colours = np.array([
    [255,   0,   0],
    [  0, 255,   0],
    [  0,   0, 255],
    [255,   0, 255],
])
f0 = ~m1 & ~m2
f1 =  m1 & ~m2
f2 = ~m1 &  m2
f3 =  m1 &  m2
filters = np.column_stack((f0, f1, f2, f3))
idx = filters.argmax(axis=1)
row_colours = debug_colours[idx]

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd.colors = o3d.utility.Vector3dVector(row_colours)

o3d.visualization.draw_geometries([pcd])

In [14]:
# ---------------------------------------------------------------------
# Find overlap boundary edges
# ---------------------------------------------------------------------
tris1 = np.asarray(mesh1.triangles)
tris2 = np.asarray(mesh2.triangles) + len(points1)  # shift indices
all_tris = np.concatenate([tris1, tris2], axis=0)

nonoverlap_tris = all_tris[~np.all(overlap_mask[all_tris], axis=1)]
boundary_edges = find_boundary_edges(nonoverlap_tris)

In [15]:
# ---------------------------------------------------------------------
# Trilateral point shifting
# ---------------------------------------------------------------------
trilat_shifted_pts = points.copy()
for _ in range(trilat_iters):
    trilat_shifted_pts = trilateral_shift_cached(trilat_shifted_pts, normals, local_spacing, local_density, overlap_idx, nbr_cache, r_alpha, h_alpha, shift_all)

kd_tree = o3d.geometry.KDTreeFlann(trilat_shifted_pts.T)

In [16]:
# ---------------------------------------------------------------------
# Merge nearby clusters
# ---------------------------------------------------------------------
cluster_mapping, clustered_overlap_pnts, clustered_overlap_cols, clustered_overlap_nrms = merge_nearby_clusters(
    trilat_shifted_pts=trilat_shifted_pts,
    normals=normals,
    colours=colours,
    overlap_mask=overlap_mask,
    overlap_idx=overlap_idx,
    global_avg_spacing=global_avg_spacing,
    h_alpha=h_alpha,
    tree=kd_tree,
)

In [17]:
# ---------------------------------------------------------------------
# Classify vertices
# ---------------------------------------------------------------------
tri_has_overlap_any = overlap_mask[all_tris].any(axis=1)
overlap_any_idx = np.unique(all_tris[tri_has_overlap_any])

border_mask = np.zeros(len(points), dtype=bool)
border_mask[overlap_any_idx] = True
border_mask[cluster_mapping != -1] = False

nonoverlap_nonborder_mask = np.zeros(len(points), dtype=bool)
nonoverlap_nonborder_mask[cluster_mapping == -1] = True
nonoverlap_nonborder_mask[border_mask] = False

n_overlap = len(clustered_overlap_pnts)
n_border = border_mask.sum()
n_free = nonoverlap_nonborder_mask.sum()

new_points = np.concatenate(
    [
        clustered_overlap_pnts,
        trilat_shifted_pts[border_mask],
        trilat_shifted_pts[nonoverlap_nonborder_mask],
    ],
    axis=0,
)
new_colours = np.concatenate(
    [
        clustered_overlap_cols,
        colours[border_mask],
        colours[nonoverlap_nonborder_mask],
    ],
    axis=0,
)
new_normals = np.concatenate(
    [
        clustered_overlap_nrms,
        normals[border_mask],
        normals[nonoverlap_nonborder_mask],
    ],
    axis=0,
)

new_colours = np.clip(new_colours, 0, 1)

In [18]:
# ---------------------------------------------------------------------
# Complete mapping
# ---------------------------------------------------------------------
border_idx_from = np.arange(len(points))[border_mask]
border_idx_to = np.arange(n_border) + n_overlap

free_idx_from = np.arange(len(points))[nonoverlap_nonborder_mask]
free_idx_to = np.arange(n_free) + n_overlap + n_border

mapping = cluster_mapping.copy()
mapping[border_idx_from] = border_idx_to
mapping[free_idx_from] = free_idx_to

In [19]:
# ---------------------------------------------------------------------
# Mesh the overlap zone
# ---------------------------------------------------------------------
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(new_points[:n_overlap])

if new_normals[:n_overlap] is not None:
    pcd.normals = o3d.utility.Vector3dVector(new_normals[:n_overlap])
else:
    pcd.estimate_normals(
        search_param=o3d.geometry.KDTreeSearchParamHybrid(
            radius=2.5 * global_avg_spacing, max_nn=30
        )
    )
    pcd.orient_normals_consistent_tangent_plane(k=30)

# r_min = global_avg_spacing
# radii = o3d.utility.DoubleVector(np.geomspace(r_min, r_min*4, num=5))

points_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points))
dists = points_pcd.compute_nearest_neighbor_distance()
s = np.median(dists)

radii = o3d.utility.DoubleVector([1.2*s, 2*s, 3*s, 4*s])

overlap_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
    pcd, radii
)

overlap_mesh.remove_duplicated_vertices()
overlap_mesh.remove_duplicated_triangles()
overlap_mesh.remove_degenerate_triangles()
overlap_mesh.remove_non_manifold_edges()
overlap_mesh.compute_vertex_normals()

TriangleMesh with 2604 points and 4737 triangles.

In [20]:
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(new_points)
colours = np.concatenate([
    np.full((n_overlap,3),[255, 0, 0]),
    np.full((n_border,3), [0, 255, 0]),
    np.full((n_free,3),   [0, 0, 255]),
])
pcd.colors = o3d.utility.Vector3dVector(colours)

o3d.visualization.draw_geometries([overlap_mesh, pcd])

In [21]:
new_nonoverlap_tris = mapping[nonoverlap_tris]
nonoverlap_mesh = o3d.geometry.TriangleMesh(
    vertices=o3d.utility.Vector3dVector(new_points),
    triangles=o3d.utility.Vector3iVector(new_nonoverlap_tris),
)

# o3d.visualization.draw_geometries([nonoverlap_mesh, pcd])

In [22]:
# o3d.visualization.draw_geometries([mesh1, mesh2, pcd])

In [23]:
# ---------------------------------------------------------------------
# Trim overlap mesh
# ---------------------------------------------------------------------
mapped_boundary_edges = mapping[boundary_edges]
relevant_boundary_edges = mapped_boundary_edges[
    np.all(mapped_boundary_edges < len(overlap_mesh.vertices), axis=1)
]

k = len(overlap_mesh.cluster_connected_triangles()[1])

trimmed_overlap_mesh = topological_trim(
    overlap_mesh, relevant_boundary_edges, k=k,
)

# o3d.visualization.draw_geometries([trimmed_overlap_mesh, pcd])

In [24]:
trimmed_overlap_tris = np.asarray(trimmed_overlap_mesh.triangles)
fused_mesh_triangles = np.concatenate(
    [trimmed_overlap_tris, mapping[nonoverlap_tris]], axis=0
)

fused_mesh = o3d.geometry.TriangleMesh(
    vertices=o3d.utility.Vector3dVector(new_points),
    triangles=o3d.utility.Vector3iVector(fused_mesh_triangles),
)

fused_mesh.vertex_colors = o3d.utility.Vector3dVector(new_colours)

fused_mesh.remove_unreferenced_vertices()
fused_mesh.remove_duplicated_triangles()
fused_mesh.remove_duplicated_vertices()
fused_mesh.remove_degenerate_triangles()
fused_mesh.remove_non_manifold_edges()
fused_mesh.compute_vertex_normals()

TriangleMesh with 19807 points and 38079 triangles.

In [26]:
o3d.visualization.draw_geometries([fused_mesh, pcd])

In [None]:
V0 = np.ascontiguousarray(new_points, dtype=np.float64)
F0 = np.ascontiguousarray(fused_mesh.triangles, dtype=np.int32)
C0 = new_colours

# sanity checks
if not np.isfinite(V0).all():
    raise ValueError("Non-finite vertex coordinates")
if F0.min() < 0 or F0.max() >= len(V0):
    raise ValueError("Face indices out of range")

V0, F0 = sanitise_mesh(new_points, fused_mesh_triangles)

mf = pymeshfix.PyTMesh(False)
mf.load_array(V0, F0)

mf.fill_small_boundaries(nbe=10)

V1, F1 = mf.return_arrays()

# ---- colour transfer
tree = cKDTree(V0)
dist, idx = tree.query(V1, k=1)
C1 = C0[idx]

mask = dist > 1e-6
if mask.any():
    dist3, idx3 = tree.query(V1[mask], k=3)
    w = 1.0 / (dist3 + 1e-12)
    w /= w.sum(axis=1, keepdims=True)
    C1[mask] = (C0[idx3] * w[..., None]).sum(axis=1)

repaired = o3d.geometry.TriangleMesh()
repaired.vertices = o3d.utility.Vector3dVector(V1)
repaired.triangles = o3d.utility.Vector3iVector(F1)
repaired.vertex_colors = o3d.utility.Vector3dVector(C1)
repaired.compute_vertex_normals()

o3d.visualization.draw_geometries([repaired])