Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attempting to run T5 ORT model in Triton inference server #157

Open
samiur opened this issue Dec 14, 2022 · 1 comment
Open

Attempting to run T5 ORT model in Triton inference server #157

samiur opened this issue Dec 14, 2022 · 1 comment

Comments

@samiur
Copy link

samiur commented Dec 14, 2022

Hi there,

Thanks again for this library!

We're trying to convert a fine-tuned T5 model to ONNX and run it in Triton. We've managed to convert the model to ONNX and use the T5 notebook guide to run the model just fine in python.

But trying to get it to run in Triton has been a challenge. In particular, we're not sure how to get past_key_values to be passed through in Triton. We have the decoder config as follows:

name: "t5-dec-if-node_onnx_model"
max_batch_size: 0
platform: "onnxruntime_onnx"
default_model_filename: "model.bin"
input [
    {
        name: "input_ids"
        data_type: TYPE_INT32
        dims: [ -1, -1 ]
    },
    {
        name: "encoder_hidden_states"
        data_type: TYPE_FP32
        dims: [ -1, -1, 2048 ]
    },
    {
        name: "enable_cache"
        data_type: TYPE_BOOL
        dims: [ 1 ]
    },
    
        {
            name: "past_key_values.0.decoder.key"
            data_type: TYPE_FP32
            dims: [-1, 32, -1, 64]
        },
        {
            name: "past_key_values.0.decoder.value"
            data_type: TYPE_FP32
            dims: [-1, 32, -1, 64]
        },
        {
            name: "past_key_values.0.encoder.key"
            data_type: TYPE_FP32
            dims: [-1, 32, -1, 64]
        },
        {
            name: "past_key_values.0.encoder.value"
            data_type: TYPE_FP32
            dims: [-1, 32, -1, 64]
        }
     ...
]
output [
    {
        name: "logits"
        data_type: TYPE_FP32
        dims: [ -1, -1, 32128 ]
    }
]
instance_group [
    {
      count: 1
      kind: KIND_GPU
    }
]

And when we do the following:

input_1 = tritonclient.http.InferInput(name="input_ids", shape=(1, 24), datatype="INT32")
input_2 = tritonclient.http.InferInput(name="encoder_hidden_states", shape=(1, 24, 2048), datatype="FP32")
input_3 = tritonclient.http.InferInput(name="enable_cache", shape=(1, ), datatype="BOOL")

input_1.set_data_from_numpy(input_ids)
input_2.set_data_from_numpy(encoder_hidden_states)
input_3.set_data_from_numpy(np.asarray([True]))

result = triton_client.infer(
    model_name='t5-dec-if-node_onnx_model', 
    inputs=[input_1, input_2, input_3], 
    outputs=[tritonclient.http.InferRequestedOutput(name="logits", binary_data=False)]
)

We get this error:

InferenceServerException: [request id: <id_unknown>] expected 99 inputs but got 3 inputs for model 't5-dec-if-node_onnx_model'

Any idea how we can fix this?

@ayoub-louati
Copy link
Contributor

Hello,
Thanks for trying our library,
We are actually working on adding T5 officialy in the convert script so that you can do conversion with one line command,
It will be added very soon (especially onnx conversion maybe in less than a week), but if you want I can help you with the triton configuration (it is a little bit complicated).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants