In [1]:
	
import numpy as np
import toponetx as tnx
import torch
from torch.utils.data import DataLoader, Dataset

from topomodelx.nn.combinatorial.hmc import HMC

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [4]:
class SHRECDataset(Dataset):
    """Class for the SHREC 2016 dataset.

    Parameters
    ----------
    data : npz file
        npz file containing the SHREC 2016 data.
    """

    def __init__(self, data) -> None:
        self.complexes = [cc.to_combinatorial_complex() for cc in data["complexes"]]
        self.x_0 = data["node_feat"]
        self.x_1 = data["edge_feat"]
        self.x_2 = data["face_feat"]
        self.y = data["label"]
        self.a0, self.a1, self.coa2, self.b1, self.b2 = self._get_neighborhood_matrix()

    def _get_neighborhood_matrix(self) -> list[list[torch.sparse.Tensor], ...]:
        """Neighborhood matrices for each combinatorial complex in the dataset.

        Following the Higher Order Attention Model for Mesh Classification message passing scheme, this method computes the necessary neighborhood matrices
        for each combinatorial complex in the dataset. This method computes:

        - Adjacency matrices for each 0-cell in the dataset.
        - Adjacency matrices for each 1-cell in the dataset.
        - Coadjacency matrices for each 2-cell in the dataset.
        - Incidence matrices from 1-cells to 0-cells for each 1-cell in the dataset.
        - Incidence matrices from 2-cells to 1-cells for each 2-cell in the dataset.

        Returns
        -------
        a0 : list of torch.sparse.FloatTensor
            Adjacency matrices for each 0-cell in the dataset.
        a1 : list of torch.sparse.FloatTensor
            Adjacency matrices for each 1-cell in the dataset.
        coa2 : list of torch.sparse.FloatTensor
            Coadjacency matrices for each 2-cell in the dataset.
        b1 : list of torch.sparse.FloatTensor
            Incidence matrices from 1-cells to 0-cells for each 1-cell in the dataset.
        b2 : list of torch.sparse.FloatTensor
            Incidence matrices from 2-cells to 1-cells for each 2-cell in the dataset.
        """

        a0 = []
        a1 = []
        coa2 = []
        b1 = []
        b2 = []

        for cc in self.complexes:
            a0.append(torch.from_numpy(cc.adjacency_matrix(0, 1).todense()).to_sparse())
            a1.append(torch.from_numpy(cc.adjacency_matrix(1, 2).todense()).to_sparse())

            B = cc.incidence_matrix(rank=1, to_rank=2)
            A = B.T @ B
            A.setdiag(0)
            coa2.append(torch.from_numpy(A.todense()).to_sparse())

            b1.append(torch.from_numpy(cc.incidence_matrix(0, 1).todense()).to_sparse())
            b2.append(torch.from_numpy(cc.incidence_matrix(1, 2).todense()).to_sparse())

        return a0, a1, coa2, b1, b2

    def num_classes(self) -> int:
        """Returns the number of classes in the dataset.

        Returns
        -------
        int
            Number of classes in the dataset.
        """
        return len(np.unique(self.y))

    def channels_dim(self) -> tuple[int, int, int]:
        """Returns the number of channels for each input signal.

        Returns
        -------
        tuple of int
            Number of channels for each input signal.
        """
        return [self.x_0[0].shape[1], self.x_1[0].shape[1], self.x_2[0].shape[1]]

    def __len__(self) -> int:
        """Returns the number of elements in the dataset.

        Returns
        -------
        int
            Number of elements in the dataset.
        """
        return len(self.complexes)

    def __getitem__(self, idx) -> tuple[torch.Tensor, ...]:
        """Returns the idx-th element in the dataset.

        Parameters
        ----------
        idx : int
            Index of the element to return.

        Returns
        -------
        tuple of torch.Tensor
            Tuple containing the idx-th element in the dataset, including the input signals on nodes, edges and faces, the neighborhood matrices and the label.
        """
        return (
            self.x_0[idx],
            self.x_1[idx],
            self.x_2[idx],
            self.a0[idx],
            self.a1[idx],
            self.coa2[idx],
            self.b1[idx],
            self.b2[idx],
            self.y[idx],
        )

In [5]:
shrec_training, shrec_testing = tnx.datasets.shrec_16()


In [10]:
import numpy as np
import matplotlib.pyplot as plt
import toponetx as tnx
from mpl_toolkits.mplot3d import Axes3D

# Load the SHREC 2016 dataset
print("Loading SHREC 2016 dataset...")
shrec_training, shrec_testing = tnx.datasets.shrec_16()

# Basic dataset information
print("\n=== SHREC Dataset Overview ===")
print(f"Training set size: {len(shrec_training['complexes'])}")
print(f"Testing set size: {len(shrec_testing['complexes'])}")

# Examining the first complex
print("\n=== First Training Complex ===")
first_complex = shrec_training['complexes'][0]
print(f"Type: {type(first_complex)}")

# Basic complex statistics for SimplicialComplex
print("\n=== Complex Structure ===")
# For SimplicialComplex, we access the simplices by dimension
print(f"Number of 0-simplices (nodes): {len(list(first_complex.skeleton(0)))}")
print(f"Number of 1-simplices (edges): {len(list(first_complex.skeleton(1))) - len(list(first_complex.skeleton(0)))}")
print(f"Number of 2-simplices (faces): {len(list(first_complex.skeleton(2))) - len(list(first_complex.skeleton(1)))}")

# Display feature information
print("\n=== Feature Information ===")
print(f"Node features shape: {shrec_training['node_feat'][0].shape}")
print(f"Edge features shape: {shrec_training['edge_feat'][0].shape}")
print(f"Face features shape: {shrec_training['face_feat'][0].shape}")

# Display label information
print("\n=== Label Information ===")
print(f"First complex label: {shrec_training['label'][0]}")
unique_labels = np.unique(shrec_training['label'])
print(f"Unique labels: {unique_labels}")
print(f"Label distribution: {[(label, np.sum(shrec_training['label'] == label)) for label in unique_labels]}")



# Example of extracting cell features
print("\n=== Example Cell Features ===")
node_feats = shrec_training['node_feat'][0]
edge_feats = shrec_training['edge_feat'][0]
face_feats = shrec_training['face_feat'][0]

print(f"First node features: {node_feats[0]}")
print(f"First edge features: {edge_feats[0]}")
print(f"First face features: {face_feats[0]}")

# Show first few entries in each dictionary with their corresponding indices
print("\n=== Key-Index Mappings ===")
try:
    # Convert to combinatorial complex for index mapping
    cc = first_complex.to_combinatorial_complex()
    
    # Get incidence matrices with indices
    row, column, B1 = cc.incidence_matrix(0, 1, index=True)
    row1, column1, B2 = cc.incidence_matrix(1, 2, index=True)
    
    # Print first few entries of each dictionary
    print("Rank 0 cell indices (first 5):")
    for i, (cell, idx) in enumerate(row.items()):
        if i >= 5: break
        print(f"  {idx}: {cell}")
    
    print("\nRank 1 cell indices (first 5):")
    for i, (cell, idx) in enumerate(column.items()):
        if i >= 5: break
        print(f"  {idx}: {cell}")
    
    print("\nRank 2 cell indices (first 5):")
    for i, (cell, idx) in enumerate(column1.items()):
        if i >= 5: break
        print(f"  {idx}: {cell}")
except Exception as e:
    print(f"Could not convert to combinatorial complex: {e}")
    
    # Alternative: just show node positions
    print("Node positions (first 5):")
    nodes = list(first_complex.nodes)
    for i, node in enumerate(nodes[:5]):
        print(f"  Node {i}: {node}")
    
    # Show edges
    print("\nEdges (first 5):")
    edges = [e for e in first_complex.skeleton(1) if len(e) == 2]
    for i, edge in enumerate(edges[:5]):
        print(f"  Edge {i}: {edge}")
    
    # Show faces
    print("\nFaces (first 5):")
    faces = [f for f in first_complex.skeleton(2) if len(f) == 3]
    for i, face in enumerate(faces[:5]):
        print(f"  Face {i}: {face}")

Loading SHREC 2016 dataset...

=== SHREC Dataset Overview ===
Training set size: 480
Testing set size: 120

=== First Training Complex ===
Type: <class 'toponetx.classes.simplicial_complex.SimplicialComplex'>

=== Complex Structure ===
Number of 0-simplices (nodes): 252
Number of 1-simplices (edges): 498
Number of 2-simplices (faces): -250

=== Feature Information ===
Node features shape: (252, 6)
Edge features shape: (750, 10)
Face features shape: (500, 7)

=== Label Information ===
First complex label: 0
Unique labels: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29]
Label distribution: [(0, 16), (1, 16), (2, 16), (3, 16), (4, 16), (5, 16), (6, 16), (7, 16), (8, 16), (9, 16), (10, 16), (11, 16), (12, 16), (13, 16), (14, 16), (15, 16), (16, 16), (17, 16), (18, 16), (19, 16), (20, 16), (21, 16), (22, 16), (23, 16), (24, 16), (25, 16), (26, 16), (27, 16), (28, 16), (29, 16)]

=== Example Cell Features ===
First node features: [ 0.567542    0.5

In [11]:
training_dataset = SHRECDataset(shrec_training)
training_dataloader = DataLoader(training_dataset, batch_size=1, shuffle=True)

In [12]:
testing_dataset = SHRECDataset(shrec_testing)
testing_dataloader = DataLoader(testing_dataset, batch_size=1, shuffle=True)

In [13]:
class Trainer:
    """Trainer for the HOANMeshClassifier.

    Parameters
    ----------
    model : torch.nn.Module
        The model to train.
    training_dataloader : torch.utils.data.DataLoader
        The dataloader for the training set.
    testing_dataloader : torch.utils.data.DataLoader
        The dataloader for the testing set.
    learning_rate : float
        The learning rate for the Adam optimizer.
    device : torch.device
        The device to use for training.
    """

    def __init__(
        self, model, training_dataloader, testing_dataloader, learning_rate, device
    ) -> None:
        self.model = model.to(device)
        self.training_dataloader = training_dataloader
        self.testing_dataloader = testing_dataloader
        self.device = device
        self.crit = torch.nn.CrossEntropyLoss()
        self.opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

    def _to_device(self, x) -> list[torch.Tensor]:
        """Converts tensors to the correct type and moves them to the device.

        Parameters
        ----------
        x : List[torch.Tensor]
            List of tensors to convert.
        Returns
        -------
        List[torch.Tensor]
            List of converted tensors to float type and moved to the device.
        """

        return [el[0].float().to(self.device) for el in x]

    def train(self, num_epochs=500, test_interval=25) -> None:
        """Trains the model for the specified number of epochs.

        Parameters
        ----------
        num_epochs : int
            Number of epochs to train.
        test_interval : int
            Interval between testing epochs.
        """
        for epoch_i in range(num_epochs):
            training_accuracy, epoch_loss = self._train_epoch()
            print(
                f"Epoch: {epoch_i} loss: {epoch_loss:.4f} Train_acc: {training_accuracy:.4f}",
                flush=True,
            )
            if (epoch_i + 1) % test_interval == 0:
                test_accuracy = self.validate()
                print(f"Test_acc: {test_accuracy:.4f}", flush=True)

    def _train_epoch(self) -> tuple[float, float]:
        """Trains the model for one epoch.

        Returns
        -------
        training_accuracy : float
            The mean training accuracy for the epoch.
        epoch_loss : float
            The mean loss for the epoch.
        """
        training_samples = len(self.training_dataloader.dataset)
        total_loss = 0
        correct = 0
        self.model.train()
        for sample in self.training_dataloader:
            (
                x_0,
                x_1,
                x_2,
                adjacency_0,
                adjacency_1,
                coadjacency_2,
                incidence_1,
                incidence_2,
            ) = self._to_device(sample[:-1])

            self.opt.zero_grad()

            y_hat = self.model.forward(
                x_0,
                x_1,
                x_2,
                adjacency_0,
                adjacency_1,
                coadjacency_2,
                incidence_1,
                incidence_2,
            )

            y = sample[-1][0].long().to(self.device)
            total_loss += self._compute_loss_and_update(y_hat, y)
            correct += (y_hat.argmax() == y).sum().item()

        training_accuracy = correct / training_samples
        epoch_loss = total_loss / training_samples

        return training_accuracy, epoch_loss

    def _compute_loss_and_update(self, y_hat, y) -> float:
        """Computes the loss, performs backpropagation, and updates the model's parameters.

        Parameters
        ----------
        y_hat : torch.Tensor
            The output of the model.
        y : torch.Tensor
            The ground truth.

        Returns
        -------
        loss: float
            The loss value.
        """

        loss = self.crit(y_hat, y)
        loss.backward()
        self.opt.step()
        return loss.item()

    def validate(self) -> float:
        """Validates the model using the testing dataloader.

        Returns
        -------
        test_accuracy : float
            The mean testing accuracy.
        """
        correct = 0
        self.model.eval()
        test_samples = len(self.testing_dataloader.dataset)
        with torch.no_grad():
            for sample in self.testing_dataloader:
                (
                    x_0,
                    x_1,
                    x_2,
                    adjacency_0,
                    adjacency_1,
                    coadjacency_2,
                    incidence_1,
                    incidence_2,
                ) = self._to_device(sample[:-1])

                y_hat = self.model(
                    x_0,
                    x_1,
                    x_2,
                    adjacency_0,
                    adjacency_1,
                    coadjacency_2,
                    incidence_1,
                    incidence_2,
                )
                y = sample[-1][0].long().to(self.device)
                correct += (y_hat.argmax() == y).sum().item()
            return correct / test_samples

In [14]:
class Network(torch.nn.Module):
    def __init__(
        self,
        channels_per_layer,
        negative_slope=0.2,
        num_classes=2,
    ):
        super().__init__()
        self.base_model = HMC(
            channels_per_layer,
            negative_slope,
        )
        self.l0 = torch.nn.Linear(channels_per_layer[-1][2][0], num_classes)
        self.l1 = torch.nn.Linear(channels_per_layer[-1][2][1], num_classes)
        self.l2 = torch.nn.Linear(channels_per_layer[-1][2][2], num_classes)

    def forward(
        self,
        x_0,
        x_1,
        x_2,
        neighborhood_0_to_0,
        neighborhood_1_to_1,
        neighborhood_2_to_2,
        neighborhood_0_to_1,
        neighborhood_1_to_2,
    ):
        x_0, x_1, x_2 = self.base_model(
            x_0,
            x_1,
            x_2,
            neighborhood_0_to_0,
            neighborhood_1_to_1,
            neighborhood_2_to_2,
            neighborhood_0_to_1,
            neighborhood_1_to_2,
        )
        x_0 = self.l0(x_0)
        x_1 = self.l1(x_1)
        x_2 = self.l2(x_2)

        # Sum all the elements in the dimension zero
        x_0 = torch.nanmean(x_0, dim=0)
        x_1 = torch.nanmean(x_1, dim=0)
        x_2 = torch.nanmean(x_2, dim=0)

        return x_0 + x_1 + x_2

In [15]:
in_channels = training_dataset.channels_dim()
intermediate_channels = [60, 60, 60]
final_channels = [60, 60, 60]

channels_per_layer = [[in_channels, intermediate_channels, final_channels]]
# defube HOAN mesh classifier
model = Network(
    channels_per_layer, negative_slope=0.2, num_classes=training_dataset.num_classes()
)

# If GPU's are available, we will make use of them. Otherwise, this will run on CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trainer = Trainer(model, training_dataloader, testing_dataloader, 0.001, device)