In [1]:
import copy
from tqdm import tqdm
import numpy as np
import open3d as o3d
import matplotlib.pyplot as plt

from superprimitive_fusion.mesh_fusion_utils import (
    smooth_normals,
    calc_local_spacing,
    compute_overlap_set_cached,
    smooth_overlap_set_cached,
    precompute_cyl_neighbours,
    update_weights,
    normal_shift_smooth,
    find_boundary_edges,
    topological_trim,
    merge_nearby_clusters,
    compact_by_faces,
)
from superprimitive_fusion.scanner import (
    virtual_mesh_scan,
    mesh_depth_image,
    generate_rgbd_noise,
    clean_mesh_and_remap_weights,
)
from superprimitive_fusion.utils import (
    bake_uv_to_vertex_colours,
    polar2cartesian,
    distinct_colours,
)
from superprimitive_fusion.mesh_fusion import (
    fuse_meshes,
)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
names = (
    ('mustard-bottle', 'mustard-bottle.obj'),
    ('table', 'table.obj'),
    # ('power-drill', 'power-drill.obj'),
    # ('bleach', 'bleach.obj'),
    # ('pitcher', 'pitcher.obj'),
    # ('mug', 'mug.obj'),
    # ('extra-large-clamp', 'extra-large-clamp-leaning.obj'),
)

gt_meshes = dict()
for foldername,filename in names:
    print(f'Getting the {foldername}')
    
    gt_mesh = o3d.io.read_triangle_mesh(f"../data/posed-meshes/{foldername}/{filename}", enable_post_processing=True)

    bake_uv_to_vertex_colours(gt_mesh)
    gt_mesh.compute_vertex_normals()

    gt_meshes[foldername] = gt_mesh

gt_mesh_list = list(gt_meshes.values())

Getting the mustard-bottle
Getting the table


In [3]:
centres = []
for gt_meshname, gt_mesh in gt_meshes.items():
    if gt_meshname == 'table':
        continue
    centres.append(gt_mesh.get_center())

centres = np.vstack(centres)

obj_centre = centres.mean(axis=0)

In [4]:
cam_centre_offset=np.array([0, 0, 0.2])
look_at = obj_centre
width_px: int = 360
height_px: int = 240
fov: float = 70.0
k: float = 3.5
max_normal_angle_deg = None

In [5]:
scans = []
# N = 10
# for theta in tqdm(np.linspace(0,360/N * (N-1), N), desc='Scanning'):
# for theta in np.linspace(30,180, N):
for theta in [70, 110]:
    object_meshes, object_weights = virtual_mesh_scan(
        meshlist=gt_mesh_list,
        cam_centre=cam_centre_offset + polar2cartesian(0.8, 30, theta),
        look_at=look_at,
        k=10,
        max_normal_angle_deg=None,
        linear_depth_sigma=5e-4,
        quadrt_depth_sigma=1e-4,
        sigma_floor=0,
        bias_k1=0.025,
    )
    scans.append([(m,w) for m,w in zip(object_meshes, object_weights)])

meshlists = [[scan_obj[0] for scan_obj in scan] for scan in scans]
meshes = np.array(meshlists).flatten().tolist()

for mesh in meshes:
    mesh.compute_vertex_normals()
o3d.visualization.draw_geometries(meshes)

  z = ((points - cam_centre) @ L).clip(min=0.0) # (H,W)


In [6]:
fused_scan = scans[0]
N = len(scans[0])
assert len(set([len(scan) for scan in scans])) == 1

for t in range(1,len(scans)):
    fused_scan_meshes = []
    fused_scan_weights = []
    for obj_id in range(N):
        mesh1, weights1 = fused_scan[obj_id]
        mesh2, weights2 = scans[t][obj_id]

        if len(np.asarray(mesh1.vertices)) == 0:
            # Either one or both meshes are empty
            print(f'Scan {t-1} obj {obj_id} is empty')
            fused_obj_mesh    = mesh2
            fused_obj_weights = weights2
        elif len(np.asarray(mesh2.vertices)) == 0:
            print(f'Scan {t} obj {obj_id} is empty')
            # mesh2 is bad but mesh1 is fine
            fused_obj_mesh    = mesh1
            fused_obj_weights = weights1
        else:
            # both meshes fine
            print(f'Fusing obj {obj_id} scans {t-1,t}')
            fused_obj_mesh, fused_obj_weights = fuse_meshes(
                mesh1           = mesh1,
                weights1        = weights1,
                mesh2           = mesh2,
                weights2        = weights2,
                h_alpha         = 2.5,
                r_alpha         = 2.0,
                nrm_shift_iters = 4,
                nrm_smth_iters  = 1,
                shift_all       = False,
                fill_holes      = False,
            )
        fused_scan_meshes.append(fused_obj_mesh)
        fused_scan_weights.append(fused_obj_weights)
        
    fused_scan = [(m,w) for m,w in zip(fused_scan_meshes, fused_scan_weights)]

Fusing obj 0 scans (0, 1)
Fusing obj 1 scans (0, 1)


In [7]:
fused_mesh_result = [obj[0] for obj in fused_scan]
o3d.visualization.draw_geometries(fused_mesh_result)

In [8]:
mesh1, weights1 = scans[0][1]
mesh2, weights2 = scans[1][1]
h_alpha: float = 2.5
r_alpha: float = 2.0
nrm_shift_iters: int = 2
nrm_smth_iters: int = 1
sigma_theta: float = 0.2
normal_diff_thresh: float = 45.0
tau_max: float|None = None
shift_all: bool = False
fill_holes: bool = False

In [9]:
# DEBUG
m1 = copy.deepcopy(mesh1)
m1.paint_uniform_color([1,0,0])
m2 = copy.deepcopy(mesh2)
m2.paint_uniform_color([0,1,0])
o3d.visualization.draw_geometries([m1, m2])

In [28]:
# ---------------------------------------------------------------------
# Raw geometry & attribute extraction
# ---------------------------------------------------------------------
points1 = np.asarray(mesh1.vertices)
points2 = np.asarray(mesh2.vertices)

pointclouds = (points1, points2)
points = np.vstack(pointclouds)

assert weights1.ndim == 1 and len(weights1) == len(mesh1.vertices)
assert weights2.ndim == 1 and len(weights2) == len(mesh2.vertices)
weights = np.concatenate((weights1, weights2))

colours1 = mesh1.vertex_colors
colours2 = mesh2.vertex_colors
colours = np.concatenate([colours1, colours2], axis=0)

kd_tree = o3d.geometry.KDTreeFlann(points.T)

In [11]:
# ---------------------------------------------------------------------
# Normals
# ---------------------------------------------------------------------
mesh1.compute_vertex_normals()
mesh2.compute_vertex_normals()

normals1 = np.asarray(mesh1.vertex_normals)
normals2 = np.asarray(mesh2.vertex_normals)
normals = np.concatenate([normals1, normals2], axis=0)

scan_ids = np.concatenate([np.full(len(pts), i) for i, pts in enumerate(pointclouds)])

normals = smooth_normals(points, normals, tree=kd_tree, k=8, T=0.7, n_iters=nrm_smth_iters)

In [12]:
# ---------------------------------------------------------------------
# Local geometric properties
# ---------------------------------------------------------------------
local_spacing_1, local_density_1 = calc_local_spacing(points1, points1, tree=kd_tree)
local_spacing_2, local_density_2 = calc_local_spacing(points2, points2, tree=kd_tree)
local_spacings = (local_spacing_1, local_spacing_2)

local_spacing = np.concatenate(local_spacings)
local_density = np.concatenate((local_density_1, local_density_2))

global_avg_spacing = (1 / len(local_spacings)) * np.sum(
    [(1 / len(ls)) * np.sum(ls) for ls in local_spacings]
)

In [13]:
# ---------------------------------------------------------------------
# Overlap detection
# ---------------------------------------------------------------------
nbr_cache = precompute_cyl_neighbours(points, normals, local_spacing, r_alpha, h_alpha, kd_tree)

overlap_idx, overlap_mask = compute_overlap_set_cached(scan_ids, nbr_cache)
overlap_idx, overlap_mask = smooth_overlap_set_cached(overlap_mask, nbr_cache)
overlap_idx, overlap_mask = smooth_overlap_set_cached(overlap_mask, nbr_cache, p_thresh=1)

In [101]:
# DEBUG
m1 = scan_ids==0
m2 = overlap_mask
debug_colours = np.array([
    [255,   0,   0],
    [  0, 255,   0],
    [  0,   0, 255],
    [255,   0, 255],
])
f0 = ~m1 & ~m2
f1 =  m1 & ~m2
f2 = ~m1 &  m2
f3 =  m1 &  m2
filters = np.column_stack((f0, f1, f2, f3))
idx = filters.argmax(axis=1)
row_colours = debug_colours[idx]

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
pcd.colors = o3d.utility.Vector3dVector(row_colours)

o3d.visualization.draw_geometries([pcd])

In [14]:
# ---------------------------------------------------------------------
# Find overlap boundary edges
# ---------------------------------------------------------------------
tris1 = np.asarray(mesh1.triangles)
tris2 = np.asarray(mesh2.triangles) + len(points1)  # shift indices
all_tris = np.concatenate([tris1, tris2], axis=0)

nonoverlap_tris = all_tris[~np.all(overlap_mask[all_tris], axis=1)]
boundary_edges = find_boundary_edges(nonoverlap_tris)

In [15]:
def update_weights(
        points:             np.ndarray,
        normals:            np.ndarray,
        weights:            np.ndarray,
        overlap_mask:       np.ndarray,
        scan_ids:           np.ndarray,
        nbr_cache:          np.ndarray,
        normal_diff_thresh: float       =45.0,
        huber_delta:        float|None  =1.345,
        tau_max:            float|None  =None,
        bilateral:          bool        =False,
    ) -> np.ndarray:

    weights_out = weights.copy()

    unique_scan_ids = np.unique(scan_ids)
    
    scan_ids_to_update = unique_scan_ids if bilateral else [unique_scan_ids.max()]

    for id in scan_ids_to_update:
        this_scan_mask = scan_ids == id
        overlap_measurement_mask = this_scan_mask & overlap_mask
        for i in np.flatnonzero(overlap_measurement_mask):
            z   = points[i]
            n_z = normals[i]
            
            nbrs = nbr_cache[i]
            ixs  = nbrs[~this_scan_mask[nbrs]] # ids of non-measurement neighbours
            
            # Non-measurement neighbours and their normals
            xs   = points[ixs]
            n_xs = normals[ixs]

            # Calculate z<->x compatibility based on angle between normals
            cos_th = np.cos(np.deg2rad(normal_diff_thresh))
            dots   = (n_xs @ n_z)
            mask = dots >= cos_th
            
            if not np.any(mask):
                continue # Skip this measurement if there are no compatible current values
            
            ixs  = ixs[mask]
            n_xs = n_xs[mask]
            xs   = xs[mask]
            
            # Calculate point-to-plane residuals (z to x's plane)
            rs = np.einsum('ij,ij->i', n_xs, xs - z)
            
            # Pull out precision (== weight == inverse covariance) for the points
            tau_z   = weights[i]
            sigma_z = 1.0 / np.sqrt(max(1e-12, tau_z))

            # Huber loss robustness measure to decrease the effect of outliers
            t = np.abs(rs) / sigma_z
            keep = t <= 3.0 # Ignore massively outlying associations
            if not np.any(keep):
                continue
            
            w = np.ones_like(t)
            if huber_delta is not None:
                far = t > huber_delta
                w[far] = huber_delta / t[far]
            
            # Calculate likelihoods and responsibilities
            ls = np.exp(-0.5 * (t**2)) * w
            resp = ls / (ls.sum() + 1e-12)
            
            # Update weights of non-new points x based on Bayesian update using new points z
            w_xs = weights_out[ixs]
            weights_out[ixs] = w_xs + resp * tau_z
        
    if tau_max is not None:
        weights_out = np.clip(weights_out, None, tau_max)
        
    return weights_out

In [16]:
# ---------------------------------------------------------------------
# Update weights
# ---------------------------------------------------------------------
updated_weights = update_weights(
    points,
    normals,
    weights,
    overlap_mask,
    scan_ids,
    nbr_cache,
    normal_diff_thresh=normal_diff_thresh,
    huber_delta=1.345,
    tau_max=tau_max,
    bilateral=True,
)

In [17]:
# DEBUG
c1 = np.full((len(points),3), [1, 0, 0])
c2 = np.full((len(points),3), [0, 1, 0])
wts_nrm = (updated_weights - updated_weights.min()) / updated_weights.max()
debug_colours =  c1*wts_nrm[:,None] + c2*(1-wts_nrm[:,None])

upd_wt_pcd = o3d.geometry.PointCloud()
upd_wt_pcd.points = o3d.utility.Vector3dVector(points)
upd_wt_pcd.colors = o3d.utility.Vector3dVector(debug_colours)

o3d.visualization.draw_geometries([upd_wt_pcd])

In [18]:
# ---------------------------------------------------------------------
# Multilateral point shifting along normals
# ---------------------------------------------------------------------
normal_shifted_points = points.copy()
for _ in range(nrm_shift_iters+3):
    normal_shifted_points = normal_shift_smooth(
        normal_shifted_points,
        normals,
        weights,
        local_spacing,
        local_density,
        overlap_idx,
        nbr_cache,
        r_alpha,
        h_alpha,
        sigma_theta,
        normal_diff_thresh,
        shift_all
    )

kd_tree = o3d.geometry.KDTreeFlann(normal_shifted_points.T)

In [19]:
def density_aware_radii(pcd: o3d.geometry.PointCloud, k=10):
    pts = np.asarray(pcd.points)
    kdt = o3d.geometry.KDTreeFlann(pcd)
    dk = np.empty(len(pts))
    for i, p in enumerate(pts):
        _, _, d2 = kdt.search_knn_vector_3d(p, k+1)  # includes the point itself
        dk[i] = np.sqrt(d2[-1])                      # k-th neighbor distance
    return np.percentile(dk, [1,5,20,50,80,95,99])

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(normal_shifted_points)

# if new_normals[:n_overlap] is not None:
#     pcd.normals = o3d.utility.Vector3dVector(new_normals[:n_overlap])
# else:
#     pcd.estimate_normals(
#         search_param=o3d.geometry.KDTreeSearchParamHybrid(
#             radius=2.5 * global_avg_spacing, max_nn=30
#         )
#     )
#     pcd.orient_normals_consistent_tangent_plane(k=30)

radii = o3d.utility.DoubleVector(density_aware_radii(pcd, k=10))
pcd.estimate_normals(o3d.geometry.KDTreeSearchParamKNN(40))
pcd.orient_normals_consistent_tangent_plane(50)

shifted_point_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
    pcd, radii
)

shifted_point_mesh.remove_duplicated_triangles()
shifted_point_mesh.remove_degenerate_triangles()
shifted_point_mesh.remove_non_manifold_edges()
shifted_point_mesh.compute_vertex_normals()

TriangleMesh with 84192 points and 167139 triangles.

In [84]:
o3d.visualization.draw_geometries([shifted_point_mesh])

In [20]:
shift = np.linalg.norm(points - normal_shifted_points, axis=1)
noshift_mask = shift==0
shift_rankings = np.argsort(np.argsort(shift[~noshift_mask]))
rankings = np.zeros_like(shift)
rankings[~noshift_mask] = shift_rankings
nrm_rankings = rankings/rankings.max()

shift_colours = (1-nrm_rankings[:,None])*np.full_like(points, np.array([0, 1, 1]))+nrm_rankings[:,None]*np.full_like(points, np.array([1, 0, 1]))
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(normal_shifted_points)
pcd.colors = o3d.utility.Vector3dVector(shift_colours)
o3d.visualization.draw_geometries([pcd])

In [73]:
# ---------------------------------------------------------------------
# Merge nearby clusterss
# ---------------------------------------------------------------------
merged_out = merge_nearby_clusters(
    normal_shifted_points=normal_shifted_points,
    normals=normals,
    weights=updated_weights,
    colours=colours,
    overlap_mask=overlap_mask,
    overlap_idx=overlap_idx,
    global_avg_spacing=global_avg_spacing,
    h_alpha=h_alpha,
    tau_max=tau_max,
    tree=kd_tree,
)
# Unpack merge output
cluster_mapping, clustered_overlap_pnts, clustered_overlap_cols, clustered_overlap_nrms, clustered_overlap_wts = merged_out

In [74]:
# ---------------------------------------------------------------------
# Classify vertices
# ---------------------------------------------------------------------
tri_has_overlap_any = overlap_mask[all_tris].any(axis=1)
overlap_any_idx = np.unique(all_tris[tri_has_overlap_any])

border_mask = np.zeros(len(points), dtype=bool)
border_mask[overlap_any_idx] = True
# border_mask[cluster_mapping != -1] = False
border_mask[overlap_mask] = False

nonoverlap_nonborder_mask = np.zeros(len(points), dtype=bool)
nonoverlap_nonborder_mask[cluster_mapping == -1] = True
nonoverlap_nonborder_mask[border_mask] = False

n_overlap = len(clustered_overlap_pnts)
n_border = border_mask.sum()
n_free = nonoverlap_nonborder_mask.sum()

new_points = np.concatenate(
    [
        clustered_overlap_pnts,
        normal_shifted_points[border_mask],
        normal_shifted_points[nonoverlap_nonborder_mask],
    ],
    axis=0,
)
new_colours = np.concatenate(
    [
        clustered_overlap_cols,
        colours[border_mask],
        colours[nonoverlap_nonborder_mask],
    ],
    axis=0,
)
new_normals = np.concatenate(
    [
        clustered_overlap_nrms,
        normals[border_mask],
        normals[nonoverlap_nonborder_mask],
    ],
    axis=0,
)
new_weights = np.concatenate(
    [
        clustered_overlap_wts,
        weights[border_mask],
        weights[nonoverlap_nonborder_mask],
    ],
    axis=0,
)

new_colours = np.clip(new_colours, 0, 1)

In [75]:
# DEBUG
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(new_points)
debug_colours = np.concatenate([
    np.full((n_overlap,3),[255, 0, 0]),
    np.full((n_border,3), [0, 255, 0]),
    np.full((n_free,3),   [0, 0, 255]),
])
pcd.colors = o3d.utility.Vector3dVector(debug_colours)

o3d.visualization.draw_geometries([pcd])

In [76]:
# ---------------------------------------------------------------------
# Complete mapping
# ---------------------------------------------------------------------
border_idx_from = np.arange(len(points))[border_mask]
border_idx_to = np.arange(n_border) + n_overlap

free_idx_from = np.arange(len(points))[nonoverlap_nonborder_mask]
free_idx_to = np.arange(n_free) + n_overlap + n_border

mapping = cluster_mapping.copy()
mapping[border_idx_from] = border_idx_to
mapping[free_idx_from] = free_idx_to

In [91]:
# ---------------------------------------------------------------------
# Mesh the overlap zone
# ---------------------------------------------------------------------
def density_aware_radii(pcd: o3d.geometry.PointCloud, percentiles=[5,15,30,70,85,95], k=10):
    pts = np.asarray(pcd.points)
    kdt = o3d.geometry.KDTreeFlann(pcd)
    dk = np.empty(len(pts))
    for i, p in enumerate(pts):
        _, _, d2 = kdt.search_knn_vector_3d(p, k+1)  # includes the point itself
        dk[i] = np.sqrt(d2[-1])                      # k-th neighbor distance
    return np.percentile(dk, percentiles)

pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(new_points[:n_overlap])

if new_normals[:n_overlap] is not None:
    pcd.normals = o3d.utility.Vector3dVector(new_normals[:n_overlap])
else:
    pcd.estimate_normals(
        search_param=o3d.geometry.KDTreeSearchParamHybrid(
            radius=2.5 * global_avg_spacing, max_nn=30
        )
    )
    pcd.orient_normals_consistent_tangent_plane(k=30)

radii = o3d.utility.DoubleVector(density_aware_radii(pcd, k=10))
# pcd.estimate_normals(o3d.geometry.KDTreeSearchParamKNN(40))
# pcd.orient_normals_consistent_tangent_plane(50)

# points_pcd = o3d.geometry.PointCloud(o3d.utility.Vector3dVector(points))
# dists = points_pcd.compute_nearest_neighbor_distance()
# s = np.median(dists)

# radii = o3d.utility.DoubleVector([1.2*s, 2*s, 3*s, 4*s])

# r_min = global_avg_spacing
# radii = o3d.utility.DoubleVector(np.geomspace(r_min, r_min*4, num=5))

overlap_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
    pcd, radii
)

overlap_mesh.remove_duplicated_triangles()
overlap_mesh.remove_degenerate_triangles()
overlap_mesh.remove_non_manifold_edges()
overlap_mesh.compute_vertex_normals()

TriangleMesh with 28226 points and 55807 triangles.

In [92]:
# DEMO
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(new_points)
colours = np.concatenate([
    np.full((n_overlap,3),[255, 0, 0]),
    np.full((n_border,3), [0, 255, 0]),
    np.full((n_free,3),   [0, 0, 255]),
])
pcd.colors = o3d.utility.Vector3dVector(colours)

o3d.visualization.draw_geometries([overlap_mesh, pcd])

In [93]:
# ---------------------------------------------------------------------
# Trim overlap mesh
# ---------------------------------------------------------------------
mapped_boundary_edges = mapping[boundary_edges]
relevant_boundary_edges = mapped_boundary_edges[
    np.all(mapped_boundary_edges < len(overlap_mesh.vertices), axis=1)
]

Ncc = len(overlap_mesh.cluster_connected_triangles()[1])

trimmed_overlap_mesh = topological_trim(
    overlap_mesh, relevant_boundary_edges, k=Ncc,
)

In [94]:
# DEMO
o3d.visualization.draw_geometries([trimmed_overlap_mesh, pcd])

In [95]:
# ---------------------------------------------------------------------
# Concatenate trimmed overlap mesh and nonoverlap meshes
# ---------------------------------------------------------------------
trimmed_overlap_tris = np.asarray(trimmed_overlap_mesh.triangles)
fused_mesh_triangles = np.concatenate(
    [trimmed_overlap_tris, mapping[nonoverlap_tris]], axis=0
)

V0, F0, (C0, N0, W0), used_mask, remap = compact_by_faces(
    new_points, fused_mesh_triangles, [new_colours, new_normals, new_weights]
)

fused_mesh = o3d.geometry.TriangleMesh(
    vertices=o3d.utility.Vector3dVector(V0),
    triangles=o3d.utility.Vector3iVector(F0),
)

fused_mesh.vertex_colors = o3d.utility.Vector3dVector(np.clip(C0, 0, 1))

In [96]:
# ---------------------------------------------------------------------
# Clean up fused mesh
# ---------------------------------------------------------------------
# fused_mesh.remove_unreferenced_vertices()
# fused_mesh.remove_duplicated_vertices()
fused_mesh.remove_duplicated_triangles()
fused_mesh.remove_degenerate_triangles()
fused_mesh.remove_non_manifold_edges()

fused_mesh.compute_vertex_normals()

if fill_holes:
    print('Hole-filling is not working at the moment.')
    raise NotImplementedError

In [97]:
o3d.visualization.draw_geometries([fused_mesh])