In [1]:
import numpy as np
import tritonclient.grpc as grpcclient

from PIL import Image

In [2]:
triton_client = grpcclient.InferenceServerClient(url="localhost:8001", verbose=False)

In [3]:
def test_infer(
    model_name,
    input0_data,
):
    inputs = []
    outputs = []
    inputs.append(grpcclient.InferInput("input", [*input0_data.shape], "FP16"))

    # Initialize the data
    inputs[0].set_data_from_numpy(input0_data)

    outputs.append(grpcclient.InferRequestedOutput("output"))
    query_params = {"test_1": "1", "test_2": "2"}
    results = triton_client.infer(
        model_name,
        inputs,
        outputs=outputs
    )

    return results

In [4]:
images = np.array([
    np.array(Image.open('../data/img0.jpg').convert('RGB')).transpose(2, 0, 1) / 255,
    # np.array(Image.open('data/1.jpg').convert('RGB')).transpose(2, 0, 1) / 255
])

In [5]:
model_name = "ganx2_tensorrt"

In [6]:
input0_data = np.array(images).astype(np.float16)

In [7]:
# Infer with requested Outputs
results = test_infer(
    model_name,
    input0_data
)

In [8]:
# Validate the results by comparing with precomputed values.
output0_data = results.as_numpy("output")

In [9]:
output0_data.shape

(1, 3, 480, 640)

In [10]:
(output0_data[0] * 255).transpose(1, 2, 0).astype(np.uint8).shape

(480, 640, 3)

In [11]:
(output0_data[0] * 255).transpose(1, 2, 0).astype(np.uint8).shape

(480, 640, 3)

In [12]:
from PIL import Image
im = Image.fromarray((output0_data[0] * 255).astype(np.uint8).transpose(2, 1, 0))
im.save("your_file1.jpeg")

In [14]:
output0_data*255

array([[[[ 89.9 ,  90.75,  86.3 , ..., 233.1 , 236.6 , 241.1 ],
         [ 74.9 ,  72.5 ,  72.9 , ..., 232.5 , 236.2 , 233.5 ],
         [ 52.03,  51.22,  54.4 , ..., 229.9 , 231.1 , 233.4 ],
         ...,
         [112.75, 111.5 , 111.94, ...,  22.17,  22.94,  21.7 ],
         [107.44, 113.  , 109.44, ...,  22.5 ,  22.16,  20.89],
         [114.1 , 112.25, 112.2 , ...,  22.83,  18.75,  21.1 ]],

        [[ 90.9 ,  94.5 ,  90.4 , ..., 244.2 , 253.1 , 251.6 ],
         [ 82.44,  78.  ,  77.44, ..., 242.8 , 244.1 , 245.  ],
         [ 55.53,  56.94,  58.8 , ..., 242.1 , 242.1 , 245.  ],
         ...,
         [114.25, 112.  , 113.56, ...,  18.94,  19.5 ,  18.55],
         [109.1 , 113.94, 110.9 , ...,  18.23,  18.52,  18.7 ],
         [113.56, 113.5 , 114.44, ...,  17.83,  16.55,  16.64]],

        [[103.7 , 105.75, 103.9 , ..., 261.  , 265.8 , 267.2 ],
         [ 89.06,  88.8 ,  88.94, ..., 259.  , 261.8 , 257.8 ],
         [ 65.9 ,  67.8 ,  69.56, ..., 257.  , 258.  , 261.8 ],
        