<a href="https://colab.research.google.com/github/Wolfffff/gpuhackathon-sleap/blob/main/triton/RealTimeSLEAP_Triton.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture

!wget https://github.com/triton-inference-server/server/releases/download/v2.10.0/v2.10.0_ubuntu2004.clients.tar.gz
!tar -zxvf v2.10.0_ubuntu2004.clients.tar.gz python/tritonclient-2.10.0-py3-none-manylinux1_x86_64.whl --strip-components 1
!wget https://raw.githubusercontent.com/Wolfffff/gpuhackathon-sleap/main/triton/triton_utils.py
!wget -P data https://storage.googleapis.com/sleap-data/reference/flies13/190719_090330_wt_18159206_rig1.2%4015000-17560.mp4

In [2]:
%%capture

!pip install AttrDict
!pip install nvidia-pyindex
!pip install tritonclient-2.10.0-py3-none-manylinux1_x86_64.whl[all]

In [3]:
from triton_utils import parse_model
from triton_utils import UserData
from triton_utils import requestGenerator

import numpy as np
import cv2
import tensorflow as tf
import tritonclient.grpc as grpcclient

In [4]:
def read_frames(video_path, fidxs=None, grayscale=True):
    """Read frames from a video file.
    
    Args:
        video_path: Path to MP4
        fidxs: List of frame indices or None to read all frames (default: None)
        grayscale: Keep only one channel of the images (default: True)
    
    Returns:
        Loaded images in array of shape (n_frames, height, width, channels) and dtype uint8.
    """
    vr = cv2.VideoCapture(video_path)
    if fidxs is None:
        fidxs = np.arange(vr.get(cv2.CAP_PROP_FRAME_COUNT))
    frames = []
    for fidx in fidxs:
        vr.set(cv2.CAP_PROP_POS_FRAMES, fidx)
        img = vr.read()[1]
        if grayscale:
            img = img[:, :, [0]]
        frames.append(img)
    return np.stack(frames, axis=0)

In [5]:
# Exposed localhost though ngrok -- this is the main retricting factor here.
# ngrok seems to be exceptionally slow but probably fine for this example...
video_path = "data/190719_090330_wt_18159206_rig1.2@15000-17560.mp4"
cap = cv2.VideoCapture(video_path)
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)

triton_url = '4.tcp.ngrok.io:12904'
model_name = "centroid_savedmodel"
model_version="1"
protocol = 'grpc'
triton_client = grpcclient.InferenceServerClient(url=triton_url)
sent_count = 0

In [6]:
model_metadata = triton_client.get_model_metadata(model_name=model_name, model_version=model_version)

model_config = triton_client.get_model_config(model_name=model_name, model_version=model_version)

model_config = model_config.config

max_batch_size, input_name, output_name, c, h, w, format, dtype = parse_model(model_metadata, model_config)

# Fix output names for when we have multiple
output_names = [model.name for model in model_metadata.outputs]

In [7]:
# Pair request generator 
def query_triton(frame):
  responses = []
  global model_name, model_version, sent_count, protocol, sent_count
  for inputs, outputs,model_name, model_version in requestGenerator(
          frame, input_name, output_names, dtype, protocol,model_name,model_version):
      responses.append(triton_client.infer(model_name,
                                  inputs,
                                  request_id=str(sent_count),
                                  model_version=model_version,
                                  outputs=outputs))
      sent_count += 1
                                  
  return responses

In [8]:
%matplotlib inline
import matplotlib.pyplot as plt
from ipywidgets import interactive
import time

def plot_frame(fidx=0):
    print ('Processing frame %d' % fidx)
    start_timestamp = time.time()
    global video_path
    # Fetch the image corresponding to the frame index
    img  = read_frames(video_path=video_path,fidxs=[fidx]).astype('float32')
    img = tf.image.resize(img,size=(512,512)).numpy()
    
    # Plot the image
    plt.figure(figsize=(8, 8))
    plt.imshow(img.squeeze(), cmap="gray")

    # Send to Triton server for inference!
    response = query_triton(img)
    conf_map = response[0].as_numpy(output_names[0])[0][:,:,0]
    extent = 0,512,0,512
    plt.imshow(conf_map, cmap=plt.cm.viridis, alpha=.7, interpolation='bilinear',
                 extent=extent,origin='lower')
    plt.show()

    # Log timing
    print ('\nFrame processed in %.2f s.' % (time.time() - start_timestamp))

# Initialize and launch the widget
max_frame_idx = int(frame_count) - 1
interactive_plot = interactive(plot_frame, fidx=(0, max_frame_idx, 1))
interactive_plot.children[-1].layout.height = "512px"
interactive_plot

interactive(children=(IntSlider(value=0, description='fidx', max=2559), Output(layout=Layout(height='512px')))…