In [21]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt


In [22]:
from torch_geometric.data import Data
def get_karate_club_data():
    G = nx.karate_club_graph()
    num_nodes = G.number_of_nodes()
    x = torch.eye(num_nodes, dtype=torch.float)
    edge_index = torch.tensor(list(G.edges), dtype=torch.long).t().contiguous()
    edge_index = torch.cat([edge_index, edge_index[[1, 0]]], dim=1)  # undirected
 
    label_map = {'Mr. Hi': 0, 'Officer': 1}
    y = torch.tensor([label_map[G.nodes[i]['club']] for i in range(num_nodes)], dtype=torch.long)
 
    perm = torch.randperm(num_nodes)
    train_mask = torch.zeros(num_nodes, dtype=torch.bool)
    test_mask = torch.zeros(num_nodes, dtype=torch.bool)
    train_mask[perm[:20]] = True
    test_mask[perm[20:]] = True
 
    data = Data(x=x, edge_index=edge_index, y=y, train_mask=train_mask, test_mask=test_mask)
    return data

In [23]:
'''import torch
import networkx as nx
from torch.utils.data import Dataset

class KarateClubDataset(Dataset):
    def __init__(self):
        G = nx.karate_club_graph()
        self.num_nodes = G.number_of_nodes()
        self.x = torch.eye(self.num_nodes)  # One-hot encoding

        # Get edge list 
        edges = np.array(G.edges())
        edge_index = np.array(edges).T
        edge_index = np.concatenate((edge_index, edge_index[::-1]), axis=1)
        self.edge_index = torch.tensor(edge_index, dtype=torch.long)

        self.labels = torch.zeros(self.num_nodes, dtype=torch.long)
        for i in range(self.num_nodes):
            # Get the club label from networkx (0 for Mr. Hi's group, 1 for Officer's group)
            self.labels[i] = G.nodes[i]['club'] == 'Officer'

    def __len__(self):
        return 1  

    def __getitem__(self, idx):
        return self.x, self.edge_index'''


"import torch\nimport networkx as nx\nfrom torch.utils.data import Dataset\n\nclass KarateClubDataset(Dataset):\n    def __init__(self):\n        G = nx.karate_club_graph()\n        self.num_nodes = G.number_of_nodes()\n        self.x = torch.eye(self.num_nodes)  # One-hot encoding\n\n        # Get edge list \n        edges = np.array(G.edges())\n        edge_index = np.array(edges).T\n        edge_index = np.concatenate((edge_index, edge_index[::-1]), axis=1)\n        self.edge_index = torch.tensor(edge_index, dtype=torch.long)\n\n        self.labels = torch.zeros(self.num_nodes, dtype=torch.long)\n        for i in range(self.num_nodes):\n            # Get the club label from networkx (0 for Mr. Hi's group, 1 for Officer's group)\n            self.labels[i] = G.nodes[i]['club'] == 'Officer'\n\n    def __len__(self):\n        return 1  \n\n    def __getitem__(self, idx):\n        return self.x, self.edge_index"

In [24]:
dataset = KarateClubDataset()
loader = DataLoader(dataset, batch_size=1, shuffle=False)

# For testing
for x, edge_index in loader:
    print("Node features shape:", x.shape)
    print("Edge index shape:", edge_index.shape)


NameError: name 'KarateClubDataset' is not defined

In [28]:
class SheafConv(torch.nn.Module):
    def __init__(self, num_features, hidden_channels, output_dim, stalk_dim=1):
        """
        SCN layer
        
            num_features: Number of input features
            hidden_channels: Number of hidden channels
            output_dim: Output dimension
            stalk_dim: Dimension k of stalks 
        """
        super().__init__()
        self.stalk_dim = stalk_dim
        
        # Linear transformation (Matrix A from equation)
        self.A0 = torch.nn.Parameter(torch.randn(num_features, hidden_channels))
        self.A1 = torch.nn.Parameter(torch.randn(hidden_channels, output_dim))
        
        # Stalkwise transformation (Matrix B from equation)
        if stalk_dim > 1:
            self.B0 = torch.nn.Parameter(torch.randn(stalk_dim, stalk_dim))
            self.B1 = torch.nn.Parameter(torch.randn(stalk_dim, stalk_dim))
        else:
            self.register_parameter('B0', None)
            self.register_parameter('B1', None)
    
    def construct_sheaf_laplacian(self, edge_indices, num_nodes):
        """
       sheaf Laplacian LF
        
            edge_indices: Tensor of shape [2, num_edges] with edge indices
            num_nodes: # Nodes in the graph
        """
        # standard graph Laplacian
        
        # adjacency matrix
        A = torch.zeros((num_nodes, num_nodes), device=edge_indices.device)
        A[edge_indices[0], edge_indices[1]] = 1
        
        # self-loops
        A += torch.eye(num_nodes, device=edge_indices.device)
        
        # degree matrix
        D = torch.diag(A.sum(1))
        
        # graph Laplacian
        L = D - A
        
        # If stalk dime > 1, block Laplacian
        if self.stalk_dim > 1:
            # block diagonal restriction maps (identity for this)
            restriction_maps = torch.eye(self.stalk_dim, device=edge_indices.device)
            
            # L to block Laplacian
            L_expanded = torch.zeros((num_nodes * self.stalk_dim, num_nodes * self.stalk_dim), 
                                     device=edge_indices.device)
            
            for i in range(num_nodes):
                for j in range(num_nodes):
                    if L[i, j] != 0:
                        # Place restriction map * L[i,j] in the corresponding block
                        block_i = i * self.stalk_dim
                        block_j = j * self.stalk_dim
                        L_expanded[block_i:block_i+self.stalk_dim, 
                                  block_j:block_j+self.stalk_dim] = L[i, j] * restriction_maps
            
            L = L_expanded
        
        # max degree for normalization
        d_max = A.sum(1).max().item()
        
        # diffusion operator DF = I - (1/d_max) * LF
        D_F = torch.eye(L.size(0), device=edge_indices.device) - (1.0 / d_max) * L
        
        return D_F
    
    def sheaf_diffusion(self, x, D_F):
        """
        Apply sheaf diffusion
    
            x: Node features tensor [num_nodes * stalk_dim, num_features] if stalk_dim > 1
               or [num_nodes, num_features] if stalk_dim = 1
            D_F: Sheaf diffusion operator
        """
        return D_F @ x
    
    def apply_stalkwise_linear(self, x, A, B, num_nodes):
        """
        stalkwise linear transformation (I ⊗ B)XA
        
            x: Features tensor
            A: Feature transformation matrix
            B: Stalk transformation matrix ( None if stalk_dim = 1)
            num_nodes: Number of nodes
        """
        if self.stalk_dim == 1:
            # apply A to features
            return x @ A
        else:
            # separate nodes and stalks
            x_reshaped = x.view(num_nodes, self.stalk_dim, -1)
            
            # Apply B (left multiplication)
            x_B = torch.bmm(B.unsqueeze(0).expand(num_nodes, -1, -1), x_reshaped)
            
            # Reshape back and apply A (right multiplication)
            x_flat = x_B.view(num_nodes * self.stalk_dim, -1)
            return x_flat @ A
    
    def forward(self, x, edge_index):
        num_nodes = x.size(0) // self.stalk_dim if self.stalk_dim > 1 else x.size(0)
        
        # sheaf diffusion operator
        D_F = self.construct_sheaf_laplacian(edge_index, num_nodes)
        
        # First layer: stalkwise linear transformation,sheaf diffusion, then ReLU
        h1 = self.apply_stalkwise_linear(x, self.A0, self.B0, num_nodes)
        h1_diffused = self.sheaf_diffusion(h1, D_F)
        h1_activated = h1_diffused.relu()
        
        # Second layer: stalkwise linear transformation, sheaf diffusion, then softmax
        h2 = self.apply_stalkwise_linear(h1_activated, self.A1, self.B1, num_nodes)
        h_out = self.sheaf_diffusion(h2, D_F)
        
        # If stalk_dim > 1, reshape to get [num_nodes, stalk_dim * output_dim]
        if self.stalk_dim > 1:
            h_out = h_out.view(num_nodes, self.stalk_dim, -1)
            h_out = h_out.reshape(num_nodes, -1)
        
        return h_out.softmax(dim=1)

In [29]:
import networkx as nx
data = get_karate_club_data()
model = SheafConv(num_features=data.num_node_features, hidden_channels=16, output_dim=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()
 
for epoch in range(300):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
 
    loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
 
    model.eval()
    _, pred = out.max(dim=1)
    correct = int((pred[data.test_mask] == data.y[data.test_mask]).sum())
    acc = correct / int(data.test_mask.sum())

    if epoch % 10 == 0:
        print(f"Epoch {epoch:3d} | Loss: {loss.item():.4f} | Test Acc: {acc:.4f}")
        

Epoch   0 | Loss: 0.8928 | Test Acc: 0.2857
Epoch  10 | Loss: 0.7316 | Test Acc: 0.5000
Epoch  20 | Loss: 0.5778 | Test Acc: 0.6429
Epoch  30 | Loss: 0.4608 | Test Acc: 0.8571
Epoch  40 | Loss: 0.3805 | Test Acc: 0.8571
Epoch  50 | Loss: 0.3447 | Test Acc: 0.9286
Epoch  60 | Loss: 0.3294 | Test Acc: 0.9286
Epoch  70 | Loss: 0.3227 | Test Acc: 0.9286
Epoch  80 | Loss: 0.3194 | Test Acc: 0.9286
Epoch  90 | Loss: 0.3177 | Test Acc: 0.9286
Epoch 100 | Loss: 0.3167 | Test Acc: 0.9286
Epoch 110 | Loss: 0.3160 | Test Acc: 0.9286
Epoch 120 | Loss: 0.3156 | Test Acc: 0.9286
Epoch 130 | Loss: 0.3152 | Test Acc: 0.9286
Epoch 140 | Loss: 0.3149 | Test Acc: 0.9286
Epoch 150 | Loss: 0.3147 | Test Acc: 0.9286
Epoch 160 | Loss: 0.3146 | Test Acc: 0.9286
Epoch 170 | Loss: 0.3144 | Test Acc: 0.9286
Epoch 180 | Loss: 0.3143 | Test Acc: 0.9286
Epoch 190 | Loss: 0.3142 | Test Acc: 0.9286
Epoch 200 | Loss: 0.3141 | Test Acc: 0.9286
Epoch 210 | Loss: 0.3140 | Test Acc: 0.9286
Epoch 220 | Loss: 0.3139 | Test 