In [3]:
import numpy as np
import json
import gzip
from scipy.sparse import coo_matrix
import pandas as pd
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
# DEHNN layer
class DEHNNLayer(nn.Module):
    def __init__(self, node_in_features, edge_in_features):
        super(DEHNNLayer, self).__init__()
        self.node_mlp1 = nn.Sequential(
            nn.Linear(edge_in_features, edge_in_features),
            nn.ReLU()
        )
        self.edge_mlp2 = nn.Sequential(
            nn.Linear(node_in_features, node_in_features),
            nn.ReLU()
        )
        self.edge_mlp3 = nn.Sequential(
            nn.Linear(2 * node_in_features, 2 * node_in_features),
            nn.ReLU()
        )

        self.node_to_virtual_mlp = nn.Sequential(
            nn.Linear(node_in_features, node_in_features),
            nn.ReLU()
        )
        self.virtual_to_higher_virtual_mlp = nn.Sequential(
            nn.Linear(node_in_features, edge_in_features),
            nn.ReLU()
        )
        self.higher_virtual_to_virtual_mlp = nn.Sequential(
            nn.Linear(edge_in_features, edge_in_features),
            nn.ReLU()
        )
        self.virtual_to_node_mlp = nn.Sequential(
            nn.Linear(edge_in_features, edge_in_features),
            nn.ReLU()
        )

        # Learnable defaults for missing driver or sink
        self.default_driver = nn.Parameter(torch.zeros(node_in_features))
        self.default_sink_agg = nn.Parameter(torch.zeros(node_in_features))
        self.default_edge_agg = nn.Parameter(torch.zeros(edge_in_features))
        self.default_virtual_node = nn.Parameter(torch.zeros(node_in_features))

        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize all parameters with Xavier uniform distribution."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, node_features, edge_features, hypergraph):
        # Node update
        updated_node_features = {}
        for node in hypergraph.nodes:
            incident_edges = hypergraph.get_incident_edges(node)
            if incident_edges:
                agg_features = torch.sum(torch.stack([self.node_mlp1(edge_features[edge]) for edge in incident_edges]), dim=0)
            else:
                agg_features = self.default_edge_agg
            updated_node_features[node] = agg_features

        # Edge update
        updated_edge_features = {}
        for edge in hypergraph.edges:
            driver, sinks = hypergraph.get_driver_and_sinks(edge)

            driver_feature = node_features[driver] if driver is not None else self.default_driver

            if sinks:
                sink_agg = torch.sum(torch.stack([self.edge_mlp2(node_features[sink]) for sink in sinks]), dim=0)
            else:
                sink_agg = self.default_sink_agg

            concatenated = torch.cat([driver_feature, sink_agg])
            updated_edge_features[edge] = self.edge_mlp3(concatenated)

        # Virtual node aggregation
        virtual_node_agg = {}
        for virtual_node in range(hypergraph.num_virtual_nodes):
            assigned_nodes = [node for node in hypergraph.nodes if hypergraph.get_virtual_node(node) == virtual_node]
            if assigned_nodes:
                agg_features = torch.sum(torch.stack([self.node_to_virtual_mlp(node_features[node]) for node in assigned_nodes]), dim=0)
            else:
                agg_features = self.default_virtual_node
            virtual_node_agg[virtual_node] = agg_features

        higher_virtual_feature = torch.sum(
            torch.stack([self.virtual_to_higher_virtual_mlp(virtual_node_agg[vn]) for vn in virtual_node_agg]), dim=0
        )

        propagated_virtual_node_features = {}
        for virtual_node in range(hypergraph.num_virtual_nodes):
            propagated_virtual_node_features[virtual_node] = self.higher_virtual_to_virtual_mlp(higher_virtual_feature)

        for node in hypergraph.nodes:
            virtual_node = hypergraph.get_virtual_node(node)
            propagated_feature = self.virtual_to_node_mlp(propagated_virtual_node_features[virtual_node])
            updated_node_features[node] += propagated_feature

        return updated_node_features, updated_edge_features

In [5]:
# DEHNN model
class DEHNN(nn.Module):
    def __init__(self, num_layers, node_in_features, edge_in_features):
        super(DEHNN, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        
        for i in range(num_layers):
            self.layers.append(DEHNNLayer(node_in_features, edge_in_features))
            node_in_features, edge_in_features = edge_in_features, node_in_features
            edge_in_features *= 2

        edge_in_features = edge_in_features // 2
        self.output_layer = nn.Sequential(
            nn.Linear(node_in_features, 2)
        )

    def forward(self, node_features, edge_features, hypergraph):
        for layer in self.layers:
            node_features, edge_features = layer(node_features, edge_features, hypergraph)
        
        final_node_features = torch.stack([node_features[node] for node in hypergraph.nodes], dim=0)
        output = self.output_layer(final_node_features)
        return output

In [6]:
# Basic Hypergraph Implementation
class Hypergraph:
    def __init__(self, nodes, edges, driver_sink_map, node_to_virtual_map, num_virtual_nodes):
        self.nodes = nodes
        self.edges = edges
        self.driver_sink_map = driver_sink_map
        self.node_to_virtual_map = node_to_virtual_map
        self.num_virtual_nodes = num_virtual_nodes

    def get_incident_edges(self, node):
        return [edge for edge in self.edges if node in self.driver_sink_map[edge][1] or node == self.driver_sink_map[edge][0]]

    def get_driver_and_sinks(self, edge):
        return self.driver_sink_map[edge]
    
    def get_virtual_node(self, node):
        return self.node_to_virtual_map[node]

In [8]:
file_indices = range(1, 9)

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = DEHNN(num_layers=2, node_in_features=14, edge_in_features=1).to(device)

# Training configuration
epochs = 10
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

model.train()
for epoch in range(epochs):
    epoch_loss = 0  # Accumulate loss over all datasets for each epoch
    
    for i in file_indices:
        print(i)
        # Load data for the current file
        clean_data_dir = 'clean_data/'
        
        with open(f'{clean_data_dir}{i}.driver_sink_map.pkl', 'rb') as f:
            driver_sink_map = pickle.load(f)
        
        with open(f'{clean_data_dir}{i}.node_features.pkl', 'rb') as f:
            node_features = pickle.load(f)
        
        with open(f'{clean_data_dir}{i}.net_features.pkl', 'rb') as f:
            edge_features = pickle.load(f)
        
        with open(f'{clean_data_dir}{i}.congestion.pkl', 'rb') as f:
            congestion = pickle.load(f)
        
        partition = np.load(f'{clean_data_dir}{i}.partition.npy')
        
        # Preprocess data
        node_features = {k: torch.tensor(v).float().to(device) for k, v in node_features.items()}
        edge_features = {k: torch.tensor(v).float().to(device) for k, v in edge_features.items()}
        
        nodes = list(range(len(node_features)))
        edges = list(range(len(edge_features)))
        hypergraph = Hypergraph(nodes, edges, driver_sink_map, partition, 2)
        
        # Forward pass
        output = model(node_features, edge_features, hypergraph)
        
        # Dummy target for illustration (binary labels for each node: 0 for not congested, 1 for congested)
        target = torch.tensor(list(congestion.values())).to(device)
        
        # Compute loss
        loss = criterion(output, target)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()  # Reset gradients after each batch
        
        # Accumulate loss
        epoch_loss += loss.item()
    
    # Print epoch loss
    print(f'Epoch [{epoch+1}/10], Loss: {epoch_loss:.4f}')

1


FileNotFoundError: [Errno 2] No such file or directory: 'clean_data/1.driver_sink_map.pkl'