In [None]:

# Hypergraph over Pretrained Network Features (CIFAR-10 + HGNN)

# ===============================
# STEP 0: INSTALL DEPENDENCIES
# ===============================
!pip install torch torchvision torch-geometric dhg matplotlib --quiet

# ===============================
# STEP 1: IMPORTS
# ===============================
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import dhg
import dhg.nn.pyg as dhg_nn
import matplotlib.pyplot as plt
import networkx as nx

# ===============================
# STEP 2: SETUP DEVICE
# ===============================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ===============================
# STEP 3: LOAD CIFAR10 & PREPROCESS
# ===============================
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Load only 500 samples from test set for quick demo
test_set = CIFAR10(root="./data", train=False, transform=transform, download=True)
test_loader = DataLoader(test_set, batch_size=500, shuffle=False)
images, labels = next(iter(test_loader))
labels = labels.to(device)

# ===============================
# STEP 4: PRETRAINED RESNET FEATURE EXTRACTOR
# ===============================
resnet = models.resnet18(pretrained=True)
resnet.fc = nn.Identity()  # Remove final classification layer
resnet = resnet.to(device)
resnet.eval()

with torch.no_grad():
    images = images.to(device)
    features = resnet(images)  # [500, 512]

# ===============================
# STEP 5: BUILD HYPERGRAPH (cosine kNN)
# ===============================
def build_incidence(features, k=5):
    sims = cosine_similarity(features.cpu().numpy())
    n = sims.shape[0]
    H = []
    for i in range(n):
        top_k = np.argsort(-sims[i])[:k+1]  # include self
        H.append(set(top_k))
    return H

incidence_sets = build_incidence(features, k=5)
hg = dhg.Hypergraph(num_v=features.shape[0], e_list=incidence_sets).to(device)

# ===============================
# STEP 6: HYPERGRAPH MODEL
# ===============================
class HGNN_Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.hgc1 = dhg_nn.HGNNConv(in_dim, hidden_dim)
        self.hgc2 = dhg_nn.HGNNConv(hidden_dim, out_dim)

    def forward(self, x, hg):
        x = self.hgc1(x, hg)
        x = F.relu(x)
        x = self.hgc2(x, hg)
        return x

model = HGNN_Classifier(512, 256, 10).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

features = features.detach()
labels = labels.detach()

# ===============================
# STEP 7: TRAINING LOOP
# ===============================
epochs = 30
for epoch in range(epochs):
    model.train()
    out = model(features, hg)
    loss = loss_fn(out, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    pred = out.argmax(dim=1)
    acc = (pred == labels).float().mean().item()

    print(f"Epoch {epoch+1:02d} | Loss: {loss.item():.4f} | Accuracy: {acc*100:.2f}%")

# ===============================
# STEP 8: VISUALIZE HYPERGRAPH STRUCTURE
# ===============================
def plot_hypergraph(hg, num_nodes=30):
    """ Visualizes the first few nodes and hyperedges """
    G = nx.Graph()
    H = hg.e_list[:num_nodes]  # Only visualize subset

    for i, e in enumerate(H):
        edge_name = f"e{i}"
        for node in e:
            G.add_edge(edge_name, f"v{node}")

    plt.figure(figsize=(10, 6))
    nx.draw(G, with_labels=True, node_color='skyblue', node_size=500, edge_color='gray')
    plt.title("Visualizing Hypergraph Structure (Edges e* connected to Nodes v*)")
    plt.show()

plot_hypergraph(hg, num_nodes=20)
