# Train a Dynamic Hypergraph Neural Network (DHGNN)

In this notebook, we will create and train a two-step message passing network in the hypergraph domain. We will use a benchmark dataset, MUTAG (from the TUDataset), to train the model to perform binary classification at the level of the h.

### The Neural Network:

The equations of one layer of this neural network are given by:

A convolution from edges to edges using a cohomology message passing scheme:

🟥 $\quad m_{y \rightarrow x}^{(r' \rightarrow r)} = M^t_{\mathcal{C}}(h_{x}^{t,(r)}, h_y^{t,(r')}, x, y)$ 

🟧 $\quad m_x^{(r' \rightarrow r)}  = AGG_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(r' \rightarrow r)}$ 

🟩 $\quad m_x^{(r)} = m_x^{(r' \rightarrow r)}$ 

🟦 $\quad h_{x}^{t+1,(r)} = U^{t,(r)}(h_{x}^{t,(r)}, m_{x}^{(r)})$

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

### The Task:

We train this model to perform entire complex classification on [`MUTAG` from the TUDataset](https://paperswithcode.com/dataset/mutag). This dataset contains:
- 188 samples of chemical compounds represented as graphs,
- with 7 discrete node features.

The task is to predict the mutagenicity of each compound on Salmonella typhimurium.

In [1]:
import torch
import numpy as np
from sklearn.model_selection import train_test_split

from torch_geometric.datasets import TUDataset
from torch_geometric.utils.convert import to_networkx
from toponetx import SimplicialComplex
from topomodelx.nn.hypergraph.dhgnn_layer import DHGNNLayer

If GPU's are available, we will make use of them. Otherwise, this will run on CPU.

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


# Pre-processing

## Import data ##

The first step is to import the dataset, MUTAG, a benchmark dataset for graph classification. We then lift each graph into our domain of choice, a hypergraph.

We will also retrieve:
- input signal on the edges for each of these hypergraphs, as that will be what we feed the model in input
- the binary label associated to the hypergraph

In [3]:
dataset = TUDataset(root="/tmp/MUTAG", name="MUTAG", use_edge_attr=True)
dataset = dataset[:20]
hg_list = []
x_1_list = []
y_list = []
for graph in dataset:
    hg = SimplicialComplex(to_networkx(graph)).to_hypergraph()
    hg_list.append(hg)
    x_1 = torch.chunk(graph.edge_attr, 2, dim=0)[1]
    x_1_list.append(x_1)
    y_list.append(int(graph.y))

## Define neighborhood structures. ##

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on each simplicial complex. In the case of this architecture, we need the boundary matrix (or incidence matrix) $B_1$ with shape $n_\text{nodes} \times n_\text{edges}$.

In [4]:
incidence_1_list = []
for hg in hg_list:
    incidence_1 = hg.incidence_matrix()
    incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()
    incidence_1_list.append(incidence_1)

# Create the Neural Network

Using the DHGNNLayer class, we create a neural network with a single layer.

In [5]:
channels_edge = x_1_list[0].shape[1]
channels_node = dataset[0].x.shape[1]

In [6]:
class DHGNNNN(torch.nn.Module):
    """Neural network implementation of Template for hypergraph classification.

    Parameters
    ---------
    channels_edge : int
        Dimension of edge features
    channels_node : int
        Dimension of node features

    """

    def __init__(self, channels_edge, channels_node):
        super().__init__()
        self.layer = DHGNNLayer(
                    in_channels=channels_edge,
                    intermediate_channels=channels_node
                )

    def forward(self, x_1, incidence_1):
        """Forward computation through layers, then linear layer, then global max pooling.

        Parameters
        ---------
        x_1 : tensor
            shape = [n_edges, channels_edge]
            Edge features.

        incidence_1 : tensor
            shape = [n_nodes, n_edges]
            Boundary matrix of rank 1.

        Returns
        --------
        _ : tensor
            shape = [1]
            Label assigned to whole complex.
        """
        x_1 = self.layer(x_1, incidence_1)
        #pooled_x = torch.max(x_1, dim=0)[0]
        #return torch.sigmoid(self.linear(pooled_x))[0]
        return x_1

# Train the Neural Network

We specify the model, the loss, and an optimizer.

In [7]:
model = DHGNNNN(channels_edge, channels_node)
model = model.to(device)
crit = torch.nn.BCELoss()
opt = torch.optim.Adam(model.parameters(), lr=0.1)

Split the dataset into train and test sets.

In [8]:
test_size = 0.2
x_1_train, x_1_test = train_test_split(x_1_list, test_size=test_size, shuffle=False)
incidence_1_train, incidence_1_test = train_test_split(
    incidence_1_list, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=test_size, shuffle=False)

The following cell performs the training, looping over the network for a low amount of epochs. We keep training minimal for the purpose of rapid testing.

In [9]:
test_interval = 2
num_epochs = 5
threshold_probability_positive_class = 0.5
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    num_samples = 0
    correct = 0
    model.train()
    for x_1, incidence_1, y in zip(x_1_train, incidence_1_train, y_train):
        x_1, incidence_1, y = x_1.float().to(device), incidence_1.float().to(device), torch.tensor(y, dtype=torch.float).to(device)
        opt.zero_grad()
        y_hat = model(x_1, incidence_1)
        loss = crit(y_hat, y)
        correct += ((y_hat > threshold_probability_positive_class) == y.bool()).sum().item()
        num_samples += 1
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())
    train_acc = correct / num_samples
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {train_acc:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            num_samples = 0
            correct = 0
            for x_1, incidence_1, y in zip(x_1_test, incidence_1_test, y_test):
                x_1, incidence_1, y = x_1.float().to(device), incidence_1.float().to(device), torch.tensor(y, dtype=torch.float).to(device)
                y_hat = model(x_1, incidence_1)
                correct += ((y_hat > threshold_probability_positive_class) == y.bool()).sum().item()
                num_samples += 1
            test_acc = correct / num_samples
            print(f"Test_acc: {test_acc:.4f}", flush=True)

Epoch: 1 loss: 0.6087 Train_acc: 0.6875
Epoch: 2 loss: 0.5492 Train_acc: 0.8125
Test_acc: 0.7500
Epoch: 3 loss: 0.5097 Train_acc: 0.7500
Epoch: 4 loss: 0.4999 Train_acc: 0.7500
Test_acc: 0.7500
Epoch: 5 loss: 0.4986 Train_acc: 0.7500
