# Train a Simplicial Complex Autoencoder (SCA)


🟥 $\quad m_{y \rightarrow \{z\} \rightarrow x}^{(r \rightarrow r' \rightarrow r)}  = M(h_{x}^{t, (r)}, h_{y}^{t, (r)}, att(h_{x}^{t, (r)}, h_{y}^{t, (r)}), x, y, \Theta^t) \qquad \text{where } r'' < r < r'$

🟥 $\quad m_{y \rightarrow \{z\} \rightarrow x}^{(r' \rightarrow r)} = M(h_{x}^{t, (r)}, h_{y}^{t, (r')}, att(h_{x}^{t, (r)}, h_{y}^{t, (r')}), x, y, \Theta^t)$

🟧 $\quad m_x^{(r \rightarrow r' \rightarrow r)}  = \text{AGG}\_{y \in \mathcal{L}\_\uparrow(x)} m_{y \rightarrow \{z\} \rightarrow x}^{(r \rightarrow r' \rightarrow r)}$

🟧 $\quad m_x^{(r' \rightarrow r)} = \text{AGG}\_{y \in \mathcal{C}(x)} m_{y \rightarrow \{z\} \rightarrow x}^{(r' \rightarrow r)}$

🟩 $\quad m_x^{(r)}  = \text{AGG}\_{\mathcal{N}\_k \in \mathcal{N}}(m_x^{(k)})$

🟦 $\quad h_{x}^{t+1, (r)} = U(h_x^{t, (r)}, m_{x}^{(r)})$

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

In [1]:
import torch
import numpy as np

from toponetx import SimplicialComplex
import toponetx.datasets as datasets
from sklearn.model_selection import train_test_split

from topomodelx.nn.simplicial.sca_layer import SCALayer

If GPUs are available we will make use of them. Otherwise, we will use CPU.

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##

The first step is to import the dataset, shrec16, a benchmark dataset for 3D mesh classification. We then lift each graph into our domain of choice, a simplicial complex.

We also retrieve:
- input signals `x_0`, `x_1`, and `x_2` on the nodes (0-cells), edges (1-cells), and faces (2-cells) for each complex: these will be the model's inputs,
- a scalar classification label `y` associated to the cell complex.

In [3]:
shrec, _ = datasets.mesh.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
scs = shrec["complexes"]

Loading shrec 16 small dataset...

done!


In [4]:
i_complex = 6
print(
    f"The {i_complex}th simplicial complex has {x_0s[i_complex].shape[0]} nodes with features of dimension {x_0s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_1s[i_complex].shape[0]} edges with features of dimension {x_1s[i_complex].shape[1]}."
)
print(
    f"The {i_complex}th simplicial complex has {x_2s[i_complex].shape[0]} faces with features of dimension {x_2s[i_complex].shape[1]}."
)

The 6th simplicial complex has 252 nodes with features of dimension 6.
The 6th simplicial complex has 750 edges with features of dimension 10.
The 6th simplicial complex has 500 faces with features of dimension 7.


## Preparing the inputs to test each message passing scheme:

#### Adjacency Message Passing Scheme (AMPS):
This will require features from the faces and edges outputting features on faces. The first neighborhood matrix will be the level 1 upper Laplacian, $L_{\uparrow, 1}$, and the second neighborhood matrix will be the incidence matrix of the faces, $B_2$.

#### Coadjacency Message Passing Scheme (CMPS):
This will require features from faces, and edges again, but outputs features on the edges. The first neighborhood matrix will be the level 2 lower Laplacian, $L_{\downarrow, 2}$, and the second neighborhood matrix will be the transpose of the incidence matrix of the faces, $B_{2}^T$.

#### Homology and Cohomology Message Passing Scheme (HCMPS):
This will require features from faces, edges, and nodes outputing features on the edges. The first neighborhood matrix will be the transpose of the incidence matrix of the edges, $B_{1}^T$, and the second neighborhood matrix will be the incidence matrix of the faces, $B_2$.

In [5]:
up_lap1_list = []
down_lap2_list = []
incidence1_t_list = []
incidence2_list = []
incidence2_t_list = []

for sc in scs:
    up_lap1 = sc.up_laplacian_matrix(rank=1)
    down_lap2 = sc.down_laplacian_matrix(rank=2)
    incidence1_t = sc.incidence_matrix(rank=1).T
    incidence_2 = sc.incidence_matrix(rank=2)
    incidence_2_t = sc.incidence_matrix(rank=2).T
    up_lap1 = torch.from_numpy(up_lap1.todense()).to_sparse()
    down_lap2 = torch.from_numpy(down_lap2.todense()).to_sparse()
    incidence1_t = torch.from_numpy(incidence1_t.todense()).to_sparse()
    incidence_2 = torch.from_numpy(incidence_2.todense()).to_sparse()
    incidence_2_t = torch.from_numpy(incidence_2_t.todense()).to_sparse()

    up_lap1_list.append(up_lap1)
    down_lap2_list.append(down_lap2)
    incidence1_t_list.append(incidence1_t)
    incidence2_list.append(incidence_2)
    incidence2_t_list.append(incidence_2_t)

# Create the Neural Networks

Using the SCALayer class, we create a neural network with stacked layers for each scheme.

In [6]:
# AMPS
in_channels_1a = x_1s.shape[-1]
in_channels_2a = x_2s.shape[-1]
out_channels_a = x_2s.shape[-1]

# CMPS
in_channels_1c = x_1s.shape[-1]
in_channels_2c = x_2s.shape[-1]
out_channels_c = x_1s.shape[-1]

# HCMPS
in_channels_1h = x_0s.shape[-1]
in_channels_2h = x_2s.shape[-1]
out_channels_h = x_1s.shape[-1]

In [7]:
class AMPSSCA(torch.nn.Module):
    """SCA with AMPS.

    Parameters
    ----------
    in_channels_1 : int
        Dimension of input features on edges.
    in_channels_2 : int
        Dimension of input features on faces.
    num_classes : int
        Number of classes.
    att : bool
        Whether to use attention.
    """

    def __init__(
        self,
        in_channels_1,
        in_channels_2,
        num_classes,
        att=False,
    ):
        super().__init__()
        self.sca_layer = SCALayer(in_channels_1, in_channels_2, in_channels_2, att)
        self.lin_ = torch.nn.Linear(in_channels_2, num_classes)

    def forward(self, x_1, x_2, neighborhood_1, neighborhood_2):
        """Forward computation through layers, then linear layers, then avg pooling.

        Parameters
        ----------
        x_0 : torch.Tensor, shape = [n_nodes, in_channels_0]
            Input features on the nodes (0-cells).
        x_1 : torch.Tensor, shape = [n_edges, in_channels_1]
            Input features on the edges (1-cells).
        neighborhood_0_to_0 : tensor, shape = [n_nodes, n_nodes]
            Adjacency matrix of rank 0 (up).
        neighborhood_1_to_2 : tensor, shape = [n_faces, n_edges]
            Transpose of boundary matrix of rank 2.
        x_2 : torch.Tensor, shape = [n_faces, in_channels_2]
            Input features on the faces (2-cells).
            Optional. Use for attention mechanism between edges and faces.

        Returns
        -------
        _ : tensor, shape = [1]
            Label assigned to whole complex.
        """
        x_2 = self.sca_layer(x_1, x_2, neighborhood_1, neighborhood_2)
        x_2 = self.lin_(x_2)
        return torch.softmax(x_2, dim=-1)

In [8]:
out_dim = max(ys)
y_list = []

for y in ys:
    y_one_hot = torch.zeros(1, out_dim)
    y_one_hot[0, y-1] = 1
    y_list.append(y_one_hot)

In [18]:
test_size = 0.2
x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
x_1_train, x_1_test = train_test_split(x_1s, test_size=test_size, shuffle=False)
x_2_train, x_2_test = train_test_split(x_2s, test_size=test_size, shuffle=False)

up_lap1_train, up_lap1_test = train_test_split(
    up_lap1_list, test_size=test_size, shuffle=False
)
down_lap2_train, down_lap2_test = train_test_split(
    down_lap2_list, test_size=test_size, shuffle=False
)
incidence1_t_train, incidence1_t_test = train_test_split(
    incidence1_t_list, test_size=test_size, shuffle=False
)
incidence2_train, incidence2_test = train_test_split(
    incidence2_list, test_size=test_size, shuffle=False
)
incidence2_t_train, incidence2_t_test = train_test_split(
    incidence2_t_list, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=test_size, shuffle=False)

In [22]:
model = AMPSSCA(in_channels_1a, in_channels_2a, num_classes=out_dim)
model = model.to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.1)
loss_fn = torch.nn.CrossEntropyLoss()

In [23]:
out_dim

29

In [24]:
test_interval = 2
num_epochs = 4
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x_1, x_2, up_lap1, incidence_2, y in zip(
        x_1_train, x_2_train, up_lap1_train, incidence2_train, y_train
    ):
        x_1, x_2, y = (
            torch.tensor(x_1).float().to(device),
            torch.tensor(x_2).float().to(device),
            torch.tensor(y).float().to(device),
        )
        up_lap1, incidence_2 = up_lap1.float().to(
            device
        ), incidence_2.float().to(device)
        opt.zero_grad()
        y_hat = model(x_1, x_2, up_lap1, incidence_2)
        loss = loss_fn(y_hat, y)
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            for x_1, x_2, up_lap1, incidence_2, y in zip(
                x_1_test, x_2_test, up_lap1_test, incidence2_test, y_test
            ):
                x_1, x_2, y = (
                    torch.tensor(x_1).float().to(device),
                    torch.tensor(x_2).float().to(device),
                    torch.tensor(y).float().to(device),
                )
                up_lap1, incidence_2 = up_lap1.float().to(
                    device
                ), incidence_2.float().to(device)
                y_hat = model(x_1, x_2, up_lap1, incidence_2)
                test_loss = loss_fn(y_hat, y)
            print(f"Test_loss: {test_loss:.4f}", flush=True)

  torch.tensor(y).float().to(device),


ValueError: Expected input batch_size (750) to match target batch_size (1).

In [29]:
x_1

tensor([[1.4278, 0.1538, 0.6280,  ..., 1.7742, 1.5595, 0.8965],
        [0.4095, 0.1731, 0.7004,  ..., 2.9332, 1.6330, 0.6524],
        [1.2717, 0.1721, 0.4799,  ..., 2.5734, 1.5595, 1.1941],
        ...,
        [0.8114, 0.1777, 0.8643,  ..., 1.7053, 0.7571, 1.4963],
        [0.4296, 0.2196, 1.1156,  ..., 1.8233, 2.0773, 0.6180],
        [0.8864, 0.2132, 0.5716,  ..., 2.3777, 1.4339, 2.0468]])