In [None]:
import sys
from google.colab import drive
drive.mount('/content/gdrive/')
sys.path.append('/content/gdrive/MyDrive/Colab Notebooks/ByteTrack/')

Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).


In [None]:
cd '/content/gdrive/MyDrive/Colab Notebooks/ByteTrack/'

/content/gdrive/MyDrive/Colab Notebooks/ByteTrack


In [None]:
# !pip3 install -r requirements.txt

In [None]:
import os
import cv2
import torch
from torch import nn
import numpy as np

from yolox.exp import get_exp
from yolox.models.network_blocks import SiLU
from yolox.utils import replace_module

import onnx
from onnxsim import simplify

import argparse
import onnxruntime
from yolox.data.data_augment import preproc as preprocess
from yolox.utils import mkdir, multiclass_nms, demo_postprocess, vis
from yolox.utils.visualize import plot_tracking
from yolox.tracker.byte_tracker import BYTETracker
from yolox.tracking_utils.timer import Timer

In [None]:
save_onnx_model = True

if save_onnx_model:
    # exp = get_exp('exps/example/mot/yolox_x_mix_det.py', None)
    exp = get_exp('exps/example/mot/yolox_s_mix_det.py', None)

    model = exp.get_model().to(torch.device("cpu"))
    model.eval()

    # ckpt = torch.load('pretrained/bytetrack_x_mot17.pth.tar', map_location="cpu")
    ckpt = torch.load('pretrained/bytetrack_s_mot17.pth.tar', map_location="cpu")

    model.load_state_dict(ckpt["model"])
    model = replace_module(model, nn.SiLU, SiLU)
    model.head.decode_in_inference = False
    dummy_input = torch.randn(1, 3, exp.test_size[0], exp.test_size[1])
    print(dummy_input.shape)

    onnx_path = "./onnx/model_s.onnx"

    torch.onnx._export(model, dummy_input, onnx_path, 
                       input_names=['images'], output_names=['output'],
                       opset_version=11)
    
    onnx_model = onnx.load(onnx_path)
    model_simp, check = simplify(onnx_model)
    assert check, "Simplified ONNX model could not be validated"
    onnx.save(model_simp, onnx_path)

torch.Size([1, 3, 608, 1088])


In [None]:
# onnx_path = "./onnx/model.onnx"
# model = cv2.dnn.readNetFromONNX(onnx_path)

# image = cv2.imread('./assets/test.jpg')
# blob = cv2.dnn.blobFromImage(image, 1.0 / 255, (224, 224),(0, 0, 0), swapRB=True, crop=False)
# net.setInput(blob)
# preds = net.forward()
# print ("Predicted", preds.shape)

In [None]:
def make_parser():
    parser = argparse.ArgumentParser("onnxruntime inference sample")
    parser.add_argument(
        "-m",
        "--model",
        type=str,
        default="onnx/model_s.onnx",
        help="Input your onnx model.",
    )
    parser.add_argument(
        "-i",
        "--video_path",
        type=str,
        default='videos/palace.mp4',
        help="Path to your input image.",
    )
    parser.add_argument(
        "-o",
        "--output_dir",
        type=str,
        default='demo_output',
        help="Path to your output directory.",
    )
    parser.add_argument(
        "-s",
        "--score_thr",
        type=float,
        default=0.1,
        help="Score threshould to filter the result.",
    )
    parser.add_argument(
        "-n",
        "--nms_thr",
        type=float,
        default=0.7,
        help="NMS threshould.",
    )
    parser.add_argument(
        "--input_shape",
        type=str,
        default="608,1088",
        help="Specify an input shape for inference.",
    )
    parser.add_argument(
        "--with_p6",
        action="store_true",
        help="Whether your model uses p6 in FPN/PAN.",
    )
    # tracking args
    parser.add_argument("--track_thresh", type=float, default=0.5, help="tracking confidence threshold")
    parser.add_argument("--track_buffer", type=int, default=30, help="the frames for keep lost tracks")
    parser.add_argument("--match_thresh", type=float, default=0.8, help="matching threshold for tracking")
    parser.add_argument('--min-box-area', type=float, default=10, help='filter out tiny boxes')
    parser.add_argument("--mot20", dest="mot20", default=False, action="store_true", help="test mot20.")
    return parser

class Predictor(object):
    def __init__(self, args):
        self.rgb_means = (0.485, 0.456, 0.406)
        self.std = (0.229, 0.224, 0.225)
        self.args = args
        self.session = onnxruntime.InferenceSession(args.model)
        self.input_shape = tuple(map(int, args.input_shape.split(',')))
    
    def inference(self, ori_img, timer):
        img_info = {"id": 0}
        height, width = ori_img.shape[:2]
        img_info["height"] = height
        img_info["width"] = width
        img_info["raw_img"] = ori_img
        
        img, ratio = preprocess(ori_img, self.input_shape, self.rgb_means, self.std)
        img_info["ratio"] = ratio
        ort_inputs = {self.session.get_inputs()[0].name: img[None, :, :, :]}
        timer.tic()
        output = self.session.run(None, ort_inputs)
        predictions = demo_postprocess(output[0], self.input_shape, p6=self.args.with_p6)[0]
        
        boxes = predictions[:, :4]
        scores = predictions[:, 4:5] * predictions[:, 5:]
        
        boxes_xyxy = np.ones_like(boxes)
        boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
        boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
        boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
        boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
        boxes_xyxy /= ratio
        dets = multiclass_nms(boxes_xyxy, scores, nms_thr=self.args.nms_thr, score_thr=self.args.score_thr)
        return dets[:, :-1], img_info


def imageflow_demo(predictor, args):
    cap = cv2.VideoCapture(args.video_path)
    width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # float
    height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # float
    fps = cap.get(cv2.CAP_PROP_FPS)
    save_folder = args.output_dir
    os.makedirs(save_folder, exist_ok=True)
    save_path = os.path.join(save_folder, args.video_path.split("/")[-1])
    print(f"video save_path is {save_path}")
    vid_writer = cv2.VideoWriter(
        save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
    )
    tracker = BYTETracker(args, frame_rate=30)
    timer = Timer()
    frame_id = 0
    results = []
    while True:
        if frame_id % 20 == 0:
            print('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
        ret_val, frame = cap.read()
        if ret_val:
            outputs, img_info = predictor.inference(frame, timer)
            online_targets = tracker.update(outputs, [img_info['height'], img_info['width']], [img_info['height'], img_info['width']])
            online_tlwhs = []
            online_ids = []
            online_scores = []
            for t in online_targets:
                tlwh = t.tlwh
                tid = t.track_id
                vertical = tlwh[2] / tlwh[3] > 1.6
                if tlwh[2] * tlwh[3] > args.min_box_area and not vertical:
                    online_tlwhs.append(tlwh)
                    online_ids.append(tid)
                    online_scores.append(t.score)
            timer.toc()
            results.append((frame_id + 1, online_tlwhs, online_ids, online_scores))
            online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1,
                                      fps=1. / timer.average_time)
            vid_writer.write(online_im)
            ch = cv2.waitKey(1)
            if ch == 27 or ch == ord("q") or ch == ord("Q"):
                break
        else:
            break
        frame_id += 1

In [None]:
import sys
sys.argv=['']
del sys

args = make_parser().parse_args()
args

Namespace(input_shape='608,1088', match_thresh=0.8, min_box_area=10, model='onnx/model_s.onnx', mot20=False, nms_thr=0.7, output_dir='demo_output', score_thr=0.1, track_buffer=30, track_thresh=0.5, video_path='videos/palace.mp4', with_p6=False)

In [None]:
predictor = Predictor(args)
imageflow_demo(predictor, args)

video save_path is demo_output/palace.mp4
Processing frame 0 (100000.00 fps)
Processing frame 20 (1.31 fps)
Processing frame 40 (1.25 fps)
Processing frame 60 (1.28 fps)
Processing frame 80 (1.26 fps)
Processing frame 100 (1.28 fps)
Processing frame 120 (1.29 fps)
Processing frame 140 (1.26 fps)
Processing frame 160 (1.27 fps)
Processing frame 180 (1.26 fps)
Processing frame 200 (1.27 fps)
Processing frame 220 (1.28 fps)
Processing frame 240 (1.28 fps)
Processing frame 260 (1.28 fps)
Processing frame 280 (1.28 fps)
Processing frame 300 (1.28 fps)
Processing frame 320 (1.28 fps)
