In [None]:
import trimesh

scene = trimesh.Scene()

In [None]:
import numpy as np
import trimesh.scene
import trimesh.scene.lighting

def interpolate_colors(values, cmap, dtype=np.uint8):
    # make input always float
    values = np.asanyarray(values, dtype=np.float64).ravel()
    # scale values to 0.0 - 1.0 and get colors
    colors = cmap((values - values.min()) / values.ptp())
    # convert to 0-255 RGBA
    rgba = trimesh.visual.color.to_rgba(colors, dtype=dtype)
    
    return rgba


def plot_p2p_map(scene, verts_x, faces_x, verts_y, faces_y, p2p, axes_color_gradient=[0, 1],
                 base_cmap='jet'):
    
    # assert axes_color_gradient is a list or tuple
    assert isinstance(axes_color_gradient, (list, tuple)), "axes_color_gradient must be a list or tuple"
    assert verts_y.shape[0] == len(p2p), f"verts_y {verts_y.shape} and p2p {p2p.shape} must have the same length"
    
    
    ##################################################
    # color gradient
    ##################################################
    
    coords_x_norm = torch.zeros_like(verts_x)
    for i in range(3):
        coords_x_norm[:, i] = (verts_x[:, i] - verts_x[:, i].min()) / (verts_x[:, i].max() - verts_x[:, i].min())

    coords_interpolated = torch.zeros(verts_x.shape[0])
    for i in axes_color_gradient:
        coords_interpolated += coords_x_norm[:, i]
        
    if type(base_cmap) == str:
        cmap = trimesh.visual.color.interpolate(coords_interpolated, base_cmap)
    else:
        cmap = interpolate_colors(coords_interpolated, base_cmap)
        
    cmap2 = cmap[p2p].clip(0, 255)
    
    
    ##################################################
    # material
    ##################################################
    
    
    # diffuse material
    material=trimesh.visual.material.SimpleMaterial(
        image=None,
        diffuse=[245] * 4,
        smooth=True
    )
    
    ##################################################
    # Lights
    ##################################################
    
    # add a light to the scene
    scene.lights=[
            trimesh.scene.lighting.DirectionalLight(
                color=[1.0, 1.0, 1.0]
            ),

        ]
    
    ##################################################
    # add the meshes
    ################################################

    # 1
    mesh1 = trimesh.Trimesh(vertices=verts_x + np.array([1, 0, 0]), faces=faces_x, validate=True)
    mesh1.visual.material = material
    mesh1.visual.vertex_colors = cmap[:len(mesh1.vertices)].clip(0, 255)
           
    # 2
    mesh2 = trimesh.Trimesh(vertices=verts_y, faces=faces_y, validate=True)
    mesh2.visual.material = material
    mesh2.visual.vertex_colors = cmap2[:len(mesh2.vertices)]
    
    
    
    
    trimesh.smoothing.filter_taubin(mesh1, iterations=3)
    trimesh.smoothing.filter_taubin(mesh2, iterations=3)
    
    scene.add_geometry(mesh1)
    scene.add_geometry(mesh2)
    
    # scene.add_geometry(trimesh.creation.axis(origin_size=0.05))

    return scene

In [None]:
import my_code.diffusion_training_sign_corr.data_loading as data_loading
import yaml
import json
from tqdm import tqdm



single_dataset, pair_dataset = data_loading.get_val_dataset(
    'DT4D_intra_pair', 'test', 128, preload=False, return_evecs=True, centering='bbox'
)

In [None]:
pair_dataset[252]['second']['id']

In [None]:
single_dataset.off_files[29]

In [None]:
# read /lustre/mlnvme/data/s94zalek_hpc-shape_matching/ddpm_checkpoints/single_64_1-2ev_64-128-128_remeshed_fixed/eval/epoch_99/SHREC19_r_pair-test/no_smoothing/2024-11-03_21-44-05/pairwise_results.yaml

with open('/lustre/mlnvme/data/s94zalek_hpc-shape_matching/ddpm_checkpoints/single_template_remeshed/eval/checkpoint_99.pt/DT4D_intra_pair-test/no_smoothing/2024-11-10_21-20-05/pairwise_results.json', 'r') as f:
    # p2p_saved = yaml.load(f, Loader=yaml.FullLoader)
    p2p_saved = json.load(f)
    
p2p_saved

In [None]:
p2p_saved[0].keys()

In [None]:

print(geo_err_list[idxs_geo_err[:10]])

In [None]:
import torch

geo_err_list = torch.tensor([p2p_saved[i]['geo_err_median_pairzo'] for i in range(len(p2p_saved))])
idxs_geo_err = torch.argsort(geo_err_list, descending=True)



indx = idxs_geo_err[15]
data_i = pair_dataset[indx]
p2p_i = p2p_saved[indx]
p2p_pairzo = torch.tensor(p2p_i['p2p_median_pairzo'])


mesh_1 = trimesh.Trimesh(data_i['first']['verts'], data_i['first']['faces'])
mesh_2 = trimesh.Trimesh(data_i['second']['verts'], data_i['second']['faces'])


print(p2p_i['geo_err_median_pairzo'])

In [None]:
import my_code.utils.plotting_utils as plotting_utils
import plotly.express as px
import plotly.express as px
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors


SAMPLES = 200
ice = px.colors.sample_colorscale(
    # px.colors.cyclical.IceFire,
    # px.colors.cyclical.HSV,
    px.colors.sequential.Jet,
    SAMPLES)
rgb = [px.colors.unconvert_from_RGB_255(px.colors.unlabel_rgb(c)) for c in ice]

cmap = mcolors.ListedColormap(rgb, name='Ice', N=SAMPLES)


scene.geometry.clear()

plot_p2p_map(
    scene,
    
    
    data_i['first']['verts'], data_i['first']['faces'],
    data_i['second']['verts'], data_i['second']['faces'],
    p2p_pairzo,
    axes_color_gradient=[0, 1],
    base_cmap=cmap
)

# scene.set_camera(
#     angles=[15, 0, 0],
#     distance=0.0,
#     center=[0, , 1.5],
#     resolution=[1920, 1080]
# )

scene.show()

In [None]:
import itertools

list(itertools.combinations(range(3), 2))

In [None]:
import metrics.geodist_metric as geodist_metric
from utils.shape_util import compute_geodesic_distmat

dist_x = torch.tensor(
    compute_geodesic_distmat(data_i['first']['verts'].numpy(), data_i['first']['faces'].numpy())    
)

In [None]:
corr_first = data_i['first']['corr']
corr_second = data_i['second']['corr']

geo_err = geodist_metric.calculate_geodesic_error(
    dist_x, corr_first.cpu(), corr_second.cpu(), p2p_pairzo, return_mean=True
)
geo_err * 100

In [None]:
dist_x.shape

In [None]:
p2p_pairzo.shape

In [None]:
print(corr_first.shape, corr_second.shape)
print(corr_first, corr_second)

In [None]:
data_i['first']['verts'].shape, data_i['second']['verts'].shape

In [None]:
dist_x[p2p_pairzo[corr_second], corr_first].mean()

In [None]:
# read /home/s94zalek_hpc/baselines/Spatially-and-Spectrally-Consistent-Deep-Functional-Maps/data/results/Phi/Phi_tr_reg_080.mat


import scipy.io as sio

mat = sio.loadmat('/home/s94zalek_hpc/baselines/Spatially-and-Spectrally-Consistent-Deep-Functional-Maps/data/results/C/C_tr_reg_080_tr_reg_081.mat')
mat

In [None]:
mat['C'].shape

In [None]:
# read /home/s94zalek_hpc/baselines/Spatially-and-Spectrally-Consistent-Deep-Functional-Maps/data/results/SCAPE_a/p2p_21/0_1.txt
import numpy as np

p2p_pairzo = np.loadtxt('/home/s94zalek_hpc/baselines/Spatially-and-Spectrally-Consistent-Deep-Functional-Maps/data/results/SCAPE_a/p2p_21/0_1.txt')
p2p_pairzo = torch.tensor(p2p_pairzo, dtype=torch.int32)

In [9]:
import torch

p2p_median_second = torch.tensor(p2p_saved[594]['p2p_median_second'])

In [None]:
import my_code.utils.plotting_utils as plotting_utils
import plotly.express as px
import plotly.express as px
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

data_i = pair_dataset[594]

template = '/home/s94zalek_hpc/shape_matching/data/SURREAL_full/template/remeshed/template.off'

template_mesh = trimesh.load(template, process=False)

SAMPLES = 200
ice = px.colors.sample_colorscale(
    # px.colors.cyclical.IceFire,
    # px.colors.cyclical.HSV,
    px.colors.sequential.Jet,
    SAMPLES)
rgb = [px.colors.unconvert_from_RGB_255(px.colors.unlabel_rgb(c)) for c in ice]

cmap = mcolors.ListedColormap(rgb, name='Ice', N=SAMPLES)


scene.geometry.clear()

plot_p2p_map(
    scene,
    
    
    # data_i['first']['verts'], data_i['first']['faces'],
    data_i['second']['verts'], data_i['second']['faces'],
    torch.tensor(template_mesh.vertices), torch.tensor(template_mesh.faces),
    p2p_median_second,
    axes_color_gradient=[0, 1],
    base_cmap=cmap
)

scene.show()