In [None]:
import pyvirtualdisplay
import trimesh
import my_code.diffusion_training_sign_corr.data_loading as data_loading
import yaml
import json
from tqdm import tqdm
import torch
import numpy as np
import trimesh.scene
import trimesh.scene.lighting
import plotly.express as px
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os
import PIL.Image


def interpolate_colors(values, cmap, dtype=np.uint8):
    # make input always float
    values = np.asanyarray(values, dtype=np.float64).ravel()
    # scale values to 0.0 - 1.0 and get colors
    colors = cmap((values - values.min()) / values.ptp())
    # convert to 0-255 RGBA
    rgba = trimesh.visual.color.to_rgba(colors, dtype=dtype)
    
    return rgba

    # mesh2.apply_transform(trimesh.transformations.rotation_matrix(np.pi/8, [1, 0, 0], [0, 0, 0]))
    
    # trimesh.smoothing.filter_taubin(mesh1, iterations=3)
    # trimesh.smoothing.filter_taubin(mesh2, iterations=3)
    
    # scene.add_geometry(mesh1)
    # scene.add_geometry(mesh2)
    
    # scene.add_geometry(trimesh.creation.axis(origin_size=0.05))

    # return scene
    
    
def get_cmap():
    SAMPLES = 100
    ice = px.colors.sample_colorscale(
        
        # DT4D
        # px.colors.cyclical.Edge,
        
        # FAUST
        px.colors.sequential.Jet,
        
        
        
        # SHREC19
        # px.colors.diverging.Picnic,
        
        # SCAPE
        # px.colors.cyclical.HSV,
        
        # px.colors.cyclical.IceFire,
        
        
        # px.colors.sequential.Blackbody,
        # px.colors.sequential.Viridis,
        SAMPLES)
    rgb = [px.colors.unconvert_from_RGB_255(px.colors.unlabel_rgb(c)) for c in ice]

    # rgb = rgb[::-1]
    
    cmap = mcolors.ListedColormap(rgb, name='Ice', N=SAMPLES)

    return cmap


def render_mesh(scene, mesh, path):
    
    scene.geometry.clear()
    scene.add_geometry(mesh)
    
    scene.set_camera()
    
    proportion = (mesh.vertices[:, 0].max() - mesh.vertices[:, 0].min()) / (mesh.vertices[:, 1].max() - mesh.vertices[:, 1].min())
    # proportion=1
        
    # png = scene.save_image(resolution=(int(proportion*1080), 1080), visible=True)


    with pyvirtualdisplay.Display(visible=False, size=(1920, 1080)) as disp:
        png = scene.save_image(resolution=(int(proportion*1080), 1080), visible=True)

    # png = scene.save_image(resolution=(int(proportion*1080), 1080), visible=True)

    with open(path, "wb") as f:
        f.write(png)

    return png

In [None]:
import trimesh

scene = trimesh.Scene()

dataset_name = 'SMAL_nocat_pair'

single_dataset, pair_dataset = data_loading.get_val_dataset(
    dataset_name, 'test', 128, preload=False, return_evecs=True, centering='bbox'
)


if dataset_name == 'SHREC19_r_pair':
    file_name = '/lustre/mlnvme/data/s94zalek_hpc-shape_matching/ddpm_checkpoints/single_64_1-2ev_64-128-128_remeshed_fixed/eval/epoch_99/SHREC19_r_pair-test/no_smoothing/2024-11-04_22-27-59/pairwise_results.json'
elif dataset_name == 'DT4D_intra_pair':
    file_name = '/lustre/mlnvme/data/s94zalek_hpc-shape_matching/ddpm_checkpoints/single_template_remeshed/eval/checkpoint_99.pt/DT4D_intra_pair-test/no_smoothing/2024-11-10_21-20-05/pairwise_results.json'
elif dataset_name == 'DT4D_inter_pair':
    file_name = '/lustre/mlnvme/data/s94zalek_hpc-shape_matching/ddpm_checkpoints/single_template_remeshed/eval/checkpoint_99.pt/DT4D_inter_pair-test/no_smoothing/2024-11-10_21-20-05/pairwise_results.json'
elif dataset_name == 'FAUST_r_pair':
    file_name = '/lustre/mlnvme/data/s94zalek_hpc-shape_matching/ddpm_checkpoints/single_64_1-2ev_64-128-128_remeshed_fixed/eval/epoch_99/FAUST_r_pair-test/no_smoothing/2024-11-04_22-27-59/pairwise_results.json'
elif dataset_name == 'SCAPE_r_pair':
    file_name = '/lustre/mlnvme/data/s94zalek_hpc-shape_matching/ddpm_checkpoints/single_64_1-2ev_64-128-128_remeshed_fixed/eval/epoch_99/SCAPE_r_pair-test/no_smoothing/2024-11-04_22-27-59/pairwise_results.json'
elif dataset_name == 'SMAL_nocat_pair':
    file_name = '/lustre/mlnvme/data/s94zalek_hpc-shape_matching/ddpm_checkpoints/single_64_SMAL_nocat_64_SMAL_isoRemesh_0.2_0.8_nocat_1-2ev_64k/eval/epoch_99/SMAL_nocat_pair-test/no_smoothing/2025-01-24_16-01-31/pairwise_results.json'
        
with open(file_name, 'r') as f:
    p2p_saved = json.load(f)


geo_err_list = torch.tensor([p2p_saved[i]['geo_err_median_pairzo'] for i in range(len(p2p_saved))])
idxs_geo_err = torch.argsort(geo_err_list, descending=True)


base_path = f'/lustre/mlnvme/data/s94zalek_hpc-shape_matching/figures/p2p/{dataset_name}'

# if os.path.exists(base_path):
#     os.system(f'rm -r {base_path}')

os.makedirs(f"{base_path}/single", exist_ok=True)
os.makedirs(f"{base_path}/combined", exist_ok=True)

cmap = get_cmap()

random_order = torch.randperm(len(idxs_geo_err))[:200]

        

In [None]:


def get_colored_meshes(verts_x, faces_x, verts_y, faces_y, p2p, dataset_name, axes_color_gradient=[0, 1],
                 base_cmap='jet'):
    
    # assert axes_color_gradient is a list or tuple
    assert isinstance(axes_color_gradient, (list, tuple)), "axes_color_gradient must be a list or tuple"
    assert verts_y.shape[0] == len(p2p), f"verts_y {verts_y.shape} and p2p {p2p.shape} must have the same length"

    if 'SMAL' in dataset_name:

        verts_x_cloned = verts_x.clone()
        
        verts_x[:, 0] = verts_x_cloned[:, 2]
        verts_x[:, 1] = -verts_x_cloned[:, 1]
        verts_x[:, 2] = verts_x_cloned[:, 0]
        
        verts_y_cloned = verts_y.clone()
        
        verts_y[:, 0] = verts_y_cloned[:, 2]
        verts_y[:, 1] = -verts_y_cloned[:, 1]
        verts_y[:, 2] = verts_y_cloned[:, 0]
        
       
    ################################################## 
    # rotate to coordinate axes
    ##################################################
    
    mesh1 = trimesh.Trimesh(vertices=verts_x, faces=faces_x, process=False, validate=False)
    mesh2 = trimesh.Trimesh(vertices=verts_y, faces=faces_y, process=False, validate=False)
    
    # indx = 38
    # mesh1.apply_transform(trimesh.transformations.rotation_matrix(3*np.pi/8, [0, 1, 0], [0, 0, 0]))
    # mesh2.apply_transform(trimesh.transformations.rotation_matrix(-2*np.pi/8, [0, 1, 0], [0, 0, 0]))
    
    # indx = 86
    # mesh2.apply_transform(trimesh.transformations.rotation_matrix(-2.5*np.pi/8, [0, 1, 0], [0, 0, 0]))
    # mesh1.apply_transform(trimesh.transformations.rotation_matrix(3.0*np.pi/8, [0, 1, 0], [0, 0, 0]))
    # mesh1.apply_transform(trimesh.transformations.rotation_matrix(-0.5*np.pi/8, [1, 0, 0], [0, 0, 0]))
    
    # indx = 279
    mesh1.apply_transform(trimesh.transformations.rotation_matrix(6*np.pi/8, [0, 1, 0], [0, 0, 0]))
    mesh2.apply_transform(trimesh.transformations.rotation_matrix(-0.8*np.pi/8, [0, 1, 0], [0, 0, 0]))
    mesh2.apply_transform(trimesh.transformations.rotation_matrix(-0.4*np.pi/8, [0, 0, 1], [0, 0, 0]))
    # mesh2.apply_transform(trimesh.transformations.rotation_matrix(-0.5*np.pi/8, [0, 0, 0], [0, 0, 0]))
    
    
    
    # mesh2.apply_transform(trimesh.transformations.rotation_matrix(0.3*np.pi/8, [0, 0, 1], [0, 0, 0]))
    
    mesh1.apply_transform(trimesh.transformations.reflection_matrix((0, 0, 0), (1, 0, 0)))
    mesh2.apply_transform(trimesh.transformations.reflection_matrix((0, 0, 0), (1, 0, 0)))
    
    # trimesh.repair.fix_inversion(mesh1)
    
    verts_x = torch.tensor(mesh1.vertices, dtype=torch.float32)
    verts_y = torch.tensor(mesh2.vertices, dtype=torch.float32)
        
    ##################################################
    # color gradient
    ##################################################
    
    coords_x_norm = torch.zeros_like(verts_x)
    for i in range(3):
        coords_x_norm[:, i] = (verts_x[:, i] - verts_x[:, i].min()) / (verts_x[:, i].max() - verts_x[:, i].min())

    coords_interpolated = torch.zeros(verts_x.shape[0])
    for i in axes_color_gradient:
        coords_interpolated += coords_x_norm[:, i]
        
    # coords_x_norm[0] = coords_x_norm[0]
        
    if type(base_cmap) == str:
        cmap = trimesh.visual.color.interpolate(coords_interpolated, base_cmap)
    else:
        cmap = interpolate_colors(coords_interpolated, base_cmap)
        
    cmap2 = cmap[p2p].clip(0, 255)

    ##################################################
    # add the meshes
    ################################################

    # 1
    mesh1 = trimesh.Trimesh(vertices=verts_x, faces=faces_x, process=False, validate=False)
    mesh1.visual.vertex_colors = cmap[:len(mesh1.vertices)].clip(0, 255)
    
    # cmap1_faces = trimesh.visual.color.vertex_to_face_color(cmap, mesh1.faces)
    # mesh1.visual.face_colors = cmap1_faces.clip(0, 255).astype(np.uint8)
      
    # 2
    mesh2 = trimesh.Trimesh(vertices=verts_y, faces=faces_y, process=False, validate=False)
    mesh2.visual.vertex_colors = cmap2[:len(mesh2.vertices)]
    
    # cmap2_faces = trimesh.visual.color.vertex_to_face_color(cmap2, mesh2.faces)
    # mesh2.visual.face_colors = cmap2_faces.clip(0, 255).astype(np.uint8)
    
    
    trimesh.repair.fix_inversion(mesh1)
    trimesh.repair.fix_inversion(mesh2)
    
    return mesh1, mesh2
    

In [None]:
# 153
# 154
# 164
# 170
# 171
# 181


indx = 280

data_i = pair_dataset[indx]
p2p_i = p2p_saved[indx]
p2p_pairzo = torch.tensor(p2p_i['p2p_median_pairzo'])

print(p2p_i['geo_err_median_pairzo'])

scene.geometry.clear()

mesh1, mesh2 = get_colored_meshes( 
    data_i['first']['verts'], data_i['first']['faces'],
    data_i['second']['verts'], data_i['second']['faces'],
    p2p_pairzo,
    axes_color_gradient=[0, 1],
    base_cmap=cmap,
    dataset_name=dataset_name
)



# scene.add_geometry(mesh1)
mesh2.apply_transform(trimesh.transformations.translation_matrix([1, 0, 0]))




scene.add_geometry(mesh2)

scene.set_camera()
scene.show()


In [None]:

# png1 = render_mesh(scene, mesh1,
#                    f"/home/s94zalek_hpc/shape_matching/notebooks/rebuttal/smal_qualitative/{indx}_1.png")
png2 = render_mesh(scene, mesh2, 
                   f"/home/s94zalek_hpc/shape_matching/notebooks/rebuttal/smal_qualitative/{indx}_2_1.png")
        
    

In [None]:
for i, combination in enumerate(pair_dataset.combinations):
    print(f'{i}: {single_dataset.off_files[combination[0]].split("/")[-1].split(".")[0]} - {single_dataset.off_files[combination[1]].split("/")[-1].split(".")[0]}')