# Train a Simplicial Complex Autoencoder (SCA) with Coadjacency Message Passing Scheme (CMPS)


🟥 $\quad m_{y \rightarrow x}^{(r \rightarrow r'' \rightarrow r)} = M(h_{x}^{t, (r)}, h_{y}^{t, (r)},att(h_{x}^{t, (r)}, h_{y}^{t, (r)}),x,y,{\Theta^t}) \qquad \text{where } r'' < r < r'$

🟥 $\quad m_{y \rightarrow x}^{(r'' \rightarrow r)} = M(h_{x}^{t, (r)}, h_{y}^{t, (r'')},att(h_{x}^{t, (r)}, h_{y}^{t, (r'')}),x,y,{\Theta^t})$

🟧 $\quad m_x^{(r \rightarrow r)}  = AGG_{y \in \mathcal{L}\_\downarrow(x)} m_{y \rightarrow x}^{(r \rightarrow r)}$

🟧 $\quad m_x^{(r'' \rightarrow r)} = AGG_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(r'' \rightarrow r)}$

🟩 $\quad m_x^{(r)}  = \text{AGG}\_{\mathcal{N}\_k \in \mathcal{N}}(m_x^{(k)})$

🟦 $\quad h_{x}^{t+1, (r)} = U(h_x^{t, (r)}, m_{x}^{(r)})$

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

In [131]:
import torch
import numpy as np

import toponetx.datasets as datasets
from sklearn.model_selection import train_test_split

from topomodelx.nn.simplicial.sca_cmps_layer import SCACMPSLayer
from topomodelx.base.aggregation import Aggregation

If GPUs are available we will make use of them. Otherwise, we will use CPU.

In [132]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##

The first step is to import the dataset, shrec16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a simplicial complex.

We also retrieve:
- input signals `x_0`, `x_1`, and `x_2` on the nodes (0-cells), edges (1-cells), and faces (2-cells) for each complex: these will be the model's inputs,
- a scalar classification label `y` associated to the simplicial complex.

In [133]:
shrec, _ = datasets.mesh.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
scs = shrec["complexes"]

Loading shrec 16 small dataset...

done!


In [134]:
i_complex = 6
print(
    f"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}."
)

The 6th simplicial complex has 252 nodes with features of dimension 6.
The 6th simplicial complex has 750 edges with features of dimension 10.
The 6th simplicial complex has 500 faces with features of dimension 7.


## Preparing the inputs to test each message passing scheme:

#### Coadjacency Message Passing Scheme (CMPS):
This will require features from faces, and edges again, but outputs features on the edges. The first neighborhood matrix will be the level 2 lower Laplacian, $L_{\downarrow, 2}$, and the second neighborhood matrix will be the transpose of the incidence matrix of the faces, $B_{2}^T$.

In [135]:
down_lap1_list = []
down_lap2_list = []
incidence1_t_list = []
incidence2_t_list = []

for sc in scs:
    down_lap1 = sc.down_laplacian_matrix(rank=1)
    down_lap2 = sc.down_laplacian_matrix(rank=2)
    incidence_1t = sc.incidence_matrix(rank=1).T
    incidence_2t = sc.incidence_matrix(rank=2).T

    down_lap1 = torch.from_numpy(down_lap1.todense()).to_sparse()
    down_lap2 = torch.from_numpy(down_lap2.todense()).to_sparse()
    incidence_1t = torch.from_numpy(incidence_1t.todense()).to_sparse()
    incidence_2t = torch.from_numpy(incidence_2t.todense()).to_sparse()

    down_lap1_list.append(down_lap1)
    down_lap2_list.append(down_lap2)
    incidence1_t_list.append(incidence_1t)
    incidence2_t_list.append(incidence_2t)

# Create the Neural Networks

Using the SCACMPSLayer class, we create a neural network with a modifiable number of layers each following the CMPS at each level.

In [136]:
channels_list = [x_0s[0].shape[-1], x_1s[0].shape[-1], x_2s[0].shape[-1]]

complex_dim = 3

In [137]:
x_0s.shape

(100,)

In [138]:
class AMPSSCA(torch.nn.Module):
    """SCA with AMPS.

    Parameters
    ----------
    channels_list: list[int]
        Dimension of features on each node, edge, simplex, tetahedron,... respectively
    complex_dimension: int
        Highest dimension of simplicial complex feature being trained on.
    num_classes : int
        Dimension to which the complex embeddings will be projected.
    att : bool
        Whether to use attention.
    """

    def __init__(
        self,
        channels_list,
        complex_dim,
        num_classes,
        n_layers=2,
        att=False,
    ):
        super().__init__()
        self.n_layers = n_layers
        self.channels_list = channels_list
        self.num_classes = num_classes

        layers = []
        for i in range(n_layers):
            layers.append(SCACMPSLayer(channels_list, complex_dim, att))

        self.layers = torch.nn.ModuleList(layers)
        self.lin0 = torch.nn.Linear(channels_list[0], num_classes)
        self.lin1 = torch.nn.Linear(channels_list[1], num_classes)
        self.lin2 = torch.nn.Linear(channels_list[2], num_classes)
        self.aggr = Aggregation(
            aggr_func="mean",
            update_func="sigmoid",
        )

    def forward(self, x_list, down_lap_list, incidencet_list):
        """Forward computation through layers, then linear layers, then avg pooling.

        Parameters
        ----------
        x_list: list[torch.Tensor]
            List of tensor inputs for each dimension of the complex (nodes, edges, etc.).
        down_lap_list: list[torch.Tensor]
            List of the down laplacian matrix for each dimension in the complex starting at edges.
        incidencet_list: list[torch.Tensor]
            List of the transpose incidence matrices for the edges and faces.

        Returns
        -------
        _ : tensor, shape = [1]
            Label assigned to whole complex.
        """
        for i in range(self.n_layers):
            x_list = self.layers[i](x_list, down_lap_list, incidencet_list)

        x_0 = self.lin0(x_list[0])
        x_1 = self.lin1(x_list[1])
        x_2 = self.lin2(x_list[2])

        two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
        two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0
        one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
        zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
        zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0

        x_2f = torch.flatten(two_dimensional_cells_mean)
        x_1f = torch.flatten(one_dimensional_cells_mean)
        x_0f = torch.flatten(zero_dimensional_cells_mean)

        return x_0f + x_1f + x_2f

# Train and Test Split

In [139]:
test_size = 0.2
x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
x_1_train, x_1_test = train_test_split(x_1s, test_size=test_size, shuffle=False)
x_2_train, x_2_test = train_test_split(x_2s, test_size=test_size, shuffle=False)

down_lap1_train, down_lap1_test = train_test_split(
    down_lap1_list, test_size=test_size, shuffle=False
)
down_lap2_train, down_lap2_test = train_test_split(
    down_lap2_list, test_size=test_size, shuffle=False
)
incidence1_t_train, incidence1_t_test = train_test_split(
    incidence1_t_list, test_size=test_size, shuffle=False
)
incidence2_t_train, incidence2_t_test = train_test_split(
    incidence2_t_list, test_size=test_size, shuffle=False
)

y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)
y_train.shape

(80,)

# Training and Testing Model
Because the SCA implementation in [HZPMC22]_ was used for clustering, we did the same in one dimension to train on the int classification labels provided my the shrec16 dataset. 

In [142]:
model = AMPSSCA(
    channels_list=channels_list,
    complex_dim=complex_dim,
    num_classes=1,
    n_layers=3,
    att=False,
)
model = model.to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.1)
loss_fn = torch.nn.MSELoss()

In [143]:
test_interval = 1
num_epochs = 6

for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x_0, x_1, x_2, down_lap1, down_lap2, incidence_1t, incidence_2t, y in zip(
        x_0_train,
        x_1_train,
        x_2_train,
        down_lap1_train,
        down_lap2_train,
        incidence1_t_train,
        incidence2_t_train,
        y_train,
    ):
        x_0, x_1, x_2, y = (
            torch.tensor(x_0).float().to(device),
            torch.tensor(x_1).float().to(device),
            torch.tensor(x_2).float().to(device),
            torch.tensor(y).float().to(device),
        )
        x_list = [x_0, x_1, x_2]

        down_lap1 = down_lap1.float().to(device)
        down_lap2 = down_lap2.float().to(device)
        down_lap_list = [down_lap1, down_lap2]

        incidence_1t = incidence_1t.float().to(device)
        incidence_2t = incidence_2t.float().to(device)
        incidence_t_list = [incidence_1t, incidence_2t]

        opt.zero_grad()
        y_hat = model(x_list, down_lap_list, incidence_t_list)
        loss = loss_fn(y_hat.flatten(), y)
        loss.backward()

        opt.step()
        epoch_loss.append(loss.item())

    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss)}",
        flush=True,
    )

    if epoch_i % test_interval == 0:
        correct_count = 0
        with torch.no_grad():
            for (
                x_0,
                x_1,
                x_2,
                down_lap1,
                down_lap2,
                incidence_1t,
                incidence_2t,
                y,
            ) in zip(
                x_0_test,
                x_1_test,
                x_2_test,
                down_lap1_test,
                down_lap2_test,
                incidence1_t_test,
                incidence2_t_test,
                y_test,
            ):
                x_0, x_1, x_2, y = (
                    torch.tensor(x_0).float().to(device),
                    torch.tensor(x_1).float().to(device),
                    torch.tensor(x_2).float().to(device),
                    torch.tensor(y).float().to(device),
                )
                x_list = [x_0, x_1, x_2]

                down_lap1 = down_lap1.float().to(device)
                down_lap2 = down_lap2.float().to(device)
                down_lap_list = [down_lap1, down_lap2]

                incidence_1t = incidence_1t.float().to(device)
                incidence_2t = incidence_2t.float().to(device)
                incidence_t_list = [incidence_1t, incidence_2t]

                y_hat = model(x_list, down_lap_list, incidence_t_list)
                test_loss = loss_fn(y_hat, y)

                if round(y_hat.item()) == round(y.item()):
                    correct_count += 1

            print(f"Test_loss: {test_loss}", flush=True)

Epoch: 1 loss: 210.15074526998214
Test_loss: 310.42108154296875
Accuracy: 0.0
Epoch: 2 loss: 117.16418675570749
Test_loss: 182.73768615722656
Accuracy: 0.0
Epoch: 3 loss: 86.16661446671351
Test_loss: 124.76749420166016
Accuracy: 0.15
Epoch: 4 loss: 77.8968961050734
Test_loss: 98.735595703125
Accuracy: 0.0
Epoch: 5 loss: 75.84987897076644
Test_loss: 86.7345199584961
Accuracy: 0.0
Epoch: 6 loss: 75.18321903655306
Test_loss: 80.809326171875
Accuracy: 0.0
