In [23]:
import os

INPUT_VIDEO = "dataset/0a2d9b_0.mp4"
CONFIG_FILE = "configs/default.yaml"
TEMP_DIR = "tmp_video"
OUTPUT_DIR = "clip_video"

input_video_fn = os.path.basename(INPUT_VIDEO)
input_video_name = input_video_fn.split('.')[0]

if not os.path.exists(TEMP_DIR):
    os.makedirs(TEMP_DIR)
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

In [24]:
# == Download pretrained X model weights ==
# !gdown --id "1P4mY0Yyd3PPTybgZkjMYhFri88nTmJX5"
# !gdown --id "11Zb0NN_Uu7JwUd9e6Nk8o2_EUfxWqsun"
# !gdown --id "1uSmhXzyV1Zvb4TJJCzpsZOIcw7CCJLxj"

In [25]:
import sys
import yaml
import time
import os.path as osp
from dotmap import DotMap

import cv2
import numpy as np
import torch
import torchvision

from loguru import logger

In [26]:
# Read config file
config = None
with open(CONFIG_FILE, "r") as stream:
    config = yaml.safe_load(stream)
    config = DotMap(config)

In [27]:
# Create save folder
current_time = time.localtime()
timestamp = time.strftime("%Y_%m_%d_%H_%M_%S", current_time)
save_folder = osp.join(TEMP_DIR, timestamp)
save_path = osp.join(save_folder, input_video_fn)

if not os.path.exists(save_folder):
    os.makedirs(save_folder)

In [28]:
# Read input video
cap = cv2.VideoCapture(INPUT_VIDEO)

# Get main video characteristic
width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)  # float
height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)  # float
fps = cap.get(cv2.CAP_PROP_FPS) # float

vid_writer = cv2.VideoWriter(
    save_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (int(width), int(height))
)

In [29]:
PATH_TO_CHECKPOINT = "detection/models/pretrained/bytetrack_m_mot17.pth.tar"

In [30]:
from detection.models.yolo.yolox_m_mix_det import Exp
from detection import ObjectDetection

exp = Exp()

if config.conf is not None:
    exp.test_conf = config.conf
if config.nms is not None:
    exp.nmsthre = config.nms
if config.tsize is not None:
    exp.test_size = (config.tsize, config.tsize)

predictor = ObjectDetection(exp, PATH_TO_CHECKPOINT, config.device)

#### Tracking model

In [31]:
from tracking import ObjectTracking
from tracking.methods.bytetrack import BYTETracker

tracker = ObjectTracking(BYTETracker(config, frame_rate=fps), config)

#### Video processing

In [32]:
from utils.timer import Timer
from visualize import plot_tracking

In [33]:
timer = Timer()
frame_id = 0

while True:
    if frame_id % 20 == 0:
        logger.info('Processing frame {} ({:.2f} fps)'.format(frame_id, 1. / max(1e-5, timer.average_time)))
    ret_val, frame = cap.read()
    if ret_val:
        # Object detection part
        outputs, img_info = predictor.inference(frame, timer)
        if outputs[0] is not None:
            # Tracking part
            online_tlwhs, online_ids, online_scores = tracker.update(frame_id, outputs[0], [img_info['height'], img_info['width']], exp.test_size)
            timer.toc()
            # Predict part 
            
            # Visualization part
            online_im = plot_tracking(img_info['raw_img'], online_tlwhs, online_ids, frame_id=frame_id + 1, fps=1. / timer.average_time)
        else:
            timer.toc()
            online_im = img_info['raw_img']
        if config.save_result:
            vid_writer.write(online_im)
    else:
        break
    frame_id += 1

if config.save_result:
    res_file = osp.join(OUTPUT_DIR, f"{timestamp}.txt")
    with open(res_file, 'w') as f:
        f.writelines(tracker.results)
    logger.info(f"save results to {res_file}")

2022-09-10 02:09:03.249 | INFO     | __main__:<module>:6 - Processing frame 0 (100000.00 fps)
2022-09-10 02:09:06.576 | INFO     | __main__:<module>:6 - Processing frame 20 (13.53 fps)
2022-09-10 02:09:09.736 | INFO     | __main__:<module>:6 - Processing frame 40 (13.62 fps)
2022-09-10 02:09:12.885 | INFO     | __main__:<module>:6 - Processing frame 60 (13.65 fps)
2022-09-10 02:09:16.013 | INFO     | __main__:<module>:6 - Processing frame 80 (13.65 fps)
2022-09-10 02:09:19.147 | INFO     | __main__:<module>:6 - Processing frame 100 (13.66 fps)
2022-09-10 02:09:22.301 | INFO     | __main__:<module>:6 - Processing frame 120 (13.65 fps)
2022-09-10 02:09:25.430 | INFO     | __main__:<module>:6 - Processing frame 140 (13.66 fps)
2022-09-10 02:09:28.670 | INFO     | __main__:<module>:6 - Processing frame 160 (13.64 fps)
2022-09-10 02:09:32.008 | INFO     | __main__:<module>:6 - Processing frame 180 (13.65 fps)
2022-09-10 02:09:35.499 | INFO     | __main__:<module>:6 - Processing frame 200 (1

In [34]:
vid_writer.release()