# Tutorial: Set-up, create and train a convolutional CCXN

In this notebook, we create and train a CCXN network, originally proposed in the paper by [Hajij et. al : Cell Complex Neural Networks (2020)](https://arxiv.org/pdf/2010.00743.pdf).

We will load a cell complex dataset (MUTAG, from the TUDataset) from the web and train the model to perform classification on this dataset.

The equations of one layer of this neural network are given by:

#### 1. A convolution from nodes to nodes using an adjacency message passing scheme (AMPS):

🟥 $\quad m_{y \rightarrow \{z\} \rightarrow x}^{(0 \rightarrow 1 \rightarrow 0)} = M_{\mathcal{L}_\uparrow}^t(h_x^{t,(0)}, h_y^{t,(0)}, \Theta^{t,(y \rightarrow x)})$ 

🟧 $\quad m_x^{(0 \rightarrow 1 \rightarrow 0)} = AGG_{y \in \mathcal{L}_\uparrow(x)}(m_{y \rightarrow \{z\} \rightarrow x}^{0 \rightarrow 1 \rightarrow 0})$ 

🟩 $\quad m_x^{(0)} = m_x^{(0 \rightarrow 1 \rightarrow 0)}$ 

🟦 $\quad h_x^{t+1,(0)} = U^{t}(h_x^{t,(0)}, m_x^{(0)})$

#### 2. A convolution from edges to faces using a cohomology message passing scheme:

🟥 $\quad m_{y \rightarrow x}^{(r' \rightarrow r)} = M^t_{\mathcal{C}}(h_{x}^{t,(r)}, h_y^{t,(r')}, x, y)$ 

🟧 $\quad m_x^{(r' \rightarrow r)}  = AGG_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(r' \rightarrow r)}$ 

🟩 $\quad m_x^{(r)} = m_x^{(r' \rightarrow r)}$ 

🟦 $\quad h_{x}^{t+1,(r)} = U^{t,(r)}(h_{x}^{t,(r)}, m_{x}^{(r)})$

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

In [43]:
import torch
import numpy as np
import random
from sklearn.model_selection import train_test_split

from torch_geometric.datasets import TUDataset
from torch_geometric.utils.convert import to_networkx
from toponetx import CellComplex

from topomodelx.nn.cell.cxn_layer import CCXNLayer

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [33]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
device = torch.device(device)

cpu


# Pre-processing

## Import data ##

The first step is to import the dataset, MUTAG, a benchmark dataset for graph classification. We then lift each graph into our domain of choice, a cell complex.

We will also retrieve:
- input signal on the nodes and edges for each of these cell complexes, as that will be what we feed the model in input
- the binary label associated to the cell complex

In [34]:
dataset = TUDataset(
    root="/tmp/MUTAG", name="MUTAG", use_edge_attr=True, use_node_attr=True
)
dataset = dataset[:20]
cell_list = []
x_0_list = []
x_1_list = []
y_list = []
for graph in dataset:
    cellcomplex = CellComplex(to_networkx(graph))
    cell_list.append(cellcomplex)
    x_0 = graph.x  # torch.chunk(graph.x, 2, dim=0)[1]
    x_1 = graph.edge_attr  # torch.chunk(graph.edge_attr, 2, dim=0)[1]
    x_0_list.append(x_0)
    x_1_list.append(x_1)
    y_list.append(int(graph.y))

## Define neighborhood structures. ##

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on each simplicial complex. In the case of this architecture, we need the adjacency matrix $A_{\uparrow, 0}$ and the coboundary matrix $B_2^T$.

In [35]:
incidence_2_t_list = []
adjacency_0_list = []
for cellcomplex in cell_list:
    incidence_2_t = cellcomplex.incidence_matrix(rank=2).T
    adjacency_0 = cellcomplex.adjacency_matrix(rank=0)
    incidence_2_t = torch.from_numpy(incidence_2_t.todense()).to_sparse()
    adjacency_0 = torch.from_numpy(adjacency_0.todense()).to_sparse()
    incidence_2_t_list.append(incidence_2_t)
    adjacency_0_list.append(adjacency_0)

# Create the Neural Network

Using the CCXNLayer class, we create a neural network with stacked layers.

In [36]:
in_ch_0 = x_0_list[0].shape[-1]
in_ch_1 = x_1_list[0].shape[-1]
in_ch_2 = 5

In [37]:
class CCXN(torch.nn.Module):
    """Convolutional CCXN.

    Parameters
    ----------
    in_ch_0 : int
        Dimension of input features on nodes.
    in_ch_1 : int
        Dimension of input features on edges.
    in_ch_2 : int
        Dimension of input features on faces.
    num_classes : int
        Number of classes.
    n_layers : int
        Number of CCXN layers.
    att : bool
        Whether to use attention.
    """

    def __init__(self, in_ch_0, in_ch_1, in_ch_2, num_classes, n_layers=2, att=False):
        super().__init__()
        layers = []
        for _ in range(n_layers):
            layers.append(
                CCXNLayer(
                    in_channels_0=in_ch_0,
                    in_channels_1=in_ch_1,
                    in_channels_2=in_ch_2,
                    att=att,
                )
            )
        self.layers = layers
        self.lin_0 = torch.nn.Linear(in_ch_0, num_classes)
        self.lin_1 = torch.nn.Linear(in_ch_1, num_classes)
        self.lin_2 = torch.nn.Linear(in_ch_2, num_classes)

    def forward(self, x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2):
        """Forward computation through layers, then linear layers, then avg pooling.

        Parameters
        ---------
        x_0 : tensor
            shape = [n_nodes, in_ch_0]
            Node features.

        x_1 : tensor
            shape = [n_edges, in_ch_1]
            Edge features.

        neighborhood_0_to_0 : tensor
            shape = [n_nodes, n_nodes]
            Adjacency matrix of rank 0 (up).

        neighborhood_1_to_2 : tensor
            shape = [n_faces, n_edges]
            Transpose of boundary matrix of rank 2.

        Returns
        --------
        _ : tensor
            shape = [1]
            Label assigned to whole complex.
        """
        for layer in self.layers:
            x_0, x_1, x_2 = layer(x_0, x_1, neighborhood_0_to_0, neighborhood_1_to_2)
        x_0 = self.lin_0(x_0)
        x_1 = self.lin_1(x_1)
        x_2 = self.lin_2(x_2)
        y_hat = torch.mean(x_2, dim=0) + torch.mean(x_1, dim=0) + torch.mean(x_0, dim=0)
        return torch.mean(x_2, dim=0) + torch.mean(x_1, dim=0) + torch.mean(x_0, dim=0)

# Train the Neural Network

We specify the model, initialize loss, and specify an optimizer. We first try it without any attention mechanism.

In [38]:
model = CCXN(in_ch_0, in_ch_1, in_ch_2, num_classes=2, n_layers=2)
model = model.to(device)
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.1)

Split the dataset into train and test sets.

In [39]:
x_0_train, x_0_test = train_test_split(x_0_list, test_size=0.2, shuffle=False)
x_1_train, x_1_test = train_test_split(x_1_list, test_size=0.2, shuffle=False)
incidence_2_t_train, incidence_2_t_test = train_test_split(
    incidence_2_t_list, test_size=0.2, shuffle=False
)
adjacency_0_train, adjacency_0_test = train_test_split(
    adjacency_0_list, test_size=0.2, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=0.2, shuffle=False)

The following cell performs the training, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing.

In [40]:
test_interval = 2
for epoch_i in range(1, 5):
    epoch_loss = []
    num_samples = 0
    correct = 0
    model.train()
    for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
        x_0_train, x_1_train, incidence_2_t_train, adjacency_0_train, y_train
    ):

        opt.zero_grad()

        y_hat = model(
            x_0.float(), x_1.float(), adjacency_0.float(), incidence_2_t.float()
        )
        y = torch.tensor(y).long()
        loss = crit(y_hat, y)
        correct += (y_hat.argmax() == y).sum().item()
        num_samples += 1
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    train_acc = correct / num_samples
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            num_samples = 0
            correct = 0
            for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
                x_0_test, x_1_test, incidence_2_t_test, adjacency_0_test, y_test
            ):
                y = torch.tensor(y).long()
                y_hat = model(
                    x_0.float(), x_1.float(), adjacency_0.float(), incidence_2_t.float()
                )

                correct += (y_hat.argmax() == y).sum().item()
                num_samples += 1
            test_acc = correct / num_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)

Epoch: 1 loss: nan Train_acc: 0.3750
Epoch: 2 loss: nan Train_acc: 0.3750
Test_acc: 0.5000
Epoch: 3 loss: nan Train_acc: 0.3750
Epoch: 4 loss: nan Train_acc: 0.3750
Test_acc: 0.5000


Now we create a new neural network, that uses the attention mechanism.

In [41]:
model = CCXN(in_ch_0, in_ch_1, in_ch_2, num_classes=2, n_layers=2, att=True)
model = model.to(device)
crit = torch.nn.CrossEntropyLoss()
opt = torch.optim.Adam(model.parameters(), lr=0.1)

We run the training for this neural network:

In [42]:
test_interval = 2
for epoch_i in range(1, 5):
    epoch_loss = []
    num_samples = 0
    correct = 0
    model.train()
    for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
        x_0_train, x_1_train, incidence_2_t_train, adjacency_0_train, y_train
    ):

        opt.zero_grad()

        y_hat = model(
            x_0.float(), x_1.float(), adjacency_0.float(), incidence_2_t.float()
        )
        y = torch.tensor(y).long()
        loss = crit(y_hat, y)
        correct += (y_hat.argmax() == y).sum().item()
        num_samples += 1
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    train_acc = correct / num_samples
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            num_samples = 0
            correct = 0
            for x_0, x_1, incidence_2_t, adjacency_0, y in zip(
                x_0_test, x_1_test, incidence_2_t_test, adjacency_0_test, y_test
            ):
                y = torch.tensor(y).long()
                y_hat = model(
                    x_0.float(), x_1.float(), adjacency_0.float(), incidence_2_t.float()
                )

                correct += (y_hat.argmax() == y).sum().item()
                num_samples += 1
            test_acc = correct / num_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)

Epoch: 1 loss: nan Train_acc: 0.3750
Epoch: 2 loss: nan Train_acc: 0.3750
Test_acc: 0.5000
Epoch: 3 loss: nan Train_acc: 0.3750
Epoch: 4 loss: nan Train_acc: 0.3750
Test_acc: 0.5000
