# FMNet  
## Required packages
`conda install numpy scipy tqdm plyfile trimesh joblib torchvision`  
`conda install -c pytorch-lts pytorch`  
`conda install -c conda-forge scikit-learn`  
`conda install -c conda-forge networkx`  

## preprocess.py

In [None]:
# stdlib
import argparse
from pathlib import Path
# 3p
from joblib import Parallel, delayed
from tqdm import tqdm
import numpy as np
import scipy.io as sio
from scipy.sparse import csr_matrix
from sklearn import neighbors
from sklearn.utils.graph import graph_shortest_path
import trimesh
import networkx as nx
# project
import utils.shot.shot as shot
from utils.io import read_mesh
from utils.laplace_decomposition import laplace_decomposition

# SHOT's hyperparameters
NORMAL_R = 0.1
SHOT_R = 0.1
KNN = 20


def compute_geodesic_matrix(verts, faces, NN):
    # get adjacency matrix
    mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False)
    vertex_adjacency = mesh.vertex_adjacency_graph
    vertex_adjacency_matrix = nx.adjacency_matrix(vertex_adjacency, range(verts.shape[0]))
    # get adjacency distance matrix
    graph_x_csr = neighbors.kneighbors_graph(verts, n_neighbors=NN, mode='distance', include_self=False)
    distance_adj = csr_matrix((verts.shape[0], verts.shape[0])).tolil()
    distance_adj[vertex_adjacency_matrix != 0] = graph_x_csr[vertex_adjacency_matrix != 0]
    # compute geodesic matrix
    geodesic_x = graph_shortest_path(distance_adj, directed=False)
    return geodesic_x


def process_mesh(mesh, corres_root, save_dir, args):
    new_name = mesh.stem

    verts, faces = read_mesh(mesh)
    # center shape
    verts -= np.mean(verts, axis=0)

    # compute decomposition
    evals, evecs, evecs_trans, old_sqrt_area = laplace_decomposition(verts, faces, args.num_eigen)

    # normalize area and save
    verts /= old_sqrt_area

    # recompute decomposition and save eigenvalues
    evals, evecs, evecs_trans, sqrt_area = laplace_decomposition(verts, faces, args.num_eigen)
    print(f"shape {mesh.stem} ==> old sqrt area: {old_sqrt_area :.8f} | new sqrt area: {sqrt_area :.8f}")

    to_save = {"pos": verts, "faces": faces,
               "evals": evals, "evecs": evecs, "evecs_trans": evecs_trans}

    # compute geodesic matrix
    geodesic_x = compute_geodesic_matrix(verts, faces, args.nn)
    to_save["geod_dist"] = geodesic_x

    # compute shot descriptors
    shot_features = shot.compute(verts, NORMAL_R, SHOT_R).reshape(-1, 352)
    to_save["feat"] = shot_features

    # add correspandance
    if corres_root is not None:
        to_save["vts"] = np.loadtxt(corres_root / f"{new_name}.vts", dtype=np.int32)

    # save
    sio.savemat(save_dir / f"{new_name}.mat", to_save)


def main(args):
    save_root = Path(args.save_dir)
    save_root.mkdir(parents=True, exist_ok=True)
    meshes_root = Path(args.dataroot / "shapes")
    corres_root = Path(args.dataroot / "correspondences") if Path(args.dataroot / "correspondences").is_dir() else None

    meshes = list(meshes_root.iterdir())
    _ = Parallel(n_jobs=args.njobs)(delayed(process_mesh)(mesh, corres_root, save_root, args)
                                    for mesh in tqdm(meshes))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description="""Preprocess data for FMNet training.
                       Compute Laplacian eigen decomposition, shot features, and geodesic distance for each shape."""
    )
    parser.add_argument('-d', '--dataroot', required=False,
                        default="../data/faust/raw", help='root directory of the dataset')
    parser.add_argument('-sd', '--save-dir', required=False,
                        default="../data/faust/processed", help='root directory to save the processed dataset')
    parser.add_argument("-ne", "--num-eigen", type=int, default=100, help="number of eigenvectors kept.")
    parser.add_argument("-nj", "--njobs", type=int, default=-2, help="Number of parallel processes to use.")
    parser.add_argument("--nn", type=int, default=20,
                        help="Number of Neighbor to consider when computing geodesic matrix.")

    args = parser.parse_args()
    main(args)


## faust_dataset.py

In [None]:
# stdlib
from os import listdir
from os.path import isfile, join
from itertools import permutations
# 3p
import numpy as np
import scipy.io as sio
import torch
from torch.utils.data import Dataset


class FAUSTDataset(Dataset):
    """FAUST dataset"""
    def __init__(self, root, dim_basis=100, transform=None):
        self.root = root
        self.dim_basis = dim_basis
        self.transform = transform
        self.samples = [join(root, f) for f in listdir(root) if isfile(join(root, f))]
        self.combinations = list(permutations(range(len(self.samples)), 2))

    def loader(self, path):
        """
        pos: num_vertices * 3
        evecs: num_vertices * n_basis
        evecs_trans: n_basis * num_vertices
        feat: num_vertices * n_features
        dist: num_vertices * num_vertices
        """
        mat = sio.loadmat(path)
        return (torch.Tensor(mat['feat']).float(), torch.Tensor(mat['evecs'])[:, :self.dim_basis].float(),
                torch.Tensor(mat['evecs_trans'])[:self.dim_basis, :].float(),
                torch.Tensor(mat['geod_dist']).float(), torch.Tensor(mat['vts']).long())

    def __len__(self):
        return len(self.combinations)

    def __getitem__(self, index):
        idx1, idx2 = self.combinations[index]
        path1, path2 = self.samples[idx1], self.samples[idx2]

        feat_x, evecs_x, evecs_trans_x, dist_x, vts_x = self.loader(path1)
        feat_x, evecs_x, evecs_trans_x, dist_x = feat_x[vts_x], evecs_x[vts_x], evecs_trans_x[:, vts_x], dist_x[vts_x][:, vts_x]
        feat_y, evecs_y, evecs_trans_y, dist_y, vts_y = self.loader(path2)
        feat_y, evecs_y, evecs_trans_y, dist_y = feat_y[vts_y], evecs_y[vts_y], evecs_trans_y[:, vts_y], dist_y[vts_y][:, vts_y]
        if self.transform is not None:
            feat_x, evecs_x, evecs_trans_x, dist_x = self.transform((feat_x, evecs_x, evecs_trans_x, dist_x))
            feat_y, evecs_y, evecs_trans_y, dist_y = self.transform((feat_y, evecs_y, evecs_trans_y, dist_y))

        return [feat_x, evecs_x, evecs_trans_x, dist_x, feat_y, evecs_y, evecs_trans_y, dist_y]


class RandomSampling(object):
    def __init__(self, num_vertices):
        self.num_vertices = num_vertices

    def __call__(self, sample):
        feat_x, evecs_x, evecs_trans_x, dist_x = sample
        vertices = np.random.choice(feat_x.size(0), self.num_vertices)
        feat_x = feat_x[vertices, :]
        evecs_x = evecs_x[vertices, :]
        evecs_trans_x = evecs_trans_x[:, vertices]
        dist_x = dist_x[vertices, :][:, vertices]

        return feat_x, evecs_x, evecs_trans_x, dist_x


## loss.py

In [None]:
# 3p
import torch
import torch.nn as nn


class SoftErrorLoss(nn.Module):
    """
    Calculate soft error loss as defined is FMNet paper.
    """
    def __init__(self):
        super().__init__()

    def forward(self, P, geodesic_dist):
        """Compute soft error loss

        Arguments:
            P {torch.Tensor} -- soft correspondence matrix. Shape: batch_size x num_vertices x num_vertices.
            geodesic_dist {torch.Tensor} -- geodesic distances on Y. Shape: batch_size x num_vertices x num_vertices.

        Returns:
            float -- total loss
        """
        loss = torch.sqrt(((P * geodesic_dist) ** 2).sum((1, 2)))
        return torch.mean(loss)


## model.py

In [None]:
# 3p
import torch
import torch.nn as nn
import torch.nn.functional as F


class ResidualBlock(nn.Module):
    """Implement one residual block as presented in FMNet paper."""
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, out_dim)
        self.bn1 = nn.BatchNorm1d(out_dim, eps=1e-3, momentum=1e-3)
        self.fc2 = nn.Linear(out_dim, out_dim)
        self.bn2 = nn.BatchNorm1d(out_dim, eps=1e-3, momentum=1e-3)

        if in_dim != out_dim:
            self.projection = nn.Sequential(
                nn.Linear(in_dim, out_dim),
                # nn.BatchNorm1d(out_dim)  # non implemented in original FMNet paper, suggested in resnet paper
            )
        else:
            self.projection = None

    def forward(self, x):
        x_res = F.relu(self.bn1(self.fc1(x).transpose(1, 2)).transpose(1, 2))
        x_res = self.bn2(self.fc2(x_res).transpose(1, 2)).transpose(1, 2)
        if self.projection:
            x = self.projection(x)
        x_res += x
        return F.relu(x_res)


class RefineNet(nn.Module):
    """Implement the refine net of FMNet. Take as input hand-crafted descriptors.
       Output learned descriptors well suited to the task of correspondence"""
    def __init__(self, n_residual_blocks=7, in_dim=352):
        super().__init__()
        model = []
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_dim, in_dim)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        """One pass in refine net.

        Arguments:
            x {torch.Tensor} -- input hand-crafted descriptor. Shape: batch-size x num-vertices x num-features

        Returns:
            torch.Tensor -- learned descriptor. Shape: batch-size x num-vertices x num-features
        """
        return self.model(x)


class SoftcorNet(nn.Module):
    """Implement the net computing the soft correspondence matrix."""
    def __init__(self):
        super().__init__()

    def forward(self, feat_x, feat_y, evecs_x, evecs_y, evecs_trans_x, evecs_trans_y):
        """One pass in soft core net.

        Arguments:
            feat_x {Torch.Tensor} -- learned feature 1. Shape: batch-size x num-vertices x num-features
            feat_y {Torch.Tensor} -- learned feature 2. Shape: batch-size x num-vertices x num-features
            evecs_x {Torch.Tensor} -- eigen vectors decomposition of shape 1. Shape: batch-size x num-vertices x num-eigenvectors
            evecs_y {Torch.Tensor} -- eigen vectors decomposition of shape 2. Shape: batch-size x num-vertices x num-eigenvectors
            evecs_trans_x: {Torch.Tensor} -- inverse eigen vectors decomposition of shape 1. defined as evecs_x.t() @ mass_matrix.
                                             Shape: batch-size x num-eigenvectors x num-vertices
            evecs_trans_y: {Torch.Tensor} -- inverse eigen vectors decomposition of shape 2. defined as evecs_y.t() @ mass_matrix.
                                             Shape: batch-size x num-eigenvectors x num-vertices

        Returns:
            Torch.Tensor -- soft correspondence matrix. Shape: batch_size x num_vertices x num_vertices.
            Torch.Tensor -- Functional map matrix. Shape: batch_size x num-eigenvectors x num-eigenvectors.
        """
        # compute linear operator matrix representation C
        F_hat = torch.bmm(evecs_trans_x, feat_x)
        G_hat = torch.bmm(evecs_trans_y, feat_y)
        F_hat, G_hat = F_hat.transpose(1, 2), G_hat.transpose(1, 2)
        Cs = []
        for i in range(feat_x.size(0)):
            C = torch.inverse(F_hat[i].t() @ F_hat[i]) @ F_hat[i].t() @ G_hat[i]
            Cs.append(C.t().unsqueeze(0))
        C = torch.cat(Cs, dim=0)

        # compute soft correspondence matrix P
        P = torch.abs(torch.bmm(torch.bmm(evecs_y, C), evecs_trans_x))
        P = F.normalize(P, 2, dim=1)
        return P, C


class FMNet(nn.Module):
    """Implement the FMNet network as described in the paper."""
    def __init__(self, n_residual_blocks=7, in_dim=352):
        """Initialize network.

        Keyword Arguments:
            n_residual_blocks {int} -- number of resnet blocks in FMNet (default: {7})
            in_dim {int} -- Input features dimension (default SHOT descriptor) (default: {352})
        """
        super().__init__()

        self.refine_net = RefineNet(n_residual_blocks, in_dim)
        self.softcor = SoftcorNet()

    def forward(self, feat_x, feat_y, evecs_x, evecs_y, evecs_trans_x, evecs_trans_y):
        """One pass in FMNet.

        Arguments:
            feat_x {Torch.Tensor} -- hand crafted feature 1. Shape: batch-size x num-vertices x num-features
            feat_y {Torch.Tensor} -- hand crafted feature 2. Shape: batch-size x num-vertices x num-features
            evecs_x {Torch.Tensor} -- eigen vectors decomposition of shape 1. Shape: batch-size x num-vertices x num-eigenvectors
            evecs_y {Torch.Tensor} -- eigen vectors decomposition of shape 2. Shape: batch-size x num-vertices x num-eigenvectors

        Returns:
            Torch.Tensor -- soft correspondence matrix. Shape: batch_size x num_vertices x num_vertices.
            Torch.Tensor -- matrix representation of functional correspondence.
                            Shape: batch_size x num-eigenvectors x num-eigenvectors.
        """
        feat_x = self.refine_net(feat_x)
        feat_y = self.refine_net(feat_y)
        P, C = self.softcor(feat_x, feat_y, evecs_x, evecs_y, evecs_trans_x, evecs_trans_y)
        return P, C


## train.py

In [None]:
# stdlib
import argparse
import os
# 3p
import torch
from torchvision import transforms
# project
from model import FMNet
from faust_dataset import FAUSTDataset, RandomSampling
from loss import SoftErrorLoss


def train_fmnet(args):
    if torch.cuda.is_available() and not args.no_cuda:
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    print(f"Using device: {device}")

    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)

    # create dataset
    print("creating dataset")
    composed = transforms.Compose([RandomSampling(args.n_vertices)])
    dataset = FAUSTDataset(args.dataroot, args.dim_basis, transform=composed)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.n_cpu)
    # create model
    print("creating model")
    fmnet = FMNet(n_residual_blocks=args.num_blocks, in_dim=352).to(device)  # number of features of SHOT descriptor
    optimizer = torch.optim.Adam(fmnet.parameters(), lr=args.lr, betas=(args.b1, args.b2))
    criterion = SoftErrorLoss().to(device)

    # Training loop
    print("start training")
    iterations = 0
    for epoch in range(1, args.n_epochs + 1):
        fmnet.train()
        for i, data in enumerate(dataloader):
            data = [x.to(device) for x in data]
            feat_x, evecs_x, evecs_trans_x, dist_x, feat_y, evecs_y, evecs_trans_y, dist_y = data

            # do iteration
            P, _ = fmnet(feat_x, feat_y, evecs_x, evecs_y, evecs_trans_x, evecs_trans_y)
            loss = criterion(P, dist_y)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            # log
            iterations += 1
            if iterations % args.log_interval == 0:
                print(f"#epoch:{epoch}, #batch:{i + 1}, #iteration:{iterations}, loss:{loss}")

        # save model
        if (epoch + 1) % args.checkpoint_interval == 0:
            torch.save(fmnet.state_dict(), os.path.join(args.save_dir, 'epoch{}.pth'.format(epoch)))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Launch the training of FMNet model."
    )
    parser.add_argument("--lr", type=float, default=1e-3, help="adam: learning rate")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("-bs", "--batch-size", type=int, default=8, help="size of the batches")
    parser.add_argument("--n-epochs", type=int, default=50, help="number of epochs of training")
    parser.add_argument('--dim-basis', type=int, default=40,
                        help='number of eigenvectors used for representation.')
    parser.add_argument("-nv", "--n-vertices", type=int, default=1500, help="Number of vertices used per shape")
    parser.add_argument("-nb", "--num-blocks", type=int, default=7, help="number of resnet blocks")
    parser.add_argument('-d', '--dataroot', required=False, default="../data/faust/processed",
                        help='root directory of the dataset')
    parser.add_argument('--save-dir', required=False, default="../data/save/", help='root directory of the dataset')
    parser.add_argument("--n-cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
    parser.add_argument('--no-cuda', action='store_true', help='Disable GPU computation')
    parser.add_argument("--checkpoint-interval", type=int, default=5, help="interval between model checkpoints")
    parser.add_argument("--log-interval", type=int, default=1, help="interval between logging train information")

    args = parser.parse_args()
    train_fmnet(args)
