In [None]:
import torch
import numpy as np
import os
import shutil
from tqdm import tqdm
import yaml

import sys
import os

# models
from my_code.models.diag_conditional import DiagConditionedUnet
from diffusers import DDPMScheduler

import my_code.datasets.template_dataset as template_dataset

import my_code.diffusion_training_sign_corr.data_loading as data_loading

import networks.diffusion_network as diffusion_network
import matplotlib.pyplot as plt
import my_code.utils.plotting_utils as plotting_utils
import utils.fmap_util as fmap_util
import metrics.geodist_metric as geodist_metric
from my_code.sign_canonicalization.training import predict_sign_change
import argparse
from pyFM_fork.pyFM.refine.zoomout import zoomout_refine
import my_code.utils.zoomout_custom as zoomout_custom
from utils.shape_util import compute_geodesic_distmat
from my_code.diffusion_training_sign_corr.test.test_diffusion_cond import select_p2p_map_dirichlet, log_to_database, parse_args
import accelerate
import my_code.sign_canonicalization.test_sign_correction as test_sign_correction
import networks.fmap_network as fmap_network
from my_code.utils.median_p2p_map import dirichlet_energy


In [32]:
# apply standard scaling and PCA
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import seaborn as sns
import pandas as pd


def apply_pca(input_data, title, pca_components=3, use_scaler=True, show_ratio=True, show_pairplot=True, color_by='body_type'):

    if use_scaler:
        scaler = StandardScaler()
        input_data_scaled = scaler.fit_transform(input_data.reshape(input_data.shape[0], -1))
    else:
        input_data_scaled = input_data.reshape(input_data.shape[0], -1)
        
    pca = PCA(n_components=32)
    input_data_pca = pca.fit_transform(input_data_scaled)


    if show_ratio:
        fig, axs = plt.subplots(1, 1, figsize=(5, 4))

        # plot explained variance
        axs.plot(pca.explained_variance_ratio_, '.-')
        axs.set_title(f'{title}: explained variance ratio')

    if show_pairplot:
        pca_df = pd.DataFrame(input_data_pca[:, :pca_components],
                              columns=[f'PCA_{i}' for i in range(pca_components)])
        # pca_df['name'] = names_y
        
        if color_by == 'body_type':
            pca_df['color_by'] = [i // 10 for i in range(input_data_pca.shape[0])]
        elif color_by == 'pose':
            pca_df['color_by'] = [i % 10 for i in range(input_data_pca.shape[0])]
        else:
            raise ValueError(f'color_by={color_by} not supported')


        # use numbers as markers
        sns.pairplot(pca_df, diag_kind='kde', hue='color_by', palette='tab10')

    if show_ratio or show_pairplot:
        plt.show()


In [9]:
class Arguments:
    def __init__(self):
        self.experiment_name='single_template_remeshed'
        self.checkpoint_name='checkpoint_99.pt'
        
        self.dataset_name='FAUST_orig'
        self.split='train'
        
        self.num_iters_avg=64
        self.num_samples_median=10
        self.confidence_threshold=0.3
        
        self.smoothing_type=None
        self.smoothing_iter=None
        
        self.zoomout_num_evecs_template=-1
        
        # self.reduced_dim=32

In [10]:
args = Arguments()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# configuration
experiment_name = args.experiment_name
checkpoint_name = args.checkpoint_name

### config
exp_base_folder = f'/home/s94zalek_hpc/shape_matching/my_code/experiments/ddpm/{experiment_name}'
with open(f'{exp_base_folder}/config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)


### Sign correction network
sign_corr_net = diffusion_network.DiffusionNet(
    **config["sign_net"]["net_params"]
    )        
sign_corr_net.load_state_dict(torch.load(
        f'{config["sign_net"]["net_path"]}/{config["sign_net"]["n_iter"]}.pth'
        ))
sign_corr_net.to(device)


### test dataset
dataset_name = args.dataset_name
split = args.split

single_dataset, test_dataset = data_loading.get_val_dataset(
    dataset_name, split, 200, preload=False, return_evecs=True, centering='bbox',
    )
sign_corr_net.cache_dir = single_dataset.lb_cache_dir

num_evecs = config["model_params"]["sample_size"]


In [None]:
import pyshot


Cxy_corr_list = []

shot_list = []

for i in tqdm(range((len(test_dataset))), desc='Calculating fmaps to template, evec signs'):

    data = test_dataset[i]
    
    verts_first = data['first']['verts'].unsqueeze(0).to(device)
    verts_second = data['second']['verts'].unsqueeze(0).to(device)
    
    faces_first = data['first']['faces'].unsqueeze(0).to(device)
    faces_second = data['second']['faces'].unsqueeze(0).to(device)

    evecs_first = data['first']['evecs'][:, :num_evecs].unsqueeze(0).to(device)
    evecs_second = data['second']['evecs'][:, :num_evecs].unsqueeze(0).to(device)
    
    corr_first = data['first']['corr']
    corr_second = data['second']['corr']
    

    if config["sign_net"]["with_mass"]:
        mass_mat_first = torch.diag_embed(
            data['first']['mass'].unsqueeze(0)
            ).to(device)
        mass_mat_second = torch.diag_embed(
            data['second']['mass'].unsqueeze(0)
            ).to(device)
    else:
        mass_mat_first = None
        mass_mat_second = None



    # predict the sign change
    with torch.no_grad():
        sign_pred_first, support_vector_norm_first, _ = predict_sign_change(
            sign_corr_net, verts_first, faces_first, evecs_first, 
            mass_mat=mass_mat_first, input_type=sign_corr_net.input_type,
            evecs_per_support=config["sign_net"]["evecs_per_support"],
            
            mass=data['first']['mass'].unsqueeze(0), L=data['first']['L'].unsqueeze(0),
            evals=data['first']['evals'][:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
            evecs=data['first']['evecs'][:,:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
            gradX=data['first']['gradX'].unsqueeze(0), gradY=data['first']['gradY'].unsqueeze(0)
            )
        sign_pred_second, support_vector_norm_second, _ = predict_sign_change(
            sign_corr_net, verts_second, faces_second, evecs_second, 
            mass_mat=mass_mat_second, input_type=sign_corr_net.input_type,
            evecs_per_support=config["sign_net"]["evecs_per_support"],
            
            mass=data['second']['mass'].unsqueeze(0), L=data['second']['L'].unsqueeze(0),
            evals=data['second']['evals'][:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
            evecs=data['second']['evecs'][:,:config["sign_net"]["net_params"]["k_eig"]].unsqueeze(0),
            gradX=data['second']['gradX'].unsqueeze(0), gradY=data['second']['gradY'].unsqueeze(0)
            )

    # correct the evecs
    evecs_first_corrected = evecs_first.cpu()[0] * torch.sign(sign_pred_first).cpu()
    evecs_first_corrected_norm = evecs_first_corrected / torch.norm(evecs_first_corrected, dim=0, keepdim=True)
    
    evecs_second_corrected = evecs_second.cpu()[0] * torch.sign(sign_pred_second).cpu()
    evecs_second_corrected_norm = evecs_second_corrected / torch.norm(evecs_second_corrected, dim=0, keepdim=True)
        
    # functional maps        
    Cxy_corr = torch.linalg.lstsq(
        evecs_second_corrected[corr_second].to(device),
        evecs_first_corrected[corr_first].to(device),
        ).solution.cpu()
    
    Cxy_corr_list.append(Cxy_corr)
    
    
    # shot descriptors
    shot_descrs_second = torch.tensor(pyshot.get_descriptors(
        verts_second.cpu().numpy().astype(np.float64)[0],
        faces_second.cpu().numpy().astype(np.int64)[0],
        radius=100,
        local_rf_radius=100,
        # The following parameters are optional
        min_neighbors=3,
        n_bins=20,
        double_volumes_sectors=True,
        use_interpolation=True,
        use_normalization=True,
    ), dtype=torch.float32)
    
    # print(shot_descrs_second.shape)
    
    # given a [6890, 352] desriptor, group the descriptors into bins of 11, then take the average of the bin to get a [6890, 32] descriptor
    
    assert shot_descrs_second.shape[1] % 32 == 0
    bin_size = shot_descrs_second.shape[1] // 32
    
    shot_descrs_second_grouped = []
    for i in range(32):
        shot_descrs_second_grouped.append(shot_descrs_second[:, i*bin_size:(i+1)*bin_size].mean(dim=1))
    shot_descrs_second_grouped = torch.stack(shot_descrs_second_grouped, dim=1)
    
    # shot_descrs_second_grouped = shot_descrs_second.view(6890, -1, 32).mean(dim=1)
    
    # normalize the descriptors
    shot_descrs_second_grouped = torch.nn.functional.normalize(shot_descrs_second_grouped, p=2, dim=0)
    
    # project them onto the support vectors
    
    shot_descrs_second_grouped_proj = shot_descrs_second_grouped.T.to(device) @ support_vector_norm_second[0]
    
    # print(shot_descrs_second_grouped_proj.shape)
    
    shot_list.append(shot_descrs_second_grouped_proj.cpu())
    
    
    # break
    
    
    
    
Cxy_corr = torch.stack(Cxy_corr_list, dim=0)
shot_list = torch.stack(shot_list, dim=0)

In [None]:
shot_descrs_second.shape[1] / 32

In [None]:
import matplotlib.pyplot as plt

# shot 5 random images from shot_list
fig, axs = plt.subplots(1, 5, figsize=(20, 4))

rand_idx = np.random.choice(len(shot_list), 5, replace=False)

for i, idx in enumerate(rand_idx):
    axs[i].imshow(shot_list[idx].numpy())
    axs[i].set_title(f'{idx}')
    
plt.show()

In [None]:
apply_pca(Cxy_corr, 'Cxy_corr')

In [None]:
apply_pca(shot_list, 'shot_list', color_by='pose', pca_components=4)

In [None]:
352 / 32