In [None]:
%pip install tritonclient[http] transformers torch

Make sure that your notebook is on the same VPC as your GKE cluster.

We can use both the /infer endpoint and the direct triton endpoint with the Triton Client to get inferencing results

In [None]:
# Populate these from the logs of the GKE-Provision-Deploy image, or substitute other known values
host = ""
flask_node_port = ""
triton_node_port = ""
payload = "Sandwiched between a second-hand bookstore and record shop in Cape Town's charmingly grungy suburb of Observatory is a blackboard reading 'Tapi Tapi -- Handcrafted, authentic African ice cream.' The parlor has become one of Cape Town's most talked about food establishments since opening in October 2020. And in its tiny kitchen, Jeff is creating ice cream flavors like no one else. Handwritten in black marker on the shiny kitchen counter are today's options: Salty kapenta dried fish (blitzed), toffee and scotch bonnet chile Sun-dried blackjack greens and caramel, Malted millet ,Hibiscus, cloves and anise. Using only flavors indigenous to the African continent, Guzha's ice cream has become the tool through which he is reframing the narrative around African food. 'This (is) ice cream for my identity, for other people's sake,' Jeff tells CNN. 'I think the (global) food story doesn't have much space for Africa ... unless we're looking at the generic idea of African food,' he adds. 'I'm not trying to appeal to the global universe -- I'm trying to help Black identities enjoy their culture on a more regular basis.'"

# Flask Inferencing

The flask endpoint takes inputs in the same format that Vertex Predictions. A list of inputs in the 'instances' property, and a response with a list of predictions in the 'predictions' property.

Using the flask endpoint will do tokenization on the server, which makes it great for experimentation but is not capable of maximum throughput.

In [None]:
%%time
import requests
flask_payload = { "instances": [payload] }
flask_endpoint = f'http://{host}:{flask_node_port}/infer'

response = requests.post(flask_endpoint, json=flask_payload)
response.json()

# Triton Inferencing

The underlying Triton server can also be called directly. Using Triton directly requires management of the payload on the client-side:
* Encoding of the payload
* Conversion to TensorRT format
* Decoding the response payload

Removal of the pre/post processing and a reduction of network calls leads to a higher throughput endpoint.

Pre-load the tokenizer to download the dictionary.

In [None]:
from transformers import AutoTokenizer
import tritonclient.http as httpclient
from tritonclient.utils import np_to_triton_dtype
import numpy as np
import torch

tokenizer = AutoTokenizer.from_pretrained("t5-base")

Generate the tensor inputs by:

    - Tokenizing your payload

    - Setting your desired inference parameters, here you can see an example of setting sequence_length, top_k, and max_output_len

In [None]:
%%time
input_token = tokenizer(
        payload, return_tensors="pt", padding=True, truncation=True
    )
input_ids = input_token.input_ids.numpy().astype(np.uint32)

mem_seq_len = (
        torch.sum(input_token.attention_mask, dim=1).numpy().astype(np.uint32)
    )
mem_seq_len = mem_seq_len.reshape([mem_seq_len.shape[0], 1])
max_output_len = np.array([[128]], dtype=np.uint32)
runtime_top_k = (1.0 * np.ones([input_ids.shape[0], 1])).astype(np.uint32)

inputs = [
    httpclient.InferInput(
        "input_ids", input_ids.shape, np_to_triton_dtype(input_ids.dtype)
    ),
    httpclient.InferInput(
        "sequence_length",
        mem_seq_len.shape,
        np_to_triton_dtype(mem_seq_len.dtype),
    ),
    httpclient.InferInput(
        "max_output_len",
        max_output_len.shape,
        np_to_triton_dtype(mem_seq_len.dtype),
    ),
    httpclient.InferInput(
        "runtime_top_k",
        runtime_top_k.shape,
        np_to_triton_dtype(runtime_top_k.dtype),
    ),
]
inputs[0].set_data_from_numpy(input_ids, False)
inputs[1].set_data_from_numpy(mem_seq_len, False)
inputs[2].set_data_from_numpy(max_output_len, False)
inputs[3].set_data_from_numpy(runtime_top_k, False)

Create a client to the Triton server running on your cluster. Set `verbose=True` to see the encoded payload being sent to the server.

In [None]:
client = httpclient.InferenceServerClient(f"{host}:{triton_node_port}", verbose=True)

Call client.infer() with the model name and the inputs. Note that the model name is hardcoded as 'fastertransformer' when deployed.

In [None]:
%%time
result = client.infer("fastertransformer", inputs)

The response is returned as tensors as well, so it will need to be decoded.

In [None]:
%%time
ft_decoding_outputs = result.as_numpy("output_ids")
ft_decoding_seq_lens = result.as_numpy("sequence_length")
tokens = tokenizer.decode(
    ft_decoding_outputs[0][0][: ft_decoding_seq_lens[0][0]],
    skip_special_tokens=True,
)

print(tokens)