# Train a Simplicial Attention Network (SAN)

In this notebook, we will create and train a High Skip Network in the simplicial complex domain, as proposed in the paper by [Hajij et. al : High Skip Networks: A Higher Order Generalization of Skip Connections (2022)](https://openreview.net/pdf?id=Sc8glB-k6e9). 

We train the model to perform binary node classification using the KarateClub benchmark dataset. 

The equations of one layer of this neural network are given by:

🟥 $\quad m_{{y \rightarrow z}}^{(0 \rightarrow 0)} = \sigma ((A_{\uparrow,0})_{xy} \cdot h^{t,(0)}_y \cdot \Theta^{t,(0)1})$    (level 1)

🟥 $\quad m_{z \rightarrow x}^{(0 \rightarrow 0)}  = (A_{\uparrow,0})_{xy} \cdot m_{y \rightarrow z}^{(0 \rightarrow 0)} \cdot \Theta^{t,(0)2}$    (level 2)

🟥 $\quad m_{{y \rightarrow z}}^{(0 \rightarrow 1)}  = \sigma((B_1^T)_{zy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(0 \rightarrow 1)})$    (level 1)

🟥 $\quad m_{z \rightarrow x)}^{(1 \rightarrow 0)}  = (B_1)_{xz} \cdot m_{z \rightarrow x}^{(0 \rightarrow 1)} \cdot \Theta^{t, (1 \rightarrow 0)}$    (level 2)

🟧 $\quad m_{x}^{(0 \rightarrow 0)}  = \sum_{z \in \mathcal{L}_\uparrow(x)} m_{z \rightarrow x}^{(0 \rightarrow 0)}$

🟧 $\quad m_{x}^{(1 \rightarrow 0)}  = \sum_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{(1 \rightarrow 0)}$

🟩 $\quad m_x^{(0)}  = m_x^{(0 \rightarrow 0)} + m_x^{(1 \rightarrow 0)}$

🟦 $\quad h_x^{t+1,(0)}  = I(m_x^{(0)})$

Where the notations are defined in [Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023)](https://arxiv.org/abs/2304.10031).

In [127]:
import numpy as np
import torch
from torch.nn.parameter import Parameter
from toponetx import SimplicialComplex
import toponetx.datasets.graph as graph
from torch_geometric.utils.convert import to_networkx

from topomodelx.nn.simplicial.san_layer import SANLayer

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Pre-processing

## Import dataset ##

The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.

We must first lift our graph dataset into the simplicial complex domain.

In [2]:
dataset = graph.karate_club(complex_type="simplicial")
print(dataset)

Simplicial Complex with shape [34, 78, 45, 11, 2] and dimension 4


## Define neighborhood structures. ##

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on the domain. In this case, we need the boundary matrix (or incidence matrix) $B_1$ and the adjacency matrix $A_{\uparrow,0}$ on the nodes. For a santiy check, we show that the shape of the $B_1 = n_\text{nodes} \times n_\text{edges}$ and $A_{\uparrow,0} = n_\text{nodes} \times n_\text{nodes}$. We also convert the neighborhood structures to torch tensors.

In [3]:
# incidence_2 = dataset.incidence_matrix(rank=2)
# incidence_1 = dataset.incidence_matrix(rank=1)
# adjacency_0 = dataset.adjacency_matrix(rank=0)

# incidence_2 = torch.from_numpy(incidence_2.todense())#.to_sparse()
# incidence_1 = torch.from_numpy(incidence_1.todense())#.to_sparse()
# adjacency_0 = torch.from_numpy(adjacency_0.todense())#.to_sparse()

# print(f"The incidence matrix B2 has shape: {incidence_2.shape}.")
# print(f"The incidence matrix B1 has shape: {incidence_1.shape}.")
# print(f"The adjacency matrix A0 has shape: {adjacency_0.shape}.")

## Import signal ##

Since our task will be node classification, we must retrieve an input signal on the nodes. The signal will have shape $n_\text{nodes} \times$ in_channels, where in_channels is the dimension of each cell's feature. Here, we have in_channels = channels_nodes $ = 34$. This is because the Karate dataset encodes the identity of each of the 34 nodes as a one hot encoder.

In [4]:
# x_0 = []
# for _, v in dataset.get_simplex_attributes("node_feat").items():
#     x_0.append(v)
# x_0 = torch.tensor(np.stack(x_0))
# channels_nodes = x_0.shape[-1]
# print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")

x_1 = []
for k, v in dataset.get_simplex_attributes("edge_feat").items():
    x_1.append(v)
x_1 = torch.tensor(np.stack(x_1))
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")

# x_2 = []
# for k, v in dataset.get_simplex_attributes("face_feat").items():
#     x_2.append(v)
# x_2 = np.stack(x_2)
# print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")

There are 78 edges with features of dimension 2.


In [30]:
channels_nodes = x_1.shape[-1]
output_dim = 2
J = 3
att_slice = J * output_dim

## Define binary labels
We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.

We convert the binary labels into one-hot encoder form, and keep the first four nodes' true labels for the purpose of testing.

In [128]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_test = y_true[:4]
y_train = y_true[-30:]

y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

# Create the Neural Network

Using the HSNLayer class, we create a neural network with stacked layers. A linear layer at the end produces an output with shape $n_\text{nodes} \times 2$, so we can compare with our binary labels.

In [129]:
class SAN(torch.nn.Module):
    """Simplicial Attention Network Implementation for binary node classification.

    Parameters
    ---------
    channels : int
        Dimension of features
    n_layers : int
        Amount of message passing layers.

    """

    def __init__(
        self,
        channels_in,
        channels_out,
        J=2,
        J_har=5,  # approximation order for harmonic
        epsilon_har=1e-1,  # epsilon for harmonic, it takes into account the normalization
        n_layers=2,
    ):
        super().__init__()
        self.J_har = J_har
        self.epsilon_har = epsilon_har
        layers = []
        for _ in range(n_layers):
            layers.append(
                SANLayer(
                    channels_in=channels_in,
                    channels_out=channels_out,
                    J=J,
                )
            )
        self.linear = torch.nn.Linear(channels_out, 2)
        self.layers = layers

    def compute_projection_matrix(self, L):
        P = torch.eye(L.shape[0]) - self.epsilon_har * L
        for _ in range(self.J_har):
            P = P @ P  # approximate the limit
        return P

    def forward(self, x, Lup, Ldown):
        """Forward computation.

        Parameters
        ---------
        x : tensor
            shape = [n_nodes, channels_in]
            Node features.

        Lup : tensor
            shape = [n_nodes, n_edges]
            Boundary matrix of rank 1.

        Ld : tensor
            shape = [n_nodes, n_nodes]
            Adjacency matrix (up) of rank 0.

        P : tensor
            shape = [n_nodes, n_nodes]


        Returns
        --------
        _ : tensor
            shape = [n_nodes, 2]
            One-hot labels assigned to nodes.

        """
        # Compute the projection matrix for the harmonic component
        L = Lup + Ldown
        P = self.compute_projection_matrix(L)

        for layer in self.layers:
            x = layer(x, Lup, Ldown, P)
        return torch.sigmoid(self.linear(x))

# Train the Neural Network

We specify the model with our pre-made neighborhood structures and specify an optimizer.

In [130]:
Ldown = torch.from_numpy(dataset.down_laplacian_matrix(rank=1).todense()).to_sparse()
Lup = torch.from_numpy(dataset.up_laplacian_matrix(rank=1).todense()).to_sparse()

# Ldown = torch.from_numpy(Ldown.todense())#.to_sparse()
# Lup = torch.from_numpy(Lup.todense())#.to_sparse()

The following cell performs the training, looping over the network for a low number of epochs.

In [131]:
dataset.incidence_matrix(rank=1)

<34x78 sparse matrix of type '<class 'numpy.float32'>'
	with 156 stored elements in COOrdinate format>

In [138]:
model = SAN(
    channels_in=channels_nodes,
    channels_out=channels_nodes,
    n_layers=3,
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.4)
test_interval = 2
num_epochs = 5
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat = model(x_1, Lup=Lup, Ldown=Ldown)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_hat[-len(y_train) :].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred == y_hat).all(dim=1).float().mean().item()
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            y_hat_test = model(x_1, Lup=Lup, Ldown=Ldown)
            y_pred_test = torch.sigmoid(y_hat_test).ge(0.5).float()
            test_accuracy = (
                torch.eq(y_pred_test[: len(y_test)], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)

Epoch: 1 loss: 0.7243 Train_acc: 0.0000
Epoch: 2 loss: 0.7127 Train_acc: 0.0000
Test_acc: 0.0000
Epoch: 3 loss: 0.7040 Train_acc: 0.0000
Epoch: 4 loss: 0.6983 Train_acc: 0.0000
Test_acc: 0.0000
Epoch: 5 loss: 0.6950 Train_acc: 0.0000


In [41]:
y_hat = model(x_1, Lup=Lup, Ldown=Ldown)

In [31]:
weight_irr = Parameter(torch.Tensor(J, channels_nodes, output_dim))
att_irr = Parameter(torch.Tensor(2 * att_slice, 1))

In [34]:
x_irr = torch.matmul(x_1, weight_irr).reshape(-1, J * output_dim)
x_irr.shape

torch.Size([78, 6])

In [36]:
(x_irr @ att_irr[:att_slice, :]).shape

torch.Size([78, 1])

In [39]:
(x_irr @ att_irr[:att_slice, :]).shape

torch.Size([78, 1])

In [44]:
(x_irr @ att_irr[att_slice:, :]).T.shape

torch.Size([1, 78])

In [79]:
e_irr = (x_irr @ att_irr[:att_slice, :]) + (x_irr @ att_irr[att_slice:, :]).T

In [97]:
(x_irr @ att_irr[:att_slice, :]).shape

torch.Size([78, 1])

In [102]:
(x_irr @ att_irr[att_slice:, :]).T.shape

torch.Size([1, 78])

In [98]:
torch.mm(x_irr, att_irr[:att_slice, :]).shape

torch.Size([78, 1])

In [101]:
torch.mm(x_irr, att_irr[att_slice:, :]).T.shape

torch.Size([1, 78])

In [95]:
alpha_irr = torch.sparse.softmax(e_irr.sparse_mask(Ldown), dim=1)

In [93]:
alpha_irr

tensor(indices=tensor([[ 0,  0,  0,  ..., 77, 77, 77],
                       [ 0,  1,  2,  ..., 75, 76, 77]]),
       values=tensor([0.0417, 0.0417, 0.0417,  ..., 0.0455, 0.0455, 0.0455]),
       size=(78, 78), nnz=1134, layout=torch.sparse_coo,
       grad_fn=<SparseSoftmaxBackward0>)

In [110]:
alpha_exp_irr = alpha_irr.unsqueeze(0)
for p in range(J - 1):
    alpha_exp_irr = torch.cat(
        [alpha_exp_irr, torch.mm(alpha_exp_irr[p], alpha_irr).unsqueeze(0)], dim=0
    )

In [111]:
alpha_exp_irr

tensor(indices=tensor([[ 0,  0,  0,  ...,  2,  2,  2],
                       [ 0,  0,  0,  ..., 77, 77, 77],
                       [ 0,  1,  2,  ..., 75, 76, 77]]),
       values=tensor([0.0417, 0.0417, 0.0417,  ..., 0.0376, 0.0368, 0.0476]),
       size=(3, 78, 78), nnz=10742, layout=torch.sparse_coo,
       grad_fn=<CatBackward0>)

In [115]:
x_irr = torch.matmul(x_1, weight_irr)
x_irr.shape

torch.Size([3, 78, 2])

In [126]:
torch.sum(torch.matmul(alpha_exp_irr.to_dense(), x_irr), dim=0).shape

torch.Size([78, 2])

In [78]:
A.sparse_mask(Ldown)

tensor(indices=tensor([[ 0,  0,  0,  ..., 77, 77, 77],
                       [ 0,  1,  2,  ..., 75, 76, 77]]),
       values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),
       size=(78, 78), nnz=1134, layout=torch.sparse_coo)

In [135]:
layer = SANLayer(channels_in=channels_nodes, channels_out=output_dim, J=J)

In [137]:
layer(x_1, Lup, Ldown, Lup).shape

torch.Size([78, 2])