### Imports

In [1]:
import os
import yaml
import argparse
import numpy as np
import torch

from hls4ml.utils.config import config_from_pyg_model
from hls4ml.converters import convert_from_pyg_model
from collections import OrderedDict
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, mean_absolute_error, mean_squared_error

# locals
from utils.models.interaction_network_pyg import InteractionNetwork
from model_wrappers import model_wrapper
from utils.data.dataset_pyg import GraphDataset
from utils.data.fix_graph_size import fix_graph_size



### PyTorch Model

In [2]:
torch_model = InteractionNetwork(aggr="add", flow="source_to_target", hidden_size=40)
torch_model_dict = torch.load("trained_models//IN_pyg_small_add_source_to_target_40_state_dict.pt")
torch_model.load_state_dict(torch_model_dict)

for name, submodule in torch_model.named_modules():
    if name != "":
        print(f"{name}: {submodule}")

R1: RelationalModel(
  (layers): Sequential(
    (0): Linear(in_features=10, out_features=40, bias=True)
    (1): ReLU()
    (2): Linear(in_features=40, out_features=40, bias=True)
    (3): ReLU()
    (4): Linear(in_features=40, out_features=4, bias=True)
  )
)
R1.layers: Sequential(
  (0): Linear(in_features=10, out_features=40, bias=True)
  (1): ReLU()
  (2): Linear(in_features=40, out_features=40, bias=True)
  (3): ReLU()
  (4): Linear(in_features=40, out_features=4, bias=True)
)
R1.layers.0: Linear(in_features=10, out_features=40, bias=True)
R1.layers.1: ReLU()
R1.layers.2: Linear(in_features=40, out_features=40, bias=True)
R1.layers.3: ReLU()
R1.layers.4: Linear(in_features=40, out_features=4, bias=True)
O: ObjectModel(
  (layers): Sequential(
    (0): Linear(in_features=7, out_features=40, bias=True)
    (1): ReLU()
    (2): Linear(in_features=40, out_features=40, bias=True)
    (3): ReLU()
    (4): Linear(in_features=40, out_features=3, bias=True)
  )
)
O.layers: Sequential(
  (

### HLS Model

In [3]:
# forward_dict: defines the order in which graph-blocks are called in the model's 'forward()' method
forward_dict = OrderedDict()
forward_dict["R1"] = "EdgeBlock"
forward_dict["O"] = "NodeBlock"
forward_dict["R2"] = "EdgeBlock"

In [4]:
graph_dims = {
        "n_node": 28,
        "n_edge": 37,
        "node_dim": 3,
        "edge_dim": 4
}

In [5]:
output_dir = "test_GNN"
config = config_from_pyg_model(torch_model,
                                   default_precision="ap_fixed<16,8>",
                                   default_index_precision='ap_uint<16>', 
                                   default_reuse_factor=1)
hls_model = convert_from_pyg_model(torch_model,
                                       n_edge=graph_dims['n_edge'],
                                       n_node=graph_dims['n_node'],
                                       edge_dim=graph_dims['edge_dim'],
                                       node_dim=graph_dims['node_dim'],
                                       forward_dictionary=forward_dict, 
                                       activate_final='sigmoid',
                                       output_dir=output_dir,
                                       hls_config=config)

In [6]:
output_dir = "test_GNN"
config = config_from_pyg_model(torch_model,
                                   default_precision="ap_fixed<32,16>",
                                   default_index_precision='ap_uint<16>', 
                                   default_reuse_factor=8)
hls_model = convert_from_pyg_model(torch_model,
                                       n_edge=graph_dims['n_edge'],
                                       n_node=graph_dims['n_node'],
                                       edge_dim=graph_dims['edge_dim'],
                                       node_dim=graph_dims['node_dim'],
                                       forward_dictionary=forward_dict, 
                                       activate_final='sigmoid',
                                       output_dir=output_dir,
                                       hls_config=config)

In [7]:
hls_model.compile()

Writing HLS project
Done


# Evaluation and prediction: hls_model.predict(input)

### Data

In [8]:
class data_wrapper(object):
    def __init__(self, node_attr, edge_attr, edge_index, target):
        self.x = node_attr
        self.edge_attr = edge_attr
        self.edge_index = edge_index.transpose(0,1)

        node_attr, edge_attr, edge_index = self.x.detach().cpu().numpy(), self.edge_attr.detach().cpu().numpy(), self.edge_index.transpose(0, 1).detach().cpu().numpy().astype(np.float32)
        node_attr, edge_attr, edge_index = np.ascontiguousarray(node_attr), np.ascontiguousarray(edge_attr), np.ascontiguousarray(edge_index)
        self.hls_data = [node_attr, edge_attr, edge_index]

        self.target = target
        self.np_target = np.reshape(target.detach().cpu().numpy(), newshape=(target.shape[0],))

def load_graphs(graph_indir, graph_dims, n_graphs):
    graph_files = np.array(os.listdir(graph_indir))
    graph_files = np.array([os.path.join(graph_indir, graph_file)
                            for graph_file in graph_files])
    n_graphs_total = len(graph_files)
    IDs = np.arange(n_graphs_total)
    dataset = GraphDataset(graph_files=graph_files[IDs])

    graphs = []
    for data in dataset[:n_graphs]:
        node_attr, edge_attr, edge_index, target, bad_graph = fix_graph_size(data.x, data.edge_attr, data.edge_index,
                                                                             data.y,
                                                                             n_node_max=graph_dims['n_node'],
                                                                             n_edge_max=graph_dims['n_edge'])
        if not bad_graph:
            graphs.append(data_wrapper(node_attr, edge_attr, edge_index, target))
    print(f"n_graphs: {len(graphs)}")

    print("writing test bench data for 1st graph")
    data = graphs[0]
    node_attr, edge_attr, edge_index = data.x.detach().cpu().numpy(), data.edge_attr.detach().cpu().numpy(), data.edge_index.transpose(
        0, 1).detach().cpu().numpy().astype(np.int32)
    os.makedirs('tb_data', exist_ok=True)
    input_data = np.concatenate([node_attr.reshape(1, -1), edge_attr.reshape(1, -1), edge_index.reshape(1, -1)], axis=1)
    np.savetxt('tb_data/input_data.dat', input_data, fmt='%f', delimiter=' ')

    return graphs


graph_indir = "trackml_data/processed_plus_pyg_small"
graph_dims = {
        "n_node": 28,
        "n_edge": 37,
        "node_dim": 3,
        "edge_dim": 4
    }
graphs = load_graphs(graph_indir, graph_dims, n_graphs=100)

n_graphs: 2
writing test bench data for 1st graph


In [9]:
data = graphs[0]
torch_pred = torch_model(data)
hls_pred = hls_model.predict(data.hls_data)
MSE = mean_squared_error(torch_pred.detach().cpu().numpy(), hls_pred)
print(f"MSE: {MSE}")

MSE: 7.784141189404181e-07
