In [2]:
import pyvirtualdisplay
import trimesh
import my_code.diffusion_training_sign_corr.data_loading as data_loading
import yaml
import json
from tqdm import tqdm
import torch
import numpy as np
import trimesh.scene
import trimesh.scene.lighting
import plotly.express as px
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import os
import networks.diffusion_network as diffusion_network
import yaml
import my_code.sign_canonicalization.training as sign_training


import PIL.Image
          
          
scene = trimesh.Scene()
      

In [36]:
dataset_name = 'SMAL_cat_pair'

single_dataset, pair_dataset = data_loading.get_val_dataset(
    dataset_name, 'test', 128, preload=False, return_evecs=True, centering='bbox'
)

In [None]:
exp_name = 'signNet_32_SMAL_isoRemesh_0.2_0.8'

exp_dir = f'/home/s94zalek_hpc/shape_matching/my_code/experiments/sign_net/{exp_name}'

with open(f'{exp_dir}/config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    
start_dim = config['start_dim']
feature_dim = config['feature_dim']
evecs_per_support = config['evecs_per_support']


device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = diffusion_network.DiffusionNet(
    **config['net_params']
    ).to(device)

net.load_state_dict(torch.load(f'{exp_dir}/39440.pth'))


In [None]:





data = single_dataset[10]
        

##############################################
# Set the variables
##############################################

# data = double_shape['second']
verts = data['verts'].unsqueeze(0).to(device)

verts_cloned = verts.clone()

verts[:, :, 0] = verts_cloned[:,:, 2]
verts[:,:, 1] = -verts_cloned[:,:, 1]
verts[:,:, 2] = verts_cloned[:,:, 0]

faces = data['faces'].unsqueeze(0).to(device)    

evecs_orig = data['evecs'].unsqueeze(0)[:, :, config['start_dim']:config['start_dim']+config['feature_dim']].to(device)

if 'with_mass' in config and config['with_mass']:
    mass_mat = torch.diag_embed(
        data['mass'].unsqueeze(0)
        ).to(device)
else:
    mass_mat = None

# predict the sign change
with torch.no_grad():
    sign_pred_0, supp_vec_0, prod_0 = sign_training.predict_sign_change(
        net, verts, faces, evecs_orig, 
        mass_mat=mass_mat, input_type=net.input_type,
        evecs_per_support=config['evecs_per_support'],
        mass=data['mass'].unsqueeze(0), L=data['L'].unsqueeze(0),
        evals=data['evals'].unsqueeze(0), evecs=data['evecs'].unsqueeze(0),
        gradX=data['gradX'].unsqueeze(0), gradY=data['gradY'].unsqueeze(0)
        )
    
if 'with_mass' in config and config["with_mass"]:

    print('Using mass')

    supp_vec_norm = torch.nn.functional.normalize(
        supp_vec_0[0].transpose(0, 1) \
            @ mass_mat[0],
        p=2, dim=1)
    
    evecs_cond = supp_vec_norm @ evecs_orig[0]
    supp_vec_norm = supp_vec_norm.transpose(0, 1).unsqueeze(0)
        

In [None]:
evec_id = 25

# supp_vec = supp_vec_0[0, :, evec_id].cpu()
supp_vec = supp_vec_norm[0, :, evec_id].cpu()

# supp_vec is a vector in [-1, 1]
# make that the minimum negative value and maximum positive value have the same absolute value
# but the zero value is still zero
max_abs = torch.max(torch.abs(supp_vec))

idx_min = torch.argmin(supp_vec)
idx_max = torch.argmax(supp_vec)

supp_vec[idx_min] = -max_abs
supp_vec[idx_max] = max_abs


mesh1 = trimesh.Trimesh(verts[0].cpu().numpy(), faces[0].cpu().numpy())
cmap1 = trimesh.visual.color.interpolate(supp_vec, 'bwr')

# smooth the colors
# cmap1 = (cmap1.astype(np.int32) + np.roll(cmap1.astype(np.int32), 1) + np.roll(cmap1.astype(np.int32), -1)) / 3
# cmap1 = cmap1.clip(0, 255).astype(np.uint8)

cmap1_faces = trimesh.visual.color.vertex_to_face_color(cmap1, mesh1.faces)
mesh1.visual.face_colors = cmap1_faces.clip(0, 255).astype(np.uint8)
# mesh1.visual.vertex_colors = cmap1[:len(mesh1.vertices)].clip(0, 255).astype(np.uint8)

mesh2 = trimesh.Trimesh(verts[0].cpu().numpy() + np.array([1, 0, 0]), faces[0].cpu().numpy())
cmap2 = trimesh.visual.color.interpolate(evecs_orig[0, :, evec_id].cpu().numpy(), 'bwr')
# mesh2.visual.vertex_colors = cmap2[:len(mesh2.vertices)].clip(0, 255).astype(np.uint8)

cmap2_faces = trimesh.visual.color.vertex_to_face_color(cmap2, mesh2.faces)
mesh2.visual.face_colors = cmap2_faces.clip(0, 255).astype(np.uint8)


scene.geometry.clear()

scene.add_geometry(mesh1)
scene.add_geometry(mesh2)

scene.add_geometry(trimesh.creation.axis(axis_length=1))

scene.set_camera()
scene.show()
