# Topological Deep Learning: A new direction for artificial intelligence with healthcare applications

# 位相幾何学・ディープラーニング: 健康管理に応用される人工知能の新たな方向性

いそういくなんがく・ディープラーニング：けんこうにかんりにおうようされる人工知能のあらたなほうこうせい

This workshop consists of a notebook. The goals of this notebook are:

* Briefly present what are Topological Deep Learning (and Geometric Deep Learning) by providing some explanations, definitions, visuals, etc
* Present the potential applications in healthcare and some work that have been done so far in these emerging fields
* Run some code to get some hands-on experience

## Author / Acknowledgment

This workshop has been written by **Adrien Carrel** for the [3rd Big Data Machine Learning in Healthcare in Japan (2023)](https://datathon-japan.jp/2023tokyo/). He holds a MSc degree in Advanced Computing obtained at Imperial College London, United Kingdom, he previously obtained a MEng in Applied Mathematics (Diplôme d'Ingénieur) at CentraleSupélec, France, and he studied in Classes Préparatoires (MPSI/MP*) in France at Lycée Hoche (Versailles) and Lycée Pierre Corneille (Rouen).

Feel free to reach out to me using the links below!

Mail: a.carrel@hotmail.fr

Tel: +33 7 86 83 27 05

<a href="https://linkedin.com/in/adrien.carrel/" target="_blank"><img align="center" src="https://cdn.jsdelivr.net/npm/simple-icons@3.0.1/icons/linkedin.svg" alt="linkedin" height="39" width="52"/></a>
<a href="https://www.instagram.com/adrien.carrel" target="_blank"><img align="center" src="https://cdn.jsdelivr.net/npm/simple-icons@3.0.1/icons/instagram.svg" alt="instagram" height="39" width="52" /></a>
<a href="https://github.com/AdrienC21/" target="_blank"><img align="center" src="https://cdn.jsdelivr.net/npm/simple-icons@3.0.1/icons/github.svg" alt="github" height="39" width="52" /></a>

<img src="https://adriencarrel.com/images/avatar.jpg" alt="Drawing" style="width: 200px;"/>

Special thanks to Tolga Birdal, Mustafa Hajij, Nina Miolane and the rest of the authors in the two papers below. Some ideas have also been inspired by the work of Li et al. (see below) and the Hands-on on graph neural network written by Google. I would also like to thank all the authors of the papers that I mentionned in this notebook.

* [Papillon et al. : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031)

* [Hajij et al. : Topological Deep Learning: Going Beyond Graph Data](https://arxiv.org/abs/2206.00606)

* [Li et al. : Graph Representation Learning in Biomedicine and Healthcare (2022)](https://www.nature.com/articles/s41551-022-00942-x)

* [Adrien Carrel: Combinatorial Complex Score-based Diffusion Modeling through Stochastic Differential Equations (2023)](https://github.com/AdrienC21/CCSD)

## Topological/Geometric Deep Learning

**What is Machine Learning?**

Deep learning is a branch of machine learning that uses neural networks to process data and learn patterns, enabling computers to perform tasks like image recognition, language understanding, and decision-making. It's inspired by the brain's structure, using layers of interconnected nodes (neurons) to extract complex features from input data and make accurate predictions.

<img src="https://thegradient.pub/content/images/size/w1600/2019/02/1_1mpE6fsq5LNxH31xeTWi5w.jpeg" alt="Drawing" style="width: 300px;"/>

Source: [The Gradient](https://thegradient.pub/the-limitations-of-visual-deep-learning-and-how-we-might-fix-them/)

**What is Topology?**

Topology is a branch of mathematics that studies the properties of spaces and shapes, and how similar they are between each other under continuous deformations (shearing, stretching, bending, etc), but not tearing, drilling or gluing! It generalizes these notions without relying on the concept of distance. The goal is to classify and compare different types of spaces and objects by relying on properties such as connectivity, compactness, and continuity.

**Question 1:** What is a hole?

**Question 2:** How many holes in the human body?

<img src="https://www.math.ens.psl.eu/~cemprin/formes.png" alt="Drawing" style="width: 400px;"/>

Source: [ENS Travaux dirigés (Printemps 2022) : Topologie algébrique](https://www.math.ens.psl.eu/~cemprin/enseignementP22.html)

<img src="https://i.insider.com/57f4f6e8dd08959f358b482f?width=500&format=jpeg&auto=webp" alt="Drawing" style="width: 400px;"/>

Source: [Johan Jarnestad/The Royal Swedish Academy of Sciences](https://www.nobelprize.org/nobel_prizes/physics/laureates/2016/popular-physicsprize2016.pdf)

**And so ... what are TDL and GDL?**

Topological Deep Learning (TDL) and Geometric Deep Learning (GDL) are two new and emerging subfields within the broader fields of machine learning and deep learning that focus on incorporating topological and geometric information into the learning process. Such information is often refered to as **inductive bias**. The idea behind is that, by including such information, the performance of the resulting model will increase.

TDL is the combination of (algebraic) topology and machine learning or deep learning. This combination generally consists of transforming the data into topological representations that capture its underlying structure. These modified data are then combined with algorithms, often neural networks, thus improving sometimes their ability to handle complex and high-dimensional data.

GDL and TDL sometimes refer to the same things. However, GDL mostly refers to learning from unconventional sources of data with different geometric structures, such as graphs, point clouds, meshes, and manifolds.

Overall, the aim of topological deep learning is to create neural network architectures that are invariant or equivariant to certain transformations (as the original objects are). Mathematically, let $f$ be a scalar function, $F$ be a vector function, $X$ be a node feature matrix, $A$ and adjacency matrix, and $\mathcal{P}$ the set of all the permutations of nodes. The invariance and equivariance are defined by:

* $\forall P\in\mathcal{P}, f(PX,PAP^{T})=f(X,A)$

* $\forall P\in\mathcal{P}, F(PX,PAP^{T})=PF(X,A)$

For reference, here is a nice paper on invariance and equivariance in graph neural networks: [Keriven, N. and Peyré, G. Universal invariant and equivariant graph neural networks. In Advances in Neural Information Processing Systems 32: Annual Conference on Neural Information Processing Systems 2019, NeurIPS 2019](https://papers.nips.cc/paper_files/paper/2019/hash/ea9268cb43f55d1d12380fb6ea5bf572-Abstract.html)

Categories of Geometric Deep Learning. Source: [Bronstein, Bruna, Cohen, Velickovic, Geometric Deep Learning: Grids, Groups, Graphs, Geodesics, and Gauges (2021)](https://arxiv.org/abs/2104.13478).

<img src="https://miro.medium.com/v2/resize:fit:1400/format:webp/1*J6Ipo8rqdjpsN3_9LafnQw.png" alt="Drawing" style="width: 600px;"/>

_Important remark:_ TDL/GDL generalize Deep Learning. Here are a few reasons:

* Structures we encounter can be represented as grid (e.g. images) or graphs (see [Everything is Connected: Graph Neural Networks
](https://arxiv.org/abs/2301.08210)). Therefore, such architectures generalize and could be applied to a wider range of data types (compared to RNN, CNN, etc).

* Graphs and other geometric structures also often have sparse connections or missing data.

* Connectivity allows interpretability. Also, if such a model performs well, it can inform us about the geometry of the data.

* Most of the TDL architectures preserve invariance and equivariance compare to regular neural network architectures.

## Application

**Neuroscience research:** [Georgiadis K, Kalaganis FP, Oikonomou VP, Nikolopoulos S, Laskaris NA, Kompatsiaris I. RNeuMark: A Riemannian EEG Analysis Framework for Neuromarketing. Brain Inform. (2022)](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9481797/)

Exploit Riemannian geometry concepts to introduce a novel EEG-based decoder for detecting the consumers' preferences.

**Diagnosis of Glaucoma**: [Alexandre H. Thiéry, Fabian Braeu, Tin A. Tun, Tin Aung, Michaël J. A. Girard; Medical Application of Geometric Deep Learning for the Diagnosis of Glaucoma. Trans. Vis. Sci. Tech. 2023;12(2):23.](https://tvst.arvojournals.org/article.aspx?articleid=2785377)

Apply a geometric deep learning solution (PointNet) to provide a robust glaucoma diagnosis from a single OCT (optical coherence tomography) scan of the ONH (optic nerve head).

<img src="https://arvo.silverchair-cdn.com/arvo/content_public/journal/tvst/938623/m_i2164-2591-12-2-23-f1_1676453203.08458.png?Expires=1695180495&Signature=ubCEKrIBHzw2Lz9vCXZmrrihTgZiSLrgnQUT12Zof6fKV0INJZut6M-GHAtwZ~YHHKWpkzJBwtK9nfnb6oDCoY~SJGye3gtmofXIJRc-4hn1Pt~1BXd7RH4jct~y0WoVtBIN-gI8B7jukfJAu4ZNOzQkn5OA4lwTWcnkj~leDoquDISsuizuS80-0CMIly4YEng~nluOux2cSCxjH5Ys-hM~6dk2AcXseXZY~Ey8iD54jken8cKMxC4-oH9Jbbm~mGMRX3~XShnqCvpTYqagXJ6ASumO8tcOuoHVP-4zzXMsmhY2XYQut19GeEsIkQot8fnYv3toWAhxf5eB9Clj5A__&Key-Pair-Id=APKAIE5G5CRDK6RD3PGA" alt="Drawing" style="width: 300px;"/>

**Drug virus interactions:** [Das B, Kutsal M, Das R. A geometric deep learning model for display and prediction of potential drug-virus interactions against SARS-CoV-2.](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9400382/)

Simulate the interactions between molecules and viruses in order to predict if and how drugs will work.

<img src="https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9400382/bin/gr1_lrg.jpg" alt="Drawing" style="width: 300px;"/>

**Pathology prediction Chest X-Ray**: [Gaurang A. Karwande. Geometric Deep Learning for Healthcare Applications (2023)](https://vtechworks.lib.vt.edu/bitstream/handle/10919/115361/Karwande_GA_T_2023.pdf?sequence=1&isAllowed=y)

Construct a graphical representation of the CXR (Chest X-Ray) image pairs by utilizing the correlation
among anatomical region features from the images, and the correlation among anatomical regions between the two images in the pair.

**Drug prediction, discovery, and molecular conformer generation**:

[Shen, C., Luo, J., & Xia, K. (2023). Molecular geometric deep learning.](https://arxiv.org/abs/2306.15065)

[Adrien Carrel. Combinatorial Complex Score-based Diffusion Modeling through Stochastic Differential Equations. (2023)](https://github.com/AdrienC21/CCSD)

[Jing, B., Corso, G., Chang, J., Barzilay, R., & Jaakkola, T. (2022). Torsional Diffusion for Molecular Conformer Generation.](https://arxiv.org/abs/2206.01729)

Model the drugs or molecules as graphs other topological structures and predict their intrinsic properties like their solubility in a particular fluid. Generate some new molecules that follow a particular distribution.

<img src="https://adriencarrel.com/images/torsional.png" alt="Drawing" style="width: 300px;"/>

## Preliminary steps

Installation of sone packages and the QM9 dataset! (qm9.csv to download on your machine)

In [None]:
# Let's install PyTorch and TopoModelX
# For GPU, change "cpu" to "cu118" or your cuda version
!pip install torch==2.0.1 --extra-index-url https://download.pytorch.org/whl/cpu
!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.0.1+cpu.html
!pip install git+https://github.com/pyt-team/TopoModelX@ff0425a825311f9f80b6ca784bae128b11ccb827
!pip install rdkit

**RESTART YOUR KERNEL/RUNTIME** after the installation(s)!

General imports

In [None]:
import os
import json
import math
from typing import Optional, Tuple, List, Union, Dict, Any, Callable
from time import perf_counter

import torch
import torch_geometric
import random
import hypernetx
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from tqdm import tqdm
from torch.nn import Linear
from torch.nn.parameter import Parameter
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx
from rdkit import Chem
from rdkit.Chem import Draw


Load GPU if available

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")


Set seed for interpretability

In [None]:
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


## Introduction: Graph Neural Networks

### Preamble

Graph neural networks (GNNs) are a type of geometric deep learning models to process ... surprise ... graphs! It consists of iteratively updating the features of each node (or edges) in the graph, taking into account the features of its neighbors = the geometry of our data.

**Definition of a (Undirected) Graph:**

Let $S$ be a non-empty set. A graph on $S$ is a pair $\mathcal{G}=(S, \mathcal{E})$ where $\mathcal{E}$ is a set of non-empty subsets of **size 2** of the powerset $\mathcal{P}(S)$ of $S$, which are called edges. Elements of $S$ are called vertices.

A graph is said to be directed if for all $(u, v)\in \mathcal{E}$, $(v, u)\in \mathcal{E}$.

------

<img src="https://adriencarrel.com/images/GNN.png" alt="Drawing" style="width: 500px;"/>

In this part, we will introduce you to the basics of GNNs using the [PyTorch](https://pytorch.org/) and [PyTorch Geometric (PyG)](https://github.com/rusty1s/pytorch_geometric) libraries. PyG is an extension of PyTorch that provides a variety of tools for working with graph data. We will start by working on the **Zachary's Karate club dataset**. It consists of 34 nodes, which represent members of a university Karate club, and 78 edges, which represent connections between members who interacted outside of the club. During the study, a conflict arose between the administrator "John A" and instructor "Mr. Hi" (there are pseudonyms), which led to the split of the club into two. Half of the members formed a new club around Mr. Hi, members from the other part found a new instructor or gave up karate. They also provided labels for 4 classes classification.

Source: ["An Information Flow Model for Conflict and Fission in Small Groups" by Wayne W. Zachary.](https://en.wikipedia.org/wiki/Zachary%27s_karate_club)

We will then train a GNN to detect communities in the Karate club network and predict in which group members will go. The GNN will learn to predict the community label for each node in the network. The model will be trained on a subset of nodes and will be evaluated on a test set which consists on the rest of the nodes. The accuracy will be our measure for the performance of the model.

Mathematically,

This is done by following a simple **neural message passing scheme**, where node features $\mathbf{x}_v^{(\ell)}$ of all nodes $v \in S$ at layer $l\in\mathbb{N}$ in a graph $\mathcal{G} = (S, \mathcal{E})$ are iteratively updated by aggregating localized information from their neighbors $\mathcal{N}(v)$:

$$
\mathbf{x}_v^{(\ell + 1)} = f^{(\ell + 1)}_{\theta} \left( \mathbf{x}_v^{(\ell)}, \left\{ \mathbf{x}_w^{(\ell)} : w \in \mathcal{N}(v) \right\} \right)
$$

where $\mathcal{N}(v)=\{ w\in S, (w, v)\in\mathcal{E} \}$

---------

The graph neural network that will be implemented below will be made of multiple **Graph Convolutional Network layers (GCN)**: [Kipf, T. N., & Welling, M. (2016). Semi-Supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907). This layer is defined by the following operator:

$$
\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)}
$$

where $\mathbf{W}^{(\ell + 1)}$ are learnable parameters represented as a matrix of size `[nb_output_features, nb_input_features]` and $c_{w,v}$ refers to a fixed normalization coefficient for each edge $(w,v)$.

This layer is already implemented in PyTorch Geometric: [`GCNConv`](https://pytorch-geometric.readthedocs.io/en/latest/modules/nn.html#torch_geometric.nn.conv.GCNConv). It takes as an input the node feature representation `x` and the graph connectivity representation `edge_index` created by PyTorch Geometric when we loaded our data. The tensor `edge_index` is represented as a **COO** format (coordinate format) commonly used for representing sparse matrices. The idea is that, instead of storing the adjacency information in a dense representation $A \in \{ 0, 1 \}^{|S| \times |S|}$, PyTorch Geometric represents graphs sparsely, which refers to only holding the coordinates/values for which entries in $A$ are non-zero. This redices space and time complexity.

In particular, for this implementation, the weights and normalization coefficients are designed such that:

$$
\mathbf{x}_v^{(\ell + 1)} =\hat{D}^{-\frac{1}{2}}\hat{A}\hat{D}^{-\frac{1}{2}} \mathbf{x}_v^{(\ell)} \mathbf{W}^{(\ell + 1)}
$$

where $\hat{A}=A+I$ is the adjacency matrix with self-loops added and $\hat{D}$ is a diagonal degree matrix defined by: $\forall u\in S, \hat{D}_{u,u}=\sum_{v\in S}\hat{A}_{u,v}$

### Preliminary analysis

Let's import the dataset

In [None]:
dataset = KarateClub()
print(f"Dataset: {dataset}:")
print("======================")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of features: {dataset.num_features}")
print(f"Number of classes: {dataset.num_classes}")


Print information about our graph

In [None]:
data = dataset[0]  # Get the graph object.

print(data)
print(f"Type: {type(data)}")
print("==============================================================")

# Add test_mask: We train on 1 node per class only!
data.test_mask = ~data.train_mask
print(f"Test set size: {round(100 * data.test_mask.sum().item() / data.num_nodes, 3)}%")

# Gather some statistics about the dataset/graph.
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")
# TODO: avg_node_degree  # calculate the average number of edges per node
print(f"Average node degree: {avg_node_degree}")
print(f"Number of training nodes: {data.train_mask.sum()}")
print(f"Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}")
print(f"Has isolated nodes: {data.has_isolated_nodes()}")
print(f"Has self-loops: {data.has_self_loops()}")
is_undirected = False  # TODO: to modify, is_undirected should be eqal to True if the graph is undirected
print(f"Is undirected: {is_undirected}")

Define utility functions

In [None]:
def visualize_graph(G: nx.Graph, color: torch.Tensor) -> None:
    """Visualize a networkx graph.

    Args:
        G (nx.Graph): A networkx graph.
        color (torch.Tensor): A color for the nodes.
    """
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    # TODO: complete the code to visualize the graph with the nodes colored according to the color tensor
    plt.show()


def visualize_embedding(
    h: torch.Tensor,
    color: torch.Tensor,
    epoch: Optional[int] = None,
    loss: Optional[torch.Tensor] = None,
) -> None:
    """Visualize a 2D embedding.

    Args:
        h (torch.Tensor): A 2D embedding.
        color (torch.Tensor): Colors for the nodes.
        epoch (Optional[int], optional): The current epoch. Defaults to None.
        loss (Optional[torch.Tensor], optional): The current loss. Defaults to None.
    """
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
    if epoch is not None and loss is not None:
        plt.title(f"Epoch: {epoch}, Loss: {loss.item():.4f}", fontsize=16)
    plt.xlabel("h0")
    plt.ylabel("h1")
    plt.show()


Visualize the Karate Club graph dataset

In [None]:
G = to_networkx(data, to_undirected=True)
# TODO: visualize the graph using your function and the label accessible using "data.y"


### Implementing/training Graph Neural Networks

Now, let's implement and train our neural network!

We create the network architecture by creating a class that inherits `torch.nn.Module`:

In [None]:
class GCN(torch.nn.Module):
    """Our simple GCN model."""

    def __init__(
        self,
        num_features: int,
        hidden_dim: int = 4,
        embedding_dim: int = 2,
        num_classes: int = 4,
    ) -> None:
        """Initialize the model.

        Args:
            num_features (int): The number of input features.
            hidden_dim (int, optional): The hidden dimension. Defaults to 4.
            embedding_dim (int, optional): The embedding dimension. Best to set it to 2 for embeddings interpretability. Defaults to 2.
            num_classes (int, optional): The number of classes. Defaults to 4.
        """
        super().__init__()
        self.num_features = num_features
        self.hidden_dim = hidden_dim
        self.embedding_dim = embedding_dim
        self.num_classes = num_classes
        self.conv1 = GCNConv(
            self.num_features, self.hidden_dim
        )  # 32, hidden_dim for Karate
        self.conv2 = GCNConv(self.hidden_dim, self.hidden_dim)
        self.conv3 = GCNConv(self.hidden_dim, self.embedding_dim)
        self.classifier = Linear(2, self.num_classes)  # embedding_dim, 4 for Karate

        self.reset_parameters()

    def reset_parameters(self) -> None:
        """Reset the parameters."""
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()
        self.conv3.reset_parameters()
        self.classifier.reset_parameters()

    def forward(
        self, x: torch.Tensor, edge_index: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward pass of the model.

        Args:
            x (torch.Tensor): The input node features.
            edge_index (torch.Tensor): The graph connectivity.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The output logits for each node classification, and the embedding h.
        """
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()
        # Last layer h is the embedding

        # Apply a final (linear) classifier.

        # TODO: complete the code to apply the final classifier on the embedding h. The result should be stored in the variable "out".

        return out, h

    def __repr__(self) -> str:
        """Representation of the model."""
        return f"{self.__class__.__name__}(num_features={self.num_features}, hidden_dim={self.hidden_dim}, embedding_dim={self.embedding_dim}, num_classes={self.num_classes})"

In [None]:
model = GCN(
    num_features=dataset.num_features,
    hidden_dim=4,
    embedding_dim=2,
    num_classes=dataset.num_classes,
).to(device)
print(model)


Get the number of parameters of our model

In [None]:
def get_nb_parameters(model: torch.nn.Module) -> int:
    """Get the number of trainable parameters of a model.

    Args:
        model (torch.nn.Module): The model.

    Returns:
        int: The number of trainable parameters.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


print(f"Number of parameters GCN: {get_nb_parameters(model)}")


Plot our initial embeddings with the model not trained

The idea is that we can visualize the quality of the training of our model through the embeddings.

In [None]:
model.eval()
with torch.no_grad():
    _, h = model(data.x, data.edge_index)

print(f"Embedding shape: {list(h.shape)}")

visualize_embedding(h, color=data.y)


Without any training, it is random :/

Define a loss and an optimizer

In [None]:
learning_rate = 0.01  # TODO: try a different learning rate?

criterion = torch.nn.CrossEntropyLoss()  # cross entropy loss for classification.
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  # Adam optimizer.


Let's train our model!

In [None]:
def train_one_epoch(
    model: torch.nn.Module, data: torch_geometric.data.data.Data
) -> Tuple[float, float, float, float, torch.Tensor]:
    """Train the model for 1 epoch.

    Args:
        model (torch.nn.Module): The model to train.
        data (torch_geometric.data.data.Data): The graph data.

    Returns:
        Tuple[float, float, float, float, torch.Tensor]: The losses on train and test, the accuracy on train and test, and the embedding h after one epoch/step.
    """
    model.train()
    optimizer.zero_grad()  # clear gradients.
    out, h = model(data.x, data.edge_index)  # perform a single forward pass.
    loss_train = criterion(
        out[data.train_mask], data.y[data.train_mask]
    )  # compute the loss on the training nodes.
    model.eval()
    with torch.no_grad():  # no gradients for test pass.
        loss_test = criterion(out[data.test_mask], data.y[data.test_mask])
        acc_train = (
            torch.sum(
                out[data.train_mask].argmax(dim=1) == data.y[data.train_mask]
            ).item()
            / data.train_mask.sum().item()
        )
        acc_test = (
            torch.sum(
                out[data.test_mask].argmax(dim=1) == data.y[data.test_mask]
            ).item()
            / data.test_mask.sum().item()
        )
    model.train()
    loss_train.backward()  # derive gradients.
    optimizer.step()  # update parameters based on gradients.
    return loss_train.item(), loss_test.item(), acc_train, acc_test, h


def train(
    model: torch.nn.Module,
    data: Union[
        torch_geometric.data.data.Data,
        Tuple[
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
            torch.Tensor,
        ],
    ],
    nb_epochs: int,
    train_func: Callable[
        [
            torch.nn.Module,
            Union[
                torch_geometric.data.data.Data,
                Tuple[
                    torch.Tensor,
                    torch.Tensor,
                    torch.Tensor,
                    torch.Tensor,
                    torch.Tensor,
                    torch.Tensor,
                ],
            ],
        ],
        Tuple[float, float, float, float, torch.Tensor],
    ],
) -> Tuple[List[float], List[float], List[float], List[float], List[torch.Tensor]]:
    """Train the model.

    Args:
        model (torch.nn.Module): The model to train.
        data (Union[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): The graph data.
        nb_epochs (int): The number of epochs.
        train_func (Callable[[torch.nn.Module, Union[torch_geometric.data.data.Data, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]], Tuple[float, float, float, float torch.Tensor]]): function that train our model for one epoch

    Returns:
        Tuple[List[float], List[float], List[float], List[float], List[torch.Tensor]]: The losses on train and test, the accuracy on train and test, and the embeddings at each epochs.
    """
    top = perf_counter()  # calculate total training time
    model.train()
    losses_train = []
    losses_test = []
    accuracy_train = []
    accuracy_test = []
    embeddings = []
    for epoch in tqdm(range(1, nb_epochs + 1)):
        loss_train, loss_test, acc_train, acc_test, h = train_func(model, data)
        losses_train.append(loss_train)
        losses_test.append(loss_test)
        accuracy_train.append(acc_train)
        accuracy_test.append(acc_test)
        embeddings.append(h)
    print(f"Training finished! Total time: {round(perf_counter() - top, 3)}s.")
    return losses_train, losses_test, accuracy_train, accuracy_test, embeddings


Train the model for 400 epochs

In [None]:
nb_epochs = 400  # TODO: try a different number of epochs?
losses_train, losses_test, accuracy_train, accuracy_test, embeddings = train(
    model, data, nb_epochs, train_one_epoch
)


Plot the metrics over time

In [None]:
def plot_learning_curves(
    losses_train: List[float],
    losses_test: List[float],
    accuracy_train: List[float],
    accuracy_test: List[float],
) -> None:
    """Plot the learning curves.

    Args:
        losses_train (List[float]): The training losses.
        losses_test (List[float]): The test losses.
        accuracy_train (List[float]): The training accuracies.
        accuracy_test (List[float]): The test accuracies.
    """
    plt.figure(figsize=(16, 6))
    plt.subplot(121)
    plt.plot(losses_train, label="train")
    plt.plot(losses_test, label="test")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.subplot(122)
    # TODO: modify the code to plot the accuracy but in percentage!
    plt.plot(np.array(accuracy_train), label="train")
    plt.plot(np.array(accuracy_test), label="test")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy (%)")
    plt.legend()
    # TODO: modify the code to display the accuracies in percentage too!
    plt.suptitle(
        "Learning curves. Final train accuracy: {:.2f}%. Final test accuracy: {:.2f}%".format(
            accuracy_train[-1], accuracy_test[-1]
        )
    )
    plt.tight_layout()
    plt.show()

In [None]:
# TODO: plot the learning curves using the function above

Plot some of the embeddings during the training

In [None]:
def plot_multiple_embeddings(
    embeddings: List[torch.Tensor],
    losses_train: List[float],
    losses_test: List[float],
    epochs: List[int],
) -> None:
    """Plot the embeddings at each epoch of the list.

    Args:
        embeddings (List[torch.Tensor]): The embeddings at each epoch.
        losses_train (List[float]): The training losses.
        losses_test (List[float]): The test losses.
        epochs (List[int]): The epochs (starting from 1!).
    """
    fig = plt.figure()
    for epoch in epochs:
        h = embeddings[epoch - 1]
        plt.scatter(
            h[:, 0].detach().numpy(),
            h[:, 1].detach().numpy(),
            c=data.y,
            s=140,
            cmap="Set2",
        )
        plt.title(
            f"Epoch: {epoch}, loss_train: {losses_train[epoch-1]:.4f}, loss_test: {losses_test[epoch-1]:.4f}"
        )
        plt.xlabel("h0")
        plt.ylabel("h1")
        plt.show()


In [None]:
epochs = [1, 10, 50, 100, 200, 400]  # TODO: try different epochs?
plot_multiple_embeddings(embeddings, losses_train, losses_test, epochs)


For binary classification:

In [None]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)

data = dataset[0]
data.y = torch.tensor(y, dtype=torch.long)
test_size = 0.2  # TODO: try a different test size?
y_train, y_test, ind_train, ind_test = train_test_split(
    y, np.arange(data.num_nodes), test_size=test_size, random_state=seed, stratify=y
)
data.train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.train_mask[ind_train] = True
data.test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
data.test_mask[ind_test] = True


In [None]:
nb_epochs = 400  # TODO: try a different number of epochs?
losses_train, losses_test, accuracy_train, accuracy_test, embeddings = train(
    model, data, nb_epochs, train_one_epoch
)


In [None]:
# TODO: plot the learning curves using the function above


## Second example: Train a Hypergraph Networks with Hyperedge Neurons (HNHN)

Here are some other topological domains that we could use to try to improve our performance:

<img src="https://adriencarrel.com/images/topological_domains.png" alt="Drawing" style="width: 500px;"/>

For now, let's dive into a more complex Topological Deep Learning framework. We will **lift** our dataset into a hypergraph and we will train on it a Hypergraph Networks with Hyperedge Neurons in the hypergraph domain, as introduced in the paper: [Dong et al. : HNHN: Hypergraph networks with hyperedge neurons (2020)](https://grlplus.github.io/papers/40.pdf). We will also apply it on the Karate Club dataset.

But first, what are hypergraphs?

**Hypergraph:** Let $S$ be a non-empty set. A hypergraph on $S$ is a pair $(S, X)$, where $X$ is a set of non-empty subsets of the powerset $\mathcal{P}(S)$ of $S$, which are called hyperedges. Elements of $S$ are called vertices.

-------

The equations of one layer of the HNHN neural network are given by:

Message passing:

🟥 $\quad m_{y \rightarrow x}^{(0 \rightarrow 1)} = \sigma((B_1^T \cdot W^{(0)})_{xy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(0)} + b^{t,(0)})$

🟥 $\quad m_{y \rightarrow x}^{(1 \rightarrow 0)}  = \sigma((B_1 \cdot W^{(1)})_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,(1)} + b^{t,(1)})$

Within-Neighborhood Aggregation:

🟧 $\quad m_x^{(0 \rightarrow 1)}  = \sum_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(0 \rightarrow 1)}$

🟧 $\quad m_x^{(1 \rightarrow 0)}  = \sum_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(1 \rightarrow 0)}$

Between-Neighborhood Aggregation:

🟩 $\quad m_x^{(0)}  = m_x^{(1 \rightarrow 0)}$

🟩 $\quad m_x^{(1)}  = m_x^{(0 \rightarrow 1)}$

Update:

🟦 $\quad h_x^{t+1,(0)} = m_x^{(0)}$

🟦 $\quad h_x^{t+1,(1)} = m_x^{(1)}$


### Define the models and utility functions

* Scatter adaptated from torch_scatter/scatter.py from: [torch_scatter](https://github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py)
* Message passing and HNHN architecture from TopoModelX: [Message passing](https://github.com/pyt-team/TopoModelX/blob/main/topomodelx/base/message_passing.py) and [HNHN]()

In [None]:
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
    """Broadcasts `src` to the shape of `other`."""
    if dim < 0:
        dim = other.dim() + dim
    if src.dim() == 1:
        for _ in range(0, dim):
            src = src.unsqueeze(0)
    for _ in range(src.dim(), other.dim()):
        src = src.unsqueeze(-1)
    src = src.expand(other.size())
    return src


def scatter_sum(
    src: torch.Tensor,
    index: torch.Tensor,
    dim: int = -1,
    out: Optional[torch.Tensor] = None,
    dim_size: Optional[int] = None,
) -> torch.Tensor:
    """Add all values from the `src` tensor into `out` at the indices."""
    index = broadcast(index, src, dim)
    if out is None:
        size = list(src.size())
        if dim_size is not None:
            size[dim] = dim_size
        elif index.numel() == 0:
            size[dim] = 0
        else:
            size[dim] = int(index.max()) + 1
        out = torch.zeros(size, dtype=src.dtype, device=src.device)
        return out.scatter_add_(dim, index, src)
    else:
        return out.scatter_add_(dim, index, src)


def scatter_add(
    src: torch.Tensor,
    index: torch.Tensor,
    dim: int = -1,
    out: Optional[torch.Tensor] = None,
    dim_size: Optional[int] = None,
) -> torch.Tensor:
    """Add all values from the `src` tensor into `out` at the indices."""
    return scatter_sum(src, index, dim, out, dim_size)


def scatter_mean(
    src: torch.Tensor,
    index: torch.Tensor,
    dim: int = -1,
    out: Optional[torch.Tensor] = None,
    dim_size: Optional[int] = None,
) -> torch.Tensor:
    """Compute the mean value of all values from the `src` tensor into `out`."""
    out = scatter_sum(src, index, dim, out, dim_size)
    dim_size = out.size(dim)

    index_dim = dim
    if index_dim < 0:
        index_dim = index_dim + src.dim()
    if index.dim() <= index_dim:
        index_dim = index.dim() - 1

    ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
    count = scatter_sum(ones, index, index_dim, None, dim_size)
    count[count < 1] = 1
    count = broadcast(count, out, dim)
    if out.is_floating_point():
        out.true_divide_(count)
    else:
        out.div_(count, rounding_mode="floor")
    return out


SCATTER_DICT = {"sum": scatter_sum, "mean": scatter_mean, "add": scatter_sum}


def scatter(scatter):
    """Return the scatter function."""
    if isinstance(scatter, str) and scatter in SCATTER_DICT:
        return SCATTER_DICT[scatter]
    else:
        raise ValueError(
            f"scatter must be callable or string: {list(SCATTER_DICT.keys())}"
        )


class MessagePassing(torch.nn.Module):
    """MessagePassing.

    This class defines message passing through a single neighborhood N,
    by decomposing it into 2 steps:

    1. 🟥 Create messages going from source cells to target cells through N.
    2. 🟧 Aggregate messages coming from different sources cells onto each target cell.

    This class should not be instantiated directly, but rather inherited
    through subclasses that effectively define a message passing function.

    This class does not have trainable weights, but its subclasses should
    define these weights.

    Parameters
    ----------
    aggr_func : string
        Aggregation function to use.
    att : bool
        Whether to use attention.
    initialization : string
        Initialization method for the weights of the layer.

    References
    ----------
    .. [H23] Hajij, Zamzmi, Papamarkou, Miolane, Guzmán-Sáenz, Ramamurthy, Birdal, Dey,
        Mukherjee, Samaga, Livesay, Walters, Rosen, Schaub. Topological Deep Learning: Going Beyond Graph Data.
        (2023) https://arxiv.org/abs/2206.00606.

    .. [PSHM23] Papillon, Sanborn, Hajij, Miolane.
        Architectures of Topological Deep Learning: A Survey on Topological Neural Networks.
        (2023) https://arxiv.org/abs/2304.10031.
    """

    def __init__(
        self,
        aggr_func="sum",
        att=False,
        initialization="xavier_uniform",
    ):
        super().__init__()
        self.aggr_func = aggr_func
        self.att = att
        self.initialization = initialization
        assert initialization in ["xavier_uniform", "xavier_normal"]
        assert aggr_func in ["sum", "mean", "add"]

    def reset_parameters(self, gain=1.414):
        r"""Reset learnable parameters.

        Notes
        -----
        This function will be called by subclasses of
        MessagePassing that have trainable weights.

        Parameters
        ----------
        gain : float
            Gain for the weight initialization.
        """
        if self.initialization == "xavier_uniform":
            if self.weight is not None:
                torch.nn.init.xavier_uniform_(self.weight, gain=gain)
            if self.att:
                torch.nn.init.xavier_uniform_(self.att_weight.view(-1, 1), gain=gain)

        elif self.initialization == "xavier_normal":
            if self.weight is not None:
                torch.nn.init.xavier_normal_(self.weight, gain=gain)
            if self.att:
                torch.nn.init.xavier_normal_(self.att_weight.view(-1, 1), gain=gain)
        else:
            raise RuntimeError(
                "Initialization method not recognized. "
                "Should be either xavier_uniform or xavier_normal."
            )

    def message(self, x_source, x_target=None):
        """Construct message from source cells to target cells.

        🟥 This provides a default message function to the message passing scheme.

        Alternatively, users can subclass MessagePassing and overwrite
        the message method in order to replace it with their own message mechanism.

        Parameters
        ----------
        x_source : Tensor, shape=[..., n_source_cells, in_channels]
            Input features on source cells.
            Assumes that all source cells have the same rank r.
        x_target : Tensor, shape=[..., n_target_cells, in_channels]
            Input features on target cells.
            Assumes that all target cells have the same rank s.
            Optional. If not provided, x_target is assumed to be x_source,
            i.e. source cells send messages to themselves.

        Returns
        -------
        _ : Tensor, shape=[..., n_source_cells, in_channels]
            Messages on source cells.
        """
        return x_source

    def attention(self, x_source, x_target=None):
        """Compute attention weights for messages.

        This provides a default attention function to the message passing scheme.

        Alternatively, users can subclass MessagePassing and overwrite
        the attention method in order to replace it with their own attention mechanism.

        Details in [H23]_, Definition of "Attention Higher-Order Message Passing".

        Parameters
        ----------
        x_source : torch.Tensor, shape=[n_source_cells, in_channels]
            Input features on source cells.
            Assumes that all source cells have the same rank r.
        x_target : torch.Tensor, shape=[n_target_cells, in_channels]
            Input features on source cells.
            Assumes that all source cells have the same rank r.

        Returns
        -------
        _ : torch.Tensor, shape = [n_messages, 1]
            Attention weights: one scalar per message between a source and a target cell.
        """
        x_source_per_message = x_source[self.source_index_j]
        x_target_per_message = (
            x_source[self.target_index_i]
            if x_target is None
            else x_target[self.target_index_i]
        )

        x_source_target_per_message = torch.cat(
            [x_source_per_message, x_target_per_message], dim=1
        )

        return torch.nn.functional.elu(
            torch.matmul(x_source_target_per_message, self.att_weight)
        )

    def aggregate(self, x_message):
        """Aggregate messages on each target cell.

        A target cell receives messages from several source cells.
        This function aggregates these messages into a single output
        feature per target cell.

        🟧 This function corresponds to the within-neighborhood aggregation
        defined in [H23]_ and [PSHM23]_.

        Parameters
        ----------
        x_messages : Tensor, shape=[..., n_messages, out_channels]
            Features associated with each message.
            One message is sent from a source cell to a target cell.

        Returns
        -------
        _ : Tensor, shape=[...,  n_target_cells, out_channels]
            Output features on target cells.
            Each target cell aggregates messages from several source cells.
            Assumes that all target cells have the same rank s.
        """
        aggr = scatter(self.aggr_func)
        return aggr(x_message, self.target_index_i, 0)

    def forward(self, x_source, neighborhood, x_target=None):
        r"""Forward pass.

        This implements message passing for a given neighborhood:

        - from source cells with input features `x_source`,
        - via `neighborhood` defining where messages can pass,
        - to target cells with input features `x_target`.

        In practice, this will update the features on the target cells.

        If not provided, x_target is assumed to be x_source,
        i.e. source cells send messages to themselves.

        The message passing is decomposed into two steps:

        1. 🟥 Message: A message :math:`m_{y \rightarrow x}^{\left(r \rightarrow s\right)}`
        travels from a source cell :math:`y` of rank r to a target cell :math:`x` of rank s
        through a neighborhood of :math:`x`, denoted :math:`\mathcal{N} (x)`,
        via the message function :math:`M_\mathcal{N}`:

        .. math::
            m_{y \rightarrow x}^{\left(r \rightarrow s\right)}
                = M_{\mathcal{N}}\left(\mathbf{h}_x^{(s)}, \mathbf{h}_y^{(r)}, \Theta \right),

        where:

        - :math:`\mathbf{h}_y^{(r)}` are input features on the source cells, called `x_source`,
        - :math:`\mathbf{h}_x^{(s)}` are input features on the target cells, called `x_target`,
        - :math:`\Theta` are optional parameters (weights) of the message passing function.

        Optionally, attention can be applied to the message, such that:

        .. math::
            m_{y \rightarrow x}^{\left(r \rightarrow s\right)}
                \leftarrow att(\mathbf{h}_y^{(r)}, \mathbf{h}_x^{(s)}) . m_{y \rightarrow x}^{\left(r \rightarrow s\right)}

        2. 🟧 Aggregation: Messages are aggregated across source cells :math:`y` belonging to the
        neighborhood :math:`\mathcal{N}(x)`:

        .. math::
            m_x^{\left(r \rightarrow s\right)}
                = \text{AGG}_{y \in \mathcal{N}(x)} m_{y \rightarrow x}^{\left(r\rightarrow s\right)},

        resulting in the within-neighborhood aggregated message :math:`m_x^{\left(r \rightarrow s\right)}`.

        Details in [H23]_ and [PSHM23]_ "The Steps of Message Passing".

        Parameters
        ----------
        x_source : Tensor, shape=[..., n_source_cells, in_channels]
            Input features on source cells.
            Assumes that all source cells have the same rank r.
        neighborhood : torch.sparse, shape=[n_target_cells, n_source_cells]
            Neighborhood matrix.
        x_target : Tensor, shape=[..., n_target_cells, in_channels]
            Input features on target cells.
            Assumes that all target cells have the same rank s.
            Optional. If not provided, x_target is assumed to be x_source,
            i.e. source cells send messages to themselves.

        Returns
        -------
        _ : Tensor, shape=[..., n_target_cells, out_channels]
            Output features on target cells.
            Assumes that all target cells have the same rank s.
        """
        neighborhood = neighborhood.coalesce()
        self.target_index_i, self.source_index_j = neighborhood.indices()
        neighborhood_values = neighborhood.values()

        x_message = self.message(x_source=x_source, x_target=x_target)
        x_message = x_message.index_select(-2, self.source_index_j)

        if self.att:
            attention_values = self.attention(x_source=x_source, x_target=x_target)
            neighborhood_values = torch.multiply(neighborhood_values, attention_values)

        x_message = neighborhood_values.view(-1, 1) * x_message
        return self.aggregate(x_message)


class Conv(MessagePassing):
    """Message passing: steps 1, 2, and 3.

    Builds the message passing route given by one neighborhood matrix.
    Includes an option for a x-specific update function.

    Parameters
    ----------
    in_channels : int
        Dimension of input features.
    out_channels : int
        Dimension of output features.
    aggr_norm : bool
        Whether to normalize the aggregated message by the neighborhood size.
    update_func : string
        Update method to apply to message.
    att : bool
        Whether to use attention.
        Optional, default: False.
    initialization : string
        Initialization method.
    with_linear_transform: bool
        Whether to apply a learnable linear transform.
        NB: if false in_channels has to be equal to out_channels
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        aggr_norm=False,
        update_func=None,
        att=False,
        initialization="xavier_uniform",
        with_linear_transform=True,
    ):
        super().__init__(
            att=att,
            initialization=initialization,
        )
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.aggr_norm = aggr_norm
        self.update_func = update_func

        self.weight = (
            Parameter(torch.Tensor(self.in_channels, self.out_channels))
            if with_linear_transform
            else None
        )

        if not with_linear_transform and in_channels != out_channels:
            raise ValueError(
                "With `linear_trainsform=False`, in_channels has to be equal to out_channels"
            )
        if self.att:
            self.att_weight = Parameter(
                torch.Tensor(
                    2 * self.in_channels,
                )
            )

        self.reset_parameters()

    def update(self, x_message_on_target, x_target=None):
        """Update embeddings on each cell (step 4).

        Parameters
        ----------
        x_message_on_target : torch.Tensor, shape=[n_target_cells, out_channels]
            Output features on target cells.

        Returns
        -------
        _ : torch.Tensor, shape=[n_target_cells, out_channels]
            Updated output features on target cells.
        """
        if self.update_func == "sigmoid":
            return torch.sigmoid(x_message_on_target)
        if self.update_func == "relu":
            return torch.nn.functional.relu(x_message_on_target)

    def forward(self, x_source, neighborhood, x_target=None):
        """Forward pass.

        This implements message passing:
        - from source cells with input features `x_source`,
        - via `neighborhood` defining where messages can pass,
        - to target cells with input features `x_target`.

        In practice, this will update the features on the target cells.

        If not provided, x_target is assumed to be x_source,
        i.e. source cells send messages to themselves.

        Parameters
        ----------
        x_source : Tensor, shape=[..., n_source_cells, in_channels]
            Input features on source cells.
            Assumes that all source cells have the same rank r.
        neighborhood : torch.sparse, shape=[n_target_cells, n_source_cells]
            Neighborhood matrix.
        x_target : Tensor, shape=[..., n_target_cells, in_channels]
            Input features on target cells.
            Assumes that all target cells have the same rank s.
            Optional. If not provided, x_target is assumed to be x_source,
            i.e. source cells send messages to themselves.

        Returns
        -------
        _ : Tensor, shape=[..., n_target_cells, out_channels]
            Output features on target cells.
            Assumes that all target cells have the same rank s.
        """
        if self.att:
            neighborhood = neighborhood.coalesce()
            self.target_index_i, self.source_index_j = neighborhood.indices()
            attention_values = self.attention(x_source, x_target)
            neighborhood = torch.sparse_coo_tensor(
                indices=neighborhood.indices(),
                values=attention_values * neighborhood.values(),
                size=neighborhood.shape,
            )
        if self.weight is not None:
            x_message = torch.mm(x_source, self.weight)
        else:
            x_message = x_source
        x_message_on_target = torch.mm(neighborhood, x_message)

        if self.aggr_norm:
            neighborhood_size = torch.sum(neighborhood.to_dense(), dim=1)
            x_message_on_target = torch.einsum(
                "i,ij->ij", 1 / neighborhood_size, x_message_on_target
            )

        if self.update_func is None:
            return x_message_on_target

        return self.update(x_message_on_target, x_target)


class HNHNLayer(torch.nn.Module):
    """Layer of a Hypergraph Networks with Hyperedge Neurons (HNHN).

    Implementation of a simplified version of the HNHN layer proposed in [DSB20]_.

    This layer is composed of two convolutional layers:
    1. A convolutional layer sending messages from edges to nodes.
    2. A convolutional layer sending messages from nodes to edges.
    The incidence matrices can be normalized usign the node and edge cardinality.
    Two hyperparameters alpha and beta, control the normalization strenght.
    The convolutional layers support the training of a bias term.

    Notes
    -----
    This is the architecture proposed for node classification.

    References
    ----------
    .. [DSB20] Dong, Sawin, Bengio.
        HNHN: Hypergraph networks with hyperedge neurons.
        Graph Representation Learning and Beyond Workshop at ICML 2020
        https://grlplus.github.io/papers/40.pdf


    Parameters
    ----------
    channels_node : int
        Dimension of node features.
    channels_edge : int
        Dimension of edge features.
    incidence_1 : torch.sparse
        Incidence matrix mapping edges to nodes (B_1).
        shape=[n_nodes, n_edges]
    use_bias : bool
        Flag controlling whether to use a bias term in the convolution.
    use_normalized_incidence : bool
        Flag controlling whether to normalize the incidence matrices.
    alpha : float
        Scalar controlling the importance of edge cardinality.
    beta : float
        Scalar controlling the importance of node cardinality.
    bias_gain : float
        Gain for the bias initialization.
    bias_init : string ["xavier_uniform"|"xavier_normal"]
        Controls the bias initialization method.
    """

    def __init__(
        self,
        channels_node,
        channels_edge,
        incidence_1,
        use_bias=True,
        use_normalized_incidence=True,
        alpha=-1.5,
        beta=-0.5,
        bias_gain=1.414,
        bias_init="xavier_uniform",
    ):
        super().__init__()
        self.use_bias = use_bias
        self.bias_init = bias_init
        self.bias_gain = bias_gain
        self.use_normalized_incidence = use_normalized_incidence
        self.incidence_1 = incidence_1
        self.incidence_1_transpose = incidence_1.transpose(1, 0)
        self.channels_edge = channels_edge
        self.channels_node = channels_node
        self.conv_1_to_0 = Conv(
            in_channels=channels_edge,
            out_channels=channels_node,
            aggr_norm=False,
            update_func=None,
        )
        self.conv_0_to_1 = Conv(
            in_channels=channels_node,
            out_channels=channels_edge,
            aggr_norm=False,
            update_func=None,
        )
        if self.use_bias:
            self.bias_1_to_0 = Parameter(torch.Tensor(1, channels_node))
            self.bias_0_to_1 = Parameter(torch.Tensor(1, channels_edge))
            self.init_biases()
        if self.use_normalized_incidence:
            self.alpha = alpha
            self.beta = beta
            self.n_nodes, self.n_edges = self.incidence_1.shape
            self.compute_normalization_matrices()
            self.normalize_incidence_matrices()

    def compute_normalization_matrices(self):
        """Compute the normalization matrices for the incidence matrices."""
        B1 = self.incidence_1.to_dense()
        edge_cardinality = (B1.sum(0)) ** self.alpha
        node_cardinality = (B1.sum(1)) ** self.beta

        # Compute D0_left_alpha_inverse
        self.D0_left_alpha_inverse = torch.zeros(self.n_nodes, self.n_nodes)
        for i_node in range(self.n_nodes):
            self.D0_left_alpha_inverse[i_node, i_node] = 1 / (
                edge_cardinality[B1[i_node, :].bool()].sum()
            )

        # Compute D1_left_beta_inverse
        self.D1_left_beta_inverse = torch.zeros(self.n_edges, self.n_edges)
        for i_edge in range(self.n_edges):
            self.D1_left_beta_inverse[i_edge, i_edge] = 1 / (
                node_cardinality[B1[:, i_edge].bool()].sum()
            )

        # Compute D1_right_alpha
        self.D1_right_alpha = torch.diag(edge_cardinality)

        # Compute D0_right_beta
        self.D0_right_beta = torch.diag(node_cardinality)
        return

    def normalize_incidence_matrices(self):
        """Normalize the incidence matrices."""
        self.incidence_1 = (
            self.D0_left_alpha_inverse
            @ self.incidence_1.to_dense()
            @ self.D1_right_alpha
        ).to_sparse()
        self.incidence_1_transpose = (
            self.D1_left_beta_inverse
            @ self.incidence_1_transpose.to_dense()
            @ self.D0_right_beta
        ).to_sparse()
        return

    def init_biases(self):
        """Initialize the bias."""
        for bias in [self.bias_0_to_1, self.bias_1_to_0]:
            if self.bias_init == "xavier_uniform":
                torch.nn.init.xavier_uniform_(bias, gain=self.bias_gain)
            elif self.bias_init == "xavier_normal":
                torch.nn.init.xavier_normal_(bias, gain=self.bias_gain)

    def reset_parameters(self):
        """Reset learnable parameters."""
        self.conv_1_to_0.reset_parameters()
        self.conv_0_to_1.reset_parameters()
        if self.use_bias:
            self.init_biases()

    def forward(self, x_0, x_1):
        r"""Forward computation.

        The forward pass was initially proposed in [DSB20]_.
        Its equations are given in [TNN23]_ and graphically illustrated in [PSHM23]_.

        The equations of one layer of this neural network are given by:
        .. math::
        \begin{align*}
        &🟥 $\quad m_{y \rightarrow x}^{(0 \rightarrow 1)} = \sigma((B_1^T \cdot W^{(0)})_{xy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(0)} + b^{t,(0)})$

        &🟥 $\quad m_{y \rightarrow x}^{(1 \rightarrow 0)}  = \sigma((B_1 \cdot W^{(1)})_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,(1)} + b^{t,(1)})$

        &🟧 $\quad m_x^{(0 \rightarrow 1)}  = \sum_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(0 \rightarrow 1)}$

        &🟧 $\quad m_x^{(1 \rightarrow 0)}  = \sum_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(1 \rightarrow 0)}$

        &🟩 $\quad m_x^{(0)}  = m_x^{(1 \rightarrow 0)}$

        &🟩 $\quad m_x^{(1)}  = m_x^{(0 \rightarrow 1)}$

        &🟦 $\quad h_x^{t+1,(0)}  = m_x^{(0)}$

        &🟦 $\quad h_x^{t+1,(1)} = m_x^{(1)}$
        \end{align*}

        References
        ----------
        .. [DSB20] Dong, Sawin, Bengio.
            HNHN: Hypergraph networks with hyperedge neurons.
            Graph Representation Learning and Beyond Workshop at ICML 2020
            https://grlplus.github.io/papers/40.pdf
        .. [TNN23] Equations of Topological Neural Networks.
            https://github.com/awesome-tnns/awesome-tnns/
        .. [PSHM23] Papillon, Sanborn, Hajij, Miolane.
            Architectures of Topological Deep Learning: A Survey on Topological Neural Networks.
            (2023) https://arxiv.org/abs/2304.10031.

        Parameters
        ----------
        x_0 : torch.Tensor, shape=[n_nodes, channels_node]
            Input features on the hypernodes
        x_1 : torch.Tensor, shape=[n_edges, channels_edge]
            Input features on the hyperedges

        Returns
        -------
        x_0 : torch.Tensor, shape=[n_nodes, channels_node]
            Output features on the hypernodes
        x_1 : torch.Tensor, shape=[n_edges, channels_edge]
            Output features on the hyperedges
        """
        # Move incidence matrices to device
        self.incidence_1 = self.incidence_1.to(x_0.device)
        self.incidence_1_transpose = self.incidence_1_transpose.to(x_0.device)
        # Compute output hyperedge features
        x_1_tp1 = self.conv_0_to_1(x_0, self.incidence_1_transpose)  # nodes to edges
        if self.use_bias:
            x_1_tp1 += self.bias_0_to_1
        # Compute output hypernode features
        x_0_tp1 = self.conv_1_to_0(x_1, self.incidence_1)  # edges to nodes
        if self.use_bias:
            x_0_tp1 += self.bias_1_to_0
        return torch.sigmoid(x_0_tp1), torch.sigmoid(x_1_tp1)


### Preprocessing, preliminary data analysis

#### Import dataset

First, we lift our graph dataset into the hypergraph domain.

In [None]:
import toponetx.datasets.graph as graph


dataset_sim = graph.karate_club(complex_type="simplicial")  # import through toponetx
dataset_hyp = dataset_sim.to_hypergraph()  # convert simplicial complex to hypergraph


#### Visualize our new dataset

In [None]:
def visualize_hypergraph(hypergraph: hypernetx.classes.hypergraph.Hypergraph) -> None:
    """Visualize a hypergraph.

    Args:
        hypergraph (hypernetx.classes.hypergraph.Hypergraph): A hypergraph.
    """
    plt.figure(figsize=(7, 7))
    plt.xticks([])
    plt.yticks([])
    # TODO: draw the hypergraph using hypernetx
    plt.show()


def filter_hypergraph(
    hypergraph: hypernetx.classes.hypergraph.Hypergraph, min_size: int, max_size: int
) -> hypernetx.classes.hypergraph.Hypergraph:
    """Returns a hypergraph by keeping only the hyperedges with at most size nodes.

    Args:
        hypergraph (hypernetx.classes.hypergraph.Hypergraph): A hypergraph.
        min_size (int): The minimum size of the hyperedges to keep.
        max_size (int): The maximum size of the hyperedges to keep.

    Returns:
        hypernetx.classes.hypergraph.Hypergraph: A hypergraph.
    """
    return hypernetx.classes.hypergraph.Hypergraph(
        {
            hyperedge: list(hypergraph.edges[hyperedge])
            for hyperedge in list(hypergraph.edges)
            if (len(list(hypergraph.edges[hyperedge])) <= max_size)
            and (min_size <= len(list(hypergraph.edges[hyperedge])))
        }
    )

In [None]:
# TODO: visualize the hypergraph dataset_hyp using the function above


In [None]:
reduced_hypergraph = filter_hypergraph(dataset_hyp, min_size=2, max_size=2)
# TODO: visualize the hypergraph reduced_hypergraph using the function above


#### Define neighborhood structures

From the hypergraph, we now retrieve the neighborhood structures that we will use in the message passing. In our case, we just need the boundary matrix (or incidence matrix) $B_1$ (or $B_{0,1}$). The shape of the matrix is: $(n_\text{nodes} \times n_\text{edges})$. We will also convert the neighborhood structures to sparse torch tensors that will be feed into our model.

In [None]:
incidence_1 = dataset_sim.incidence_matrix(rank=1, signed=False)
incidence_1 = torch.from_numpy(incidence_1.todense()).to_sparse()
print(f"The incidence matrix B1 has shape: {incidence_1.shape}.")


#### Import the features and labels

Our features will be:
- node features $X_0$ of shape: $n_{nodes} \times channels_{node}$
- edge features $X_1$ of shape: $n_{edges} \times channels_{edge}$

In our case, we have $channels_{node}$ = $channels_{edge}$ = 2. The features are the eigenvalues of the hodge laplacian. Next, we will retrieve node features, edge features, and node labels $y$ that are one hot encoded.

In [None]:
# Node features
x_0 = []
for _, v in dataset_sim.get_simplex_attributes("node_feat").items():
    x_0.append(v)
x_0 = torch.tensor(np.stack(x_0)).to(device)
n_nodes, channels_node = x_0.shape
print(f"There are {n_nodes} nodes with features of dimension {channels_node}.")


In [None]:
# Edge features
x_1 = []
for k, v in dataset_sim.get_simplex_attributes("edge_feat").items():
    x_1.append(v)
x_1 = torch.tensor(np.stack(x_1)).to(device)
n_edges, channels_edge = x_1.shape
print(f"There are {n_edges} edges with features of dimension {channels_edge}.")


In [None]:
# Node labels
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)
n_classes = len(np.unique(y))
y_1h = np.eye(n_classes)[y].astype(int)  # 1-hot representation
print(f"There are {y_1h.shape[0]} labels, one for each node.")


We will now split the dataset into a train (80%) and test set (20%). To deal with label imbalance, we will stratify our split.

In [None]:
test_size = 0.2  # 20%
y_train, y_test, ind_train, ind_test = train_test_split(
    y_1h, np.arange(n_nodes), test_size=test_size, random_state=seed, stratify=y
)
y_train = torch.tensor(y_train, dtype=torch.float32).to(device)
y_test = torch.tensor(y_test, dtype=torch.float32).to(device)
y = torch.tensor(y, dtype=torch.int32).to(device)
print(
    f"Fraction of class-1 samples in the training set: {round(100 * torch.sum(y_train[:,0]).item() / y_train.shape[0], 3)}%"
)
print(
    f"Fraction of class-1 samples in the test set: {round(100 * torch.sum(y_test[:,0]).item() / y_test.shape[0], 3)}%"
)


### Create the Neural Network

Using the HNHNLayer class, we create a neural network for node classification. We will call it HNHNNetwork.

In [None]:
class HNHNNetwork(torch.nn.Module):
    """Hypergraph Networks with Hyperedge Neurons. Implementation for multiclass node classification."""

    def __init__(
        self,
        channels_node: int,
        channels_edge: int,
        incidence_1: torch.sparse,
        n_classes: int,
        n_layers: int = 2,
    ) -> None:
        """Initialize the model.

        Args:
            channels_node (int): Dimension of node features.
            channels_edge (int): Dimension of edge features.
            incidence_1 (torch.sparse): Incidence matrix mapping edges to nodes (B_1).
                shape=[n_nodes, n_edges]
            n_classes (int): Number of classes
            n_layers (int, optional): Number of HNHN message passing layers. Defaults to 2.
        """
        super().__init__()
        self.layers = torch.nn.ModuleList(
            [
                HNHNLayer(
                    channels_node=channels_node,
                    channels_edge=channels_edge,
                    incidence_1=incidence_1,
                )
                for _ in range(n_layers)
            ]
        )
        self.linear = torch.nn.Linear(
            channels_node, n_classes
        )  # final prediction layer

    def forward(
        self, x_0: torch.Tensor, x_1: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward computation.

        Args:
            x_0 (torch.Tensor): shape = [n_nodes, channels_node]
                    Hypernode features.
            x_1 (torch.Tensor): shape = [n_nodes, channels_edge]
                Hyperedge features.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The predicted node logits with shape = [n_nodes, n_classes],
                and The predicted node class with shape = [n_nodes].
        """
        for layer in self.layers:
            x_0, x_1 = layer(x_0, x_1)
        logits = self.linear(x_0)
        classes = torch.softmax(logits, -1).argmax(-1)
        return logits, classes


### Train the Neural Network

We initialize the HNHNNetwork model with our neighborhood structures and specify the same optimizer and loss as before.

In [None]:
n_layers = 2
learning_rate = 0.5 * 1e-2

model = HNHNNetwork(
    channels_node=channels_node,
    channels_edge=channels_edge,
    incidence_1=incidence_1,
    n_classes=n_classes,
    n_layers=n_layers,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()


Print the number of parameters of our model

In [None]:
print(f"Number of parameters HNHNNetwork: {get_nb_parameters(model)}")


Way less parameters!

Next, we train the model for 2000 epochs.

In [None]:
def train_HNHN_one_epoch(
    model: torch.nn.Module,
    data: Tuple[
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
        torch.Tensor,
    ],
) -> Tuple[float, float, float, float, torch.Tensor]:
    """Train the model for 1 epoch.

    Args:
        model (torch.nn.Module): The model to train.
        data (Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): The graph data (x_0, x_1, y_train, y_test, ind_train, ind_test).

    Returns:
        Tuple[float, float, float, float, torch.Tensor]: The losses on train and test, the accuracy on train and test, and the embedding h after one epoch/step.
    """
    model.train()
    optimizer.zero_grad()  # clear gradients.
    x_0, x_1, y_train, y_test, ind_train, ind_test = data
    logits, classes = model(x_0, x_1)  # perform a single forward pass.
    loss_train = criterion(
        logits[ind_train], y_train
    )  # compute the loss on the training nodes.
    model.eval()
    with torch.no_grad():  # no gradients for test pass.
        loss_test = criterion(logits[ind_test], y_test)
        acc_train = torch.sum(classes[ind_train] == y_train.argmax(dim=1)).item() / len(
            ind_train
        )
        acc_test = torch.sum(classes[ind_test] == y_test.argmax(dim=1)).item() / len(
            ind_test
        )
    model.train()
    loss_train.backward()  # derive gradients.
    optimizer.step()  # update parameters based on gradients.
    return loss_train.item(), loss_test.item(), acc_train, acc_test, x_0


In [None]:
nb_epochs = 1000  # TODO: change the number of epochs?
data = [
    x_0,
    x_1,
    y_train,
    y_test,
    ind_train,
    ind_test,
]
losses_train, losses_test, accuracy_train, accuracy_test, embeddings = train(
    model, data, nb_epochs, train_HNHN_one_epoch
)


Finally, we plot the training results

In [None]:
# TODO: plot the learning curves


## Healthcare application

Let's apply our models to QM9!

We define first some utility functions.

In [None]:
def mols_to_nx(mols: List[Chem.Mol]) -> List[nx.Graph]:
    """Converts a list of molecules to a list of networkx graphs.

    Args:
        mols (List[Chem.Mol]): list of molecules

    Returns:
        List[nx.Graph]: list of networkx graphs
    """
    nx_graphs = []
    for mol in mols:
        G = nx.Graph()

        for atom in mol.GetAtoms():
            G.add_node(atom.GetIdx(), label=atom.GetSymbol())
            # Other potential features:
            # atomic_num=atom.GetAtomicNum()
            # formal_charge=atom.GetFormalCharge()
            # chiral_tag=atom.GetChiralTag()
            # hybridization=atom.GetHybridization()
            # num_explicit_hs=atom.GetNumExplicitHs()
            # is_aromatic=atom.GetIsAromatic()

        for bond in mol.GetBonds():
            G.add_edge(
                bond.GetBeginAtomIdx(),
                bond.GetEndAtomIdx(),
                label=int(bond.GetBondTypeAsDouble()),
            )
            # Other potential feature:
            # bond_type=bond.GetBondType()

        nx_graphs.append(G)
    return nx_graphs


def pad_adjs(ori_adj: np.ndarray, node_number: int) -> np.ndarray:
    """Create padded adjacency matrices

    Args:
        ori_adj (np.ndarray): original adjacency matrix
        node_number (int): number of desired nodes

    Raises:
        ValueError: if the original adjacency matrix is larger than the desired number of nodes (we can't pad)

    Returns:
        np.ndarray: Padded adjacency matrix
    """
    if not (ori_adj.size):  # empty
        return np.zeros((node_number, node_number), dtype=np.float32)
    a = ori_adj
    ori_len = a.shape[-1]
    if ori_len == node_number:  # same shape
        return a
    if ori_len > node_number:
        raise ValueError(
            f"Original number of nodes {ori_len} is greater (>) that the desired number of nodes after padding {node_number}"
        )
    # Pad
    a = np.concatenate([a, np.zeros([ori_len, node_number - ori_len])], axis=-1)
    a = np.concatenate([a, np.zeros([node_number - ori_len, node_number])], axis=0)
    return a


def graphs_to_tensor(graph_list: List[nx.Graph], max_node_num: int) -> torch.Tensor:
    """Convert a list of graphs to a tensor

    Args:
        graph_list (List[nx.Graph]): List of graphs to convert to adjacency matrices tensors
        max_node_num (int): max number of nodes in all the graphs

    Returns:
        torch.Tensor: Tensor of adjacency matrices
    """
    adjs_list = []
    max_node_num = max_node_num  # memory issue

    for g in graph_list:
        assert isinstance(g, nx.Graph)
        node_list = []
        for v, feature in g.nodes.data("feature"):
            node_list.append(v)

        # convert to adj matrix
        adj = nx.to_numpy_array(g, nodelist=node_list)
        padded_adj = pad_adjs(adj, node_number=max_node_num)  # pad to max node number
        adjs_list.append(padded_adj)

    del graph_list

    adjs_np = np.asarray(adjs_list)  # concatenate the arrays
    del adjs_list

    adjs_tensor = torch.tensor(adjs_np, dtype=torch.float32)  # convert to tensor
    del adjs_np

    return adjs_tensor


def plot_molecules(mols: List[Chem.Mol], max_num: int = 16, shift: int = 100) -> None:
    """Plot multiple molecules (max_num) at the same time.

    Args:
        mols (List[Chem.Mol]): List of molecules to plot.
        max_num (int, optional): number of molecules to plot in a square image. Defaults to 16.
        shift (int, optional): shift to plot starting at index shift. Defaults to 100.
    """
    img_c = int(math.ceil(np.sqrt(max_num)))
    figure = plt.figure()

    for idx in range(max_num):
        mol = mols[idx + shift]

        assert isinstance(
            mol, Chem.Mol
        ), "elements should be molecules"  # check if we have a molecule

        ax = plt.subplot(img_c, img_c, idx + 1)
        mol_img = Draw.MolToImage(mol, size=(300, 300))
        ax.imshow(mol_img)
        title_str = f"{Chem.MolToSmiles(mol)}"
        ax.title.set_text(title_str)
        ax.set_axis_off()
    figure.suptitle("Plot of molecules")
    plt.show()


def convert_adjacency_to_edge_index(A: torch.Tensor) -> torch.Tensor:
    """Converts an adjacency matrix to edge indices for PyTorch Geometric.
    If data.edge_index.size()[1] == 0, then the graph is empty or only made of nodes. Let's add self loops.

    Args:
        A (torch.Tensor): Adjacency matrix

    Returns:
        torch.Tensor: Edge indices
    """
    A_ = A.to_dense()
    edge_index = torch.nonzero(A_, as_tuple=False).t().contiguous().int()
    if not (edge_index.size()[1]):
        edge_index = torch.tensor([[0], [0]]).int()
    return edge_index


In [None]:
def load_qm9(
    folder: str = "./", nb_mol: int = 100, features: List[str] = ["mu", "alpha"]
) -> Tuple[
    List[torch_geometric.data.data.Data],
    List[torch_geometric.data.data.Data],
    List[Chem.Mol],
    List[Chem.Mol],
    List[torch.Tensor],
    List[torch.Tensor],
]:
    """Load the QM9 dataset.

    Args:
        folder (str, optional): path to the folder where the qm9.csv file is located. Defaults to "./".
        nb_mol (int, optional): number of molecules to load. Defaults to 1000.
        features (List[str], optional): features to load. Defaults to ["mu", "alpha"].

    Returns:
        Tuple[List[torch_geometric.data.data.Data], List[torch_geometric.data.data.Data], List[Chem.Mol], List[Chem.Mol], List[torch.Tensor], List[torch.Tensor]]: List of PyTorch Geometric Data objects, one for each molecules, for train and test, the lists of molecules finally the results of most of them.
    """
    data = pd.read_csv(os.path.join(folder, "qm9.csv"))
    data = data[["SMILES1"] + features][:nb_mol]

    max_node_num = 9  # max number of nodes in the QM9 dataset

    with open(os.path.join(folder, f"valid_idx_qm9.json")) as f:
        test_idx = json.load(f)

    test_idx = test_idx["valid_idxs"]
    test_idx = set([int(i) for i in test_idx if int(i) < nb_mol])
    train_idx = set(
        [i for i in range(len(data)) if i not in test_idx and (int(i) < nb_mol)]
    )

    y_train = torch.tensor(
        data[features].values[np.array(list(train_idx))], dtype=torch.float32
    )
    y_test = torch.tensor(
        data[features].values[np.array(list(test_idx))], dtype=torch.float32
    )

    train_mols = [
        Chem.MolFromSmiles(smiles)
        for i, smiles in enumerate(data["SMILES1"])
        if i in train_idx
    ]
    train_graphs = mols_to_nx(train_mols)
    test_mols = [
        Chem.MolFromSmiles(smiles)
        for i, smiles in enumerate(data["SMILES1"])
        if i in test_idx
    ]
    test_graphs = mols_to_nx(test_mols)
    return (
        [
            torch_geometric.data.data.Data(
                x=torch.diag(
                    torch.tensor(
                        [
                            1.0 if i in set(list(g.nodes)) else 0.0
                            for i in range(max_node_num)
                        ],
                        dtype=torch.float32,
                    )
                ),
                edge_index=convert_adjacency_to_edge_index(
                    graphs_to_tensor([g], max_node_num=max_node_num)[0]
                ),
            )
            for i, g in enumerate(train_graphs)
        ],
        [
            torch_geometric.data.data.Data(
                x=torch.diag(
                    torch.tensor(
                        [
                            1.0 if i in set(list(g.nodes)) else 0.0
                            for i in range(max_node_num)
                        ],
                        dtype=torch.float32,
                    )
                ),
                edge_index=convert_adjacency_to_edge_index(
                    graphs_to_tensor([g], max_node_num=max_node_num)[0]
                ),
            )
            for i, g in enumerate(test_graphs)
        ],
        train_mols,
        test_mols,
        y_train,
        y_test,
    )


def print_top_qm9_and_features(folder: str = "./") -> None:
    """Print the first lines of QM9 and the list of all features.

    Args:
        folder (str, optional): path to the folder that contains qm9.csv. Defaults to "./".
    """
    data = pd.read_csv(os.path.join(folder, "qm9.csv"))
    print("Features QM9 dataset:")
    print(list(data.columns))
    print("First 5 rows:")
    print(data.head(5))

In [None]:
print_top_qm9_and_features(folder="./")


Load the QM9 dataset

In [None]:
folder = "./"
nb_mol = 100  # TODO: try to increase the number of molecules
features = ["mu", "alpha"]  # TODO: try to add more features or some other features
train_data, test_data, train_mols, test_mols, y_train, y_test = load_qm9(
    folder=folder, nb_mol=nb_mol, features=features
)


Plot some molecules

In [None]:
plot_molecules(train_mols, max_num=16, shift=50)


In [None]:
max_node_number = 9
model = GCN(
    num_features=max_node_number,
    hidden_dim=4,
    embedding_dim=2,
    num_classes=len(features),
).to(device)


In [None]:
print(f"Number of parameters GCN: {get_nb_parameters(model)}")


In [None]:
learning_rate = 0.01  # TODO: change the learning rate?

criterion = torch.nn.MSELoss()  # mean squared error (MSE) loss for regression.
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)  # Adam optimizer.


Small modification to our GNN, to make a prediction for the entire graph, we will average the predictions made on each nodes!

This is done below by adding the line `out = out.avg(dim=0)`.

In [None]:
def train_one_epoch_reg(
    model: torch.nn.Module,
    data_train: torch_geometric.data.data.Data,
    data_test: torch_geometric.data.data.Data,
    y_train: torch.Tensor,
    y_test: torch.Tensor,
) -> Tuple[float, float, float, float, torch.Tensor]:
    """Train the model for 1 epoch.

    Args:
        model (torch.nn.Module): The model to train.
        data_train (torch_geometric.data.data.Data): The train graph data.
        data_test (torch_geometric.data.data.Data): The test graph data.
        y_train (torch.Tensor): The train labels.
        y_test (torch.Tensor): The test labels.

    Returns:
        Tuple[float, float, float, float, torch.Tensor]: The losses on train and test, the MSE on train and test, and the embedding h after one epoch/step.
    """
    model.train()
    optimizer.zero_grad()  # clear gradients.
    avg_train_loss = torch.tensor(0.0, dtype=torch.float32)
    avg_train_mse = torch.tensor(0.0, dtype=torch.float32)
    for i, data in enumerate(data_train):
        out, h = model(data.x, data.edge_index)  # perform a single forward pass.
        out = out.mean(
            dim=0
        )  # average all the node predictions for the graph prediction!
        loss_train = criterion(out, y_train[i])  # compute the loss
        avg_train_loss += loss_train
        avg_train_mse += mean_squared_error(y_train[i], out.detach().numpy())
    avg_train_loss /= len(data_train)
    avg_train_mse /= len(data_train)
    model.eval()
    with torch.no_grad():  # no gradients for test pass.
        avg_test_loss = torch.tensor(0.0, dtype=torch.float32)
        avg_test_mse = torch.tensor(0.0, dtype=torch.float32)
        for i, data in enumerate(data_test):
            out, h = model(data.x, data.edge_index)
            out = out.mean(
                dim=0
            )  # average all the node predictions for the graph prediction!
            loss_test = criterion(out, y_test[i])
            avg_test_loss += loss_test
            avg_test_mse += mean_squared_error(y_test[i], out.detach().numpy())
        avg_test_loss /= len(data_test)
        avg_test_mse /= len(data_test)
    model.train()
    avg_train_loss.backward()  # derive gradients.
    optimizer.step()  # update parameters based on gradients.
    return avg_train_loss.item(), avg_test_loss.item(), avg_train_mse, avg_test_mse, h


def train_reg(
    model: torch.nn.Module,
    data_train: torch_geometric.data.data.Data,
    data_test: torch_geometric.data.data.Data,
    y_train: torch.Tensor,
    y_test: torch.Tensor,
    nb_epochs: int,
    train_func: Callable[
        [
            torch.nn.Module,
            torch_geometric.data.data.Data,
            torch_geometric.data.data.Data,
        ],
        Tuple[float, float, float, float, torch.Tensor],
    ],
) -> Tuple[List[float], List[float], List[float], List[float], List[torch.Tensor]]:
    """Train the model.

    Args:
        model (torch.nn.Module): The model to train.
        data_train (torch_geometric.data.data.Data): The train graph data.
        data_test (torch_geometric.data.data.Data): The test graph data.
        y_train (torch.Tensor): The train labels.
        y_test (torch.Tensor): The test labels.
        nb_epochs (int): The number of epochs.
        train_func (Callable[[torch.nn.Module, Union[torch_geometric.data.data.Data, Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]], Tuple[float, float, float, float torch.Tensor]]): function that train our model for one epoch

    Returns:
        Tuple[List[float], List[float], List[float], List[float], List[torch.Tensor]]: The losses on train and test, the MSE on train and test, and the embeddings at each epochs.
    """
    top = perf_counter()  # calculate total training time
    model.train()
    losses_train = []
    losses_test = []
    mse_train = []
    mse_test = []
    embeddings = []
    for epoch in tqdm(range(1, nb_epochs + 1)):
        loss_train, loss_test, m_train, m_test, h = train_func(
            model, data_train, data_test, y_train, y_test
        )
        losses_train.append(loss_train)
        losses_test.append(loss_test)
        mse_train.append(m_train)
        mse_test.append(m_test)
        embeddings.append(h)
    print(f"Training finished! Total time: {round(perf_counter() - top, 3)}s.")
    return losses_train, losses_test, mse_train, mse_test, embeddings


In [None]:
nb_epochs = 300  # TODO: change the number of epochs?
losses_train, losses_test, mse_train, mse_test, embeddings = train_reg(
    model, train_data, test_data, y_train, y_test, nb_epochs, train_one_epoch_reg
)


In [None]:
def plot_learning_curves_reg(
    losses_train: List[float],
    losses_test: List[float],
    mse_train: List[float],
    mse_test: List[float],
) -> None:
    """Plot the learning curves.

    Args:
        losses_train (List[float]): The training losses.
        losses_test (List[float]): The test losses.
        mse_train (List[float]): The training mean squared errors.
        mse_test (List[float]): The test mean squared errors.
    """
    plt.figure(figsize=(16, 6))
    plt.subplot(121)
    plt.plot(losses_train, label="train")
    plt.plot(losses_test, label="test")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.subplot(122)
    plt.plot(100 * np.array(mse_train), label="train")
    plt.plot(100 * np.array(mse_test), label="test")
    plt.xlabel("Epoch")
    plt.ylabel("Mean squared errors")
    plt.legend()
    plt.suptitle(
        "Learning curves. Final train mse: {:.2f}. Final test mse: {:.2f}".format(
            mse_train[-1], mse_test[-1]
        )
    )
    plt.tight_layout()
    plt.show()


In [None]:
plot_learning_curves_reg(losses_train, losses_test, mse_train, mse_test)


Whay do you think? How to improve?