<h3>Imports</h3>

In [51]:
from ultralytics import YOLO
import cv2
from tqdm import tqdm

<h3>Model paths</h3>

In [52]:
model_1 = YOLO("/home/theo/Documents/Unif/Master/ChimpRec/Code/Body_detection/YOLO_small/runs/detect/train9/weights/best.pt")
model_2 = YOLO("/home/theo/Documents/Unif/Master/ChimpRec/Code/Face_detection/runs/detect/train3/weights/best.pt")

<h3>Video paths</h3>

In [53]:
video_path = "/home/theo/Documents/Unif/Master/Chimprec - Extra/videos/20241023 - 09h28.MP4"
output_path = "output.mp4"

<h2>First prediction in the pipeline</h2>

In [54]:
# prediction using the body detection model
def predict_1(image, t_confidence=0.4):
    results = model_1.predict(image, verbose=False)[0]
    return tuple(
        (x1, y1, x2, y2, score)
        for x1, y1, x2, y2, score, _ in results.boxes.data.tolist()
        if score >= t_confidence
    )

<h2>Second prediction in the pipeline</h2>

In [55]:
# prediction using the face detection model
def predict_2(image, t_confidence=0.6):
    results = model_2.predict(image, verbose=False)
    best = max(
        (
            (int(x1), int(y1), int(x2), int(y2), score)
            for result in results
            for x1, y1, x2, y2, score, _ in result.boxes.data.tolist()
        ),
        default=None,
        key=lambda x: x[-1]
    )
    return best if best and best[-1] >= t_confidence else None

<h2>Util functions</h2>

In [56]:
# returns a cropped image within image according to bbox
def crop(image, bbox):
    x1, y1, x2, y2, _ = map(int, bbox)
    return image[max(y1, 0):y2, max(x1, 0):x2]

# converts the bbox from the cropped body image into coordinates in the initial image
def face_to_src(body_bbox, face_bbox):
    bx1, by1, _, _, _ = body_bbox
    fx1, fy1, fx2, fy2, score = face_bbox
    return (int(bx1 + fx1), int(by1 + fy1), int(bx1 + fx2), int(by1 + fy2), score)

# draw bbox on image
def draw_bbox(image, color, bbox, label):
    x1, y1, x2, y2, score = bbox
    x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
    factor = 0.65 if label == "Face" else 0.3
    font_scale = max(0.4, (x2 - x1 + y2 - y1) / 300) * factor

    cv2.rectangle(image, (x1, y1), (x2, y2), color, 4)
    label_text = f"{label}: {score:.2f}"
    (w, h), _ = cv2.getTextSize(label_text, cv2.FONT_HERSHEY_COMPLEX, font_scale, 1)

    overlay = image.copy()
    cv2.rectangle(overlay, (x2 - w - 10, y2 - h - 10), (x2, y2), color, -1)
    cv2.addWeighted(overlay, 0.5, image, 0.5, 0, image)

    cv2.putText(image, label_text, (x2 - w - 5, y2 - 5), cv2.FONT_HERSHEY_COMPLEX, font_scale, (255,255,255), 1)
    return image

# extract the body and face bboxes from image
def predict_frame(image, body_bboxes=None, face_bboxes=None):
    if body_bboxes is None:
        body_bboxes = predict_1(image)
        face_bboxes = tuple(
            face_to_src(body_bbox, face_bbox)
            for body_bbox in body_bboxes
            if (face_bbox := predict_2(crop(image, body_bbox))) is not None
        )

    for bbox in body_bboxes:
        draw_bbox(image, (254, 122, 51), bbox, "Body")
    for bbox in face_bboxes:
        draw_bbox(image, (66, 66, 255), bbox, "Face")

    return image, body_bboxes, face_bboxes

<h1>Main code</h1>

In [58]:
n = 3  # Process one frame every n frames

cap = cv2.VideoCapture(video_path)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = cap.get(cv2.CAP_PROP_FPS)

max_frames = 120#int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (frame_width, frame_height))

body_bboxes, face_bboxes = None, None

with tqdm(total=max_frames, desc="Processing frames") as pbar:
    for frame_idx in range(max_frames):
        ret, frame = cap.read()
        if not ret:
            break

        if frame_idx % n == 0:
            annotated_frame, body_bboxes, face_bboxes = predict_frame(frame)
        else:
            annotated_frame, _, _ = predict_frame(frame, body_bboxes, face_bboxes)

        out.write(annotated_frame)
        pbar.update(1)

cap.release()
out.release()
cv2.destroyAllWindows()

Processing frames: 100%|██████████| 120/120 [00:22<00:00,  5.31it/s]
