In [3]:
!pip install torch-geometric



In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import CIFAR10
import torchvision.transforms as T
import numpy as np
import scipy.sparse
from torch_geometric.utils import dense_to_sparse
from torch_geometric.nn import HypergraphConv, AttentionalAggregation
from torch.utils.data import DataLoader
import networkx as nx
import torch_geometric.utils as pyg_utils
from networkx.algorithms.community import greedy_modularity_communities

dataset prep

In [17]:
transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2470, 0.2435, 0.2616]
    )
    ])
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

print(train_dataset.data.shape)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

(50000, 32, 32, 3)


convert PyG edge_index to networkx graph & compute graph metrics

In [18]:
def edge_index_to_nx(edge_index):
  edge_index = edge_index.cpu().numpy()
  G = nx.Graph()
  for i, j in edge_index.T:
    G.add_edge(int(i), int(j))
  return G

def graph_metrics(edge_index):
    G = edge_index_to_nx(edge_index)

    # Number of nodes & edges
    num_nodes = G.number_of_nodes()
    num_edges = G.number_of_edges()

    # Avg node degree, min degree, max degree
    degrees = [d for n, d in G.degree()]
    avg_degree = sum(degrees)/num_nodes
    max_degree = max(degrees)
    min_degree = min(degrees)

    # Clustering coefficient
    clustering_coeff = nx.average_clustering(G)

    # Path-based metrics
    connected_components = [G.subgraph(c).copy() for c in nx.connected_components(G)]

    # Largest connected component (LCC)
    lcc = max(connected_components, key=lambda x: x.number_of_nodes())
    lcc_metrics = {
        'LCC_avg_path_length': nx.average_shortest_path_length(lcc),
        'LCC_diameter': nx.diameter(lcc),
        'LCC_radius': nx.radius(lcc),
        'LCC_eccentricity': nx.eccentricity(lcc),
        'LCC_center': nx.center(lcc)
    }

   # Average across all connected components (AvgCC)
    avg_path_lengths = []
    diameters = []
    radii = []
    for c in connected_components:
        avg_path_lengths.append(nx.average_shortest_path_length(c))
        diameters.append(nx.diameter(c))
        radii.append(nx.radius(c))
    avg_cc_metrics = {
        'avg_path_length': sum(avg_path_lengths)/len(avg_path_lengths),
        'avg_diameter': sum(diameters)/len(diameters),
        'avg_radius': sum(radii)/len(radii)
    }

    # Centrality measures
    betweenness_centrality = nx.betweenness_centrality(G)  # compute on full graph
    closeness_centrality = nx.closeness_centrality(G)      # compute on full graph

    # Modularity
    try:
        communities = list(greedy_modularity_communities(G))
        modularity = nx.algorithms.community.modularity(G, communities)
    except Exception as e:
        modularity = None
        print("Could not compute modularity:", e)

    # Combine all metrics
    metrics = {
        'num_nodes': num_nodes,
        'num_edges': num_edges,
        'min_degree': min_degree,
        'max_degree': max_degree,
        'betweenness_centrality': betweenness_centrality,
        'closeness_centrality': closeness_centrality,
        'modularity': modularity,

        'avg_degree': avg_degree,
        'avg_clustering_coefficient': clustering_coeff
    }
    metrics.update(avg_cc_metrics)
    metrics.update(lcc_metrics)

    return metrics

def print_graph_metrics(metrics):
    for key, value in metrics.items():
        print(f"{key}: {value}")



image -> hypergraph

In [19]:
def image_to_hypergraph(images, k_spatial=4, k_feature=4):
    """
    Convert a batch of images to hypergraph format.
    Combines spatial + feature kNN hyperedges.
    """
    batch_node_feats = []
    batch_edge_index = []
    batch_map = []
    node_offset = 0

    for b, img in enumerate(images):
        # img: [C,H,W] -> patches
        C, H, W = img.shape
        patch_size = 8  # 8x8 patches for CIFAR-10
        patches = img.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
        patches = patches.contiguous().view(C, -1, patch_size, patch_size)
        patches = patches.view(patches.size(1), -1)  # flatten each patch to vector

        num_nodes = patches.size(0)
        node_feats = patches

        # Spatial hyperedges: connect each patch to k nearest spatial neighbors
        spatial_edges = []
        coords = np.array([[i // (W // patch_size), i % (W // patch_size)] for i in range(num_nodes)])
        for i in range(num_nodes):
            dists = np.sum((coords - coords[i])**2, axis=1)
            nn_idx = np.argsort(dists)[1:k_spatial+1]
            for j in nn_idx:
                spatial_edges.append([i, j])

        # Feature hyperedges: connect each patch to k nearest neighbors in feature space
        feats = node_feats.cpu().numpy()
        feature_edges = []
        for i in range(num_nodes):
            dists = np.sum((feats - feats[i])**2, axis=1)
            nn_idx = np.argsort(dists)[1:k_feature+1]
            for j in nn_idx:
                feature_edges.append([i, j])

        all_edges = np.array(spatial_edges + feature_edges).T
        edge_index = torch.tensor(all_edges, dtype=torch.long)

        batch_node_feats.append(node_feats)
        batch_edge_index.append(edge_index + node_offset)
        batch_map.append(torch.full((num_nodes,), b, dtype=torch.long))

        node_offset += num_nodes

    x = torch.cat(batch_node_feats, dim=0).float()
    edge_index = torch.cat(batch_edge_index, dim=1)
    batch_map = torch.cat(batch_map)
    return x, edge_index, batch_map

In [20]:
class HyperVigClassifier(nn.Module):
  def __init__(self, in_channels, hidden, num_classes):
    super().__init__()
    self.conv1 = HypergraphConv(in_channels, hidden)
    self.conv2 = HypergraphConv(hidden, hidden)
    self.conv3 = HypergraphConv(hidden, hidden)
    self.pool = AttentionalAggregation(gate_nn=nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1)))
    self.classifier = nn.Linear(hidden, num_classes)

  def forward(self, x, edge_index, batch_map):
    x = self.conv1(x, edge_index)
    x = F.relu(x)
    x = self.conv2(x, edge_index)
    x = F.relu(x)
    x = self.conv3(x, edge_index)
    x = F.relu(x)
    out = self.pool(x, batch_map) #attention pooling
    out = self.classifier(out)
    return out

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HyperVigClassifier(in_channels=3*8*8, hidden=256, num_classes=10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(30):
  model.train()
  total_loss = 0
  correct = 0
  total = 0
  for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    node_feats, edge_index, batch_map = image_to_hypergraph(images)
    node_feats, edge_index, batch_map = node_feats.to(device), edge_index.to(device), batch_map.to(device)
    optimizer.zero_grad()
    outputs = model(node_feats, edge_index, batch_map)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    total_loss += loss.item() * images.size(0)
    _, predicted = outputs.max(1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

  print(f"Epoch {epoch+1}, Loss: {total_loss/total}, Accuracy: {correct/total}")

Epoch 1, Loss: 2.088416355819702, Accuracy: 0.23224
Epoch 2, Loss: 1.9834194818878175, Accuracy: 0.27412
Epoch 3, Loss: 1.9389868724441528, Accuracy: 0.29322
Epoch 4, Loss: 1.910102783355713, Accuracy: 0.30436
Epoch 5, Loss: 1.8843680082702636, Accuracy: 0.31322
Epoch 6, Loss: 1.8610284884643555, Accuracy: 0.32314
Epoch 7, Loss: 1.8422677359771729, Accuracy: 0.32924
Epoch 8, Loss: 1.8205385348129273, Accuracy: 0.3364
Epoch 9, Loss: 1.801078766975403, Accuracy: 0.34776
Epoch 10, Loss: 1.7878248051071166, Accuracy: 0.35226
Epoch 11, Loss: 1.7628574675750732, Accuracy: 0.36102
Epoch 12, Loss: 1.7463323606872558, Accuracy: 0.3656
Epoch 13, Loss: 1.7256840154266357, Accuracy: 0.37608
Epoch 14, Loss: 1.7043174792671203, Accuracy: 0.38514
Epoch 15, Loss: 1.689911383934021, Accuracy: 0.38966
Epoch 16, Loss: 1.6714112515449524, Accuracy: 0.39566
Epoch 17, Loss: 1.649961003074646, Accuracy: 0.40186
Epoch 18, Loss: 1.6301556049537658, Accuracy: 0.4124
Epoch 19, Loss: 1.610771968421936, Accuracy: 

**Metrics:**

LCC - Largest Connected Component

**Averaged metrics are averaged over all connected components

In [81]:
metrics = graph_metrics(edge_index)
print_graph_metrics(metrics)

num_nodes: 256
num_edges: 969
min_degree: 6
max_degree: 15
betweenness_centrality: {0: 7.71962328238382e-06, 3: 0.0005257921191223646, 7: 0.0005113821556619148, 11: 0.0007163810406052185, 12: 4.717547561456779e-05, 1: 0.0, 2: 4.734702279862076e-05, 13: 0.0001957353370044431, 4: 1.0292831043178425e-05, 6: 1.0292831043178425e-05, 5: 2.916302128900554e-05, 8: 1.8012454325562245e-05, 9: 1.3895321908290876e-05, 14: 3.259396497006501e-05, 15: 1.8870190245827115e-05, 10: 7.71962328238382e-06, 16: 0.0, 17: 0.0, 18: 0.00022644228294992527, 19: 0.00022644228294992527, 20: 0.00022644228294992527, 24: 0.001595388811692656, 21: 2.058566208635685e-05, 25: 0.0, 26: 0.00010880992817074336, 27: 0.00018674136321195143, 30: 0.0005175823610284009, 22: 0.00023305910290625435, 23: 0.00025879118051420045, 28: 0.0002220310696457061, 29: 9.263547938860583e-05, 31: 0.0004389157237698229, 32: 0.0, 34: 0.0013586536976995522, 36: 0.0, 40: 0.0, 44: 0.0, 33: 0.0005038340795635838, 35: 0.0015516442797591477, 37: 0.00