# Triton & ONNX
- runs inference on triton
- runs inference from onnx

In [24]:
from typing import Optional, List, Dict

import numpy as np
import scipy

import tritonclient.grpc as triton_grpc
import tritonclient.http as triton_http

from tqdm import tqdm

In [25]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:
# from https://github.com/lgray/hgg-coffea/blob/triton-bdts/src/hgg_coffea/tools/chained_quantile.py
class wrapped_triton:
    def __init__(
        self,
        model_url: str,
    ) -> None:
        fullprotocol, location = model_url.split("://")
        _, protocol = fullprotocol.split("+")
        address, model, version = location.split("/")

        self._protocol = protocol
        self._address = address
        self._model = model
        self._version = version

    def __call__(self, input_dict: Dict[str, np.ndarray]) -> np.ndarray:
        if self._protocol == "grpc":
            client = triton_grpc.InferenceServerClient(url=self._address, verbose=False)
            triton_protocol = triton_grpc
        elif self._protocol == "http":
            client = triton_http.InferenceServerClient(
                url=self._address,
                verbose=False,
                concurrency=12,
            )
            triton_protocol = triton_http
        else:
            raise ValueError(f"{self._protocol} does not encode a valid protocol (grpc or http)")

        # Infer
        inputs = []

        for key in input_dict:
            input = triton_protocol.InferInput(key, input_dict[key].shape, "FP32")
            input.set_data_from_numpy(input_dict[key])
            inputs.append(input)

        output = triton_protocol.InferRequestedOutput("softmax")

        request = client.infer(
            self._model,
            model_version=self._version,
            inputs=inputs,
            outputs=[output],
        )

        out = request.as_numpy("softmax")

        return out

In [46]:
batch_size = 128
# pfs = 100
# svs = 7
pfs = 128
svs = 10
np.random.seed(42)

# input_dict = {
#     "pf_points": np.random.rand(batch_size, 2, pfs).astype("float32"),
#     "pf_features": np.random.rand(batch_size, 19, pfs).astype("float32"),
#     "pf_mask": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
#     "sv_points": np.random.rand(batch_size, 2, svs).astype("float32"),
#     "sv_features": np.random.rand(batch_size, 11, svs).astype("float32"),
#     "sv_mask": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
# }

input_dict = {
    "pf_features": np.random.rand(batch_size, 25, pfs).astype("float32"),
    "pf_vectors": np.random.rand(batch_size, 4, pfs).astype("float32"),
    "pf_mask": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
    "sv_features": np.random.rand(batch_size, 11, svs).astype("float32"),
    "sv_vectors": np.random.rand(batch_size, 4, svs).astype("float32"),
    "sv_mask": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
}

# input_dict = {
#     "pf_points__0": np.random.rand(batch_size, 2, pfs).astype("float32"),
#     "pf_features__1": np.random.rand(batch_size, 19, pfs).astype("float32"),
#     "pf_mask__2": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
#     "sv_points__3": np.random.rand(batch_size, 2, svs).astype("float32"),
#     "sv_features__4": np.random.rand(batch_size, 11, svs).astype("float32"),
#     "sv_mask__5": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
# }

In [47]:
# model_url = "triton+grpc://ailab01.fnal.gov:8001/particlenet_hww/1"
# model_url = "triton+grpc://prp-gpu-1.t2.ucsd.edu:8001/particlenet_hww/1"
# model_url = "triton+grpc://67.58.49.52:8001/ak8_MD_vminclv2ParT_manual_fixwrap/1"
model_url = "triton+grpc://67.58.49.48:8001/ak8_MD_vminclv2ParT_manual_fixwrap_all_nodes/1"


model_url = "triton+grpc://67.58.49.48:8001/ak8_MD_vminclv2ParT_manual_fixwrap_all_nodes/1"


# model_url = "triton+grpc://67.58.49.48:8001/2023May30_ak8_MD_inclv8_part_2reg_manual/1"
triton_model = wrapped_triton(model_url)
for i in tqdm(range(1)):
    output = triton_model(input_dict)
print(output)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.63s/it]

[[ 2.3253001e-03  7.6664449e-03  8.5004615e-03 ... -1.3894490e+00
   1.9106083e+00  1.2435826e+00]
 [ 3.2445844e-03  6.0990350e-03  4.9772575e-03 ... -1.4686842e+00
   2.4426298e+00  5.2210975e-01]
 [ 3.6650272e-03  1.2284490e-02  6.0200943e-03 ... -1.1954683e+00
   1.8408061e+00  1.3436289e+00]
 ...
 [ 2.9861033e-03  4.1847778e-03  2.4435457e-03 ... -1.3185942e+00
   2.0695624e+00  1.6338797e+00]
 [ 3.0086911e-04  1.0955615e-03  7.7940477e-04 ... -1.1281847e+00
   2.0996377e+00  9.3868500e-01]
 [ 2.0276865e-03  4.6882359e-03  2.7664499e-03 ... -1.6536552e+00
   2.1289117e+00  1.6567315e+00]]





In [None]:
output

In [25]:
output[0].shape

(166,)

# Run model using onnx

In [31]:
import onnx
import onnxruntime as ort

batch_size = 20
pfs = 128
svs = 10

input_dict = {
    "pf_features": np.random.rand(batch_size, 25, pfs).astype("float32"),
    "pf_vectors": np.random.rand(batch_size, 4, pfs).astype("float32"),
    "pf_mask": (np.random.rand(batch_size, 1, pfs) > 0.2).astype("float32"),
    "sv_features": np.random.rand(batch_size, 11, svs).astype("float32"),
    "sv_vectors": np.random.rand(batch_size, 4, svs).astype("float32"),
    "sv_mask": (np.random.rand(batch_size, 1, svs) > 0.2).astype("float32"),
}

onnx_model = onnx.load("/Users/fmokhtar/projects/weaver-core-dev/ak8_MD_vminclv2ParT_manual_fixwrap/1/model.onnx")
onnx.checker.check_model(onnx_model)

ort_sess = ort.InferenceSession("/Users/fmokhtar/projects/weaver-core-dev/ak8_MD_vminclv2ParT_manual_fixwrap/1/model.onnx")
outputs = ort_sess.run(None, input_dict)
print(outputs)

[array([[4.86524962e-03, 4.97966399e-03, 1.74506661e-03, 1.17051192e-02,
        6.81609195e-03, 5.58614638e-03, 2.07647705e-03, 8.02133349e-04,
        1.28070323e-03, 9.03823646e-04, 8.59912427e-04, 1.36731251e-03,
        5.05559612e-04, 1.48216175e-04, 1.17055039e-04, 7.71649793e-05,
        4.24696133e-02, 3.44527699e-02, 1.47807617e-02, 2.23821755e-02,
        1.27572479e-04, 6.49630329e-06, 1.83118209e-06, 5.21351621e-02,
        3.65206562e-02, 1.85362101e-01, 8.91993344e-02, 2.09410697e-01,
        2.67438330e-02, 3.88626419e-02, 1.29468903e-01, 4.75959405e-02,
        6.96596084e-03, 5.73533494e-03, 2.12198449e-03, 9.53700114e-03,
        2.28361320e-03],
       [2.21705972e-03, 4.84761875e-03, 4.28972626e-03, 4.10481263e-03,
        6.57369848e-03, 1.64161697e-02, 6.01862092e-04, 4.78196889e-04,
        1.61183663e-02, 1.34124737e-02, 1.45546772e-04, 2.76738923e-04,
        8.05218797e-03, 2.30695214e-03, 1.15670846e-05, 9.05202251e-06,
        5.45549653e-02, 8.37783888e-02

In [32]:
outputs[0].shape

(20, 37)