In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import torch.nn as nn
import torch.nn.functional as F
from scipy.sparse import coo_matrix

In [2]:
class SheafKarateClubDataset(Dataset):
    def __init__(self, stalk_dim=2):
        # original graph
        G = nx.karate_club_graph()
        self.num_nodes = G.number_of_nodes()
        self.stalk_dim = stalk_dim  # Dimensionality of the stalks (k in the formula)
        
        # edge list and adjacency matrix
        edges = list(G.edges())
        self.num_edges = len(edges)
        self.edge_index = torch.tensor([[u, v] for u, v in edges], dtype=torch.long).t()
        
        # Create adjacency matrix
        adj = nx.adjacency_matrix(G).todense()
        self.adj = torch.tensor(adj, dtype=torch.float)
        
        # Compute node degrees
        degrees = torch.tensor([G.degree(i) for i in range(self.num_nodes)], dtype=torch.float)
        self.max_degree = torch.max(degrees).item()
        
        node_features = torch.eye(self.num_nodes)  # One-hot encoding
        
        # sheaf features: X matrix will have dimensions (num_nodes * stalk_dim, num_features)
        self.x = torch.zeros((self.num_nodes * self.stalk_dim, self.num_nodes))
        for i in range(self.num_nodes):
            for j in range(self.stalk_dim):
                self.x[i * self.stalk_dim + j] = node_features[i]
        
        # Construct the sheaf Laplacian
        # block diagonal structure for the restriction maps
        self.restriction_maps = {}
        
        # Identity 
        identity = torch.eye(stalk_dim)
        
        # For each edge, create a restriction map
        for idx, (u, v) in enumerate(edges):
            #  restriction maps that are either identity or rotations
            angle = torch.rand(1).item() * 2 * np.pi
            rotation = torch.tensor([
                [np.cos(angle), -np.sin(angle)],
                [np.sin(angle), np.cos(angle)]
            ], dtype=torch.float)
            
            # Store restriction maps
            self.restriction_maps[(u, v)] = rotation
            self.restriction_maps[(v, u)] = rotation.transpose(0, 1)  # Transpose for opposite direction
        
        # sheaf Laplacian L_F
        # L_F[i*k:(i+1)*k, j*k:(j+1)*k] = degree[i] * Identity if i=j
        # L_F[i*k:(i+1)*k, j*k:(j+1)*k] = -restriction_maps[(i,j)] if (i,j) is an edge
        self.L_F = torch.zeros((self.num_nodes * stalk_dim, self.num_nodes * stalk_dim))
        
        # Diagonal blocks (degree * identity)
        for i in range(self.num_nodes):
            degree_i = degrees[i].item()
            self.L_F[i*stalk_dim:(i+1)*stalk_dim, i*stalk_dim:(i+1)*stalk_dim] = degree_i * identity
        
        # -restriction_maps
        for (u, v), restriction_map in self.restriction_maps.items():
            self.L_F[u*stalk_dim:(u+1)*stalk_dim, v*stalk_dim:(v+1)*stalk_dim] = -restriction_map
        
        # Normalize the Laplacian get diffusion operator D_F = I - (1/dmax) * L_F
        self.D_F = torch.eye(self.num_nodes * stalk_dim) - (1.0 / self.max_degree) * self.L_F
        
        # Get real labels
        self.labels = torch.zeros(self.num_nodes, dtype=torch.long)
        for i in range(self.num_nodes):
            self.labels[i] = G.nodes[i]['club'] == 'Officer'
    
    def __len__(self):
        return 1
    
    def __getitem__(self, idx):
        return self.x, self.D_F, self.labels, self.stalk_dim

In [3]:
dataset = SheafKarateClubDataset(stalk_dim=2)
loader = DataLoader(dataset, batch_size=1, shuffle=False)

In [4]:
for x, D_F, labels, stalk_dim in loader:
    print("Sheaf features shape:", x.shape)
    print("Diffusion operator shape:", D_F.shape)
    print("Labels shape:", labels.shape)
    print("Stalk dimension:", stalk_dim.item())
    print("Label distribution:", labels.squeeze().bincount())

Sheaf features shape: torch.Size([1, 68, 34])
Diffusion operator shape: torch.Size([1, 68, 68])
Labels shape: torch.Size([1, 34])
Stalk dimension: 2
Label distribution: tensor([17, 17])


In [14]:
class SheafConv(nn.Module):
    def __init__(self, in_features, out_features, stalk_dim):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.stalk_dim = stalk_dim
        
        # Learnable parameters:
        # A: Linear transformation for features (Ninfeat × Noutfeat matrix)
        self.A = nn.Parameter(torch.Tensor(in_features, out_features))
        
        # B: Stalk-wise transformation (k × k matrix)
        self.B = nn.Parameter(torch.Tensor(stalk_dim, stalk_dim))
        
        self.reset_parameters()
    
    def reset_parameters(self):
        nn.init.xavier_uniform_(self.A)
        nn.init.xavier_uniform_(self.B)
    
    def forward(self, X, D_F):
       
        num_nodes = X.size(0) // self.stalk_dim
        
        #  (I ⊗ B)X
        # Reshape X to separate nodes and stalk dimensions
        X_reshaped = X.view(num_nodes, self.stalk_dim, self.in_features)
        
        # Apply B to each node's stalk
        #  multiplying by the Kronecker product I ⊗ B
        X_B = torch.bmm(self.B.unsqueeze(0).expand(num_nodes, -1, -1), 
                        X_reshaped)
        
        # Reshape to original dimensions
        X_B = X_B.view(num_nodes * self.stalk_dim, self.in_features)
        
        #   X_B·A
        X_BA = torch.mm(X_B, self.A)
        
        # diffusion operator D_F·(X_BA), D_F = I -(1/dmax)*L_F 
        X_DBA = torch.mm(D_F, X_BA)
        
        # nonlinearity ρ
        # ReLU as nonlinearity
        output = F.relu(X_DBA)
        
        return output

In [6]:
class SheafNN(nn.Module):
    def __init__(self, in_features, hidden_dim, out_features, stalk_dim, num_nodes):
        super().__init__()
        self.stalk_dim = stalk_dim
        self.num_nodes = num_nodes
        
        # First sheaf convolutional layer
        self.sheaf_conv1 = SheafConv(in_features, hidden_dim, stalk_dim)
        
        # Second sheaf convolutional layer
        self.sheaf_conv2 = SheafConv(hidden_dim, hidden_dim, stalk_dim)
        
        # Output layer (stalk_dim * hidden_dim -> out_features)
        self.fc = nn.Linear(hidden_dim, out_features)
    
    def forward(self, x, D_F):
        # Apply first sheaf convolution
        x = self.sheaf_conv1(x, D_F)
        
        # Apply second sheaf convolution
        x = self.sheaf_conv2(x, D_F)
        
        # Pool features across stalks to get node representations
        # Reshape to separate nodes and stalk dimensions
        x = x.view(self.num_nodes, self.stalk_dim, -1)
        
        # Average across stalk dimensions
        x = torch.mean(x, dim=1)
        
        # App final linear layer for classification
        x = self.fc(x)
        
        # softmax 
        return F.softmax(x, dim=1)

In [7]:
for x, D_F, labels, stalk_dim in loader:
    x = x.squeeze(0)
    D_F = D_F.squeeze(0)
    labels = labels.squeeze(0)
    stalk_dim = stalk_dim.item()
    
    num_nodes = labels.size(0)
    in_features = x.size(1)  # Number of input features
    
    model = SheafNN(in_features=in_features, 
                    hidden_dim=16, 
                    out_features=2, 
                    stalk_dim=stalk_dim,
                    num_nodes=num_nodes)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    loss_fn = torch.nn.CrossEntropyLoss()

In [10]:
loss_history = []
accuracy_history = []
    
for epoch in range(200):
        # Forward pass
        model.train()
        optimizer.zero_grad()
        out = model(x, D_F)
        loss = loss_fn(out, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Calculate accuracy
        _, predicted = torch.max(out.data, 1)
        correct = (predicted == labels).sum().item()
        accuracy = correct / labels.size(0)
        
       
        loss_history.append(loss.item())
        accuracy_history.append(accuracy)
        
        # Print 
        if epoch % 5 == 0:
            print(f"Epoch {epoch:3d} | Loss: {loss.item():.4f} | Accuracy: {accuracy:.4f}")

Epoch   0 | Loss: 0.6945 | Accuracy: 0.5000
Epoch   5 | Loss: 0.6608 | Accuracy: 0.6765
Epoch  10 | Loss: 0.5971 | Accuracy: 0.9118
Epoch  15 | Loss: 0.5004 | Accuracy: 0.9706
Epoch  20 | Loss: 0.4132 | Accuracy: 0.9706
Epoch  25 | Loss: 0.3705 | Accuracy: 0.9706
Epoch  30 | Loss: 0.3528 | Accuracy: 0.9706
Epoch  35 | Loss: 0.3423 | Accuracy: 0.9706
Epoch  40 | Loss: 0.3335 | Accuracy: 1.0000
Epoch  45 | Loss: 0.3276 | Accuracy: 1.0000
Epoch  50 | Loss: 0.3229 | Accuracy: 1.0000
Epoch  55 | Loss: 0.3193 | Accuracy: 1.0000
Epoch  60 | Loss: 0.3168 | Accuracy: 1.0000
Epoch  65 | Loss: 0.3154 | Accuracy: 1.0000
Epoch  70 | Loss: 0.3146 | Accuracy: 1.0000
Epoch  75 | Loss: 0.3141 | Accuracy: 1.0000
Epoch  80 | Loss: 0.3139 | Accuracy: 1.0000
Epoch  85 | Loss: 0.3137 | Accuracy: 1.0000
Epoch  90 | Loss: 0.3136 | Accuracy: 1.0000
Epoch  95 | Loss: 0.3136 | Accuracy: 1.0000
Epoch 100 | Loss: 0.3135 | Accuracy: 1.0000
Epoch 105 | Loss: 0.3135 | Accuracy: 1.0000
Epoch 110 | Loss: 0.3135 | Accur