In [449]:
import os
import numpy as np
import pickle
import torch
import torch.nn
from torch_geometric.data import Dataset
from torch_geometric.data import Data, HeteroData
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from utils import *
import torch.nn as nn
import torch.optim as optim
import wandb
from tqdm import tqdm

import sys
sys.path.insert(1, 'data/')

from pyg_dataset import NetlistDataset

sys.path.append("models/layers/")
from models.model import GNN_node

In [450]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [451]:
import sys
sys.path.insert(1, 'data/')

In [452]:
from pyg_dataset import NetlistDataset

In [453]:
dataset = NetlistDataset(data_dir="data/processed_datasets", load_pe = True, pl = True, processed = True, load_indices=[0])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.46it/s]


In [454]:
h_dataset = []
for data in tqdm(dataset):
    num_instances = data.node_congestion.shape[0]
    data.num_instances = num_instances
    data.edge_index_source_to_net[1] = data.edge_index_source_to_net[1] - num_instances
    data.edge_index_sink_to_net[1] = data.edge_index_sink_to_net[1] - num_instances

    h_data = HeteroData()
    h_data['node'].x = data.node_features
    h_data['node'].y = data.node_congestion
    
    h_data['net'].x = data.net_features
    h_data['net'].y = data.net_hpwl
    
    h_data['node', 'as_a_sink_of', 'net'].edge_index, h_data['node', 'as_a_sink_of', 'net'].edge_weight = gcn_norm(data.edge_index_sink_to_net, add_self_loops=False)
    h_data['node', 'as_a_source_of', 'net'].edge_index = data.edge_index_source_to_net

    h_data.batch = data.batch
    h_data.num_vn = data.num_vn
    h_data.num_instances = num_instances
    h_dataset.append(h_data)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.02it/s]


In [455]:
sys.path.append("models/layers/")

In [456]:
from models.model import GNN_node

In [463]:
device = "cpu"

In [464]:
data = dataset[0]

In [465]:
model = GNN_node(4, 32, 8, 1, node_dim = data.node_features.shape[1], net_dim = data.net_features.shape[1], vn=True).to(device)

In [466]:
criterion_node = nn.CrossEntropyLoss()
criterion_net = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [467]:
data = h_data

In [468]:
for epoch in range(500):
    model.train()
    optimizer.zero_grad()
    node_representation, net_representation = model(data, device)
    loss_node = criterion_node(node_representation, data['node'].y.to(device))
    loss_net = criterion_net(net_representation.flatten(), data['net'].y.to(device))
    loss = loss_node + 0.001*loss_net
    loss.backward()
    optimizer.step()   
    
    model.eval()
    node_representation, net_representation = model(data, device)
    val_loss_node = criterion_node(node_representation, data['node'].y.to(device))
    val_loss_net = criterion_net(net_representation.flatten(), data['net'].y.to(device))
    print(val_loss_node.item(), val_loss_net.item())

RuntimeError: Attempting CUDA graph capture of step() for an instance of Adam but this instance was constructed with capturable=False.

In [100]:
torch.save(model, "trained_dehnn.model")

In [144]:
def compute_accuracy(logits, targets):
    predicted_classes = torch.argmax(logits, dim=1)
    #predicted_classes = torch.round(logits.flatten())
    correct_predictions = (predicted_classes.long() == targets.long()).sum().item()
    accuracy = correct_predictions / targets.size(0)
    return accuracy

In [165]:
compute_accuracy(node_representation[data.train_indices].cpu(), data.node_congestion[data.train_indices])

0.8965819224032695

In [166]:
compute_accuracy(node_representation[data.valid_indices].cpu(), data.node_congestion[data.valid_indices])

0.8966282593171598

In [167]:
compute_accuracy(node_representation[data.test_indices].cpu(), data.node_congestion[data.test_indices])

0.8971165537606393

In [168]:
torch.nn.functional.l1_loss(net_representation.flatten()[data.net_train_indices].cpu(), data.net_hpwl[data.net_train_indices])

tensor(7.6716, grad_fn=<MeanBackward0>)

In [169]:
torch.nn.functional.l1_loss(net_representation.flatten()[data.net_valid_indices].cpu(), data.net_hpwl[data.net_valid_indices])

tensor(7.7605, grad_fn=<MeanBackward0>)

In [170]:
torch.nn.functional.l1_loss(net_representation.flatten()[data.net_test_indices].cpu(), data.net_hpwl[data.net_test_indices])

tensor(7.7583, grad_fn=<MeanBackward0>)

In [40]:
predicted_classes = torch.argmax(node_representation, dim=1)

In [82]:
import matplotlib.pyplot as plt

In [84]:
pos_lst = data.node_features[:, 7:9]

In [90]:
x_lst = pos_lst[:, 0].cpu().detach().flatten()
y_lst = pos_lst[:, 1].cpu().detach().flatten()

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(20, 10))
axs[0].scatter(x_lst, y_lst, c=data.node_congestion.flatten().detach().cpu(), s=1)
axs[1].scatter(x_lst, y_lst, c=predicted_classes.detach().cpu(), s=1)

In [209]:
from torch_geometric.nn import global_mean_pool, global_max_pool, global_add_pool