# How to learn feature for functional maps

In this notebook, we show how to use deep functional maps to learn feature for 3d shape matching.

In [None]:
# Set the backend for geomstats to PyTorch, commented for github test
import os

os.environ["GEOMSTATS_BACKEND"] = "pytorch"

import torch

from geomfum.convert import P2pFromFmConverter, SoftmaxNeighborFinder
from geomfum.dataset.torch import PairsDataset, ShapeDataset
from geomfum.descriptor.learned import FeatureExtractor
from geomfum.descriptor.spectral import WaveKernelSignature
from geomfum.forward_functional_map import ForwardFunctionalMap
from geomfum.learning.losses import (
    BijectivityLoss,
    FmapDescriptorsSupervisionLoss,
    GeodesicError,
    LossManager,
    OrthonormalityLoss,
)
from geomfum.learning.models import RobustFMNet
from geomfum.learning.trainer import DeepFunctionalMapTrainer

DOWNLOAD_FAUST = False

First, we instantiate the training dataset. \
In our Datset class, we cna set boolean variable to specify what kind of objects we expect in the dataset.\
In the dataset folder, we always expect datas to be stored in a 'shapes' folder. \
If we have access to tamplate ground thruth correspondences, we can set correspondences= True, in this case we expect to have a folder called 'corr'.
We can set spectral=True if we want to compute spectral quantities, and set distances=True if we want to compute distances, this is expensive, so we suggest to do so only for testing dataset.

The following code download the faust dataset.

In [None]:
if DOWNLOAD_FAUST:
    from urllib.request import urlretrieve

    faust_url = "https://raw.githubusercontent.com/JM-data/PyFuncMap/4bde4484c3e93bff925a6a82da29fa79d6862f4b/FAUST_shapes_off/"
    train_set = [f"tr_reg_{i:03d}.off" for i in range(80)]
    test_set = [f"tr_reg_{i:03d}.off" for i in range(80, 100)]
    for fname in train_set:
        url = faust_url + fname
        out_path = os.path.join("../../../datasets/faust/train_set/shapes/", fname)
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        urlretrieve(url, out_path)
    for fname in test_set:
        url = faust_url + fname
        out_path = os.path.join("../../../datasets/faust/test_set/shapes/", fname)
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        urlretrieve(url, out_path)


Sometimes the distance computation is usefull only at validation time, so we suggest to perform the following trick

In [None]:
TRAIN_SET_PATH = "../../../datasets/smal/train_set/"
train_shapes = ShapeDataset(
    TRAIN_SET_PATH,
    spectral=True,
    distances=False,
    correspondences=True,
    device="cuda",
    k=200,
)
TEST_SET_PATH = "../../../datasets/smal/test_set/"

val_shapes = ShapeDataset(
    TEST_SET_PATH,
    spectral=True,
    distances=True,
    correspondences=True,
    device="cuda",
    k=200,
)

train_dataset = PairsDataset(
    train_shapes,
    pair_mode="all",
)

validation_dataset = PairsDataset(
    val_shapes,
    pair_mode="all",
)


Then, we set the model

In [None]:
feature_extractor = FeatureExtractor.from_registry(
    which="diffusionnet",
    device="cuda",
    in_channels=3,
    #descriptor=WaveKernelSignature(n_domain=128, k=200),
    out_channels=256,
    n_block=4
)
fmap_module = ForwardFunctionalMap(1e3, 1, True, fmap_shape=(200, 200))
functional_map_model = RobustFMNet(
    feature_extractor=feature_extractor,
    fmap_module=fmap_module,
    converter=P2pFromFmConverter(neighbor_finder=SoftmaxNeighborFinder()),
)


Then , we instantiate the optimizer

In [None]:
optimizer = torch.optim.Adam(functional_map_model.parameters(), lr=1e-3)

Now we define the losses that we will consider. Again we can define our own losses, however we provide some classic functional map energies, like the orthonormality loss. 
\
For evaluation, we can use training losses, or we can compute the geodesic distance loss, to evaluate the estimates.\
We note that this loss makes sense only if we ahve access to a ground thruth correspondence or if the shapes share the same triangulation.

In [None]:
# define the loss
losses = [
    # GroundTruthSupervisionLoss(weight=1.0),
    OrthonormalityLoss(weight=1.0),
    BijectivityLoss(weight=1.0),
    # LaplacianCommutativityLoss(weight=1e-3),
    FmapDescriptorsSupervisionLoss(weight=1.0),
]
loss_manager = LossManager(losses)

losses = [
    GeodesicError(),
]

val_loss_manager = LossManager(losses)

We have defined a trainer for simplicity that thakes as input model, losses, train and val datasets and optimizer and manages the training loops.

In [None]:
trainer = DeepFunctionalMapTrainer(
    model=functional_map_model,
    train_loss_manager=loss_manager,
    val_loss_manager=val_loss_manager,
    train_set=train_dataset,
    val_set=validation_dataset,
    optimizer=optimizer,
    device="cuda",
    epochs=10,
)

In [None]:
trainer.train()

In [None]:
trainer.validate()

Now we can save learned feature extractor weights

In [None]:
trainer.model.feature_extractor.save("./RobustFMNet_faust_unsup_epoch1.pth")
