# Train a Block simplicial complex neural networks (BScNets)

In this notebook, we will create and train Block simplicial complex neural networks, as proposed in the paper by [Yuzhou Chen, Yulia R Gel and H Vincent Poor. BScNets: Block simplicial complex neural networks. Proceedings of the AAAI Conference on Artificial Intelligence. 2022]. 

We train the model to perform Cora benchmark dataset. 

The equations of one layer of this neural network are given by:

🟥 $\quad m_{y \rightarrow x}^{(r \rightarrow r)} = (H_r)\_{xy} \cdot h_{{y}}^{t, (r)} \cdot \Theta^{t,(r \rightarrow r)}$

🟥 $\quad m_{y \rightarrow x}^{(r \rightarrow r')} = (G_{r \rightarrow r'})\_{xy} \cdot h^{t, (r)}\_y \cdot \Theta^{t,(r \rightarrow r')}$

🟥 $\quad m_{y \rightarrow x}^{(r' \rightarrow r)} = (G{r' \rightarrow r})\_{xy} \cdot h_y^{t,(r')} \cdot \Theta^{t,(r' \rightarrow r)}$

🟥 $\quad m_{y \rightarrow x}^{(r' \rightarrow r')}  = (H_{r'})\_{xy} \cdot h_{{y}}^{t,(r')} \cdot \Theta^{t,(r' \rightarrow r')}$

🟧 $\quad m_x^{(r' \rightarrow r)} = \sum_{y \in \mathcal{N}\_\uparrow(x)} m_{y \rightarrow x}^{(r' \rightarrow r)}$

🟧 $\quad m_x^{(r \rightarrow r')}  = \sum_{y \in \mathcal{N}\_\downarrow(x)} m_{y \rightarrow x}^{(r \rightarrow r')}$

🟧 $\quad m_x^{(r \rightarrow r)}  = \sum_{y \in (\mathcal{L}\_\uparrow+\mathcal{L}\_\downarrow)(x)} m_{y \rightarrow x}^{(r \rightarrow r)}$

🟧 $\quad m_x^{(r' \rightarrow r')}  = \sum_{y \in (\mathcal{L}\_\uparrow+\mathcal{L}\_\downarrow)(x)} m_{y \rightarrow x}^{(r' \rightarrow r')}$

🟩 $\quad m_x^{(r)} = m_x^{(r \rightarrow r)}+ m_x^{(r' \rightarrow r)}$

🟩 $\quad m_x^{(r')}  = m_x^{(r' \rightarrow r')} + m_x^{(r \rightarrow r')}$

🟦 $\quad h^{t+1, (r)}\_x  = \sigma(m_x^{(r)})$

🟦 $\quad h^{t+1, (r')}\_x = \sigma(m_x^{(r')})$


Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

In [1]:
import torch
import numpy as np

from toponetx import SimplicialComplex
import toponetx.datasets.graph as graph
import torch_geometric

# from topomodelx.nn.simplicial.hsn_layer import HSNLayer
import os
import torch

# import loaddatas as lds
import torch.nn.functional as F
import numpy as np
import topomodelx.nn.simplicial.bScNet_layer as bScLayer
from sklearn.metrics import roc_auc_score, average_precision_score
from torch.nn.init import xavier_normal_ as xavier
import torch_geometric.transforms as T

# Loading Cora Dataset

In [2]:
# dataset = graph.karate_club(complex_type="simplicial")
# print(dataset)
# Cora Dataset
dataset = torch_geometric.datasets.Planetoid(
    root="tmp/Cora", name="Cora", transform=T.NormalizeFeatures()
)

data = dataset[0]

# Checking Cora Dataset

In [3]:
# Check Cora data
print(dataset.num_classes)
print(dataset.name)
print(data)
print(data.edge_index[1][1])

7
Cora
Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
tensor(1862)


# Initilizing Model

In [4]:
print(data)
model, data = locals()["bScLayer"].call(
    data, dataset.name, data.x.size(1), dataset.num_classes
)

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])


  adj = nx.adjacency_matrix(g)
  nx.incidence_matrix(G, nodelist=V, edgelist=E, oriented=True).todense()
  L0u = B1.T @ B1  # B1 @ D3_n @ B1.T @ inv(D2_1)


In [5]:
def weights_init(m):
    if isinstance(m, torch.nn.Linear):
        xavier(m.weight)
        if not m.bias is None:
            torch.nn.init.constant_(m.bias, 0)

In [6]:
model.apply(weights_init)

optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=0)
best_val_acc = test_acc_same = test_acc_diff = test_acc = 0.0
best_val_roc = test_roc_same = test_roc_diff = test_roc = 0.0
best_val_loss = np.inf
# train and val/test
wait_step = 0

# Training and Testing

In [7]:
wait_total = 200
total_epochs = 6


def train():
    model.train()
    optimizer.zero_grad()
    emb = model.g_encode(data).clone()
    x, y = model.s_encode(data, emb)  # emb from encode's, i.e., Gconv's output
    loss = F.binary_cross_entropy(x, y)
    loss.backward()
    optimizer.step()
    return x


def test():
    model.eval()
    accs = []
    emb = model.g_encode(data)
    for type in ["val", "test"]:
        pred, y = model.s_encode(data, emb, type=type)
        pred, y = pred.cpu(), y.cpu()
        if type == "val":
            accs.append(F.binary_cross_entropy(pred, y))
            pred = pred.data.numpy()
            roc = roc_auc_score(y, pred)
            accs.append(roc)
            acc = average_precision_score(y, pred)
            accs.append(acc)
        else:
            pred = pred.data.numpy()
            roc = roc_auc_score(y, pred)
            accs.append(roc)
            acc = average_precision_score(y, pred)
            accs.append(acc)
    return accs


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True


# train and test
for epoch in range(1, total_epochs + 1):
    print("epoch is:", epoch)
    pred = train()
    val_loss, val_roc, val_acc, tmp_test_roc, tmp_test_acc = test()
    if val_roc >= best_val_roc:
        test_acc = tmp_test_acc
        test_roc = tmp_test_roc
        best_val_acc = val_acc
        best_val_roc = val_roc
        best_val_loss = val_loss
        wait_step = 0
    else:
        wait_step += 1
        if wait_step == wait_total:
            print(
                "Early stop! Min loss: ",
                best_val_loss,
                ", Max accuracy: ",
                best_val_acc,
                ", Max roc: ",
                best_val_roc,
            )
            break
    print(best_val_roc)
# del model
# del data
# print result

# pipeline_acc[Conv_method][data_cnt] = test_acc
# pipeline_roc[Conv_method][data_cnt] = test_roc

# log = 'Epoch: ' + str(
#    total_epochs) + ', dataset name: ' + d_name + ', Method: ' + Conv_method + ' Test pr: {:.4f}, roc: {:.4f} \n'
# print((log.format(pipeline_acc[Conv_method][data_cnt], pipeline_roc[Conv_method][data_cnt])))

epoch is: 1
0.6203863002211973
epoch is: 2
0.6203863002211973
epoch is: 3
0.6203863002211973
epoch is: 4
0.6203863002211973
epoch is: 5
0.6203863002211973
epoch is: 6
0.6203863002211973
epoch is: 7
0.6203863002211973
epoch is: 8
0.6203863002211973
epoch is: 9
0.6203863002211973
epoch is: 10
0.6203863002211973
epoch is: 11
0.6543393716838468
epoch is: 12
0.6946753603492895
epoch is: 13
0.7057207708655612
epoch is: 14
0.7089881305208982
epoch is: 15
0.7089881305208982
epoch is: 16
0.7089881305208982
epoch is: 17
0.7089881305208982
epoch is: 18
0.7089881305208982
epoch is: 19
0.7089881305208982
epoch is: 20
0.7201058277552083
