In [1]:
import math
%load_ext autoreload
%autoreload 2

In [40]:
from mantra.simplicial import SimplicialDataset
import torch
from torch import nn
import torchmetrics
import torchvision.transforms as transforms
from mantra.transforms import SimplicialComplexTransform
from mantra.dataloaders import SimplicialDataLoader
from topomodelx.nn.simplicial.scnn import SCNN
from torch_geometric.nn import pool
import lightning as L
from typing import Literal
from torch.utils.data import random_split
import math
from mantra.utils import transfer_simplicial_complex_batch_to_device

In [3]:
from mantra.transforms import SimplicialComplexDegreeTransform, SimplicialComplexEdgeCoadjacencyDegreeTransform, \
    SimplicialComplexEdgeAdjacencyDegreeTransform, SimplicialComplexTriangleCoadjacencyDegreeTransform, \
    OrientableToClassSimplicialComplexTransform, DimTwoHodgeLaplacianSimplicialComplexTransform, \
    DimOneHodgeLaplacianDownSimplicialComplexTransform, DimOneHodgeLaplacianUpSimplicialComplexTransform, \
    DimZeroHodgeLaplacianSimplicialComplexTransform

tr = transforms.Compose(
            [SimplicialComplexTransform(), 
             SimplicialComplexDegreeTransform(),
             SimplicialComplexEdgeCoadjacencyDegreeTransform(),
             SimplicialComplexEdgeAdjacencyDegreeTransform(),
             SimplicialComplexTriangleCoadjacencyDegreeTransform(),
             DimZeroHodgeLaplacianSimplicialComplexTransform(),
             DimOneHodgeLaplacianUpSimplicialComplexTransform(),
             DimOneHodgeLaplacianDownSimplicialComplexTransform(),
             DimTwoHodgeLaplacianSimplicialComplexTransform(),
             OrientableToClassSimplicialComplexTransform()]
        )

In [4]:
dataset = SimplicialDataset(root="./data", transform=tr)

## Train the Neural Network
We specify the model with our pre-made neighborhood structures and specify an optimizer.

In [5]:
rank = 1  # simplex level. We'll use the features of the rank-simplices.
conv_order_down = 2 # TODO: No idea of what this parameter does
conv_order_up = 2 # TODO: No idea of what this parameter does
hidden_channels = 4
out_channels = 1  # num classes
num_layers = 3
# Check the rank has an appropriate value.
assert 0 <= rank <= 2, "rank must be 0, 1 or 2."
# select the simplex level
if rank == 0:
    conv_order_down = 0
# configure parameters
in_channels = dataset[0].x[rank].shape[1]

In [57]:
class SCNNNetwork(L.LightningModule):
    def __init__(self, rank, in_channels, hidden_channels, out_channels,
                 conv_order_down, conv_order_up, n_layers=3):
        super().__init__()
        self.rank = rank
        self.base_model = SCNN(in_channels=in_channels,
                               hidden_channels=hidden_channels,
                               conv_order_down=conv_order_down, 
                               conv_order_up=conv_order_up, 
                               n_layers=n_layers)
        self.liner_readout = torch.nn.Linear(hidden_channels, out_channels)
        # Accuracy metrics
        self.training_accuracy = torchmetrics.classification.BinaryAccuracy()
        self.validation_accuracy = torchmetrics.classification.BinaryAccuracy()
        self.test_accuracy = torchmetrics.classification.BinaryAccuracy()
    
    def forward(self, x, laplacian_down, laplacian_up, signal_belongings):
        x = self.base_model(x, laplacian_down, laplacian_up)
        x = self.liner_readout(x)
        x_mean = pool.global_mean_pool(x, signal_belongings)
        x_mean[torch.isnan(x_mean)] = 0
        return x_mean
    
    def transfer_batch_to_device(self, batch, device, dataloader_idx):
        return transfer_simplicial_complex_batch_to_device(batch, device, dataloader_idx)
        
    def general_step(self, batch, batch_idx, step: Literal['train', 'test', 'validation']):
        s_complexes, signal_belongings, batch_len = batch
        x = s_complexes.signals[self.rank]
        if rank == 0:
            laplacian_down = None
            laplacian_up = s_complexes.neighborhood_matrices[f'0_laplacian']
        elif rank == 1:
            laplacian_down = s_complexes.neighborhood_matrices[f'1_laplacian_down']
            laplacian_up = s_complexes.neighborhood_matrices[f'1_laplacian_up']
        elif rank == 2:
            laplacian_down = s_complexes.neighborhood_matrices[f'2_laplacian']
            laplacian_up = None
        else:
            raise ValueError("rank must be 0, 1 or 2.")
        y = s_complexes.other_features['y'].float()
        signal_belongings = signal_belongings[self.rank]
        x_hat = self(x, laplacian_down, laplacian_up, signal_belongings)
        # Squeeze x_hat to match the shape of y
        x_hat = x_hat.squeeze()
        loss = nn.functional.binary_cross_entropy_with_logits(x_hat, y)
        self.log('train_loss', loss, prog_bar=True, batch_size=batch_len, on_step=False, on_epoch=True)
        self.log_accuracies(x_hat, y, batch_len, step)
        return loss
    
    def log_accuracies(self, x_hat, y, batch_len, step: Literal['train', 'test', 'validation']):
        # Apply the sigmoid function to x_hat to get the probabilities
        x_hat = torch.sigmoid(x_hat)
        if step == 'train':
            self.training_accuracy(x_hat, y)
            self.log('train_accuracy', self.training_accuracy, prog_bar=True, on_step=False, on_epoch=True, batch_size=batch_len)
        elif step == 'test':
            self.test_accuracy(x_hat, y)
            self.log('test_accuracy', self.test_accuracy, prog_bar=True, on_step=False, on_epoch=True, batch_size=batch_len)
        elif step == 'validation':
            self.validation_accuracy(x_hat, y)
            self.log('validation_accuracy', self.validation_accuracy, prog_bar=True, on_step=False, on_epoch=True, batch_size=batch_len)
    
    def test_step(self, batch, batch_idx):
        return self.general_step(batch, batch_idx, 'test')
    
    def validation_step(self, batch, batch_idx):
        return self.general_step(batch, batch_idx, 'validation')
    
    def training_step(self, batch, batch_idx):
        return self.general_step(batch, batch_idx, 'train')
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
        return optimizer
        

In [58]:
model = SCNNNetwork(rank=rank,
                    in_channels=in_channels,
                    hidden_channels=hidden_channels,
                    out_channels=out_channels,
                    conv_order_down=conv_order_down,
                    conv_order_up=conv_order_up,
                    n_layers=num_layers)

loss_fn = torch.nn.MSELoss()

In [59]:
# Split the dataset
test_percentage = 0.2
batch_size = 16
test_len = math.floor(len(dataset) * test_percentage)
train_ds, test_ds = random_split(dataset, [len(dataset) - test_len, test_len])
train_dl = SimplicialDataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_dl = SimplicialDataLoader(test_ds, batch_size=batch_size, shuffle=False)

In [60]:
# Use CPU acceleration: SCCNN does not support GPU acceleration because it creates matrices not placed in the device of the network.
trainer = L.Trainer(max_epochs=10, accelerator='cpu')
trainer.fit(model, train_dl, test_dl)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name                | Type           | Params
-------------------------------------------------------
0 | base_model          | SCNN           | 200   
1 | liner_readout       | Linear         | 5     
2 | training_accuracy   | BinaryAccuracy | 0     
3 | validation_accuracy | BinaryAccuracy | 0     
4 | test_accuracy       | BinaryAccuracy | 0     
-------------------------------------------------------
205       Trainable params
0         Non-trainable params
205       Total params
0.001     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.
