In [1]:
import pymeshlab
import torch
import trimesh
import numpy as np
import utils.fmap_util as fmap_util

ms = pymeshlab.MeshSet()
scene = trimesh.Scene()

In [None]:
mesh_input = trimesh.load_mesh('/home/s94zalek_hpc/shape_matching/data/SHREC16_test/null/off/david.off',
                               validate=True, process=False)

# mesh_input = trimesh.load_mesh('/home/s94zalek_hpc/shape_matching/data_sign_training/train/SURREAL/off/0001.off')

scene.geometry.clear()
scene.add_geometry(mesh_input)
scene.show()

In [38]:
import utils.shape_util as shape_util

geo_dist_input = shape_util.compute_geodesic_distmat(
    mesh_input.vertices,
    mesh_input.faces
    )

In [None]:
import pyFM_fork.pyFM.mesh.geometry as pyfm_geom

idxs_farthest = pyfm_geom.farthest_point_sampling_distmat(
    D=geo_dist_input,
    k=25,
)
idxs_farthest

In [None]:
scene.geometry.clear()

mesh_input.visual.vertex_colors = np.array([255, 255, 255, 255], dtype=np.uint8)

for i in idxs_farthest:
    sphere = trimesh.creation.uv_sphere(radius=0.03)
    sphere.vertices += mesh_input.vertices[i]
    sphere.visual.vertex_colors = [0, 255, 0, 255]
    
    scene.add_geometry(sphere)
    
scene.add_geometry(mesh_input)
scene.show()

In [200]:
def remesh_partial(
    verts,
    faces,
    n_remesh_iters,
    fraction_to_select,
    n_seed_samples,
    weighted_by,
    remove_selection: bool,
    ):
    
    assert weighted_by in ['area', 'face_count']

    mesh = pymeshlab.Mesh(verts, faces)
    ms.add_mesh(mesh)
    
    # isotropic remeshing
    if n_remesh_iters > 0:
        ms.meshing_isotropic_explicit_remeshing(
            iterations=n_remesh_iters,
        ) 
        
    # mesh after remeshing   
    v_r = ms.current_mesh().vertex_matrix()
    f_r = ms.current_mesh().face_matrix()
    
    if weighted_by == 'area':
        # face area
        mesh_r = trimesh.Trimesh(v_r, f_r, process=False)
        area_faces = mesh_r.area_faces
        total_area_faces = area_faces.sum()

        # choose a random face, with probability proportional to its area
        rand_idxs = np.random.choice(len(area_faces), size=n_seed_samples,
                                    p=area_faces / total_area_faces)
        
    elif weighted_by == 'face_count':
        # choose a random face
        rand_idxs = np.random.randint(0, len(f_r), size=n_seed_samples)

    # select the face
    ms.set_selection_none()
    
    # make a query string to select all faces with rand_idxs
    query_str = ''
    for i, rand_idx in enumerate(rand_idxs):
        query_str += f'(fi == {rand_idx})'
        if i < len(rand_idxs) - 1:
            query_str += ' || '
    
    # print(query_str)
    
    ms.compute_selection_by_condition_per_face(
        condselect= query_str
    )
    
    # select the simplification area by dilatation
    for dil_iter in range(100):
        
        # stopping criterion
        if weighted_by == 'area':
            selected_area = sum(area_faces[ms.current_mesh().face_selection_array()])
            if selected_area >= total_area_faces * fraction_to_select:
                # print('dil_iter', dil_iter)
                break
            
        elif weighted_by == 'face_count':
            selected_faces = sum(ms.current_mesh().face_selection_array())
            if selected_faces >= len(f_r) * fraction_to_select:
                # print('dil_iter', dil_iter)
                break
        ms.apply_selection_dilatation()
        
    selected_faces = ms.current_mesh().face_selection_array()


    if remove_selection:
    
        # remove the selected faces
        ms.meshing_remove_selected_vertices_and_faces()
        
        ms.generate_splitting_by_connected_components()
        
        # get the number of vertices in each connected component
        n_vertices_list = []
        for i in range(ms.mesh_number()):
            mesh_i = ms.mesh(i)
            n_vertices_list.append(mesh_i.vertex_matrix().shape[0])
            
        print(n_vertices_list)
            
        # sort the connected components by the number of vertices, ascending
        idx_max_vertices = np.argsort(n_vertices_list)  
        
        # get the vertices and faces of the largest connected component
        # 2nd from the end, last one is the full mesh        
        mesh_partial = ms.mesh(idx_max_vertices[-2])
            
    else:
        ms.generate_from_selected_faces()
        
        n_vertices_list = []
        for i in range(ms.mesh_number()):
            mesh_i = ms.mesh(i)
            n_vertices_list.append(mesh_i.vertex_matrix().shape[0])
            
        print(n_vertices_list)
        
        idx_max_vertices = np.argsort(n_vertices_list)
        
        mesh_partial = ms.mesh(idx_max_vertices[0])

    v_qec = torch.tensor(
        mesh_partial.vertex_matrix(), dtype=torch.float32
    )
    f_qec = torch.tensor(
        mesh_partial.face_matrix(), dtype=torch.int32
    )
    
    ms.clear()
    
    return v_qec, f_qec

In [None]:
import importlib
importlib.reload(remesh)

In [174]:
mesh_input = trimesh.load_mesh('/home/s94zalek_hpc/shape_matching/data_sign_training/train/SURREAL/off/0007.off', process=False)

# scene.geometry.clear()

# scene.add_geometry(mesh_input)

# scene.show()

In [None]:
import my_code.sign_canonicalization.remesh as remesh

# v_output, f_output = remesh_partial(
    
v_output, f_output = remesh.remesh_partial(
    verts=mesh_input.vertices,
    faces=mesh_input.faces,
    n_remesh_iters=0,
    fraction_to_select=0.7,
    n_seed_samples=25,
    weighted_by='area',
    remove_selection=True,
)

scene.geometry.clear()

mesh_output = trimesh.Trimesh(v_output, f_output, process=False)

scene.add_geometry(mesh_output)

scene.show()

In [None]:
len(trimesh.graph.connected_components(mesh_output.edges))

In [None]:
import os

base_folder = '/home/s94zalek_hpc/shape_matching/data_sign_training/train/test_partial_0.8/off'

n_to_show = 10

# select n_to_show random meshes from base_folder and show them

scene.geometry.clear()

mesh_files = os.listdir(base_folder)
files_to_show = np.random.choice(mesh_files, size=n_to_show, replace=False)

print(files_to_show)

for i, file in enumerate(files_to_show):
    mesh = trimesh.load_mesh(os.path.join(base_folder, file), process=False)
    
    mesh.vertices += np.array([i, 0, 0])
    scene.add_geometry(mesh)
    
scene.show()

# Get correspondences to the input

In [None]:
import my_code.sign_canonicalization.remesh as remesh
import my_code.utils.plotting_utils as plotting_utils

# v_output, f_output = remesh_partial(
    
v_output, f_output = remesh.remesh_partial(
    verts=mesh_input.vertices,
    faces=mesh_input.faces,
    n_remesh_iters=10,
    fraction_to_select=0.3,
    n_seed_samples=25,
    weighted_by='area',
    remove_selection=True,
)

corr = fmap_util.nn_query(
    torch.tensor(mesh_input.vertices, dtype=torch.float32), 
    v_output,
    )

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    torch.tensor(mesh_input.vertices), torch.tensor(mesh_input.faces),
    v_output, f_output,
    
    
    corr,
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [None]:
corr.shape

In [2]:
augmentations = {
    "remesh": {
            "isotropic": {
                "n_remesh_iters": 10,
                "remesh_targetlen": 1,
                "simplify_strength_min": 0.2,
                "simplify_strength_max": 0.8,
            },
            "partial": {
                "probability": 0.75,
                "n_remesh_iters": 10,
                "fraction_to_select_min": 0.25,
                "fraction_to_select_max": 0.75,
                "n_seed_samples": [1, 5, 25],
                "weighted_by": "area",
            },
        },
    }

In [3]:
import my_code.datasets.surreal_dataset_3dc as surreal_dataset_3dc

dataset_template = surreal_dataset_3dc.TemplateSurrealDataset3DC(
    shape_path='/lustre/mlnvme/data/s94zalek_hpc-shape_matching/mmap_datas_surreal_train.pth',
    num_evecs=128,
    cache_lb_dir=None,
    return_evecs=True,
    return_fmap=True,
    mmap=True,
    augmentations=augmentations,
    template_path=f'/home/s94zalek_hpc/shape_matching/data/SURREAL_full/template/remeshed/template.off',
    template_corr=np.loadtxt(
        f'/home/s94zalek_hpc/shape_matching/data/SURREAL_full/template/remeshed/corr.txt',
        dtype=np.int32) - 1
) 

In [None]:
import my_code.diffusion_training_sign_corr.test.test_diffusion_pair_template_unified as test_diffusion

fmnet = test_diffusion.RegularizedFMNet(
) 

In [None]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

# plotting_utils.plot_p2p_map(
#     scene,
#     data_template['first']['verts'], data_template['first']['faces'],
#     v_output, f_output,

#     data_template['first']['corr'][corr],
#     axes_color_gradient=[0, 1],
#     base_cmap='hsv'
# )

# data = dataset_template[10]

plotting_utils.plot_p2p_map(
    scene,
    data['first']['verts'], data['first']['faces'],
    data['second']['verts'], data['second']['faces'],

    data['first']['corr'],
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [None]:
import matplotlib.pyplot as plt

data = dataset_template[10]

evecs_first = data['first']['evecs']
evecs_second = data['second']['evecs']

evals_first = data['first']['evals']
evals_second = data['second']['evals']

num_evecs = 200

p2p_first = data['first']['corr']
p2p_second = data['second']['corr']

evecs_trans_first = data['first']['evecs_trans']
evecs_trans_second = data['second']['evecs_trans']
        
Cxy_reg = fmnet.compute_functional_map(
    evecs_trans_second[:num_evecs, p2p_second].unsqueeze(0),
    evecs_trans_first[:num_evecs, p2p_first].unsqueeze(0),
    evals_second[:num_evecs].unsqueeze(0),
    evals_first[:num_evecs].unsqueeze(0), 
)[0].T

Cyx_reg = fmnet.compute_functional_map(
    evecs_trans_first[:num_evecs, p2p_first].unsqueeze(0),
    evecs_trans_second[:num_evecs, p2p_second].unsqueeze(0),
    evals_first[:num_evecs].unsqueeze(0),
    evals_second[:num_evecs].unsqueeze(0),  
)[0].T


Cxy_lstsq = torch.linalg.lstsq(
    data['second']['evecs'][data['second']['corr']],
    data['first']['evecs'][data['first']['corr']],
).solution
Cyx_lstsq = torch.linalg.lstsq(
    data['first']['evecs'][data['first']['corr']],
    data['second']['evecs'][data['second']['corr']],
).solution

print(Cxy_lstsq.min(), Cxy_lstsq.max(), Cyx_lstsq.min(), Cyx_lstsq.max())

l = 0
h = 64


fig, axs = plt.subplots(1, 4, figsize=(16, 4))

plotting_utils.plot_Cxy(fig, axs[0],
                        Cxy_lstsq,
                        'before', l, h, show_grid=False, show_colorbar=False)
plotting_utils.plot_Cxy(fig, axs[1],
                        Cxy_reg,
                        'before', l, h, show_grid=False, show_colorbar=False)
plotting_utils.plot_Cxy(fig, axs[2],
                        Cyx_lstsq,
                        'before', l, h, show_grid=False, show_colorbar=False)
plotting_utils.plot_Cxy(fig, axs[3],
                        Cyx_reg,
                        'before', l, h, show_grid=False, show_colorbar=False)
plt.show()

In [None]:
data_template['second']['corr'].shape, data_template['first']['corr'].shape