In [None]:
import os
import sys
import cv2
import json
import time
import math
import numpy as np
from pathlib import Path
from dataclasses import dataclass
from ultralytics import YOLO

p = Path.cwd()
while p != p.parent and not (p / "utils").exists():
    p = p.parent

sys.path.insert(0, str(p))

from utils.file_dialog_utils import pick_video_cv2

In [None]:
# --- USER CONFIGURATION ---

# Inputs
_ , VIDEO_SOURCE  = pick_video_cv2(title="Select Roulette Video")

GEOMETRY_JSON = "C:\\Users\\Gabriel\\Documents\\Dissertation\\Code\\notebooks\\pipeline_attempts\\roulette_cv4\\wheel_geometry.json"   # produced by setup.ipynb
YOLO_WEIGHTS  = "C:\\Users\\Gabriel\\Documents\\Dissertation\\Code\\models\\yolo\\RD2_Model.pt"                     # trained YOLO weights (Ultralytics)

# Detection settings
BALL_CLASS_ID = 0       # YOLO class index for the ball
CONF_THRES    = 0.40
IOU_THRES     = 0.70

# Landing logic
LANDING_SECONDS = 1.5   # how long (continuous) ball must remain in halo
RESET_ON_LOST   = True  # if detection is lost, reset timer
REARM_SECONDS   = 1.0   # must be out-of-halo (or missing) this long before next landing can be detected

# Screenshot burst after landing
LANDING_SHOT_COUNT        = 5    # total screenshots per landing (including first frame)
LANDING_SHOT_DURATION_SEC = 1.5  # spread screenshots across this duration

# Outputs
OUTPUT_DIR           = "inference_output"
SAVE_ANNOTATED_VIDEO = True
ANNOTATED_VIDEO_PATH = os.path.join(OUTPUT_DIR, "annotated.mp4")
RESULTS_JSONL_PATH   = os.path.join(OUTPUT_DIR, "results_events.jsonl")

# Display (set False for headless runs)
SHOW_WINDOW  = True
WINDOW_NAME  = "Halo + YOLO Inference"

In [None]:
# --- OUTPUT DIRECTORY SETUP ---

Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)
print("OUTPUT_DIR:", Path(OUTPUT_DIR).resolve())

In [None]:
# --- HELPER CLASSES AND FUNCTIONS ---

@dataclass
class HaloParams:
    cx: float
    cy: float
    outer_rx: float
    outer_ry: float
    inner_rx: float
    inner_ry: float
    angle_deg: float

def load_halo_from_geometry_json(path: str) -> HaloParams:
    with open(path, "r") as f:
        geom = json.load(f)

    cx = float(geom["ellipse"]["cx"])
    cy = float(geom["ellipse"]["cy"])
    rx = float(geom["ellipse"]["rx"])
    ry = float(geom["ellipse"]["ry"])
    angle_deg = float(geom["ellipse"]["rotation_deg"])

    outer_inward = float(geom["halo"]["outer_inward_pct"])
    inner_inward = float(geom["halo"]["inner_inward_pct"])

    outer_rx = rx * (1.0 - outer_inward)
    outer_ry = ry * (1.0 - outer_inward)

    inner_rx = rx * (1.0 - inner_inward)
    inner_ry = ry * (1.0 - inner_inward)

    return HaloParams(
        cx=cx, cy=cy,
        outer_rx=outer_rx, outer_ry=outer_ry,
        inner_rx=inner_rx, inner_ry=inner_ry,
        angle_deg=angle_deg
    )

def point_in_rotated_ellipse(px: float, py: float, cx: float, cy: float, rx: float, ry: float, angle_deg: float) -> bool:
    """Return True if point (px,py) lies inside the rotated ellipse."""
    if rx <= 0 or ry <= 0:
        return False

    ang = math.radians(angle_deg)
    cos_a = math.cos(ang)
    sin_a = math.sin(ang)

    dx = px - cx
    dy = py - cy

    # rotate by -angle to align ellipse to axes
    x =  cos_a * dx + sin_a * dy
    y = -sin_a * dx + cos_a * dy

    val = (x * x) / (rx * rx) + (y * y) / (ry * ry)
    return val <= 1.0

def point_in_halo(px: float, py: float, halo: HaloParams) -> bool:
    inside_outer = point_in_rotated_ellipse(px, py, halo.cx, halo.cy, halo.outer_rx, halo.outer_ry, halo.angle_deg)
    inside_inner = point_in_rotated_ellipse(px, py, halo.cx, halo.cy, halo.inner_rx, halo.inner_ry, halo.angle_deg)
    return inside_outer and (not inside_inner)

def draw_halo(frame_bgr: np.ndarray, halo: HaloParams) -> np.ndarray:
    """Draw the halo boundaries on the frame (in-place) and return it."""
    cv2.ellipse(frame_bgr, (int(halo.cx), int(halo.cy)), (int(halo.outer_rx), int(halo.outer_ry)), halo.angle_deg, 0, 360, (255, 255, 0), 2)
    cv2.ellipse(frame_bgr, (int(halo.cx), int(halo.cy)), (int(halo.inner_rx), int(halo.inner_ry)), halo.angle_deg, 0, 360, (0, 255, 255), 2)
    return frame_bgr

halo = load_halo_from_geometry_json(GEOMETRY_JSON)

In [None]:
# --- SCREENSHOT UTILITY ---

def save_landing_screenshot(
    frame_bgr,
    spin_id=None,
    frame_idx=None,
    time_s=None,
    shot_idx=None
):
    """
    Saves a clean screenshot of a landing frame.

    Parameters:
        frame_bgr  : raw OpenCV BGR frame (before drawing overlays)
        spin_id    : optional spin identifier
        frame_idx  : frame index at landing
        time_s     : timestamp in seconds
        shot_idx   : optional screenshot index within the landing burst

    Returns:
        filename : name of the saved screenshot file
    """

    # Build filename
    timestamp = time.strftime("%Y%m%d_%H%M%S", time.gmtime())

    spin_part = f"spin_{spin_id}_" if spin_id is not None else ""
    shot_part = f"shot_{shot_idx}_" if shot_idx is not None else ""
    frame_part = f"f{frame_idx}_" if frame_idx is not None else ""
    time_part = f"{time_s:.3f}s_" if time_s is not None else ""

    filename = f"{spin_part}{shot_part}{frame_part}{time_part}{timestamp}.png"

    screenshots_dir = os.path.join(OUTPUT_DIR, "screenshots")
    os.makedirs(screenshots_dir, exist_ok=True)
    save_path = os.path.join(screenshots_dir, filename)

    # Save raw frame (NO overlays)
    cv2.imwrite(save_path, frame_bgr)

    return filename

In [None]:
# --- MAIN INFERENCE LOOP ---

# Load model
model = YOLO(YOLO_WEIGHTS)

# Open source: file path OR camera index OR RTSP url
cap = cv2.VideoCapture(VIDEO_SOURCE)
if not cap.isOpened():
    raise RuntimeError(f"Failed to open VIDEO_SOURCE: {VIDEO_SOURCE}")

fps = cap.get(cv2.CAP_PROP_FPS)
if not fps or fps <= 1e-3:
    fps = 30.0
frame_dt = 1.0 / fps

w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

writer = None
if SAVE_ANNOTATED_VIDEO:
    fourcc = cv2.VideoWriter_fourcc(*"mp4v")
    writer = cv2.VideoWriter(ANNOTATED_VIDEO_PATH, fourcc, fps, (w, h))

if SHOW_WINDOW:
    cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
    cv2.resizeWindow(WINDOW_NAME, 1280, 720)

# Landing state
time_in_halo = 0.0
time_outside = 0.0
armed = True          # can we trigger landing?
landed = False

spin_counter = 0
pending_screenshots = []

# For saving events
def append_event(event: dict, path: str = RESULTS_JSONL_PATH):
    with open(path, "a", encoding="utf-8") as f:
        f.write(json.dumps(event) + "\n")

frame_idx = -1
last_ball = None  # store best ball detection for debug

try:
    while True:
        ok, frame = cap.read()
        if not ok:
            break

        frame_clean = frame.copy()  # for saving screenshots without overlays

        frame_idx += 1
        t_s = frame_idx / fps

        # Run YOLO
        results = model.predict(
            source=frame,
            conf=CONF_THRES,
            iou=IOU_THRES,
            device=0,
            verbose=False
        )
        r = results[0]
        display = r.plot()  # draws boxes/labels

        # pick best ball detection (highest conf)
        best = None
        if r.boxes is not None and len(r.boxes) > 0:
            for box in r.boxes:
                cls_id = int(box.cls[0].item())
                if cls_id != BALL_CLASS_ID:
                    continue
                conf = float(box.conf[0].item())
                if best is None or conf > best["conf"]:
                    x1, y1, x2, y2 = box.xyxy[0].tolist()
                    best = {"conf": conf, "xyxy": [x1, y1, x2, y2]}

        in_halo = False
        if best is not None:
            x1, y1, x2, y2 = best["xyxy"]
            bx = (x1 + x2) / 2.0
            by = (y1 + y2) / 2.0
            in_halo = point_in_halo(bx, by, halo)

            # Draw ball center
            cv2.circle(display, (int(bx), int(by)), 4, (0, 0, 255), -1)
            last_ball = {"conf": best["conf"], "center": [bx, by]}
        else:
            in_halo = False   # treat missing as outside for re-arming

        if armed:
            if in_halo:
                time_in_halo += frame_dt
                time_outside = 0.0
            else:
                time_in_halo = 0.0
                time_outside += frame_dt

            if time_in_halo >= LANDING_SECONDS:
                landed = True
                armed = False
                time_outside = 0.0

                landing_frame_idx = frame_idx
                landing_time_s = t_s

                spin_counter += 1

                shot_count = max(1, int(LANDING_SHOT_COUNT))
                duration_frames = max(0, int(round(LANDING_SHOT_DURATION_SEC * fps)))
                if shot_count == 1 or duration_frames == 0:
                    shot_offsets = [0]
                else:
                    shot_offsets = np.round(
                        np.linspace(0, duration_frames, num=shot_count)
                    ).astype(int).tolist()

                for shot_idx, offset in enumerate(shot_offsets, start=1):
                    pending_screenshots.append({
                        "spin_id": spin_counter,
                        "shot_idx": shot_idx,
                        "target_frame_idx": landing_frame_idx + int(offset),
                    })

                event = {
                    "event": "landing",
                    "spin_id": spin_counter,
                    "frame_idx": landing_frame_idx,
                    "time_s": landing_time_s,
                    "ball": last_ball,
                    "fps": fps,
                    "screenshot_count": shot_count,
                    "screenshot_duration_sec": LANDING_SHOT_DURATION_SEC,
                    "geometry_json": GEOMETRY_JSON,
                    "video_source": VIDEO_SOURCE,
                    "timestamp_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
                }
                append_event(event)
                print(f"[LANDED] frame={landing_frame_idx} t={landing_time_s:.3f}s (saved to {RESULTS_JSONL_PATH})")
        else:
            # not armed: wait until ball is outside (or missing) for long enough
            if not in_halo:
                time_outside += frame_dt
                if time_outside >= REARM_SECONDS:
                    armed = True
                    landed = False
                    time_in_halo = 0.0
                    time_outside = 0.0
            else:
                time_outside = 0.0

        if pending_screenshots:
            remaining = []
            for shot in pending_screenshots:
                if frame_idx >= shot["target_frame_idx"]:
                    screenshot_name = save_landing_screenshot(
                        frame_clean,
                        spin_id=shot["spin_id"],
                        frame_idx=frame_idx,
                        time_s=t_s,
                        shot_idx=shot["shot_idx"]
                    )

                    shot_event = {
                        "event": "screenshot",
                        "spin_id": shot["spin_id"],
                        "shot_idx": shot["shot_idx"],
                        "frame_idx": frame_idx,
                        "time_s": t_s,
                        "ball": last_ball,
                        "screenshot": screenshot_name,
                        "fps": fps,
                        "geometry_json": GEOMETRY_JSON,
                        "video_source": VIDEO_SOURCE,
                        "timestamp_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
                    }
                    append_event(shot_event)
                else:
                    remaining.append(shot)
            pending_screenshots = remaining

        # Draw halo boundaries + status
        draw_halo(display, halo)
        status = f"in_halo_time={time_in_halo:.2f}s / {LANDING_SECONDS:.2f}s | landed={landed}"
        cv2.putText(display, status, (20, 40), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2)

        if landed:
            cv2.putText(display, "BALL HAS LANDED", (20, 90), cv2.FONT_HERSHEY_SIMPLEX, 1.2, (0, 0, 255), 3)

        # Output
        if writer is not None:
            writer.write(display)

        if SHOW_WINDOW:
            cv2.imshow(WINDOW_NAME, display)
            key = cv2.waitKey(1) & 0xFF
            if key in (ord("q"), 27):  # q or ESC
                break

finally:
    cap.release()
    if writer is not None:
        writer.release()
    if SHOW_WINDOW:
        cv2.destroyAllWindows()

print("\nAnnotated video:", ANNOTATED_VIDEO_PATH if SAVE_ANNOTATED_VIDEO else "(disabled)")
print("Events:", RESULTS_JSONL_PATH)


## Live stream later (camera / RTSP)

When youâ€™re ready to swap `VIDEO_SOURCE` from a file to a live stream:

- USB webcam: `VIDEO_SOURCE = 0`
- RTSP: `VIDEO_SOURCE = "rtsp://user:pass@host:554/stream1"`

The loop already uses `cv2.VideoCapture(VIDEO_SOURCE)` so the only change should be the source string/index.
