In [62]:
import trimesh

scene = trimesh.Scene()

In [63]:
import torch
import numpy as np
import os
import shutil
from tqdm import tqdm
import yaml

import sys
import os

# models
from my_code.models.diag_conditional import DiagConditionedUnet
from diffusers import DDPMScheduler

import my_code.datasets.template_dataset as template_dataset

import my_code.diffusion_training_sign_corr.data_loading as data_loading

import networks.diffusion_network as diffusion_network
import matplotlib.pyplot as plt
import my_code.utils.plotting_utils as plotting_utils
import utils.fmap_util as fmap_util
import metrics.geodist_metric as geodist_metric
from my_code.sign_canonicalization.training import predict_sign_change
import argparse
from pyFM_fork.pyFM.refine.zoomout import zoomout_refine
import my_code.utils.zoomout_custom as zoomout_custom
from utils.shape_util import compute_geodesic_distmat
from my_code.diffusion_training_sign_corr.test.test_diffusion_cond import select_p2p_map_dirichlet, log_to_database, parse_args
import accelerate

tqdm._instances.clear()

In [64]:
def get_geo_error(
    p2p_first, p2p_second,
    evecs_first, evecs_second,
    corr_first, corr_second,
    num_evecs, apply_zoomout,
    dist_x,
    regularized,
    evecs_trans_first=None, evecs_trans_second=None,
    evals_first=None, evals_second=None,
    return_p2p=False, return_Cxy=False,
    A2=None
    ):
        
    if regularized:
        Cxy = fmnet.compute_functional_map(
            evecs_trans_second[:num_evecs, p2p_second].unsqueeze(0),
            evecs_trans_first[:num_evecs, p2p_first].unsqueeze(0),
            evals_second[:num_evecs].unsqueeze(0),
            evals_first[:num_evecs].unsqueeze(0), 
        )[0].T
        
    else:
        Cxy = torch.linalg.lstsq(
            evecs_second[:, :num_evecs][p2p_second],
            evecs_first[:, :num_evecs][p2p_first]
            ).solution
    
    
    if apply_zoomout:
        Cxy = zoomout_custom.zoomout(
            FM_12=Cxy, 
            evects1=evecs_first,
            evects2=evecs_second,
            nit=evecs_first.shape[1] - num_evecs, step=1,
            A2=A2
        )
        num_evecs = evecs_first.shape[1]
        
    p2p = fmap_util.fmap2pointmap(
        C12=Cxy,
        evecs_x=evecs_first[:, :num_evecs],
        evecs_y=evecs_second[:, :num_evecs],
        ).cpu()
    
    geo_err = geodist_metric.calculate_geodesic_error(
        dist_x, corr_first.cpu(), corr_second.cpu(), p2p, return_mean=True
    )
    
    # if return_p2p:
    #     return geo_err * 100, p2p
    # else:
    #     return geo_err * 100
    
    if not return_p2p and not return_Cxy:
        return geo_err * 100
    
    payload = [geo_err * 100]
    
    if return_p2p:
        payload.append(p2p)
    if return_Cxy:
        payload.append(Cxy)
        
    return payload


def filter_p2p_by_confidence(
    p2p_first, p2p_second,
    confidence_scores_first, confidence_scores_second,
    confidence_threshold, log_file_name
    ):
    
    assert p2p_first.shape[0] == p2p_second.shape[0]
    
    # select points with both confidence scores above threshold
    valid_points = (confidence_scores_first < confidence_threshold) & (confidence_scores_second < confidence_threshold)
    
    with open(log_file_name, 'a') as f:
        
        while valid_points.sum() < 0.05 * len(valid_points):
            confidence_threshold += 0.05
            valid_points = (confidence_scores_first < confidence_threshold) & (confidence_scores_second < confidence_threshold)
            
            print(f'Increasing confidence threshold: {confidence_threshold}\n')
        print(f'Valid points: {valid_points.sum()}')
        assert valid_points.sum() > 0, "No valid points found"
        
    p2p_first = p2p_first[valid_points]
    p2p_second = p2p_second[valid_points]
    
    return p2p_first, p2p_second


In [65]:
class Arguments:
    def __init__(self):
        self.experiment_name='single_64_2-2ev_64-128-128_remeshed_fixed'
        self.checkpoint_name='epoch_99'
        
        self.dataset_name='SHREC19_r_pair'
        self.split='test'
        
        self.num_iters_avg=100
        self.num_samples_median=8
        self.confidence_threshold=0.3

In [66]:
args = Arguments()

# configuration
experiment_name = args.experiment_name
checkpoint_name = args.checkpoint_name

### config
exp_base_folder = f'/home/s94zalek_hpc/shape_matching/my_code/experiments/ddpm/{experiment_name}'
with open(f'{exp_base_folder}/config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)


### model
model = DiagConditionedUnet(config["model_params"]).to('cuda')

if "accelerate" in config and config["accelerate"]:
    accelerate.load_checkpoint_in_model(model, f"{exp_base_folder}/checkpoints/{checkpoint_name}/model.safetensors")
else:
    model.load_state_dict(torch.load(f"{exp_base_folder}/checkpoints/{checkpoint_name}"))


model = model.to('cuda')

### Sign correction network
sign_corr_net = diffusion_network.DiffusionNet(
    **config["sign_net"]["net_params"]
    ).to('cuda')
    
sign_corr_net.load_state_dict(torch.load(
        f'{config["sign_net"]["net_path"]}/{config["sign_net"]["n_iter"]}.pth'
        ))


### sample the model
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2',
                                clip_sample=True) 


### test dataset
dataset_name = args.dataset_name
split = args.split

single_dataset, test_dataset = data_loading.get_val_dataset(
    dataset_name, split, 200, preload=False, return_evecs=True
    )
sign_corr_net.cache_dir = single_dataset.lb_cache_dir


num_evecs = config["model_params"]["sample_size"]


##########################################
# Template
##########################################

template_shape = template_dataset.get_template(
    # template_path='data/SURREAL_full/template/template.ply',
    num_evecs=single_dataset.num_evecs,
    # template_corr=list(range(6890)),
    centering='bbox',
    
    template_path=f'/home/s94zalek_hpc/shape_matching/data/SURREAL_full/template/{config["sign_net"]["template_type"]}/template.off',
    template_corr=np.loadtxt(
        f'/home/s94zalek_hpc/shape_matching/data/SURREAL_full/template/{config["sign_net"]["template_type"]}/corr.txt',
        dtype=np.int32) - 1
    )    

##########################################
# Logging
##########################################

# log_dir = f'{exp_base_folder}/eval/{checkpoint_name}/{dataset_name}-{split}-template'
# os.makedirs(log_dir, exist_ok=True)

# fig_dir = f'{log_dir}/figs'
# os.makedirs(fig_dir, exist_ok=True)

# log_file_name = f'{log_dir}/log.txt'

log_dir = f'{exp_base_folder}/eval/{checkpoint_name}/{dataset_name}-{split}/no_smoothing'
os.makedirs(log_dir, exist_ok=True)

fig_dir = f'{log_dir}/figs'
os.makedirs(fig_dir, exist_ok=True)

log_file_name = f'{log_dir}/log.txt'


##########################################
# Template stage
##########################################

# data_range = tqdm(range(len(single_dataset)), desc='Calculating fmaps to template')

# data_range = tqdm([0, 11, 43])

    

In [None]:

test_dataset.dataset = single_dataset
    
geo_errs_gt = []
geo_errs_corr_gt = []
geo_errs_pairzo = []
geo_errs_dirichlet = []
geo_errs_median = []
geo_errs_median_filtered = []

    
# data_range_pair = tqdm(range(len(test_dataset)), desc='Calculating pair fmaps',
#                        disable=True)

data_range_pair = tqdm([124])
# print('!!! WARNING: only 2 samples are processed !!!')


for i in data_range_pair:
    
    data = test_dataset[i]        
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # device = 'cpu'
    
    verts_first = data['first']['verts'].to(device)
    verts_second = data['second']['verts'].to(device)
    
    faces_first = data['first']['faces'].to(device)
    faces_second = data['second']['faces'].to(device)

    evecs_first = data['first']['evecs'][:, :].to(device)
    evecs_second = data['second']['evecs'][:, :].to(device)
    
    evals_first = data['first']['evals'][:num_evecs]
    evals_second = data['second']['evals'][:num_evecs]

    corr_first = data['first']['corr'].to(device)
    corr_second = data['second']['corr'].to(device)
    
    ###############################################
    # Functional maps
    ###############################################
    
    # evecs_first_zo = data['first']['evecs_zo'].to(device)
    # evecs_second_zo = data['second']['evecs_zo'].to(device)
    
    evecs_first_zo = data['first']['evecs'].to(device)
    evecs_second_zo = data['second']['evecs'].to(device)
    
    # p2p_est_first = data['first']['p2p_est'].to(device)
    # p2p_est_second = data['second']['p2p_est'].to(device)
    
    # p2p_dirichlet_first = data['first']['p2p_dirichlet'].to(device)
    # p2p_dirichlet_second = data['second']['p2p_dirichlet'].to(device)
    
    # p2p_median_first = data['first']['p2p_median'].to(device)
    # p2p_median_second = data['second']['p2p_median'].to(device)
    
    dist_x = torch.tensor(
        compute_geodesic_distmat(data['first']['verts'].numpy(), data['first']['faces'].numpy())    
    )
    dist_y = torch.tensor(
        compute_geodesic_distmat(data['second']['verts'].numpy(), data['second']['faces'].numpy())    
    )
    # dist_x = data['first']['geo_dist']
    # dist_y = data['second']['geo_dist']
    
    # mass_second = data['second']['mass'].to(device)
    mass_second=data['second']['mass'].to(device)
    
    ###############################################
    # Geodesic errors
    ###############################################
    
    # GT geo error
    geo_err_gt = get_geo_error(
        corr_first, corr_second,
        evecs_first, evecs_second,
        corr_first, corr_second,
        num_evecs, False,
        dist_x, regularized=False,
        A2=mass_second
        )

In [None]:
import my_code.utils.plotting_utils as plotting_utils

geo_err_k, p2p_k, Cxy_k = get_geo_error(
    corr_first, corr_second,
    # p2p_est_first[k], p2p_est_second[k],
    evecs_first_zo, evecs_second_zo,
    corr_first, corr_second,
    num_evecs, False,
    dist_x, regularized=True,
    A2=mass_second,
    return_p2p=True,
    return_Cxy=True,
    evecs_trans_first=torch.linalg.pinv(evecs_first_zo),
    evecs_trans_second=torch.linalg.pinv(evecs_second_zo),
    evals_first=data['first']['evals'].to('cuda'),
    evals_second=data['second']['evals'].to('cuda'),
    )
print(geo_err_k)

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    data['first']['verts'], data['first']['faces'],
    data['second']['verts'], data['second']['faces'],
    p2p_k,
    axes_color_gradient=[0],
    base_cmap='hsv'
)

scene.show()

In [72]:
import potpourri3d as pp3d
import scipy.sparse.linalg as sla
import scipy.sparse
import numpy as np

def get_evecs_cotan(verts, faces, k):

    eps = 1e-8

    L = pp3d.cotan_laplacian(verts, faces, denom_eps=1e-10)
    massvec = pp3d.vertex_areas(verts, faces)
    massvec += eps * np.mean(massvec)

    # Compute the eigenbasis
    # Prepare matrices
    L_eigsh = (L + eps * scipy.sparse.identity(L.shape[0])).tocsc()
    massvec_eigsh = massvec
    Mmat = scipy.sparse.diags(massvec_eigsh)
    eigs_sigma = eps

    evals, evecs = sla.eigsh(L_eigsh, k=k, M=Mmat, sigma=eigs_sigma)
    
    return torch.tensor(evecs, dtype=torch.float32)

evecs_first_cotan = get_evecs_cotan(
    data['first']['verts'].numpy(),
    data['first']['faces'].numpy(),
    k=200
    ).to('cuda')
evecs_second_cotan = get_evecs_cotan(
    data['second']['verts'].numpy(),
    data['second']['faces'].numpy(),
    k=200
    ).to('cuda')

In [None]:
geo_err_k, p2p_k, Cxy_k = get_geo_error(
    corr_first, corr_second,
    # p2p_est_first[k], p2p_est_second[k],
    evecs_first_cotan, evecs_second_cotan,
    corr_first, corr_second,
    num_evecs, False,
    dist_x, regularized=False,
    A2=mass_second,
    return_p2p=True,
    return_Cxy=True,
    evecs_trans_first=torch.linalg.pinv(evecs_first_cotan),
    evecs_trans_second=torch.linalg.pinv(evecs_second_cotan),
    evals_first=data['first']['evals'].to('cuda'),
    evals_second=data['second']['evals'].to('cuda'),
    )
print(geo_err_k)

In [78]:
import sys

sys.path.append('/home/s94zalek_hpc/shape_matching/pyFM_fork')

import pyFM
from pyFM.mesh.trimesh import TriMesh

In [None]:
mesh_1 = TriMesh(
    data['first']['verts'].numpy(),
    data['first']['faces'].numpy()
    )
mesh_2 = TriMesh(
    data['second']['verts'].numpy(),
    data['second']['faces'].numpy()
    )
mesh_1.process(intrinsic=True, robust=True)
mesh_2.process(intrinsic=True, robust=True)

In [91]:
evecs_first_pyfm = torch.tensor(mesh_1.eigenvectors, dtype=torch.float32, device='cuda')
evecs_second_pyfm = torch.tensor(mesh_2.eigenvectors, dtype=torch.float32, device='cuda')

In [None]:
geo_err_k, p2p_k, Cxy_k = get_geo_error(
    corr_first, corr_second,
    # p2p_est_first[k], p2p_est_second[k],
    evecs_first_pyfm, evecs_second_pyfm,
    corr_first, corr_second,
    200, False,
    dist_x, regularized=True,
    A2=mass_second,
    return_p2p=True,
    return_Cxy=True,
    evecs_trans_first=torch.linalg.pinv(evecs_first_pyfm),
    evecs_trans_second=torch.linalg.pinv(evecs_second_pyfm),
    evals_first=data['first']['evals'].to('cuda'),
    evals_second=data['second']['evals'].to('cuda'),
    )
print(geo_err_k)

In [None]:
(399 * 3 + 8 * 9) / 408

In [None]:
from my_code.utils.median_p2p_map import dirichlet_energy

dirichlet_energy(p2p_k, verts_first.cpu(), data['second']['L']),\
    dirichlet_energy(corr_second.cpu(), verts_second.cpu(), data['first']['L']),\
        # dirichlet_energy(corr_second.cpu(), verts_first.cpu(), data['second']['L']),\
# dirichlet_energy(p2p_k, verts_second.cpu(), data['first']['L'])

In [None]:
corr_second, p2p_k