In [1]:
import trimesh

scene = trimesh.Scene()

In [2]:
import torch
import numpy as np
import os
import shutil
from tqdm import tqdm
import yaml

import sys
import os

# models
from my_code.models.diag_conditional import DiagConditionedUnet
from diffusers import DDPMScheduler

import my_code.datasets.template_dataset as template_dataset

import my_code.diffusion_training_sign_corr.data_loading as data_loading

import networks.diffusion_network as diffusion_network
import matplotlib.pyplot as plt
import my_code.utils.plotting_utils as plotting_utils
import utils.fmap_util as fmap_util
import metrics.geodist_metric as geodist_metric
from my_code.sign_canonicalization.training import predict_sign_change
import argparse
from pyFM_fork.pyFM.refine.zoomout import zoomout_refine
import my_code.utils.zoomout_custom as zoomout_custom
from utils.shape_util import compute_geodesic_distmat
from my_code.diffusion_training_sign_corr.test.test_diffusion_cond import select_p2p_map_dirichlet, log_to_database, parse_args
import accelerate
import my_code.sign_canonicalization.test_sign_correction as test_sign_correction
import networks.fmap_network as fmap_network
from my_code.utils.median_p2p_map import dirichlet_energy

tqdm._instances.clear()

class RegularizedFMNet(torch.nn.Module):
    """Compute the functional map matrix representation in DPFM"""
    def __init__(self, lmbda=0.01, resolvant_gamma=0.5, bidirectional=False):
        super(RegularizedFMNet, self).__init__()
        self.lmbda = lmbda
        self.resolvant_gamma = resolvant_gamma
        self.bidirectional = bidirectional

    def compute_functional_map(self, A, B, evals_x, evals_y):
        # A = torch.bmm(evecs_trans_x, feat_x)  # [B, K, C]
        # B = torch.bmm(evecs_trans_y, feat_y)  # [B, K, C]

        D = fmap_network.get_mask(evals_x, evals_y, self.resolvant_gamma)  # [B, K, K]

        A_t = A.transpose(1, 2)  # [B, C, K]
        A_A_t = torch.bmm(A, A_t)  # [B, K, K]
        B_A_t = torch.bmm(B, A_t)  # [B, K, K]

        C_i = []
        for i in range(evals_x.shape[1]):
            D_i = torch.cat([torch.diag(D[bs, i, :].flatten()).unsqueeze(0) for bs in range(evals_x.shape[0])], dim=0)
            C = torch.bmm(torch.inverse(A_A_t + self.lmbda * D_i), B_A_t[:, [i], :].transpose(1, 2))
            C_i.append(C.transpose(1, 2))

        Cxy = torch.cat(C_i, dim=1)
         
        return Cxy
    



def get_geo_error(
    p2p_first, p2p_second,
    evecs_first, evecs_second,
    corr_first, corr_second,
    num_evecs, apply_zoomout,
    dist_x,
    regularized=False,
    evecs_trans_first=None, evecs_trans_second=None,
    evals_first=None, evals_second=None,
    return_p2p=False, return_Cxy=False,
    A2=None, fmnet=None
    ):
        
    if regularized:
        Cxy = fmnet.compute_functional_map(
            evecs_trans_second[:num_evecs, p2p_second].unsqueeze(0),
            evecs_trans_first[:num_evecs, p2p_first].unsqueeze(0),
            evals_second[:num_evecs].unsqueeze(0),
            evals_first[:num_evecs].unsqueeze(0), 
        )[0].T
        
    else:
        Cxy = torch.linalg.lstsq(
            evecs_second[:, :num_evecs][p2p_second],
            evecs_first[:, :num_evecs][p2p_first]
            ).solution
    
    
    if apply_zoomout:
        Cxy = zoomout_custom.zoomout(
            FM_12=Cxy, 
            evects1=evecs_first,
            evects2=evecs_second,
            nit=evecs_first.shape[1] - num_evecs, step=1,
            A2=A2
        )
        num_evecs = evecs_first.shape[1]
        
    p2p = fmap_util.fmap2pointmap(
        C12=Cxy,
        evecs_x=evecs_first[:, :num_evecs],
        evecs_y=evecs_second[:, :num_evecs],
        ).cpu()
    
    geo_err = geodist_metric.calculate_geodesic_error(
        dist_x, corr_first.cpu(), corr_second.cpu(), p2p, return_mean=True
    )
    
    # if return_p2p:
    #     return geo_err * 100, p2p
    # else:
    #     return geo_err * 100
    
    if not return_p2p and not return_Cxy:
        return geo_err * 100
    
    payload = [geo_err * 100]
    
    if return_p2p:
        payload.append(p2p)
    if return_Cxy:
        payload.append(Cxy)
        
    return payload


def filter_p2p_by_confidence(
        p2p_first, p2p_second,
        confidence_scores_first, confidence_scores_second,
        confidence_threshold, log_file_name
    ):
    
    assert p2p_first.shape[0] == p2p_second.shape[0]
    
    # select points with both confidence scores above threshold
    valid_points = (confidence_scores_first < confidence_threshold) & (confidence_scores_second < confidence_threshold)
    
    with open(log_file_name, 'a') as f:
        
        while valid_points.sum() < 0.05 * len(valid_points):
            confidence_threshold += 0.05
            valid_points = (confidence_scores_first < confidence_threshold) & (confidence_scores_second < confidence_threshold)
            
            f.write(f'Increasing confidence threshold: {confidence_threshold}\n')
        f.write(f'Valid points: {valid_points.sum()}\n')
        assert valid_points.sum() > 0, "No valid points found"
        
    p2p_first = p2p_first[valid_points]
    p2p_second = p2p_second[valid_points]
    
    return p2p_first, p2p_second

In [3]:
def get_fmaps_evec_signs(
        data, model,
        noise_scheduler, config, args,
        template_shape, sign_corr_net
    ):
        
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    num_evecs = config["model_params"]["sample_size"]

        
    verts_first = template_shape['verts'].unsqueeze(0).to(device)
    verts_second = data['verts'].unsqueeze(0).to(device)
    
    faces_first = template_shape['faces'].unsqueeze(0).to(device)
    faces_second = data['faces'].unsqueeze(0).to(device)

    evecs_first = template_shape['evecs'][:, :num_evecs].unsqueeze(0).to(device)
    evecs_second = data['evecs'][:, :num_evecs].unsqueeze(0).to(device)
    

    if config["sign_net"]["with_mass"]:
        mass_mat_first = torch.diag_embed(
            template_shape['mass'].unsqueeze(0)
            ).to(device)
        mass_mat_second = torch.diag_embed(
            data['mass'].unsqueeze(0)
            ).to(device)
    else:
        mass_mat_first = None
        mass_mat_second = None


    ###############################################
    # get conditioning and signs num_iters_avg times
    ###############################################

    evecs_cond_first_list = []
    evecs_cond_second_list = []
    evecs_first_signs_list = []
    evecs_second_signs_list = []

    for _ in range(args.num_iters_avg):

        # predict the sign change
        with torch.no_grad():
            sign_pred_first, support_vector_norm_first, _ = predict_sign_change(
                sign_corr_net, verts_first, faces_first, evecs_first, 
                mass_mat=mass_mat_first, input_type=sign_corr_net.input_type,
                evecs_per_support=config["sign_net"]["evecs_per_support"],
                
                mass=template_shape['mass'].unsqueeze(0), L=template_shape['L'].unsqueeze(0),
                evals=template_shape['evals'][:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
                evecs=template_shape['evecs'][:,:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
                gradX=template_shape['gradX'].unsqueeze(0), gradY=template_shape['gradY'].unsqueeze(0)
                )
            sign_pred_second, support_vector_norm_second, _ = predict_sign_change(
                sign_corr_net, verts_second, faces_second, evecs_second, 
                mass_mat=mass_mat_second, input_type=sign_corr_net.input_type,
                evecs_per_support=config["sign_net"]["evecs_per_support"],
                
                mass=data['mass'].unsqueeze(0), L=data['L'].unsqueeze(0),
                evals=data['evals'][:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
                evecs=data['evecs'][:,:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
                gradX=data['gradX'].unsqueeze(0), gradY=data['gradY'].unsqueeze(0)
                )

        # correct the evecs
        evecs_first_corrected = evecs_first.cpu()[0] * torch.sign(sign_pred_first).cpu()
        evecs_first_corrected_norm = evecs_first_corrected / torch.norm(evecs_first_corrected, dim=0, keepdim=True)
        
        evecs_second_corrected = evecs_second.cpu()[0] * torch.sign(sign_pred_second).cpu()
        evecs_second_corrected_norm = evecs_second_corrected / torch.norm(evecs_second_corrected, dim=0, keepdim=True)
        

        # product with support
        if config["sign_net"]["with_mass"]:
        # if config["sign_net"]['cond_mass_normalize']:
            
            mass_mat_first = torch.diag_embed(
                template_shape['mass'].unsqueeze(0)
                ).to(device)
            mass_mat_second = torch.diag_embed(
                data['mass'].unsqueeze(0)
                ).to(device)
            
            evecs_cond_first = torch.nn.functional.normalize(
                support_vector_norm_first[0].cpu().transpose(0, 1) \
                    @ mass_mat_first[0].cpu(),
                p=2, dim=1) \
                    @ evecs_first_corrected_norm
            
            evecs_cond_second = torch.nn.functional.normalize(
                support_vector_norm_second[0].cpu().transpose(0, 1) \
                    @ mass_mat_second[0].cpu(),
                p=2, dim=1) \
                    @ evecs_second_corrected_norm 
            
        else:
            evecs_cond_first = support_vector_norm_first[0].cpu().transpose(0, 1) @ evecs_first_corrected_norm
            evecs_cond_second = support_vector_norm_second[0].cpu().transpose(0, 1) @ evecs_second_corrected_norm
            
        evecs_cond_first_list.append(evecs_cond_first)
        evecs_cond_second_list.append(evecs_cond_second)
        evecs_first_signs_list.append(torch.sign(sign_pred_first).cpu())
        evecs_second_signs_list.append(torch.sign(sign_pred_second).cpu())
        
    evecs_cond_first_list = torch.stack(evecs_cond_first_list)
    evecs_cond_second_list = torch.stack(evecs_cond_second_list)
    evecs_first_signs_list = torch.stack(evecs_first_signs_list)
    evecs_second_signs_list = torch.stack(evecs_second_signs_list)    
    
    
    ###############################################
    # Conditioning
    ###############################################

    conditioning = torch.cat(
        (evecs_cond_first_list.unsqueeze(1), evecs_cond_second_list.unsqueeze(1)),
        1)
    
    ###############################################
    # Sample the model
    ###############################################
    
    x_sampled = torch.rand(args.num_iters_avg, 1, 
                        config["model_params"]["sample_size"],
                        config["model_params"]["sample_size"]).to(device)
    y = conditioning.to(device)    
        
    # Sampling loop
    for t in noise_scheduler.timesteps:

        # Get model pred
        with torch.no_grad():
            residual = model(x_sampled, t,
                                conditioning=y
                                ).sample

        # Update sample with step
        x_sampled = noise_scheduler.step(residual, t, x_sampled).prev_sample
        
    return x_sampled, evecs_first_signs_list, evecs_second_signs_list


def get_p2p_maps_template(
        data,
        C_yx_est_i, evecs_first_signs_i, evecs_second_signs_i,
        template_shape, args, log_file_name, config
    ):
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    num_evecs = config["model_params"]["sample_size"]
    
    f = open(log_file_name, 'a', buffering=1)
    
    verts_second = data['verts']
    faces_second = data['faces']
    
    evecs_first = template_shape['evecs'][:, :num_evecs]
    evecs_second = data['evecs'][:, :num_evecs]
    
    # evecs_first = template_shape['evecs']
    # evecs_second = data['evecs']
    
    dist_second = torch.tensor(
        compute_geodesic_distmat(
            verts_second.numpy(),
            faces_second.numpy())    
    )
    
    ##########################################################
    # Convert fmaps to p2p maps to template
    ##########################################################
    
    p2p_est = []
    
    # version without zoomout and dirichlet energy condition
    # for k in range(args.num_iters_avg):

    #     evecs_first_corrected = evecs_first * evecs_first_signs_i[k]
    #     evecs_second_corrected = evecs_second * evecs_second_signs_i[k]
    #     Cyx_est_k = C_yx_est_i[k][0].cpu()

    #     p2p_est_k = fmap_util.fmap2pointmap(
    #         C12=Cyx_est_k.to(device),
    #         evecs_x=evecs_second_corrected.to(device),
    #         evecs_y=evecs_first_corrected.to(device),
    #         ).cpu()

    #     p2p_est.append(p2p_est_k)
                
                
    for k in range(args.num_iters_avg):
        
        evecs_first_corrected = evecs_first * evecs_first_signs_i[k]
        evecs_second_corrected = evecs_second * evecs_second_signs_i[k]
        Cyx_est_k = C_yx_est_i[k][0].cpu()
    
        fmap_dimension_k = num_evecs
    
        zo_num_evecs = args.zoomout_num_evecs_template
        if zo_num_evecs is not None and zo_num_evecs > 0 and fmap_dimension_k < zo_num_evecs:
            
            evecs_first_zo = torch.cat(
                [evecs_first_corrected, evecs_first[:, fmap_dimension_k:zo_num_evecs]],
                dim=1
            ).to(device)
            
            evecs_second_zo = torch.cat(
                [evecs_second_corrected, evecs_second[:, fmap_dimension_k:zo_num_evecs]],
                dim=1                    
            ).to(device)
            
            
            Cyx_zo_k = zoomout_custom.zoomout(
                FM_12=Cyx_est_k.to(device), 
                evects1=evecs_second_zo,
                evects2=evecs_first_zo,
                nit=zo_num_evecs-fmap_dimension_k, step=1,
                A2=template_shape['mass'].to(device),
            )
            p2p_zo_k = fmap_util.fmap2pointmap(
                C12=Cyx_zo_k,
                evecs_x=evecs_second_zo,
                evecs_y=evecs_first_zo,
                ).cpu()
            
            # dirichlet_energy_zo = dirichlet_energy(p2p_zo_k, verts_second, template_shape['L'])
            # f.write(f'Zoomout energy: {dirichlet_energy_zo}\n')
            
            p2p_est_k = p2p_zo_k
            
        else:
            p2p_est_k = fmap_util.fmap2pointmap(
                C12=Cyx_est_k.to(device),
                evecs_x=evecs_second_corrected.to(device),
                evecs_y=evecs_first_corrected.to(device),
                ).cpu()
            
        p2p_est.append(p2p_est_k)
                
                
                
                

    p2p_est = torch.stack(p2p_est)
        
    ##########################################################
    # p2p map selection
    ##########################################################
    
    p2p_dirichlet, p2p_median, confidence_scores, dirichlet_energy_list = select_p2p_map_dirichlet(
        p2p_est,
        verts_second,
        template_shape['L'], 
        dist_second,
        num_samples_median=args.num_samples_median
        )
         
    # f.write(f'Template stage\n')
    # f.write(f'Dirichlet energy: {dirichlet_energy_list}\n')
    # f.write(f'Confidence scores: {confidence_scores}\n')
    # f.write(f'Mean confidence score: {confidence_scores.mean():.3f}\n')
    # f.write(f'Median confidence score: {confidence_scores.median():.3f}\n')
    # f.write('\n')
    
    # replace the code above with print, remove \n at the end
    # print(f'Template stage')
    # print(f'Dirichlet energy: {dirichlet_energy_list}')
    # print(f'Confidence scores: {confidence_scores}')
    # print(f'Mean confidence score: {confidence_scores.mean():.3f}')
    # print(f'Median confidence score: {confidence_scores.median():.3f}')
        
    f.close()
        
    return p2p_est, p2p_dirichlet, p2p_median, confidence_scores, dist_second, dirichlet_energy_list


In [4]:
class Arguments:
    def __init__(self):
        self.experiment_name='partial_0.8_5k_xyz_32_1_lambda_0.01_anisRemesh_holes_bbox_partial_0.8_xy'
        self.checkpoint_name='epoch_99'
        
        self.dataset_name='SHREC16_holes_pair_noSingle'
        self.split='test'
        
        self.num_iters_avg=64
        self.num_samples_median=10
        self.confidence_threshold=0.3
        
        self.smoothing_type=None
        self.smoothing_iter=None
        
        self.zoomout_num_evecs_template=-1

In [5]:

# args = parse_args()

args = Arguments()

# configuration
experiment_name = args.experiment_name
checkpoint_name = args.checkpoint_name

### config
exp_base_folder = f'/home/s94zalek_hpc/shape_matching/my_code/experiments/ddpm/{experiment_name}'
with open(f'{exp_base_folder}/config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)


### model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = DiagConditionedUnet(config["model_params"])

if "accelerate" in config and config["accelerate"]:
    accelerate.load_checkpoint_in_model(model, f"{exp_base_folder}/checkpoints/{checkpoint_name}/model.safetensors")
else:
    model.load_state_dict(torch.load(f"{exp_base_folder}/checkpoints/{checkpoint_name}"))

model.to(device)

### Sign correction network
sign_corr_net = diffusion_network.DiffusionNet(
    **config["sign_net"]["net_params"]
    )        
sign_corr_net.load_state_dict(torch.load(
        f'{config["sign_net"]["net_path"]}/{config["sign_net"]["n_iter"]}.pth'
        ))
sign_corr_net.to(device)


### noise scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2',
                                clip_sample=True) 

# fmap network
fmnet = RegularizedFMNet(lmbda=config["sign_net"]["regularization_lambda"], resolvant_gamma=0.5)


### test dataset
dataset_name = args.dataset_name
split = args.split

single_dataset, test_dataset = data_loading.get_val_dataset(
    dataset_name, split, 200, preload=False, return_evecs=True, centering='bbox'
    )
# sign_corr_net.cache_dir = single_dataset.lb_cache_dir

num_evecs = config["model_params"]["sample_size"]

##########################################
# Template
##########################################

template_shape = template_dataset.get_template(
    num_evecs=200,
    centering='bbox',
    template_path=f'/home/s94zalek_hpc/shape_matching/data/SURREAL_full/template/{config["sign_net"]["template_type"]}/template.off',
    template_corr=np.loadtxt(
        f'/home/s94zalek_hpc/shape_matching/data/SURREAL_full/template/{config["sign_net"]["template_type"]}/corr.txt',
        dtype=np.int32) - 1
    )    

##########################################
# Logging
##########################################

if args.smoothing_type is not None:
    test_name = f'{args.smoothing_type}-{args.smoothing_iter}'
else:
    test_name = 'no_smoothing'

log_dir = f'{exp_base_folder}/eval/{checkpoint_name}/{dataset_name}-{split}/{test_name}'
os.makedirs(log_dir, exist_ok=True)

fig_dir = f'{log_dir}/figs'
os.makedirs(fig_dir, exist_ok=True)

log_file_name = f'{log_dir}/log_{test_name}.txt'



In [6]:

##########################################
# 1.1: Template stage, get the functional maps and signs of evecs
##########################################

# data_range_2 = range(len(test_dataset))

# data_range_2 = range(10)
data_range_2 = [5]
# print('!!! WARNING: only 2 samples are processed !!!')
    

       
device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
geo_errs_gt = []
geo_errs_pairzo = []
geo_errs_dirichlet = []
geo_errs_median = []
geo_errs_median_filtered = []
geo_errs_median_filtered_noZo = []
geo_errs_dirichlet_pairzo = []
geo_errs_median_pairzo = []


for i in tqdm(data_range_2, desc='Calculating pair fmaps'):
    
    data = test_dataset[i]        
    
    evecs_first = data['first']['evecs'][:, :].to(device)
    evecs_second = data['second']['evecs'][:, :].to(device)
    
    evecs_trans_first = data['first']['evecs_trans'][:, :].to(device)
    evecs_trans_second = data['second']['evecs_trans'][:, :].to(device)
    
    evals_first = data['first']['evals'][:num_evecs].to(device)
    evals_second = data['second']['evals'][:num_evecs].to(device)

    corr_first = data['first']['corr'].to(device)
    corr_second = data['second']['corr'].to(device)
    
    mass_second = data['second']['mass'].to(device)
    
    ###############################################
    # Functional maps
    ###############################################
    
    # first mesh
    
        
    C_sampled_first_list, evecs_first_signs_list_first, evecs_second_signs_list_first = get_fmaps_evec_signs(
        data['first'], model,
        noise_scheduler, config, args,
        template_shape, sign_corr_net
    )
    
    if config["fmap_direction"] == 'xy':
        Cxy_first_list = C_sampled_first_list
        Cyx_first_list = Cxy_first_list.transpose(2, 3)
    else:
        Cyx_first_list = C_sampled_first_list
        Cxy_first_list = Cyx_first_list.transpose(2, 3)
    
    p2p_est_first, p2p_dirichlet_first, p2p_median_first, confidence_scores_first, dist_x, dirichlet_est_first = get_p2p_maps_template(
        data['first'],
        Cyx_first_list, evecs_first_signs_list_first, evecs_second_signs_list_first,
        template_shape, args, log_file_name, config
    )
    p2p_est_first_rev, p2p_dirichlet_first_rev, p2p_median_first_rev, _, _, dirichlet_est_first_rev = get_p2p_maps_template(
        template_shape,
        Cxy_first_list, evecs_second_signs_list_first, evecs_first_signs_list_first,
        data['first'], args, log_file_name, config
    )
    
    
    # second mesh
    
    C_sampled_second_list, evecs_first_signs_list_second, evecs_second_signs_list_second = get_fmaps_evec_signs(
        data['second'], model,
        noise_scheduler, config, args,
        template_shape, sign_corr_net
    )
    
    if config["fmap_direction"] == 'xy':
        Cxy_second_list = C_sampled_second_list
        Cyx_second_list = Cxy_second_list.transpose(2, 3)
    else:
        Cyx_second_list = C_sampled_second_list
        Cxy_second_list = Cyx_second_list.transpose(2, 3)
    
    p2p_est_second, p2p_dirichlet_second, p2p_median_second, confidence_scores_second, dist_y, dirichlet_est_second = get_p2p_maps_template(
        data['second'],
        Cyx_second_list, evecs_first_signs_list_second, evecs_second_signs_list_second,
        template_shape, args, log_file_name, config
    )
    
    p2p_est_second_rev, p2p_dirichlet_second_rev, p2p_median_second_rev, confidence_scores_second_rev, _, dirichlet_est_second_rev = get_p2p_maps_template(
        template_shape,
        Cxy_second_list, evecs_second_signs_list_second, evecs_first_signs_list_second,
        data['second'], args, log_file_name, config
    )


    corr_first = corr_first.cpu()
    corr_second = corr_second.cpu()
    

    p2p_est_pairzo = []
    geo_err_est_pairzo = []

    for k in range(args.num_iters_avg):
        
        p2p_est_pairzo_k = p2p_est_first[k][p2p_est_second_rev[k]].cpu()
        
        p2p_est_pairzo.append(p2p_est_pairzo_k)
        geo_err_est_pairzo.append(
            geodist_metric.calculate_geodesic_error(
            dist_x, corr_first, corr_second, p2p_est_pairzo_k, return_mean=True
        ) * 100)
  
    p2p_est_pairzo = torch.stack(p2p_est_pairzo)
    geo_err_est_pairzo = torch.tensor(geo_err_est_pairzo)



    p2p_est_dirichlet = p2p_dirichlet_first[p2p_dirichlet_second_rev].cpu()
    geo_err_est_dirichlet = geodist_metric.calculate_geodesic_error(
        dist_x, corr_first, corr_second, p2p_est_dirichlet, return_mean=True
    ) * 100

    p2p_est_median = p2p_median_first[p2p_median_second_rev].cpu()
    geo_err_est_median = geodist_metric.calculate_geodesic_error(
        dist_x, corr_first, corr_second, p2p_est_median, return_mean=True
    ) * 100
    
    
    
    p2p_dirichlet_pairzo, p2p_median_pairzo, confidence_scores, dirichlet_energy_list = select_p2p_map_dirichlet(
        p2p_est_pairzo,
        data['first']['verts'],
        data['second']['L'], 
        dist_x,
        num_samples_median=args.num_samples_median
        )
    
    geo_err_dirichlet_pairzo = geodist_metric.calculate_geodesic_error(
        dist_x, corr_first, corr_second, p2p_dirichlet_pairzo, return_mean=True
    ) * 100
    geo_err_median_pairzo = geodist_metric.calculate_geodesic_error(
        dist_x, corr_first, corr_second, p2p_median_pairzo, return_mean=True
    ) * 100
    
    
    
    if "id" in data["first"] and "id" in data["second"]:
        print(f'{i}: {data["first"]["id"]}, {data["second"]["id"]}')
    else:
        # print "name" instead of "id"
        print(f'{i}: {data["first"]["name"]}, {data["second"]["name"]}')

    print(f'Geo error est pairzo: {geo_err_est_pairzo}')
    print(f'Geo error est pairzo mean: {geo_err_est_pairzo.mean():.2f}')
    print(f'Geo error est dirichlet: {geo_err_est_dirichlet:.2f}')
    print(f'Geo error est median: {geo_err_est_median:.2f}')
    print(f'Geo error dirichlet pairzo: {geo_err_dirichlet_pairzo:.2f}')
    print(f'Geo error median pairzo: {geo_err_median_pairzo:.2f}')
    print('-----------------------------------')
    
    geo_errs_pairzo.append(geo_err_est_pairzo.mean())
    geo_errs_dirichlet.append(geo_err_est_dirichlet)
    geo_errs_median.append(geo_err_est_median)
    geo_errs_dirichlet_pairzo.append(geo_err_dirichlet_pairzo)
    geo_errs_median_pairzo.append(geo_err_median_pairzo)
    
    
    continue

    

    
    
    
    
    
    
    ###############################################
    # !!! Previous method
    ###############################################
    
    # put the data to device
    # p2p_est_first = p2p_est_first.to(device)
    # p2p_est_second = p2p_est_second.to(device)
    
    # p2p_dirichlet_first = p2p_dirichlet_first.to(device)
    # p2p_dirichlet_second = p2p_dirichlet_second.to(device)
    
    # p2p_median_first = p2p_median_first.to(device)
    # p2p_median_second = p2p_median_second.to(device)
    
    
    
    ###############################################
    # Geodesic errors
    ###############################################
    
    # GT geo error
    geo_err_gt = get_geo_error(
        corr_first, corr_second,
        evecs_first, evecs_second,
        corr_first, corr_second,
        num_evecs, False,
        dist_x, A2=mass_second,
        )
    
    # mean pred geo error with zoomout
    geo_err_est_pairzo = []
    p2p_est_pairzo = []
    for k in range(args.num_iters_avg):
        geo_err_k, p2p_k = get_geo_error(
            p2p_est_first[k], p2p_est_second[k],
            evecs_first, evecs_second,
            corr_first, corr_second,
            num_evecs, True,
            dist_x, A2=mass_second,
            return_p2p=True
            )
        geo_err_est_pairzo.append(geo_err_k)
        p2p_est_pairzo.append(p2p_k)
        
    geo_err_est_pairzo = torch.tensor(geo_err_est_pairzo)
    p2p_est_pairzo = torch.stack(p2p_est_pairzo)
    
    # dirichlet geo error
    geo_err_est_dirichlet, p2p_est_dirichlet = get_geo_error(
        p2p_dirichlet_first, p2p_dirichlet_second,
        evecs_first, evecs_second,
        corr_first, corr_second,
        num_evecs, True,
        dist_x, A2=mass_second,
        return_p2p=True
        )
        
    # median geo error
    geo_err_est_median, p2p_est_median = get_geo_error(
        p2p_median_first, p2p_median_second,
        evecs_first, evecs_second,
        corr_first, corr_second,
        num_evecs, True,
        dist_x, A2=mass_second,
        return_p2p=True
        )
    
    # median geo error with confidence filtering
    p2p_median_first_filtered, p2p_median_second_filtered = filter_p2p_by_confidence(
        p2p_median_first, p2p_median_second,
        confidence_scores_first, confidence_scores_second,
        args.confidence_threshold, log_file_name
        )
    geo_err_est_median_filtered = get_geo_error(
        p2p_median_first_filtered, p2p_median_second_filtered,
        evecs_first, evecs_second,
        corr_first, corr_second,
        num_evecs, True,
        dist_x, A2=mass_second,
        regularized=True,
        evecs_trans_first=evecs_trans_first, 
        evecs_trans_second=evecs_trans_second, 
        evals_first=evals_first, 
        evals_second=evals_second,
        fmnet=fmnet
        )
    geo_err_est_median_filtered_noZo = get_geo_error(
        p2p_median_first_filtered, p2p_median_second_filtered,
        evecs_first, evecs_second,
        corr_first, corr_second,
        num_evecs, False,
        dist_x, A2=mass_second,
        regularized=True,
        evecs_trans_first=evecs_trans_first, 
        evecs_trans_second=evecs_trans_second, 
        evals_first=evals_first, 
        evals_second=evals_second,
        fmnet=fmnet
        )
    
    ###############################################
    # Dirichlet and median maps at the pairwise stage
    ###############################################

    p2p_dirichlet_pairzo, p2p_median_pairzo, confidence_scores, dirichlet_energy_list = select_p2p_map_dirichlet(
        p2p_est_pairzo,
        data['first']['verts'],
        data['second']['L'], 
        dist_x,
        num_samples_median=args.num_samples_median
        )
    
    geo_err_dirichlet_pairzo = geodist_metric.calculate_geodesic_error(
        dist_x, corr_first.cpu(), corr_second.cpu(), p2p_dirichlet_pairzo, return_mean=True
    ) * 100
    geo_err_median_pairzo = geodist_metric.calculate_geodesic_error(
        dist_x, corr_first.cpu(), corr_second.cpu(), p2p_median_pairzo, return_mean=True
    ) * 100
    
    
    ###############################################
    # Logging
    ###############################################

    # replace the code above with print, remove \n at the end
    
    if "id" in data["first"] and "id" in data["second"]:
        print(f'{i}: {data["first"]["id"]}, {data["second"]["id"]}')
    else:
        # print "name" instead of "id"
        print(f'{i}: {data["first"]["name"]}, {data["second"]["name"]}')
    
    
    print(f'Geo error GT: {geo_err_gt:.2f}')
    print(f'Geo error est pairzo: {geo_err_est_pairzo}')
    print(f'Geo error est pairzo mean: {geo_err_est_pairzo.mean():.2f}')
    print(f'Geo error est dirichlet: {geo_err_est_dirichlet:.2f}')
    print(f'Geo error est median: {geo_err_est_median:.2f}')
    print(f'Geo error est median filtered: {geo_err_est_median_filtered:.2f}')
    print(f'Geo error est median filtered noZo: {geo_err_est_median_filtered_noZo:.2f}')
    print(f'Geo error dirichlet pairzo: {geo_err_dirichlet_pairzo:.2f}')
    print(f'Geo error median pairzo: {geo_err_median_pairzo:.2f}')
    print('-----------------------------------')
    
    geo_errs_gt.append(geo_err_gt)
    geo_errs_pairzo.append(geo_err_est_pairzo.mean())
    geo_errs_dirichlet.append(geo_err_est_dirichlet)
    geo_errs_median.append(geo_err_est_median)
    geo_errs_median_filtered.append(geo_err_est_median_filtered)
    geo_errs_median_filtered_noZo.append(geo_err_est_median_filtered_noZo)
    geo_errs_dirichlet_pairzo.append(geo_err_dirichlet_pairzo)
    geo_errs_median_pairzo.append(geo_err_median_pairzo)


geo_errs_gt = torch.tensor(geo_errs_gt)
geo_errs_pairzo = torch.tensor(geo_errs_pairzo)
geo_errs_dirichlet = torch.tensor(geo_errs_dirichlet)
geo_errs_median = torch.tensor(geo_errs_median)
geo_errs_median_filtered = torch.tensor(geo_errs_median_filtered)
geo_errs_median_filtered_noZo = torch.tensor(geo_errs_median_filtered_noZo)
geo_errs_dirichlet_pairzo = torch.tensor(geo_errs_dirichlet_pairzo)
geo_errs_median_pairzo = torch.tensor(geo_errs_median_pairzo)
   

In [7]:
import my_code.utils.plotting_utils as plotting_utils
import matplotlib.pyplot as plt

C_gt_xy_lstsq = torch.linalg.lstsq(
    data['second']['evecs'][data['second']['corr']],
    data['first']['evecs'][data['first']['corr']]
    ).solution

C_gt_yx_lstsq = torch.linalg.lstsq(
    data['first']['evecs'][data['first']['corr']],
    data['second']['evecs'][data['second']['corr']]
    ).solution
     


l = 0
h = 64

fig, axs = plt.subplots(1, 4, figsize=(16, 4))

plotting_utils.plot_Cxy(fig, axs[0], Cxy_first_list[0][0].cpu(),
                        'fmnet', l, h, show_grid=False, show_colorbar=False)
plotting_utils.plot_Cxy(fig, axs[1],  Cxy_second_list[0][0].cpu(),
                        'fmnet', l, h, show_grid=False, show_colorbar=False)
plotting_utils.plot_Cxy(fig, axs[2],  Cyx_second_list[0][0].cpu(),
                        'fmnet', l, h, show_grid=False, show_colorbar=False)
plotting_utils.plot_Cxy(fig, axs[3],  C_gt_xy_lstsq,
                        'fmnet', l, h, show_grid=False, show_colorbar=False)
# plotting_utils.plot_Cxy(fig, axs[3],  single_dataset[0]['Cyx_template'][15],
#                         'fmnet', l, h, show_grid=False, show_colorbar=False)
plt.show()

In [8]:
data = test_dataset[9]

In [None]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    
    data['first']['verts'], data['first']['faces'],
    data['second']['verts'], data['second']['faces'],

    p2p_dirichlet_pairzo.cpu(),
    # p2p_median_pairzo.cpu(),
    # p2p_est_pairzo[1].cpu(),
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [None]:
geo_err_dirichlet_full = geodist_metric.calculate_geodesic_error(
    dist_x, corr_first, corr_second, p2p_dirichlet_pairzo, return_mean=False
) * 100

print(geo_err_dirichlet_full)

In [None]:
(geo_err_dirichlet_full > geo_err_dirichlet_full.mean()).sum()

In [None]:
mesh_second = trimesh.Trimesh(data['second']['verts'], data['second']['faces'])

mesh_second.visual.vertex_colors = np.ones((len(mesh_second.vertices), 4)) * 255

mesh_second.visual.vertex_colors[geo_err_dirichlet_full > geo_err_dirichlet_full.mean()] = [255, 0, 0, 255]

mesh_second.show()



In [16]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

# swap x and y axis in [6890, 3] vector
# verts_first_fixed = data['first']['verts'].clone()
# verts_first_fixed[:, 0] = data['first']['verts'][:, 0]
# verts_first_fixed[:, 1] = data['first']['verts'][:, 2]
# verts_first_fixed[:, 2] = -data['first']['verts'][:, 1]

plotting_utils.plot_p2p_map(
    scene,
    
    data['first']['verts'],
    # verts_first_fixed,
    data['first']['faces'],
    template_shape['verts'], template_shape['faces'],
    
    p2p_dirichlet_first.cpu(),
    # p2p_est_second[k].cpu(),
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [21]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    
    data['second']['verts'], data['second']['faces'],
    template_shape['verts'], template_shape['faces'],
    
    p2p_dirichlet_second.cpu(),
    # p2p_est_second[k].cpu(),
    axes_color_gradient=[0, 2],
    base_cmap='hsv'
)

scene.show()

In [18]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    template_shape['verts'], template_shape['faces'],
    data['second']['verts'], data['second']['faces'],
    
    p2p_dirichlet_second_rev.cpu(),
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [None]:
p2p_dirichlet_first[p2p_dirichlet_second_rev].shape

In [62]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

k = 17

plotting_utils.plot_p2p_map(
    scene,
    
    data['first']['verts'], data['first']['faces'],
    data['second']['verts'], data['second']['faces'],

    # p2p_dirichlet_first[p2p_dirichlet_second_rev].cpu(),
    
    p2p_est_first[k][p2p_est_second_rev[k]].cpu(),
    
    # p2p_est_second[k].cpu(),
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [63]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

k = 17

plotting_utils.plot_p2p_map(
    scene,
    
    data['second']['verts'], data['second']['faces'],
    data['first']['verts'], data['first']['faces'],
    

    # p2p_dirichlet_second[p2p_dirichlet_first_rev].cpu(),
    
    p2p_est_second[k][p2p_est_first_rev[k]].cpu(),
    
    # p2p_est_second[k].cpu(),
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [None]:
# print(dist_x.shape, dist_y.shape)
# print(corr_first.shape, corr_second.shape)
# print(corr_first, corr_second)

corr_first = corr_first.cpu()
corr_second = corr_second.cpu()

err_12_list = []
err_21_list = []

for k in range(args.num_iters_avg):

    map_12 = p2p_est_first[k][p2p_est_second_rev[k]].cpu()
    map_21 = p2p_est_second[k][p2p_est_first_rev[k]].cpu()
    
    err_12 = dist_x[corr_first, map_12].mean() * 100
    err_21 = dist_x[corr_first, map_21[corr_second]].mean() * 100
    
    
    
    
    
    err_geo_12 = geodist_metric.calculate_geodesic_error(
        dist_x, corr_first, corr_second, map_12, return_mean=True
    )
    # print(err_geo_21)
    
    err_12_list.append(err_12)
    err_21_list.append(err_21)

err_12_list = torch.stack(err_12_list)
err_21_list = torch.stack(err_21_list)

print(err_12_list)
print(err_21_list)

In [None]:
scene.geometry.clear()

mesh_second = trimesh.Trimesh(vertices=data['second']['verts'].cpu(), faces=data['second']['faces'].cpu())
mesh_second.apply_transform(mesh_second.principal_inertia_transform)

mesh_first = trimesh.Trimesh(vertices=template_shape['verts'].cpu(), faces=template_shape['faces'].cpu())
mesh_first.apply_transform(mesh_first.principal_inertia_transform)

scene.add_geometry(mesh_first)
scene.add_geometry(mesh_second)

scene.show()

In [32]:
import my_code.utils.plotting_utils as plotting_utils

# k = torch.argmin

scene.geometry.clear()


k = 24

p2p_est_lstsq = fmap_util.fmap2pointmap(
    C12=Cyx_second_list[k][0].to(device),
    # C12=torch.linalg.pinv(Cxy_second_list[k][0].to(device)),
    evecs_x=(data['second']['evecs'][:, :num_evecs] * evecs_second_signs_list_second[k]).to(device),
    evecs_y=(template_shape['evecs'][:, :num_evecs] * evecs_first_signs_list_second[k]).to(device),
    ).cpu()

p2p_est_rev_lstsq = fmap_util.fmap2pointmap(
    C12=Cxy_second_list[k][0].to(device),
    # C12=torch.linalg.pinv(Cyx_second_list[k][0].to(device)),
    evecs_x=(template_shape['evecs'][:, :num_evecs] * evecs_first_signs_list_second[k]).to(device),
    evecs_y=(data['second']['evecs'][:, :num_evecs] * evecs_second_signs_list_second[k]).to(device),
    ).cpu()

# print(p2p_est_lstsq.shape)


plotting_utils.plot_p2p_map(
    scene,
    data['second']['verts'], data['second']['faces'],
    template_shape['verts'], template_shape['faces'],
    p2p_est_lstsq.cpu(),
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

# plotting_utils.plot_p2p_map(
#     scene,
#     template_shape['verts'], template_shape['faces'],
#     data['second']['verts'], data['second']['faces'],
#     p2p_est_rev_lstsq.cpu(),
#     axes_color_gradient=[0, 1],
#     base_cmap='hsv'
# )

scene.show()

In [45]:
projection_diff = (torch.matmul(
        (data['second']['evecs'][:, :num_evecs] * evecs_second_signs_list_second[k]).to(device),
        Cyx_second_list[k][0].to(device).t())[p2p_est_lstsq] -\
    (template_shape['evecs'][:, :num_evecs] * evecs_first_signs_list_second[k]).to(device)).abs().sum(dim=1)
projection_diff = projection_diff.cpu()



In [47]:
# plot a histogram of the projection difference

plt.hist(projection_diff, bins=100, cumulative=True, density=True)
plt.show()

In [20]:
p2p_dirichlet_pairzo.shape

scene.geometry.clear()

# plotting_utils.plot_p2p_map(
#     scene,
#     data['second']['verts'], data['second']['faces'],
#     data['first']['verts'], data['first']['faces'],
#     p2p_dirichlet_pairzo.cpu(),
#     axes_color_gradient=[0, 1],
#     base_cmap='hsv'
# )
plotting_utils.plot_p2p_map(
    scene,
    data['first']['verts'], data['first']['faces'],
    data['second']['verts'], data['second']['faces'],
    p2p_dirichlet_pairzo.cpu(),
    # p2p_dirichlet_first_rev[p2p_dirichlet_second_rev.cpu()],
    axes_color_gradient=[0, 2],
    base_cmap='hsv'
)

scene.show()

In [21]:
mass_first = data['first']['mass'].to(device)

_, p2p_est_dirichlet_rev = get_geo_error(
    p2p_dirichlet_second, p2p_dirichlet_first,
    evecs_second, evecs_first,
    corr_second, corr_first,
    num_evecs, True,
    dist_y, A2=mass_first,
    return_p2p=True
    )

In [23]:
p2p_dirichlet_pairzo.shape

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    data['second']['verts'], data['second']['faces'],
    data['first']['verts'], data['first']['faces'],
    p2p_est_dirichlet_rev.cpu(),
    axes_color_gradient=[0, 2],
    base_cmap='hsv'
)
# plotting_utils.plot_p2p_map(
#     scene,
#     data['first']['verts'], data['first']['faces'],
#     data['second']['verts'], data['second']['faces'],
#     p2p_dirichlet_pairzo.cpu(),
#     # p2p_dirichlet_first_rev[p2p_dirichlet_second_rev.cpu()],
#     axes_color_gradient=[0, 2],
#     base_cmap='hsv'
# )

scene.show()

# Orientation preservation

In [50]:
import sys
sys.path.append('/home/s94zalek_hpc/shape_matching/pyFM_fork')
from pyFM_fork.pyFM.mesh import TriMesh
from pyFM.functional import FunctionalMapping

mesh_template = TriMesh(template_shape['verts'].cpu(), template_shape['faces'].cpu()).process(k=150, intrinsic=True)
mesh_second = TriMesh(data['second']['verts'].cpu(), data['second']['faces'].cpu()).process(k=150, intrinsic=True)

fmapping = FunctionalMapping(mesh_template, mesh_second)
fmapping.preprocess(n_ev=(32, 32))
# orientation_op = fmapping.compute_orientation_op()

In [55]:
def compute_orientation_op(mesh1, mesh2, descr1, descr2, k, reversing=False, normalize=False):
    """
    Compute orientation preserving or reversing operators associated to each descriptor.

    Parameters
    ---------------------------------
    reversing : bool
        whether to return operators associated to orientation inversion instead
                of orientation preservation (return the opposite of the second operator)
    normalize : bool
        whether to normalize the gradient on each face. Might improve results
                according to the authors

    Returns
    ---------------------------------
    list_op : list
        (n_descr,) where term i contains (D1,D2) respectively of size (k1,k1) and
        (k2,k2) which represent operators supposed to commute.
    """
    
    n_descr = descr1.shape[1]

    # Precompute the inverse of the eigenvectors matrix
    pinv1 = mesh1.eigenvectors[:, :k].T @ mesh1.A  # (k1,n)
    pinv2 = mesh2.eigenvectors[:, :k].T @ mesh2.A  # (k2,n)

    # Compute the gradient of each descriptor
    grads1 = [mesh1.gradient(descr1[:, i], normalize=normalize) for i in range(n_descr)]
    grads2 = [mesh2.gradient(descr2[:, i], normalize=normalize) for i in range(n_descr)]

    # Compute the operators in reduced basis
    can_op1 = [pinv1 @ mesh1.orientation_op(gradf) @ mesh1.eigenvectors[:, :k]
                for gradf in grads1]

    if reversing:
        can_op2 = [- pinv2 @ mesh2.orientation_op(gradf) @ mesh2.eigenvectors[:, :k]
                    for gradf in grads2]
    else:
        can_op2 = [pinv2 @ mesh2.orientation_op(gradf) @ mesh2.eigenvectors[:, :k]
                    for gradf in grads2]

    list_op = list(zip(can_op1, can_op2))

    return list_op

In [73]:
orientation_op = compute_orientation_op(mesh_template, mesh_second, template_shape['verts'], data['second']['verts'], 32)

In [76]:
def evaluate_fmap(Cxy, orientation_op):
    score = 0
    
    for i in range(len(orientation_op)):
        score += ((Cxy @ orientation_op[i][0] - orientation_op[i][1] @ Cxy)**2).sum()
        
    return score

fmap_scores = []
for j in range(Cxy_second_list.shape[0]):
    fmap_scores.append(evaluate_fmap(
        Cxy_second_list[j][0].cpu().numpy(),
        orientation_op
        ))
    
fmap_scores = torch.tensor(fmap_scores)
print(fmap_scores)
print(geo_err_est_pairzo)

In [77]:
dirichlet_est_second_rev

In [78]:
# get the indices with 5 lowest dirichlet_est_second_rev
_, idx = torch.topk(dirichlet_est_second_rev, 5, largest=False)

# turn into boolean mask
mask = torch.zeros(len(dirichlet_est_second_rev), dtype=torch.bool)
mask[idx] = True
mask

In [None]:
import matplotlib.pyplot as plt

# min_score = torch.argmin(fmap_scores)
# print(geo_err_est_pairzo[min_score])

geo_err_sorted_idx = torch.argsort(geo_err_est_pairzo)

# plt.plot(geo_err_est_pairzo[geo_err_sorted_idx], fmap_scores[geo_err_sorted_idx], 'o')

# set the color according to the mask
colors = ['red' if m else 'blue' for m in mask[geo_err_sorted_idx]]
plt.scatter(fmap_scores[geo_err_sorted_idx], geo_err_est_pairzo[geo_err_sorted_idx], c=colors)


plt.xscale('log')
# set the titles for axes
plt.ylabel('Geo error')
plt.xlabel('Fmap score')
plt.show()

In [None]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    
    data['first']['verts'], data['first']['faces'],
    data['second']['verts'], data['second']['faces'],

    # p2p_dirichlet_pairzo.cpu(),
    # p2p_median_pairzo.cpu(),
    p2p_est_pairzo[min_score].cpu(),
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [None]:
# Cyx_second_list = C_sampled_second_list
Cxy_second_list
