## Testing the S2GNN Class

In [4]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m20.9 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [5]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.datasets import LRGBDataset
from torch_geometric.loader import DataLoader

# Load Peptides Functional dataset
dataset_train = LRGBDataset(root="./data", name="Peptides-func", split="train")
dataset_val = LRGBDataset(root="./data", name="Peptides-func", split="val")
dataset_test = LRGBDataset(root="./data", name="Peptides-func", split="test")

train_loader = DataLoader(dataset_train, batch_size=100, shuffle=True)
test_loader = DataLoader(dataset_test, batch_size=100, shuffle=False)

Downloading https://www.dropbox.com/s/ycsq37q8sxs1ou8/peptidesfunc.zip?dl=1
Extracting data/peptidesfunc.zip
Processing...
Processing train dataset: 100%|██████████| 10873/10873 [00:00<00:00, 25052.66it/s]
Processing val dataset: 100%|██████████| 2331/2331 [00:00<00:00, 26026.47it/s]
Processing test dataset: 100%|██████████| 2331/2331 [00:00<00:00, 32607.46it/s]
Done!


In [6]:

from torch_geometric.utils import to_dense_adj, to_networkx
import numpy as np
from sklearn.metrics import roc_auc_score
from torch_geometric.utils import to_scipy_sparse_matrix
import scipy.sparse as sp

import torch.nn as nn

class SpectralFilter(nn.Module):
    def __init__(self, k, hidden_dim):
        super(SpectralFilter, self).__init__()
        self.mlp = nn.Sequential(
            nn.Linear(k, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, k)
        )

    def forward(self, laplacian_eigenvalues_list):
        eigenvalues_batch = torch.stack(laplacian_eigenvalues_list).to(device)  # Ensure same device
        filter_matrix = self.mlp(eigenvalues_batch)  # Shape: (batch_size, k)
        return filter_matrix


# Spectral-based Graph Convolution Layer
class SpectralGCNConv(torch.nn.Module):
    def __init__(self, in_channels, out_channels, k=5):
        super(SpectralGCNConv, self).__init__()
        self.k = k  # Number of spectral filters
        # self.fc = torch.nn.Linear(k**2, out_channels)
        self.fc = nn.Linear(k, out_channels)
        self.feature_transform_layer = torch.nn.Linear(in_channels, k)
        self.spectral_filter = SpectralFilter(k, k)

    def forward(self, batch):

        device = self.fc.weight.device
        # Step 1: Calculate Laplacian eigenvalues and eigenvectors
        laplacian_eigenvalues_list, laplacian_eigenvectors_list = self.compute_laplacian_eigen(batch)

        # Step 2: Apply eigenvectors transformation
        feature_transform_list = self.feature_transform(batch.to(device))  # Optional feature transformation (could be a linear layer)

        x_hat_list = [eigvec.T.to(device) @ x.to(device) for x, eigvec in zip(feature_transform_list, laplacian_eigenvectors_list)]
        x_hat = torch.stack(x_hat_list, dim=0)
        x_hat = x_hat.to(device)

        # # Step 3: Apply spectral filter to transformed features
        filter_matrix = self.spectral_filter(laplacian_eigenvalues_list).to(device)  # Shape: (batch_size, k)

        # First, ensure filter_matrix is reshaped to [32, 5, 1]
        filter_matrix = filter_matrix.unsqueeze(-1)  # shape becomes [32, 5, 1]

        # Now perform batch matrix multiplication: [32, 5, 5] @ [32, 5, 1] -> [32, 5, 1]
        x_hat = torch.bmm(x_hat, filter_matrix)

        # # Optional: Normalization (learnable or fixed)
        # y_hat = self.learnable_norm(y_hat)

        # The linear transformation
        x_hat = x_hat.squeeze(-1)
        y = self.fc(x_hat.view(batch.num_graphs, -1))
        return y


    def compute_laplacian_eigen(self, batch_data):
        laplacian_eigenvalues_list = []
        laplacian_eigenvectors_list = []

        edge_index = batch_data.edge_index  # Edge index for the whole batch
        batch = batch_data.batch  # Mapping of nodes to graphs
        num_graphs = batch.max().item() + 1  # Number of graphs in the batch

        adj_matrices = []  # Store adjacency matrices for each graph

        for i in range(num_graphs):
            # Get node indices belonging to graph i
            node_mask = (batch == i)
            node_indices = node_mask.nonzero(as_tuple=True)[0]

            # Create a mapping from global node index → local index
            node_mapping = {idx.item(): j for j, idx in enumerate(node_indices)}

            # Filter edge_index for this graph
            edge_mask = node_mask[edge_index[0]]  # Only edges where source is in graph i
            edge_index_i = edge_index[:, edge_mask]  # Filter edges

            # Remap global indices to local indices
            edge_index_i = torch.tensor([[node_mapping[n.item()] for n in edge_index_i[0]],
                                        [node_mapping[n.item()] for n in edge_index_i[1]]], dtype=torch.long)

            # Convert to adjacency matrix
            adj_matrix = to_dense_adj(edge_index_i, max_num_nodes=len(node_indices)).squeeze(0)
            laplacian = sp.csgraph.laplacian(adj_matrix.cpu().numpy(), normed=True)
            laplacian = torch.tensor(laplacian, dtype=torch.float32)

            # Compute eigenvalues and eigenvectors
            k = self.k  # Number of singular values (eigenvalues) you want to compute
            u, s, v = torch.svd_lowrank(laplacian, q=k)
            eigenvalues, eigenvectors = s[:k], v[:, :k]

            laplacian_eigenvalues_list.append(torch.tensor(eigenvalues, dtype=torch.float32))
            laplacian_eigenvectors_list.append(torch.tensor(eigenvectors, dtype=torch.float32))

        return laplacian_eigenvalues_list, laplacian_eigenvectors_list


    def feature_transform(self, batch_data):
        feature_transform_list = []

        for i in range(batch_data.num_graphs):
            # Get node indices belonging to graph i
            node_mask = (batch_data.batch == i)
            node_indices = node_mask.nonzero(as_tuple=True)[0]

            # Create a mapping from global node index → local index
            node_mapping = {idx.item(): j for j, idx in enumerate(node_indices)}

            # Select the node features for the nodes belonging to graph i
            node_features_graph_i = batch_data.x[node_mask]  # Features for nodes belonging to graph i

            # Apply the feature transformation layer
            feat_trans = self.feature_transform_layer(node_features_graph_i)

            feature_transform_list.append(feat_trans)

        return feature_transform_list

    def learnable_norm(self, y_hat):
        # Optional learnable norm for normalization after spectral filtering
        return torch.nn.functional.normalize(y_hat, p=2, dim=1)


# Define a simple GNN model with spectral convolution
class SimpleGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(SimpleGNN, self).__init__()
        self.spectral_conv = SpectralGCNConv(in_channels, hidden_channels)
        self.fc = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, batch):
        x = self.spectral_conv(batch)  # Apply spectral convolution
        # Define `batch` (assuming one graph per batch entry)
        batch_size = len(x)
        batch_idx = torch.arange(batch_size)  # Shape: [32], each graph is its own batch entry
        # Ensure batch_idx is on the same device as x
        batch_idx = batch_idx.to(x.device)

        # Reshape `x_hat` to match global_mean_pool expectations
        x_flat = x.view(batch_size, -1)  # Shape: [32, 25]

        # Apply pooling
        x_pooled = global_mean_pool(x_flat, batch_idx)

        return self.fc(x_pooled)

# Initialize model, optimizer, and loss function
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
hidden_layer_size = 64
model = SimpleGNN(dataset_train.num_node_features, hidden_layer_size, dataset_train.num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
criterion = torch.nn.CrossEntropyLoss()

# Training loop
def train():
    model.train()
    for data in train_loader:
        data = data.to(device)
        data.x = data.x.float()  # Ensure node features are float
        optimizer.zero_grad()
        out = model(data)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()

# Evaluation function for ROC-AUC score and Accuracy
def test():
    model.eval()
    all_labels = []
    all_preds = []
    all_pred_classes = []  # Stores predicted class labels
    correct = 0
    total = 0

    with torch.no_grad():
        for data in test_loader:
            data = data.to(device)
            data.x = data.x.float()  # Ensure node features are float
            out = model(data)

            # Store predictions and labels
            all_labels.append(data.y.cpu().numpy())  # True labels (one-hot)
            all_preds.append(out.cpu().numpy())  # Raw output probabilities/logits

            # Convert logits to predicted class (argmax for multi-class)
            pred_classes = out.argmax(dim=1).cpu().numpy()
            all_pred_classes.append(pred_classes)

            # Convert one-hot labels to class indices
            true_classes = data.y.argmax(dim=1).cpu().numpy()

            # Compute accuracy
            correct += (pred_classes == true_classes).sum()
            total += data.y.size(0)

    # Compute final metrics
    all_labels = np.concatenate(all_labels, axis=0)
    all_preds = np.concatenate(all_preds, axis=0)
    all_pred_classes = np.concatenate(all_pred_classes, axis=0)

    roc_auc = roc_auc_score(all_labels, all_preds, multi_class='ovr')
    accuracy = correct / total  # Compute accuracy

    return roc_auc, accuracy


In [None]:
# Run training and evaluation
for epoch in range(10):
    train()
    roc_auc, accuracy = test()
    print(f"Epoch {epoch+1}, Test ROC-AUC: {roc_auc:.4f}, Accuracy: {accuracy:.4f}")