In [10]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m41.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.7.0


In [11]:
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torch_geometric.nn import HypergraphConv, AttentionalAggregation


In [12]:
transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2470, 0.2435, 0.2616]
    )
    ])
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

print(train_dataset.data.shape)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

(50000, 32, 32, 3)


# image->hypergraph

In [16]:
def image_to_dynamic_hypergraph(images, k_spatial=4, k_feature=4):
    batch_node_feats = []
    batch_edge_index = []
    batch_map = []
    node_offset = 0

    for b, img in enumerate(images):
        # img: [C,H,W] -> patches
        C, H, W = img.shape
        patch_size = 8  # 8x8 patches for CIFAR-10
        patches = img.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
        patches = patches.contiguous().view(C, -1, patch_size, patch_size)
        patches = patches.view(patches.size(1), -1)  # flatten each patch to vector

        num_nodes = patches.size(0)
        node_feats = patches

        # Spatial eges (static)
        spatial_edges = []
        coords = np.array([[i // (W // patch_size), i % (W // patch_size)] for i in range(num_nodes)])
        for i in range(num_nodes):
            dists = np.sum((coords - coords[i])**2, axis=1)
            nn_idx = np.argsort(dists)[1:k_spatial+1]
            for j in nn_idx:
                spatial_edges.append([i, j])

        # Feature hyperedges (dynamic, computed on GPU)
        feats = node_feats
        feats = feats.to(images.device)
        dists_feat = torch.cdist(feats.float(), feats.float(), p=2)
        feature_edges = []
        for i in range(num_nodes):
            nn_idx = torch.topk(dists_feat[i], k=k_feature+1, largest=False).indices[1:]
            for j in nn_idx:
                feature_edges.append([i, j])

        all_edges = torch.tensor(spatial_edges + feature_edges, dtype=torch.long, device=images.device).T
        batch_node_feats.append(node_feats)
        batch_edge_index.append(all_edges + node_offset)
        batch_map.append(torch.full((num_nodes,), b, dtype=torch.long, device=images.device))

        node_offset += num_nodes

    x = torch.cat(batch_node_feats, dim=0).float()
    edge_index = torch.cat(batch_edge_index, dim=1)
    batch_map = torch.cat(batch_map)
    return x, edge_index, batch_map

In [14]:
class HyperVigClassifier(nn.Module):
  def __init__(self, in_channels, hidden, num_classes):
    super().__init__()
    self.conv1 = HypergraphConv(in_channels, hidden)
    self.conv2 = HypergraphConv(hidden, hidden)
    self.conv3 = HypergraphConv(hidden, hidden)
    self.pool = AttentionalAggregation(gate_nn=nn.Sequential(nn.Linear(hidden, hidden), nn.ReLU(), nn.Linear(hidden, 1)))
    self.classifier = nn.Linear(hidden, num_classes)

  def forward(self, x, edge_index, batch_map):
    x = self.conv1(x, edge_index)
    x = F.relu(x)
    x = self.conv2(x, edge_index)
    x = F.relu(x)
    x = self.conv3(x, edge_index)
    x = F.relu(x)
    out = self.pool(x, batch_map) #attention pooling
    out = self.classifier(out)
    return out

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HyperVigClassifier(in_channels=3*8*8, hidden=256, num_classes=10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

for epoch in range(30):
  model.train()
  total_loss = 0
  correct = 0
  total = 0
  for images, labels in train_loader:
    images, labels = images.to(device), labels.to(device)
    node_feats, edge_index, batch_map = image_to_dynamic_hypergraph(images)
    node_feats, edge_index, batch_map = node_feats.to(device), edge_index.to(device), batch_map.to(device)
    optimizer.zero_grad()
    outputs = model(node_feats, edge_index, batch_map)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    total_loss += loss.item() * images.size(0)
    _, predicted = outputs.max(1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

  print(f"Epoch {epoch+1}, Loss: {total_loss/total}, Accuracy: {correct/total}")

Epoch 1, Loss: 2.081924643859863, Accuracy: 0.23656
Epoch 2, Loss: 1.9771157432556152, Accuracy: 0.275
Epoch 3, Loss: 1.9334857147216797, Accuracy: 0.29718
Epoch 4, Loss: 1.8997729174041749, Accuracy: 0.30732
Epoch 5, Loss: 1.8751634169769287, Accuracy: 0.31582
Epoch 6, Loss: 1.856643161430359, Accuracy: 0.32562
