# How to learn feature for functional maps

In this notebook, we show how to use deep functional maps to learn feature for 3d shape matching.

In [3]:
# Set the backend for geomstats to PyTorch, commented for github test
import os

os.environ["GEOMSTATS_BACKEND"] = "pytorch"

import torch
from torch.utils.data import random_split

from geomfum.convert import P2pFromFmConverter
from geomfum.dataset.torch import PairsDataset, ShapeDataset
from geomfum.descriptor.learned import FeatureExtractor
from geomfum.descriptor.spectral import WaveKernelSignature
from geomfum.forward_functional_map import ForwardFunctionalMap
from geomfum.learning.losses import (
    BijectivityLoss,
    GeodesicError,
    LaplacianCommutativityLoss,
    LossManager,
    OrthonormalityLoss,
)
from geomfum.learning.models import FMNet
from geomfum.learning.trainer import DeepFunctionalMapTrainer

First, we define our model. We can instantiate it combining feature extractors and forward logic, however, we provide some classic frameworks, like Functional Map network.

In [4]:
# Build the model
fmap_module = ForwardFunctionalMap(1e3, 1, True)

feature_extractor = FeatureExtractor.from_registry(
    which="diffusionnet",
    device="cuda",
    k_eig=200,
    in_channels=128,
    descriptor=WaveKernelSignature(n_domain=128),
)

functional_map_model = FMNet(
    feature_extractor=feature_extractor,
    fmap_module=fmap_module,
    converter=P2pFromFmConverter(),
)


Then, we instantiate the training dataset. \
In our Datset class, we cna set boolean variable to specify what kind of objects we expect in the dataset.\
In the dataset folder, we always expect datas to be stored in a 'shapes' folder. \
If we have access to tamplate ground thruth correspondences, we can set correspondences= True, in this case we expect to have a folder called 'corr'.
We can set spectral=True if we want to compute spectral quantities, and set distances=True if we want to compute distances, this is expensive, so we suggest to do so only for testing dataset.

The following code download the faust dataset.

In [8]:
from urllib.request import urlretrieve

faust_url = "https://raw.githubusercontent.com/JM-data/PyFuncMap/4bde4484c3e93bff925a6a82da29fa79d6862f4b/FAUST_shapes_off/"
shape_files = [f"tr_reg_{i:03d}.off" for i in range(100)]
for fname in shape_files:
    url = faust_url + fname
    out_path = os.path.join("../../../datasets/faust/shapes/", fname)
    os.makedirs(os.path.dirname(out_path), exist_ok=True)
    urlretrieve(url, out_path)


In [10]:
TRAIN_SET_PATH = "../../../datasets/faust/"
dataset = ShapeDataset(
    TRAIN_SET_PATH,
    spectral=True,
    distances=True,
    correspondences=True,
    device="cuda",
    k=30,
)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size

train_shapes, validation_shapes = random_split(dataset, [train_size, val_size])


train_dataset = PairsDataset(
    train_shapes,
    pair_mode="all",
)

validation_dataset = PairsDataset(
    validation_shapes,
    pair_mode="all",
)


Sometimes the distance computation is usefull only at validation time, so we suggest to perform the following trick

In [None]:
from torch.utils.data import Subset


TRAIN_SET_PATH = "../../../datasets/faust/"
train_shapes = ShapeDataset(
    TRAIN_SET_PATH,
    spectral=True,
    distances=False,
    correspondences=False,
    device="cuda",
    k=30,
)

val_shapes = ShapeDataset(
    TRAIN_SET_PATH,
    spectral=True,
    distances=True,
    correspondences=True,
    device="cuda",
    k=30,
)


train_dataset = PairsDataset(
    train_shapes,
    pair_mode="all",
)

validation_dataset = PairsDataset(
    val_shapes,
    pair_mode="all",
)


Then , we instantiate the optimizer

In [11]:
optimizer = torch.optim.Adam(functional_map_model.parameters(), lr=1e-3)

Now we define the losses that we will consider. Again we can define our own losses, however we provide some classic functional map energies, like the orthonormality loss. 
\
For evaluation, we can use training losses, or we can compute the geodesic distance loss, to evaluate the estimates.\
We note that this loss makes sense only if we ahve access to a ground thruth correspondence or if the shapes share the same triangulation.

In [12]:
# define the loss
losses = [
    OrthonormalityLoss(weight=1.0),
    BijectivityLoss(weight=1.0),
    LaplacianCommutativityLoss(weight=1e-3),
]
loss_manager = LossManager(losses)

losses = [
    GeodesicError(),
]

val_loss_manager = LossManager(losses)

We have defined a trainer for simplicity that thakes as input model, losses, train and val datasets and optimizer and manages the training loops.

In [13]:
trainer = DeepFunctionalMapTrainer(
    model=functional_map_model,
    train_loss_manager=loss_manager,
    val_loss_manager=val_loss_manager,
    train_set=train_dataset,
    val_set=validation_dataset,
    optimizer=optimizer,
    device="cuda",
    epochs=10,
)

In [None]:
trainer.train()

INFO: Epoch [1/10] - Training
  input_feat = torch.tensor(input_feat).to(torch.float32).to(self.device)
Epoch 1/10 (Train):   0%|          | 3/6320 [01:13<42:46:10, 24.37s/batch, Loss=301.4008]