In [6]:
import sqlite3
import torch
import numpy as np
import os
import shutil
from tqdm import tqdm
import yaml

import sys
import os

# models
from my_code.models.diag_conditional import DiagConditionedUnet
from diffusers import DDPMScheduler

import my_code.datasets.template_dataset as template_dataset

import my_code.diffusion_training_sign_corr.data_loading as data_loading

import networks.diffusion_network as diffusion_network
import matplotlib.pyplot as plt
import my_code.utils.plotting_utils as plotting_utils
import utils.fmap_util as fmap_util
import metrics.geodist_metric as geodist_metric
from my_code.sign_canonicalization.training import predict_sign_change
import argparse
from pyFM_fork.pyFM.refine.zoomout import zoomout_refine
import my_code.utils.zoomout_custom as zoomout_custom
import my_code.sign_canonicalization.test_sign_correction as test_sign_correction

import accelerate

from utils.shape_util import compute_geodesic_distmat
from my_code.utils.median_p2p_map import get_median_p2p_map



tqdm._instances.clear()

In [7]:
class Arguments:
    def __init__(self):
        self.experiment_name = 'pair_5_xy_64_64_128_128'
        self.checkpoint_name = 'epoch_99'
        self.dataset_name = 'DT4D_intra_pair'
        self.split = 'test'
        self.num_iters_avg = 50
        
        self.smoothing_type = 'taubin'
        self.smoothing_iter = 5
        
args = Arguments()

In [None]:
# configuration
experiment_name = args.experiment_name
checkpoint_name = args.checkpoint_name

### config
exp_base_folder = f'/home/s94zalek_hpc/shape_matching/my_code/experiments/ddpm/{experiment_name}'
with open(f'{exp_base_folder}/config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)


### model
model = DiagConditionedUnet(config["model_params"]).to('cuda')
# model.load_state_dict(torch.load(f"{exp_base_folder}/checkpoints/{checkpoint_name}"))

if "accelerate" in config and config["accelerate"]:
    accelerate.load_checkpoint_in_model(model, f"{exp_base_folder}/checkpoints/{checkpoint_name}/model.safetensors")
else:
    model.load_state_dict(torch.load(f"{exp_base_folder}/checkpoints/{checkpoint_name}"))

model = model.to('cuda')


# algorithm
# smooth the single dataset
# for each mesh, correct the first evecs, get the conditioning

# for each pair
# sample the model with conditioning
# zoomout using corrected evecs



### Sign correction network
sign_corr_net = diffusion_network.DiffusionNet(
    **config["sign_net"]["net_params"]
    ).to('cuda')
    
sign_corr_net.load_state_dict(torch.load(
        f'{config["sign_net"]["net_path"]}/{config["sign_net"]["n_iter"]}.pth'
        ))

### sample the model
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2',
                                clip_sample=True)

### test dataset
dataset_name = args.dataset_name
split = args.split

single_dataset, test_dataset = data_loading.get_val_dataset(
    dataset_name, split, 200, preload=False, return_evecs=True
    )
# sign_corr_net.cache_dir = single_dataset.lb_cache_dir

single_dataset_remeshed = test_sign_correction.remesh_dataset(
    dataset=single_dataset, 
    name=dataset_name,
    remesh_targetlen=1,
    smoothing_type=args.smoothing_type,
    smoothing_iter=args.smoothing_iter,
    num_evecs=200,
)


num_evecs = config["model_params"]["sample_size"]

##########################################
# Logging
##########################################

log_dir = f'{exp_base_folder}/eval/{checkpoint_name}/{dataset_name}-{split}/{args.smoothing_type}-{args.smoothing_iter}'
os.makedirs(log_dir, exist_ok=True)

fig_dir = f'{log_dir}/figs'
os.makedirs(fig_dir, exist_ok=True)

log_file_name = f'{log_dir}/log_smooth_{args.smoothing_type}_{args.smoothing_iter}.txt'

In [None]:
##########################################
# Single stage
##########################################

data_range = tqdm(range(len(single_dataset_remeshed)), desc='Calculating conditioning, correcting evecs')

for i in data_range:

    data = single_dataset_remeshed[i]
    
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    
    verts_second = data['verts'].unsqueeze(0).to(device)
    faces_second = data['faces'].unsqueeze(0).to(device)
    
    evecs_second = data['evecs'][:, :num_evecs].unsqueeze(0).to(device)
    evals_second = data['evals'][:num_evecs]

    # corr_second = data['corr']
    
    if config["sign_net"]["with_mass"]:
        mass_mat_second = torch.diag_embed(
            data['mass'].unsqueeze(0)
            ).to(device)
    else:
        mass_mat_second = None

    # predict the sign change
    with torch.no_grad():
        sign_pred_second, support_vector_norm_second, _ = predict_sign_change(
            sign_corr_net, verts_second, faces_second, evecs_second, 
            mass_mat=mass_mat_second, input_type=sign_corr_net.input_type,
            # mass=None, L=None, evals=None, evecs=None, gradX=None, gradY=None
            mass=data['mass'].unsqueeze(0), L=data['L'].unsqueeze(0),
            evals=data['evals'][:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
            evecs=data['evecs'][:,:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
            gradX=data['gradX'].unsqueeze(0), gradY=data['gradY'].unsqueeze(0)
            )

    # correct the evecs
    evecs_second_corrected = evecs_second.cpu()[0] * torch.sign(sign_pred_second).cpu()
    evecs_second_corrected_norm = evecs_second_corrected / torch.norm(evecs_second_corrected, dim=0, keepdim=True)
    
    # product with support
    if config["sign_net"]["with_mass"]:
        mass_mat_second = torch.diag_embed(
            data['mass'].unsqueeze(0)
            ).to(device)
        
        evecs_cond_second = torch.nn.functional.normalize(
            support_vector_norm_second[0].cpu().transpose(0, 1) \
                @ mass_mat_second[0].cpu(),
            p=2, dim=1) \
                @ evecs_second_corrected_norm       
    else:
        evecs_cond_second = support_vector_norm_second[0].cpu().transpose(0, 1) @ evecs_second_corrected_norm
    
    
    ###############################################
    # Conditioning
    ###############################################

    conditioning = torch.tensor([])
    
    if 'evals' in config["conditioning_types"]:
        eval = evals_second.unsqueeze(0)
        eval = torch.diag_embed(eval)
        conditioning = torch.cat((conditioning, eval), 0)
    
    if 'evals_inv' in config["conditioning_types"]:
        eval_inv = 1 / evals_second.unsqueeze(0)
        # replace elements > 1 with 1
        eval_inv[eval_inv > 1] = 1
        eval_inv = torch.diag_embed(eval_inv)
        conditioning = torch.cat((conditioning, eval_inv), 0)
    
    if 'evecs' in config["conditioning_types"]:
        conditioning = torch.cat((conditioning,
                                    evecs_cond_second.unsqueeze(0)), 0)
    
    ###############################################
    # Correct the original evecs
    ###############################################
    
    data_orig = single_dataset[i]
    evecs_second_orig = data_orig['evecs'][:, :num_evecs]
    
    prod_evecs_orig_remesh_corrected = evecs_second_orig.transpose(0, 1) @ evecs_second_corrected[data['corr_orig_to_remeshed']].cpu()

    evecs_orig_signs = torch.sign(torch.diagonal(prod_evecs_orig_remesh_corrected, dim1=0, dim2=1))
    evecs_second_corrected_orig = evecs_second_orig * evecs_orig_signs
    
    evecs_second_orig_zo = torch.cat(
        [evecs_second_corrected_orig,
            data_orig['evecs'][:, num_evecs:]], 1)

    ###############################################
    # Save the data
    ###############################################

    single_dataset.additional_data[i]['evecs_zo'] = evecs_second_orig_zo
    single_dataset.additional_data[i]['conditioning'] = conditioning
    

In [None]:
single_dataset_remeshed[i]['verts_orig'].shape, single_dataset_remeshed[i]['verts'].shape, single_dataset_remeshed[i]['corr_orig_to_remeshed'].shape