In [27]:

import numpy as np

from trimesh import Trimesh, load

import meshplot as mp

from scipy.spatial import distance

def plot_mesh(myMesh,cmap=None):
    mp.plot(myMesh.vertices, myMesh.faces,c=cmap)
    
def double_plot(myMesh1,myMesh2,cmap1=None,cmap2=None):
    d = mp.subplot(myMesh1.vertices, myMesh1.faces, c=cmap1, s=[2, 2, 0])
    mp.subplot(myMesh2.vertices, myMesh2.faces, c=cmap2, s=[2, 2, 1], data=d)

def double_plot_clouds(cloud1, cloud2):
    d = mp.subplot(cloud1, None, c=np.array([0, 1, 0]), s=[2, 2, 0])
    mp.subplot(cloud2, None, c=np.array([1, 0, 0]), s=[2, 2, 1], data=d)

def visu(vertices):
    min_coord,max_coord = np.min(vertices,axis=0,keepdims=True),np.max(vertices,axis=0,keepdims=True)
    cmap = (vertices-min_coord)/(max_coord-min_coord)
    return cmap

def plot_correspondence_lines(mesh1, mesh2, mesh1_points, mesh2_points):
    def plot_mesh_and_lines(mesh, points):
        p = mp.plot(mesh1_points, None, c=visu(points))
        p.add_mesh(mesh.vertices, mesh.faces, c=visu(mesh.vertices))
        p.add_lines(mesh1_points, mesh2_points)
    plot_mesh_and_lines(mesh1, mesh2_points)
    plot_mesh_and_lines(mesh2, mesh1_points)

def merged_plot(mesh1, mesh2, shift=[0.0, 0.0, 0.0]):
    d = mp.plot(mesh1.vertices, mesh1.faces)
    d.add_mesh(mesh2.vertices + np.array(shift), mesh2.faces, c=np.array([1, 0, 0]))

Load meshes

In [28]:
ref_mesh = load('data/dataset/YellowToy01/yellow_push_toy_1_70000.obj')
bad_mesh = load('data/dataset/YellowToy01/yellow_push_toy_3_70000.obj')
double_plot(ref_mesh, bad_mesh)



HBox(children=(Output(), Output()))

HBox(children=(Output(), Output()))

In [29]:
shift = [0.0, 0.0, 0.0]
merged_plot(ref_mesh, bad_mesh, shift)

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(-0.054015…

Find `ref_mesh` vertices close to any `bad_mesh` vertices

In [98]:
BAD_DIST_THRESHOLD = 0.003
PATCH_SIZE = 500


# Because processing everything in one sweep kills the kernel
def iter_patches(mesh: Trimesh):
    for start in range(0, len(mesh.vertices), PATCH_SIZE):
        end = min(start + PATCH_SIZE, len(mesh.vertices))
        yield start, end, mesh.vertices[start:end]


def get_mask_near_mesh(mesh1: Trimesh, mesh2: Trimesh, threshold: float):
    used_vertices = np.zeros(len(ref_mesh.vertices), dtype=bool)
    for start, end, patch1 in iter_patches(mesh1):
        for _, _, patch2 in iter_patches(mesh2):
            euc_dist=distance.cdist(patch1, patch2)
            used_from_batch = euc_dist.min(axis=1) < threshold
            used_vertices[start:end] = np.bitwise_or(used_vertices[start:end], used_from_batch)
    
    return used_vertices


used_v_mask_bad = get_mask_near_mesh(ref_mesh, bad_mesh, BAD_DIST_THRESHOLD)


Remove `ref_mesh` vertices which are far from any `bad_mesh` vertices

In [99]:
def get_trimesh_with_mask(src_mesh, vert_mask):
    used_f_mask = np.isin(src_mesh.faces, np.where(vert_mask)).sum(axis=1) > 0
    return Trimesh(src_mesh.vertices, src_mesh.faces[used_f_mask])


ref_mesh_bad = get_trimesh_with_mask(ref_mesh, used_v_mask_bad)
double_plot(ref_mesh_bad, bad_mesh)



HBox(children=(Output(), Output()))

HBox(children=(Output(), Output()))

Remove vertices which are far from the `bad_mesh`'s centroid

In [97]:
CENTER_THRESHOLD = 0.8

def get_mask_near_point(mesh, point, threshold):
    used_vertices = np.zeros(len(mesh.vertices), dtype=bool)
    for start, end, patch in iter_patches(mesh):
        used_from_batch = np.linalg.norm(patch - point) < threshold
        used_vertices[start:end] = np.bitwise_or(used_vertices[start:end], used_from_batch)
    
    return used_vertices

used_v_mask_centroid = get_mask_near_point(ref_mesh, bad_mesh.centroid, CENTER_THRESHOLD)

ref_mesh_centroid = get_trimesh_with_mask(ref_mesh, used_v_mask_centroid)
double_plot(ref_mesh_paired_restored, bad_mesh)

HBox(children=(Output(), Output()))

HBox(children=(Output(), Output()))

Combine

In [None]:
used_v_mask = np.bitwise_or(used_v_mask_bad, used_v_mask_centroid)

37082
18601
19246
