In [1]:
!pip install torch-geometric

Collecting torch-geometric
  Downloading torch_geometric-2.7.0-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.7/63.7 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.7.0-py3-none-any.whl (1.3 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.3 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m38.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch-geometric
Successfully installed torch-geometric-2.7.0


In [2]:
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as T
import torch.nn.functional as F
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from torch_geometric.nn import HypergraphConv, AttentionalAggregation



In [3]:
transform = T.Compose([
    T.ToTensor(),
    T.Normalize(
        mean=[0.4914, 0.4822, 0.4465],
        std=[0.2470, 0.2435, 0.2616]
    )
    ])
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

print(train_dataset.data.shape)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

100%|██████████| 170M/170M [00:03<00:00, 48.3MB/s]


(50000, 32, 32, 3)


# image->hypergraph

In [4]:
def image_to_adaptive_hypergraph(images, k_spatial, k_feature=None):
    """
    Build hypergraph with specified k values (can vary per layer).
    If k_feature is None, only use spatial edges.
    """
    batch_node_feats = []
    batch_hyperedge_index = []
    batch_map = []
    node_offset = 0
    edge_id = 0

    for b, img in enumerate(images):
        C, H, W = img.shape
        patch_size = 8

        # Extract patches
        patches = img.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
        patches = patches.permute(1, 2, 0, 3, 4).contiguous()
        patches = patches.view(-1, C, patch_size, patch_size)
        node_feats = patches.view(patches.size(0), -1).to(images.device)
        num_nodes = node_feats.size(0)

        # Spatial coordinates
        num_patches_side = W // patch_size
        coords = torch.tensor([
            [i // num_patches_side, i % num_patches_side]
            for i in range(num_nodes)
        ], device=images.device, dtype=torch.float)

        dists_spatial = torch.cdist(coords, coords, p=2)

        hyperedge_list = []

        # Spatial hyperedges
        for i in range(num_nodes):
            nn_idx = torch.topk(dists_spatial[i], k=k_spatial+1, largest=False).indices
            for node in nn_idx:
                hyperedge_list.append([node.item() + node_offset, edge_id])
            edge_id += 1

        # Feature hyperedges (optional, for later layers)
        if k_feature is not None and k_feature > 0:
            dists_feat = torch.cdist(node_feats, node_feats, p=2)
            for i in range(num_nodes):
                nn_idx = torch.topk(dists_feat[i], k=k_feature+1, largest=False).indices
                for node in nn_idx:
                    hyperedge_list.append([node.item() + node_offset, edge_id])
                edge_id += 1

        batch_node_feats.append(node_feats)
        batch_hyperedge_index.extend(hyperedge_list)
        batch_map.append(torch.full((num_nodes,), b, dtype=torch.long, device=images.device))
        node_offset += num_nodes

    x = torch.cat(batch_node_feats, dim=0).float()
    hyperedge_index = torch.tensor(batch_hyperedge_index, dtype=torch.long, device=images.device).T
    batch_map = torch.cat(batch_map)

    return x, hyperedge_index, batch_map


class AdaptiveHypergraphBlock(nn.Module):
    """Hypergraph block that can use different k values"""
    def __init__(self, hidden_dim, dropout=0.3, ffn_expansion=2):
        super().__init__()
        self.conv = HypergraphConv(hidden_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim, eps=1e-5)
        self.dropout = nn.Dropout(dropout)

        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * ffn_expansion),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * ffn_expansion, hidden_dim)
        )
        self.norm2 = nn.LayerNorm(hidden_dim, eps=1e-5)

    def forward(self, x, edge_index, edge_weight):
        # Hypergraph convolution with residual
        x_res = x
        x = self.conv(x, edge_index, edge_weight)
        x = self.norm1(x)
        x = self.dropout(F.relu(x)) + x_res

        # Feedforward with residual
        x_res = x
        x = self.ffn(x)
        x = self.norm2(x)
        x = self.dropout(x) + x_res

        return x


class AdaptiveHyperVigClassifier(nn.Module):
    def __init__(self, in_channels, hidden, num_classes, k_schedule=[12, 10, 8, 6, 4, 4], dropout=0.3):
        super().__init__()
        self.hidden = hidden
        self.k_schedule = k_schedule  # k values for each layer

        # Project patch features
        self.input_proj = nn.Linear(in_channels, hidden)

        # Multi-head edge attention (shared across layers or per-layer)
        self.edge_attns = nn.ModuleList([
            MultiHeadHyperedgeAttention(in_dim=hidden, hidden=64, num_heads=8)
            for _ in k_schedule
        ])

        # Hypergraph blocks
        self.blocks = nn.ModuleList([
            AdaptiveHypergraphBlock(hidden, dropout=dropout)
            for _ in k_schedule
        ])

        # Attentional pooling
        self.pool = AttentionalAggregation(
            gate_nn=nn.Sequential(
                nn.Linear(hidden, hidden),
                nn.ReLU(),
                nn.Linear(hidden, 1)
            )
        )

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden, num_classes)
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, images, batch_map_input):
        # Extract initial patch features
        batch_size = images.size(0)
        patches_per_image = 16

        # Initial projection
        x_init, _, batch_map = image_to_adaptive_hypergraph(images, k_spatial=self.k_schedule[0])
        x = self.input_proj(x_init)

        # Process through layers with evolving topology
        for layer_idx, (k_val, block, edge_attn) in enumerate(zip(self.k_schedule, self.blocks, self.edge_attns)):
            # Rebuild hypergraph with current k value
            # Use feature-based edges for later layers (when features are better)
            k_feature = k_val if layer_idx >= 2 else None  # Feature edges only in layers 3+

            # Reshape x back to images to rebuild graph
            x_reshaped = x.view(batch_size, patches_per_image, -1)

            # Build new hypergraph based on current features
            _, edge_index, _ = self.build_hypergraph_from_features(
                x_reshaped, k_spatial=k_val, k_feature=k_feature
            )

            # Compute edge attention
            edge_weight = edge_attn(x, edge_index)

            # Apply block
            x = block(x, edge_index, edge_weight)

        # Pool and classify
        out = self.pool(x, batch_map)
        out = self.classifier(out)
        return out

    def build_hypergraph_from_features(self, x_batched, k_spatial, k_feature=None):
        """Helper to rebuild hypergraph from current node features"""
        batch_size, num_nodes, feat_dim = x_batched.shape
        batch_hyperedge_index = []
        node_offset = 0
        edge_id = 0

        for b in range(batch_size):
            x_nodes = x_batched[b]  # [num_nodes, feat_dim]

            # Spatial distances (based on patch coordinates, fixed)
            num_patches_side = 4  # 32/8 = 4
            coords = torch.tensor([
                [i // num_patches_side, i % num_patches_side]
                for i in range(num_nodes)
            ], device=x_nodes.device, dtype=torch.float)
            dists_spatial = torch.cdist(coords, coords, p=2)

            hyperedge_list = []

            # Spatial hyperedges
            for i in range(num_nodes):
                nn_idx = torch.topk(dists_spatial[i], k=k_spatial+1, largest=False).indices
                for node in nn_idx:
                    hyperedge_list.append([node.item() + node_offset, edge_id])
                edge_id += 1

            # Feature hyperedges (based on learned features)
            if k_feature is not None:
                dists_feat = torch.cdist(x_nodes, x_nodes, p=2)
                for i in range(num_nodes):
                    nn_idx = torch.topk(dists_feat[i], k=k_feature+1, largest=False).indices
                    for node in nn_idx:
                        hyperedge_list.append([node.item() + node_offset, edge_id])
                    edge_id += 1

            batch_hyperedge_index.extend(hyperedge_list)
            node_offset += num_nodes

        edge_index = torch.tensor(batch_hyperedge_index, dtype=torch.long, device=x_batched.device).T
        batch_map = torch.cat([torch.full((num_nodes,), b, dtype=torch.long, device=x_batched.device)
                               for b in range(batch_size)])

        return x_batched.view(-1, feat_dim), edge_index, batch_map

In [5]:
def image_to_adaptive_hypergraph(images, k_spatial, k_feature=None):
    """
    Build hypergraph with specified k values (can vary per layer).
    If k_feature is None, only use spatial edges.
    """
    batch_node_feats = []
    batch_hyperedge_index = []
    batch_map = []
    node_offset = 0
    edge_id = 0

    for b, img in enumerate(images):
        C, H, W = img.shape
        patch_size = 8

        # Extract patches
        patches = img.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
        patches = patches.permute(1, 2, 0, 3, 4).contiguous()
        patches = patches.view(-1, C, patch_size, patch_size)
        node_feats = patches.view(patches.size(0), -1).to(images.device)
        num_nodes = node_feats.size(0)

        # Spatial coordinates
        num_patches_side = W // patch_size
        coords = torch.tensor([
            [i // num_patches_side, i % num_patches_side]
            for i in range(num_nodes)
        ], device=images.device, dtype=torch.float)

        dists_spatial = torch.cdist(coords, coords, p=2)

        hyperedge_list = []

        # Spatial hyperedges
        for i in range(num_nodes):
            nn_idx = torch.topk(dists_spatial[i], k=k_spatial+1, largest=False).indices
            for node in nn_idx:
                hyperedge_list.append([node.item() + node_offset, edge_id])
            edge_id += 1

        # Feature hyperedges (optional, for later layers)
        if k_feature is not None and k_feature > 0:
            dists_feat = torch.cdist(node_feats, node_feats, p=2)
            for i in range(num_nodes):
                nn_idx = torch.topk(dists_feat[i], k=k_feature+1, largest=False).indices
                for node in nn_idx:
                    hyperedge_list.append([node.item() + node_offset, edge_id])
                edge_id += 1

        batch_node_feats.append(node_feats)
        batch_hyperedge_index.extend(hyperedge_list)
        batch_map.append(torch.full((num_nodes,), b, dtype=torch.long, device=images.device))
        node_offset += num_nodes

    x = torch.cat(batch_node_feats, dim=0).float()
    hyperedge_index = torch.tensor(batch_hyperedge_index, dtype=torch.long, device=images.device).T
    batch_map = torch.cat(batch_map)

    return x, hyperedge_index, batch_map


class AdaptiveHypergraphBlock(nn.Module):
    """Hypergraph block that can use different k values"""
    def __init__(self, hidden_dim, dropout=0.3, ffn_expansion=2):
        super().__init__()
        self.conv = HypergraphConv(hidden_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim, eps=1e-5)
        self.dropout = nn.Dropout(dropout)

        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * ffn_expansion),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * ffn_expansion, hidden_dim)
        )
        self.norm2 = nn.LayerNorm(hidden_dim, eps=1e-5)

    def forward(self, x, edge_index, edge_weight):
        # Hypergraph convolution with residual
        x_res = x
        x = self.conv(x, edge_index, edge_weight)
        x = self.norm1(x)
        x = self.dropout(F.relu(x)) + x_res

        # Feedforward with residual
        x_res = x
        x = self.ffn(x)
        x = self.norm2(x)
        x = self.dropout(x) + x_res

        return x


class AdaptiveHyperVigClassifier(nn.Module):
    def __init__(self, in_channels, hidden, num_classes, k_schedule=[12, 10, 8, 6, 4, 4], dropout=0.3):
        super().__init__()
        self.hidden = hidden
        self.k_schedule = k_schedule  # k values for each layer

        # Project patch features
        self.input_proj = nn.Linear(in_channels, hidden)

        # Multi-head edge attention (shared across layers or per-layer)
        self.edge_attns = nn.ModuleList([
            MultiHeadHyperedgeAttention(in_dim=hidden, hidden=64, num_heads=8)
            for _ in k_schedule
        ])

        # Hypergraph blocks
        self.blocks = nn.ModuleList([
            AdaptiveHypergraphBlock(hidden, dropout=dropout)
            for _ in k_schedule
        ])

        # Attentional pooling
        self.pool = AttentionalAggregation(
            gate_nn=nn.Sequential(
                nn.Linear(hidden, hidden),
                nn.ReLU(),
                nn.Linear(hidden, 1)
            )
        )

        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden, num_classes)
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, images, batch_map_input):
        # Extract initial patch features
        batch_size = images.size(0)
        patches_per_image = 16

        # Initial projection
        x_init, _, batch_map = image_to_adaptive_hypergraph(images, k_spatial=self.k_schedule[0])
        x = self.input_proj(x_init)

        # Process through layers with evolving topology
        for layer_idx, (k_val, block, edge_attn) in enumerate(zip(self.k_schedule, self.blocks, self.edge_attns)):
            # Rebuild hypergraph with current k value
            # Use feature-based edges for later layers (when features are better)
            k_feature = k_val if layer_idx >= 2 else None  # Feature edges only in layers 3+

            # Reshape x back to images to rebuild graph
            x_reshaped = x.view(batch_size, patches_per_image, -1)

            # Build new hypergraph based on current features
            _, edge_index, _ = self.build_hypergraph_from_features(
                x_reshaped, k_spatial=k_val, k_feature=k_feature
            )

            # Compute edge attention
            edge_weight = edge_attn(x, edge_index)

            # Apply block
            x = block(x, edge_index, edge_weight)

        # Pool and classify
        out = self.pool(x, batch_map)
        out = self.classifier(out)
        return out

    def build_hypergraph_from_features(self, x_batched, k_spatial, k_feature=None):
        """Helper to rebuild hypergraph from current node features"""
        batch_size, num_nodes, feat_dim = x_batched.shape
        batch_hyperedge_index = []
        node_offset = 0
        edge_id = 0

        for b in range(batch_size):
            x_nodes = x_batched[b]  # [num_nodes, feat_dim]

            # Spatial distances (based on patch coordinates, fixed)
            num_patches_side = 4  # 32/8 = 4
            coords = torch.tensor([
                [i // num_patches_side, i % num_patches_side]
                for i in range(num_nodes)
            ], device=x_nodes.device, dtype=torch.float)
            dists_spatial = torch.cdist(coords, coords, p=2)

            hyperedge_list = []

            # Spatial hyperedges
            for i in range(num_nodes):
                nn_idx = torch.topk(dists_spatial[i], k=k_spatial+1, largest=False).indices
                for node in nn_idx:
                    hyperedge_list.append([node.item() + node_offset, edge_id])
                edge_id += 1

            # Feature hyperedges (based on learned features)
            if k_feature is not None:
                dists_feat = torch.cdist(x_nodes, x_nodes, p=2)
                for i in range(num_nodes):
                    nn_idx = torch.topk(dists_feat[i], k=k_feature+1, largest=False).indices
                    for node in nn_idx:
                        hyperedge_list.append([node.item() + node_offset, edge_id])
                    edge_id += 1

            batch_hyperedge_index.extend(hyperedge_list)
            node_offset += num_nodes

        edge_index = torch.tensor(batch_hyperedge_index, dtype=torch.long, device=x_batched.device).T
        batch_map = torch.cat([torch.full((num_nodes,), b, dtype=torch.long, device=x_batched.device)
                               for b in range(batch_size)])

        return x_batched.view(-1, feat_dim), edge_index, batch_map

In [6]:
class MultiHeadHyperedgeAttention(nn.Module):
    """Multi-head attention for hyperedges - different heads learn different relationship types"""
    def __init__(self, in_dim, hidden=64, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = hidden // num_heads

        # Each head has its own MLP
        self.attention_heads = nn.ModuleList([
            nn.Sequential(
                nn.Linear(in_dim, self.head_dim),
                nn.ReLU(),
                nn.Dropout(0.1),
                nn.Linear(self.head_dim, 1)
            ) for _ in range(num_heads)
        ])

        self._init_weights()

    def _init_weights(self):
        for head in self.attention_heads:
            for m in head.modules():
                if isinstance(m, nn.Linear):
                    nn.init.xavier_uniform_(m.weight)
                    if m.bias is not None:
                        nn.init.zeros_(m.bias)

    def forward(self, x, edge_index):
        """
        x: [num_nodes, feature_dim]
        edge_index: [2, num_connections] - format: [node_ids, hyperedge_ids]
        Returns: edge weights [num_unique_hyperedges]
        """
        node_idx, hyperedge_idx = edge_index
        num_hyperedges = hyperedge_idx.max().item() + 1

        # Aggregate node features for each hyperedge (mean pooling)
        hyperedge_feats = torch.zeros(num_hyperedges, x.size(1), device=x.device, dtype=x.dtype)
        hyperedge_feats.index_add_(0, hyperedge_idx, x[node_idx])

        # Count nodes per hyperedge for proper averaging
        counts = torch.zeros(num_hyperedges, device=x.device, dtype=x.dtype)
        counts.index_add_(0, hyperedge_idx, torch.ones_like(node_idx, dtype=x.dtype))
        hyperedge_feats = hyperedge_feats / counts.unsqueeze(1).clamp(min=1)

        # Compute attention with each head
        head_weights = []
        for head in self.attention_heads:
            alpha = head(hyperedge_feats).squeeze(-1)
            alpha = torch.clamp(alpha, min=-5, max=5)
            head_weights.append(torch.sigmoid(alpha))

        # Average across heads
        avg_weight = torch.stack(head_weights).mean(0)

        # Scale to [0.1, 1.0] range
        return avg_weight * 0.9 + 0.1

In [None]:
def image_to_true_hypergraph(images, k_spatial=4, k_feature=4):
    batch_node_feats = []
    batch_hyperedge_index = []
    batch_map = []
    node_offset = 0
    edge_id = 0

    for b, img in enumerate(images):
        C, H, W = img.shape
        patch_size = 8

        # Extract patches
        patches = img.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
        patches = patches.permute(1, 2, 0, 3, 4).contiguous()
        patches = patches.view(-1, C, patch_size, patch_size)
        node_feats = patches.view(patches.size(0), -1).to(images.device)
        num_nodes = node_feats.size(0)

        # Spatial coordinates for spatial hyperedges
        coords = torch.tensor([
            [i // (W // patch_size), i % (W // patch_size)]
            for i in range(num_nodes)
        ], device=images.device, dtype=torch.float)

        dists_spatial = torch.cdist(coords, coords, p=2)

        # Feature distances for feature-based hyperedges
        dists_feat = torch.cdist(node_feats, node_feats, p=2)

        hyperedge_list = []

        # Spatial hyperedges: for each node, create a hyperedge with its k nearest spatial neighbors
        for i in range(num_nodes):
            nn_idx = torch.topk(dists_spatial[i], k=k_spatial+1, largest=False).indices
            # Create hyperedge: all these nodes belong to one hyperedge
            for node in nn_idx:
                hyperedge_list.append([node.item() + node_offset, edge_id])
            edge_id += 1

        # Feature hyperedges: for each node, create a hyperedge with its k nearest feature neighbors
        for i in range(num_nodes):
            nn_idx = torch.topk(dists_feat[i], k=k_feature+1, largest=False).indices
            for node in nn_idx:
                hyperedge_list.append([node.item() + node_offset, edge_id])
            edge_id += 1

        batch_node_feats.append(node_feats)
        batch_hyperedge_index.extend(hyperedge_list)
        batch_map.append(torch.full((num_nodes,), b, dtype=torch.long, device=images.device))
        node_offset += num_nodes

    x = torch.cat(batch_node_feats, dim=0).float()
    # Convert to PyG hypergraph format: [2, num_edges] where row 0 is nodes, row 1 is hyperedge IDs
    hyperedge_index = torch.tensor(batch_hyperedge_index, dtype=torch.long, device=images.device).T
    batch_map = torch.cat(batch_map)

    return x, hyperedge_index, batch_map

In [None]:
class HyperedgeAttention(nn.Module):
    """Learnable hyperedge attention that's part of the model's forward pass"""
    def __init__(self, in_dim, hidden=64):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, 1)
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x, edge_index):
        """
        Compute attention weights for each hyperedge based on aggregated node features.

        x: [num_nodes, feature_dim]
        edge_index: [2, num_connections] - format: [node_ids, hyperedge_ids]
        Returns: edge weights [num_unique_hyperedges]
        """
        node_idx, hyperedge_idx = edge_index
        num_hyperedges = hyperedge_idx.max().item() + 1

        # Aggregate node features for each hyperedge (mean pooling)
        hyperedge_feats = torch.zeros(num_hyperedges, x.size(1), device=x.device, dtype=x.dtype)

        # Use scatter_mean to aggregate features
        hyperedge_feats.index_add_(0, hyperedge_idx, x[node_idx])

        # Count nodes per hyperedge for proper averaging
        counts = torch.zeros(num_hyperedges, device=x.device, dtype=x.dtype)
        counts.index_add_(0, hyperedge_idx, torch.ones_like(node_idx, dtype=x.dtype))
        hyperedge_feats = hyperedge_feats / counts.unsqueeze(1).clamp(min=1)

        # Compute attention scores
        alpha = self.mlp(hyperedge_feats).squeeze(-1)
        alpha = torch.clamp(alpha, min=-5, max=5)

        # prevent zero weights
        return torch.sigmoid(alpha) * 0.9 + 0.1


In [None]:
class HypergraphBlock(nn.Module):
    def __init__(self, hidden_dim, dropout=0.3, ffn_expansion=2):
        super().__init__()
        self.conv = HypergraphConv(hidden_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim, eps=1e-5)
        self.dropout = nn.Dropout(dropout)

        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * ffn_expansion),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * ffn_expansion, hidden_dim)
        )
        self.norm2 = nn.LayerNorm(hidden_dim, eps=1e-5)

    def forward(self, x, edge_index, edge_weight):
        # Hypergraph conv w/ residual
        x_res = x
        x = self.conv(x, edge_index, edge_weight)
        x = self.norm1(x)
        x = self.dropout(F.relu(x)) + x_res

        # Feedforward network w/ residual
        x_res = x
        x = self.ffn(x)
        x = self.norm2(x)
        x = self.dropout(x) + x_res

        return x

In [7]:
class HyperedgeToHyperedgeLayer(nn.Module):
    """Allow hyperedges to exchange information with each other"""
    def __init__(self, hidden_dim):
        super().__init__()
        # MLP for hyperedge message passing
        self.hyperedge_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, x, edge_index, edge_weights):
        """
        x: node features [num_nodes, hidden]
        edge_index: [2, num_connections]
        edge_weights: [num_hyperedges]

        Returns: updated edge_weights
        """
        node_idx, hyperedge_idx = edge_index
        num_hyperedges = hyperedge_idx.max().item() + 1

        # Aggregate node features per hyperedge
        hyperedge_feats = torch.zeros(num_hyperedges, x.size(1), device=x.device, dtype=x.dtype)
        hyperedge_feats.index_add_(0, hyperedge_idx, x[node_idx])
        counts = torch.zeros(num_hyperedges, device=x.device, dtype=x.dtype)
        counts.index_add_(0, hyperedge_idx, torch.ones_like(node_idx, dtype=x.dtype))
        hyperedge_feats = hyperedge_feats / counts.unsqueeze(1).clamp(min=1)

        # Build dual graph: hyperedges that share nodes are connected
        # Two hyperedges are neighbors if they share at least one node
        dual_edges = self.build_dual_graph(edge_index, num_hyperedges)

        # Message passing between hyperedges
        hyperedge_feats_updated = hyperedge_feats.clone()
        for src, dst in dual_edges.T:
            # Hyperedge src sends message to hyperedge dst
            message = self.hyperedge_mlp(hyperedge_feats[src])
            hyperedge_feats_updated[dst] += 0.1 * message  # Small update

        # Update edge weights based on new hyperedge features
        edge_weights_updated = hyperedge_feats_updated.norm(dim=1)
        edge_weights_updated = edge_weights_updated / edge_weights_updated.max()  # Normalize
        edge_weights_updated = edge_weights_updated * 0.9 + 0.1  # Scale to [0.1, 1.0]

        return edge_weights_updated

    def build_dual_graph(self, edge_index, num_hyperedges):
        """Build graph where hyperedges that share nodes are connected"""
        node_idx, hyperedge_idx = edge_index

        # For each node, find which hyperedges it belongs to
        node_to_hyperedges = {}
        for node, hedge in zip(node_idx.tolist(), hyperedge_idx.tolist()):
            if node not in node_to_hyperedges:
                node_to_hyperedges[node] = []
            node_to_hyperedges[node].append(hedge)

        # Build edges between hyperedges that share nodes
        dual_edges = set()
        for node, hedges in node_to_hyperedges.items():
            # Connect all pairs of hyperedges that share this node
            for i in range(len(hedges)):
                for j in range(i+1, len(hedges)):
                    dual_edges.add((hedges[i], hedges[j]))
                    dual_edges.add((hedges[j], hedges[i]))  # Bidirectional

        if len(dual_edges) == 0:
            return torch.empty((2, 0), dtype=torch.long, device=edge_index.device)

        return torch.tensor(list(dual_edges), dtype=torch.long, device=edge_index.device).T


class HypergraphBlockWithH2H(nn.Module):
    """Hypergraph block with hyperedge-to-hyperedge communication"""
    def __init__(self, hidden_dim, dropout=0.3, ffn_expansion=2):
        super().__init__()
        self.conv = HypergraphConv(hidden_dim, hidden_dim)
        self.norm1 = nn.LayerNorm(hidden_dim, eps=1e-5)
        self.dropout = nn.Dropout(dropout)

        # Hyperedge-to-hyperedge layer
        self.h2h = HyperedgeToHyperedgeLayer(hidden_dim)

        self.ffn = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim * ffn_expansion),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim * ffn_expansion, hidden_dim)
        )
        self.norm2 = nn.LayerNorm(hidden_dim, eps=1e-5)

    def forward(self, x, edge_index, edge_weight):
        # Update edge weights via hyperedge communication
        edge_weight_updated = self.h2h(x, edge_index, edge_weight)

        # Hypergraph convolution with updated weights
        x_res = x
        x = self.conv(x, edge_index, edge_weight_updated)
        x = self.norm1(x)
        x = self.dropout(F.relu(x)) + x_res

        # Feedforward
        x_res = x
        x = self.ffn(x)
        x = self.norm2(x)
        x = self.dropout(x) + x_res

        return x

In [10]:
class LearnablePatchExtractor(nn.Module):
    """Learn where to extract patches instead of fixed grid"""
    def __init__(self, num_patches=16, patch_size=8, img_size=32):
        super().__init__()
        self.num_patches = num_patches
        self.patch_size = patch_size
        self.img_size = img_size

        # Network to predict patch locations
        self.localization_net = nn.Sequential(
            # Input: flattened image
            nn.Linear(3 * img_size * img_size, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, num_patches * 2)  # (x, y) for each patch
        )

        # Initialize to uniform grid
        self._init_grid()

    def _init_grid(self):
        """Initialize to predict uniform grid (like current fixed patches)"""
        grid_coords = []
        patches_per_side = int(np.sqrt(self.num_patches))
        spacing = self.img_size / patches_per_side

        for i in range(patches_per_side):
            for j in range(patches_per_side):
                x = (i + 0.5) * spacing
                y = (j + 0.5) * spacing
                grid_coords.extend([x / self.img_size, y / self.img_size])  # Normalize to [0,1]

        # Set final layer bias to these coordinates
        with torch.no_grad():
            self.localization_net[-1].bias.copy_(torch.tensor(grid_coords))

    def forward(self, images):
        """
        images: [batch, 3, 32, 32]
        Returns: patches [batch, num_patches, 3, patch_size, patch_size]
                 coordinates [batch, num_patches, 2] for visualization
        """
        batch_size = images.size(0)

        # Predict patch locations
        coords = self.localization_net(images.view(batch_size, -1))
        coords = coords.view(batch_size, self.num_patches, 2)
        coords = torch.sigmoid(coords)  # Scale to [0, 1]

        # Extract patches at predicted locations using grid_sample
        patches = []
        for b in range(batch_size):
            batch_patches = []
            for p in range(self.num_patches):
                x, y = coords[b, p]

                # Convert to pixel coordinates
                x_pix = x * self.img_size
                y_pix = y * self.img_size

                # Create sampling grid for this patch
                grid = self.create_patch_grid(x_pix, y_pix, self.patch_size)
                grid = grid.unsqueeze(0).to(images.device)

                # Extract patch
                patch = F.grid_sample(
                    images[b:b+1],
                    grid,
                    mode='bilinear',
                    padding_mode='border',
                    align_corners=True
                )
                batch_patches.append(patch)

            patches.append(torch.cat(batch_patches, dim=0))

        patches = torch.stack(patches)  # [batch, num_patches, 3, patch_size, patch_size]

        return patches, coords

    def create_patch_grid(self, center_x, center_y, size):
        """Create sampling grid for extracting a patch"""
        # Grid in normalized coordinates [-1, 1]
        half_size = size / 2

        x_min = (center_x - half_size) / self.img_size * 2 - 1
        x_max = (center_x + half_size) / self.img_size * 2 - 1
        y_min = (center_y - half_size) / self.img_size * 2 - 1
        y_max = (center_y + half_size) / self.img_size * 2 - 1

        x = torch.linspace(x_min, x_max, size)
        y = torch.linspace(y_min, y_max, size)

        grid_y, grid_x = torch.meshgrid(y, x, indexing='ij')
        grid = torch.stack([grid_x, grid_y], dim=-1)

        return grid


class LearnablePatchHyperViG(nn.Module):
    def __init__(self, hidden, num_classes, num_patches=16, dropout=0.3):
        super().__init__()

        # Learnable patch extractor
        self.patch_extractor = LearnablePatchExtractor(num_patches=num_patches)

        self.input_proj = nn.Linear(3 * 8 * 8, hidden)
        self.edge_attn = MultiHeadHyperedgeAttention(in_dim=hidden, hidden=64, num_heads=8)

        self.blocks = nn.ModuleList([
            HypergraphBlockWithH2H(hidden, dropout=dropout)
            for _ in range(6)
        ])

        self.pool = AttentionalAggregation(
            gate_nn=nn.Sequential(
                nn.Linear(hidden, hidden),
                nn.ReLU(),
                nn.Linear(hidden, 1)
            )
        )

        self.classifier = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden, num_classes)
        )

        self._init_weights()

    def _init_weights(self):
        """Add this method!"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, images):
        # Extract patches at learned locations
        patches, coords = self.patch_extractor(images)  # [batch, num_patches, 3, 8, 8]

        # Flatten patches
        batch_size, num_patches = patches.shape[:2]
        patches_flat = patches.view(batch_size * num_patches, -1)

        # Build hypergraph based on learned patch locations
        # Use coords to compute distances
        x, edge_index, batch_map = self.build_hypergraph_from_coords(
            patches_flat, coords, batch_size
        )

        # Project features
        x = self.input_proj(x)

        # Compute edge attention
        edge_weight = self.edge_attn(x, edge_index)

        # Process through blocks
        for block in self.blocks:
            x = block(x, edge_index, edge_weight)

        # Pool and classify
        out = self.pool(x, batch_map)
        out = self.classifier(out)

        return out, coords  # Return coords for visualization

    def build_hypergraph_from_coords(self, patches_flat, coords, batch_size):
        """Build hypergraph based on learned patch coordinates"""
        num_patches = coords.size(1)
        batch_hyperedge_index = []
        node_offset = 0
        edge_id = 0

        for b in range(batch_size):
            # Compute distances between patch centers
            patch_coords = coords[b]  # [num_patches, 2]
            dists = torch.cdist(patch_coords, patch_coords, p=2)

            hyperedge_list = []
            k = 8  # Number of neighbors

            for i in range(num_patches):
                nn_idx = torch.topk(dists[i], k=k+1, largest=False).indices
                for node in nn_idx:
                    hyperedge_list.append([node.item() + node_offset, edge_id])
                edge_id += 1

            batch_hyperedge_index.extend(hyperedge_list)
            node_offset += num_patches

        edge_index = torch.tensor(batch_hyperedge_index, dtype=torch.long, device=patches_flat.device).T
        batch_map = torch.cat([
            torch.full((num_patches,), b, dtype=torch.long, device=patches_flat.device)
            for b in range(batch_size)
        ])

        return patches_flat, edge_index, batch_map

In [None]:
class HyperVigClassifier(nn.Module):
    def __init__(self, in_channels, hidden, num_classes, num_blocks=3, dropout=0.3):
        super().__init__()
        self.hidden = hidden
        self.input_proj = nn.Linear(in_channels, hidden)
        #self.edge_attn = HyperedgeAttention(in_dim=hidden, hidden=64)
        #try new
        self.edge_attn = MultiHeadHyperedgeAttention(in_dim=hidden, hidden=64, num_heads=8)

        # Multiple hypergraph blocks (each with its own FFN)
        self.blocks = nn.ModuleList([
            HypergraphBlock(hidden, dropout=dropout)
            for _ in range(num_blocks)
        ])

        self.pool = AttentionalAggregation(
            gate_nn=nn.Sequential(
                nn.Linear(hidden, hidden),
                nn.ReLU(),
                nn.Linear(hidden, 1)
            )
        )

        self.classifier = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden, num_classes)
        )

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x, edge_index, batch_map):
        x = self.input_proj(x)  # [num_nodes, hidden]

        # Compute hyperedge weights in forward pass so gradients flow
        num_hyperedges = edge_index[1].max().item() + 1
        hyperedge_weight = self.edge_attn(x, edge_index)

        # Apply hypergraph blocks (each with its own FFN)
        for block in self.blocks:
            x = block(x, edge_index, hyperedge_weight)

        out = self.pool(x, batch_map)  # [batch_size, hidden]

        out = self.classifier(out)  # [batch_size, num_classes]
        return out

In [None]:
def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        node_feats, edge_index, batch_map = image_to_true_hypergraph(images)

        optimizer.zero_grad()
        outputs = model(node_feats, edge_index, batch_map)
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping for stability
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        total_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    return total_loss / total, correct / total

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            node_feats, edge_index, batch_map = image_to_true_hypergraph(images)

            outputs = model(node_feats, edge_index, batch_map)
            loss = criterion(outputs, labels)

            total_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return total_loss / total, correct / total

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

model = LearnablePatchHyperViG(
    hidden=384,
    num_classes=10,
    num_patches=16,
    dropout=0.2
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=0.002, weight_decay=0.005)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=120, eta_min=1e-5)
criterion = nn.CrossEntropyLoss()

# Training loop
for epoch in range(120):
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        outputs, coords = model(images)  # coords for visualization
        loss = criterion(outputs, labels)

        coord_diversity_loss = -coords.std()
        loss = loss + 0.001 * coord_diversity_loss

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()

Using device: cuda
