In [1]:
import cv2
import torch
import numpy as np

from PIL import Image
from ultralytics import YOLO


# load model
model=YOLO("/mnt/data/weights/all_data_v11/yolov11_all_data_02122024_5/weights/best.pt")
print(f"{sum(p.numel() for p in model.parameters())/1e6} million parameters")

# results=model.track("data/box.mov",conf=0.1,show=True)


25.324358 million parameters


# Objective
Assume we want to track **main objects** in a video and draw bounding box in **every single frame**. There are two issues to solve:

1. Detect the only **large objects** in the video 
2. If some objects are not detected at a certain frame (for example due to ocllusion), the bounding boxes **will not be drawn** by the usual command model.track

### 1. Detect Large Objects
Idea: Large objects have height (or width) relatively large compared to the image height (or width)

In [10]:

def detect_large(image_array,ratio=1/8):    
    """
    Detect only large objects, return list [xywh,conf,cls]
    """
    results=[]

    height,width=image_array.shape[:2]

    # prediction with yolo model
    dets=model.predict(image_array,conf=0.5,classes=[1])
    dets=dets[0]            # only 1 input image -> dets = list of length 1
    
    xywhs=dets.boxes.xywh.cpu().numpy()
    confs=dets.boxes.conf.cpu().numpy()
    clss=dets.boxes.cls.cpu().numpy()

    for xywh,conf,cls in zip(xywhs,confs,clss):
        if xywh[2]>width*ratio or xywh[3]*ratio>height*ratio:
            results.append([xywh[0],xywh[1],xywh[2],xywh[3],conf,cls])
    
    return np.asarray(results)



## 2. Track with ByteTracker
We aim to draw bounding boxes in every single frame. In those frames that the model don't detect an object, we use the bounding box of the most recent frame that the object is detected. We need 2 dictionaries:

1. Dictionary 1: {track_id: bbox} 
2. Dictionary 2: {track_id: frame_number} with frame_number= most recent frame the track_id was detected

In [11]:
from argparse import Namespace
from ultralytics.trackers.byte_tracker import BYTETracker

# input to BYTETracker has attributes xywh,conf,cls
class Result:
    def __init__(self,xywh,conf,cls):
        self.xywh=xywh
        self.conf=conf
        self.cls=cls

def inference(video_path):
    cap = cv2.VideoCapture(video_path)
    # video features
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    print(f"Original width : {width} | Original height :{height} | num_frames : {num_frames}")

    # create codec for writing output frame
    out = cv2.VideoWriter(video_path[:-4] + "_tracked.mp4", cv2.VideoWriter_fourcc(*"mp4v"), fps//2, (width, height))
    
    # dictionaries
    history = {}            # dictionary {track_id: bbox}
    frame_history = {}      # dictionary {track_id: frame_number} with frame_number = most recent frame track_id is detected
    buffer = 30             # maximum number of frames to retain information
    
    MIN_BOXES = 2           # minimum number of boxes to draw in every frame 
    last_min_results = None # last result which detects at least 2 boxes

    # activate BYTETracker
    args = Namespace(track_high_thresh=0.5, track_low_thresh=0.1,
                     match_thresh=0.7, new_track_thresh=0.6,
                     track_buffer=buffer, fuse_score=True)
    tracker = BYTETracker(args, frame_rate=fps)

    # loop through frames
    frame_number = 0
    while frame_number < num_frames:
        ret, frame = cap.read()

        if not ret:
            print(f"Failed to read frame {frame_number}. Skipping...")
            frame_number += 1
            continue     

        results = detect_large(frame)
        
        if len(results) < MIN_BOXES:
            results = last_min_results
        else:
            last_min_results = results
        try:
            # prepare input for tracker
            xywh = results[:, :4]
            conf = results[:, 4]
            cls = results[:, 5]
            tracker_inputs = Result(xywh, conf, cls)
        except:
            continue

        # track the frame: tracked_frame=[x1, y1, x2, y2, track_id, score, class, idx]
        tracked_frame = tracker.update(tracker_inputs)
        
        if len(tracked_frame) > 0:
            # extract track results
            track_ids = tracked_frame[:, 4]
            xyxys = tracked_frame[:, :4]
            
            # Update the history with current track_id and bbox
            for track_id, bbox in zip(track_ids, xyxys):
                history[track_id] = bbox
                frame_history[track_id] = frame_number

            # remove track_ids and bboxes that were not seen in the last 30 frames
            for track_id in list(history.keys()):
                if frame_history[track_id] < frame_number - buffer:
                    del history[track_id]
                    del frame_history[track_id]

            for track_id, bbox in history.items():
                cv2.rectangle(frame, (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])), (0, 255, 0), 2)
                cv2.putText(frame, str(track_id), (int(bbox[0]), int(bbox[1])), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
    
        out.write(frame)
        frame_number += 1

    cap.release()
    out.release()
    cv2.destroyAllWindows()

video_path = "/mnt/data/demo_video/surg_sponge.avi"
inference(video_path)

Original width : 2048 | Original height :2048 | num_frames : 597

0: 2048x2048 (no detections), 98.1ms
Speed: 21.0ms preprocess, 98.1ms inference, 1.4ms postprocess per image at shape (1, 3, 2048, 2048)

0: 2048x2048 (no detections), 92.7ms
Speed: 15.3ms preprocess, 92.7ms inference, 1.8ms postprocess per image at shape (1, 3, 2048, 2048)

0: 2048x2048 (no detections), 88.9ms
Speed: 21.0ms preprocess, 88.9ms inference, 1.2ms postprocess per image at shape (1, 3, 2048, 2048)



0: 2048x2048 1 sponge, 87.4ms
Speed: 14.5ms preprocess, 87.4ms inference, 2.0ms postprocess per image at shape (1, 3, 2048, 2048)

0: 2048x2048 1 sponge, 87.3ms
Speed: 20.9ms preprocess, 87.3ms inference, 1.6ms postprocess per image at shape (1, 3, 2048, 2048)

0: 2048x2048 1 sponge, 87.4ms
Speed: 14.5ms preprocess, 87.4ms inference, 1.9ms postprocess per image at shape (1, 3, 2048, 2048)

0: 2048x2048 1 sponge, 87.5ms
Speed: 20.8ms preprocess, 87.5ms inference, 1.8ms postprocess per image at shape (1, 3, 2048, 2048)

0: 2048x2048 1 sponge, 87.6ms
Speed: 14.3ms preprocess, 87.6ms inference, 2.2ms postprocess per image at shape (1, 3, 2048, 2048)

0: 2048x2048 (no detections), 87.0ms
Speed: 21.1ms preprocess, 87.0ms inference, 1.1ms postprocess per image at shape (1, 3, 2048, 2048)

0: 2048x2048 1 sponge, 87.4ms
Speed: 14.4ms preprocess, 87.4ms inference, 2.1ms postprocess per image at shape (1, 3, 2048, 2048)

0: 2048x2048 1 sponge, 87.2ms
Speed: 20.2ms preprocess, 87.2ms inference, 1.