In [None]:
import numpy as np

from geomfum.dataset import NotebooksDataset
from geomfum.dfm.forward_functional_map import ForwardFunctionalMap
from geomfum.shape import TriangleMesh
from geomfum.descriptor.learned import LearnedDescriptor

from geomfum.dfm.losses import LossManager
from geomfum.dfm.dataset import ShapeDataset


import torch

import torch
import itertools
import numpy as np
import os
from torch.utils.data import DataLoader
from tqdm import tqdm


In [None]:
# Function to get all possible pairs of meshes from the dataset
def get_all_pairs(dataset):
    shape_files = dataset.shape_files  # list of shape files in your dataset
    # Generate all unique pairs (combinations) of shapes
    pairs = list(itertools.combinations(shape_files, 2))
    return pairs


In [None]:
import torch
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

# Initialize dataset and dataloader
shape_dir = '../../../datasets/shrec_r/'
dataset = ShapeDataset(shape_dir, pair_mode='all',device='cuda')  # You can change pair_mode to 'random', 'all', or any other mode

# Split dataset into training and testing sets
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

# Create DataLoaders for training and testing sets
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Initialize model and other components
DEVICE = "cuda"
loss_config = {"Orthonormality": 1.0}

descr = LearnedDescriptor.from_registry(which='diffusion_net', cache_dir=shape_dir+'diffusion/', device=DEVICE)
forward_map = ForwardFunctionalMap(0.0001, 1)

loss_manager = LossManager(loss_config)
optimizer = torch.optim.Adam(descr.model.parameters(), lr=0.0001)

# Training Loop
for epoch in range(100):
    print(f"Epoch {epoch + 1}/{100}")
    running_loss = 0.0
    i = 0
    optimizer.zero_grad()  # Move optimizer.zero_grad() outside the batch loop for gradient accumulation

    for batch_idx, (source, target) in enumerate(tqdm(train_dataloader)):
        # Extract shape pair data
        feat_a = descr(source)
        feat_b = descr(target)

        # Compute functional maps
        Cxy = forward_map(source, target, feat_a, feat_b)

        # Compute loss using LossManager
        loss, loss_details = loss_manager.compute_loss(Cxy=Cxy)

        # Accumulate gradients
        loss.backward()

        # Perform optimization step every 4 batches (adjust as needed)
        if (batch_idx + 1) % 4 == 0:
            optimizer.step()
            optimizer.zero_grad()

        # Log loss every 10 steps
        running_loss += loss.item()

        if (i + 1) % 10 == 0:
            print(f'Processed pair {i + 1}/{len(train_dataloader)} - Loss: {loss.item():.4f}, Breakdown: {loss_details}')
        i += 1

    # Print average loss after each epoch
    avg_loss = running_loss / len(train_dataloader)
    print(f"Epoch [{epoch + 1}/{100}], Average Loss: {avg_loss:.4f}")

    # Save the model and pair data (optional)
    if (epoch + 1) % 50 == 0:
        torch.save(descr.model.state_dict(), f'checkpoint_epoch_{epoch + 1}.pth')

# Testing Loop
print("Testing...")
test_loss = 0.0
with torch.no_grad():
    for batch_idx, (source, target) in enumerate(tqdm(test_dataloader)):
        # Extract shape pair data
        feat_a = descr(source)
        feat_b = descr(target)

        # Compute functional maps
        Cxy = forward_map(source, target, feat_a, feat_b)

        # Compute loss using LossManager
        loss, loss_details = loss_manager.compute_loss(Cxy=Cxy)

        # Accumulate test loss
        test_loss += loss.item()

        if (batch_idx + 1) % 10 == 0:
            print(f'Tested pair {batch_idx + 1}/{len(test_dataloader)} - Loss: {loss.item():.4f}, Breakdown: {loss_details}')

# Print average test loss
avg_test_loss = test_loss / len(test_dataloader)
print(f"Average Test Loss: {avg_test_loss:.4f}")