In [6]:
import argparse
import time
import os
from queue import Empty, Queue

import numpy as np
from tqdm import tqdm

import tritonclient.grpc as grpcclient
from tritonclient.utils import triton_to_np_dtype

In [31]:
def onnx_callback(result, error=None) -> None:
    try:
        # raise the error if anything went wrong
        if error is not None:
            error = RuntimeError(str(error))
            raise error

        # read the request id from the server response
        request_id = str(result.get_response().id)

        # parse the numpy arrays from the response and
        # package them into a dictionary mapping from
        # the name of the output to the array
        output = list(result._result.outputs)[0].name
        np_output = result.as_numpy(output)[:,1]

        # if there's only one output, just return the numpy
        # array and spare the user the need to keep track
        # of input and output names
        if len(np_output) == 1:
            np_output = np_output[output.name]

        # send these parsed outputs to downstream processes
        response = (np_output, request_id)
    #         print(response)
#         print(time.time())

        if response is not None:
            callback_q.put(response)
        print(type(result))

    except Exception as ex:
        print("Exception in callback")
        template = "An exception of type {0} occurred. Arguments:\n{1!r}"
        message = template.format(type(ex).__name__, ex.args)
        print(message)

In [32]:
gpu_node = "john108"

batch_size = 512
model = "test-onnx-512"

# Setting up client
triton_client = grpcclient.InferenceServerClient(url=gpu_node + ":8001")

dummy_data = np.random.normal(size=(batch_size, 2048, 1)).astype(np.float32)
inputs = grpcclient.InferInput("input_1", dummy_data.shape, datatype="FP32")
output = grpcclient.InferRequestedOutput("dense_4")

In [10]:
data = np.random.normal(size=(1, batch_size, 2048, 1)).astype(np.float32)

In [11]:
triton_client.get_inference_statistics().model_stats[0].inference_stats.success.count

0

In [16]:
triton_client.get_inference_statistics().model_stats[0].inference_stats.success.count

2454

In [35]:
callback_q = Queue()

t1 = time.time()
print("Start time: {}".format(t1))

for j in range(len(data)):
    request_id = f"{j}"

    inputs.set_data_from_numpy(data[j])

    triton_client.async_infer(
        model_name=model, inputs=[inputs], outputs=[output], request_id=request_id, callback=onnx_callback
    )
    time.sleep(0.001)

    

all_responses = []
count = 0
while True:
    count += 1
    if count <= len(data):
        print(count)
        response = callback_q.get()
        all_responses.append(response)
#         print(response)
    else:
        break
print(f"End time: {time.time()}")

Start time: 1696229180.0811746
1
<class 'tritonclient.grpc._infer_result.InferResult'>
End time: 1696229180.2585588


In [34]:
all_responses[0][0]

array([-12.881743 , -15.351374 , -16.866516 , -12.397059 , -12.99683  ,
       -18.571247 , -12.5456295,  -9.547202 , -14.405883 , -11.336063 ,
       -16.526608 , -13.485931 , -16.486914 , -17.400597 ,  -9.912657 ,
       -11.725847 , -11.096227 , -13.250407 , -16.69548  , -12.62651  ,
       -17.144135 , -12.980102 , -16.960503 , -16.325174 , -16.211586 ,
       -12.854264 , -10.66183  , -15.986105 ,  -9.316241 , -10.747484 ,
       -10.816374 , -16.467562 , -18.605772 ,  -9.400908 , -10.682304 ,
       -16.381891 , -11.888888 , -13.241409 , -16.291039 , -15.457211 ,
       -15.052286 , -15.159737 , -16.507519 , -16.069826 , -11.732653 ,
        -8.788711 , -16.336948 , -15.306169 , -15.403146 , -14.043018 ,
       -11.602609 , -15.618595 , -14.665495 , -13.103614 , -14.81268  ,
       -17.299099 , -14.457716 , -14.481176 , -12.831978 , -15.330777 ,
       -16.518269 , -12.720187 , -10.322359 , -14.518945 , -14.785755 ,
       -15.759824 , -13.998498 , -15.682806 , -15.376196 , -14.0