# Explore Onnx GNN Compilation

In [1]:
# System imports
import os
import sys
import collections.abc as container_abcs
from pprint import pprint as pp
from time import time as tt

# External imports
import matplotlib.pyplot as plt
import matplotlib.colors
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.data import DataLoader
import torch.nn as nn
# from torch_scatter import scatter_add
import torch.nn.functional as F

import ipywidgets as widgets
from ipywidgets import interact, interact_manual
from IPython.display import clear_output
from IPython.display import HTML, display
import onnxruntime

%matplotlib inline

sys.path.append("..")

# Get rid of RuntimeWarnings, gross
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

# from gnn.checkpoint_agnn import CheckpointedResAGNN

device = "cuda" if torch.cuda.is_available() else "cpu"

OSError: libcublas.so.11: cannot open shared object file: No such file or directory

## Load Data

In [2]:
input_dir= "/global/cscratch1/sd/danieltm/ExaTrkX/trackml-codalab/embedding_processed/1_pt_cut_endcaps_unweighted_augmented/train"

In [3]:
num_events = 10
all_events = os.listdir(input_dir)
loaded_events = [torch.load(os.path.join(input_dir,event)) for event in all_events[:num_events]]

## Model Definitions

In [4]:
def make_mlp(
    input_size,
    sizes,
    hidden_activation="ReLU",
    output_activation="ReLU",
    layer_norm=False,
):
    """Construct an MLP with specified fully-connected layers."""
    hidden_activation = getattr(nn, hidden_activation)
    if output_activation is not None:
        output_activation = getattr(nn, output_activation)
    layers = []
    n_layers = len(sizes)
    sizes = [input_size] + sizes
    # Hidden layers
    for i in range(n_layers - 1):
        layers.append(nn.Linear(sizes[i], sizes[i + 1]))
        if layer_norm:
            layers.append(nn.LayerNorm(sizes[i + 1]))
        layers.append(hidden_activation())
    # Final layer
    layers.append(nn.Linear(sizes[-2], sizes[-1]))
    if output_activation is not None:
        if layer_norm:
            layers.append(nn.LayerNorm(sizes[-1]))
        layers.append(output_activation())
    return nn.Sequential(*layers)

In [5]:
def scatter_add_attention(encoded_nodes, encoded_edges, edge_list):
    start, end = edge_list[0], edge_list[1]

    src = encoded_nodes[end]*encoded_edges
    index = start.unsqueeze(-1)
    in_messages = torch.zeros(encoded_nodes.shape, dtype=src.dtype, device=encoded_nodes.device).scatter_add(0, index.repeat((1,src.shape[1])), src) 

    src = encoded_nodes[start]*encoded_edges
    index = end.unsqueeze(-1)
    out_messages = torch.zeros(encoded_nodes.shape, dtype=src.dtype, device=encoded_nodes.device).scatter_add(0, index.repeat((1,src.shape[1])), src) 
    
    aggr_nodes = in_messages + out_messages
    
    return aggr_nodes

In [10]:
class ResAGNN(nn.Module):
    def __init__(self, hparams):
        super(ResAGNN, self).__init__()
        """
        Initialise the Lightning Module that can scan over different GNN training regimes
        """
        
        self.hparams = hparams
        
        # Setup input network
        self.node_encoder = make_mlp(
            hparams["in_channels"],
            [hparams["hidden"]],
            output_activation=hparams["hidden_activation"],
            layer_norm=hparams["layernorm"],
        )

        # The edge network computes new edge features from connected nodes
        self.edge_network = make_mlp(
            2 * (hparams["in_channels"] + hparams["hidden"]),
            [hparams["hidden"]] * hparams["nb_edge_layer"] + [1],
            layer_norm=hparams["layernorm"],
            output_activation=None,
            hidden_activation=hparams["hidden_activation"],
        )

        # The node network computes new node features
        self.node_network = make_mlp(
            (hparams["in_channels"] + hparams["hidden"]) * 2,
            [hparams["hidden"]] * hparams["nb_node_layer"],
            layer_norm=hparams["layernorm"],
            output_activation=None,
            hidden_activation=hparams["hidden_activation"],
        )

    def forward(self, x, edge_index):

        # Encode the graph features into the hidden space
        input_x = x
        x = self.node_encoder(x)
        x = torch.cat([x, input_x], dim=-1)

        start, end = edge_index[0], edge_index[1]

        # Loop over iterations of edge and node networks
        for i in range(self.hparams["n_graph_iters"]):
            # Previous hidden state
            x0 = x

            # Compute new edge score
            edge_inputs = torch.cat([x[start], x[end]], dim=1)
            e = self.edge_network(edge_inputs)
            e = torch.sigmoid(e)

            # Sum weighted node features coming into each node
            #             weighted_messages_in = scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0])
            #             weighted_messages_out = scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0])

            weighted_messages = scatter_add_attention(x, e, edge_index)

            # Compute new node features
            #             node_inputs = torch.cat([x, weighted_messages_in, weighted_messages_out], dim=1)
            node_inputs = torch.cat([x, weighted_messages], dim=1)
            x = self.node_network(node_inputs)

            # Residual connection
            x = torch.cat([x, input_x], dim=-1)
            x = x + x0

        # Compute final edge scores; use original edge directions only
        clf_inputs = torch.cat([x[start], x[end]], dim=1)
        return self.edge_network(clf_inputs).squeeze(-1)


Ignore this testing model:

In [6]:
class ResAGNN(nn.Module):
    def __init__(self):
        super(ResAGNN, self).__init__()
        """
        Initialise the Lightning Module that can scan over different GNN training regimes
        """
        print("init")
        

        self.test_network =  torch.nn.Sequential(
                                torch.nn.Linear(3, 64),
                                torch.nn.ReLU(),
                                torch.nn.Linear(64, 1),
                            )

    def forward(self, x):

        # Encode the graph features into the hidden space
        #         input_x = x
        new_x = self.test_network(x)
        #         x = torch.cat([x, input_x], dim=1)

        #         start, end = edge_index[0], edge_index[1]

        return new_x
        
    
        # Loop over iterations of edge and node networks
#         for i in range(8):
#             # Previous hidden state
#             x0 = x

#             # Compute new edge score
#             edge_inputs = torch.cat([x[start], x[end]], dim=1)
        
        
#             e = self.edge_network(edge_inputs)
#             e = torch.sigmoid(e)

        
#             # Sum weighted node features coming into each node
#             #             weighted_messages_in = scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0])
#             #             weighted_messages_out = scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0])

#             weighted_messages = scatter_add_attention(x, e, edge_index)

#             # Compute new node features
#             #             node_inputs = torch.cat([x, weighted_messages_in, weighted_messages_out], dim=1)
#             node_inputs = torch.cat([x, weighted_messages], dim=1)
#             x = self.node_network(node_inputs)

#             # Residual connection
#             x = torch.cat([x, input_x], dim=-1)
#             x = x + x0

#         # Compute final edge scores; use original edge directions only
#         clf_inputs = torch.cat([x[start], x[end]], dim=1)
# return self.edge_network(clf_inputs).squeeze(-1)


## Load Model

In [3]:
hparams = checkpoint["hyper_parameters"]
state_dict = checkpoint["state_dict"]

In [4]:
torch.save(hparams, "hyper_parameters.ckpt")
torch.save(state_dict, "state_dict.ckpt")

In [8]:
hparams = torch.load("hyper_parameters.ckpt")
state_dict = torch.load("state_dict.ckpt")

In [11]:
model = ResAGNN(hparams)

In [12]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [15]:
example_data = loaded_events[0]
input_data = (example_data.x, example_data.edge_index)

## Onnx Testing

In [16]:
ONNX_FILE_PATH = "ResAGNN_model.onnx"
dynamic_axes = {"nodes": [0, 1], "edge_index": [0, 1]}
torch.onnx.export(model, input_data, ONNX_FILE_PATH, input_names=["nodes", "edge_index"], opset_version=12,
                  output_names=["output"], export_params=True, dynamic_axes=dynamic_axes)

In [17]:
session = onnxruntime.InferenceSession(ONNX_FILE_PATH, None)

In [18]:
input_data_list = [(event.x, event.edge_index) for event in loaded_events]

# VOILA!

In [19]:
session.run(None, {"nodes": input_data_list[0][0].numpy(), "edge_index": input_data_list[0][1].numpy()})[0]

array([ 0.9123555, -0.6301682,  0.9521093, ...,  4.3359084,  3.2423496,
        1.5909792], dtype=float32)

In [20]:
session.run(None, {"nodes": input_data_list[1][0].numpy(), "edge_index": input_data_list[1][1].numpy()})[0]

array([ 0.91810256, -3.100629  , -0.5200292 , ..., -2.6808546 ,
        4.421003  ,  4.1069803 ], dtype=float32)

## Torchscript Testing

In [27]:
traced_script_module = torch.jit.trace(model, input_data)

In [59]:
traced_script_module(input_data[0], input_data[1])

tensor([-0.2641, -0.2798, -0.2481,  ..., -0.4477, -0.0890, -0.5004],
       grad_fn=<SqueezeBackward1>)

In [22]:
script_module = torch.jit.script(ResAGNN())

init


In [23]:
# out= script_module(input_data[0], input_data[1])
out= script_module(input_data[0])

In [24]:
out

tensor([[ 0.1020],
        [ 0.1171],
        [ 0.1397],
        ...,
        [ 0.3792],
        [-0.0710],
        [-0.0020]], grad_fn=<AddmmBackward>)

In [32]:
ONNX_FILE_PATH = "ResAGNN_script_model.onnx"
dynamic_axes = {"nodes": [0, 1]}
# torch.onnx.export(script_module, input_data, ONNX_FILE_PATH, input_names=["nodes", "edge_index"],  opset_version=12,
torch.onnx.export(script_module, input_data[0], ONNX_FILE_PATH, input_names=["nodes"],  opset_version=12, verbose=True,
                  export_params=True, example_outputs=out, dynamic_axes=dynamic_axes)

graph(%nodes : Float(*, *, strides=[3, 1], requires_grad=0, device=cpu),
      %test_network.0.weight : Float(64, 3, strides=[3, 1], requires_grad=0, device=cpu),
      %test_network.0.bias : Float(64, strides=[1], requires_grad=0, device=cpu),
      %test_network.2.weight : Float(1, 64, strides=[64, 1], requires_grad=0, device=cpu),
      %test_network.2.bias : Float(1, strides=[1], requires_grad=0, device=cpu)):
  %5 : Float(*, 64, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1](%nodes, %test_network.0.weight, %test_network.0.bias) # /global/homes/d/danieltm/.conda/envs/exatrkx-test/lib/python3.7/site-packages/torch/nn/functional.py:1847:11
  %6 : Float(*, 64, device=cpu) = onnx::Relu(%5) # /global/homes/d/danieltm/.conda/envs/exatrkx-test/lib/python3.7/site-packages/torch/nn/functional.py:1298:17
  %7 : Float(*, 1, strides=[1, 1], requires_grad=1, device=cpu) = onnx::Gemm[alpha=1., beta=1., transB=1](%6, %test_network.2.weight, %test_network.2.bias) # /global/homes/d/danieltm/

  'Automatically generated names will be applied to each dynamic axes of input {}'.format(key))


In [33]:
ONNX_FILE_PATH = "ResAGNN_script_model.onnx"
session = onnxruntime.InferenceSession(ONNX_FILE_PATH)

In [34]:
input_data_list = [(event.x, event.edge_index) for event in loaded_events]

In [35]:
session.get_inputs()[0].name

'nodes'

In [36]:
session.run(None, {"nodes": input_data_list[0][0].numpy()})[0]

array([[ 0.10199904],
       [ 0.11714503],
       [ 0.13971089],
       ...,
       [ 0.37922925],
       [-0.07096032],
       [-0.00202517]], dtype=float32)

In [37]:
session.run(None, {"nodes": input_data_list[1][0].numpy()})[0]

array([[ 0.12115289],
       [ 0.06801388],
       [ 0.12111214],
       ...,
       [-0.06109324],
       [-0.06108928],
       [ 0.10260198]], dtype=float32)

In [74]:
session.run(None, {"nodes": input_data_list[0][0].numpy(), "edge_index": input_data_list[0][1].numpy()})[0]

InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:nodes

In [75]:
session.run(None, {"nodes": input_data_list[1][0].numpy(), "edge_index": input_data_list[1][1].numpy()})[0]

InvalidArgument: [ONNXRuntimeError] : 2 : INVALID_ARGUMENT : Invalid Feed Input Name:nodes