In [None]:
import sys
sys.path.append('./pyFM/')

import os
import numpy as np
import matplotlib.pyplot as plt
import scipy
import itertools
import copy

import utils
import networkx as nx

import pyFM.spectral as spectral
from pyFM.mesh import TriMesh

from tqdm.auto import tqdm

import seaborn as sns

from pyFM.FMN import FMN

# Log Information 

In [None]:
init_maps_dir = None # PATH WHERE TO LOG INITIAL MAPS

network_file = None # PATH WHERE TO LOG THE NETWORK

# Load Collection

In [None]:
path_list = [] # List of path for all shapes 
meshlist = [TriMesh(path).process(k=110, intrinsic=True) for path in tqdm(path_list)]

n_meshes = len(meshlist)

# Initial Maps

In [None]:
n_chosen_pairs = 2500  # Number of pairs to select in the network
k_init = 50  # Initial size of functional maps. We recommand around 20 but because of the double surface we need 50
n_subsample = 3000 # Number of samples to use for initial maps

print(f'{int(scipy.special.binom(n_meshes, 2)):d} possible pairs')

## Samples initial pairs

We sample random pairs in our graph

In [None]:
rng = np.random.default_rng()

all_pairs = list(itertools.combinations(np.arange(n_meshes), 2))
rng.shuffle(all_pairs)

chosen_pairs = all_pairs[:n_chosen_pairs]

G = nx.Graph()
G.add_nodes_from(np.arange(n_meshes))
G.add_edges_from(chosen_pairs)


print(f'Is G connected ? {nx.is_connected(G)}')
all_cliques= nx.enumerate_all_cliques(G)
triad_cliques=[x for x in all_cliques if len(x)==3 ]
print(f'G has {G.number_of_nodes()} nodes and {G.number_of_edges()} edges')
print(f'G has {len(triad_cliques)} 3-cycles')

## Compute initial correspondence using rigid alignment and functional maps

Extract a subset of vertices to work on for init 

In [None]:
subsample_list = np.zeros((n_meshes, n_subsample), dtype=int)
for i in tqdm(range(len(meshlist))):
    subsample_list[i] = meshlist[i].extract_fps(n_subsample, geodesic=False, verbose=False)
    
utils.save_pickle(os.path.join(init_maps_dir, "init_samples.p"), subsample_list)

# Else load the file
utils.load_pickle(os.path.join(init_maps_dir, "init_samples.p"))

Compute intial functional maps for all pairs

In [None]:
maps_dict = {}

for pair_ind, (i, j) in enumerate(tqdm(chosen_pairs)):
    
    map_file_21 = os.path.join(init_maps_dir, f'{j}_to_{i}')
    map_file_12 = os.path.join(init_maps_dir, f'{i}_to_{j}')
    
    fps1 = subsample_list[i]
    fps2 = subsample_list[j]
    
    if os.path.isfile(map_file_21):
        p2p_21 = utils.load_ints(map_file_21) # (n2,)
        p2p_12 = utils.load_ints(map_file_12) # (n1,)
        
    else:
        mesh1 = copy.deepcopy(meshlist[i])
        mesh2 = copy.deepcopy(meshlist[j])

        # Get initial correspondences
        p2p_21_init_sub = utils.knn_query_normals(mesh1.vertlist[fps1], mesh2.vertlist[fps2],
                                                  mesh1.vertex_normals[fps1], mesh2.vertex_normals[fps2],
                                                  k_base=20, n_jobs=20, verbose=False)
        
        # ICP Align the shape
        _, R, t = utils.icp_align(mesh2.vertlist[fps2], mesh1.vertlist[fps1],
                                  p2p_12=p2p_21_init_sub,
                                  return_params=True, n_jobs=n_jobs, epsilon=1e-4, verbose=False)

        mesh2.rotate(R);
        mesh2.translate(t);
        
        # Get final correspondences
        p2p_21 = utils.knn_query_normals(mesh1.vertlist[fps1], mesh2.vertlist[fps2],
                                         mesh1.vertex_normals[fps1], mesh2.vertex_normals[fps2],
                                         k_base=50, n_jobs=n_jobs)
        
        p2p_12 = utils.knn_query_normals(mesh2.vertlist[fps2], mesh1.vertlist[fps1],
                                         mesh2.vertex_normals[fps2], mesh1.vertex_normals[fps1],
                                         k_base=50, n_jobs=n_jobs)

        utils.save_ints(map_file_21, p2p_21)
        utils.save_ints(map_file_12, p2p_12)
        
    # Compute initial functional maps
    FM_12 = spectral.mesh_p2p_to_FM(p2p_21, meshlist[i], meshlist[j], dims=k_init, subsample=(fps1, fps2))
    FM_21 = spectral.mesh_p2p_to_FM(p2p_12, meshlist[j], meshlist[i], dims=k_init, subsample=(fps2, fps1))
    maps_dict[(i,j)] = FM_12.copy();
    maps_dict[(j,i)] = FM_21.copy();
    
print(f'{len(maps_dict.keys())} maps computed');


utils.save_pickle(os.path.join(init_maps_dir, "init_FM.p"), maps_dict);

In [None]:
# maps_dict = utils.load_pickle(os.path.join(init_maps_dir, "init_FM.p"))

# Refine the initial maps 

## Build the functional map network 

In [None]:
network = FMN(meshlist, maps_dict=maps_dict)

In [None]:
# These values were copied you can delete them
del meshlist
del maps_dict

## Select samples to work with in the network 

In [None]:
sub_size = 2000
subsample_list = np.zeros((len(network.meshlist),sub_size), dtype=int)
for i in tqdm(range(len(network.meshlist))):
    subsample_list[i] = network.meshlist[i].extract_fps(sub_size, geodesic=False, verbose=False)

## Refine the network 

In [None]:
czo_parameters = {
    'nit': (110-k_init)//5,
    'step': 5,
    'cclb_ratio': .8,
    'subsample': subsample_list,
    'isometric': False,
    'verbose': True,
    'use_ANN': False,
    'weight_type': 'iscm',
    'n_jobs': 15,
    'backend': 'gpu'
}


network.zoomout_refine(**czo_parameters)
utils.save_pickle(network_file, network)
network.compute_W(M=110)
network.compute_CCLB(int(.8*network.M), verbose=True)

# Extract Template 

In [None]:
deviation_from_id_a = np.zeros(network.n_meshes)
deviation_from_id_cr = np.zeros(network.n_meshes)
for i in range(network.n_meshes):
    CSD_a, CSD_c = network.get_CSD(i)
    deviation_from_id_a[i] = np.linalg.norm(CSD_a - np.eye(CSD_a.shape[0]))
    deviation_from_id_cr[i] = np.linalg.norm(np.sqrt(cclb_ev)[:,None] * (CSD_c - np.eye(CSD_c.shape[0]))) / np.sqrt(cclb_ev.sum())

    
deviation_from_id_cr.argmin(), deviation_from_id_cr.argmin()

Chose on of the two templates 

In [None]:
base_meshind = deviation_from_id_cr.argmin()

# Build Deformation Fields 

In [None]:
k_displacement = 100
backend = 'gpu' # OR "cpu"

In [None]:

mesh1 = copy.deepcopy(network.meshlist[base_meshind]) # TriMesh(network.meshlist[base_meshind].path, area_normalize=True).process(k=k_displacement,verbose=True)
LB_1 = network.get_LB(base_meshind, complete=True)  # (n_1',m)

displacements = np.zeros((network.n_meshes, 3*mesh1.n_vertices))
displacements_red = np.zeros((network.n_meshes, 3*k_displacement))

for meshind2 in tqdm(range(network.n_meshes)):
    if meshind2 == base_meshind:
        continue
    
    mesh2 = copy.deepcopy(network.meshlist[meshind2])
    LB_2 = network.get_LB(meshind2, complete=True)  # (n_2',m)
    
    p2p_czo_12 = pyFM.FMN.knn_query(torch.from_numpy(LB_2.astype(np.float32)).cuda(), torch.from_numpy(LB_1.astype(np.float32)).cuda(), backend=backend).cpu().numpy()
    R, t = utils.rigid_alignment(mesh1.vertlist, mesh2.vertlist, p2p_12=p2p_czo_12,
                                 return_params=True, return_deformed=False, weights=mesh1.vertex_areas)
    
    
    mesh2 = TriMesh(mesh2.vertlist, mesh2.facelist)
    mesh2.translate(-t)
    mesh2.rotate(np.linalg.inv(R))
    
    tau_czo = mesh1.project(mesh2.vertlist[p2p_czo_12] - mesh1.vertlist, k=k_displacement)
    
    displacements[meshind2] = mesh1.decode(tau_czo).flatten()
    displacements_red[meshind2] = tau_czo.flatten()
print('');

# Apply PCA 

In [None]:
from sklearn.linear_model import LogisticRegression

from sklearn.model_selection import train_test_split

from sklearn.decomposition import PCA

In [None]:
avg_disp = np.mean(displacements_red,axis=0)
pca_d = PCA(n_components=50)
emb_d_red = pca_d.fit_transform(displacements_red - avg_disp[None,:])

plt.figure(dpi=150)
plt.title('Explained variance ratio')
plt.xlabel('Number of components')
plt.ylabel('Explained variance (%)')
plt.plot(np.arange(1+pca_d.n_components) , 100*np.cumsum(np.concatenate([[0],pca_d.explained_variance_ratio_])), marker=".");

# Apply logistic regression 

In [None]:
labels = [] # list of labels for each skulls
labels = np.asarray(labels)

In [None]:
reglin1 = LogisticRegression(penalty='none', fit_intercept=True, max_iter=1000)