## Imports

In [55]:
import numpy as np
import json
import gzip
from scipy.sparse import coo_matrix
import pandas as pd
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F

## Model Architecture

In [56]:
class DEHNNLayer(nn.Module):
    def __init__(self, node_in_features, edge_in_features):
        super(DEHNNLayer, self).__init__()
        self.node_mlp1 = nn.Sequential(
            nn.Linear(edge_in_features, edge_in_features),
            nn.ReLU()
        )
        self.edge_mlp2 = nn.Sequential(
            nn.Linear(node_in_features, node_in_features),
            nn.ReLU()
        )
        self.edge_mlp3 = nn.Sequential(
            nn.Linear(2 * node_in_features, 2 * node_in_features),
            nn.ReLU()
        )

        self.node_to_virtual_mlp = nn.Sequential(
            nn.Linear(node_in_features, node_in_features),
            nn.ReLU()
        )
        self.virtual_to_higher_virtual_mlp = nn.Sequential(
            nn.Linear(node_in_features, edge_in_features),
            nn.ReLU()
        )
        self.higher_virtual_to_virtual_mlp = nn.Sequential(
            nn.Linear(edge_in_features, edge_in_features),
            nn.ReLU()
        )
        self.virtual_to_node_mlp = nn.Sequential(
            nn.Linear(edge_in_features, edge_in_features),
            nn.ReLU()
        )

        # Learnable defaults for missing driver or sink
        self.default_driver = nn.Parameter(torch.zeros(node_in_features))
        self.default_sink_agg = nn.Parameter(torch.zeros(node_in_features))
        self.default_edge_agg = nn.Parameter(torch.zeros(edge_in_features))
        self.default_virtual_node = nn.Parameter(torch.zeros(node_in_features))

        self._initialize_weights()

    def _initialize_weights(self):
        """Initialize all parameters with Xavier uniform distribution."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, node_features, edge_features, hypergraph):
        # Node update
        updated_node_features = {}
        for node in hypergraph.nodes:
            incident_edges = hypergraph.get_incident_edges(node)
            if incident_edges:
                agg_features = torch.sum(torch.stack([self.node_mlp1(edge_features[edge]) for edge in incident_edges]), dim=0)
            else:
                agg_features = self.default_edge_agg
            updated_node_features[node] = agg_features

        # Edge update
        updated_edge_features = {}
        for edge in hypergraph.edges:
            driver, sinks = hypergraph.get_driver_and_sinks(edge)

            driver_feature = node_features[driver] if driver is not None else self.default_driver

            if sinks:
                sink_agg = torch.sum(torch.stack([self.edge_mlp2(node_features[sink]) for sink in sinks]), dim=0)
            else:
                sink_agg = self.default_sink_agg

            concatenated = torch.cat([driver_feature, sink_agg])
            updated_edge_features[edge] = self.edge_mlp3(concatenated)

        # Virtual node aggregation
        virtual_node_agg = {}
        for virtual_node in range(hypergraph.num_virtual_nodes):
            assigned_nodes = [node for node in hypergraph.nodes if hypergraph.get_virtual_node(node) == virtual_node]
            if assigned_nodes:
                agg_features = torch.sum(torch.stack([self.node_to_virtual_mlp(node_features[node]) for node in assigned_nodes]), dim=0)
            else:
                agg_features = self.default_virtual_node
            virtual_node_agg[virtual_node] = agg_features

        higher_virtual_feature = torch.sum(
            torch.stack([self.virtual_to_higher_virtual_mlp(virtual_node_agg[vn]) for vn in virtual_node_agg]), dim=0
        )

        propagated_virtual_node_features = {}
        for virtual_node in range(hypergraph.num_virtual_nodes):
            propagated_virtual_node_features[virtual_node] = self.higher_virtual_to_virtual_mlp(higher_virtual_feature)

        for node in hypergraph.nodes:
            virtual_node = hypergraph.get_virtual_node(node)
            propagated_feature = self.virtual_to_node_mlp(propagated_virtual_node_features[virtual_node])
            updated_node_features[node] += propagated_feature

        return updated_node_features, updated_edge_features


class DEHNN(nn.Module):
    def __init__(self, num_layers, node_in_features, edge_in_features):
        super(DEHNN, self).__init__()
        self.num_layers = num_layers
        self.layers = nn.ModuleList()
        
        for i in range(num_layers):
            self.layers.append(DEHNNLayer(node_in_features, edge_in_features))
            node_in_features, edge_in_features = edge_in_features, node_in_features
            edge_in_features *= 2

        edge_in_features = edge_in_features // 2
        self.output_layer = nn.Sequential(
            nn.Linear(node_in_features, 2)
        )

    def forward(self, node_features, edge_features, hypergraph):
        for layer in self.layers:
            node_features, edge_features = layer(node_features, edge_features, hypergraph)
        
        final_node_features = torch.stack([node_features[node] for node in hypergraph.nodes], dim=0)
        output = self.output_layer(final_node_features)
        return output
    
class Hypergraph:
    def __init__(self, nodes, edges, driver_sink_map, node_to_virtual_map, num_virtual_nodes):
        self.nodes = nodes
        self.edges = edges
        self.driver_sink_map = driver_sink_map
        self.node_to_virtual_map = node_to_virtual_map
        self.num_virtual_nodes = num_virtual_nodes

    def get_incident_edges(self, node):
        return [edge for edge in self.edges if node in self.driver_sink_map[edge][1] or node == self.driver_sink_map[edge][0]]

    def get_driver_and_sinks(self, edge):
        return self.driver_sink_map[edge]
    
    def get_virtual_node(self, node):
        return self.node_to_virtual_map[node]

## Loading Training and Validation Data

In [57]:
# loading training data and constructing hypergraph
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

clean_data_dir = '../data/chips/clean_data/'

with open(clean_data_dir + '1.driver_sink_map.pkl', 'rb') as f:
    train_driver_sink_map = pickle.load(f)

with open(clean_data_dir + '1.node_features.pkl', 'rb') as f:
    train_node_features = pickle.load(f)

with open(clean_data_dir + '1.net_features.pkl', 'rb') as f:
    train_edge_features = pickle.load(f)

with open(clean_data_dir + '1.congestion.pkl', 'rb') as f:
    train_congestion = pickle.load(f)

train_partition = np.load(clean_data_dir + '1.partition.npy')

train_node_features = {k: torch.tensor(v).float().to(device) for k, v in train_node_features.items()}
train_edge_features = {k: torch.tensor(v).float().to(device) for k, v in train_edge_features.items()}

train_nodes = list(range(len(train_node_features)))
train_edges = list(range(len(train_edge_features)))
train_hypergraph = Hypergraph(train_nodes, train_edges, train_driver_sink_map, train_partition, 2)

In [58]:
# loading validation data and constructing hypergraph
with open(clean_data_dir + '2.driver_sink_map.pkl', 'rb') as f:
    val_driver_sink_map = pickle.load(f)

with open(clean_data_dir + '2.node_features.pkl', 'rb') as f:
    val_node_features = pickle.load(f)

with open(clean_data_dir + '2.net_features.pkl', 'rb') as f:
    val_edge_features = pickle.load(f)

with open(clean_data_dir + '2.congestion.pkl', 'rb') as f:
    val_congestion = pickle.load(f)

val_partition = np.load(clean_data_dir + '2.partition.npy')

val_node_features = {k: torch.tensor(v).float().to(device) for k, v in val_node_features.items()}
val_edge_features = {k: torch.tensor(v).float().to(device) for k, v in val_edge_features.items()}

val_nodes = list(range(len(val_node_features)))
val_edges = list(range(len(val_edge_features)))
val_hypergraph = Hypergraph(val_nodes, val_edges, val_driver_sink_map, val_partition, 2)
val_targets = val_congestion

In [59]:
# checking class balance for validation and training
print(np.array(list(val_targets.values())).mean())
print(np.array(list(train_congestion.values())).mean())

0.0037834691501746217
0.0979251012145749


## Training and Validating Model

In [60]:
# Initialize DE-HNN model
model = DEHNN(num_layers=2, node_in_features=14, edge_in_features=1).to(device)
epochs = 10

# Optimizer and Loss Function
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()  # Cross-entropy loss for classification
#criterion = nn.BCEWithLogitsLoss(pos_weight = torch.tensor (5.0))
# Training and Validation Data
train_node_features = {k: v.to(device) for k, v in train_node_features.items()}
train_edge_features = {k: v.to(device) for k, v in train_edge_features.items()}
train_targets = torch.tensor(list(train_congestion.values())).long().to(device)

val_node_features = {k: v.to(device) for k, v in val_node_features.items()}
val_edge_features = {k: v.to(device) for k, v in val_edge_features.items()}
val_targets = torch.tensor(list(val_congestion.values())).long().to(device)

# Training and Validation Loop 
for epoch in range(epochs):
    # Training Phase
    model.train()
    optimizer.zero_grad()

    # Forward pass for training
    train_output = model(train_node_features, train_edge_features, train_hypergraph)
    
    # Compute training loss
    train_loss = criterion(train_output, train_targets)
    
    # Backward pass and optimization
    train_loss.backward()
    optimizer.step()
    
    # Validation Phase
    model.eval()
    with torch.no_grad():
        # Forward pass for validation
        val_output = model(val_node_features, val_edge_features, val_hypergraph)
        
        # Compute validation loss
        val_loss = criterion(val_output, val_targets)
        
        # Compute validation accuracy
        val_predictions = torch.argmax(val_output, dim=1)
        val_correct = (val_predictions == val_targets).sum().item()
        val_total = len(val_targets)
        val_accuracy = val_correct / val_total
    
    # Print training and validation metrics
    print(f"Epoch [{epoch+1}/{epochs}]")
    print(f"Train Loss: {train_loss.item():.4f}")
    print(f"Validation Loss: {val_loss.item():.4f}, Validation Accuracy: {val_accuracy:.4f}")


Epoch [1/10]
Train Loss: 5.1498
Validation Loss: 0.3153, Validation Accuracy: 0.8545
Epoch [2/10]
Train Loss: 4.1136
Validation Loss: 0.4343, Validation Accuracy: 0.7792
Epoch [3/10]
Train Loss: 3.1252
Validation Loss: 0.6416, Validation Accuracy: 0.6708
Epoch [4/10]
Train Loss: 2.2031
Validation Loss: 1.0277, Validation Accuracy: 0.4665
Epoch [5/10]
Train Loss: 1.3594
Validation Loss: 3.2409, Validation Accuracy: 0.0038
Epoch [6/10]
Train Loss: 3.1144
Validation Loss: 1.5024, Validation Accuracy: 0.1604
Epoch [7/10]
Train Loss: 0.8732
Validation Loss: 0.8880, Validation Accuracy: 0.4229
Epoch [8/10]
Train Loss: 0.9605
Validation Loss: 0.6206, Validation Accuracy: 0.5485
Epoch [9/10]
Train Loss: 1.0916
Validation Loss: 0.4599, Validation Accuracy: 0.6916
Epoch [10/10]
Train Loss: 1.2093
Validation Loss: 0.3601, Validation Accuracy: 0.8417


In [61]:
test_output = model(train_node_features, train_edge_features, train_hypergraph)

In [62]:
out = test_output.detach().cpu().numpy()
out = np.array([np.argmax(i) for i in out])
out

array([0, 0, 0, ..., 0, 0, 0])

In [63]:
# accuracy
np.mean(np.array(list(train_congestion.values())) == out)

0.8598178137651822