# ByteTrack Inference with Amazon SageMaker

This notebook will demonstrate how to create an endpoint for real time inference with the trained FairMOT model. We will first deploy the trained model in Sagemaker using [BYOS](https://sagemaker-examples.readthedocs.io/en/latest/sagemaker-script-mode/sagemaker-script-mode.html) mode by using custom inference scripts. And then apply inference on each frame of the video by invoking the endpoint. The inference result will be saved to a local directory.
SageMaker provided prebuilt containers for various frameworks like Scikit-learn, PyTorch, and XGBoost. For this example, we will use PyTorch prebuilt containers by defining a PyTorchModel instance. 

## 1. SageMaker Initialization 
First we upgrade SageMaker to the latest version. If your notebook is already using latest Sagemaker 2.x API, you may skip the next cell.

In [None]:
! pip install --upgrade pip
! python3 -m pip install --upgrade sagemaker
! pip install cython_bbox

### Import libraries and get execution role 

In [None]:
import boto3
import json
import time
import numpy as np

import sagemaker
from sagemaker import get_execution_role
from sagemaker.pytorch.model import PyTorchModel

role = (
    get_execution_role()
)  # provide a pre-existing role ARN as an alternative to creating a new role
print(f"SageMaker Execution Role:{role}")

client = boto3.client('sts')
account = client.get_caller_identity()['Account']
print(f'AWS account:{account}')

session = boto3.session.Session()
aws_region = session.region_name
print(f"AWS region:{aws_region}")

## 2. Deploy YOLOX model

You need to complete training job on [bytetrack-training.ipynb](bytetrack-training.ipynb) before running the following steps. Script Mode in SageMaker allows you to take control of the training and inference process without having to go through the trouble of creating and maintaining your own docker containers. Here, we want to use a pytorch algorithm, just use the AWS-provided Pytorch container and pass our own inference code. On your behalf, the SageMaker Python SDK will package this entry point script (sagemaker-serving/code/inferece.py), upload it to S3, and set two environment variables that are read at runtime and load the custom inference functions from the entry point script. 

### Get the s3 path for the model trained in [bytetrack-training.ipynb](bytetrack-training.ipynb) 

In [None]:
%store -r s3_model_uri

Inside inference.py, we defined 4 functions: model_fn, input_fn, predict_fn, and output_fn. These function handlers are automatically loaded and executed at runtime. The argument variables for these function handlers are predefined by SageMaker prebuilt containers. The model_fn handler loads the model according to s3 path, while the input_fn handler defines steps to pre-process the image passed by the requests. The predict_fn handler defines the model forward computing steps, and finally the output_fn handler defines the post-processing steps after getting the inference results. 

In [None]:
pytorch_model = PyTorchModel(
    model_data=s3_model_uri,
    role=role,
    source_dir="sagemaker-serving/code",
    entry_point="inference.py",
    framework_version="1.7.1",
    py_version="py3",
)

### In this cell, you would need to define the endpoint name.  

In [None]:
endpoint_name = <endpint name>
pytorch_model.deploy(
    initial_instance_count=1,
    instance_type="ml.p3.2xlarge",
    endpoint_name=endpoint_name
)

## 3. Run Multi-Object Tracking with YOLOX

We deploy YOLOX model into SageMaker endpoint, and run tracking task in client side. We will use the tracking scripts provided [ByteTrack](https://github.com/ifzhang/ByteTrack/tree/main/yolox/tracker).

In [None]:
%%writefile download_tracking.sh
git clone --filter=blob:none --no-checkout --depth 1 --sparse https://github.com/ifzhang/ByteTrack.git && \
cd ByteTrack && \
git sparse-checkout set yolox && \
git checkout && \
cd ..
cp -r ByteTrack/yolox yolox
cp container-batch-inference/byte_tracker.py yolox/tracker/
sudo rm -r ByteTrack

In [None]:
!bash download_tracking.sh

In [None]:
!pip install lap

This cell defines a function which captures each frame of a video, passes each frame to the inference endpoint defined in the previous step and saves the resulting frames to a local directory (save_folder). 

In [None]:
from yolox.tracker.byte_tracker import BYTETracker
import cv2
import time
from yolox.tracking_utils.timer import Timer
import os.path as osp
import os
import torch
from yolox.utils.visualize import plot_tracking

sm_runtime = boto3.Session().client("sagemaker-runtime")

def imageflow_demo(endpoint_name, video_path="", save_folder=""):
    cap = cv2.VideoCapture(video_path)
    
    width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # float
    height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # float
    
    print(f"width: {width}, height: {height}")
    
    fps = cap.get(cv2.CAP_PROP_FPS)
    
    os.makedirs(save_folder, exist_ok=True)
    save_path = osp.join(save_folder, video_path.split("/")[-1])
    
    print(f"video save_path is {save_path}")
    vid_writer = cv2.VideoWriter(
        save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
    )
    
    aspect_ratio_thresh = 1.6
    min_box_area = 10
    tracker = BYTETracker(
        frame_rate=30,
        track_thresh=0.5,
        track_buffer=30,
        mot20=False,
        match_thresh=0.8
    )
    timer = Timer()
    frame_id = 0
    results = []
    while True:
        ret_val, frame = cap.read()
        if ret_val:
            cv2.imwrite(f'datasets/frame_{frame_id}.png', frame)
            with open(f"datasets/frame_{frame_id}.png", "rb") as f:
                payload = f.read()
            
            timer.tic()
            response = sm_runtime.invoke_endpoint(
                EndpointName=endpoint_name, ContentType="application/x-image", Body=payload
            )
            outputs = json.loads(response["Body"].read().decode())
            
            if outputs[0] is not None:
                online_targets = tracker.update(torch.as_tensor(outputs[0]), [height, width], (800, 1440))
                online_tlwhs = []
                online_ids = []
                online_scores = []
                for t in online_targets:
                    tlwh = t.tlwh
                    tid = t.track_id
                    vertical = tlwh[2] / tlwh[3] > aspect_ratio_thresh
                    if tlwh[2] * tlwh[3] > min_box_area and not vertical:
                        online_tlwhs.append(tlwh)
                        online_ids.append(tid)
                        online_scores.append(t.score)
                        results.append(
                            f"{frame_id},{tid},{tlwh[0]:.2f},{tlwh[1]:.2f},{tlwh[2]:.2f},{tlwh[3]:.2f},{t.score:.2f},-1,-1,-1\n"
                        )
                timer.toc()
                online_im = plot_tracking(
                    frame, online_tlwhs, online_ids, frame_id=frame_id + 1, fps=1. / timer.average_time
                )
            else:
                timer.toc()
                online_im = frame
            if frame_id % 20 == 0:
                print('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
            
            vid_writer.write(online_im)
        else:
            break
        frame_id += 1

    res_file = osp.join(save_folder, f"log.txt")
    with open(res_file, 'w') as f:
        f.writelines(results)
    print(f"save results to {res_file}")

### Download a video from a public resource. 

In [None]:
!mkdir datasets
!wget https://raw.githubusercontent.com/ifzhang/FairMOT/master/videos/MOT16-03.mp4 -O datasets/MOT16-03.mp4

In [None]:
video_path="datasets/MOT16-03.mp4"
save_folder="track_res"
imageflow_demo(endpoint_name, video_path, save_folder)

---
## Inference speed comparison

<table>
    <tr><th>Instance</th><th>FPS</th></tr>
    <tr><td>ml.g4dn.2xlarge</td><td>3.6</td></tr>
    <tr><td>ml.p3.2xlarge</td><td>5</td></tr>
</table>