In [2]:
!bash -c python -m pip install tritonclient[all]

zsh:1: no matches found: tritonclient[all]


In [24]:
import os
import logging

import tritonclient.http as httpclient
import numpy as np
from torchvision import datasets

In [20]:
def get_triton_client(protocol):
	logging.info("Retrieving triton client")
	triton_dns = "localhost"
	if not triton_dns:
		raise Exception(f"Cannot find triton_dns")
	if protocol == "grpc":
		triton_uri = f"{triton_dns}:8001"
	else:
		triton_uri = f"{triton_dns}:8000"

	logging.info(f"Using protocol={protocol} triton_dns={triton_uri}")

	logging.info(f"Connecting to {triton_uri}")

	# Specify large enough concurrency to handle the
	# the number of requests.
	concurrency = 20
	triton_client = httpclient.InferenceServerClient(
		url=triton_uri, concurrency=concurrency
	)
	return triton_client

In [21]:
model_name = "mnist_classifier"
model_version = "1"

protocol = "http"
triton_client = get_triton_client(protocol)

In [30]:
batch = [np.random.rand(1, 28, 28).astype("float32") for _ in range(10)]

In [41]:
input_name = "input"
output_name = "output"
dtype="FP32"
classes = 10

batched_image_data = np.stack(batch, axis=0)
inputs = [httpclient.InferInput(input_name, batched_image_data.shape, dtype)]
inputs[0].set_data_from_numpy(batched_image_data)
outputs = [httpclient.InferRequestedOutput(output_name, class_count=classes)]
response = triton_client.infer(
	model_name,
	inputs,
	request_id="0",
	model_version=model_version,
	outputs=outputs,
)

In [42]:
r = response.get_response()
r

{'id': '0',
 'model_name': 'mnist_classifier',
 'model_version': '1',
 'outputs': [{'name': 'output',
   'datatype': 'BYTES',
   'shape': [10, 10],
   'parameters': {'binary_data_size': 2209}}]}

In [43]:
response.as_numpy("output")

array([[b'2.812821:2:Pullover', b'2.759900:8:Bag', b'2.749728:6:Shirt',
        b'1.851528:4:Coat', b'1.432543:0:T-shirt/top',
        b'-0.364798:5:Sandal', b'-0.570288:3:Dress',
        b'-1.213886:9:Ankle boot', b'-3.550432:1:Trouser',
        b'-4.199965:7:Sneaker'],
       [b'2.732077:8:Bag', b'2.630997:6:Shirt', b'2.358886:2:Pullover',
        b'1.651615:4:Coat', b'1.338836:0:T-shirt/top',
        b'-0.317795:5:Sandal', b'-0.349168:3:Dress',
        b'-1.143678:9:Ankle boot', b'-3.352515:1:Trouser',
        b'-3.939103:7:Sneaker'],
       [b'2.632139:6:Shirt', b'2.507267:8:Bag', b'2.492306:2:Pullover',
        b'1.664160:4:Coat', b'1.505023:0:T-shirt/top',
        b'-0.164413:5:Sandal', b'-0.605954:3:Dress',
        b'-0.899673:9:Ankle boot', b'-3.433556:1:Trouser',
        b'-3.884909:7:Sneaker'],
       [b'2.765271:8:Bag', b'2.387483:6:Shirt', b'2.213422:2:Pullover',
        b'1.350260:4:Coat', b'1.322876:0:T-shirt/top',
        b'-0.024314:5:Sandal', b'-0.578251:3:Dress',
    

In [None]:
def get_model_metadata_config(model_name, model_version, triton_client, protocol):
    # Make sure the model matches our requirements, and get some
    # properties of the model that we need for preprocessing
    logging.info(
        f"Retrieve model info of model {model_name} and version {model_version}"
    )
    model_metadata = triton_client.get_model_metadata(
        model_name=model_name, model_version=model_version
    )
    logging.info(f"Found: {model_metadata}")

    logging.info(
        f"Retrieve model config of model {model_name} and version {model_version}"
    )
    model_config = triton_client.get_model_config(
        model_name=model_name, model_version=model_version
    )
    logging.info(f"Found: {model_config}")

    # if protocol.lower() == "grpc":
    #     model_config = model_config.config
    # else:
    #     model_metadata, model_config = convert_http_metadata_config(
    #         model_metadata, model_config
    #     )

    return model_metadata, model_config