In [2]:
import rootutils

rootutils.setup_root("./", indicator=".project-root", pythonpath=True)

%load_ext autoreload
%autoreload 2

import topomodelx
import toponetx
import topobenchmarkx

In [3]:
import torch
import numpy as np
from torch_geometric.utils import to_undirected
import torch_geometric.datasets as geom_datasets

from topomodelx.nn.hypergraph.allset import AllSet

torch.manual_seed(0)
np.random.seed(0)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [5]:
cora = geom_datasets.Planetoid(root="tmp/", name="cora")
data = cora.data

x_0s = data.x
y = data.y
edge_index = data.edge_index

train_mask = data.train_mask
val_mask = data.val_mask
test_mask = data.test_mask



In [9]:
cora.x

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [5]:
# Ensure the graph is undirected (optional but often useful for one-hop neighborhoods).
edge_index = to_undirected(edge_index)

# Create a list of one-hop neighborhoods for each node.
one_hop_neighborhoods = []
for node in range(data.num_nodes):
    # Get the one-hop neighbors of the current node.
    neighbors = data.edge_index[1, data.edge_index[0] == node]

    # Append the neighbors to the list of one-hop neighborhoods.
    one_hop_neighborhoods.append(neighbors.numpy())

# Detect and eliminate duplicate hyperedges.
unique_hyperedges = set()
hyperedges = []
for neighborhood in one_hop_neighborhoods:
    # Sort the neighborhood to ensure consistent comparison.
    neighborhood = tuple(sorted(neighborhood))
    if neighborhood not in unique_hyperedges:
        hyperedges.append(list(neighborhood))
        unique_hyperedges.add(neighborhood)

In [6]:
max_edges = len(hyperedges)
incidence_1 = np.zeros((x_0s.shape[0], max_edges))
for col, neighibourhood in enumerate(hyperedges):
    for row in neighibourhood:
        incidence_1[row, col] = 1

assert all(incidence_1.sum(0) > 0) == True, "Some hyperedges are empty"
assert all(incidence_1.sum(1) > 0) == True, "Some nodes are not in any hyperedges"
incidence_1 = torch.Tensor(incidence_1).to_sparse_coo()

In [7]:
class Network(torch.nn.Module):
    """Network class that initializes the AllSet model and readout layer.

    Base model parameters:
    ----------
    Reqired:
    in_channels : int
        Dimension of the input features.
    hidden_channels : int
        Dimension of the hidden features.

    Optitional:
    **kwargs : dict
        Additional arguments for the base model.

    Readout layer parameters:
    ----------
    out_channels : int
        Dimension of the output features.
    task_level : str
        Level of the task. Either "graph" or "node".
    """

    def __init__(
        self, in_channels, hidden_channels, out_channels, task_level="graph", **kwargs
    ):
        super().__init__()

        # Define the model
        self.base_model = AllSet(
            in_channels=in_channels, hidden_channels=hidden_channels, **kwargs
        )

        # Readout
        self.linear = torch.nn.Linear(hidden_channels, out_channels)
        self.out_pool = True if task_level == "graph" else False

    def forward(self, x_0, incidence_1):
        # Base model
        x_0, x_1 = self.base_model(x_0, incidence_1)

        # Pool over all nodes in the hypergraph
        if self.out_pool is True:
            x = torch.max(x_0, dim=0)[0]
        else:
            x = x_0

        return self.linear(x)

In [8]:
# Base model hyperparameters
in_channels = x_0s.shape[1]
hidden_channels = 128
n_layers = 1
mlp_num_layers = 1

# Readout hyperparameters
out_channels = torch.unique(y).shape[0]
task_level = "graph" if out_channels == 1 else "node"


model = Network(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    n_layers=n_layers,
    mlp_num_layers=mlp_num_layers,
    task_level=task_level,
).to(device)

# Optimizer and loss
opt = torch.optim.Adam(model.parameters(), lr=0.01)

# Categorial cross-entropy loss
loss_fn = torch.nn.CrossEntropyLoss()

# Accuracy
acc_fn = lambda y, y_hat: (y == y_hat).float().mean()

In [9]:
x_0s = torch.tensor(x_0s)
x_0s, incidence_1, y = (
    x_0s.float().to(device),
    incidence_1.float().to(device),
    torch.tensor(y, dtype=torch.long).to(device),
)

  x_0s = torch.tensor(x_0s)
  torch.tensor(y, dtype=torch.long).to(device),


In [10]:
test_interval = 2
num_epochs = 30

epoch_loss = []
for epoch_i in range(1, num_epochs + 1):
    model.train()

    opt.zero_grad()

    # Extract edge_index from sparse incidence matrix
    y_hat = model(x_0s, incidence_1)
    loss = loss_fn(y_hat[train_mask], y[train_mask])

    loss.backward()
    opt.step()
    epoch_loss.append(loss.item())

    if epoch_i % test_interval == 0:
        model.eval()
        y_hat = model(x_0s, incidence_1)

        loss = loss_fn(y_hat[train_mask], y[train_mask])
        print(f"Epoch: {epoch_i} ")
        print(
            f"Train_loss: {np.mean(epoch_loss):.4f}, acc: {acc_fn(y_hat[train_mask].argmax(1), y[train_mask]):.4f}",
            flush=True,
        )

        loss = loss_fn(y_hat[val_mask], y[val_mask])

        print(
            f"Val_loss: {loss:.4f}, Val_acc: {acc_fn(y_hat[val_mask].argmax(1), y[val_mask]):.4f}",
            flush=True,
        )

        loss = loss_fn(y_hat[test_mask], y[test_mask])
        print(
            f"Test_loss: {loss:.4f}, Test_acc: {acc_fn(y_hat[test_mask].argmax(1), y[test_mask]):.4f}",
            flush=True,
        )

Epoch: 2 
Train_loss: 1.9459, acc: 0.2071
Val_loss: 1.9381, Val_acc: 0.1420
Test_loss: 1.9343, Test_acc: 0.1570
Epoch: 4 
Train_loss: 1.8995, acc: 0.3286
Val_loss: 1.8095, Val_acc: 0.2860
Test_loss: 1.7797, Test_acc: 0.3200
Epoch: 6 
Train_loss: 1.7463, acc: 0.5143
Val_loss: 1.7505, Val_acc: 0.4220
Test_loss: 1.7044, Test_acc: 0.4670
Epoch: 8 
Train_loss: 1.6252, acc: 0.5214
Val_loss: 2.0630, Val_acc: 0.3660
Test_loss: 1.9248, Test_acc: 0.4040
Epoch: 10 
Train_loss: 1.4922, acc: 0.6857
Val_loss: 1.5380, Val_acc: 0.5160
Test_loss: 1.3973, Test_acc: 0.5380
Epoch: 12 
Train_loss: 1.3399, acc: 0.8357
Val_loss: 1.5469, Val_acc: 0.5640
Test_loss: 1.4999, Test_acc: 0.5830
Epoch: 14 
Train_loss: 1.2132, acc: 0.9429
Val_loss: 1.8685, Val_acc: 0.6640
Test_loss: 1.7860, Test_acc: 0.6990
Epoch: 16 
Train_loss: 1.0973, acc: 0.9929
Val_loss: 3.4106, Val_acc: 0.6280
Test_loss: 3.3016, Test_acc: 0.6390
Epoch: 18 
Train_loss: 1.0341, acc: 1.0000
Val_loss: 2.2763, Val_acc: 0.7040
Test_loss: 2.3472, Test