In [None]:
import os
import numpy as np
import pickle
import torch
import torch.nn
from torch_geometric.data import Dataset
from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from utils import *
import torch.nn as nn
import torch.optim as optim
import wandb
from tqdm import tqdm

import sys
sys.path.insert(1, 'data/')

from pyg_dataset import NetlistDataset

sys.path.append("models/layers/")
from models.model import GNN_node

In [None]:
# You need to put all the design directories in the "data/cross_design_data/"

dataset = NetlistDataset(data_dir="data/cross_design_data/", load_pe = True, pl = True, processed = False)

In [None]:
h_dataset = []
for data in tqdm(dataset):
    num_instances = data.node_congestion.shape[0]
    data.num_instances = num_instances
    data.edge_index_source_to_net[1] = data.edge_index_source_to_net[1] - num_instances
    data.edge_index_sink_to_net[1] = data.edge_index_sink_to_net[1] - num_instances

    h_data = HeteroData()
    h_data['node'].x = data.node_features
    h_data['node'].y = data.node_congestion
    
    h_data['net'].x = data.net_features
    h_data['net'].y = data.net_hpwl
    
    h_data['node', 'as_a_sink_of', 'net'].edge_index, h_data['node', 'as_a_sink_of', 'net'].edge_weight = gcn_norm(data.edge_index_sink_to_net, add_self_loops=False)
    h_data['node', 'as_a_source_of', 'net'].edge_index = data.edge_index_source_to_net

    h_data.batch = data.batch
    h_data.num_vn = data.num_vn
    h_data.num_instances = num_instances
    h_dataset.append(h_data)

In [None]:
device = "cuda"
#model = GNN_node(4, 32, 8, 1, node_dim = data.node_features.shape[1], net_dim = data.net_features.shape[1], vn=True).to(device)
model = torch.load("best_dehnn_model.pt")

In [None]:
total_test_acc = 0
total_test_net_l1 = 0
test_acc, test_net_l1 = 0, 0
all_test_idx = 0
for data in tqdm(h_dataset):
    try:
        node_representation, net_representation = model(data, device)
        test_acc = compute_accuracy(node_representation, data['node'].y.to(device))
        test_net_l1 = torch.nn.functional.l1_loss(net_representation.flatten(), data['net'].y.to(device)).item()
    except:
        print("OOM")
        continue
    
    total_test_acc += test_acc
    total_test_net_l1 += test_net_l1
    all_test_idx += 1

In [None]:
print(total_test_acc/all_test_idx, total_test_net_l1/all_test_idx)