In [None]:
import trimesh
import numpy as np

scene = trimesh.Scene()

In [None]:
import networks.diffusion_network as diffusion_network
from tqdm import tqdm
import my_code.sign_canonicalization.training as sign_training
import torch
import my_code.diffusion_training_sign_corr.data_loading as data_loading
import yaml

exp_name = 'signNet_remeshed_4b_mass_10_0.2_0.8' 
# exp_name = 'signNet_isoRemesh_targetlen_3'
# exp_name = 'signNet_anisRemesh' 
# exp_name = 'signNet_FAUST_a'

exp_dir = f'/home/s94zalek_hpc/shape_matching/my_code/experiments/sign_net/{exp_name}'

with open(f'{exp_dir}/config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)
    
start_dim = config['start_dim']

feature_dim = config['feature_dim']
evecs_per_support = config['evecs_per_support']


device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = diffusion_network.DiffusionNet(
    **config['net_params']
    ).to(device)

input_type = config['net_params']['input_type']

net.load_state_dict(torch.load(f'{exp_dir}/{config["n_iter"]}.pth'))

test_dataset = data_loading.get_val_dataset(
    'SHREC19', 'test', 128, canonicalize_fmap=None, preload=False, return_evecs=True
    )[0]     
# test_dataset = data_loading.get_val_dataset(
#     'FAUST_orig', 'test', 128, canonicalize_fmap=None, preload=False, return_evecs=True
#     )[0]  

In [None]:
import my_code.sign_canonicalization.remesh as remesh
import torch
import my_code.datasets.preprocessing as preprocessing


tqdm._instances.clear()
    
with_mass = config['with_mass']
n_epochs = 1
iterator = tqdm(total=len(test_dataset) * n_epochs)
incorrect_signs_list = torch.tensor([])
curr_iter = 0
    
for _ in range(n_epochs):
    for curr_idx in range(len(test_dataset)):
        
        curr_idx = 19

        ##############################################
        # Select a shape
        ##############################################

        train_shape_orig = test_dataset[curr_idx]

        verts_orig = train_shape_orig['verts']
        faces_orig = train_shape_orig['faces']
        
        verts, faces = remesh.remesh_simplify_iso(
            verts_orig,
            faces_orig,
            n_remesh_iters=10,
            remesh_targetlen=1,
            simplify_strength=1,
        )
        
        mesh_anis_remeshed = trimesh.Trimesh(verts, faces)
        # apply laplacian smoothing
        # trimesh.smoothing.filter_laplacian(mesh_anis_remeshed, lamb=0.5, iterations=smoothing_iter)
        trimesh.smoothing.filter_taubin(
            mesh_anis_remeshed, lamb=0.5,
            iterations=5
            )
        
        train_shape = {
            'verts': torch.tensor(mesh_anis_remeshed.vertices).float(),
            'faces': torch.tensor(mesh_anis_remeshed.faces).int(),
        }
        train_shape = preprocessing.get_spectral_ops(train_shape, num_evecs=128,
                                    cache_dir=None)
        

        # train_shape = double_shape['second']
        verts = train_shape['verts'].unsqueeze(0).to(device)
        faces = train_shape['faces'].unsqueeze(0).to(device)    

        evecs_orig = train_shape['evecs'].unsqueeze(0)[:, :, start_dim:start_dim+feature_dim].to(device)
        
        if with_mass:
            mass_mat = torch.diag_embed(
                train_shape['mass'].unsqueeze(0)
                ).to(device)
        else:
            mass_mat = None

        ##############################################
        # Set the signs on shape 0
        ##############################################

        # create a random combilation of +1 and -1, length = feature_dim
        sign_gt_0 = torch.randint(0, 2, (feature_dim,)).float().to(device)
        
        sign_gt_0[sign_gt_0 == 0] = -1
        sign_gt_0 = sign_gt_0.float().unsqueeze(0)

        # print('evecs_orig', evecs_orig.shape, 'sign_gt_0', sign_gt_0.shape)

        # multiply evecs [6890 x 16] by sign_flip [16]
        evecs_flip_0 = evecs_orig * sign_gt_0
        
        
        
        # predict the sign change
        with torch.no_grad():
            sign_pred_0, supp_vec_0, _ = sign_training.predict_sign_change(
                net, verts, faces, evecs_flip_0, 
                mass_mat=mass_mat, input_type=net.input_type,
                
                mass=train_shape['mass'].unsqueeze(0), L=train_shape['L'].unsqueeze(0),
                evals=train_shape['evals'].unsqueeze(0), evecs=train_shape['evecs'].unsqueeze(0),
                gradX=train_shape['gradX'].unsqueeze(0), gradY=train_shape['gradY'].unsqueeze(0)
                )
        
        ##############################################
        # Set the signs on shape 1
        ##############################################
        
        # create a random combilation of +1 and -1, length = feature_dim
        sign_gt_1 = torch.randint(0, 2, (feature_dim,)).float().to(device)
        
        sign_gt_1[sign_gt_1 == 0] = -1
        sign_gt_1 = sign_gt_1.float().unsqueeze(0)
        
        # multiply evecs [6890 x 16] by sign_flip [16]
        evecs_flip_1 = evecs_orig * sign_gt_1
        
        # predict the sign change
        with torch.no_grad():
            sign_pred_1, supp_vec_1, _ = sign_training.predict_sign_change(
                net, verts, faces, evecs_flip_1, 
                mass_mat=mass_mat, input_type=net.input_type,
                
                mass=train_shape['mass'].unsqueeze(0), L=train_shape['L'].unsqueeze(0),
                evals=train_shape['evals'].unsqueeze(0), evecs=train_shape['evecs'].unsqueeze(0),
                gradX=train_shape['gradX'].unsqueeze(0), gradY=train_shape['gradY'].unsqueeze(0)
                )
        
        ##############################################
        # Calculate the loss
        ##############################################
        
        # calculate the ground truth sign difference
        sign_diff_gt = sign_gt_1 * sign_gt_0
        
        # calculate the sign difference between predicted evecs
        sign_diff_pred = sign_pred_1 * sign_pred_0
        
        sign_correct = sign_diff_pred.sign() * sign_diff_gt.sign() 
        
        
        # count the number of incorrect signs
        count_incorrect_signs = (sign_correct < 0).int().sum()
        
        # if count_incorrect_signs > 2:
        #     break
        
        if curr_idx == 19:
            break
            
        # incorrect_signs_list.append(count_incorrect_signs)
        incorrect_signs_list = torch.cat([incorrect_signs_list, torch.tensor([count_incorrect_signs])])
        
        
        iterator.set_description(f'Mean incorrect signs {incorrect_signs_list.float().mean():.2f} / {feature_dim}, max {incorrect_signs_list.max()}')
        iterator.update(1)
        # if count_incorrect_signs > 7:
        #     raise ValueError('Too many incorrect signs')
        
    # return incorrect_signs_list.float().mean(), incorrect_signs_list.max()


In [None]:
print(sign_correct)

In [None]:
import trimesh.visual


scene.geometry.clear()

verts = train_shape['verts'].cpu().numpy()
faces = train_shape['faces'].cpu().numpy()

for i in range(supp_vec_0.shape[-1] // evecs_per_support):

    cmap = np.ones((verts.shape[0], 4))

    # set cmap to 1 where supp_vec_0[0, :, -4] > 0.02
    cmap[supp_vec_0[0, :, i * evecs_per_support].cpu().abs() > 0.015, :2] = 0
    cmap *= 255
    
    # supp_vec_i = supp_vec_0[0, :, i * evecs_per_support].cpu().numpy()
    
    # set values with abs < 0.015 to 0   
    # cmap = trimesh.visual.color.interpolate(supp_vec_i, 'bwr')
    # cmap[np.abs(supp_vec_i) < 0.015] = 255

    mesh = trimesh.Trimesh(vertices=verts + np.array([i, 0, 0]), faces=faces, vertex_colors=cmap)
    scene.add_geometry(mesh)

scene.show()

In [None]:
scene.geometry.clear()

verts = train_shape['verts'].cpu().numpy()
faces = train_shape['faces'].cpu().numpy()

for i in range(supp_vec_0.shape[-1] // evecs_per_support):

    cmap = np.ones((verts.shape[0], 4))

    # set cmap to 1 where supp_vec_0[0, :, -4] > 0.02
    cmap[supp_vec_0[0, :, i * evecs_per_support].cpu().abs() > 0.015, :2] = 0
    # cmap *= 255

    mesh = trimesh.Trimesh(vertices=verts + np.array([i, 0, 0]), faces=faces, vertex_colors=cmap)
    scene.add_geometry(mesh)

scene.show()

In [None]:
train_shape = test_dataset[19]

In [None]:
import utils.geometry_util as geometry_util

wks_orig = geometry_util.compute_wks_autoscale(
    train_shape['evals'].unsqueeze(0), 
    train_shape['evecs'].unsqueeze(0), 
    train_shape['mass'].unsqueeze(0))[0]

hks_orig = geometry_util.compute_hks_autoscale(
    train_shape['evals'].unsqueeze(0), 
    train_shape['evecs'].unsqueeze(0), 
    16)[0]

In [None]:
import matplotlib.pyplot as plt
# wks_orig.shape
plt.plot(hks_orig[:, 6])

In [None]:
scene.geometry.clear()

verts = train_shape['verts'].cpu().numpy()
faces = train_shape['faces'].cpu().numpy()

for i, idx in enumerate(range(1, 10, 1)):
    mesh = trimesh.Trimesh(vertices=verts + np.array([i, 0, 0]), faces=faces)
    cmap = trimesh.visual.color.interpolate(hks_orig[:, idx].cpu().numpy(), 'bwr')
    mesh.visual.vertex_colors = cmap
    scene.add_geometry(mesh)

scene.show()

In [None]:
scene.geometry.clear()

verts = train_shape['verts'].cpu().numpy()
faces = train_shape['faces'].cpu().numpy()

for i, idx in enumerate(range(1, 82, 10)):
    mesh = trimesh.Trimesh(vertices=verts + np.array([i, 0, 0]), faces=faces)
    cmap = trimesh.visual.color.interpolate(wks_orig[:, idx].cpu().numpy(), 'bwr')
    mesh.visual.vertex_colors = cmap
    scene.add_geometry(mesh)

scene.show()

In [None]:
import sys 

sys.path.append('/home/s94zalek_hpc/shape_matching/pyFM_fork')

In [None]:
# read /home/s94zalek_hpc/shape_matching/data/SHREC19_r/off/19.off
mesh = trimesh.load('/home/s94zalek_hpc/shape_matching/data/SHREC19_r/off/20.off')
mesh.show()

In [None]:
from pyFM.mesh import TriMesh

mesh_pyfm = TriMesh(mesh.vertices, mesh.faces)

In [None]:
from pyFM.signatures.WKS_functions import mesh_WKS

mesh_pyfm.process(k=128, intrinsic=False, verbose=True)
wks_pyfm = mesh_WKS(mesh_pyfm, 64)

In [None]:
evecs_pyfm = mesh_pyfm.eigenvectors
evals_pyfm = mesh_pyfm.eigenvalues

In [None]:
scene.geometry.clear()

# verts = train_shape['verts'].cpu().numpy()
# faces = train_shape['faces'].cpu().numpy()

verts = np.array(mesh.vertices)
faces = np.array(mesh.faces)

for i, idx in enumerate(range(1, 10, 1)):
    mesh = trimesh.Trimesh(vertices=verts + np.array([i, 0, 0]), faces=faces)
    cmap = trimesh.visual.color.interpolate(wks_pyfm[:, idx], 'bwr')
    mesh.visual.vertex_colors = cmap
    scene.add_geometry(mesh)

scene.show()

# Compare cotan with robust laplacian

In [None]:
import robust_laplacian
import potpourri3d as pp3d
import scipy.sparse.linalg as sla
import scipy.sparse

eps = 1e-8

# L = pp3d.cotan_laplacian(verts, faces, denom_eps=1e-10)
# massvec_np = pp3d.vertex_areas(verts, faces)
# massvec_np += eps * np.mean(massvec_np)

L, M = robust_laplacian.mesh_laplacian(verts, faces)
massvec_np = M.diagonal()

L_eigsh = (L + eps * scipy.sparse.identity(L.shape[0])).tocsc()
massvec_eigsh = massvec_np
Mmat = scipy.sparse.diags(massvec_eigsh)
eigs_sigma = eps

fail_cnt = 0

evals_np, evecs_np = sla.eigsh(L_eigsh, k=128, M=Mmat, sigma=eigs_sigma)
# Clip off any eigenvalues that end up slightly negative due to numerical error
evals_np = np.clip(evals_np, a_min=0., a_max=float('inf'))

In [None]:
verts.shape, evecs_np.shape

In [None]:
import utils.geometry_util as geometry_util

# wks_lapl = geometry_util.compute_hks_autoscale(
#     torch.tensor(evals_np).unsqueeze(0), 
#     torch.tensor(evecs_np).unsqueeze(0), 
#     torch.tensor(massvec_np).unsqueeze(0),
#     )[0]

wks_lapl = geometry_util.auto_wks(
    torch.tensor(evals_np), 
    torch.tensor(evecs_np), 
    128,
    # scaled=False
    )


scene.geometry.clear()

# verts = train_shape['verts'].cpu().numpy()
# faces = train_shape['faces'].cpu().numpy()

for i, idx in enumerate(range(1, 10, 1)):
    mesh_i = trimesh.Trimesh(vertices=verts + np.array([i, 0, 0]), faces=faces)
    cmap = trimesh.visual.color.interpolate(wks_lapl[:, idx], 'bwr')
    mesh_i.visual.vertex_colors = cmap
    scene.add_geometry(mesh_i)

scene.show()

In [None]:
scene.geometry.clear()

# verts = train_shape['verts'].cpu().numpy()
# faces = train_shape['faces'].cpu().numpy()

for i in range(5, 10):
    mesh_pyfm = trimesh.Trimesh(vertices=verts + np.array([0, -i, 0]), faces=faces)
    cmap = trimesh.visual.color.interpolate(evecs_pyfm[:, i], 'bwr')
    mesh_pyfm.visual.vertex_colors = cmap
    scene.add_geometry(mesh_pyfm)

    mesh_lapl = trimesh.Trimesh(vertices=verts + np.array([1, -i, 0]), faces=faces)
    cmap = trimesh.visual.color.interpolate(evecs_np[:, i], 'bwr')
    mesh_lapl.visual.vertex_colors = cmap
    scene.add_geometry(mesh_lapl)

scene.show()

In [None]:
plt.plot(np.sum(np.abs(evecs_pyfm) - np.abs(evecs_np), axis = 0))
print(np.sum(np.abs(evecs_pyfm) - np.abs(evecs_np), axis = 0))