In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import sys

import cv2
import numpy as np
from bytetracker import BYTETracker
from bytetracker.basetrack import BaseTrack
from tqdm import tqdm

from lib.bbox.utils import rescale_bbox, xy_center_to_xyxy
from lib.sequence import Sequence
from trackreid.args.reid_args import OUTPUT_POSITIONS
from trackreid.reid_processor import ReidProcessor

sys.path.append("..")


# Real life data

In [None]:
DATA_PATH = "../data"
DETECTION_PATH = f"{DATA_PATH}/detections"
FRAME_PATH = f"{DATA_PATH}/frames"
VIDEO_OUTPUT_PATH = "private"

SEQUENCES = os.listdir(DETECTION_PATH)


In [None]:
def get_sequence_frames(sequence):
    frames = os.listdir(f"{FRAME_PATH}/{sequence}")
    frames = [os.path.join(f"{FRAME_PATH}/{sequence}", frame) for frame in frames]
    frames.sort()
    return frames

def get_sequence_detections(sequence):
    detections = os.listdir(f"{DETECTION_PATH}/{sequence}")
    detections = [os.path.join(f"{DETECTION_PATH}/{sequence}", detection) for detection in detections]
    detections.sort()
    return detections

In [None]:
class DetectionHandler():
    def __init__(self, image_shape) -> None:
        self.image_shape = image_shape

    def process(self, detection_output):
        if detection_output.size:
            if detection_output.ndim == 1:
                detection_output = np.expand_dims(detection_output, 0)

            processed_detection = np.zeros(detection_output.shape)

            for idx, detection in enumerate(detection_output):
                clss = detection[0]
                conf = detection[5]
                bbox = detection[1:5]
                xyxy_bbox = xy_center_to_xyxy(bbox)
                rescaled_bbox = rescale_bbox(xyxy_bbox,self.image_shape)
                processed_detection[idx,:4] = rescaled_bbox
                processed_detection[idx,4] = conf
                processed_detection[idx,5] = clss

            return processed_detection
        else:
            return detection_output


In [None]:
class TrackingHandler():
    def __init__(self, tracker) -> None:
        self.tracker = tracker

    def update(self, detection_outputs, frame_id):

        if not detection_outputs.size :
            return detection_outputs

        processed_detections = self._pre_process(detection_outputs)
        tracked_objects = self.tracker.update(processed_detections, frame_id = frame_id)
        processed_tracked = self._post_process(tracked_objects)
        return processed_tracked

    def _pre_process(self,detection_outputs : np.ndarray):
        return detection_outputs

    def _post_process(self, tracked_objects : np.ndarray):

        if tracked_objects.size :
            if tracked_objects.ndim == 1:
                tracked_objects = np.expand_dims(tracked_objects, 0)

        return tracked_objects

In [None]:
def bounding_box_distance(obj1, obj2):
    # Get the bounding boxes from the Metadata of each TrackedObject
    bbox1 = obj1.metadata.bbox
    bbox2 = obj2.metadata.bbox

    # Calculate the Euclidean distance between the centers of the bounding boxes
    center1 = ((bbox1[0] + bbox1[2]) / 2, (bbox1[1] + bbox1[3]) / 2)
    center2 = ((bbox2[0] + bbox2[2]) / 2, (bbox2[1] + bbox2[3]) / 2)
    distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)

    return distance

# TODO : discard by zone
def select_by_category(obj1, obj2):
    # Compare the categories of the two objects
    return 1 if obj1.category == obj2.category else 0

In [None]:
for sequence in SEQUENCES :
    frame_path = get_sequence_frames(sequence)
    test_sequence = Sequence(frame_path)
    test_sequence
    frame_id = 0
    BaseTrack._count = 0

        # Define the codec using VideoWriter_fourcc() and create a VideoWriter object
    fourcc = cv2.VideoWriter_fourcc(*'avc1')  # or use 'x264'
    out = cv2.VideoWriter(f'{sequence}.mp4', fourcc, 20.0, (2560, 1440))  # adjust the frame size (640, 480) as per your needs


    detection_handler = DetectionHandler(image_shape=[2560, 1440])
    tracking_handler = TrackingHandler(tracker=BYTETracker(track_thresh= 0.3, track_buffer = 5, match_thresh = 0.85, frame_rate= 30))
    reid_processor = ReidProcessor(filter_confidence_threshold=0.1,
                                    filter_time_threshold=5,
                                    cost_function=bounding_box_distance,
                                    #cost_function_threshold=500, # max cost to rematch 2 objects
                                    selection_function=select_by_category,
                                    max_attempt_to_match=5,
                                    max_frames_to_rematch=500)

    for frame, detection in tqdm(test_sequence):
        frame = np.array(frame)

        frame_id += 1

        processed_detections = detection_handler.process(detection)
        processed_tracked = tracking_handler.update(processed_detections, frame_id)
        reid_results = reid_processor.update(processed_tracked, frame_id)

        if len(reid_results) > 0:
            for res in reid_results:
                object_id = int(res[OUTPUT_POSITIONS["object_id"]])
                bbox = list(map(int, res[OUTPUT_POSITIONS["bbox"]]))
                class_id = int(res[OUTPUT_POSITIONS["category"]])
                tracker_id = int(res[OUTPUT_POSITIONS["tracker_id"]])
                mean_confidence = float(res[OUTPUT_POSITIONS["mean_confidence"]])
                #mean_confidence_per_object[object_id].append((frame_id, mean_confidence))
                x1, y1, x2, y2 = bbox
                color = (0, 0, 255) if class_id  else (0, 255, 0)  # green for class 0, red for class 1
                cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
                cv2.putText(frame, f"{object_id} ({tracker_id}) : {round(mean_confidence,2)}", (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, color, 2)
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

        # Write the frame to the video file
        out.write(frame)
    out.release()

    print(sequence, len(reid_processor.seen_objects),reid_processor.nb_corrections)
    print(reid_processor.seen_objects)
