In [1]:
import trimesh

scene = trimesh.Scene()

In [None]:
import torch
import numpy as np
import os
import shutil
from tqdm import tqdm
import yaml

import sys
import os

# models
from my_code.models.diag_conditional import DiagConditionedUnet
from diffusers import DDPMScheduler

import my_code.datasets.template_dataset as template_dataset

import my_code.diffusion_training_sign_corr.data_loading as data_loading

import networks.diffusion_network as diffusion_network
import matplotlib.pyplot as plt
import my_code.utils.plotting_utils as plotting_utils
import utils.fmap_util as fmap_util
import metrics.geodist_metric as geodist_metric
from my_code.sign_canonicalization.training import predict_sign_change
import argparse
from pyFM_fork.pyFM.refine.zoomout import zoomout_refine
import my_code.utils.zoomout_custom as zoomout_custom
from utils.shape_util import compute_geodesic_distmat
from my_code.diffusion_training_sign_corr.test.test_diffusion_cond import select_p2p_map_dirichlet, log_to_database, parse_args
import accelerate
import my_code.sign_canonicalization.test_sign_correction as test_sign_correction
import networks.fmap_network as fmap_network
from my_code.utils.median_p2p_map import dirichlet_energy

tqdm._instances.clear()

class RegularizedFMNet(torch.nn.Module):
    """Compute the functional map matrix representation in DPFM"""
    def __init__(self, lmbda=0.01, resolvant_gamma=0.5, bidirectional=False):
        super(RegularizedFMNet, self).__init__()
        self.lmbda = lmbda
        self.resolvant_gamma = resolvant_gamma
        self.bidirectional = bidirectional

    def compute_functional_map(self, A, B, evals_x, evals_y):
        # A = torch.bmm(evecs_trans_x, feat_x)  # [B, K, C]
        # B = torch.bmm(evecs_trans_y, feat_y)  # [B, K, C]

        D = fmap_network.get_mask(evals_x, evals_y, self.resolvant_gamma)  # [B, K, K]

        A_t = A.transpose(1, 2)  # [B, C, K]
        A_A_t = torch.bmm(A, A_t)  # [B, K, K]
        B_A_t = torch.bmm(B, A_t)  # [B, K, K]

        C_i = []
        for i in range(evals_x.shape[1]):
            D_i = torch.cat([torch.diag(D[bs, i, :].flatten()).unsqueeze(0) for bs in range(evals_x.shape[0])], dim=0)
            C = torch.bmm(torch.inverse(A_A_t + self.lmbda * D_i), B_A_t[:, [i], :].transpose(1, 2))
            C_i.append(C.transpose(1, 2))

        Cxy = torch.cat(C_i, dim=1)
         
        return Cxy
    



def get_geo_error(
    p2p_first, p2p_second,
    evecs_first, evecs_second,
    corr_first, corr_second,
    num_evecs, apply_zoomout,
    dist_x,
    regularized=False,
    evecs_trans_first=None, evecs_trans_second=None,
    evals_first=None, evals_second=None,
    return_p2p=False, return_Cxy=False,
    A2=None, fmnet=None
    ):
        
    if regularized:
        Cxy = fmnet.compute_functional_map(
            evecs_trans_second[:num_evecs, p2p_second].unsqueeze(0),
            evecs_trans_first[:num_evecs, p2p_first].unsqueeze(0),
            evals_second[:num_evecs].unsqueeze(0),
            evals_first[:num_evecs].unsqueeze(0), 
        )[0].T
        
    else:
        Cxy = torch.linalg.lstsq(
            evecs_second[:, :num_evecs][p2p_second],
            evecs_first[:, :num_evecs][p2p_first]
            ).solution
    
    
    if apply_zoomout:
        Cxy = zoomout_custom.zoomout(
            FM_12=Cxy, 
            evects1=evecs_first,
            evects2=evecs_second,
            nit=evecs_first.shape[1] - num_evecs, step=1,
            A2=A2
        )
        num_evecs = evecs_first.shape[1]
        
    p2p = fmap_util.fmap2pointmap(
        C12=Cxy,
        evecs_x=evecs_first[:, :num_evecs],
        evecs_y=evecs_second[:, :num_evecs],
        ).cpu()
    
    geo_err = geodist_metric.calculate_geodesic_error(
        dist_x, corr_first.cpu(), corr_second.cpu(), p2p, return_mean=True
    )
    
    # if return_p2p:
    #     return geo_err * 100, p2p
    # else:
    #     return geo_err * 100
    
    if not return_p2p and not return_Cxy:
        return geo_err * 100
    
    payload = [geo_err * 100]
    
    if return_p2p:
        payload.append(p2p)
    if return_Cxy:
        payload.append(Cxy)
        
    return payload


def filter_p2p_by_confidence(
        p2p_first, p2p_second,
        confidence_scores_first, confidence_scores_second,
        confidence_threshold, log_file_name
    ):
    
    assert p2p_first.shape[0] == p2p_second.shape[0]
    
    # select points with both confidence scores above threshold
    valid_points = (confidence_scores_first < confidence_threshold) & (confidence_scores_second < confidence_threshold)
    
    with open(log_file_name, 'a') as f:
        
        while valid_points.sum() < 0.05 * len(valid_points):
            confidence_threshold += 0.05
            valid_points = (confidence_scores_first < confidence_threshold) & (confidence_scores_second < confidence_threshold)
            
            f.write(f'Increasing confidence threshold: {confidence_threshold}\n')
        f.write(f'Valid points: {valid_points.sum()}\n')
        assert valid_points.sum() > 0, "No valid points found"
        
    p2p_first = p2p_first[valid_points]
    p2p_second = p2p_second[valid_points]
    
    return p2p_first, p2p_second

In [3]:
def get_fmaps_evec_signs(
        data, model,
        noise_scheduler, config, args,
        template_shape, sign_corr_net
    ):
        
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    num_evecs = config["model_params"]["sample_size"]

        
    verts_first = template_shape['verts'].unsqueeze(0).to(device)
    verts_second = data['verts'].unsqueeze(0).to(device)
    
    faces_first = template_shape['faces'].unsqueeze(0).to(device)
    faces_second = data['faces'].unsqueeze(0).to(device)

    evecs_first = template_shape['evecs'][:, :num_evecs].unsqueeze(0).to(device)
    evecs_second = data['evecs'][:, :num_evecs].unsqueeze(0).to(device)
    

    if config["sign_net"]["with_mass"]:
        mass_mat_first = torch.diag_embed(
            template_shape['mass'].unsqueeze(0)
            ).to(device)
        mass_mat_second = torch.diag_embed(
            data['mass'].unsqueeze(0)
            ).to(device)
    else:
        mass_mat_first = None
        mass_mat_second = None


    ###############################################
    # get conditioning and signs num_iters_avg times
    ###############################################

    evecs_cond_first_list = []
    evecs_cond_second_list = []
    evecs_first_signs_list = []
    evecs_second_signs_list = []

    for _ in range(args.num_iters_avg):

        # predict the sign change
        with torch.no_grad():
            sign_pred_first, support_vector_norm_first, _ = predict_sign_change(
                sign_corr_net, verts_first, faces_first, evecs_first, 
                mass_mat=mass_mat_first, input_type=sign_corr_net.input_type,
                evecs_per_support=config["sign_net"]["evecs_per_support"],
                
                mass=template_shape['mass'].unsqueeze(0), L=template_shape['L'].unsqueeze(0),
                evals=template_shape['evals'][:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
                evecs=template_shape['evecs'][:,:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
                gradX=template_shape['gradX'].unsqueeze(0), gradY=template_shape['gradY'].unsqueeze(0)
                )
            sign_pred_second, support_vector_norm_second, _ = predict_sign_change(
                sign_corr_net, verts_second, faces_second, evecs_second, 
                mass_mat=mass_mat_second, input_type=sign_corr_net.input_type,
                evecs_per_support=config["sign_net"]["evecs_per_support"],
                
                mass=data['mass'].unsqueeze(0), L=data['L'].unsqueeze(0),
                evals=data['evals'][:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
                evecs=data['evecs'][:,:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
                gradX=data['gradX'].unsqueeze(0), gradY=data['gradY'].unsqueeze(0)
                )

        # correct the evecs
        evecs_first_corrected = evecs_first.cpu()[0] * torch.sign(sign_pred_first).cpu()
        evecs_first_corrected_norm = evecs_first_corrected / torch.norm(evecs_first_corrected, dim=0, keepdim=True)
        
        evecs_second_corrected = evecs_second.cpu()[0] * torch.sign(sign_pred_second).cpu()
        evecs_second_corrected_norm = evecs_second_corrected / torch.norm(evecs_second_corrected, dim=0, keepdim=True)
        

        # product with support
        if config["sign_net"]["with_mass"]:
        # if config["sign_net"]['cond_mass_normalize']:
            
            mass_mat_first = torch.diag_embed(
                template_shape['mass'].unsqueeze(0)
                ).to(device)
            mass_mat_second = torch.diag_embed(
                data['mass'].unsqueeze(0)
                ).to(device)
            
            evecs_cond_first = torch.nn.functional.normalize(
                support_vector_norm_first[0].cpu().transpose(0, 1) \
                    @ mass_mat_first[0].cpu(),
                p=2, dim=1) \
                    @ evecs_first_corrected_norm
            
            evecs_cond_second = torch.nn.functional.normalize(
                support_vector_norm_second[0].cpu().transpose(0, 1) \
                    @ mass_mat_second[0].cpu(),
                p=2, dim=1) \
                    @ evecs_second_corrected_norm 
            
        else:
            evecs_cond_first = support_vector_norm_first[0].cpu().transpose(0, 1) @ evecs_first_corrected_norm
            evecs_cond_second = support_vector_norm_second[0].cpu().transpose(0, 1) @ evecs_second_corrected_norm
            
        evecs_cond_first_list.append(evecs_cond_first)
        evecs_cond_second_list.append(evecs_cond_second)
        evecs_first_signs_list.append(torch.sign(sign_pred_first).cpu())
        evecs_second_signs_list.append(torch.sign(sign_pred_second).cpu())
        
    evecs_cond_first_list = torch.stack(evecs_cond_first_list)
    evecs_cond_second_list = torch.stack(evecs_cond_second_list)
    evecs_first_signs_list = torch.stack(evecs_first_signs_list)
    evecs_second_signs_list = torch.stack(evecs_second_signs_list)    
    
    
    ###############################################
    # Conditioning
    ###############################################

    conditioning = torch.cat(
        (evecs_cond_first_list.unsqueeze(1), evecs_cond_second_list.unsqueeze(1)),
        1)
    
    ###############################################
    # Sample the model
    ###############################################
    
    x_sampled = torch.rand(args.num_iters_avg, 1, 
                        config["model_params"]["sample_size"],
                        config["model_params"]["sample_size"]).to(device)
    y = conditioning.to(device)    
        
    # Sampling loop
    for t in noise_scheduler.timesteps:

        # Get model pred
        with torch.no_grad():
            residual = model(x_sampled, t,
                                conditioning=y
                                ).sample

        # Update sample with step
        x_sampled = noise_scheduler.step(residual, t, x_sampled).prev_sample
        
    return x_sampled, evecs_first_signs_list, evecs_second_signs_list


def get_p2p_maps_template(
        data,
        C_yx_est_i, evecs_first_signs_i, evecs_second_signs_i,
        template_shape, args, log_file_name, config
    ):
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    num_evecs = config["model_params"]["sample_size"]
    
    f = open(log_file_name, 'a', buffering=1)
    
    verts_second = data['verts']
    faces_second = data['faces']
    
    evecs_first = template_shape['evecs'][:, :num_evecs]
    evecs_second = data['evecs'][:, :num_evecs]
    
    # evecs_first = template_shape['evecs']
    # evecs_second = data['evecs']
    
    dist_second = torch.tensor(
        compute_geodesic_distmat(
            verts_second.numpy(),
            faces_second.numpy())    
    )
    
    ##########################################################
    # Convert fmaps to p2p maps to template
    ##########################################################
    
    p2p_est = []
    
    # version without zoomout and dirichlet energy condition
    for k in range(args.num_iters_avg):

        evecs_first_corrected = evecs_first * evecs_first_signs_i[k]
        evecs_second_corrected = evecs_second * evecs_second_signs_i[k]
        Cyx_est_k = C_yx_est_i[k][0].cpu()

        p2p_est_k = fmap_util.fmap2pointmap(
            C12=Cyx_est_k.to(device),
            evecs_x=evecs_second_corrected.to(device),
            evecs_y=evecs_first_corrected.to(device),
            ).cpu()

        p2p_est.append(p2p_est_k)
                

    p2p_est = torch.stack(p2p_est)
        
    ##########################################################
    # p2p map selection
    ##########################################################
    
    p2p_dirichlet, p2p_median, confidence_scores, dirichlet_energy_list = select_p2p_map_dirichlet(
        p2p_est,
        verts_second,
        template_shape['L'], 
        dist_second,
        num_samples_median=args.num_samples_median
        )
         
    # f.write(f'Template stage\n')
    # f.write(f'Dirichlet energy: {dirichlet_energy_list}\n')
    # f.write(f'Confidence scores: {confidence_scores}\n')
    # f.write(f'Mean confidence score: {confidence_scores.mean():.3f}\n')
    # f.write(f'Median confidence score: {confidence_scores.median():.3f}\n')
    # f.write('\n')
    
    # replace the code above with print, remove \n at the end
    print(f'Template stage')
    print(f'Dirichlet energy: {dirichlet_energy_list}')
    print(f'Confidence scores: {confidence_scores}')
    print(f'Mean confidence score: {confidence_scores.mean():.3f}')
    print(f'Median confidence score: {confidence_scores.median():.3f}')
        
    f.close()
        
    return p2p_est, p2p_dirichlet, p2p_median, confidence_scores, dist_second


In [4]:
class Arguments:
    def __init__(self):
        self.experiment_name='partial_0.8_5k_32_2'
        self.checkpoint_name='epoch_99'
        
        self.dataset_name='SHREC16_holes_pair'
        self.split='test'
        
        self.num_iters_avg=32
        self.num_samples_median=4
        self.confidence_threshold=0.3
        
        self.smoothing_type=None
        self.smoothing_iter=None

In [5]:

# args = parse_args()

args = Arguments()

# configuration
experiment_name = args.experiment_name
checkpoint_name = args.checkpoint_name

### config
exp_base_folder = f'/home/s94zalek_hpc/shape_matching/my_code/experiments/ddpm/{experiment_name}'
with open(f'{exp_base_folder}/config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)


### model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = DiagConditionedUnet(config["model_params"])

if "accelerate" in config and config["accelerate"]:
    accelerate.load_checkpoint_in_model(model, f"{exp_base_folder}/checkpoints/{checkpoint_name}/model.safetensors")
else:
    model.load_state_dict(torch.load(f"{exp_base_folder}/checkpoints/{checkpoint_name}"))

model.to(device)

### Sign correction network
sign_corr_net = diffusion_network.DiffusionNet(
    **config["sign_net"]["net_params"]
    )        
sign_corr_net.load_state_dict(torch.load(
        f'{config["sign_net"]["net_path"]}/{config["sign_net"]["n_iter"]}.pth'
        ))
sign_corr_net.to(device)


### noise scheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2',
                                clip_sample=True) 

# fmap network
fmnet = RegularizedFMNet(lmbda=0.01, resolvant_gamma=0.5)


### test dataset
dataset_name = args.dataset_name
split = args.split

single_dataset, test_dataset = data_loading.get_val_dataset(
    dataset_name, split, 200, preload=False, return_evecs=True
    )
# sign_corr_net.cache_dir = single_dataset.lb_cache_dir

num_evecs = config["model_params"]["sample_size"]

##########################################
# Template
##########################################

template_shape = template_dataset.get_template(
    num_evecs=200,
    centering='bbox',
    template_path=f'/home/s94zalek_hpc/shape_matching/data/SURREAL_full/template/{config["sign_net"]["template_type"]}/template.off',
    template_corr=np.loadtxt(
        f'/home/s94zalek_hpc/shape_matching/data/SURREAL_full/template/{config["sign_net"]["template_type"]}/corr.txt',
        dtype=np.int32) - 1
    )    

##########################################
# Logging
##########################################

if args.smoothing_type is not None:
    test_name = f'{args.smoothing_type}-{args.smoothing_iter}'
else:
    test_name = 'no_smoothing'

log_dir = f'{exp_base_folder}/eval/{checkpoint_name}/{dataset_name}-{split}/{test_name}'
os.makedirs(log_dir, exist_ok=True)

fig_dir = f'{log_dir}/figs'
os.makedirs(fig_dir, exist_ok=True)

log_file_name = f'{log_dir}/log_{test_name}.txt'



In [99]:
from my_code.datasets.surreal_dataset_3dc import TemplateSurrealDataset3DC


augmentations = {
    "remesh": {
            "isotropic": {
                "n_remesh_iters": 10,
                "remesh_targetlen": 1,
                "simplify_strength_min": 0.2,
                "simplify_strength_max": 0.8,
            },
            "partial": {
                "probability": 1,
                "n_remesh_iters": 10,
                "fraction_to_select_min": 0.35,
                "fraction_to_select_max": 0.65,
                "n_seed_samples": [1, 5, 25],
                "weighted_by": "area",
            },
        },
    }

    

test_dataset = TemplateSurrealDataset3DC(
    shape_path='/lustre/mlnvme/data/s94zalek_hpc-shape_matching/mmap_datas_surreal_train.pth',
    num_evecs=128,
    cache_lb_dir=None,
    return_evecs=True,
    return_fmap=False,
    mmap=True,
    augmentations=augmentations,
    template_path=f'/home/s94zalek_hpc/shape_matching/data/SURREAL_full/template/remeshed/template.off',
    template_corr=np.loadtxt(
        f'/home/s94zalek_hpc/shape_matching/data/SURREAL_full/template/remeshed/corr.txt',
        dtype=np.int32) - 1
)   

In [75]:
data = test_dataset[10]        

evecs_first = data['first']['evecs'][:, :].to(device)
evecs_second = data['second']['evecs'][:, :].to(device)

evecs_trans_first = data['first']['evecs_trans'][:, :].to(device)
evecs_trans_second = data['second']['evecs_trans'][:, :].to(device)

evals_first = data['first']['evals'][:num_evecs].to(device)
evals_second = data['second']['evals'][:num_evecs].to(device)

corr_first = data['first']['corr'].to(device)
corr_second = data['second']['corr'].to(device)

mass_second = data['second']['mass'].to(device)

###############################################
# Functional maps
###############################################

# first mesh

# Cxy_first_list, evecs_first_signs_list_first, evecs_second_signs_list_first = get_fmaps_evec_signs(
#     data['first'], model,
#     noise_scheduler, config, args,
#     template_shape, sign_corr_net
# )
# # transpose the functional maps
# Cyx_first_list = Cxy_first_list.transpose(2, 3)

# p2p_est_first, p2p_dirichlet_first, p2p_median_first, confidence_scores_first, dist_x = get_p2p_maps_template(
#     data['first'],
#     Cyx_first_list, evecs_first_signs_list_first, evecs_second_signs_list_first,
#     template_shape, args, log_file_name, config
# )
# p2p_est_first_rev, p2p_dirichlet_first_rev, p2p_median_first_rev, _, _ = get_p2p_maps_template(
#     template_shape,
#     Cxy_first_list, evecs_second_signs_list_first, evecs_first_signs_list_first,
#     data['first'], args, log_file_name, config
# )


# second mesh

Cxy_second_list, evecs_first_signs_list_second, evecs_second_signs_list_second = get_fmaps_evec_signs(
    data['second'], model,
    noise_scheduler, config, args,
    data['first'], sign_corr_net
)
# transpose the functional maps
Cyx_second_list = Cxy_second_list.transpose(2, 3)

p2p_est_second, p2p_dirichlet_second, p2p_median_second, confidence_scores_second, dist_y = get_p2p_maps_template(
    data['second'],
    Cyx_second_list, evecs_first_signs_list_second, evecs_second_signs_list_second,
    data['first'], args, log_file_name, config
)

p2p_est_second_rev, p2p_dirichlet_second_rev, p2p_median_second_rev, confidence_scores_second_rev, _ = get_p2p_maps_template(
    data['first'],
    Cxy_second_list, evecs_second_signs_list_second, evecs_first_signs_list_second,
    data['second'], args, log_file_name, config
)
    

In [None]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    
    data['second']['verts'], data['second']['faces'],
    data['first']['verts'], data['first']['faces'],
    
    p2p_dirichlet_second.cpu(),
    # p2p_est_second[k].cpu(),
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [None]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    
    data['first']['verts'], data['first']['faces'],
    data['second']['verts'], data['second']['faces'],
    
    
    p2p_dirichlet_second_rev.cpu(),
    # p2p_est_second[k].cpu(),
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [77]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    
    data['first']['verts'], data['first']['faces'],
    data['second']['verts'], data['second']['faces'],
    
    
    data['first']['corr'].cpu(),
    # p2p_est_second[k].cpu(),
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [136]:
fmnet = RegularizedFMNet(lmbda=0.001, resolvant_gamma=0.5)

In [137]:
data = test_dataset[10] 

In [138]:
num_evecs_test = 64

C_gt_xy = torch.linalg.lstsq(
    data['second']['evecs'][:, :num_evecs_test][data['second']['corr']].to(device),
    data['first']['evecs'][:, :num_evecs_test][data['first']['corr']].to(device)
    ).solution.to('cpu') #.unsqueeze(0)

C_gt_yx = torch.linalg.lstsq(
    data['first']['evecs'][:, :num_evecs_test][data['first']['corr']].to(device),
    data['second']['evecs'][:, :num_evecs_test][data['second']['corr']].to(device)
    ).solution.to('cpu') #.unsqueeze(0)

p2p_xy = fmap_util.fmap2pointmap(
    C12=C_gt_xy,
    evecs_x=data['first']['evecs'][:, :num_evecs_test],
    evecs_y=data['second']['evecs'][:, :num_evecs_test],
    ).cpu()

p2p_yx = fmap_util.fmap2pointmap(
    C12=C_gt_yx,
    evecs_x=data['second']['evecs'][:, :num_evecs_test],
    evecs_y=data['first']['evecs'][:, :num_evecs_test],
    ).cpu()

In [139]:
C_gt_xy_reg = fmnet.compute_functional_map(
    data['second']['evecs_trans'][:num_evecs_test, data['second']['corr']].unsqueeze(0),
    data['first']['evecs_trans'][:num_evecs_test, data['first']['corr']].unsqueeze(0),
    data['second']['evals'][:num_evecs_test].unsqueeze(0),
    data['first']['evals'][:num_evecs_test].unsqueeze(0), 
)[0].T

C_gt_yx_reg = fmnet.compute_functional_map(
    data['first']['evecs_trans'][:num_evecs_test, data['first']['corr']].unsqueeze(0),
    data['second']['evecs_trans'][:num_evecs_test, data['second']['corr']].unsqueeze(0),
    data['first']['evals'][:num_evecs_test].unsqueeze(0), 
    data['second']['evals'][:num_evecs_test].unsqueeze(0),
)[0].T

p2p_xy_reg = fmap_util.fmap2pointmap(
    C12=C_gt_xy_reg,
    evecs_x=data['first']['evecs'][:, :num_evecs_test],
    evecs_y=data['second']['evecs'][:, :num_evecs_test],
    ).cpu()

p2p_yx_reg = fmap_util.fmap2pointmap(
    C12=C_gt_yx_reg,
    evecs_x=data['second']['evecs'][:, :num_evecs_test],
    evecs_y=data['first']['evecs'][:, :num_evecs_test],
    ).cpu()

In [None]:
import my_code.utils.plotting_utils as plotting_utils
import matplotlib.pyplot as plt

l = 0
h = 64

fig, axs = plt.subplots(1, 4, figsize=(16, 4))

plotting_utils.plot_Cxy(fig, axs[0], C_gt_xy.cpu(),
                        'fmnet', l, h, show_grid=False, show_colorbar=False)
plotting_utils.plot_Cxy(fig, axs[1],  C_gt_yx.cpu(),
                        'fmnet', l, h, show_grid=False, show_colorbar=False)
plotting_utils.plot_Cxy(fig, axs[2],  C_gt_xy_reg.cpu(),
                        'fmnet', l, h, show_grid=False, show_colorbar=False)
plotting_utils.plot_Cxy(fig, axs[3],  C_gt_yx_reg.cpu(),
                        'fmnet', l, h, show_grid=False, show_colorbar=False)
plt.show()

In [None]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    
    data['first']['verts'], data['first']['faces'],
    data['second']['verts'], data['second']['faces'],
    
    
    p2p_xy.cpu(),
    # p2p_est_second[k].cpu(),
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [127]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    data['second']['verts'], data['second']['faces'],    
    data['first']['verts'], data['first']['faces'],

    p2p_yx.cpu(),
    # p2p_est_second[k].cpu(),
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [141]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    
    data['first']['verts'], data['first']['faces'],
    data['second']['verts'], data['second']['faces'],
    
    
    p2p_xy_reg.cpu(),
    # p2p_est_second[k].cpu(),
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [142]:
import my_code.utils.plotting_utils as plotting_utils

scene.geometry.clear()

plotting_utils.plot_p2p_map(
    scene,
    data['second']['verts'], data['second']['faces'],    
    data['first']['verts'], data['first']['faces'],

    p2p_yx_reg.cpu(),
    # p2p_est_second[k].cpu(),
    axes_color_gradient=[0, 1],
    base_cmap='hsv'
)

scene.show()

In [135]:
print(data['first']['evecs'].shape, data['first']['mass'][None].shape)

(data['first']['evecs'].T * data['first']['mass'][None]).shape