In [31]:
import trimesh

scene = trimesh.Scene()

In [None]:
import torch
import numpy as np
import os
import shutil
from tqdm import tqdm
import yaml

import sys
import os

# models
from my_code.models.diag_conditional import DiagConditionedUnet
from diffusers import DDPMScheduler

import my_code.datasets.template_dataset as template_dataset

import my_code.diffusion_training_sign_corr.data_loading as data_loading

import networks.diffusion_network as diffusion_network
import matplotlib.pyplot as plt
import my_code.utils.plotting_utils as plotting_utils
import utils.fmap_util as fmap_util
import metrics.geodist_metric as geodist_metric
from my_code.sign_canonicalization.training import predict_sign_change
import argparse
from pyFM_fork.pyFM.refine.zoomout import zoomout_refine
import my_code.utils.zoomout_custom as zoomout_custom
from utils.shape_util import compute_geodesic_distmat
from my_code.diffusion_training_sign_corr.test.test_diffusion_cond import select_p2p_map_dirichlet, log_to_database, parse_args
import accelerate

tqdm._instances.clear()

In [None]:
import importlib
importlib.reload(zoomout_custom)

In [3]:
def get_geo_error(
    p2p_first, p2p_second,
    evecs_first, evecs_second,
    corr_first, corr_second,
    num_evecs, apply_zoomout,
    dist_x,
    regularized,
    evecs_trans_first=None, evecs_trans_second=None,
    evals_first=None, evals_second=None,
    return_p2p=False, return_Cxy=False,
    A2=None
    ):
        
    if regularized:
        Cxy = fmnet.compute_functional_map(
            evecs_trans_second[:, p2p_second].unsqueeze(0),
            evecs_trans_first[:, p2p_first].unsqueeze(0),
            evals_second.unsqueeze(0),
            evals_first.unsqueeze(0), 
        )[0].T
        
    else:
        Cxy = torch.linalg.lstsq(
            evecs_second[:, :num_evecs][p2p_second],
            evecs_first[:, :num_evecs][p2p_first]
            ).solution
    
    
    if apply_zoomout:
        Cxy = zoomout_custom.zoomout(
            FM_12=Cxy, 
            evects1=evecs_first,
            evects2=evecs_second,
            nit=evecs_first.shape[1] - num_evecs, step=1,
            A2=A2
        )
        num_evecs = evecs_first.shape[1]
        
    p2p = fmap_util.fmap2pointmap(
        C12=Cxy,
        evecs_x=evecs_first[:, :num_evecs],
        evecs_y=evecs_second[:, :num_evecs],
        ).cpu()
    
    geo_err = geodist_metric.calculate_geodesic_error(
        dist_x, corr_first.cpu(), corr_second.cpu(), p2p, return_mean=True
    )
    
    # if return_p2p:
    #     return geo_err * 100, p2p
    # else:
    #     return geo_err * 100
    
    if not return_p2p and not return_Cxy:
        return geo_err * 100
    
    payload = [geo_err * 100]
    
    if return_p2p:
        payload.append(p2p)
    if return_Cxy:
        payload.append(Cxy)
        
    return payload


def filter_p2p_by_confidence(
    p2p_first, p2p_second,
    confidence_scores_first, confidence_scores_second,
    confidence_threshold, log_file_name
    ):
    
    assert p2p_first.shape[0] == p2p_second.shape[0]
    
    # select points with both confidence scores above threshold
    valid_points = (confidence_scores_first < confidence_threshold) & (confidence_scores_second < confidence_threshold)
    
    with open(log_file_name, 'a') as f:
        
        while valid_points.sum() < 0.05 * len(valid_points):
            confidence_threshold += 0.05
            valid_points = (confidence_scores_first < confidence_threshold) & (confidence_scores_second < confidence_threshold)
            
            print(f'Increasing confidence threshold: {confidence_threshold}\n')
        print(f'Valid points: {valid_points.sum()}')
        assert valid_points.sum() > 0, "No valid points found"
        
    p2p_first = p2p_first[valid_points]
    p2p_second = p2p_second[valid_points]
    
    return p2p_first, p2p_second


In [28]:
class Arguments:
    def __init__(self):
        self.experiment_name='single_64_2-2ev_64-128-128_remeshed_fixed'
        self.checkpoint_name='epoch_99'
        
        self.dataset_name='SHREC19_r_pair'
        self.split='test'
        
        self.num_iters_avg=100
        self.num_samples_median=8
        self.confidence_threshold=0.3

In [None]:
args = Arguments()

# configuration
experiment_name = args.experiment_name
checkpoint_name = args.checkpoint_name

### config
exp_base_folder = f'/home/s94zalek_hpc/shape_matching/my_code/experiments/ddpm/{experiment_name}'
with open(f'{exp_base_folder}/config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)


### model
model = DiagConditionedUnet(config["model_params"]).to('cuda')

if "accelerate" in config and config["accelerate"]:
    accelerate.load_checkpoint_in_model(model, f"{exp_base_folder}/checkpoints/{checkpoint_name}/model.safetensors")
else:
    model.load_state_dict(torch.load(f"{exp_base_folder}/checkpoints/{checkpoint_name}"))


model = model.to('cuda')

### Sign correction network
sign_corr_net = diffusion_network.DiffusionNet(
    **config["sign_net"]["net_params"]
    ).to('cuda')
    
sign_corr_net.load_state_dict(torch.load(
        f'{config["sign_net"]["net_path"]}/{config["sign_net"]["n_iter"]}.pth'
        ))


### sample the model
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2',
                                clip_sample=True) 


### test dataset
dataset_name = args.dataset_name
split = args.split

single_dataset, test_dataset = data_loading.get_val_dataset(
    dataset_name, split, 200, preload=False, return_evecs=True
    )
sign_corr_net.cache_dir = single_dataset.lb_cache_dir


num_evecs = config["model_params"]["sample_size"]


##########################################
# Template
##########################################

template_shape = template_dataset.get_template(
    # template_path='data/SURREAL_full/template/template.ply',
    num_evecs=single_dataset.num_evecs,
    # template_corr=list(range(6890)),
    centering='bbox',
    
    template_path=f'/home/s94zalek_hpc/shape_matching/data/SURREAL_full/template/{config["sign_net"]["template_type"]}/template.off',
    template_corr=np.loadtxt(
        f'/home/s94zalek_hpc/shape_matching/data/SURREAL_full/template/{config["sign_net"]["template_type"]}/corr.txt',
        dtype=np.int32) - 1
    )    

##########################################
# Logging
##########################################

# log_dir = f'{exp_base_folder}/eval/{checkpoint_name}/{dataset_name}-{split}-template'
# os.makedirs(log_dir, exist_ok=True)

# fig_dir = f'{log_dir}/figs'
# os.makedirs(fig_dir, exist_ok=True)

# log_file_name = f'{log_dir}/log.txt'

log_dir = f'{exp_base_folder}/eval/{checkpoint_name}/{dataset_name}-{split}/no_smoothing'
os.makedirs(log_dir, exist_ok=True)

fig_dir = f'{log_dir}/figs'
os.makedirs(fig_dir, exist_ok=True)

log_file_name = f'{log_dir}/log.txt'


##########################################
# Template stage
##########################################

# data_range = tqdm(range(len(single_dataset)), desc='Calculating fmaps to template')

# data_range = tqdm([0, 11, 43])
data_range = tqdm([0, 43])
# 
# data_range = tqdm(range(2))
# print('!!! WARNING: only 2 samples are processed !!!')

for i in data_range:

    data = single_dataset[i]
    
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # device = 'cpu'
    
    verts_first = template_shape['verts'].unsqueeze(0).to(device)
    verts_second = data['verts'].unsqueeze(0).to(device)
    
    faces_first = template_shape['faces'].unsqueeze(0).to(device)
    faces_second = data['faces'].unsqueeze(0).to(device)

    evecs_first = template_shape['evecs'][:, :num_evecs].unsqueeze(0).to(device)
    evecs_second = data['evecs'][:, :num_evecs].unsqueeze(0).to(device)
    
    evals_first = template_shape['evals'][:num_evecs]
    evals_second = data['evals'][:num_evecs]

    # corr_first = data['first']['corr']
    # corr_second = data['corr']
    
    if config["sign_net"]["with_mass"]:
        mass_mat_first = torch.diag_embed(
            template_shape['mass'].unsqueeze(0)
            ).to(device)
        mass_mat_second = torch.diag_embed(
            data['mass'].unsqueeze(0)
            ).to(device)
    else:
        mass_mat_first = None
        mass_mat_second = None


    evecs_cond_first_list = []
    evecs_cond_second_list = []
    evecs_first_corrected_list = []
    evecs_second_corrected_list = []

    for _ in range(args.num_iters_avg):

        # predict the sign change
        with torch.no_grad():
            sign_pred_first, support_vector_norm_first, _ = predict_sign_change(
                sign_corr_net, verts_first, faces_first, evecs_first, 
                mass_mat=mass_mat_first, input_type=sign_corr_net.input_type,
                evecs_per_support=config["sign_net"]["evecs_per_support"],
                
                # mass=None, L=None, evals=None, evecs=None, gradX=None, gradY=None
                mass=template_shape['mass'].unsqueeze(0), L=template_shape['L'].unsqueeze(0),
                evals=template_shape['evals'][:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
                evecs=template_shape['evecs'][:,:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
                gradX=template_shape['gradX'].unsqueeze(0), gradY=template_shape['gradY'].unsqueeze(0)
                )
            sign_pred_second, support_vector_norm_second, _ = predict_sign_change(
                sign_corr_net, verts_second, faces_second, evecs_second, 
                mass_mat=mass_mat_second, input_type=sign_corr_net.input_type,
                evecs_per_support=config["sign_net"]["evecs_per_support"],
                
                # mass=None, L=None, evals=None, evecs=None, gradX=None, gradY=None
                mass=data['mass'].unsqueeze(0), L=data['L'].unsqueeze(0),
                evals=data['evals'][:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
                evecs=data['evecs'][:,:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
                gradX=data['gradX'].unsqueeze(0), gradY=data['gradY'].unsqueeze(0)
                )

        # correct the evecs
        evecs_first_corrected = evecs_first.cpu()[0] * torch.sign(sign_pred_first).cpu()
        evecs_first_corrected_norm = evecs_first_corrected / torch.norm(evecs_first_corrected, dim=0, keepdim=True)
        
        evecs_second_corrected = evecs_second.cpu()[0] * torch.sign(sign_pred_second).cpu()
        evecs_second_corrected_norm = evecs_second_corrected / torch.norm(evecs_second_corrected, dim=0, keepdim=True)
        
        # product with support
        # evecs_cond_first = evecs_first_corrected_norm.transpose(0, 1) @ support_vector_norm_first[0].cpu()
        # evecs_cond_second = evecs_second_corrected_norm.transpose(0, 1) @ support_vector_norm_second[0].cpu()


        # product with support
        if config["sign_net"]["with_mass"]:
        # if config["sign_net"]['cond_mass_normalize']:
            
            mass_mat_first = torch.diag_embed(
                template_shape['mass'].unsqueeze(0)
                ).to(device)
            mass_mat_second = torch.diag_embed(
                data['mass'].unsqueeze(0)
                ).to(device)
            
            evecs_cond_first = torch.nn.functional.normalize(
                support_vector_norm_first[0].cpu().transpose(0, 1) \
                    @ mass_mat_first[0].cpu(),
                p=2, dim=1) \
                    @ evecs_first_corrected_norm
            
            evecs_cond_second = torch.nn.functional.normalize(
                support_vector_norm_second[0].cpu().transpose(0, 1) \
                    @ mass_mat_second[0].cpu(),
                p=2, dim=1) \
                    @ evecs_second_corrected_norm 
            
        else:
            evecs_cond_first = support_vector_norm_first[0].cpu().transpose(0, 1) @ evecs_first_corrected_norm
            evecs_cond_second = support_vector_norm_second[0].cpu().transpose(0, 1) @ evecs_second_corrected_norm
    
        evecs_cond_first_list.append(evecs_cond_first)
        evecs_cond_second_list.append(evecs_cond_second)
        evecs_first_corrected_list.append(evecs_first_corrected)
        evecs_second_corrected_list.append(evecs_second_corrected)
        
    evecs_cond_first_list = torch.stack(evecs_cond_first_list)
    evecs_cond_second_list = torch.stack(evecs_cond_second_list)
    evecs_first_corrected_list = torch.stack(evecs_first_corrected_list)
    evecs_second_corrected_list = torch.stack(evecs_second_corrected_list)
    
    ###############################################
    # Conditioning
    ###############################################

    # conditioning = torch.tensor([])
    
    # if 'evals' in config["conditioning_types"]:
    #     eval = evals_second.unsqueeze(0)
    #     eval = torch.diag_embed(eval)
    #     conditioning = torch.cat((conditioning, eval), 0)
    
    # if 'evals_inv' in config["conditioning_types"]:
    #     eval_inv = 1 / evals_second.unsqueeze(0)
    #     # replace elements > 1 with 1
    #     eval_inv[eval_inv > 1] = 1
    #     eval_inv = torch.diag_embed(eval_inv)
    #     conditioning = torch.cat((conditioning, eval_inv), 0)
    
    # if 'evecs' in config["conditioning_types"]:
    #     evecs = torch.cat(
    #         (evecs_cond_first.unsqueeze(0), evecs_cond_second.unsqueeze(0)),
    #         0)
    #     conditioning = torch.cat((conditioning, evecs), 0)
    
    conditioning = torch.cat(
        (evecs_cond_first_list.unsqueeze(1), evecs_cond_second_list.unsqueeze(1)),
        1)
    
    
    ###############################################
    # Sample the model
    ###############################################
    
    # x_sampled = torch.rand(args.num_iters_avg, 1, model.model.sample_size, model.model.sample_size).to(device)
    x_sampled = torch.rand(args.num_iters_avg, 1, 
                           config["model_params"]["sample_size"],
                           config["model_params"]["sample_size"]).to(device)
    # y = conditioning.unsqueeze(0).repeat(args.num_iters_avg, 1, 1, 1).to(device)  
    y = conditioning.to(device)  
    
    # print(x_sampled.shape, y.shape)
        
    # Sampling loop
    for t in noise_scheduler.timesteps:

        # Get model pred
        with torch.no_grad():
            residual = model(x_sampled, t,
                                conditioning=y
                                ).sample

        # Update sample with step
        x_sampled = noise_scheduler.step(residual, t, x_sampled).prev_sample
    
    
    ###############################################
    # Zoomout
    ###############################################
    
    # evecs_first_zo = torch.cat(
    #     [evecs_first_corrected,
    #         template_shape['evecs'][:, num_evecs:]], 1)
    # evecs_second_zo = torch.cat(
    #     [evecs_second_corrected,
    #         data['evecs'][:, num_evecs:]], 1)
    
    
    # single_dataset.additional_data[i]['evecs_zo'] = evecs_second_zo
    # single_dataset.additional_data[i]['evecs_corrected_list'] = evecs_second_corrected_list

    single_dataset.additional_data[i]['p2p_est'] = []
    # single_dataset.additional_data[i]['p2p_est_zo'] = []
    
    for k in range(args.num_iters_avg):
        
        # evecs_first_zo = torch.cat(
        #     [evecs_first_corrected_list[k],
        #         template_shape['evecs'][:, num_evecs:]], 1)
        # evecs_second_zo = torch.cat(
        #     [evecs_second_corrected_list[k],
        #         data['evecs'][:, num_evecs:]], 1)

        
        Cyx_est_k = x_sampled[k][0]
    
        # Cyx_est_zo_k = zoomout_custom.zoomout(
        #     FM_12=Cyx_est_k.to(device), 
        #     evects1=evecs_second_zo.to(device), 
        #     evects2=evecs_first_zo.to(device),
        #     nit=evecs_first_zo.shape[1] - num_evecs, step=1,
        # ).cpu()

        p2p_est_k = fmap_util.fmap2pointmap(
            C12=Cyx_est_k,
            evecs_x=evecs_second_corrected_list[k].to(device),
            evecs_y=evecs_first_corrected_list[k].to(device),
            ).cpu()

        # p2p_est_zo_k = fmap_util.fmap2pointmap(
        #     C12=Cyx_est_zo_k.to(device),
        #     evecs_x=evecs_second_zo.to(device),
        #     evecs_y=evecs_first_zo.to(device),
        #     ).cpu()

        # single_dataset.additional_data[i]['Cyx_est'].append(Cyx_est_k)
        # single_dataset.additional_data[i]['Cyx_est_zo'].append(Cyx_est_zo_k)
        # single_dataset.additional_data[i]['evecs_zo'] = evecs_second_zo

        single_dataset.additional_data[i]['p2p_est'].append(p2p_est_k)
        # single_dataset.additional_data[i]['p2p_est_zo'].append(p2p_est_zo_k)
        
        
    single_dataset.additional_data[i]['p2p_est'] = torch.stack(single_dataset.additional_data[i]['p2p_est'])
        
    ##########################################################
    # p2p map selection
    ##########################################################
    
    dist_second = torch.tensor(
        compute_geodesic_distmat(
            verts_second[0].cpu().numpy(),
            faces_second[0].cpu().numpy())    
    )
    
    p2p_dirichlet, p2p_median, confidence_scores, dirichlet_energy_list = select_p2p_map_dirichlet(
        single_dataset.additional_data[i]['p2p_est'],
        verts_second[0].cpu(),
        template_shape['L'], 
        dist_second,
        num_samples_median=args.num_samples_median
        )
    
    single_dataset.additional_data[i]['p2p_dirichlet'] = p2p_dirichlet
    single_dataset.additional_data[i]['p2p_median'] = p2p_median
    single_dataset.additional_data[i]['confidence_scores_median'] = confidence_scores
    
    single_dataset.additional_data[i]['geo_dist'] = dist_second
    
    
    # with open(log_file_name, 'a') as f:
    #     print(f'Template stage, {i}\n')
    #     print(f'Dirichlet energy: {dirichlet_energy_list}\n')
    #     print(f'Confidence scores: {confidence_scores}\n')
    #     print(f'Mean confidence score: {confidence_scores.mean():.3f}\n')
    #     f.write(f'Median confidence score: {confidence_scores.median():.3f}\n')
    #     print('\n')
    
    print(f'{i}, mean cs {confidence_scores.mean():.3f}, median {confidence_scores.median():.3f}')

    

In [None]:
single_dataset[43]['evecs_corrected_list'].shape

In [None]:
import matplotlib.pyplot as plt

plt.hist(single_dataset[0]['confidence_scores_median'], bins=20)

plt.show()
# single_dataset[43]['confidence_scores_median']

In [None]:
# import pickle

# for entry_name in ['p2p_dirichlet', 'p2p_median', 'confidence_scores_median']:
#     entry_list = [single_dataset.additional_data[i][entry_name] for i in range(len(single_dataset))]
#     entry_list = torch.stack(entry_list)
    
#     torch.save(entry_list, f'{log_dir}/{entry_name}.pt')    

In [None]:
test_dataset[1]['first']['id'], test_dataset[1]['second']['id']

In [None]:
single_dataset[0]['mass'][:, None].shape

In [None]:
##########################################
# Pairwise stage
##########################################
    
test_dataset.dataset = single_dataset
    
geo_errs_gt = []
geo_errs_corr_gt = []
geo_errs_pairzo = []
geo_errs_dirichlet = []
geo_errs_median = []
geo_errs_median_filtered = []

    
# data_range_pair = tqdm(range(len(test_dataset)), desc='Calculating pair fmaps',
#                        disable=True)

data_range_pair = tqdm([1])
# print('!!! WARNING: only 2 samples are processed !!!')


for i in data_range_pair:
    
    data = test_dataset[i]        
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # device = 'cpu'
    
    verts_first = data['first']['verts'].to(device)
    verts_second = data['second']['verts'].to(device)
    
    faces_first = data['first']['faces'].to(device)
    faces_second = data['second']['faces'].to(device)

    evecs_first = data['first']['evecs'][:, :].to(device)
    evecs_second = data['second']['evecs'][:, :].to(device)
    
    evals_first = data['first']['evals'][:num_evecs]
    evals_second = data['second']['evals'][:num_evecs]

    corr_first = data['first']['corr'].to(device)
    corr_second = data['second']['corr'].to(device)
    
    ###############################################
    # Functional maps
    ###############################################
    
    # evecs_first_zo = data['first']['evecs_zo'].to(device)
    # evecs_second_zo = data['second']['evecs_zo'].to(device)
    
    evecs_first_zo = data['first']['evecs'].to(device)
    evecs_second_zo = data['second']['evecs'].to(device)
    
    p2p_est_first = data['first']['p2p_est'].to(device)
    p2p_est_second = data['second']['p2p_est'].to(device)
    
    p2p_dirichlet_first = data['first']['p2p_dirichlet'].to(device)
    p2p_dirichlet_second = data['second']['p2p_dirichlet'].to(device)
    
    p2p_median_first = data['first']['p2p_median'].to(device)
    p2p_median_second = data['second']['p2p_median'].to(device)
    
    # dist_x = torch.tensor(
    #     compute_geodesic_distmat(data['first']['verts'].numpy(), data['first']['faces'].numpy())    
    # )
    dist_x = data['first']['geo_dist']
    
    # mass_second = data['second']['mass'].to(device)
    mass_second=data['second']['mass'].to(device)
    
    ###############################################
    # Geodesic errors
    ###############################################
    
    # GT geo error
    geo_err_gt = get_geo_error(
        corr_first, corr_second,
        evecs_first, evecs_second,
        corr_first, corr_second,
        num_evecs, False,
        dist_x, regularized=False,
        A2=mass_second
        )
    geo_err_corr_gt = get_geo_error(
        corr_first, corr_second,
        evecs_first_zo, evecs_second_zo,
        corr_first, corr_second,
        num_evecs, False,
        dist_x, regularized=False,
        A2=mass_second
        )
    
    # mean pred geo error with zoomout
    geo_err_est_pairzo = []
    for k in range(args.num_iters_avg):
        geo_err_est_pairzo.append(
            get_geo_error(
            p2p_est_first[k], p2p_est_second[k],
            evecs_first_zo, evecs_second_zo,
            corr_first, corr_second,
            num_evecs, True,
            dist_x, regularized=False,
            A2=mass_second
            ))
    geo_err_est_pairzo = torch.tensor(geo_err_est_pairzo)
    
    # dirichlet geo error
    geo_err_est_dirichlet = get_geo_error(
        p2p_dirichlet_first, p2p_dirichlet_second,
        evecs_first_zo, evecs_second_zo,
        corr_first, corr_second,
        num_evecs, True,
        dist_x, regularized=False,
        A2=mass_second
        )
        
    # median geo error
    geo_err_est_median = get_geo_error(
        p2p_median_first, p2p_median_second,
        evecs_first_zo, evecs_second_zo,
        corr_first, corr_second,
        num_evecs, True,
        dist_x, regularized=False,
        A2=mass_second
        )
    
    # median geo error with confidence filtering
    
    # geo_err_est_median_filtered_list = []
    
    # for confidence_threshold_i in [0.1, 0.2, 0.3]:
    
    #     p2p_median_first_filtered, p2p_median_second_filtered = filter_p2p_by_confidence(
    #         p2p_median_first, p2p_median_second,
    #         data['first']['confidence_scores_median'], data['second']['confidence_scores_median'],
    #         confidence_threshold_i, log_file_name
    #         )
    #     geo_err_est_median_filtered_i = get_geo_error(
    #         p2p_median_first_filtered, p2p_median_second_filtered,
    #         evecs_first_zo, evecs_second_zo,
    #         corr_first, corr_second,
    #         num_evecs, True,
    #         dist_x
    #         )
    #     geo_err_est_median_filtered_list.append(
    #         (confidence_threshold_i, geo_err_est_median_filtered_i)
    #         )
    
    p2p_median_first_filtered, p2p_median_second_filtered = filter_p2p_by_confidence(
        p2p_median_first, p2p_median_second,
        data['first']['confidence_scores_median'], data['second']['confidence_scores_median'],
        args.confidence_threshold, log_file_name
        )
    geo_err_est_median_filtered = get_geo_error(
        p2p_median_first_filtered, p2p_median_second_filtered,
        evecs_first_zo, evecs_second_zo,
        corr_first, corr_second,
        num_evecs, True,
        dist_x, regularized=False,
        A2=mass_second
        )
    
    
    # print('p2p_median_first_filtered', p2p_median_first_filtered.shape, p2p_median_first_filtered)
    # print('p2p_median_second_filtered', p2p_median_second_filtered.shape, p2p_median_second_filtered)
    
    
    print(f'{i}: {data["first"]["id"]}, {data["second"]["id"]}')
    print(f'Geo error GT: {geo_err_gt:.2f}')
    print(f'Geo error est pairzo: {geo_err_est_pairzo}')
    print(f'Geo error est pairzo mean: {geo_err_est_pairzo.mean():.2f}')
    print(f'Geo error est dirichlet: {geo_err_est_dirichlet:.2f}')
    print(f'Geo error est median: {geo_err_est_median:.2f}')
    print(f'Geo error est median filtered: {geo_err_est_median_filtered:.2f}')
    
    # for confidence_threshold_i, geo_err_est_median_filtered_i in geo_err_est_median_filtered_list:
    #     print(f'Geo error est median filtered {confidence_threshold_i:.2f}: {geo_err_est_median_filtered_i:.2f}')
    
    print('-----------------------------------\n')
    
    geo_errs_gt.append(geo_err_gt)
    geo_errs_corr_gt.append(geo_err_corr_gt)
    geo_errs_pairzo.append(geo_err_est_pairzo.mean())
    geo_errs_dirichlet.append(geo_err_est_dirichlet)
    geo_errs_median.append(geo_err_est_median)
    # geo_errs_median_filtered.append(geo_err_est_median_filtered)
    
    # break


geo_errs_gt = torch.tensor(geo_errs_gt)
geo_errs_corr_gt = torch.tensor(geo_errs_corr_gt)
geo_errs_pairzo = torch.tensor(geo_errs_pairzo)
geo_errs_dirichlet = torch.tensor(geo_errs_dirichlet)
geo_errs_median = torch.tensor(geo_errs_median)
geo_errs_median_filtered = torch.tensor(geo_errs_median_filtered)
    
# data = [(
#     args.experiment_name,
#     args.checkpoint_name, 
#     'no', 
#     args.dataset_name,
#     args.split, 
#     # dirichlet
#     geo_errs_dirichlet.mean().item(),
#     # median p2p
#     geo_errs_median.mean().item(),
#     # zoomout
#     geo_errs_pairzo.mean().item(), geo_errs_pairzo.median().item(),
#     # pred
#     0, 0
#     ),]

# data = {
#     'experiment_name': args.experiment_name,
#     'checkpoint_name': args.checkpoint_name, 
#     'smoothing': 'no', 
#     'dataset_name': args.dataset_name,
#     'split': args.split,
    
#     'confidence_filtered': geo_errs_median_filtered.mean().item(),
    
#     'dirichlet': geo_errs_dirichlet.mean().item(),
#     'p2p_median': geo_errs_median.mean().item(),
    
#     'zoomout_mean': geo_errs_pairzo.mean().item(),
#     'zoomout_median': geo_errs_pairzo.median().item(),
#     }


In [None]:
import my_code.utils.plotting_utils as plotting_utils

k = geo_err_est_pairzo.argmin()

geo_err_k, p2p_k = get_geo_error(
    p2p_est_first[k], p2p_est_second[k],
    evecs_first_zo, evecs_second_zo,
    corr_first, corr_second,
    num_evecs, False,
    dist_x, regularized=False,
    A2=mass_second,
    return_p2p=True
    )
print(geo_err_k)

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    data['first']['verts'], data['first']['faces'],
    data['second']['verts'], data['second']['faces'],
    p2p_k,
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [62]:
import trimesh

# interpolate an array of colors
import numpy as np

def plot_double_p2p_map(
    scene,
    verts_first, verts_second,
    faces_first, faces_second,
    p2p_first, p2p_second,
):

    cmap = trimesh.visual.color.interpolate(
        (verts_first[:, 1] + verts_first[:, 0]).cpu()[p2p_first.cpu()],
        'jet')
    
    mesh_1 = trimesh.Trimesh(vertices=verts_first.cpu().numpy(), faces=faces_first.cpu().numpy())
    mesh_2 = trimesh.Trimesh(vertices=verts_second.cpu().numpy() + [1, 0, 0], faces=faces_second.cpu().numpy())

    mesh_1.visual.vertex_colors = np.ones_like(mesh_1.vertices) * 255
    mesh_1.visual.vertex_colors[p2p_first.cpu().numpy()] = cmap

    mesh_2.visual.vertex_colors = np.ones_like(mesh_2.vertices) * 255
    mesh_2.visual.vertex_colors[p2p_second.cpu().numpy()] = cmap

    scene.add_geometry(mesh_1)
    scene.add_geometry(mesh_2)
    
    return scene

In [None]:
import trimesh

scene = trimesh.Scene()

scene = plot_double_p2p_map(
    scene,
    verts_first, verts_second,
    faces_first, faces_second,
    p2p_median_first_filtered, p2p_median_second_filtered,
    # p2p_est_first[5], p2p_est_second[5],
)

scene.show()

# Fmap network

In [22]:
# import networks.fmap_network as fmap_network

# fmnet = fmap_network.RegularizedFMNet()

In [8]:
import networks.fmap_network as fmap_network
import torch.nn as nn


class RegularizedFMNet(nn.Module):
    """Compute the functional map matrix representation in DPFM"""
    def __init__(self, lmbda=0.01, resolvant_gamma=0.5, bidirectional=False):
        super(RegularizedFMNet, self).__init__()
        self.lmbda = lmbda
        self.resolvant_gamma = resolvant_gamma
        self.bidirectional = bidirectional

    def compute_functional_map(self, A, B, evals_x, evals_y):
        # A = torch.bmm(evecs_trans_x, feat_x)  # [B, K, C]
        # B = torch.bmm(evecs_trans_y, feat_y)  # [B, K, C]

        D = fmap_network.get_mask(evals_x, evals_y, self.resolvant_gamma)  # [B, K, K]

        A_t = A.transpose(1, 2)  # [B, C, K]
        A_A_t = torch.bmm(A, A_t)  # [B, K, K]
        B_A_t = torch.bmm(B, A_t)  # [B, K, K]

        C_i = []
        for i in range(evals_x.shape[1]):
            D_i = torch.cat([torch.diag(D[bs, i, :].flatten()).unsqueeze(0) for bs in range(evals_x.shape[0])], dim=0)
            C = torch.bmm(torch.inverse(A_A_t + self.lmbda * D_i), B_A_t[:, [i], :].transpose(1, 2))
            C_i.append(C.transpose(1, 2))

        Cxy = torch.cat(C_i, dim=1)
        
        # print('Cxy.shape', Cxy.shape)
        # print('Cxy', Cxy)
        # exit(0)
        
        return Cxy
    
fmnet = RegularizedFMNet()

In [24]:
# C_fmnet = fmnet.forward(
#     torch.ones_like(evecs_first_zo[:, :num_evecs][p2p_median_first_filtered].unsqueeze(0)),
#     torch.ones_like(evecs_second_zo[:, :num_evecs][p2p_median_second_filtered].unsqueeze(0)),
#     # torch.ones(1, p2p_median_first_filtered.shape[0], 64, device='cuda'),
#     # torch.ones(1, p2p_median_second_filtered.shape[0], 64,   device='cuda'),
#     evals_first.cuda().unsqueeze(0), evals_second.cuda().unsqueeze(0),
#     torch.pinverse(evecs_first_zo[:, :num_evecs])[:, p2p_median_first_filtered].unsqueeze(0),
#     torch.pinverse(evecs_second_zo[:, :num_evecs])[:, p2p_median_second_filtered].unsqueeze(0),
# )[0][0]


C_fmnet = fmnet.compute_functional_map(
    torch.pinverse(evecs_second_zo[:, :num_evecs])[:, p2p_median_second_filtered].unsqueeze(0),
    torch.pinverse(evecs_first_zo[:, :num_evecs])[:, p2p_median_first_filtered].unsqueeze(0),
    evals_second.cuda().unsqueeze(0),
    evals_first.cuda().unsqueeze(0), 
)[0]

Cxy_evecs = evecs_second_zo[p2p_median_second_filtered, :num_evecs].T @\
    (mass_second[p2p_median_second_filtered, None] * evecs_first_zo[p2p_median_first_filtered, :num_evecs])

Cxy_lstsq = torch.linalg.lstsq(
    evecs_second_zo[:, :num_evecs][p2p_median_second_filtered],
    evecs_first_zo[:, :num_evecs][p2p_median_first_filtered]
    ).solution

Cxy_gt = torch.linalg.lstsq(
    evecs_second_zo[:, :num_evecs][corr_second],
    evecs_first_zo[:, :num_evecs][corr_first]
    ).solution

In [None]:
import my_code.utils.plotting_utils as plotting_utils
import matplotlib.pyplot as plt

l = 0
h = 64

fig, axs = plt.subplots(1, 5, figsize=(16, 4))

plotting_utils.plot_Cxy(fig, axs[0],  C_fmnet.cpu(),
                        'fmnet', l, h, show_grid=False, show_colorbar=False)
plotting_utils.plot_Cxy(fig, axs[1],  Cxy_evecs.cpu(),
                        'evecs', l, h, show_grid=False, show_colorbar=False)
plotting_utils.plot_Cxy(fig, axs[2],  Cxy_gt.cpu(),
                        'GT', l, h, show_grid=False, show_colorbar=False)
plotting_utils.plot_Cxy(fig, axs[3],  (C_fmnet.T - Cxy_gt).cpu(),
                        'fmnet - GT', l, h, show_grid=False, show_colorbar=False)
plotting_utils.plot_Cxy(fig, axs[4],  (Cxy_lstsq - Cxy_gt).cpu(),
                        'lstsq - GT', l, h, show_grid=False, show_colorbar=False)
plt.show()

In [None]:
geo_err, p2p_expanded, Cxy_est = get_geo_error(
    p2p_median_first_filtered, p2p_median_second_filtered,
    evecs_first_zo, evecs_second_zo,
    corr_first, corr_second,
    num_evecs, False,
    dist_x,
    regularized=True,
    evecs_trans_first=torch.pinverse(evecs_first_zo[:, :num_evecs]).to('cuda'),
    evecs_trans_second=torch.pinverse(evecs_second_zo[:, :num_evecs]).to('cuda'),
    evals_first=evals_first.to('cuda'),
    evals_second=evals_second.to('cuda'),
    return_p2p=True, return_Cxy=True    
    )
geo_err


# C_fmnet = fmnet.compute_functional_map(
#     torch.pinverse(evecs_second_zo[:, :num_evecs])[:, p2p_median_second].unsqueeze(0),
#     torch.pinverse(evecs_first_zo[:, :num_evecs])[:, p2p_median_first].unsqueeze(0),
#     evals_second.cuda().unsqueeze(0),
#     evals_first.cuda().unsqueeze(0), 
# )[0]

In [None]:
import importlib

importlib.reload(plotting_utils)

In [None]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    data['first']['verts'], data['first']['faces'],
    data['second']['verts'], data['second']['faces'],
    p2p_expanded,
    axes_color_gradient=[0,1 ],
    base_cmap='hsv'
)

scene.show()

# Are the conditioning evecs always the same?

In [None]:

data_range = tqdm([0, 11, 43])
# 
# data_range = tqdm(range(2))
# print('!!! WARNING: only 2 samples are processed !!!')

# for i in data_range:

data = single_dataset[0]


device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'

verts_first = template_shape['verts'].unsqueeze(0).to(device)
verts_second = data['verts'].unsqueeze(0).to(device)

faces_first = template_shape['faces'].unsqueeze(0).to(device)
faces_second = data['faces'].unsqueeze(0).to(device)

evecs_first = template_shape['evecs'][:, :num_evecs].unsqueeze(0).to(device)
evecs_second = data['evecs'][:, :num_evecs].unsqueeze(0).to(device)

evals_first = template_shape['evals'][:num_evecs]
evals_second = data['evals'][:num_evecs]

# corr_first = data['first']['corr']
# corr_second = data['corr']

if config["sign_net"]["with_mass"]:
    mass_mat_first = torch.diag_embed(
        template_shape['mass'].unsqueeze(0)
        ).to(device)
    mass_mat_second = torch.diag_embed(
        data['mass'].unsqueeze(0)
        ).to(device)
else:
    mass_mat_first = None
    mass_mat_second = None


evecs_cond_first_list = []
evecs_cond_second_list = [
    
]
for _ in tqdm(range(100)):

    # predict the sign change
    with torch.no_grad():
        sign_pred_first, support_vector_norm_first, _ = predict_sign_change(
            sign_corr_net, verts_first, faces_first, evecs_first, 
            mass_mat=mass_mat_first, input_type=sign_corr_net.input_type,
            evecs_per_support=config["sign_net"]["evecs_per_support"],
            
            # mass=None, L=None, evals=None, evecs=None, gradX=None, gradY=None
            mass=template_shape['mass'].unsqueeze(0), L=template_shape['L'].unsqueeze(0),
            evals=template_shape['evals'][:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
            evecs=template_shape['evecs'][:,:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
            gradX=template_shape['gradX'].unsqueeze(0), gradY=template_shape['gradY'].unsqueeze(0)
            )
        sign_pred_second, support_vector_norm_second, _ = predict_sign_change(
            sign_corr_net, verts_second, faces_second, evecs_second, 
            mass_mat=mass_mat_second, input_type=sign_corr_net.input_type,
            evecs_per_support=config["sign_net"]["evecs_per_support"],
            
            # mass=None, L=None, evals=None, evecs=None, gradX=None, gradY=None
            mass=data['mass'].unsqueeze(0), L=data['L'].unsqueeze(0),
            evals=data['evals'][:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
            evecs=data['evecs'][:,:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
            gradX=data['gradX'].unsqueeze(0), gradY=data['gradY'].unsqueeze(0)
            )

    # correct the evecs
    evecs_first_corrected = evecs_first.cpu()[0] * torch.sign(sign_pred_first).cpu()
    evecs_first_corrected_norm = evecs_first_corrected / torch.norm(evecs_first_corrected, dim=0, keepdim=True)
    
    evecs_second_corrected = evecs_second.cpu()[0] * torch.sign(sign_pred_second).cpu()
    evecs_second_corrected_norm = evecs_second_corrected / torch.norm(evecs_second_corrected, dim=0, keepdim=True)
    
    # product with support
    # evecs_cond_first = evecs_first_corrected_norm.transpose(0, 1) @ support_vector_norm_first[0].cpu()
    # evecs_cond_second = evecs_second_corrected_norm.transpose(0, 1) @ support_vector_norm_second[0].cpu()


    # product with support
    if config["sign_net"]["with_mass"]:
    # if config["sign_net"]['cond_mass_normalize']:
        
        mass_mat_first = torch.diag_embed(
            template_shape['mass'].unsqueeze(0)
            ).to(device)
        mass_mat_second = torch.diag_embed(
            data['mass'].unsqueeze(0)
            ).to(device)
        
        evecs_cond_first = torch.nn.functional.normalize(
            support_vector_norm_first[0].cpu().transpose(0, 1) \
                @ mass_mat_first[0].cpu(),
            p=2, dim=1) \
                @ evecs_first_corrected_norm
        
        evecs_cond_second = torch.nn.functional.normalize(
            support_vector_norm_second[0].cpu().transpose(0, 1) \
                @ mass_mat_second[0].cpu(),
            p=2, dim=1) \
                @ evecs_second_corrected_norm 
        
    else:
        evecs_cond_first = support_vector_norm_first[0].cpu().transpose(0, 1) @ evecs_first_corrected_norm
        evecs_cond_second = support_vector_norm_second[0].cpu().transpose(0, 1) @ evecs_second_corrected_norm

    evecs_cond_first_list.append(evecs_cond_first)
    evecs_cond_second_list.append(evecs_cond_second)
    
evecs_cond_first_list = torch.stack(evecs_cond_first_list)
evecs_cond_second_list = torch.stack(evecs_cond_second_list)

In [None]:
import my_code.utils.plotting_utils as plotting_utils
import matplotlib.pyplot as plt

idx_0 = torch.randint(0, 100, (1,)).item()
idx_1 = torch.randint(0, 100, (1,)).item()
idx_2 = torch.randint(0, 100, (1,)).item()

l = 0
h = 64

fig, axs = plt.subplots(1, 3, figsize=(15, 4))

plotting_utils.plot_Cxy(
    fig, axs[0], 
    (evecs_cond_second_list[idx_0] - evecs_cond_second_list[idx_1]).cpu(),
    f'{idx_0} - {idx_1}', l, h, show_grid=False, show_colorbar=False)

plotting_utils.plot_Cxy(
    fig, axs[1],
    (evecs_cond_second_list[idx_1] - evecs_cond_second_list[idx_2]).cpu(),
    f'{idx_1} - {idx_2}', l, h, show_grid=False, show_colorbar=False)

plotting_utils.plot_Cxy(
    fig, axs[2], 
    (evecs_cond_second_list[idx_0] - evecs_cond_second_list[idx_2]).cpu(),
    f'{idx_0} - {idx_2}', l, h, show_grid=False, show_colorbar=False)
plt.show()