In [1]:
# === STEP 1: Load Models ===
from ultralytics import YOLO
import torch

# Load YOLOv8 custom-trained head detector
head_detector = YOLO("custom_head_yolov8n.pt")

# Load Gaze-LLE model and transform
from torch.hub import load

gaze_model, gaze_transform = load('fkryan/gazelle', 'gazelle_dinov2_vitb14')
gaze_model.eval()


# === STEP 2: Detect Heads ===
def detect_heads_from_frame(frame):
    results = head_detector.predict(source=frame, stream=False, verbose=False)
    boxes = []
    result = results[0]
    for box in result.boxes:
        if int(box.cls[0]) == 0:
            x1, y1, x2, y2 = box.xyxy[0].tolist()
            boxes.append((x1, y1, x2, y2))
    return boxes


# === STEP 3: Run Gaze-LLE ===
def run_gaze_lle_on_frame(pil_frame, head_bboxes):
    width, height = pil_frame.size
    input_tensor = gaze_transform(pil_frame).unsqueeze(0)

    norm_bboxes = [
        (x1 / width, y1 / height, x2 / width, y2 / height)
        for (x1, y1, x2, y2) in head_bboxes
    ]

    if not norm_bboxes:
        return []

    input_data = {'images': input_tensor, 'bboxes': [norm_bboxes]}
    with torch.no_grad():
        output = gaze_model(input_data)

    return output['heatmap'][0]  # shape: [num_heads, H, W]


# === STEP 4: Draw on Frame ===
from PIL import ImageDraw
import numpy as np

def draw_gaze_on_frame(pil_frame, head_bboxes, heatmaps):
    draw = ImageDraw.Draw(pil_frame)
    width, height = pil_frame.size

    for (x1, y1, x2, y2), heatmap in zip(head_bboxes, heatmaps):
        head_center = ((x1 + x2) / 2, (y1 + y2) / 2)

        heatmap_np = heatmap.detach().cpu().numpy()
        y, x = np.unravel_index(np.argmax(heatmap_np), heatmap_np.shape)
        gaze_x = x * width / heatmap_np.shape[1]
        gaze_y = y * height / heatmap_np.shape[0]
        gaze_target = (gaze_x, gaze_y)

        draw.rectangle([(x1, y1), (x2, y2)], outline="red", width=2)
        draw.line([head_center, gaze_target], fill="blue", width=2)
        draw.ellipse((gaze_x - 3, gaze_y - 3, gaze_x + 3, gaze_y + 3), fill="blue")

    return pil_frame


# === STEP 5: Process Video ===
import cv2
from PIL import Image

def process_video(video_path, save_path=None):
    cap = cv2.VideoCapture(video_path)
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = None

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Convert to PIL image
        rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_frame = Image.fromarray(rgb_frame)

        # Run detection and gaze
        head_bboxes = detect_heads_from_frame(pil_frame)
        if not head_bboxes:
            continue
        heatmaps = run_gaze_lle_on_frame(pil_frame, head_bboxes)
        result_frame = draw_gaze_on_frame(pil_frame, head_bboxes, heatmaps)

        # Convert back to OpenCV format
        result_np = np.array(result_frame)
        result_bgr = cv2.cvtColor(result_np, cv2.COLOR_RGB2BGR)

        if save_path:
            if out is None:
                h, w, _ = result_bgr.shape
                out = cv2.VideoWriter(save_path, fourcc, 20.0, (w, h))
            out.write(result_bgr)
        else:
            cv2.imshow("Gaze Estimation", result_bgr)
            if cv2.waitKey(1) & 0xFF == ord('q'):
                break

    cap.release()
    if out:
        out.release()
    cv2.destroyAllWindows()



Using cache found in C:\Users\User/.cache\torch\hub\fkryan_gazelle_main
  from .autonotebook import tqdm as notebook_tqdm
Using cache found in C:\Users\User/.cache\torch\hub\facebookresearch_dinov2_main


In [3]:
# Example usage:
process_video("input_gaze.mp4", "output_gaze.mp4")

