In [None]:
import cv2
from ultralytics import YOLO
import numpy as np
import torch
import time
import csv
from types import SimpleNamespace

# ByteTrack import
try:
    from yolox.tracker.byte_tracker import BYTETracker, STrack
except ImportError:
    raise ImportError("ByteTrack not found. Make sure yolox is installed and in PYTHONPATH.")

def draw_box(frame, bbox, label, color=(0, 255, 0), thickness=2, font_scale=0.6):
    x1, y1, x2, y2 = map(int, bbox)
    cv2.rectangle(frame, (x1, y1), (x2, y2), color, thickness)
    if label:
        cv2.putText(frame, label, (x1, y1 - 5), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, 1)

def run_ant_tracking(video_path, model_path, confidence_threshold=0.20, iou_threshold=0.4):
    print(f"Loading YOLOv8 model: {model_path}...")
    model = YOLO(model_path)
    print("Model loaded.")

    # ByteTrack config
    args = SimpleNamespace()
    args.track_thresh = 0.08
    args.track_buffer = 1500
    args.match_thresh = 0.8
    args.aspect_ratio_thresh = 10
    args.min_box_area = 1
    args.mot20 = False
    tracker = BYTETracker(args, frame_rate=100)

    # Open video
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise IOError(f"Could not open video file {video_path}")

    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    print(f"Video info: {width}x{height}, {fps:.2f} FPS, {total_frames} frames")

    # Output setup
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter("output_ant_trackingXXXXF.mp4", fourcc, min(fps, 15), (width, height))
    csv_file = open("ant_tracking_resultsXXXXF.csv", mode="w", newline="")
    csv_writer = csv.writer(csv_file)
    csv_writer.writerow(["frame", "track_id", "center_x", "center_y", "confidence"])

    frame_count = 0
    start_time = time.time()

    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame_count += 1
        print(f"Processing frame {frame_count}...")

        # YOLO detection
        results = model(frame, conf=confidence_threshold, iou=iou_threshold, verbose=False)
        detections = []

        if results and results[0].boxes:
            for box in results[0].boxes:
                x1, y1, x2, y2 = box.xyxy[0].cpu().numpy()
                conf = float(box.conf[0])
                detections.append([x1, y1, x2, y2, conf])

        detections_np = np.array(detections)

        # Visualize all YOLO detections (blue)
        for det in detections:
            draw_box(frame, det[:4], f"{det[4]:.2f}", color=(255, 0, 0), thickness=1)

        # Run ByteTrack update
        if len(detections_np) == 0:
            online_targets = tracker.update(np.empty((0, 5)), [height, width], (height, width))
        else:
            online_targets = tracker.update(detections_np, [height, width], (height, width))

        # Visualize tracks
        tracked_boxes = []
        for target in online_targets:
            tlwh = target.tlwh
            track_id = target.track_id
            x1, y1, w, h = map(int, tlwh)
            x2, y2 = x1 + w, y1 + h
            tracked_boxes.append([x1, y1, x2, y2])

            # Center point
            cx = (x1 + x2) / 2
            cy = (y1 + y2) / 2
            conf = target.score if hasattr(target, 'score') else -1

            draw_box(frame, [x1, y1, x2, y2], f"ID: {track_id}", color=(0, 255, 0), thickness=2)
            csv_writer.writerow([frame_count, track_id, round(cx, 2), round(cy, 2), round(conf, 3)])

        # Optional: highlight unmatched YOLO detections (red)
        for det in detections:
            is_matched = any(
                abs(det[0] - tb[0]) < 5 and abs(det[1] - tb[1]) < 5 for tb in tracked_boxes
            )
            if not is_matched:
                draw_box(frame, det[:4], "Unmatched", color=(0, 0, 255), thickness=1)

        # Overlay FPS and frame
        elapsed_time = time.time() - start_time
        current_fps = frame_count / elapsed_time
        cv2.putText(frame, f"FPS: {current_fps:.2f}", (10, 30),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
        cv2.putText(frame, f"Frame: {frame_count}/{total_frames}", (10, 70),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)

        out.write(frame)

    cap.release()
    out.release()
    csv_file.close()
    print("✅ Tracking complete. Results saved to 'output_ant_trackingXXXXF.mp4' and 'ant_tracking_resultsXXXXF.csv'.")

if __name__ == "__main__":
    video_file = 'Videos with marking behavior/S3160001_(slo-mo(120fps)_recruitment_p1).MP4'
    model_file = 'runs/detect/train2/weights/best.pt'
    run_ant_tracking(video_file, model_file, confidence_threshold=0.20, iou_threshold=0.4)