In [1]:
import numpy as np

from geomfum.dataset import NotebooksDataset
from geomfum.dfm.forward_functional_map import ForwardFunctionalMap
from geomfum.shape import TriangleMesh
from geomfum.descriptor.learned import LearnedDescriptor

from geomfum.dfm.losses import LossManager

import torch

In [2]:
#load
dataset = NotebooksDataset()

mesh_a = TriangleMesh.from_file(dataset.get_filename("cat-00"))
mesh_b = TriangleMesh.from_file(dataset.get_filename("lion-00"))


In [3]:
mesh_a.laplacian.find_spectrum(spectrum_size=40, set_as_basis=True)
mesh_b.laplacian.find_spectrum(spectrum_size=40, set_as_basis=True)

mesh_a.basis.use_k = 30
mesh_b.basis.use_k = 30


In [4]:
#convert shapes to tensors and preproces
DEVICE="cuda"
mesh_a.to_torch(DEVICE)
mesh_b.to_torch(DEVICE)

mesh_a.basis.to_torch(DEVICE)
mesh_b.basis.to_torch(DEVICE)

In [5]:
descr = LearnedDescriptor.from_registry(which='diffusion_net',device=DEVICE)
forward_map = ForwardFunctionalMap(0.0001,1)


In [None]:
# Define the loss manager
loss_config = {
    "Orthonormality": 1.0,
    #"Laplacian_Commutativity": 0.001
}
loss_manager = LossManager(loss_config)
optimizer = torch.optim.Adam(descr.model.parameters(), lr=0.001)
evals_x = mesh_a.basis.vals[None]
evals_y = mesh_b.basis.vals[None]

# Training Loop
for epoch in range(100):
    print(epoch)
    optimizer.zero_grad()
    
    # Forward pass
    feat_a = descr(mesh_a)
    feat_b = descr(mesh_b)

    # Compute functional maps
    Cxy = forward_map(mesh_a, mesh_b, feat_a, feat_b)
    
    # Compute loss (LossManager automatically selects the right inputs)
    loss, loss_details = loss_manager.compute_loss(Cxy=Cxy, evals_x=evals_x, evals_y=evals_y)
    
    # Backward pass and optimization
    loss.backward()
    optimizer.step()
    
    if (epoch + 1) % 10 == 0:
        print(f'Epoch [{epoch + 1}/{100}], Loss: {loss.item():.4f}, Breakdown: {loss_details}')

0
1
2
3
4
5
6
7
8
9
Epoch [10/100], Loss: 2023.8610, Breakdown: {'Orthonormality': 2023.8609619140625}
10
11
12
13
14
Decomposition failed; adding eps


In [None]:
mesh_a.basis.vecs