In [1]:
#!/usr/bin/env python3
"""
Benchmark inference + analysis time on saved frames (no UI, no clicks, no overlays).

- Loads images from ./frames (fallback: ./alpha/frames).
- Runs YOLO segmentation and the same post-processing you use at runtime.
- Prints per-frame timings and a summary (mean/min/max and p50/p90/p99, plus FPS).

Usage:
  python benchmark_frames.py
"""

import os, time, math, glob, statistics
from pathlib import Path

import cv2
import numpy as np
import torch
from ultralytics import YOLO

# =======================
# Config
# =======================
home       = os.path.expanduser("~")
weights    = f"{home}/models/jakes-loped/jakes-finder-mk1/1/weights.pt"

RAIL_ID    = 9
IMG_SIZE   = 512
CONF, IOU  = 0.30, 0.45
MAX_DET    = 30

# Color/region filter for "green rails"
TARGET_COLORS_RGB  = [(119,104,67), (81,42,45)]
TOLERANCE          = 20.0
MIN_REGION_SIZE    = 30
MIN_REGION_HEIGHT  = 150

# Heat/triangle
HEAT_BLUR_KSIZE     = 51
RED_SCORE_THRESH    = 220
EXCLUDE_TOP_FRAC    = 0.40
EXCLUDE_BOTTOM_FRAC = 0.15
MIN_DARK_RED_AREA   = 1200
MIN_DARK_FRACTION   = 0.15
TRI_SIZE_PX         = 18

# Sampling ray length
SAMPLE_UP_PX        = 200
RAY_STEP_PX         = 20

# Lane anchor points (only used for angles and band math)
LANE_LEFT   = (300, 1340)
LANE_MID    = (490, 1340)
LANE_RIGHT  = (680, 1340)
LANE_POINTS = (LANE_LEFT, LANE_MID, LANE_RIGHT)
LANE_TARGET_DEG = {"left": -10.7, "mid": +1.5, "right": +15.0}

# Class buckets
DANGER_RED   = {1, 6, 7, 11}
WARN_YELLOW  = {2, 3, 4, 5, 8}
BOOTS_PINK   = {0}

LABELS = {
    0:"BOOTS",1:"GREYTRAIN",2:"HIGHBARRIER1",3:"JUMP",4:"LOWBARRIER1",
    5:"LOWBARRIER2",6:"ORANGETRAIN",7:"PILLAR",8:"RAMP",9:"RAILS",
    10:"SIDEWALK",11:"YELLOWTRAIN"
}

# --- tunnel wall color gate (HSV) ---
LOWBARRIER1_ID   = 4
ORANGETRAIN_ID   = 6
WALL_STRIP_PX    = 20
WALL_MATCH_FRAC  = 0.40
# HSV thresholds (OpenCV H: 0–179). Broad orange range; tune as needed.
WALL_ORANGE_LO = np.array([5,  80,  60], dtype=np.uint8)
WALL_ORANGE_HI = np.array([35, 255, 255], dtype=np.uint8)

# ====== tiny helpers ======
def _clampi(v, lo, hi):
    return lo if v < lo else (hi if v > hi else v)

def lane_name_from_point(p):
    if p == LANE_LEFT:  return "left"
    if p == LANE_MID:   return "mid"
    if p == LANE_RIGHT: return "right"
    return "mid"

# =======================
# Fast rails green finder
# =======================
TARGETS_BGR_F32 = np.array([(r,g,b)[::-1] for (r,g,b) in TARGET_COLORS_RGB], dtype=np.float32)
TOL2            = TOLERANCE * TOLERANCE

def highlight_rails_mask_only_fast(img_bgr, rail_mask):
    H, W = rail_mask.shape
    if not rail_mask.any():
        return np.zeros((H, W), dtype=bool)

    rail_u8 = rail_mask.view(dtype=np.uint8) * 255
    x, y, w, h = cv2.boundingRect(rail_u8)
    img_roi  = img_bgr[y:y+h, x:x+w]
    mask_roi = rail_u8[y:y+h, x:x+w]

    img_f = img_roi.astype(np.float32, copy=False)
    diff  = img_f[:, :, None, :] - TARGETS_BGR_F32[None, None, :, :]
    dist2 = (diff * diff).sum(-1)
    colour_hit = (dist2 <= TOL2).any(-1)

    combined = np.logical_and(colour_hit, mask_roi.astype(bool))

    comp = combined.astype(np.uint8)
    n, lbls, stats, _ = cv2.connectedComponentsWithStats(comp, 8)
    if n <= 1: return np.zeros((H, W), dtype=bool)

    good = np.zeros_like(combined)
    areas = stats[1:, cv2.CC_STAT_AREA]
    hs    = stats[1:, cv2.CC_STAT_HEIGHT]
    keep  = np.where((areas >= MIN_REGION_SIZE) & (hs >= MIN_REGION_HEIGHT))[0] + 1
    for k in keep: good[lbls == k] = True

    full = np.zeros((H, W), dtype=bool)
    full[y:y+h, x:x+w] = good
    return full

def red_vs_green_score(red_mask, green_mask):
    k = (HEAT_BLUR_KSIZE, HEAT_BLUR_KSIZE)
    r = cv2.blur(red_mask.astype(np.float32, copy=False), k)
    g = cv2.blur(green_mask.astype(np.float32, copy=False), k)
    diff = r - g
    amax = float(np.max(np.abs(diff))) + 1e-6
    norm = (diff / (2.0 * amax) + 0.5)
    return np.clip(norm * 255.0, 0, 255.0).astype(np.uint8, copy=False)

def purple_triangles(score, H):
    top_ex = int(H * EXCLUDE_TOP_FRAC)
    bot_ex = int(H * EXCLUDE_BOTTOM_FRAC)
    dark = (score >= RED_SCORE_THRESH).astype(np.uint8, copy=False)
    if top_ex: dark[:top_ex, :] = 0
    if bot_ex: dark[-bot_ex:, :] = 0

    dark = cv2.morphologyEx(
        dark, cv2.MORPH_OPEN,
        cv2.getStructuringElement(cv2.MORPH_RECT, (5, 9)), iterations=1
    )
    total_dark = int(dark.sum())
    if total_dark == 0: return [], None

    frac_thresh = int(np.ceil(MIN_DARK_FRACTION * total_dark))
    n_lbl, lbls, stats, _ = cv2.connectedComponentsWithStats(dark, 8)
    if n_lbl <= 1: return [], None

    tris = []
    for lbl in range(1, n_lbl):
        area = stats[lbl, cv2.CC_STAT_AREA]
        if area >= MIN_DARK_RED_AREA and area >= frac_thresh:
            ys, xs = np.where(lbls == lbl)
            if ys.size == 0: continue
            y_top = ys.min()
            x_mid = int(xs[ys == y_top].mean())
            tris.append((x_mid, int(y_top)))

    if not tris: return [], None
    best = min(tris, key=lambda xy: xy[1])
    return tris, best

# ===== Bearing-based Jake triangle selection =====
def signed_degrees_from_vertical(dx, dy):
    if dx == 0 and dy == 0: return 0.0
    return -math.degrees(math.atan2(dx, -dy))

def select_triangle_by_bearing(tri_positions, jx, jy, target_deg, min_dy=6):
    best_i, best_deg, best_err = -1, None, None
    for i, (xt, yt) in enumerate(tri_positions):
        dy = yt - jy
        if dy >= -min_dy:
            continue
        deg = signed_degrees_from_vertical(xt - jx, dy)
        err = abs(deg - target_deg)
        if (best_err is None) or (err < best_err):
            best_i, best_deg, best_err = i, deg, err
    return best_i, best_deg, best_err

# ===== Lane-aware curved sampling (precompute sin/cos) =====
BEND_LEFT_STATE_RIGHT_DEG  = -20.0
BEND_MID_STATE_RIGHT_DEG   = -20.0
BEND_MID_STATE_LEFT_DEG    = +20.0
BEND_RIGHT_STATE_LEFT_DEG  = +20.0

def _precompute_trig():
    angles = sorted(set([0.0,
        BEND_LEFT_STATE_RIGHT_DEG,
        BEND_MID_STATE_RIGHT_DEG,
        BEND_MID_STATE_LEFT_DEG,
        BEND_RIGHT_STATE_LEFT_DEG
    ]))
    table = {}
    for a in angles:
        r = math.radians(a)
        table[a] = (math.sin(r), -math.cos(r))  # (dx, dy) for unit ray (up = -y)
    return table

TRIG_TABLE = _precompute_trig()

def pick_bend_angle(jake_point, xt, x_ref, idx, best_idx):
    if idx == best_idx:
        return 0.0
    if jake_point == LANE_LEFT:
        return BEND_LEFT_STATE_RIGHT_DEG if xt > x_ref else 0.0
    if jake_point == LANE_RIGHT:
        return BEND_RIGHT_STATE_LEFT_DEG if xt < x_ref else 0.0
    if xt > x_ref: return BEND_MID_STATE_RIGHT_DEG
    if xt < x_ref: return BEND_MID_STATE_LEFT_DEG
    return 0.0

def classify_triangles_at_sample_curved(
    tri_positions, masks_np, classes_np, H, W,
    jake_point, x_ref, best_idx, sample_px=SAMPLE_UP_PX, step_px=RAY_STEP_PX
):
    if masks_np is None or classes_np is None or len(tri_positions) == 0:
        return [], [], [], []

    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1))
    sy = (mh - 1) / max(1, (H - 1))

    red_idx    = [i for i, c in enumerate(classes_np) if int(c) in DANGER_RED]
    yellow_idx = [i for i, c in enumerate(classes_np) if int(c) in WARN_YELLOW]
    boots_idx  = [i for i, c in enumerate(classes_np) if int(c) in BOOTS_PINK]

    colours, rays, hit_class_ids, hit_distances_px = [], [], [], []
    max_k = max(1, sample_px // max(1, step_px))

    for idx, (x0, y0) in enumerate(tri_positions):
        theta = pick_bend_angle(jake_point, x0, x_ref, idx, best_idx)
        dx1, dy1 = TRIG_TABLE[theta]

        hit_cls = None
        hit_dist_px = None

        found = False
        for k in range(1, max_k + 1):
            t  = k * step_px
            xs = _clampi(int(round(x0 + dx1 * t)), 0, W-1)
            ys = _clampi(int(round(y0 + dy1 * t)), 0, H-1)
            mx = _clampi(int(round(xs * sx)), 0, mw-1)
            my = _clampi(int(round(ys * sy)), 0, mh-1)

            for i in red_idx:
                if masks_np[i][my, mx] > 0.5:
                    hit_cls = int(classes_np[i]); hit_dist_px = float(t); found = True; break
            if found: break
            for i in yellow_idx:
                if masks_np[i][my, mx] > 0.5:
                    hit_cls = int(classes_np[i]); hit_dist_px = float(t); found = True; break
            if found: break
            for i in boots_idx:
                if masks_np[i][my, mx] > 0.5:
                    hit_cls = int(classes_np[i]); hit_dist_px = float(t); found = True; break
            if found: break

        colours.append(None)  # unused in timing run
        rays.append(((int(x0), int(y0)), (0, 0), float(theta)))  # minimal structure
        hit_class_ids.append(hit_cls)
        hit_distances_px.append(hit_dist_px)

    return colours, rays, hit_class_ids, hit_distances_px

# =======================
# Promotion logic (LOWBARRIER1 -> ORANGETRAIN when orange wall behind)
# =======================
def promote_lowbarrier_when_wall(frame_bgr, masks_np, classes_np,
                                 strip_px=WALL_STRIP_PX, frac_thresh=WALL_MATCH_FRAC):
    if masks_np is None or classes_np is None or masks_np.size == 0:
        return classes_np

    H, W = frame_bgr.shape[:2]
    hsv = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2HSV)
    wall_u8 = cv2.inRange(hsv, WALL_ORANGE_LO, WALL_ORANGE_HI)

    for i, cls in enumerate(classes_np):
        if int(cls) != LOWBARRIER1_ID:
            continue

        m = masks_np[i]
        if m.shape != (H, W):
            m_full = cv2.resize(m.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST).astype(bool)
        else:
            m_full = m.astype(bool, copy=False)

        ys, xs = np.where(m_full)
        if xs.size == 0:
            continue

        x0, x1 = xs.min(), xs.max()
        y0, _  = ys.min(), ys.max()

        yb0 = max(0, y0 - strip_px)
        yb1 = y0
        if yb1 <= yb0:
            continue

        strip = wall_u8[yb0:yb1, x0:x1+1]
        if strip.size == 0:
            continue

        frac = float(cv2.countNonZero(strip)) / strip.size
        if frac >= frac_thresh:
            classes_np[i] = ORANGETRAIN_ID

    return classes_np

# =======================
# Hit probing helpers used in analysis
# =======================
def first_red_hit_y(pos, masks_np, classes_np, H, W, band_px=6, step_px=5, max_up=SAMPLE_UP_PX):
    if masks_np is None or masks_np.size == 0: return None
    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1)); sy = (mh - 1) / max(1, (H - 1))
    red_idx = [i for i, c in enumerate(classes_np) if int(c) in DANGER_RED]
    if not red_idx: return None

    x0, y0 = int(pos[0]), int(pos[1])
    x0 = _clampi(x0, 0, W-1); y0 = _clampi(y0, 0, H-1)

    for t in range(step_px, max_up + 1, step_px):
        y = _clampi(y0 - t, 0, H-1)
        for dx in range(-band_px, band_px + 1):
            x = _clampi(x0 + dx, 0, W-1)
            mx = _clampi(int(round(x * sx)), 0, mw-1)
            my = _clampi(int(round(y * sy)), 0, mh-1)
            for i in red_idx:
                if masks_np[i][my, mx] > 0.5:
                    return y
    return None

# =======================
# Frame post-processing (timing-oriented, no printing)
# =======================
def process_frame_post(frame_bgr, yolo_res, jake_point):
    """
    Returns:
      tri_best_xy, tri_count, mask_count, to_cpu_ms, post_ms,
      masks_np, classes_np, rail_mask, green_mask,
      tri_positions, tri_colours, tri_rays,
      best_idx, best_deg, x_ref,
      tri_hit_classes, tri_summary
    """
    H, W = frame_bgr.shape[:2]
    if yolo_res.masks is None:
        return (None, 0, 0, 0.0, 0.0, None, None, None, None,
                [], [], [], None, None, None, [], [])

    t0 = time.perf_counter()
    masks_np = yolo_res.masks.data.detach().cpu().numpy()
    if hasattr(yolo_res.masks, "cls") and yolo_res.masks.cls is not None:
        classes_np = yolo_res.masks.cls.detach().cpu().numpy().astype(int)
    else:
        classes_np = yolo_res.boxes.cls.detach().cpu().numpy().astype(int)
    to_cpu_ms = (time.perf_counter() - t0) * 1000.0

    mask_count = int(masks_np.shape[0])
    if mask_count == 0 or classes_np.size == 0:
        return (None, 0, mask_count, to_cpu_ms, 0.0, masks_np, classes_np, None, None,
                [], [], [], None, None, None, [], [])

    classes_np = promote_lowbarrier_when_wall(frame_bgr, masks_np, classes_np)

    rail_sel = (classes_np == RAIL_ID)
    if not np.any(rail_sel):
        return (None, 0, mask_count, to_cpu_ms, 0.0, masks_np, classes_np, None, None,
                [], [], [], None, None, None, [], [])

    t1 = time.perf_counter()

    rail_masks = masks_np[rail_sel].astype(bool, copy=False)
    union = np.any(rail_masks, axis=0).astype(np.uint8, copy=False)
    rail_mask = cv2.resize(union, (W, H), interpolation=cv2.INTER_NEAREST).astype(bool, copy=False)

    green = highlight_rails_mask_only_fast(frame_bgr, rail_mask)
    red   = np.logical_and(rail_mask, np.logical_not(green))
    score = red_vs_green_score(red, green)
    tri_positions, tri_best = purple_triangles(score, H)

    # Choose Jake triangle (bearing)
    lane_name = lane_name_from_point(jake_point)
    target_deg = LANE_TARGET_DEG[lane_name]
    xj, yj = jake_point
    best_idx, best_deg, _ = select_triangle_by_bearing(tri_positions, xj, yj, target_deg, min_dy=6)

    # x_ref for bending
    x_ref = tri_positions[best_idx][0] if (lane_name == "mid" and best_idx is not None and 0 <= best_idx < len(tri_positions)) else xj

    tri_colours, tri_rays, tri_hit_classes, tri_hit_dists = classify_triangles_at_sample_curved(
        tri_positions, masks_np, classes_np, H, W, jake_point, x_ref, best_idx,
        SAMPLE_UP_PX, RAY_STEP_PX
    )

    post_ms = (time.perf_counter() - t1) * 1000.0

    # Minimal summary (useful later if you want to analyze decisions offline)
    tri_summary = []
    for i, (x, y) in enumerate(tri_positions):
        cid = tri_hit_classes[i] if i < len(tri_hit_classes) else None
        hdist = tri_hit_dists[i] if i < len(tri_hit_dists) else None
        tri_summary.append({
            "pos": (int(x), int(y)),
            "hit_class": None if cid is None else int(cid),
            "hit_label": None if cid is None else LABELS.get(int(cid), f"C{int(cid)}"),
            "hit_dist_px": None if hdist is None else float(hdist),
            "is_jake": (i == best_idx)
        })

    return (tri_best, len(tri_positions), mask_count, to_cpu_ms, post_ms,
            masks_np, classes_np, rail_mask, green,
            tri_positions, tri_colours, tri_rays,
            best_idx, best_deg, x_ref,
            tri_hit_classes, tri_summary)

# =======================
# Main benchmark
# =======================
def main():
    # Locate frames
    root = Path.cwd()
    frames_dir = root / "frames"
    if not frames_dir.exists():
        alt = root / "alpha" / "frames"
        if alt.exists():
            frames_dir = alt
    if not frames_dir.exists():
        raise SystemExit("No ./frames or ./alpha/frames directory found.")

    img_paths = sorted(
        [p for ext in ("*.png","*.jpg","*.jpeg","*.bmp") for p in frames_dir.glob(ext)]
    )
    if not img_paths:
        raise SystemExit(f"No images in {frames_dir}")

    # Backend
    cv2.setUseOptimized(True)
    try: cv2.setNumThreads(max(1, (os.cpu_count() or 1) - 1))
    except Exception: pass

    if torch.cuda.is_available():
        device, half = 0, True
        torch.backends.cudnn.benchmark = True
        try: torch.set_float32_matmul_precision('high')
        except Exception: pass
    elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        device, half = "mps", False
    else:
        device, half = "cpu", False

    # Model
    model = YOLO(weights)
    try: model.fuse()
    except Exception: pass

    # Warmup
    _dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
    _ = model.predict(_dummy, task="segment", imgsz=IMG_SIZE, device=device,
                      conf=CONF, iou=IOU, verbose=False, half=half, max_det=MAX_DET)

    # Timing accumulators
    infer_ms_list = []
    tocpu_ms_list = []
    post_ms_list  = []
    total_ms_list = []

    # Assume we start mid-lane for the bearing math
    JAKE_POINT = LANE_MID

    print(f"Benchmarking {len(img_paths)} frames from: {frames_dir}\n")
    print(f"{'frame':>16}  {'infer(ms)':>10}  {'toCPU(ms)':>10}  {'post(ms)':>9}  {'total(ms)':>10}")

    for p in img_paths:
        img = cv2.imread(str(p), cv2.IMREAD_COLOR)
        if img is None:
            continue

        t0 = time.perf_counter()
        res_list = model.predict(
            [img], task="segment", imgsz=IMG_SIZE, device=device,
            conf=CONF, iou=IOU, verbose=False, half=half, max_det=MAX_DET, batch=1
        )
        infer_ms = (time.perf_counter() - t0) * 1000.0
        yres = res_list[0]

        # Postproc (returns component timings)
        (_, _, _, to_cpu_ms, post_ms,
         _, _, _, _,
         _, _, _,
         _, _, _,
         _, _) = process_frame_post(img, yres, JAKE_POINT)

        total_ms = infer_ms + to_cpu_ms + post_ms

        infer_ms_list.append(infer_ms)
        tocpu_ms_list.append(to_cpu_ms)
        post_ms_list.append(post_ms)
        total_ms_list.append(total_ms)

        print(f"{p.name:>16}  {infer_ms:10.1f}  {to_cpu_ms:10.1f}  {post_ms:9.1f}  {total_ms:10.1f}")

    # Summary
    def q(arr, qv):  # percentile helper
        arr_sorted = sorted(arr)
        idx = max(0, min(len(arr_sorted)-1, int(round((qv/100.0)*(len(arr_sorted)-1)))))
        return arr_sorted[idx]

    mean_total = statistics.fmean(total_ms_list)
    fps_mean   = 1000.0 / mean_total if mean_total > 0 else 0.0

    print("\n=== Summary ===")
    print(f"Frames: {len(total_ms_list)}")
    print(f"Infer  : mean={statistics.fmean(infer_ms_list):.1f} ms")
    print(f"toCPU  : mean={statistics.fmean(tocpu_ms_list):.1f} ms")
    print(f"Post   : mean={statistics.fmean(post_ms_list):.1f} ms")
    print(f"Total  : mean={mean_total:.1f} ms  |  FPS≈{fps_mean:.2f}")
    print(f"Total  : min={min(total_ms_list):.1f}  p50={q(total_ms_list,50):.1f}  "
          f"p90={q(total_ms_list,90):.1f}  p99={q(total_ms_list,99):.1f}  max={max(total_ms_list):.1f} ms")

if __name__ == "__main__":
    main()


YOLO11n-seg summary (fused): 113 layers, 2,836,908 parameters, 0 gradients, 10.2 GFLOPs
Benchmarking 312 frames from: /Users/marcus/Documents/GitHub/Ai-plays-SubwaySurfers/frames

           frame   infer(ms)   toCPU(ms)   post(ms)   total(ms)
 frame_00000.png      1673.4         1.9        0.0      1675.2
 frame_00001.png       514.1         4.3      129.1       647.4
 frame_00002.png       121.8         1.9      110.4       234.2
 frame_00003.png        76.3         1.2      142.4       219.8
 frame_00004.png       189.6         1.1      175.7       366.5
 frame_00005.png       177.6         1.0      180.8       359.4
 frame_00006.png        74.8         1.0      156.6       232.4
 frame_00007.png       106.9         1.6      173.0       281.4
 frame_00008.png        85.1         1.3      174.0       260.4
 frame_00009.png       118.6         1.1      204.2       324.0
 frame_00010.png        38.1         0.9      155.7       194.8
 frame_00011.png        73.6         0.8      155.4 

KeyboardInterrupt: 

In [None]:
#!/usr/bin/env python3
# Live overlays + lane-aware curved sampling (optimized postproc)
# • Parsec focus + auto click
# • mss live capture of a crop region
# • Arrow keys switch lane (0/1/2) -> JAKE_POINT updates per frame
# • Full overlay rendering + per-frame save
# • Prints compact timing per frame
# • RETURNS per frame: tri_positions, best_idx, tri_hit_classes, tri_summary (for movement logic)

import os, time, math, subprocess, statistics
import cv2, torch, numpy as np
from pathlib import Path
from mss import mss
import pyautogui
from pynput import keyboard
from ultralytics import YOLO
from threading import Timer
from threading import Thread

# -------------------
# Timing accumulators
# -------------------
infer_ms_list  = []
tocpu_ms_list  = []
post_ms_list   = []
total_ms_list  = []
loop_dt_ms_list = []

t_run_start = time.perf_counter()

# --- swallow AI-generated keypresses in the listener for a short window ---
SYNTHETIC_SUPPRESS_S = 0.15  # 150 ms is plenty
_synth_block_until = 0.0     # simple, explicit init avoids IDE warnings

try:
    _synth_block_until
except NameError:
    _synth_block_until = 0.0

# ======================= Quick supreesion to prevent instant bailouts =======================

# --- allows 0.5s of movement, then mute for 2.5s, then restore ---
# Save originals
__press_orig   = pyautogui.press
__keyDown_orig = pyautogui.keyDown
__keyUp_orig   = pyautogui.keyUp
__hotkey_orig  = pyautogui.hotkey

# near your other globals, after imports
MOVEMENT_ENABLED = True

def __mute_keys():
    global MOVEMENT_ENABLED
    MOVEMENT_ENABLED = False
    pyautogui.press  = lambda *a, **k: None
    pyautogui.keyDown = lambda *a, **k: None
    pyautogui.keyUp   = lambda *a, **k: None
    pyautogui.hotkey  = lambda *a, **k: None
    print("[BOOT] movement muted")

def __unmute_keys():
    global MOVEMENT_ENABLED
    MOVEMENT_ENABLED = True
    pyautogui.press   = __press_orig
    pyautogui.keyDown = __keyDown_orig
    pyautogui.keyUp   = __keyUp_orig
    pyautogui.hotkey  = __hotkey_orig
    print("[BOOT] movement unmuted")


# Allow movement immediately; after 0.5s, mute; after 3.0s total, unmute
Timer(0.5, __mute_keys).start()
Timer(4.0, __unmute_keys).start()


# =======================
# Config
# =======================
home       = os.path.expanduser("~")
weights    = f"{home}/models/jakes-loped/jakes-finder-mk1/1/weights.pt"

# SAVE HERE
out_dir    = Path(home) / "Documents" / "GitHub" / "Ai-plays-SubwaySurfers" / "out_live_overlays"
out_dir.mkdir(parents=True, exist_ok=True)

# Crop + click (set by ad layout)
advertisement = True
if advertisement:
    snap_coords = (644, 77, (1149-644), (981-75))  # (left, top, width, height)
    start_click = (1030, 900)
else:
    snap_coords = (483, 75, (988-483), (981-75))
    start_click = (870, 895)

RAIL_ID    = 9
IMG_SIZE   = 512
CONF, IOU  = 0.30, 0.45
MAX_DET    = 30

# Color/region filter
TARGET_COLORS_RGB  = [(119,104,67), (81,42,45)]
TOLERANCE          = 20.0
MIN_REGION_SIZE    = 30
MIN_REGION_HEIGHT  = 150

# Heat/triangle
HEAT_BLUR_KSIZE     = 51
RED_SCORE_THRESH    = 220
EXCLUDE_TOP_FRAC    = 0.40
EXCLUDE_BOTTOM_FRAC = 0.15
MIN_DARK_RED_AREA   = 1200
MIN_DARK_FRACTION   = 0.15
TRI_SIZE_PX         = 18

# Sampling ray length
SAMPLE_UP_PX        = 200
RAY_STEP_PX         = 20   # walk the ray every 20 px

# ===== Bend degrees (tune here) =====
BEND_LEFT_STATE_RIGHT_DEG  = -20.0  # N1
BEND_MID_STATE_RIGHT_DEG   = -20.0  # N2
BEND_MID_STATE_LEFT_DEG    = +20.0  # N3
BEND_RIGHT_STATE_LEFT_DEG  = +20.0  # N4

# Colours (BGR)
COLOR_GREEN  = (0, 255, 0)
COLOR_PINK   = (203, 192, 255)
COLOR_YELLOW = (0, 255, 255)
COLOR_RED    = (0, 0, 255)
COLOR_WHITE  = (255, 255, 255)
COLOR_CYAN   = (255, 255, 0)
COLOR_BLACK  = (0, 0, 0)

# =======================
# Jake lane points + dynamic JAKE_POINT
# =======================
LANE_LEFT   = (300, 1340)
LANE_MID    = (490, 1340)
LANE_RIGHT  = (680, 1340)
LANE_POINTS = (LANE_LEFT, LANE_MID, LANE_RIGHT)  # index by lane (0,1,2)
JAKE_POINT  = LANE_MID  # will be set each frame from 'lane'

LANE_TARGET_DEG = {"left": -10.7, "mid": +1.5, "right": +15.0}

def lane_name_from_point(p):
    if p == LANE_LEFT:  return "left"
    if p == LANE_MID:   return "mid"
    if p == LANE_RIGHT: return "right"
    return "mid"


# ===== Movement logic (modular) HELPER FUNCTIOSNS==============================================================================================================

# --- tunnel wall color gate (HSV) ---
LOWBARRIER1_ID   = 4
ORANGETRAIN_ID   = 6
WALL_STRIP_PX    = 20           # vertical strip height checked just above the barrier
WALL_MATCH_FRAC  = 0.40         # % of “wall” pixels required to relabel
WALL_ORANGE_LO = np.array([5,  80,  60], dtype=np.uint8)   # H,S,V (lo)
WALL_ORANGE_HI = np.array([35, 255, 255], dtype=np.uint8)  # H,S,V (hi)


def promote_lowbarrier_when_wall(frame_bgr, masks_np, classes_np,
                                 strip_px=WALL_STRIP_PX, frac_thresh=WALL_MATCH_FRAC):
    """
    If a LOWBARRIER1 has an orange 'tunnel wall' strip right behind it,
    relabel that instance to ORANGETRAIN (treated as RED).
    """
    if masks_np is None or classes_np is None or masks_np.size == 0:
        return classes_np

    H, W = frame_bgr.shape[:2]
    hsv = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2HSV)
    wall_u8 = cv2.inRange(hsv, WALL_ORANGE_LO, WALL_ORANGE_HI)  # 0/255

    # iterate only over LOWBARRIER1 instances
    for i, cls in enumerate(classes_np):
        if int(cls) != LOWBARRIER1_ID:
            continue

        m = masks_np[i]
        # upsample to frame size if needed
        if m.shape != (H, W):
            m_full = cv2.resize(m.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST).astype(bool)
        else:
            m_full = m.astype(bool, copy=False)

        ys, xs = np.where(m_full)
        if xs.size == 0:
            continue

        x0, x1 = xs.min(), xs.max()
        y0, _  = ys.min(), ys.max()

        # check a strip immediately above the barrier (toward smaller y)
        yb0 = max(0, y0 - strip_px)
        yb1 = y0
        if yb1 <= yb0:
            continue

        strip = wall_u8[yb0:yb1, x0:x1+1]
        if strip.size == 0:
            continue

        frac = float(cv2.countNonZero(strip)) / strip.size
        if frac >= frac_thresh:
            classes_np[i] = ORANGETRAIN_ID  # promote to a RED class

    return classes_np


# extra classes/sets
WARN_FOR_MOVE = {2, 3, 4, 5, 8}      # yellow set that should try to sidestep if a green exists
JUMP_SET      = {3, 5, 10}           # Jump, LowBarrier2, Sidewalk
DUCK_SET      = {2, 4}               # HighBarrier1, LowBarrier1

# action keys (change if your emulator uses different binds)
JUMP_KEY = "up"
DUCK_KEY = "down"

# --- "white-ish" lane probe (5x5 box counts) ---
# tune these if your Jake sprite/board highlight isn't pure white
WHITE_MIN = np.array([220, 220, 220], dtype=np.uint8)  # BGR lower bound
WHITE_MAX = np.array([255, 255, 255], dtype=np.uint8)  # BGR upper bound
BOX_RAD   = 2  # 5x5 => radius 2

def _count_white_around(img_bgr, pt, box_rad=BOX_RAD):
    H, W = img_bgr.shape[:2]
    x, y = pt
    x0 = max(0, x - box_rad); x1 = min(W, x + box_rad + 1)
    y0 = max(0, y - box_rad); y1 = min(H, y + box_rad + 1)
    roi = img_bgr[y0:y1, x0:x1]
    if roi.size == 0:
        return 0
    mask = cv2.inRange(roi, WHITE_MIN, WHITE_MAX)
    return int(cv2.countNonZero(mask))

def _detect_lane_by_whiteness(img_bgr):
    # returns lane index 0/1/2 chosen by the largest white count;
    # if all zero, returns None to keep previous lane
    counts = [
        _count_white_around(img_bgr, LANE_LEFT),
        _count_white_around(img_bgr, LANE_MID),
        _count_white_around(img_bgr, LANE_RIGHT),
    ]
    best_idx = int(np.argmax(counts))
    return best_idx if counts[best_idx] > 0 else None




# action cooldown so we don't spam jump/duck
try:
    last_action_ts
except NameError:
    last_action_ts = 0.0
ACTION_COOLDOWN_S = 0.5

# distance threshold (pixels) from Jake to triangle apex for action decisions
ACTION_DIST_PX = 30

def _is_warn(cls_id: int | None) -> bool:
    return (cls_id is not None) and (int(cls_id) in WARN_FOR_MOVE)

def _dist_px(jx: int, jy: int, tx: int, ty: int) -> float:
    return math.hypot(tx - jx, ty - jy)

def _pick_best_green(cands, jx: int):
    """Choose the closest triangle with hit_class == None (no hit along ray)."""
    greens = [c for c in cands if c["hit_class"] is None]
    if not greens:
        return None
    greens = [c for c in greens if c["pos"][0] != jx] or greens
    return min(greens, key=lambda c: abs(c["pos"][0] - jx))

def _schedule(fn, *args, **kwargs):
    Thread(target=fn, args=args, kwargs=kwargs, daemon=True).start()

def _do_jump_then_duck(delay_s: float = 0.50):
    pyautogui.press(JUMP_KEY)
    time.sleep(delay_s)
    pyautogui.press(DUCK_KEY)

def _try_jump_then_duck():
    if not MOVEMENT_ENABLED:
        return
    global last_action_ts
    now = time.perf_counter()
    if now - last_action_ts >= ACTION_COOLDOWN_S:
        last_action_ts = now
        _schedule(_do_jump_then_duck, 0.20)

MIN_GREEN_AHEAD_PX = 400
def _filter_green_far(cands, jake_band_y: int, min_ahead_px: int = MIN_GREEN_AHEAD_PX):
    """Keep only green triangles that are at least `min_ahead_px` above Jake's y band."""
    out = []
    for c in cands:
        _, yt = c["pos"]
        if (jake_band_y - yt) >= min_ahead_px:  # keep if ≥ 400 px ahead
            out.append(c)
    return out

def first_red_hit_y(pos, masks_np, classes_np, H, W, band_px=6, step_px=5, max_up=SAMPLE_UP_PX):
    """Return the screen y of the first RED pixel straight above `pos`, or None."""
    if masks_np is None or masks_np.size == 0: return None
    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1)); sy = (mh - 1) / max(1, (H - 1))
    red_idx = [i for i, c in enumerate(classes_np) if int(c) in DANGER_RED]
    if not red_idx: return None

    x0, y0 = int(pos[0]), int(pos[1])
    x0 = _clampi(x0, 0, W-1); y0 = _clampi(y0, 0, H-1)

    for t in range(step_px, max_up + 1, step_px):
        y = _clampi(y0 - t, 0, H-1)
        for dx in range(-band_px, band_px + 1):
            x = _clampi(x0 + dx, 0, W-1)
            mx = _clampi(int(round(x * sx)), 0, mw-1)
            my = _clampi(int(round(y * sy)), 0, mh-1)
            for i in red_idx:
                if masks_np[i][my, mx] > 0.5:
                    return y
    return None

def first_hit_y(pos, masks_np, classes_np, H, W, class_set, band_px=6, step_px=5, max_up=SAMPLE_UP_PX):
    """Return the screen y of the first pixel (straight up) whose class ∈ class_set."""
    if masks_np is None or masks_np.size == 0: return None
    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1)); sy = (mh - 1) / max(1, (H - 1))
    idxs = [i for i, c in enumerate(classes_np) if int(c) in class_set]
    if not idxs: return None

    x0, y0 = int(pos[0]), int(pos[1])
    x0 = _clampi(x0, 0, W-1); y0 = _clampi(y0, 0, H-1)

    for t in range(step_px, max_up + 1, step_px):
        y = _clampi(y0 - t, 0, H-1)
        for dx in range(-band_px, band_px + 1):
            x = _clampi(x0 + dx, 0, W-1)
            mx = _clampi(int(round(x * sx)), 0, mw-1)
            my = _clampi(int(round(y * sy)), 0, mh-1)
            for i in idxs:
                if masks_np[i][my, mx] > 0.5:
                    return y
    return None


# Only step from RED into a YELLOW lane if its triangle is far enough ahead
MIN_YELLOW_AHEAD_PX = 400
def _filter_yellow_far(cands, jake_band_y: int, min_ahead_px: int = MIN_YELLOW_AHEAD_PX):
    """Keep only yellow triangles that are at least `min_ahead_px` above Jake's y band."""
    out = []
    for c in cands:
        _, yt = c["pos"]
        if (jake_band_y - yt) >= min_ahead_px:
            out.append(c)
    return out


def _try_duck():
    if not MOVEMENT_ENABLED:
        return
    global last_action_ts
    now = time.perf_counter()
    if now - last_action_ts >= ACTION_COOLDOWN_S:
        last_action_ts = now
        _schedule(pyautogui.press, DUCK_KEY)
try:
    last_move_ts
except NameError:
    last_move_ts = 0.0

MOVE_COOLDOWN_S = 0.10  # 100 ms

def _is_danger(cls_id: int | None) -> bool:
    return (cls_id is not None) and (int(cls_id) in DANGER_RED)

def _is_safe(cls_id: int | None) -> bool:
    return not _is_danger(cls_id)

def _filter_by_lane(cands, jx: int, lane_idx: int):
    """Prune triangles based on current lane:
       - lane 0 (left): drop triangles with x < jx
       - lane 2 (right): drop triangles with x > jx
       - lane 1 (mid): keep all
    """
    if lane_idx == 0:
        return [c for c in cands if c["pos"][0] >= jx]
    if lane_idx == 2:
        return [c for c in cands if c["pos"][0] <= jx]
    return cands

def _pick_best_safe_triangle(cands, jx: int):
    """Prefer triangles with hit_class == None; otherwise any non-danger.
       Break ties by smallest |x - jx|.
    """
    if not cands:
        return None
    none_hits  = [c for c in cands if c["hit_class"] is None]
    safe_hits  = [c for c in cands if c["hit_class"] is not None and _is_safe(c["hit_class"])]
    pool = none_hits if none_hits else safe_hits
    if not pool:
        return None
    # exclude triangles exactly aligned with Jake in x (no direction)
    pool = [c for c in pool if c["pos"][0] != jx] or pool
    return min(pool, key=lambda c: abs(c["pos"][0] - jx))

def _issue_move_towards_x(jx: int, tx: int):
    global lane, last_move_ts, _synth_block_until
    if not MOVEMENT_ENABLED:
        return

    now = time.perf_counter()
    if now - last_move_ts < MOVE_COOLDOWN_S:
        return

    if tx < jx and lane > MIN_LANE:
        _synth_block_until = time.monotonic() + SYNTHETIC_SUPPRESS_S
        pyautogui.press('left')
        lane = max(MIN_LANE, lane - 1)
        print(f"[AI MOVE] left -> Lane {lane}")
        last_move_ts = now

    elif tx > jx and lane < MAX_LANE:
        _synth_block_until = time.monotonic() + SYNTHETIC_SUPPRESS_S
        pyautogui.press('right')
        lane = min(MAX_LANE, lane + 1)
        print(f"[AI MOVE] right -> Lane {lane}")
        last_move_ts = now
    else:
        print('WE ARE COOKED')


#============================================================================================================================================


# =======================
# Lane/keyboard state
# =======================
lane = 1
MIN_LANE = 0
MAX_LANE = 2
running = True

# ===== Debounce / cooldown =====
COOLDOWN_MS = 20
_last_press_ts = 0.0  # monotonic seconds

def on_press(key):
    global lane, running, _last_press_ts, _synth_block_until
    now = time.monotonic()

    # swallow AI-generated lane key events during the suppression window
    if key in (keyboard.Key.left, keyboard.Key.right) and now < _synth_block_until:
        return

    if key != keyboard.Key.esc and (now - _last_press_ts) * 1000.0 < COOLDOWN_MS:
        return

    try:
        if key == keyboard.Key.left:
            lane = max(MIN_LANE, lane - 1)
            _last_press_ts = now
            print(f"Moved Left into → Lane {lane}")

        elif key == keyboard.Key.right:
            lane = min(MAX_LANE, lane + 1)
            _last_press_ts = now
            print(f"Moved Right into → Lane {lane}")

        elif key == keyboard.Key.esc:
            running = False
            return False
    except Exception as e:
        print(f"Error: {e}")


# =======================
# System/Backends
# =======================
cv2.setUseOptimized(True)
try: cv2.setNumThreads(max(1, (os.cpu_count() or 1) - 1))
except Exception: pass

if torch.cuda.is_available():
    device, half = 0, True
    torch.backends.cudnn.benchmark = True
    try: torch.set_float32_matmul_precision('high')
    except Exception: pass
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device, half = "mps", False
else:
    device, half = "cpu", False

# =======================
# Model
# =======================
model = YOLO(weights)
try: model.fuse()
except Exception: pass

# warmup
_dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
_ = model.predict(_dummy, task="segment", imgsz=IMG_SIZE, device=device,
                  conf=CONF, iou=IOU, verbose=False, half=half, max_det=MAX_DET)

# =======================
# Precomputed
# =======================
TARGETS_BGR_F32 = np.array([(r,g,b)[::-1] for (r,g,b) in TARGET_COLORS_RGB], dtype=np.float32)
TOL2            = TOLERANCE * TOLERANCE

# Class buckets for probe classification
DANGER_RED   = {1, 6, 7, 11}
WARN_YELLOW  = {2, 3, 4, 5, 8}
BOOTS_PINK   = {0}

CLASS_COLOURS = {
    0:(255,255,0),1:(192,192,192),2:(0,128,255),3:(0,255,0),
    4:(255,0,255),5:(0,255,255),6:(255,128,0),7:(128,0,255),
    8:(0,0,128),9:(0,0,255),10:(128,128,0),11:(255,255,102)
}
LABELS = {
    0:"BOOTS",1:"GREYTRAIN",2:"HIGHBARRIER1",3:"JUMP",4:"LOWBARRIER1",
    5:"LOWBARRIER2",6:"ORANGETRAIN",7:"PILLAR",8:"RAMP",9:"RAILS",
    10:"SIDEWALK",11:"YELLOWTRAIN"
}

# ====== tiny helpers ======
def _clampi(v, lo, hi):
    return lo if v < lo else (hi if v > hi else v)

def _fmt_px(v):
    return f"{v:.1f}px" if v is not None else "n/a"

# =======================
# Parsec to front + click Start (non-blocking failures)
# =======================
try:
    subprocess.run(["osascript", "-e", 'tell application "Parsec" to activate'], check=False)
    time.sleep(0.4)
except Exception:
    pass

try:
    pyautogui.click(start_click)
except Exception:
    pass

# =======================
# Fast rails green finder
# =======================
def highlight_rails_mask_only_fast(img_bgr, rail_mask):
    H, W = rail_mask.shape
    if not rail_mask.any():
        return np.zeros((H, W), dtype=bool)

    rail_u8 = rail_mask.view(dtype=np.uint8) * 255
    x, y, w, h = cv2.boundingRect(rail_u8)
    img_roi  = img_bgr[y:y+h, x:x+w]
    mask_roi = rail_u8[y:y+h, x:x+w]

    img_f = img_roi.astype(np.float32, copy=False)
    diff  = img_f[:, :, None, :] - TARGETS_BGR_F32[None, None, :, :]
    dist2 = (diff * diff).sum(-1)
    colour_hit = (dist2 <= TOL2).any(-1)

    combined = np.logical_and(colour_hit, mask_roi.astype(bool))

    comp = combined.astype(np.uint8)
    n, lbls, stats, _ = cv2.connectedComponentsWithStats(comp, 8)
    if n <= 1: return np.zeros((H, W), dtype=bool)

    good = np.zeros_like(combined)
    areas = stats[1:, cv2.CC_STAT_AREA]
    hs    = stats[1:, cv2.CC_STAT_HEIGHT]
    keep  = np.where((areas >= MIN_REGION_SIZE) & (hs >= MIN_REGION_HEIGHT))[0] + 1
    for k in keep: good[lbls == k] = True

    full = np.zeros((H, W), dtype=bool)
    full[y:y+h, x:x+w] = good
    return full

def red_vs_green_score(red_mask, green_mask):
    k = (HEAT_BLUR_KSIZE, HEAT_BLUR_KSIZE)
    r = cv2.blur(red_mask.astype(np.float32, copy=False), k)
    g = cv2.blur(green_mask.astype(np.float32, copy=False), k)
    diff = r - g
    amax = float(np.max(np.abs(diff))) + 1e-6
    norm = (diff / (2.0 * amax) + 0.5)
    return np.clip(norm * 255.0, 0, 255.0).astype(np.uint8, copy=False)

def purple_triangles(score, H):
    top_ex = int(H * EXCLUDE_TOP_FRAC)
    bot_ex = int(H * EXCLUDE_BOTTOM_FRAC)
    dark = (score >= RED_SCORE_THRESH).astype(np.uint8, copy=False)
    if top_ex: dark[:top_ex, :] = 0
    if bot_ex: dark[-bot_ex:, :] = 0

    dark = cv2.morphologyEx(
        dark, cv2.MORPH_OPEN,
        cv2.getStructuringElement(cv2.MORPH_RECT, (5, 9)), iterations=1
    )
    total_dark = int(dark.sum())
    if total_dark == 0: return [], None

    frac_thresh = int(np.ceil(MIN_DARK_FRACTION * total_dark))
    n_lbl, lbls, stats, _ = cv2.connectedComponentsWithStats(dark, 8)
    if n_lbl <= 1: return [], None

    tris = []
    for lbl in range(1, n_lbl):
        area = stats[lbl, cv2.CC_STAT_AREA]
        if area >= MIN_DARK_RED_AREA and area >= frac_thresh:
            ys, xs = np.where(lbls == lbl)
            if ys.size == 0: continue
            y_top = ys.min()
            x_mid = int(xs[ys == y_top].mean())
            tris.append((x_mid, int(y_top)))

    if not tris: return [], None
    best = min(tris, key=lambda xy: xy[1])
    return tris, best

# ===== Bearing-based Jake triangle selection =====
def signed_degrees_from_vertical(dx, dy):
    if dx == 0 and dy == 0: return 0.0
    return -math.degrees(math.atan2(dx, -dy))

def select_triangle_by_bearing(tri_positions, jx, jy, target_deg, min_dy=6):
    best_i, best_deg, best_err = -1, None, None
    for i, (xt, yt) in enumerate(tri_positions):
        dy = yt - jy
        if dy >= -min_dy:  # must be above Jake
            continue
        deg = signed_degrees_from_vertical(xt - jx, dy)
        err = abs(deg - target_deg)
        if (best_err is None) or (err < best_err):
            best_i, best_deg, best_err = i, deg, err
    return best_i, best_deg, best_err

# ===== Lane-aware curved sampling (precompute sin/cos) =====
def _precompute_trig():
    angles = sorted(set([0.0,
        BEND_LEFT_STATE_RIGHT_DEG,
        BEND_MID_STATE_RIGHT_DEG,
        BEND_MID_STATE_LEFT_DEG,
        BEND_RIGHT_STATE_LEFT_DEG
    ]))
    table = {}
    for a in angles:
        r = math.radians(a)
        table[a] = (math.sin(r), -math.cos(r))  # (dx, dy) for unit ray (up = -y)
    return table
TRIG_TABLE = _precompute_trig()

def pick_bend_angle(jake_point, xt, x_ref, idx, best_idx):
    if idx == best_idx:
        return 0.0
    if jake_point == LANE_LEFT:
        return BEND_LEFT_STATE_RIGHT_DEG if xt > x_ref else 0.0
    if jake_point == LANE_RIGHT:
        return BEND_RIGHT_STATE_LEFT_DEG if xt < x_ref else 0.0
    if xt > x_ref: return BEND_MID_STATE_RIGHT_DEG
    if xt < x_ref: return BEND_MID_STATE_LEFT_DEG
    return 0.0

# --------- walk-the-ray classifier (first-hit wins) ----------
def classify_triangles_at_sample_curved(
    tri_positions, masks_np, classes_np, H, W,
    jake_point, x_ref, best_idx, sample_px=SAMPLE_UP_PX, step_px=RAY_STEP_PX
):
    if masks_np is None or classes_np is None or len(tri_positions) == 0:
        return [], [], [], []  # colours, rays, hit_class_ids, hit_distances_px

    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1))
    sy = (mh - 1) / max(1, (H - 1))

    red_idx    = [i for i, c in enumerate(classes_np) if int(c) in DANGER_RED]
    yellow_idx = [i for i, c in enumerate(classes_np) if int(c) in WARN_YELLOW]
    boots_idx  = [i for i, c in enumerate(classes_np) if int(c) in BOOTS_PINK]

    colours, rays, hit_class_ids, hit_distances_px = [], [], [], []
    max_k = max(1, sample_px // max(1, step_px))

    for idx, (x0, y0) in enumerate(tri_positions):
        theta = pick_bend_angle(jake_point, x0, x_ref, idx, best_idx)
        dx1, dy1 = TRIG_TABLE[theta]

        hit_colour = COLOR_GREEN
        hit_cls = None
        hit_dist_px = None

        found = False
        for k in range(1, max_k + 1):
            t  = k * step_px
            xs = _clampi(int(round(x0 + dx1 * t)), 0, W-1)
            ys = _clampi(int(round(y0 + dy1 * t)), 0, H-1)
            mx = _clampi(int(round(xs * sx)), 0, mw-1)
            my = _clampi(int(round(ys * sy)), 0, mh-1)

            # RED first (so if red exists at a point, we record red distance)
            for i in red_idx:
                if masks_np[i][my, mx] > 0.5:
                    hit_colour = COLOR_RED
                    hit_cls = int(classes_np[i])
                    hit_dist_px = float(t)
                    found = True
                    break
            if found: break
            # then YELLOW
            for i in yellow_idx:
                if masks_np[i][my, mx] > 0.5:
                    hit_colour = COLOR_YELLOW
                    hit_cls = int(classes_np[i])
                    hit_dist_px = float(t)
                    found = True
                    break
            if found: break
            # then BOOTS
            for i in boots_idx:
                if masks_np[i][my, mx] > 0.5:
                    hit_colour = COLOR_PINK
                    hit_cls = int(classes_np[i])
                    hit_dist_px = float(t)
                    found = True
                    break
            if found: break

        x1 = _clampi(int(round(x0 + dx1 * sample_px)), 0, W-1)
        y1 = _clampi(int(round(y0 + dy1 * sample_px)), 0, H-1)

        colours.append(hit_colour)
        rays.append(((int(x0), int(y0)), (x1, y1), float(theta)))
        hit_class_ids.append(hit_cls)
        hit_distances_px.append(hit_dist_px)

    return colours, rays, hit_class_ids, hit_distances_px

# -----------------------------------------------------------------------

# =======================
# Frame post-processing
# =======================
def process_frame_post(frame_bgr, yolo_res, jake_point):
    """
    Returns (…)
      tri_best_xy, tri_count, mask_count, to_cpu_ms, post_ms,
      masks_np, classes_np, rail_mask, green_mask,
      tri_positions, tri_colours, tri_rays,
      best_idx, best_deg, x_ref,
      tri_hit_classes, tri_summary
    """
    H, W = frame_bgr.shape[:2]
    if yolo_res.masks is None:
        return (None, 0, 0, 0.0, 0.0, None, None, None, None,
                [], [], [], None, None, None, [], [])

    t0 = time.perf_counter()
    masks_np = yolo_res.masks.data.detach().cpu().numpy()  # [n,h,w]
    if hasattr(yolo_res.masks, "cls") and yolo_res.masks.cls is not None:
        classes_np = yolo_res.masks.cls.detach().cpu().numpy().astype(int)
    else:
        classes_np = yolo_res.boxes.cls.detach().cpu().numpy().astype(int)

    to_cpu_ms = (time.perf_counter() - t0) * 1000.0
    mask_count = int(masks_np.shape[0])
    if mask_count == 0 or classes_np.size == 0:
        return (None, 0, mask_count, to_cpu_ms, 0.0, masks_np, classes_np, None, None,
                [], [], [], None, None, None, [], [])

    classes_np = promote_lowbarrier_when_wall(frame_bgr, masks_np, classes_np)

    rail_sel = (classes_np == RAIL_ID)
    if not np.any(rail_sel):
        return (None, 0, mask_count, to_cpu_ms, 0.0, masks_np, classes_np, None, None,
                [], [], [], None, None, None, [], [])

    t1 = time.perf_counter()
    rail_masks = masks_np[rail_sel].astype(bool, copy=False)
    union = np.any(rail_masks, axis=0).astype(np.uint8, copy=False)
    rail_mask = cv2.resize(union, (W, H), interpolation=cv2.INTER_NEAREST).astype(bool, copy=False)

    green = highlight_rails_mask_only_fast(frame_bgr, rail_mask)
    red   = np.logical_and(rail_mask, np.logical_not(green))
    score = red_vs_green_score(red, green)
    tri_positions, tri_best = purple_triangles(score, H)

    # Jake triangle by bearing
    lane_name = lane_name_from_point(jake_point)
    target_deg = LANE_TARGET_DEG[lane_name]
    xj, yj = jake_point
    best_idx, best_deg, _ = select_triangle_by_bearing(tri_positions, xj, yj, target_deg, min_dy=6)

    # x_ref for bending
    x_ref = tri_positions[best_idx][0] if (lane_name == "mid" and best_idx is not None and 0 <= best_idx < len(tri_positions)) else xj

    tri_colours, tri_rays, tri_hit_classes, tri_hit_dists = classify_triangles_at_sample_curved(
        tri_positions, masks_np, classes_np, H, W, jake_point, x_ref, best_idx,
        SAMPLE_UP_PX, RAY_STEP_PX
    )

    post_ms = (time.perf_counter() - t1) * 1000.0

    # Minimal summary (useful later if you want to analyze decisions offline)
    tri_summary = []
    for i, (x, y) in enumerate(tri_positions):
        cid = tri_hit_classes[i] if i < len(tri_hit_classes) else None
        hdist = tri_hit_dists[i] if i < len(tri_hit_dists) else None
        tri_summary.append({
            "pos": (int(x), int(y)),
            "hit_class": None if cid is None else int(cid),
            "hit_label": None if cid is None else LABELS.get(int(cid), f"C{int(cid)}"),
            "hit_dist_px": None if hdist is None else float(hdist),
            "is_jake": (i == best_idx)
        })

    return (tri_best, len(tri_positions), mask_count, to_cpu_ms, post_ms,
            masks_np, classes_np, rail_mask, green,
            tri_positions, tri_colours, tri_rays,
            best_idx, best_deg, x_ref,
            tri_hit_classes, tri_summary)

# =======================
# Main benchmark
# =======================
def main():
    # Locate frames
    root = Path.cwd()
    frames_dir = root / "frames"
    if not frames_dir.exists():
        alt = root / "alpha" / "frames"
        if alt.exists():
            frames_dir = alt
    if not frames_dir.exists():
        raise SystemExit("No ./frames or ./alpha/frames directory found.")

    img_paths = sorted(
        [p for ext in ("*.png","*.jpg","*.jpeg","*.bmp") for p in frames_dir.glob(ext)]
    )
    if not img_paths:
        raise SystemExit(f"No images in {frames_dir}")

    # Backend
    cv2.setUseOptimized(True)
    try: cv2.setNumThreads(max(1, (os.cpu_count() or 1) - 1))
    except Exception: pass

    if torch.cuda.is_available():
        device, half = 0, True
        torch.backends.cudnn.benchmark = True
        try: torch.set_float32_matmul_precision('high')
        except Exception: pass
    elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        device, half = "mps", False
    else:
        device, half = "cpu", False

    # Model
    model = YOLO(weights)
    try: model.fuse()
    except Exception: pass

    # Warmup
    _dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
    _ = model.predict(_dummy, task="segment", imgsz=IMG_SIZE, device=device,
                      conf=CONF, iou=IOU, verbose=False, half=half, max_det=MAX_DET)

    # Timing accumulators
    infer_ms_list = []
    tocpu_ms_list = []
    post_ms_list  = []
    total_ms_list = []

    # Assume we start mid-lane for the bearing math
    JAKE_POINT = LANE_MID

    print(f"Benchmarking {len(img_paths)} frames from: {frames_dir}\n")
    print(f"{'frame':>16}  {'infer(ms)':>10}  {'toCPU(ms)':>10}  {'post(ms)':>9}  {'total(ms)':>10}")

    for p in img_paths:
        img = cv2.imread(str(p), cv2.IMREAD_COLOR)
        if img is None:
            continue

        t0 = time.perf_counter()
        res_list = model.predict(
            [img], task="segment", imgsz=IMG_SIZE, device=device,
            conf=CONF, iou=IOU, verbose=False, half=half, max_det=MAX_DET, batch=1
        )
        infer_ms = (time.perf_counter() - t0) * 1000.0
        yres = res_list[0]

        # Postproc (returns component timings)
        (_, _, _, to_cpu_ms, post_ms,
         _, _, _, _,
         _, _, _,
         _, _, _,
         _, _) = process_frame_post(img, yres, JAKE_POINT)

        total_ms = infer_ms + to_cpu_ms + post_ms

        infer_ms_list.append(infer_ms)
        tocpu_ms_list.append(to_cpu_ms)
        post_ms_list.append(post_ms)
        total_ms_list.append(total_ms)

        print(f"{p.name:>16}  {infer_ms:10.1f}  {to_cpu_ms:10.1f}  {post_ms:9.1f}  {total_ms:10.1f}")

    # Summary
    def q(arr, qv):  # percentile helper
        arr_sorted = sorted(arr)
        idx = max(0, min(len(arr_sorted)-1, int(round((qv/100.0)*(len(arr_sorted)-1)))))
        return arr_sorted[idx]

    mean_total = statistics.fmean(total_ms_list)
    fps_mean   = 1000.0 / mean_total if mean_total > 0 else 0.0

    print("\n=== Summary ===")
    print(f"Frames: {len(total_ms_list)}")
    print(f"Infer  : mean={statistics.fmean(infer_ms_list):.1f} ms")
    print(f"toCPU  : mean={statistics.fmean(tocpu_ms_list):.1f} ms")
    print(f"Post   : mean={statistics.fmean(post_ms_list):.1f} ms")
    print(f"Total  : mean={mean_total:.1f} ms  |  FPS≈{fps_mean:.2f}")
    print(f"Total  : min={min(total_ms_list):.1f}  p50={q(total_ms_list,50):.1f}  "
          f"p90={q(total_ms_list,90):.1f}  p99={q(total_ms_list,99):.1f}  max={max(total_ms_list):.1f} ms")

if __name__ == "__main__":
    main()



YOLO11n-seg summary (fused): 113 layers, 2,836,908 parameters, 0 gradients, 10.2 GFLOPs
[BOOT] movement muted
YOLO11n-seg summary (fused): 113 layers, 2,836,908 parameters, 0 gradients, 10.2 GFLOPs
Benchmarking 42 frames from: /Users/marcus/Documents/GitHub/Ai-plays-SubwaySurfers/frames

           frame   infer(ms)   toCPU(ms)   post(ms)   total(ms)
[BOOT] movement unmuted
       00001.png       645.3         1.0      143.8       790.2
       00002.png       144.1         2.0       80.8       226.9
       00003.png        66.4         0.7      142.9       210.0
       00004.png       204.6         1.1      141.7       347.4
       00005.png        59.5         1.1      110.3       170.8
       00006.png       133.7       179.1      138.1       450.9
       00007.png        63.5         1.1      109.0       173.6


In [1]:
#!/usr/bin/env python3
# Live overlays + lane-aware curved sampling (optimized postproc)
# • Parsec focus + auto click
# • mss live capture of a crop region
# • Arrow keys switch lane (0/1/2) -> JAKE_POINT updates per frame
# • Full overlay rendering + per-frame save
# • Prints compact timing per frame
# • RETURNS per frame: tri_positions, best_idx, tri_hit_classes, tri_summary (for movement logic)

import os, time, math, subprocess
import cv2, torch, numpy as np
from pathlib import Path
from mss import mss
import pyautogui
from pynput import keyboard
from ultralytics import YOLO
from threading import Timer
from threading import Thread
import time
start_time = time.perf_counter()



# --- swallow AI-generated keypresses in the listener for a short window ---
SYNTHETIC_SUPPRESS_S = 0.15  # 150 ms is plenty
_synth_block_until = 0.0     # simple, explicit init avoids IDE warnings

try:
    _synth_block_until
except NameError:
    _synth_block_until = 0.0

# ======================= Quick supreesion to prevent instant bailouts =======================

# --- allows 0.5s of movement, then mute for 2.5s, then restore ---
# Save originals
__press_orig   = pyautogui.press
__keyDown_orig = pyautogui.keyDown
__keyUp_orig   = pyautogui.keyUp
__hotkey_orig  = pyautogui.hotkey

# near your other globals, after imports
MOVEMENT_ENABLED = True

def __mute_keys():
    global MOVEMENT_ENABLED
    MOVEMENT_ENABLED = False
    pyautogui.press  = lambda *a, **k: None
    pyautogui.keyDown = lambda *a, **k: None
    pyautogui.keyUp   = lambda *a, **k: None
    pyautogui.hotkey  = lambda *a, **k: None
    print("[BOOT] movement muted")

def __unmute_keys():
    global MOVEMENT_ENABLED
    MOVEMENT_ENABLED = True
    pyautogui.press   = __press_orig
    pyautogui.keyDown = __keyDown_orig
    pyautogui.keyUp   = __keyUp_orig
    pyautogui.hotkey  = __hotkey_orig
    print("[BOOT] movement unmuted")


# Allow movement immediately; after 0.5s, mute; after 3.0s total, unmute
Timer(0.5, __mute_keys).start()
Timer(4.0, __unmute_keys).start()


# =======================
# Config
# =======================
home       = os.path.expanduser("~")
weights    = f"{home}/models/jakes-loped/jakes-finder-mk1/1/weights.pt"

# SAVE HERE
out_dir    = Path(home) / "Documents" / "GitHub" / "Ai-plays-SubwaySurfers" / "out_live_overlays"
out_dir.mkdir(parents=True, exist_ok=True)

# Crop + click (set by ad layout)
advertisement = True
if advertisement:
    snap_coords = (644, 77, (1149-644), (981-75))  # (left, top, width, height)
    start_click = (1030, 900)
else:
    snap_coords = (483, 75, (988-483), (981-75))
    start_click = (870, 895)

RAIL_ID    = 9
IMG_SIZE   = 512
CONF, IOU  = 0.30, 0.45
MAX_DET    = 30

# Color/region filter
TARGET_COLORS_RGB  = [(119,104,67), (81,42,45)]
TOLERANCE          = 20.0
MIN_REGION_SIZE    = 30
MIN_REGION_HEIGHT  = 150

# Heat/triangle
HEAT_BLUR_KSIZE     = 51
RED_SCORE_THRESH    = 220
EXCLUDE_TOP_FRAC    = 0.40
EXCLUDE_BOTTOM_FRAC = 0.15
MIN_DARK_RED_AREA   = 1200
MIN_DARK_FRACTION   = 0.15
TRI_SIZE_PX         = 18

# Sampling ray length
SAMPLE_UP_PX        = 200
RAY_STEP_PX         = 20   # walk the ray every 20 px

# ===== Bend degrees (tune here) =====
BEND_LEFT_STATE_RIGHT_DEG  = -20.0  # N1
BEND_MID_STATE_RIGHT_DEG   = -20.0  # N2
BEND_MID_STATE_LEFT_DEG    = +20.0  # N3
BEND_RIGHT_STATE_LEFT_DEG  = +20.0  # N4

# Colours (BGR)
COLOR_GREEN  = (0, 255, 0)
COLOR_PINK   = (203, 192, 255)
COLOR_YELLOW = (0, 255, 255)
COLOR_RED    = (0, 0, 255)
COLOR_WHITE  = (255, 255, 255)
COLOR_CYAN   = (255, 255, 0)
COLOR_BLACK  = (0, 0, 0)

# =======================
# Jake lane points + dynamic JAKE_POINT
# =======================
LANE_LEFT   = (300, 1340)
LANE_MID    = (490, 1340)
LANE_RIGHT  = (680, 1340)
LANE_POINTS = (LANE_LEFT, LANE_MID, LANE_RIGHT)  # index by lane (0,1,2)
JAKE_POINT  = LANE_MID  # will be set each frame from 'lane'

LANE_TARGET_DEG = {"left": -10.7, "mid": +1.5, "right": +15.0}

def lane_name_from_point(p):
    if p == LANE_LEFT:  return "left"
    if p == LANE_MID:   return "mid"
    if p == LANE_RIGHT: return "right"
    return "mid"


# ===== Movement logic (modular) HELPER FUNCTIOSNS==============================================================================================================


# --- inference pause gate (after imports/globals) ---
PAUSE_AFTER_MOVE_S = 0.40

try:
    PAUSE_UNTIL
except NameError:
    PAUSE_UNTIL = 0.0  # monotonic timestamp

def pause_inference(sec: float = PAUSE_AFTER_MOVE_S):
    """Freeze the main loop for `sec` seconds from NOW."""
    global PAUSE_UNTIL
    PAUSE_UNTIL = time.monotonic() + sec


# One-shot gating
try:
    IMPACT_TOKEN
except NameError:
    IMPACT_TOKEN = None  # (lane, class_id)

def _fire_action_key(key: str, token_snapshot):
    global IMPACT_TOKEN
    if MOVEMENT_ENABLED:
        pyautogui.press(key)
        print(f"[TIMER FIRE] pressed {key}")
    # allow instant re-arm for the next identical obstacle
    if IMPACT_TOKEN == token_snapshot:
        IMPACT_TOKEN = None


# Only arm timer when distance is strictly inside this window (px)
IMPACT_MIN_PX = 100
IMPACT_MAX_PX = 650

# ===== Impact delay lookup (distance px -> seconds) =====
# Fill these with your *monotone ascending* distances (px) and corresponding delays (seconds).
# Example placeholders; REPLACE with your numbers:
LUT_PX = np.array([100, 125, 150, 175, 200, 225, 250, 275, 300, 325, 350, 375, 400, 425, 450, 475, 500, 525, 550, 575, 600, 625, 650, 675, 700, 725, 750, 775, 800], dtype=float)

SHORTEN_S = 0.35 #Shortnen by 100ms
# Safety clamps so Timer never explodes or becomes a no-op
MIN_DELAY_S = 0.03   # 30 ms
MAX_DELAY_S = 2.00   # 2 s

LUT_S = np.clip(np.array([0.0259, 0.0303, 0.0353, 0.0412, 0.0481, 0.0561, 0.0655, 0.0765, 0.0893, 0.1042, 0.1216, 0.1419, 0.1656, 0.1933, 0.2256, 0.2633, 0.3073, 0.3586, 0.4185, 0.4885, 0.5701, 0.6654, 0.7765, 0.9063, 1.0578, 1.2345, 1.4408, 1.6815, 1.9625], dtype=float) - SHORTEN_S, MIN_DELAY_S, MAX_DELAY_S)


def first_mask_hit_starburst_then_ray_for_set(
    jake_point, tri_pos, theta_deg, masks_np, classes_np, H, W,
    allowed_classes, up2_px=SAMPLE_UP_PX, step_px=2
):
    """
    Same as first_mask_hit_starburst_then_ray, but only counts hits whose class ∈ allowed_classes.
    Returns (dist_px_from_jake, (x_hit, y_hit), class_id) or (None, None, None).
    """
    if masks_np is None or classes_np is None or masks_np.size == 0:
        return (None, None, None)

    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1))
    sy = (mh - 1) / max(1, (H - 1))

    allowed = set(int(c) for c in allowed_classes)
    idxs = [i for i, c in enumerate(classes_np) if int(c) in allowed]
    if not idxs:
        return (None, None, None)

    def _hit_at(xs, ys):
        mx = _clampi(int(round(xs * sx)), 0, mw-1)
        my = _clampi(int(round(ys * sy)), 0, mh-1)
        for i in idxs:
            if masks_np[i][my, mx] > 0.5:
                return int(classes_np[i])
        return None

    x0, y0 = map(int, jake_point)
    x1, y1 = map(int, tri_pos)
    x0 = _clampi(x0, 0, W-1); y0 = _clampi(y0, 0, H-1)
    x1 = _clampi(x1, 0, W-1); y1 = _clampi(y1, 0, H-1)

    dx = x1 - x0; dy = y1 - y0
    seg1_len = max(1e-6, math.hypot(dx, dy))
    n1 = max(1, int(seg1_len // max(1, step_px)))
    for k in range(1, n1 + 1):
        t = min(1.0, (k * step_px) / seg1_len)
        xs = _clampi(int(round(x0 + dx * t)), 0, W-1)
        ys = _clampi(int(round(y0 + dy * t)), 0, H-1)
        cls_hit = _hit_at(xs, ys)
        if cls_hit is not None:
            dist = math.hypot(xs - x0, ys - y0)
            return (float(dist), (int(xs), int(ys)), int(cls_hit))

    dxr, dyr = TRIG_TABLE.get(theta_deg, (0.0, -1.0))
    n2 = max(1, int(up2_px // max(1, step_px)))
    for k in range(1, n2 + 1):
        t = k * step_px
        xs = _clampi(int(round(x1 + dxr * t)), 0, W-1)
        ys = _clampi(int(round(y1 + dyr * t)), 0, H-1)
        cls_hit = _hit_at(xs, ys)
        if cls_hit is not None:
            dist = seg1_len + math.hypot(xs - x1, ys - y1)
            return (float(dist), (int(xs), int(ys)), int(cls_hit))

    return (None, None, None)


# ===== Impact-timer overhaul (single-triangle action) =====
# Classes to act on (exact mapping)
IMPACT_CLASSES = {2, 3, 4, 5}  # 2:HIGHBARRIER1, 3:JUMP, 4:LOWBARRIER1, 5:LOWBARRIER2
ACTION_BY_CLASS = {3: "up", 5: "up", 4: "down", 2: "down"}  # per spec

# Global timer handle (overwritten when re-arming)
try:
    IMPACT_TIMER
except NameError:
    IMPACT_TIMER = None


def _cancel_impact_timer(reason=None):
    global IMPACT_TIMER
    if IMPACT_TIMER is not None and getattr(IMPACT_TIMER, "is_alive", lambda: False)():
        print("[TIMER] cancelled" + (f" ({reason})" if reason else ""))
        try:
            IMPACT_TIMER.cancel()
        except Exception:
            pass
    IMPACT_TIMER = None



def _impact_delay_seconds(dist_px: float) -> float:
    """
    O(1) lookup + linear interpolation from a monotone table (px -> seconds).
    - Dist is cropped to [IMPACT_MIN_PX, IMPACT_MAX_PX] to preserve your windowing.
    - Result is clamped to [MIN_DELAY_S, MAX_DELAY_S] for Timer safety.
    """
    if not math.isfinite(dist_px):
        return MIN_DELAY_S

    # Respect your arming window; crop inside it so behavior matches old gating.
    d = max(IMPACT_MIN_PX, min(float(dist_px), IMPACT_MAX_PX))

    # Interpolate within the table’s range
    lo = float(LUT_PX[0]); hi = float(LUT_PX[-1])
    d_clamped = max(lo, min(d, hi))

    delay = float(np.interp(d_clamped, LUT_PX, LUT_S))
    # Final safety clamp
    return max(MIN_DELAY_S, min(delay, MAX_DELAY_S))


def _arm_impact_timer(dist_px: float, cls_id: int):
    """
    Overwrite-or-set the global timer if dist is in (400, 800) px and class is in IMPACT_CLASSES.
    Prints whether we armed a NEW timer or UPDATED (overwrote) an existing one.
    """
    if cls_id not in IMPACT_CLASSES:
        return

    if not (IMPACT_MIN_PX < dist_px < IMPACT_MAX_PX):
        # Optional debug: show why we didn't arm
        print(f"[TIMER] skip: dist {dist_px:.1f}px outside ({IMPACT_MIN_PX},{IMPACT_MAX_PX}) for {LABELS.get(int(cls_id), cls_id)}")
        return

    key = ACTION_BY_CLASS.get(int(cls_id))
    if not key:
        return

    delay_s = _impact_delay_seconds(dist_px)

    if not math.isfinite(delay_s) or delay_s <= 0.0:
        print(f"[TIMER] skip: invalid delay {delay_s} for dist={dist_px:.1f}px, cls={LABELS.get(int(cls_id), cls_id)}")
        return

    # detect whether we are overwriting a live timer
    # detect whether we are overwriting a live timer
    global IMPACT_TIMER, IMPACT_TOKEN
    was_live = (IMPACT_TIMER is not None and getattr(IMPACT_TIMER, "is_alive", lambda: False)())

    _cancel_impact_timer()  # overwrite existing timer if any

    # --- REPLACEMENT: arm with a token so post-fire re-arm is instant ---
    new_token = (lane, int(cls_id))            # <-- build token for THIS arm
    from threading import Timer
    IMPACT_TIMER = Timer(delay_s, _fire_action_key, args=(key, new_token))
    IMPACT_TIMER.daemon = True
    IMPACT_TIMER.start()

    IMPACT_TOKEN = new_token                   # remember what we armed
    status = "updated" if was_live else "armed"
    print(f"[TIMER] {status}: key={key} in {delay_s:.3f}s  (dist={dist_px:.1f}px, cls={LABELS.get(int(cls_id), cls_id)})")


def first_mask_hit_starburst_then_ray(
    jake_point, tri_pos, theta_deg, masks_np, classes_np, H, W,
    up2_px=SAMPLE_UP_PX, step_px=2, exclude_classes=(RAIL_ID,), danger_only=False
):
    """
    Follow the path used in viz:
      1) straight line JAKE_POINT -> tri_pos,
      2) then continue from tri_pos along the angled probe ray (theta_deg) for up to `up2_px`.
    Return (dist_px_from_jake, (x_hit, y_hit), class_id) for the first mask hit
    (skipping rails or restricted to DANGER_RED if danger_only=True). If none, return (None, None, None).
    """
    if masks_np is None or classes_np is None or masks_np.size == 0:
        return (None, None, None)

    # choose class indices to test
    if danger_only:
        test_idxs = [i for i,c in enumerate(classes_np) if int(c) in DANGER_RED]
    else:
        test_idxs = [i for i,c in enumerate(classes_np) if int(c) not in exclude_classes]
    if not test_idxs:
        return (None, None, None)

    # scale factors from frame to mask grid
    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1))
    sy = (mh - 1) / max(1, (H - 1))

    def _hit_at(xs, ys):
        mx = _clampi(int(round(xs * sx)), 0, mw-1)
        my = _clampi(int(round(ys * sy)), 0, mh-1)
        for i in test_idxs:
            if masks_np[i][my, mx] > 0.5:
                return int(classes_np[i])
        return None

    # ---- segment 1: JAKE_POINT -> triangle apex
    x0, y0 = map(int, jake_point)
    x1, y1 = map(int, tri_pos)
    x0 = _clampi(x0, 0, W-1); y0 = _clampi(y0, 0, H-1)
    x1 = _clampi(x1, 0, W-1); y1 = _clampi(y1, 0, H-1)

    dx = x1 - x0; dy = y1 - y0
    seg1_len = max(1e-6, math.hypot(dx, dy))
    n1 = max(1, int(seg1_len // max(1, step_px)))
    for k in range(1, n1 + 1):
        t = min(1.0, (k * step_px) / seg1_len)
        xs = _clampi(int(round(x0 + dx * t)), 0, W-1)
        ys = _clampi(int(round(y0 + dy * t)), 0, H-1)
        cls_hit = _hit_at(xs, ys)
        if cls_hit is not None:
            dist = math.hypot(xs - x0, ys - y0)  # Euclidean from Jake
            return (float(dist), (int(xs), int(ys)), int(cls_hit))

    # ---- segment 2: continue from triangle along angled probe (same as classify rays)
    dxr, dyr = TRIG_TABLE.get(theta_deg, (0.0, -1.0))  # default straight up
    n2 = max(1, int(up2_px // max(1, step_px)))
    for k in range(1, n2 + 1):
        t = k * step_px
        xs = _clampi(int(round(x1 + dxr * t)), 0, W-1)
        ys = _clampi(int(round(y1 + dyr * t)), 0, H-1)
        cls_hit = _hit_at(xs, ys)
        if cls_hit is not None:
            dist = seg1_len + math.hypot(xs - x1, ys - y1)  # piecewise length from Jake
            return (float(dist), (int(xs), int(ys)), int(cls_hit))

    return (None, None, None)


def first_mask_hit_along_segment(jake_point, tri_pos, masks_np, classes_np,
                                 H, W, exclude_classes=(RAIL_ID,), step_px=1):
    """
    Walk the straight segment from JAKE_POINT -> Jake's triangle apex.
    Return (distance_px, (x_hit, y_hit), class_id) for the first mask hit,
    skipping any classes in `exclude_classes`. If none, return (None, None, None).
    """
    if masks_np is None or masks_np.size == 0 or classes_np is None or len(tri_pos) != 2:
        return (None, None, None)

    x0, y0 = map(int, jake_point)
    x1, y1 = map(int, tri_pos)
    # clamp
    x0 = _clampi(x0, 0, W-1); y0 = _clampi(y0, 0, H-1)
    x1 = _clampi(x1, 0, W-1); y1 = _clampi(y1, 0, H-1)

    dx = x1 - x0
    dy = y1 - y0
    seg_len = math.hypot(dx, dy)
    if seg_len < 1e-6:
        return (None, None, None)

    # map-to-mask scale factors
    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1))
    sy = (mh - 1) / max(1, (H - 1))

    # prebuild indices of classes we actually test (skip excluded, e.g., RAIL_ID)
    test_idxs = [i for i, c in enumerate(classes_np) if int(c) not in exclude_classes]
    if not test_idxs:
        return (None, None, None)

    # step along the line; start at step 1 so we don't immediately "hit" Jake's pixel
    n_steps = max(1, int(seg_len // max(1, step_px)))
    for k in range(1, n_steps + 1):
        t = (k * step_px) / seg_len
        if t > 1.0: t = 1.0
        xs = _clampi(int(round(x0 + dx * t)), 0, W-1)
        ys = _clampi(int(round(y0 + dy * t)), 0, H-1)

        mx = _clampi(int(round(xs * sx)), 0, mw-1)
        my = _clampi(int(round(ys * sy)), 0, mh-1)

        for i in test_idxs:
            # masks_np is float in [0,1]; >0.5 treated as hit (consistent with rest of code)
            if masks_np[i][my, mx] > 0.5:
                # distance in pixels along the segment to this sample
                dist_px = math.hypot(xs - x0, ys - y0)
                return (float(dist_px), (int(xs), int(ys)), int(classes_np[i]))

    return (None, None, None)



# --- tunnel wall color gate (HSV) ---
LOWBARRIER1_ID   = 4
ORANGETRAIN_ID   = 6
WALL_STRIP_PX    = 10           # vertical strip height checked just above the barrier
WALL_MATCH_FRAC  = 0.95         # % of “wall” pixels required to relabel
WALL_ORANGE_LO = np.array([5,  80,  60], dtype=np.uint8)   # H,S,V (lo)
WALL_ORANGE_HI = np.array([35, 255, 255], dtype=np.uint8)  # H,S,V (hi)


def promote_lowbarrier_when_wall(frame_bgr, masks_np, classes_np,
                                 strip_px=WALL_STRIP_PX, frac_thresh=WALL_MATCH_FRAC):
    """
    If a LOWBARRIER1 has an orange 'tunnel wall' strip right behind it,
    relabel that instance to ORANGETRAIN (treated as RED).
    """
    if masks_np is None or classes_np is None or masks_np.size == 0:
        return classes_np

    H, W = frame_bgr.shape[:2]
    hsv = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2HSV)
    wall_u8 = cv2.inRange(hsv, WALL_ORANGE_LO, WALL_ORANGE_HI)  # 0/255

    # iterate only over LOWBARRIER1 instances
    for i, cls in enumerate(classes_np):
        if int(cls) != LOWBARRIER1_ID:
            continue

        m = masks_np[i]
        # upsample to frame size if needed
        if m.shape != (H, W):
            m_full = cv2.resize(m.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST).astype(bool)
        else:
            m_full = m.astype(bool, copy=False)

        ys, xs = np.where(m_full)
        if xs.size == 0:
            continue

        x0, x1 = xs.min(), xs.max()
        y0, _  = ys.min(), ys.max()

        # check a strip immediately above the barrier (toward smaller y)
        yb0 = max(0, y0 - strip_px)
        yb1 = y0
        if yb1 <= yb0:
            continue

        strip = wall_u8[yb0:yb1, x0:x1+1]
        if strip.size == 0:
            continue

        frac = float(cv2.countNonZero(strip)) / strip.size
        if frac >= frac_thresh:
            classes_np[i] = ORANGETRAIN_ID  # promote to a RED class

    return classes_np


# extra classes/sets
WARN_FOR_MOVE = {2, 3, 4, 5, 8}      # yellow set that should try to sidestep if a green exists
JUMP_SET      = {3, 5, 10}           # Jump, LowBarrier2, Sidewalk
DUCK_SET      = {2, 4}               # HighBarrier1, LowBarrier1

# action keys (change if your emulator uses different binds)
JUMP_KEY = "up"
DUCK_KEY = "down"

# --- "white-ish" lane probe (5x5 box counts) ---
# tune these if your Jake sprite/board highlight isn't pure white
WHITE_MIN = np.array([220, 220, 220], dtype=np.uint8)  # BGR lower bound
WHITE_MAX = np.array([255, 255, 255], dtype=np.uint8)  # BGR upper bound
BOX_RAD   = 2  # 5x5 => radius 2

def _count_white_around(img_bgr, pt, box_rad=BOX_RAD):
    H, W = img_bgr.shape[:2]
    x, y = pt
    x0 = max(0, x - box_rad); x1 = min(W, x + box_rad + 1)
    y0 = max(0, y - box_rad); y1 = min(H, y + box_rad + 1)
    roi = img_bgr[y0:y1, x0:x1]
    if roi.size == 0:
        return 0
    mask = cv2.inRange(roi, WHITE_MIN, WHITE_MAX)
    return int(cv2.countNonZero(mask))

def _detect_lane_by_whiteness(img_bgr):
    # returns lane index 0/1/2 chosen by the largest white count;
    # if all zero, returns None to keep previous lane
    counts = [
        _count_white_around(img_bgr, LANE_LEFT),
        _count_white_around(img_bgr, LANE_MID),
        _count_white_around(img_bgr, LANE_RIGHT),
    ]
    best_idx = int(np.argmax(counts))
    return best_idx if counts[best_idx] > 0 else None


# action cooldown so we don't spam jump/duck
try:
    last_action_ts
except NameError:
    last_action_ts = 0.0
ACTION_COOLDOWN_S = 0.0

# distance threshold (pixels) from Jake to triangle apex for action decisions
ACTION_DIST_PX = 30

def _is_warn(cls_id: int | None) -> bool:
    return (cls_id is not None) and (int(cls_id) in WARN_FOR_MOVE)

def _schedule(fn, *args, **kwargs):
    Thread(target=fn, args=args, kwargs=kwargs, daemon=True).start()

MIN_GREEN_AHEAD_PX = 400
def _filter_green_far(cands, jake_band_y: int, min_ahead_px: int = MIN_GREEN_AHEAD_PX):
    """Keep only green triangles that are at least `min_ahead_px` above Jake's y band."""
    out = []
    for c in cands:
        _, yt = c["pos"]
        if (jake_band_y - yt) >= min_ahead_px:  # keep if ≥ 400 px ahead
            out.append(c)
    return out

def first_red_hit_y(pos, masks_np, classes_np, H, W, band_px=6, step_px=5, max_up=SAMPLE_UP_PX):
    """Return the screen y of the first RED pixel straight above `pos`, or None."""
    if masks_np is None or masks_np.size == 0: return None
    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1)); sy = (mh - 1) / max(1, (H - 1))
    red_idx = [i for i, c in enumerate(classes_np) if int(c) in DANGER_RED]
    if not red_idx: return None

    x0, y0 = int(pos[0]), int(pos[1])
    x0 = _clampi(x0, 0, W-1); y0 = _clampi(y0, 0, H-1)

    for t in range(step_px, max_up + 1, step_px):
        y = _clampi(y0 - t, 0, H-1)
        for dx in range(-band_px, band_px + 1):
            x = _clampi(x0 + dx, 0, W-1)
            mx = _clampi(int(round(x * sx)), 0, mw-1)
            my = _clampi(int(round(y * sy)), 0, mh-1)
            for i in red_idx:
                if masks_np[i][my, mx] > 0.5:
                    return y
    return None

def first_hit_y(pos, masks_np, classes_np, H, W, class_set, band_px=6, step_px=5, max_up=SAMPLE_UP_PX):
    """Return the screen y of the first pixel (straight up) whose class ∈ class_set."""
    if masks_np is None or masks_np.size == 0: return None
    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1)); sy = (mh - 1) / max(1, (H - 1))
    idxs = [i for i, c in enumerate(classes_np) if int(c) in class_set]
    if not idxs: return None

    x0, y0 = int(pos[0]), int(pos[1])
    x0 = _clampi(x0, 0, W-1); y0 = _clampi(y0, 0, H-1)

    for t in range(step_px, max_up + 1, step_px):
        y = _clampi(y0 - t, 0, H-1)
        for dx in range(-band_px, band_px + 1):
            x = _clampi(x0 + dx, 0, W-1)
            mx = _clampi(int(round(x * sx)), 0, mw-1)
            my = _clampi(int(round(y * sy)), 0, mh-1)
            for i in idxs:
                if masks_np[i][my, mx] > 0.5:
                    return y
    return None


# Only step from RED into a YELLOW lane if its triangle is far enough ahead
MIN_YELLOW_AHEAD_PX = 400
def _filter_yellow_far(cands, jake_band_y: int, min_ahead_px: int = MIN_YELLOW_AHEAD_PX):
    """Keep only yellow triangles that are at least `min_ahead_px` above Jake's y band."""
    out = []
    for c in cands:
        _, yt = c["pos"]
        if (jake_band_y - yt) >= min_ahead_px:
            out.append(c)
    return out


def _try_duck():
    if not MOVEMENT_ENABLED:
        return
    global last_action_ts
    now = time.perf_counter()
    if now - last_action_ts >= ACTION_COOLDOWN_S:
        last_action_ts = now
        _schedule(pyautogui.press, DUCK_KEY)
try:
    last_move_ts
except NameError:
    last_move_ts = 0.0

MOVE_COOLDOWN_S = 0.10  # 100 ms

def _is_danger(cls_id: int | None) -> bool:
    return (cls_id is not None) and (int(cls_id) in DANGER_RED)

def _is_safe(cls_id: int | None) -> bool:
    return not _is_danger(cls_id)

def _filter_by_lane(cands, jx: int, lane_idx: int):
    """Prune triangles based on current lane:
       - lane 0 (left): drop triangles with x < jx
       - lane 2 (right): drop triangles with x > jx
       - lane 1 (mid): keep all
    """
    if lane_idx == 0:
        return [c for c in cands if c["pos"][0] >= jx]
    if lane_idx == 2:
        return [c for c in cands if c["pos"][0] <= jx]
    return cands

def _pick_best_safe_triangle(cands, jx: int):
    """Prefer triangles with hit_class == None; otherwise any non-danger.
       Break ties by smallest |x - jx|.
    """
    if not cands:
        return None
    none_hits  = [c for c in cands if c["hit_class"] is None]
    safe_hits  = [c for c in cands if c["hit_class"] is not None and _is_safe(c["hit_class"])]
    pool = none_hits if none_hits else safe_hits
    if not pool:
        return None
    # exclude triangles exactly aligned with Jake in x (no direction)
    pool = [c for c in pool if c["pos"][0] != jx] or pool
    return min(pool, key=lambda c: abs(c["pos"][0] - jx))

def _issue_move_towards_x(jx: int, tx: int):
    global lane, last_move_ts, _synth_block_until
    if not MOVEMENT_ENABLED:
        return

    now = time.perf_counter()
    if now - last_move_ts < MOVE_COOLDOWN_S:
        return

    if tx < jx and lane > MIN_LANE:
        pause_inference()  # 360ms freeze to avoid mid-lane frames
        _synth_block_until = time.monotonic() + SYNTHETIC_SUPPRESS_S
        pyautogui.press('left')
        lane = max(MIN_LANE, lane - 1)
        print(f"[AI MOVE] left -> Lane {lane}")
        last_move_ts = now

    elif tx > jx and lane < MAX_LANE:
        pause_inference()  # 360ms freeze to avoid mid-lane frames
        _synth_block_until = time.monotonic() + SYNTHETIC_SUPPRESS_S
        pyautogui.press('right')
        lane = min(MAX_LANE, lane + 1)
        print(f"[AI MOVE] right -> Lane {lane}")
        last_move_ts = now

    else:
        print('WE ARE COOKED')

#============================================================================================================================================


# =======================
# Lane/keyboard state
# =======================
lane = 1
MIN_LANE = 0
MAX_LANE = 2
running = True

# ===== Debounce / cooldown =====
COOLDOWN_MS = 20
_last_press_ts = 0.0  # monotonic seconds

def on_press(key):
    global lane, running, _last_press_ts, _synth_block_until
    now = time.monotonic()

    # swallow AI-generated lane key events during the suppression window
    if key in (keyboard.Key.left, keyboard.Key.right) and now < _synth_block_until:
        return

    if key != keyboard.Key.esc and (now - _last_press_ts) * 1000.0 < COOLDOWN_MS:
        return

    try:
        if key == keyboard.Key.left:
            lane = max(MIN_LANE, lane - 1)
            _last_press_ts = now
            print(f"Moved Left into → Lane {lane}")

        elif key == keyboard.Key.right:
            lane = min(MAX_LANE, lane + 1)
            _last_press_ts = now
            print(f"Moved Right into → Lane {lane}")

        elif key == keyboard.Key.esc:
            running = False
            return False
    except Exception as e:
        print(f"Error: {e}")


# =======================
# System/Backends
# =======================
cv2.setUseOptimized(True)
try: cv2.setNumThreads(max(1, (os.cpu_count() or 1) - 1))
except Exception: pass

if torch.cuda.is_available():
    device, half = 0, True
    torch.backends.cudnn.benchmark = True
    try: torch.set_float32_matmul_precision('high')
    except Exception: pass
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device, half = "mps", False
else:
    device, half = "cpu", False

# =======================
# Model
# =======================
model = YOLO(weights)
try: model.fuse()
except Exception: pass

# warmup
_dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
_ = model.predict(_dummy, task="segment", imgsz=IMG_SIZE, device=device,
                  conf=CONF, iou=IOU, verbose=False, half=half, max_det=MAX_DET)

# =======================
# Precomputed
# =======================
TARGETS_BGR_F32 = np.array([(r,g,b)[::-1] for (r,g,b) in TARGET_COLORS_RGB], dtype=np.float32)
TOL2            = TOLERANCE * TOLERANCE

# Class buckets for probe classification
DANGER_RED   = {1, 6, 7, 11}
WARN_YELLOW  = {2, 3, 4, 5, 8}
BOOTS_PINK   = {0}

CLASS_COLOURS = {
    0:(255,255,0),1:(192,192,192),2:(0,128,255),3:(0,255,0),
    4:(255,0,255),5:(0,255,255),6:(255,128,0),7:(128,0,255),
    8:(0,0,128),9:(0,0,255),10:(128,128,0),11:(255,255,102)
}
LABELS = {
    0:"BOOTS",1:"GREYTRAIN",2:"HIGHBARRIER1",3:"JUMP",4:"LOWBARRIER1",
    5:"LOWBARRIER2",6:"ORANGETRAIN",7:"PILLAR",8:"RAMP",9:"RAILS",
    10:"SIDEWALK",11:"YELLOWTRAIN"
}

# ====== tiny helpers ======
def _clampi(v, lo, hi):
    return lo if v < lo else (hi if v > hi else v)

def _fmt_px(v):
    return f"{v:.1f}px" if v is not None else "n/a"

# =======================
# Parsec to front + click Start (non-blocking failures)
# =======================
try:
    subprocess.run(["osascript", "-e", 'tell application "Parsec" to activate'], check=False)
    time.sleep(0.4)
except Exception:
    pass

try:
    pyautogui.click(start_click)
except Exception:
    pass

# =======================
# Fast rails green finder
# =======================
def highlight_rails_mask_only_fast(img_bgr, rail_mask):
    H, W = rail_mask.shape
    if not rail_mask.any():
        return np.zeros((H, W), dtype=bool)

    rail_u8 = rail_mask.view(dtype=np.uint8) * 255
    x, y, w, h = cv2.boundingRect(rail_u8)
    img_roi  = img_bgr[y:y+h, x:x+w]
    mask_roi = rail_u8[y:y+h, x:x+w]

    img_f = img_roi.astype(np.float32, copy=False)
    diff  = img_f[:, :, None, :] - TARGETS_BGR_F32[None, None, :, :]
    dist2 = (diff * diff).sum(-1)
    colour_hit = (dist2 <= TOL2).any(-1)

    combined = np.logical_and(colour_hit, mask_roi.astype(bool))

    comp = combined.astype(np.uint8)
    n, lbls, stats, _ = cv2.connectedComponentsWithStats(comp, 8)
    if n <= 1: return np.zeros((H, W), dtype=bool)

    good = np.zeros_like(combined)
    areas = stats[1:, cv2.CC_STAT_AREA]
    hs    = stats[1:, cv2.CC_STAT_HEIGHT]
    keep  = np.where((areas >= MIN_REGION_SIZE) & (hs >= MIN_REGION_HEIGHT))[0] + 1
    for k in keep: good[lbls == k] = True

    full = np.zeros((H, W), dtype=bool)
    full[y:y+h, x:x+w] = good
    return full

def red_vs_green_score(red_mask, green_mask):
    k = (HEAT_BLUR_KSIZE, HEAT_BLUR_KSIZE)
    r = cv2.blur(red_mask.astype(np.float32, copy=False), k)
    g = cv2.blur(green_mask.astype(np.float32, copy=False), k)
    diff = r - g
    amax = float(np.max(np.abs(diff))) + 1e-6
    norm = (diff / (2.0 * amax) + 0.5)
    return np.clip(norm * 255.0, 0, 255.0).astype(np.uint8, copy=False)

def purple_triangles(score, H):
    top_ex = int(H * EXCLUDE_TOP_FRAC)
    bot_ex = int(H * EXCLUDE_BOTTOM_FRAC)
    dark = (score >= RED_SCORE_THRESH).astype(np.uint8, copy=False)
    if top_ex: dark[:top_ex, :] = 0
    if bot_ex: dark[-bot_ex:, :] = 0

    dark = cv2.morphologyEx(
        dark, cv2.MORPH_OPEN,
        cv2.getStructuringElement(cv2.MORPH_RECT, (5, 9)), iterations=1
    )
    total_dark = int(dark.sum())
    if total_dark == 0: return [], None

    frac_thresh = int(np.ceil(MIN_DARK_FRACTION * total_dark))
    n_lbl, lbls, stats, _ = cv2.connectedComponentsWithStats(dark, 8)
    if n_lbl <= 1: return [], None

    tris = []
    for lbl in range(1, n_lbl):
        area = stats[lbl, cv2.CC_STAT_AREA]
        if area >= MIN_DARK_RED_AREA and area >= frac_thresh:
            ys, xs = np.where(lbls == lbl)
            if ys.size == 0: continue
            y_top = ys.min()
            x_mid = int(xs[ys == y_top].mean())
            tris.append((x_mid, int(y_top)))

    if not tris: return [], None
    best = min(tris, key=lambda xy: xy[1])
    return tris, best

# ===== Bearing-based Jake triangle selection =====
def signed_degrees_from_vertical(dx, dy):
    if dx == 0 and dy == 0: return 0.0
    return -math.degrees(math.atan2(dx, -dy))

def select_triangle_by_bearing(tri_positions, jx, jy, target_deg, min_dy=6):
    best_i, best_deg, best_err = -1, None, None
    for i, (xt, yt) in enumerate(tri_positions):
        dy = yt - jy
        if dy >= -min_dy:  # must be above Jake
            continue
        deg = signed_degrees_from_vertical(xt - jx, dy)
        err = abs(deg - target_deg)
        if (best_err is None) or (err < best_err):
            best_i, best_deg, best_err = i, deg, err
    return best_i, best_deg, best_err

# ===== Lane-aware curved sampling (precompute sin/cos) =====
def _precompute_trig():
    angles = sorted(set([0.0,
        BEND_LEFT_STATE_RIGHT_DEG,
        BEND_MID_STATE_RIGHT_DEG,
        BEND_MID_STATE_LEFT_DEG,
        BEND_RIGHT_STATE_LEFT_DEG
    ]))
    table = {}
    for a in angles:
        r = math.radians(a)
        table[a] = (math.sin(r), -math.cos(r))  # (dx, dy) for unit ray (up = -y)
    return table
TRIG_TABLE = _precompute_trig()

def pick_bend_angle(jake_point, xt, x_ref, idx, best_idx):
    if idx == best_idx:
        return 0.0
    if jake_point == LANE_LEFT:
        return BEND_LEFT_STATE_RIGHT_DEG if xt > x_ref else 0.0
    if jake_point == LANE_RIGHT:
        return BEND_RIGHT_STATE_LEFT_DEG if xt < x_ref else 0.0
    if xt > x_ref: return BEND_MID_STATE_RIGHT_DEG
    if xt < x_ref: return BEND_MID_STATE_LEFT_DEG
    return 0.0

# --------- walk-the-ray classifier (first-hit wins) ----------
def classify_triangles_at_sample_curved(
    tri_positions, masks_np, classes_np, H, W,
    jake_point, x_ref, best_idx, sample_px=SAMPLE_UP_PX, step_px=RAY_STEP_PX
):
    if masks_np is None or classes_np is None or len(tri_positions) == 0:
        return [], [], [], []  # colours, rays, hit_class_ids, hit_distances_px

    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1))
    sy = (mh - 1) / max(1, (H - 1))

    red_idx    = [i for i, c in enumerate(classes_np) if int(c) in DANGER_RED]
    yellow_idx = [i for i, c in enumerate(classes_np) if int(c) in WARN_YELLOW]
    boots_idx  = [i for i, c in enumerate(classes_np) if int(c) in BOOTS_PINK]

    colours, rays, hit_class_ids, hit_distances_px = [], [], [], []
    max_k = max(1, sample_px // max(1, step_px))

    for idx, (x0, y0) in enumerate(tri_positions):
        theta = pick_bend_angle(jake_point, x0, x_ref, idx, best_idx)
        dx1, dy1 = TRIG_TABLE[theta]

        hit_colour = COLOR_GREEN
        hit_cls = None
        hit_dist_px = None

        found = False
        for k in range(1, max_k + 1):
            t  = k * step_px
            xs = _clampi(int(round(x0 + dx1 * t)), 0, W-1)
            ys = _clampi(int(round(y0 + dy1 * t)), 0, H-1)
            mx = _clampi(int(round(xs * sx)), 0, mw-1)
            my = _clampi(int(round(ys * sy)), 0, mh-1)

            # RED first (so if red exists at a point, we record red distance)
            for i in red_idx:
                if masks_np[i][my, mx] > 0.5:
                    hit_colour = COLOR_RED
                    hit_cls = int(classes_np[i])
                    hit_dist_px = float(t)
                    found = True
                    break
            if found: break
            # then YELLOW
            for i in yellow_idx:
                if masks_np[i][my, mx] > 0.5:
                    hit_colour = COLOR_YELLOW
                    hit_cls = int(classes_np[i])
                    hit_dist_px = float(t)
                    found = True
                    break
            if found: break
            # then BOOTS
            for i in boots_idx:
                if masks_np[i][my, mx] > 0.5:
                    hit_colour = COLOR_PINK
                    hit_cls = int(classes_np[i])
                    hit_dist_px = float(t)
                    found = True
                    break
            if found: break

        x1 = _clampi(int(round(x0 + dx1 * sample_px)), 0, W-1)
        y1 = _clampi(int(round(y0 + dy1 * sample_px)), 0, H-1)

        colours.append(hit_colour)
        rays.append(((int(x0), int(y0)), (x1, y1), float(theta)))
        hit_class_ids.append(hit_cls)
        hit_distances_px.append(hit_dist_px)

    return colours, rays, hit_class_ids, hit_distances_px

# -----------------------------------------------------------------------

# =======================
# Frame post-processing
# =======================
def process_frame_post(frame_bgr, yolo_res, jake_point):
    """
    Returns (…)
      tri_best_xy, tri_count, mask_count, to_cpu_ms, post_ms,
      masks_np, classes_np, rail_mask, green_mask,
      tri_positions, tri_colours, tri_rays,
      best_idx, best_deg, x_ref,
      tri_hit_classes, tri_summary
    """
    H, W = frame_bgr.shape[:2]
    if yolo_res.masks is None:
        return (None, 0, 0, 0.0, 0.0, None, None, None, None,
                [], [], [], None, None, None, [], [])

    t0 = time.perf_counter()
    masks_np = yolo_res.masks.data.detach().cpu().numpy()  # [n,h,w]
    if hasattr(yolo_res.masks, "cls") and yolo_res.masks.cls is not None:
        classes_np = yolo_res.masks.cls.detach().cpu().numpy().astype(int)
    else:
        classes_np = yolo_res.boxes.cls.detach().cpu().numpy().astype(int)

    to_cpu_ms = (time.perf_counter() - t0) * 1000.0
    mask_count = int(masks_np.shape[0])
    if mask_count == 0 or classes_np.size == 0:
        return (None, 0, mask_count, to_cpu_ms, 0.0, masks_np, classes_np, None, None,
                [], [], [], None, None, None, [], [])

    classes_np = promote_lowbarrier_when_wall(frame_bgr, masks_np, classes_np)

    rail_sel = (classes_np == RAIL_ID)
    if not np.any(rail_sel):
        return (None, 0, mask_count, to_cpu_ms, 0.0, masks_np, classes_np, None, None,
                [], [], [], None, None, None, [], [])

    t1 = time.perf_counter()
    rail_masks = masks_np[rail_sel].astype(bool, copy=False)
    union = np.any(rail_masks, axis=0).astype(np.uint8, copy=False)
    rail_mask = cv2.resize(union, (W, H), interpolation=cv2.INTER_NEAREST).astype(bool, copy=False)

    green = highlight_rails_mask_only_fast(frame_bgr, rail_mask)
    red   = np.logical_and(rail_mask, np.logical_not(green))
    score = red_vs_green_score(red, green)
    tri_positions, tri_best = purple_triangles(score, H)

    # Jake triangle by bearing
    lane_name = lane_name_from_point(jake_point)
    target_deg = LANE_TARGET_DEG[lane_name]
    xj, yj = jake_point

    MIN_AHEAD_FROM_JAKE_PX = 120  # tune (e.g., 80–160)
    tri_positions = [p for p in tri_positions if (yj - p[1]) >= MIN_AHEAD_FROM_JAKE_PX]

    # Filter triangles by absolute angle from vertical (≤ 45°) and at least 6px above Jake
    ANGLE_MAX_DEG = 45.0
    MIN_DY_ABOVE  = 100

    def _angle_ok(p):
        xt, yt = p
        dy = yt - yj
        if dy >= -MIN_DY_ABOVE:     # must be above Jake
            return False
        deg = signed_degrees_from_vertical(xt - xj, dy)
        return abs(deg) <= ANGLE_MAX_DEG

    tri_positions = [p for p in tri_positions if _angle_ok(p)]

    best_idx, best_deg, _ = select_triangle_by_bearing(tri_positions, xj, yj, target_deg, min_dy=6)

    # x_ref for bending
    if lane_name == "mid" and (best_idx is not None) and (0 <= best_idx < len(tri_positions)):
        x_ref = tri_positions[best_idx][0]
    else:
        x_ref = xj

    tri_colours, tri_rays, tri_hit_classes, tri_hit_dists = classify_triangles_at_sample_curved(
        tri_positions, masks_np, classes_np, H, W, jake_point, x_ref, best_idx,
        SAMPLE_UP_PX, RAY_STEP_PX
    )

    if tri_positions and any(ty >= (H) for _, ty in tri_positions):
    # compute rail_grad/edge_dist and run the loop

        # --- edge-danger override: triangles too close to rail edges in bottom half ---
        EDGE_PAD_PX = 50

        # distance-to-rail-edge map (in pixels)
        rail_grad = cv2.morphologyEx(rail_mask.astype(np.uint8), cv2.MORPH_GRADIENT,
                                    cv2.getStructuringElement(cv2.MORPH_RECT, (3,3)))
        # distanceTransform gives distance to the nearest 0; make edges=0, elsewhere=255
        edge_bg = (rail_grad == 0).astype(np.uint8) * 255
        edge_dist = cv2.distanceTransform(edge_bg, cv2.DIST_L2, 5)

        # mark triangles as danger if they're within EDGE_PAD_PX of an edge
        # AND they sit in the bottom half of the screen (y >= H//2)
        for i, (tx, ty) in enumerate(tri_positions):
            if ty >= (H // 2) and edge_dist[int(ty), int(tx)] <= EDGE_PAD_PX:
                tri_colours[i]      = COLOR_RED          # show as red in the overlay
                tri_hit_classes[i]  = 1                  # any class in DANGER_RED; 1=GREYTRAIN works
                tri_hit_dists[i]    = 0.0                # optional: treat as immediate


    post_ms = (time.perf_counter() - t1) * 1000.0

    # Minimal movement-friendly summary (pos, hit_class id/label, is_jake)
    tri_summary = []
    for i, (x, y) in enumerate(tri_positions):
        cid = tri_hit_classes[i] if i < len(tri_hit_classes) else None
        hdist = tri_hit_dists[i] if i < len(tri_hit_dists) else None
        tri_summary.append({
            "pos": (int(x), int(y)),
            "hit_class": None if cid is None else int(cid),
            "hit_label": None if cid is None else LABELS.get(int(cid), f"C{int(cid)}"),
            "hit_dist_px": None if hdist is None else float(hdist),
            "is_jake": (i == best_idx)
        })


    #PATHING LOGIC HERE# =================================================================================================================================================================
    # ===== PATHING / ACTION LOGIC =================================================
    jake_tri = next((t for t in tri_summary if t.get("is_jake")), None)
    if jake_tri:
        jx, jy = jake_tri["pos"]
        jake_hit = jake_tri["hit_class"]

        # For movement logging: distance to the obstacle ahead of Jake

        y_hit_log = first_red_hit_y(jake_tri["pos"], masks_np, classes_np, H, W, band_px=6, step_px=5)
        obstacle_dist_px = (jy - y_hit_log) if y_hit_log is not None else None

        # "Yellow in Jake's lane" == Jake's own triangle has a WARN_YELLOW class.
        jake_cls = jake_tri.get("hit_class", None)
        if (jake_cls is not None) and (int(jake_cls) in WARN_YELLOW) and (int(jake_cls) in IMPACT_CLASSES):
            # theta actually used for Jake’s ray (matches overlay)
            theta_deg = float(tri_rays[best_idx][2]) if (best_idx is not None and 0 <= best_idx < len(tri_rays)) else 0.0
            allowed_set = JUMP_SET if int(jake_cls) in JUMP_SET else DUCK_SET

            dist_px, _, _ = first_mask_hit_starburst_then_ray_for_set(
                jake_point=JAKE_POINT,
                tri_pos=jake_tri["pos"],
                theta_deg=theta_deg,
                masks_np=masks_np, classes_np=classes_np, H=H, W=W,
                allowed_classes=allowed_set,
                up2_px=SAMPLE_UP_PX, step_px=2
            )

            # ---- token: only lane + class ----
            new_token = (lane, int(jake_cls))

            global IMPACT_TOKEN
            # Only arm if no timer for this token yet
            if IMPACT_TOKEN is None:
                if dist_px is not None and (IMPACT_MIN_PX < dist_px < IMPACT_MAX_PX):
                    _arm_impact_timer(float(dist_px), int(jake_cls))
                    IMPACT_TOKEN = new_token
                    print(f"[TIMER] lock token {IMPACT_TOKEN}")
                # else: don’t arm; wait for next frame when it enters the window

            else:
                if IMPACT_TOKEN == new_token:
                    # If the previous timer already fired (no longer alive), allow immediate re-arm
                    if not (IMPACT_TIMER and getattr(IMPACT_TIMER, "is_alive", lambda: False)()):
                        _arm_impact_timer(float(dist_px), int(jake_cls))
                        IMPACT_TOKEN = new_token
                    # else: keep the existing live timer

                    # Same situation → do nothing (no cancel, no re-arm), even if dist jitters/out of window
                    pass
                else:
                    # Situation changed (lane or class) → cancel old and arm once for new (if in window)
                    _cancel_impact_timer("token change")
                    if dist_px is not None and (IMPACT_MIN_PX < dist_px < IMPACT_MAX_PX):
                        _arm_impact_timer(float(dist_px), int(jake_cls))
                        IMPACT_TOKEN = new_token
                        print(f"[TIMER] lock token {IMPACT_TOKEN} (replaced)")
                    else:
                        IMPACT_TOKEN = None  # no valid new timer yet
        else:
            # Jake’s triangle not yellow/impact anymore → cancel & unlock
            if IMPACT_TOKEN is not None:
                _cancel_impact_timer("no longer impact in Jake lane")
                IMPACT_TOKEN = None


        # --- 2) Lateral pathing decisions (policy: GREEN first) --------------------
        # Build reusable candidate pools (excluding Jake's current triangle)
        greens  = [t for t in tri_summary if t["hit_class"] is None]
        yellows = [t for t in tri_summary if (t["hit_class"] is not None and int(t["hit_class"]) in WARN_FOR_MOVE)]
        reds    = [t for t in tri_summary if (t["hit_class"] is not None and int(t["hit_class"]) in DANGER_RED)]

        # Lane-based pruning
        greens  = _filter_by_lane(greens,  jx, lane)
        yellows = _filter_by_lane(yellows, jx, lane)
        reds    = _filter_by_lane(reds,    jx, lane)

        # Only consider yellow if it's far enough ahead of the Jake band (e.g., 400px)
        jake_band_y   = jake_point[1]  # 1340 with your lane points
        yellows_far   = _filter_yellow_far(yellows, jake_band_y)  # uses MIN_YELLOW_AHEAD_PX
        greens_far  = _filter_green_far(greens, jake_band_y)


        def _nearest_x(cands):
            return min(cands, key=lambda c: abs(c["pos"][0] - jx)) if cands else None

        # Score for "least-bad red": prefer the ray that hits red furthest away.
        # If 'hit_dist_px' isn't present in tri_summary, fall back to apex distance.
        jake_band_y = jake_point[1]

        def _red_score(c):
            y_hit = first_red_hit_y(c["pos"], masks_np, classes_np, H, W, band_px=6, step_px=5)
            if y_hit is None:
                return (float('inf'), -abs(c["pos"][0] - jx))  # no red in range = strictly better
            ahead_px = jake_band_y - y_hit  # larger = farther ahead
            return (ahead_px, -abs(c["pos"][0] - jx))


        # If Jake is already GREEN, stay put.
        if jake_hit is None:
            pass

        # RED ahead: GREEN -> (far) YELLOW -> least-bad RED (all-red fallback)
        elif _is_danger(jake_hit):
            tgt = _nearest_x(greens)
            if tgt is not None:
                if MOVEMENT_ENABLED:
                    print(f"[MOVE] RED ahead → GREEN: obstacle={LABELS.get(int(jake_hit), str(jake_hit))}, dist={_fmt_px(obstacle_dist_px)}; target_x={tgt['pos'][0]}")
                _issue_move_towards_x(jx, tgt["pos"][0])
            else:
                tgt = _nearest_x(yellows_far)   # only yellows ≥ threshold above the band
                if tgt is not None:
                    if MOVEMENT_ENABLED:
                        ahead_px = jake_band_y - tgt["pos"][1]
                        print(f"[MOVE] RED ahead → YELLOW (far): obstacle={LABELS.get(int(jake_hit), str(jake_hit))}, dist={_fmt_px(obstacle_dist_px)}; yellow_ahead={int(ahead_px)}px (≥{MIN_YELLOW_AHEAD_PX})")
                    _issue_move_towards_x(jx, tgt["pos"][0])
                else:
                    if reds:
                        # When choosing best_red:
                        best_red = max(reds, key=_red_score)
                        tx = best_red["pos"][0]
                        if tx != jx:                      # avoid needless move if staying is best
                            _issue_move_towards_x(jx, tx)

                    # else: boxed in → no lateral move this frame

        # YELLOW ahead: try GREEN; if none, rely on countermeasures (jump/duck)
        elif _is_warn(jake_hit):
            tgt = _nearest_x(greens_far)  # only consider far-enough greens
            if tgt is not None:
                _issue_move_towards_x(jx, tgt["pos"][0])
    # else: no safe far green → hold lane; jump/duck handled above

            # else: no green → hold lane; jumps/ducks already handled above
# ============================================================================

# ============================================================================
#END
# ============================================================================

    return (tri_best, len(tri_positions), mask_count, to_cpu_ms, post_ms,
            masks_np, classes_np, rail_mask, green,
            tri_positions, tri_colours, tri_rays,
            best_idx, best_deg, x_ref,
            tri_hit_classes, tri_summary)

# =======================
# Viz helpers
# =======================
def _colour_for_point(x, y, masks_np, classes_np, H, W):
    if masks_np is None or classes_np is None or masks_np.size == 0: return COLOR_GREEN
    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1)); sy = (mh - 1) / max(1, (H - 1))
    mx = _clampi(int(round(x * sx)), 0, mw-1)
    my = _clampi(int(round(y * sy)), 0, mh-1)
    cls_here = None
    for m, c in zip(masks_np, classes_np):
        if m[my, mx] > 0.5: cls_here = int(c); break
    if cls_here in DANGER_RED:   return COLOR_RED
    if cls_here in WARN_YELLOW:  return COLOR_YELLOW
    if cls_here in BOOTS_PINK:   return COLOR_PINK
    return COLOR_GREEN

def draw_triangle(img, x, y, size=TRI_SIZE_PX, colour=COLOR_RED):
    h = int(size * 1.2)
    pts = np.array([[x, y], [x-size, y+h], [x+size, y+h]], np.int32)
    cv2.fillConvexPoly(img, pts, colour)
    cv2.polylines(img, [pts.reshape(-1,1,2)], True, COLOR_BLACK, 1, cv2.LINE_AA)

def triangle_pts(x, y, size=TRI_SIZE_PX):
    h = int(size * 1.2)
    return np.array([[x, y], [x-size, y+h], [x+size, y+h]], np.int32)

def render_overlays(frame_bgr, masks_np, classes_np, rail_mask, green_mask,
                    tri_positions, tri_colours, tri_rays, best_idx, best_deg, x_ref, jake_point):
    out = frame_bgr.copy()
    H, W = out.shape[:2]
    alpha = 0.45

    if masks_np is not None and classes_np is not None and masks_np.size:
        for m, c in zip(masks_np, classes_np):
            m_full = m
            if m.shape != (H, W):
                m_full = cv2.resize(m.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST).astype(bool)
            color = CLASS_COLOURS.get(int(c), (255,255,255))
            out[m_full] = (np.array(color, dtype=np.uint8) * alpha + out[m_full] * (1 - alpha)).astype(np.uint8)
            ys, xs = np.where(m_full)
            if xs.size:
                xc, yc = int(xs.mean()), int(ys.mean())
                label = LABELS.get(int(c), f"C{int(c)}")
                cv2.putText(out, label, (max(5, xc-40), max(20, yc)),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, COLOR_BLACK, 2, cv2.LINE_AA)
                cv2.putText(out, label, (max(5, xc-40), max(20, yc)),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 1, cv2.LINE_AA)

    if rail_mask is not None:
        tint = out.copy()
        tint[rail_mask] = (0, 0, 255)
        out = cv2.addWeighted(tint, 0.30, out, 0.70, 0)
    if green_mask is not None:
        out[green_mask] = (0, 255, 0)

    # tiny scout lines (viz only)
    if tri_positions:
        for (x, y) in tri_positions:
            y_end = max(0, y - SAMPLE_UP_PX)
            for yy in range(y, y_end - 1, -1):
                out[yy, x] = _colour_for_point(x, yy, masks_np, classes_np, H, W)

    # starburst to Jake
    xj, yj = jake_point
    for idx, (xt, yt) in enumerate(tri_positions):
        xt = _clampi(int(xt), 0, W-1); yt = _clampi(int(yt), 0, H-1)
        deg_signed = signed_degrees_from_vertical(xt - xj, yt - yj)
        cv2.line(out, (xj, yj), (xt, yt),
                 COLOR_CYAN if idx == best_idx else COLOR_WHITE, 2, cv2.LINE_AA)
        mx = (xj + xt) // 2; my = (yj + yt) // 2

    # curved sampling rays (viz)
    for (p0, p1, theta) in tri_rays:
        cv2.line(out, p0, p1, (255,255,255), 2, cv2.LINE_AA)
        mx = (p0[0] + p1[0]) // 2; my = (p0[1] + p1[1]) // 2

    for (x, y), col in zip(tri_positions, tri_colours):
        draw_triangle(out, int(x), int(y), colour=col)

    lane_name = lane_name_from_point(jake_point)
    target_deg = LANE_TARGET_DEG[lane_name]
    if best_idx is not None and 0 <= best_idx < len(tri_positions):
        xt, yt = tri_positions[best_idx]

        # theta used for that triangle in classify_triangles_at_sample_curved
        theta_deg = tri_rays[best_idx][2] if best_idx < len(tri_rays) else 0.0

        dist_px, hit_xy, hit_cls = first_mask_hit_starburst_then_ray(
            jake_point=jake_point,
            tri_pos=(int(xt), int(yt)),
            theta_deg=float(theta_deg),
            masks_np=masks_np, classes_np=classes_np, H=H, W=W,
            up2_px=SAMPLE_UP_PX, step_px=2,
            exclude_classes=(RAIL_ID,),   # skip rails
            danger_only=False             # set True to only consider DANGER_RED
        )

        # draw starburst segment (cyan + thicker)
        xj, yj = jake_point
        cv2.line(out, (xj, yj), (int(xt), int(yt)), COLOR_CYAN, 3, cv2.LINE_AA)

        # draw the angled continuation for viz (cyan + thicker)
        dxr, dyr = TRIG_TABLE.get(float(theta_deg), (0.0, -1.0))
        xe = _clampi(int(round(xt + dxr * SAMPLE_UP_PX)), 0, W-1)
        ye = _clampi(int(round(yt + dyr * SAMPLE_UP_PX)), 0, H-1)
        cv2.line(out, (int(xt), int(yt)), (xe, ye), COLOR_CYAN, 3, cv2.LINE_AA)

        # OPTIONAL: cyan outline around the best triangle so it pops
        cv2.polylines(out, [triangle_pts(int(xt), int(yt)).reshape(-1,1,2)], True, COLOR_CYAN, 3, cv2.LINE_AA)


        if hit_xy is not None:
            cv2.circle(out, hit_xy, 6, (0, 0, 0), -1, cv2.LINE_AA)
            cv2.circle(out, hit_xy, 4, (255, 255, 255), -1, cv2.LINE_AA)

        dist_text = "∞" if dist_px is None else f"{dist_px:.1f}px"
        if hit_cls is not None:
            dist_text += f" → {LABELS.get(hit_cls, str(hit_cls))}"
        midx = (xj + int(xt)) // 2
        midy = (yj + int(yt)) // 2 - 10
        cv2.putText(out, f"Jake→tri (ray) first-hit: {dist_text}",
                    (max(5, midx - 160), max(24, midy)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, COLOR_BLACK, 2, cv2.LINE_AA)
        cv2.putText(out, f"Jake→tri (ray) first-hit: {dist_text}",
                    (max(5, midx - 160), max(24, midy)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 1, cv2.LINE_AA)
        
        # TIME KEEPING

        elapsed_ms = (time.perf_counter() - start_time) * 1000.0
        time_str = f"{elapsed_ms:.3f} ms"

        # position in bottom-right corner
        (text_w, text_h), _ = cv2.getTextSize(time_str, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 1)
        x_pos = out.shape[1] - text_w - 10  # 10px from right edge
        y_pos = out.shape[0] - 10           # 10px from bottom edge

        # draw text
        cv2.putText(out, time_str, (x_pos, y_pos),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2, cv2.LINE_AA)
        cv2.putText(out, time_str, (x_pos, y_pos),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 1, cv2.LINE_AA)

        # -------------------------------------------------------------------------------


    # top-left JAKE_POINT state
    cv2.putText(out, f"JAKE_POINT: {lane_name.upper()}",
                (10, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,255,0), 2, cv2.LINE_AA)

    return out

# =======================
# Live loop
# =======================   

listener = keyboard.Listener(on_press=on_press)
listener.start()

sct = mss()
frame_idx = 0

from mss import mss

if advertisement:
    CHECK_X, CHECK_Y = 1030, 900
else:
    CHECK_X, CHECK_Y = 870, 895

#===========================================Resource consuption monitoring===========================================

import psutil
import os
import subprocess, threading, re
import subprocess
import threading
import re

process = psutil.Process(os.getpid())

def print_system_usage():
    cpu_percent = psutil.cpu_percent(interval=None)
    mem_info = process.memory_info()
    rss_mb = mem_info.rss / (1024 ** 2)  # Resident memory in MB
    print(f"[SYS] CPU: {cpu_percent:.1f}%  |  RAM: {rss_mb:.1f} MB")

    import subprocess, threading, re

def stream_mps_gpu_stats():
    # requires sudo; run your script with:  sudo python your_script.py
    cmd = ["powermetrics", "--samplers", "gpu_power", "-i", "200"]
    p = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1)
    busy_re = re.compile(r"GPU Busy\s*=\s*(\d+)%")
    power_re = re.compile(r"GPU Power\s*=\s*([\d\.]+)\s*W")
    for line in p.stdout:
        m1 = busy_re.search(line); m2 = power_re.search(line)
        if m1 or m2:
            busy = m1.group(1) if m1 else "?"
            power = m2.group(1) if m2 else "?"
            print(f"[GPU] Busy {busy}% | Power {power} W")

# call once before your while-loop:
threading.Thread(target=stream_mps_gpu_stats, daemon=True).start()

#=====================================================================================================================

# =======================

save_frames = False
power_metrics = True

# =======================

while running:
    frame_start_time = time.perf_counter()

    _now = time.monotonic()
    if _now < PAUSE_UNTIL:
        time.sleep(PAUSE_UNTIL - _now)
        continue

    frame_idx += 1
    print('/n')
    print(f'===================================== Operating on frame {frame_idx} =====================================')

    # --- Screen grab ---
    t0_grab = time.perf_counter()
    left, top, width, height = snap_coords
    raw = sct.grab({"left": left, "top": top, "width": width, "height": height})
    frame_bgr = np.array(raw)[:, :, :3]  # BGRA -> BGR
    grab_ms = (time.perf_counter() - t0_grab) * 1000.0

    # --- ABSOLUTE SCREEN pixel check ---
    t0_check = time.perf_counter()
    arr = np.array(sct.grab({"left": CHECK_X, "top": CHECK_Y, "width": 1, "height": 1}))
    b, g, r, a = arr[0, 0]
    TOL = 20
    target = (61, 156, 93)

    pixel_check_ms = (time.perf_counter() - t0_check) * 1000.0

    # --- Lane detection ---
    t0_lane = time.perf_counter()
    _detected = _detect_lane_by_whiteness(frame_bgr)
    if _detected is not None:
        lane = _detected  # 0/1/2
    JAKE_POINT = LANE_POINTS[lane]
    lane_ms = (time.perf_counter() - t0_lane) * 1000.0

    # --- Inference ---
    t0_inf = time.perf_counter()
    res_list = model.predict(
        [frame_bgr], task="segment", imgsz=IMG_SIZE, device=device,
        conf=CONF, iou=IOU, verbose=False, half=half, max_det=MAX_DET, batch=1
    )
    infer_ms = (time.perf_counter() - t0_inf) * 1000.0
    yres = res_list[0]

    # --- Postproc ---
    t0_post = time.perf_counter()
    (tri_best_xy, tri_count, mask_count, to_cpu_ms, post_ms,
     masks_np, classes_np, rail_mask, green_mask, tri_positions, tri_colours,
     tri_rays, best_idx, best_deg, x_ref,
     tri_hit_classes, tri_summary) = process_frame_post(frame_bgr, yres, JAKE_POINT)
    postproc_ms = (time.perf_counter() - t0_post) * 1000.0

    total_proc_ms = grab_ms + pixel_check_ms + lane_ms + infer_ms + postproc_ms

    if save_frames:
        elapsed_no_post = time.perf_counter() - frame_start_time
        print(f"Frame {frame_idx} WITHOUT POSTPROC WAS: {elapsed_no_post * 1000:.2f} ms")

    if save_frames:
        t0_overlay = time.perf_counter()
        overlay = render_overlays(frame_bgr, masks_np, classes_np, rail_mask, green_mask,
                                  tri_positions, tri_colours, tri_rays, best_idx, best_deg, x_ref, JAKE_POINT)
        out_path = out_dir / f"live_overlay_{frame_idx:05d}.jpg"
        cv2.imwrite(str(out_path), overlay)
        overlay_ms = (time.perf_counter() - t0_overlay) * 1000.0
    else:
        overlay_ms = 0.0

    total_elapsed_ms = (time.perf_counter() - frame_start_time) * 1000.0

    if power_metrics:
        print_system_usage()


    # --- Timing summary ---
    print(
        f"[TIMINGS] Grab={grab_ms:.2f} ms | PixelChk={pixel_check_ms:.2f} ms | "
        f"LaneDet={lane_ms:.2f} ms | Inference={infer_ms:.2f} ms | "
        f"Postproc={postproc_ms:.2f} ms | Overlay={overlay_ms:.2f} ms | "
        f"TOTAL={total_elapsed_ms:.2f} ms"
    )

# Cleanup
listener.join()
print("Script halted.")


YOLO11n-seg summary (fused): 113 layers, 2,836,908 parameters, 0 gradients, 10.2 GFLOPs
[BOOT] movement muted
/n
[SYS] CPU: 23.6%  |  RAM: 480.3 MB
[TIMINGS] Grab=82.75 ms | PixelChk=28.88 ms | LaneDet=0.24 ms | Inference=657.43 ms | Postproc=165.52 ms | Overlay=0.00 ms | TOTAL=937.34 ms
/n
[SYS] CPU: 41.2%  |  RAM: 556.2 MB
[TIMINGS] Grab=50.97 ms | PixelChk=29.02 ms | LaneDet=0.08 ms | Inference=42.60 ms | Postproc=125.46 ms | Overlay=0.00 ms | TOTAL=248.30 ms
/n
[BOOT] movement unmuted
[SYS] CPU: 52.5%  |  RAM: 528.9 MB
[TIMINGS] Grab=36.85 ms | PixelChk=24.61 ms | LaneDet=0.04 ms | Inference=38.20 ms | Postproc=128.72 ms | Overlay=0.00 ms | TOTAL=228.86 ms
/n
[SYS] CPU: 44.1%  |  RAM: 544.8 MB
[TIMINGS] Grab=23.32 ms | PixelChk=16.72 ms | LaneDet=0.05 ms | Inference=65.48 ms | Postproc=135.21 ms | Overlay=0.00 ms | TOTAL=240.92 ms
/n
[SYS] CPU: 35.6%  |  RAM: 570.9 MB
[TIMINGS] Grab=30.75 ms | PixelChk=19.81 ms | LaneDet=0.05 ms | Inference=38.06 ms | Postproc=130.32 ms | Overlay=0

In [1]:
sudo powermetrics --samplers gpu_power -i 200

SyntaxError: invalid syntax (1559998765.py, line 1)

In [10]:
# fast_capture.py
# High-FPS, low-latency screenshot capture (capture-only).
# - Crops at source (mss region) to avoid post-crop cost
# - Two-slot ring buffer (always read the freshest frame)
# - Zero-copy NumPy view over mss bytes (BGRA) to minimize overhead
# - Esc to quit (demo mode)
#
# Usage (as module):
#   grabber = FastGrabber(left=100, top=100, width=640, height=360)
#   ts, frame_bgra = grabber.latest()   # Non-blocking; returns None until first frame
#   ...
#   grabber.stop()
#
# Run file directly to see a minimal FPS print demo (no UI).

import time
import threading
from collections import deque
from typing import Optional, Tuple

import numpy as np
try:
    import mss  # pip install mss
except Exception as e:
    raise SystemExit("Please `pip install mss` first") from e


class FastGrabber:
    """
    Fast screen grabber that continuously captures a cropped region and exposes the latest frame.

    Frame format: BGRA uint8, shape = (H, W, 4) as a zero-copy NumPy view over mss bytes.
    NOTE: The returned array is READ-ONLY (backed by an immutable bytes buffer). Copy only if needed.
    """

    def __init__(self, left: int, top: int, width: int, height: int, maxlen: int = 2, throttle_hz: Optional[int] = None):
        """
        :param left, top, width, height: capture region (pixels) — crop AT SOURCE for speed
        :param maxlen: ring buffer length (2 recommended)
        :param throttle_hz: optional soft cap on capture rate; None for 'as fast as possible'
        """
        self.region = {"left": int(left), "top": int(top), "width": int(width), "height": int(height)}
        self._sct = mss.mss()
        self._buf = deque(maxlen=maxlen)  # (ts, np.ndarray BGRA view)
        self._running = threading.Event()
        self._running.set()
        self._throttle_dt = (1.0 / throttle_hz) if throttle_hz else 0.0

        self._thr = threading.Thread(target=self._loop, name="FastGrabber", daemon=True)
        self._thr.start()

    def _grab_once(self) -> Tuple[float, np.ndarray]:
        ts = time.perf_counter()
        shot = self._sct.grab(self.region)  # Cropped capture at source
        # Zero-copy view over raw bytes (BGRA)
        arr = np.frombuffer(shot.raw, dtype=np.uint8)
        arr = arr.reshape((shot.height, shot.width, 4))
        return ts, arr  # BGRA view (read-only)

    def _loop(self):
        next_deadline = 0.0
        while self._running.is_set():
            ts, frame = self._grab_once()
            self._buf.append((ts, frame))
            if self._throttle_dt > 0.0:
                next_deadline += self._throttle_dt
                sleep_s = max(0.0, next_deadline - time.perf_counter())
                if sleep_s:
                    time.sleep(sleep_s)

    def latest(self) -> Optional[Tuple[float, np.ndarray]]:
        """Return the most recent (ts, frame_bgra) or None if not yet available. Non-blocking."""
        return self._buf[-1] if self._buf else None

    def stop(self):
        self._running.clear()
        if self._thr.is_alive():
            self._thr.join(timeout=1.0)
        try:
            self._sct.close()
        except Exception:
            pass


# ---------- Minimal demo (no display) ----------
if __name__ == "__main__":
    import sys

    # Example region (edit to your game area)
    LEFT, TOP, WIDTH, HEIGHT = 0, 0, 1010, 1812


    grabber = FastGrabber(LEFT, TOP, WIDTH, HEIGHT, throttle_hz=None)  # as fast as possible

    print("Capturing... (press Esc to quit)")
    # Simple Esc listener without extra deps
    try:
        import termios, tty, select
        def kbhit():
            dr, _, _ = select.select([sys.stdin], [], [], 0)
            return bool(dr)
        old = termios.tcgetattr(sys.stdin)
        tty.setcbreak(sys.stdin.fileno())
        esc_pressed = False
    except Exception:
        # Fallback: Ctrl+C to quit on platforms without termios (e.g., Windows)
        old = None
        esc_pressed = False

    frames = 0
    t0 = time.perf_counter()
    last_print = t0

    try:
        while True:
            item = grabber.latest()
            if item is not None:
                frames += 1
            now = time.perf_counter()
            if now - last_print >= 1.0:
                fps = frames / (now - t0)
                print(f"[FastGrabber] ~{fps:5.1f} FPS   (region {WIDTH}x{HEIGHT})")
                last_print = now

            # Esc to exit (POSIX)
            if old is not None and kbhit():
                ch = sys.stdin.read(1)
                if ch == '\x1b':  # ESC
                    esc_pressed = True
            if esc_pressed:
                break

            # tiny yield to keep the loop responsive without throttling capture thread
            time.sleep(0.001)
    except KeyboardInterrupt:
        pass
    finally:
        if old is not None:
            termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old)
        grabber.stop()
        print("Stopped.")


Capturing... (press Esc to quit)
[FastGrabber] ~529.6 FPS   (region 1010x1812)
[FastGrabber] ~555.1 FPS   (region 1010x1812)
[FastGrabber] ~567.5 FPS   (region 1010x1812)
[FastGrabber] ~570.5 FPS   (region 1010x1812)
Stopped.


In [13]:
#!/usr/bin/env python3
# fast_capture.py
# High-FPS, low-latency screenshot capture (capture-only) with real capture FPS
# and turnaround (age) stats.
#
# - Crops at source (mss region)
# - Two-slot ring buffer (always read freshest frame; drops stale ones)
# - Zero-copy NumPy view over mss bytes (BGRA)
# - Prints true CAPTURE FPS and TURNAROUND median/p95/max once per second
# - Esc to quit (POSIX & Windows), or Ctrl+C
#
# Usage as module:
#   grabber = FastGrabber(left=100, top=100, width=640, height=360)
#   item = grabber.latest_with_age()  # -> (ts, frame_bgra, age_ms) or None
#   ...
#   grabber.stop()

import time
import threading
import statistics
from collections import deque
from typing import Optional, Tuple

import numpy as np
try:
    import mss  # pip install mss
except Exception as e:
    raise SystemExit("Please `pip install mss` first") from e


class FastGrabber:
    """
    Fast screen grabber that continuously captures a cropped region and exposes the latest frame.

    Frame format: BGRA uint8, shape = (H, W, 4) as a zero-copy NumPy view over mss bytes.
    NOTE: The returned array is READ-ONLY (backed by an immutable bytes buffer). Copy only if needed.
    """

    def __init__(
        self,
        left: int,
        top: int,
        width: int,
        height: int,
        maxlen: int = 2,
        throttle_hz: Optional[int] = None,
    ):
        """
        :param left, top, width, height: capture region (pixels) — crop AT SOURCE for speed
        :param maxlen: ring buffer length (2 recommended)
        :param throttle_hz: optional soft cap on capture rate; None = as fast as possible
        """
        self.region = {"left": int(left), "top": int(top), "width": int(width), "height": int(height)}
        self._sct = mss.mss()
        self._buf = deque(maxlen=maxlen)  # (ts, np.ndarray BGRA view)
        self._running = threading.Event()
        self._running.set()
        self._throttle_dt = (1.0 / throttle_hz) if throttle_hz else 0.0

        # stats
        self._cap_count = 0                # total captured frames
        self._t0 = time.perf_counter()     # start time for stats

        self._thr = threading.Thread(target=self._loop, name="FastGrabber", daemon=True)
        self._thr.start()

    def _grab_once(self) -> Tuple[float, np.ndarray]:
        ts = time.perf_counter()
        shot = self._sct.grab(self.region)  # Cropped capture at source
        # Zero-copy view over raw bytes (BGRA)
        arr = np.frombuffer(shot.raw, dtype=np.uint8).reshape((shot.height, shot.width, 4))
        return ts, arr  # BGRA view (read-only)

    def _loop(self):
        next_deadline = 0.0
        while self._running.is_set():
            ts, frame = self._grab_once()
            self._buf.append((ts, frame))
            self._cap_count += 1
            if self._throttle_dt > 0.0:
                next_deadline += self._throttle_dt
                sleep_s = max(0.0, next_deadline - time.perf_counter())
                if sleep_s:
                    time.sleep(sleep_s)

    def latest(self) -> Optional[Tuple[float, np.ndarray]]:
        """Return the most recent (ts, frame_bgra) or None if not yet available. Non-blocking."""
        return self._buf[-1] if self._buf else None

    def latest_with_age(self) -> Optional[Tuple[float, np.ndarray, float]]:
        """Return the most recent (ts, frame_bgra, age_ms) or None if not yet available."""
        item = self.latest()
        if not item:
            return None
        ts, frame = item
        age_ms = (time.perf_counter() - ts) * 1000.0
        return ts, frame, age_ms

    def capture_stats(self) -> Tuple[int, float, float]:
        """Return (total_captures, elapsed_s, avg_fps_since_start)."""
        elapsed = max(1e-9, time.perf_counter() - self._t0)
        return self._cap_count, elapsed, self._cap_count / elapsed

    def stop(self):
        self._running.clear()
        if self._thr.is_alive():
            self._thr.join(timeout=1.0)
        try:
            self._sct.close()
        except Exception:
            pass


# ---------- Minimal demo ----------
def _posix_esc_support():
    """Return (kbhit, readch, cleanup) functions for POSIX, or (None, None, None) on failure."""
    try:
        import sys, termios, tty, select
        def kbhit():
            dr, _, _ = select.select([sys.stdin], [], [], 0)
            return bool(dr)
        def readch():
            return sys.stdin.read(1)
        old = termios.tcgetattr(sys.stdin)
        tty.setcbreak(sys.stdin.fileno())
        def cleanup():
            termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old)
        return kbhit, readch, cleanup
    except Exception:
        return None, None, None


def _windows_esc_support():
    """Return (kbhit, readch, cleanup) functions for Windows, or (None, None, None) on failure."""
    try:
        import msvcrt
        def kbhit():
            return msvcrt.kbhit()
        def readch():
            return msvcrt.getch().decode(errors='ignore')
        def cleanup():
            return
        return kbhit, readch, cleanup
    except Exception:
        return None, None, None


if __name__ == "__main__":
    # Example region (edit to your game area)
    LEFT, TOP, WIDTH, HEIGHT = 0, 0, 1010, 1812

    grabber = FastGrabber(LEFT, TOP, WIDTH, HEIGHT, throttle_hz=None)  # as fast as possible
    print("Capturing... (press Esc to quit)")

    # Cross-platform Esc support
    kbhit, readch, cleanup = _posix_esc_support()
    if kbhit is None:
        kbhit, readch, cleanup = _windows_esc_support()
    esc_pressed = False

    # For stats
    last_seen_ts = None
    ages_ms: list[float] = []
    last_tick = time.perf_counter()
    last_cap_count = 0

    try:
        while True:
            info = grabber.latest_with_age()
            if info:
                ts, frame_bgra, age_ms = info
                if ts != last_seen_ts:
                    last_seen_ts = ts
                    ages_ms.append(age_ms)

            now = time.perf_counter()
            if now - last_tick >= 1.0:
                # True capture FPS from producer thread
                total, _, _ = grabber.capture_stats()
                cap_fps_inst = (total - last_cap_count) / (now - last_tick)
                last_cap_count = total

                if ages_ms:
                    ages_sorted = sorted(ages_ms)
                    med = statistics.median(ages_sorted)
                    p95 = ages_sorted[int(0.95 * (len(ages_sorted) - 1))]
                    worst = ages_sorted[-1]
                else:
                    med = p95 = worst = float('nan')

                print(
                    f"[FastGrabber] CAP {cap_fps_inst:6.1f} fps | "
                    f"TURNAROUND  median {med:6.2f} ms  p95 {p95:6.2f} ms  max {worst:6.2f} ms "
                    f"(region {WIDTH}x{HEIGHT})"
                )

                ages_ms.clear()
                last_tick = now

            # Esc to exit
            if kbhit and kbhit():
                ch = readch()
                if ch == '\x1b':  # ESC
                    esc_pressed = True
            if esc_pressed:
                break

            # tiny yield to keep loop responsive
            time.sleep(0.001)

    except KeyboardInterrupt:
        pass
    finally:
        if cleanup:
            cleanup()
        grabber.stop()
        print("Stopped.")


Capturing... (press Esc to quit)
[FastGrabber] CAP   18.9 fps | TURNAROUND  median  47.11 ms  p95  77.43 ms  max  79.17 ms (region 1010x1812)
[FastGrabber] CAP   21.9 fps | TURNAROUND  median  44.00 ms  p95  48.08 ms  max  49.48 ms (region 1010x1812)
[FastGrabber] CAP   22.8 fps | TURNAROUND  median  43.08 ms  p95  51.73 ms  max  76.01 ms (region 1010x1812)
[FastGrabber] CAP   23.9 fps | TURNAROUND  median  42.68 ms  p95  45.88 ms  max  47.73 ms (region 1010x1812)
[FastGrabber] CAP   22.0 fps | TURNAROUND  median  43.79 ms  p95  53.80 ms  max  69.35 ms (region 1010x1812)
[FastGrabber] CAP   23.0 fps | TURNAROUND  median  42.67 ms  p95  46.67 ms  max  59.20 ms (region 1010x1812)
[FastGrabber] CAP   21.0 fps | TURNAROUND  median  42.16 ms  p95 105.07 ms  max 114.14 ms (region 1010x1812)
[FastGrabber] CAP   20.9 fps | TURNAROUND  median  42.31 ms  p95  72.46 ms  max  78.84 ms (region 1010x1812)
[FastGrabber] CAP   23.0 fps | TURNAROUND  median  43.21 ms  p95  50.51 ms  max  69.42 ms (regi

In [1]:
#!/usr/bin/env python3
# ffmpeg_capture.py
# High-FPS capture of a large ROI via ffmpeg/avfoundation → rawvideo pipe.
# Same public API style as your FastGrabber: latest_with_age(), capture_stats().
#
# Why faster? AVFoundation's screen pipeline is optimized & can output NV12 (1.5 B/px)
# which cuts bandwidth vs Quartz BGRA (4 B/px). We crop at source, and stream frames.

import os
import sys
import time
import threading
import subprocess
import statistics
from collections import deque
from typing import Optional, Tuple

import numpy as np
import cv2  # pip install opencv-python

class FfmpegGrabber:
    """
    Spawn ffmpeg to capture a screen region and stream frames via stdout as rawvideo.
    Works on macOS with -f avfoundation. Use BGRA (no convert) or NV12 (smaller bandwidth).

    latest_with_age() -> (ts, frame, age_ms)
      - if pix_fmt='bgra' and return_bgr=True: returns BGR uint8 (H,W,3)
      - if pix_fmt='bgra' and return_bgr=False: returns BGRA uint8 (H,W,4) view
      - if pix_fmt='nv12' and return_bgr=True: returns BGR converted from NV12
      - if pix_fmt='nv12' and return_bgr=False: returns the NV12 buffer (H*3//2, W) uint8
    """

    def __init__(
        self,
        left: int,
        top: int,
        width: int,
        height: int,
        framerate: int = 120,
        pix_fmt: str = "nv12",          # 'nv12' (fast) or 'bgra' (simple)
        input_device: str = "1:none",   # <<< CHANGE THIS for your screen (see notes below)
        return_bgr: bool = True,        # convert to BGR in Python (cost ~1–2 ms @ 1080p NV12)
        ring_len: int = 2,
    ):
        """
        input_device examples (macOS):
          - List devices:  ffmpeg -f avfoundation -list_devices true -i ""
          - Primary display is often "1:none" or "2:none", or "Capture screen 0"
          - You can also try: input_device='Capture screen 0'
        """
        self.left, self.top, self.width, self.height = map(int, (left, top, width, height))
        self.framerate = int(framerate)
        self.pix_fmt = pix_fmt.lower()
        assert self.pix_fmt in ("bgra", "nv12")
        self.input_device = input_device
        self.return_bgr = return_bgr

        # bytes per frame
        if self.pix_fmt == "bgra":
            self.bytes_per_frame = self.width * self.height * 4
        else:  # NV12: 1.5 bytes per pixel
            self.bytes_per_frame = self.width * self.height * 3 // 2

        self._buf = deque(maxlen=ring_len)   # (ts, frame_obj)
        self._running = threading.Event()
        self._running.set()
        self._cap_count = 0
        self._t0 = time.perf_counter()

        self.proc = self._start_ffmpeg()
        self._thr = threading.Thread(target=self._reader, name="FfmpegGrabber", daemon=True)
        self._thr.start()

    def _start_ffmpeg(self) -> subprocess.Popen:
        # Build ffmpeg command
        # -f avfoundation: macOS screen capture
        # -framerate: desired capture fps
        # -i "<device>": see __init__ docstring for finding the right string
        # -vf crop=WxH:x:y crops at source
        # -pix_fmt: 'bgra' or 'nv12' (NV12 drastically reduces bandwidth)
        # -f rawvideo - : write raw frames to stdout
        vf = f"crop={self.width}:{self.height}:{self.left}:{self.top}"
        cmd = [
            "ffmpeg",
            "-hide_banner", "-loglevel", "error",
            "-f", "avfoundation",
            "-framerate", str(self.framerate),
            "-i", self.input_device,
            "-vf", vf,
            "-pix_fmt", self.pix_fmt,
            "-an", "-sn",
            "-f", "rawvideo",
            "-"
        ]
        # If your ffmpeg expects a different device string style, try:
        #   self.input_device = "Capture screen 0"
        # or use the numeric index you see in -list_devices

        return subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            bufsize=0  # unbuffered
        )

    def _read_exact(self, n: int) -> bytes:
        """Read exactly n bytes from ffmpeg stdout."""
        data = bytearray(n)
        view = memoryview(data)
        read = 0
        while read < n and self._running.is_set():
            chunk = self.proc.stdout.read(n - read)
            if not chunk:
                break
            view[read:read + len(chunk)] = chunk
            read += len(chunk)
        if read != n:
            return b""
        return bytes(data)

    def _reader(self):
        while self._running.is_set():
            ts = time.perf_counter()
            raw = self._read_exact(self.bytes_per_frame)
            if not raw:
                # Try to detect process death
                if self.proc.poll() is not None:
                    break
                # Otherwise, continue; might be a transient read issue
                continue

            # Interpret the raw frame depending on pix_fmt
            if self.pix_fmt == "bgra":
                frame = np.frombuffer(raw, dtype=np.uint8).reshape(self.height, self.width, 4)
                if self.return_bgr:
                    # Drop alpha; this is just a view, copy if you'll mutate
                    frame_out = frame[..., :3]
                else:
                    frame_out = frame  # BGRA view
            else:
                # NV12: pack as (H*3/2, W) then cvtColor
                yuv = np.frombuffer(raw, dtype=np.uint8).reshape(self.height * 3 // 2, self.width)
                if self.return_bgr:
                    frame_out = cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR_NV12)
                else:
                    frame_out = yuv  # NV12 buffer (H*3/2, W)

            self._buf.append((ts, frame_out))
            self._cap_count += 1

        self._running.clear()

    def latest(self) -> Optional[Tuple[float, np.ndarray]]:
        return self._buf[-1] if self._buf else None

    def latest_with_age(self) -> Optional[Tuple[float, np.ndarray, float]]:
        item = self.latest()
        if not item:
            return None
        ts, frame = item
        age_ms = (time.perf_counter() - ts) * 1000.0
        return ts, frame, age_ms

    def capture_stats(self):
        elapsed = max(1e-9, time.perf_counter() - self._t0)
        return self._cap_count, elapsed, self._cap_count / elapsed

    def stop(self):
        self._running.clear()
        try:
            if self.proc and self.proc.poll() is None:
                self.proc.terminate()
                try:
                    self.proc.wait(timeout=1.0)
                except subprocess.TimeoutExpired:
                    self.proc.kill()
        finally:
            try:
                if self.proc and self.proc.stdout:
                    self.proc.stdout.close()
            except Exception:
                pass


# ---------- Minimal demo (prints CAP fps + TURNAROUND stats) ----------
def _posix_esc_support():
    try:
        import termios, tty, select
        def kbhit():
            dr, _, _ = select.select([sys.stdin], [], [], 0)
            return bool(dr)
        def readch():
            return sys.stdin.read(1)
        old = termios.tcgetattr(sys.stdin)
        tty.setcbreak(sys.stdin.fileno())
        def cleanup():
            termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old)
        return kbhit, readch, cleanup
    except Exception:
        return None, None, None

def _windows_esc_support():
    try:
        import msvcrt
        def kbhit():
            return msvcrt.kbhit()
        def readch():
            return msvcrt.getch().decode(errors='ignore')
        def cleanup():
            return
        return kbhit, readch, cleanup
    except Exception:
        return None, None, None


if __name__ == "__main__":
    # === Set your capture region (keep your full size) ===
    LEFT, TOP, WIDTH, HEIGHT = 0, 0, 1010, 1812

    # 1) Find your screen device:
    #    ffmpeg -f avfoundation -list_devices true -i ""
    #    Use the "Capture screen X" entry or numeric index like "1:none"
    INPUT_DEVICE = "1:none"   # <— change this for your machine

    # 2) Choose pixel format: 'nv12' (faster) or 'bgra' (simple)
    PIX_FMT = "nv12"
    RETURN_BGR = True

    grab = FfmpegGrabber(
        LEFT, TOP, WIDTH, HEIGHT,
        framerate=120,
        pix_fmt=PIX_FMT,
        input_device=INPUT_DEVICE,
        return_bgr=RETURN_BGR,
        ring_len=2
    )
    print("Capturing via ffmpeg... (press Esc to quit)")

    # Cross-platform ESC
    kbhit, readch, cleanup = _posix_esc_support()
    if kbhit is None:
        kbhit, readch, cleanup = _windows_esc_support()
    esc = False

    last_seen_ts = None
    ages_ms = []
    last_tick = time.perf_counter()
    last_cap_count = 0

    try:
        while True:
            info = grab.latest_with_age()
            if info:
                ts, frame, age_ms = info
                if ts != last_seen_ts:
                    last_seen_ts = ts
                    ages_ms.append(age_ms)

            now = time.perf_counter()
            if now - last_tick >= 1.0:
                total, _, _ = grab.capture_stats()
                cap_fps_inst = (total - last_cap_count) / (now - last_tick)
                last_cap_count = total

                if ages_ms:
                    ages_sorted = sorted(ages_ms)
                    med = statistics.median(ages_sorted)
                    p95 = ages_sorted[int(0.95 * (len(ages_sorted) - 1))]
                    worst = ages_sorted[-1]
                else:
                    med = p95 = worst = float('nan')

                print(f"[FfmpegGrabber] CAP {cap_fps_inst:6.1f} fps | "
                      f"TURNAROUND median {med:6.2f} ms  p95 {p95:6.2f} ms  max {worst:6.2f} ms "
                      f"(region {WIDTH}x{HEIGHT}, fmt={PIX_FMT})")

                ages_ms.clear()
                last_tick = now

            if kbhit and kbhit():
                ch = readch()
                if ch == '\x1b':
                    esc = True
            if esc:
                break

            time.sleep(0.001)

    except KeyboardInterrupt:
        pass
    finally:
        if cleanup:
            cleanup()
        grab.stop()
        print("Stopped.")


Capturing via ffmpeg... (press Esc to quit)
[FfmpegGrabber] CAP    0.0 fps | TURNAROUND median    nan ms  p95    nan ms  max    nan ms (region 1010x1812, fmt=nv12)
[FfmpegGrabber] CAP    0.0 fps | TURNAROUND median    nan ms  p95    nan ms  max    nan ms (region 1010x1812, fmt=nv12)
[FfmpegGrabber] CAP    0.0 fps | TURNAROUND median    nan ms  p95    nan ms  max    nan ms (region 1010x1812, fmt=nv12)
[FfmpegGrabber] CAP    0.0 fps | TURNAROUND median    nan ms  p95    nan ms  max    nan ms (region 1010x1812, fmt=nv12)
[FfmpegGrabber] CAP    0.0 fps | TURNAROUND median    nan ms  p95    nan ms  max    nan ms (region 1010x1812, fmt=nv12)
[FfmpegGrabber] CAP    0.0 fps | TURNAROUND median    nan ms  p95    nan ms  max    nan ms (region 1010x1812, fmt=nv12)
Stopped.


In [3]:
#!/usr/bin/env python3
# avfoundation_nv12_capture.py
# High-FPS, low-latency screen region capture on macOS using ffmpeg/avfoundation → NV12 rawvideo.
# Prints true CAPTURE FPS (producer) and TURNAROUND latency (age when consumed).
#
# Usage examples:
#   python avfoundation_nv12_capture.py --list-devices
#   python avfoundation_nv12_capture.py --device "1:none" -x 0 -y 0 -W 1010 -H 1812 -r 120
#   python avfoundation_nv12_capture.py --device "Capture screen 0" -x 0 -y 0 -W 1010 -H 1812 --show
#
# Notes:
# - To find your device string:  ffmpeg -f avfoundation -list_devices true -i ""
# - Default pixel format is NV12 (1.5 B/px). You can switch to BGRA (4 B/px) with --pix-fmt bgra.

import argparse
import os
import sys
import time
import threading
import subprocess
import statistics
from collections import deque
from typing import Optional, Tuple

import numpy as np

try:
    import cv2  # only needed if --show or --return-bgr
except Exception:
    cv2 = None


class FfmpegGrabber:
    """
    Spawn ffmpeg to capture a screen region and stream frames via stdout as rawvideo.

    latest_with_age() -> (ts, frame, age_ms)
      - If pix_fmt='nv12' and return_bgr=True: frame is BGR uint8 (H,W,3)
      - If pix_fmt='nv12' and return_bgr=False: frame is NV12 buffer (H*3//2, W)
      - If pix_fmt='bgra' and return_bgr=True: frame is BGR (drops alpha)
      - If pix_fmt='bgra' and return_bgr=False: frame is BGRA (H,W,4)
    """

    def __init__(
        self,
        left: int,
        top: int,
        width: int,
        height: int,
        framerate: int = 120,
        pix_fmt: str = "nv12",             # 'nv12' (fast) or 'bgra' (simple)
        input_device: str = "1:none",      # e.g. "1:none" or "Capture screen 0"
        return_bgr: bool = False,          # convert to BGR in Python (cost ~1–2 ms @ ~2MP for NV12)
        ring_len: int = 2,
        ffmpeg_path: str = "ffmpeg",
        loglevel: str = "error",
    ):
        self.left, self.top, self.width, self.height = map(int, (left, top, width, height))
        self.framerate = int(framerate)
        self.pix_fmt = pix_fmt.lower()
        assert self.pix_fmt in ("nv12", "bgra")
        self.input_device = input_device
        self.return_bgr = return_bgr
        self.ffmpeg_path = ffmpeg_path
        self.loglevel = loglevel

        if self.return_bgr and cv2 is None:
            raise RuntimeError("--return-bgr or --show requires opencv-python (cv2)")

        # bytes per frame
        if self.pix_fmt == "bgra":
            self.bytes_per_frame = self.width * self.height * 4
        else:  # NV12: 1.5 bytes per pixel
            self.bytes_per_frame = self.width * self.height * 3 // 2

        self._buf = deque(maxlen=ring_len)   # (ts, frame_obj)
        self._running = threading.Event()
        self._running.set()
        self._cap_count = 0
        self._t0 = time.perf_counter()

        self.proc = self._start_ffmpeg()
        self._thr = threading.Thread(target=self._reader, name="FfmpegGrabber", daemon=True)
        self._thr.start()

    def _start_ffmpeg(self) -> subprocess.Popen:
        vf = f"crop={self.width}:{self.height}:{self.left}:{self.top}"
        cmd = [
            self.ffmpeg_path,
            "-hide_banner",
            "-loglevel", self.loglevel,
            "-f", "avfoundation",
            "-framerate", str(self.framerate),
            "-i", self.input_device,
            "-vf", vf,
            "-pix_fmt", self.pix_fmt,
            "-an", "-sn",
            "-f", "rawvideo",
            "-"
        ]
        try:
            proc = subprocess.Popen(
                cmd,
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                bufsize=0  # unbuffered
            )
        except FileNotFoundError as e:
            raise SystemExit("ffmpeg executable not found. Install with `brew install ffmpeg`.") from e
        return proc

    def _read_exact(self, n: int) -> bytes:
        """Read exactly n bytes from ffmpeg stdout."""
        out = bytearray(n)
        mv = memoryview(out)
        got = 0
        stdout = self.proc.stdout
        while got < n and self._running.is_set():
            chunk = stdout.read(n - got)
            if not chunk:
                break
            mv[got:got+len(chunk)] = chunk
            got += len(chunk)
        if got != n:
            return b""
        return bytes(out)

    def _reader(self):
        stderr_reader = threading.Thread(target=self._drain_stderr, daemon=True)
        stderr_reader.start()
        while self._running.is_set():
            ts = time.perf_counter()
            raw = self._read_exact(self.bytes_per_frame)
            if not raw:
                if self.proc.poll() is not None:
                    # process ended
                    break
                # transient short read; continue
                continue

            if self.pix_fmt == "bgra":
                frame_bgra = np.frombuffer(raw, dtype=np.uint8).reshape(self.height, self.width, 4)
                if self.return_bgr:
                    frame_out = frame_bgra[..., :3]  # drop alpha (view)
                else:
                    frame_out = frame_bgra
            else:
                # NV12
                yuv = np.frombuffer(raw, dtype=np.uint8).reshape(self.height * 3 // 2, self.width)
                if self.return_bgr:
                    frame_out = cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR_NV12)
                else:
                    frame_out = yuv

            self._buf.append((ts, frame_out))
            self._cap_count += 1

        self._running.clear()

    def _drain_stderr(self):
        # Helpful for capturing permission or device errors
        if not self.proc or not self.proc.stderr:
            return
        for line in self.proc.stderr:
            try:
                sys.stderr.write(line.decode(errors="ignore"))
            except Exception:
                pass

    def latest(self) -> Optional[Tuple[float, np.ndarray]]:
        return self._buf[-1] if self._buf else None

    def latest_with_age(self) -> Optional[Tuple[float, np.ndarray, float]]:
        item = self.latest()
        if not item:
            return None
        ts, frame = item
        age_ms = (time.perf_counter() - ts) * 1000.0
        return ts, frame, age_ms

    def capture_stats(self):
        elapsed = max(1e-9, time.perf_counter() - self._t0)
        return self._cap_count, elapsed, self._cap_count / elapsed

    def stop(self):
        self._running.clear()
        try:
            if self.proc and self.proc.poll() is None:
                self.proc.terminate()
                try:
                    self.proc.wait(timeout=1.0)
                except subprocess.TimeoutExpired:
                    self.proc.kill()
        finally:
            try:
                if self.proc and self.proc.stdout:
                    self.proc.stdout.close()
            except Exception:
                pass
            try:
                if self.proc and self.proc.stderr:
                    self.proc.stderr.close()
            except Exception:
                pass


def list_devices(ffmpeg_path: str = "ffmpeg"):
    print("Listing AVFoundation devices...\n", flush=True)
    try:
        # This command always exits non-zero; we just want its stderr
        proc = subprocess.Popen(
            [ffmpeg_path, "-f", "avfoundation", "-list_devices", "true", "-i", ""],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
        _, err = proc.communicate(timeout=5)
        sys.stdout.write(err.decode(errors="ignore"))
    except FileNotFoundError:
        print("ffmpeg not found. Install with `brew install ffmpeg`.")
    except subprocess.TimeoutExpired:
        print("ffmpeg list devices timed out.")


def main():
    ap = argparse.ArgumentParser(description="AVFoundation NV12 region capture via ffmpeg (macOS).")
    ap.add_argument("--list-devices", action="store_true", help="List AVFoundation devices and exit.")
    ap.add_argument("--device", type=str, default="1:none", help='Input device, e.g. "1:none" or "Capture screen 0".')
    ap.add_argument("-x", type=int, default=0, help="Left (pixels)")
    ap.add_argument("-y", type=int, default=0, help="Top (pixels)")
    ap.add_argument("-W", type=int, default=1010, help="Width (pixels)")
    ap.add_argument("-H", type=int, default=1812, help="Height (pixels)")
    ap.add_argument("-r", "--fps", type=int, default=120, help="Target capture frame rate")
    ap.add_argument("--pix-fmt", choices=["nv12", "bgra"], default="nv12", help="Pixel format to stream.")
    ap.add_argument("--return-bgr", action="store_true", help="Convert to BGR in Python (NV12→BGR or drop A).")
    ap.add_argument("--show", action="store_true", help="Preview window (implies --return-bgr).")
    ap.add_argument("--ffmpeg", type=str, default="ffmpeg", help="Path to ffmpeg binary.")
    args = ap.parse_args()

    if args.list_devices:
        list_devices(args.ffmpeg)
        return

    if args.show:
        args.return_bgr = True
        if cv2 is None:
            raise SystemExit("--show requires opencv-python (pip install opencv-python)")

    grab = FfmpegGrabber(
        left=args.x, top=args.y, width=args.W, height=args.H,
        framerate=args.fps, pix_fmt=args.pix_fmt, input_device=args.device,
        return_bgr=args.return_bgr, ring_len=2, ffmpeg_path=args.ffmpeg,
    )

    print(f"Capturing via ffmpeg/avfoundation (fmt={args.pix_fmt}, "
          f"device={args.device})...  Press Ctrl+C to quit.\n")

    last_seen_ts = None
    ages_ms = []
    last_tick = time.perf_counter()
    last_cap_count = 0

    try:
        while True:
            info = grab.latest_with_age()
            if info:
                ts, frame, age_ms = info
                if ts != last_seen_ts:
                    last_seen_ts = ts
                    ages_ms.append(age_ms)

                    if args.show:
                        cv2.imshow("Preview (BGR)", frame)
                        # ~120 Hz UI cap
                        if cv2.waitKey(1) == 27:  # Esc
                            break

            now = time.perf_counter()
            if now - last_tick >= 1.0:
                total, _, _ = grab.capture_stats()
                cap_fps_inst = (total - last_cap_count) / (now - last_tick)
                last_cap_count = total

                if ages_ms:
                    ages_sorted = sorted(ages_ms)
                    med = statistics.median(ages_sorted)
                    p95 = ages_sorted[int(0.95 * (len(ages_sorted) - 1))]
                    worst = ages_sorted[-1]
                else:
                    med = p95 = worst = float('nan')

                print(f"[AVF] CAP {cap_fps_inst:6.1f} fps | "
                      f"TURNAROUND median {med:6.2f} ms  p95 {p95:6.2f} ms  max {worst:6.2f} ms "
                      f"(region {args.W}x{args.H}, fmt={args.pix_fmt})")

                ages_ms.clear()
                last_tick = now

    except KeyboardInterrupt:
        pass
    finally:
        grab.stop()
        if args.show and cv2 is not None:
            try:
                cv2.destroyAllWindows()
            except Exception:
                pass
        print("Stopped.")


if __name__ == "__main__":
    main()


usage: ipykernel_launcher.py [-h] [--list-devices] [--device DEVICE] [-x X]
                             [-y Y] [-W W] [-H H] [-r FPS]
                             [--pix-fmt {nv12,bgra}] [--return-bgr] [--show]
                             [--ffmpeg FFMPEG]
ipykernel_launcher.py: error: ambiguous option: --f=/Users/marcus/Library/Jupyter/runtime/kernel-v3f2b5e4280663f99227882b893a119c4f4175d9f2.json could match --fps, --ffmpeg


SystemExit: 2

In [6]:
import time, statistics, mss
left, top, width, height = 0, 0, 1010, 1812
times = []
with mss.mss() as sct:
    for _ in range(10): sct.grab({"left":left,"top":top,"width":width,"height":height})  # warmup
    for _ in range(120):
        t0 = time.perf_counter()
        _ = sct.grab({"left":left,"top":top,"width":width,"height":height})
        times.append((time.perf_counter()-t0)*1000)
print("median:", statistics.median(times), "ms  p95:", sorted(times)[int(0.95*(len(times)-1))], "ms")


median: 38.70979161001742 ms  p95: 47.56199987605214 ms


In [7]:
#!/usr/bin/env python3
# fast_capture_bgr0.py
# High-FPS, low-latency screen region capture on macOS using ffmpeg/avfoundation.
# Optimized for the screen’s native pixel format (bgr0) to avoid costly color conversions.
# Prints true CAPTURE FPS (producer rate) and TURNAROUND (age when consumed) once per second.
#
# Requires: macOS + ffmpeg (`brew install ffmpeg`)
#
# Example:
#   python fast_capture_bgr0.py --device "Capture screen 0" -x 0 -y 0 -W 1010 -H 1812 -r 120
#   # If the name form fails, try a numeric form:
#   python fast_capture_bgr0.py --device "1:none" -x 0 -y 0 -W 1010 -H 1812 -r 120
#
# Notes:
# - We request bgr0 (native for screen capture) to avoid colorspace conversion overhead.
# - We disable buffering and pacing: -fflags nobuffer -flags low_delay -vsync 0.
# - AVFoundation can’t crop at source; ffmpeg applies -vf crop after capture (still fast enough here).
# - If you truly need 3-channel BGR, we drop A in Python via a view (no copy).

import argparse
import sys
import time
import threading
import subprocess
import statistics
from collections import deque
from typing import Optional, Tuple

import numpy as np


def list_avfoundation_devices(ffmpeg_path: str = "ffmpeg"):
    """Print AVFoundation devices (stderr) and return the raw listing text."""
    try:
        p = subprocess.run(
            [ffmpeg_path, "-f", "avfoundation", "-list_devices", "true", "-i", ""],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            check=False,
        )
    except FileNotFoundError:
        raise SystemExit("ffmpeg not found. Install with `brew install ffmpeg`.")
    text = p.stderr.decode(errors="ignore")
    # Print a cleaned list
    print("\n=== AVFoundation devices ===")
    for line in text.splitlines():
        if "AVFoundation video devices" in line or "AVFoundation audio devices" in line or "] [" in line:
            print(line)
    print("============================\n")
    return text


class AVFGrabber:
    """
    AVFoundation screen grabber via ffmpeg, streaming rawvideo frames.

    latest_with_age() -> (ts, frame, age_ms)
      - If pix_fmt='bgr0' and return_bgr=False: frame is BGR0 (H,W,4) view
      - If pix_fmt='bgr0' and return_bgr=True:  frame is BGR  (H,W,3) view (alpha dropped)
      - If pix_fmt='nv12' and return_bgr=False: frame is NV12 buffer (H*3//2, W)
      - If pix_fmt='nv12' and return_bgr=True:  frame is BGR (converted; needs opencv-python)
    """

    def __init__(
        self,
        left: int,
        top: int,
        width: int,
        height: int,
        framerate: int = 120,
        device: str = "Capture screen 0",
        pix_fmt: str = "bgr0",            # 'bgr0' recommended for speed (native); 'nv12' optional
        return_bgr: bool = False,         # drop alpha (bgr0→bgr) or convert (nv12→bgr)
        ring_len: int = 2,
        ffmpeg_path: str = "ffmpeg",
        loglevel: str = "error",
    ):
        self.left, self.top, self.width, self.height = map(int, (left, top, width, height))
        self.fps = int(framerate)
        self.device = device
        self.pix_fmt = pix_fmt.lower()
        assert self.pix_fmt in ("bgr0", "nv12"), "pix_fmt must be 'bgr0' or 'nv12'"
        self.return_bgr = return_bgr
        self.ffmpeg = ffmpeg_path
        self.loglevel = loglevel

        if self.pix_fmt == "bgr0":
            self._bytes_per_frame = self.width * self.height * 4
        else:
            self._bytes_per_frame = self.width * self.height * 3 // 2  # NV12

        if self.pix_fmt == "nv12" and self.return_bgr:
            try:
                import cv2  # noqa: F401
            except Exception:
                raise SystemExit("nv12→bgr conversion requires `pip install opencv-python`.")

        self._buf = deque(maxlen=ring_len)  # (ts, frame)
        self._running = threading.Event()
        self._running.set()
        self._cap_count = 0
        self._t0 = time.perf_counter()

        self.proc = self._start_ffmpeg()
        self._thr = threading.Thread(target=self._reader, name="AVFGrabber", daemon=True)
        self._thr.start()

        # Drain stderr so ffmpeg can't block
        self._stderr_thr = threading.Thread(target=self._drain_stderr, daemon=True)
        self._stderr_thr.start()

    def _start_ffmpeg(self) -> subprocess.Popen:
        vf = f"crop={self.width}:{self.height}:{self.left}:{self.top}"
        cmd = [
            self.ffmpeg,
            "-hide_banner", "-loglevel", self.loglevel,
            "-fflags", "nobuffer",
            "-flags", "low_delay",
            "-f", "avfoundation",
            "-framerate", str(self.fps),
            "-i", self.device,
            "-vf", vf,
            "-pix_fmt", self.pix_fmt,    # request native bgr0 for screen capture
            "-vsync", "0",               # do not repace/duplicate frames
            "-an", "-sn",
            "-f", "rawvideo", "-"
        ]
        try:
            return subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, bufsize=0)
        except FileNotFoundError:
            raise SystemExit("ffmpeg not found. Install with `brew install ffmpeg`.")

    def _read_exact(self, n: int) -> bytes:
        out = bytearray(n)
        mv = memoryview(out)
        got = 0
        r = self.proc.stdout
        while got < n and self._running.is_set():
            chunk = r.read(n - got)
            if not chunk:
                break
            mv[got:got + len(chunk)] = chunk
            got += len(chunk)
        return bytes(out) if got == n else b""

    def _reader(self):
        if self.pix_fmt == "nv12" and self.return_bgr:
            import cv2  # local import only when needed

        while self._running.is_set():
            ts = time.perf_counter()
            raw = self._read_exact(self._bytes_per_frame)
            if not raw:
                if self.proc.poll() is not None:
                    break
                continue

            if self.pix_fmt == "bgr0":
                bgr0 = np.frombuffer(raw, np.uint8).reshape(self.height, self.width, 4)
                frame = bgr0[..., :3] if self.return_bgr else bgr0  # drop alpha via view if requested
            else:
                # NV12 buffer shape (H*3/2, W)
                yuv = np.frombuffer(raw, np.uint8).reshape(self.height * 3 // 2, self.width)
                if self.return_bgr:
                    import cv2
                    frame = cv2.cvtColor(yuv, cv2.COLOR_YUV2BGR_NV12)
                else:
                    frame = yuv

            self._buf.append((ts, frame))
            self._cap_count += 1

        self._running.clear()

    def _drain_stderr(self):
        if not self.proc or not self.proc.stderr:
            return
        for line in self.proc.stderr:
            # Uncomment to see ffmpeg messages:
            # sys.stderr.write(line.decode(errors="ignore"))
            pass

    # ---- Public API (matches your previous FastGrabber) ----
    def latest(self) -> Optional[Tuple[float, np.ndarray]]:
        return self._buf[-1] if self._buf else None

    def latest_with_age(self) -> Optional[Tuple[float, np.ndarray, float]]:
        item = self.latest()
        if not item:
            return None
        ts, frame = item
        age_ms = (time.perf_counter() - ts) * 1000.0
        return ts, frame, age_ms

    def capture_stats(self):
        elapsed = max(1e-9, time.perf_counter() - self._t0)
        return self._cap_count, elapsed, self._cap_count / elapsed

    def stop(self):
        self._running.clear()
        try:
            if self.proc and self.proc.poll() is None:
                self.proc.terminate()
                try:
                    self.proc.wait(timeout=1.0)
                except subprocess.TimeoutExpired:
                    self.proc.kill()
        finally:
            for s in ("stdout", "stderr"):
                try:
                    getattr(self.proc, s).close()
                except Exception:
                    pass


def main():
    ap = argparse.ArgumentParser(description="Fast screen region capture via AVFoundation (bgr0).")
    ap.add_argument("--list-devices", action="store_true", help="List AVFoundation devices and exit.")
    ap.add_argument("--device", type=str, default="Capture screen 0",
                    help='AVFoundation input (e.g., "Capture screen 0" or "1:none").')
    ap.add_argument("-x", type=int, default=0, help="Left")
    ap.add_argument("-y", type=int, default=0, help="Top")
    ap.add_argument("-W", type=int, default=1010, help="Width")
    ap.add_argument("-H", type=int, default=1812, help="Height")
    ap.add_argument("-r", "--fps", type=int, default=120, help="Requested capture FPS")
    ap.add_argument("--pix-fmt", choices=["bgr0", "nv12"], default="bgr0", help="Pixel format.")
    ap.add_argument("--return-bgr", action="store_true",
                    help="Return 3-channel BGR (drops alpha for bgr0, converts for nv12).")
    args = ap.parse_args()

    if args.list_devices:
        list_avfoundation_devices()
        sys.exit(0)

    grab = AVFGrabber(
        left=args.x, top=args.y, width=args.W, height=args.H,
        framerate=args.fps, device=args.device,
        pix_fmt=args.pix_fmt, return_bgr=args.return_bgr,
        ring_len=2, ffmpeg_path="ffmpeg", loglevel="error"
    )

    print(f"Capturing via AVFoundation (device={args.device}, fmt={args.pix_fmt}) — Ctrl+C to stop")
    last_seen_ts = None
    ages_ms = []
    last_tick = time.perf_counter()
    last_cap_count = 0

    try:
        while True:
            info = grab.latest_with_age()
            if info:
                ts, frame, age_ms = info
                if ts != last_seen_ts:
                    last_seen_ts = ts
                    ages_ms.append(age_ms)

            now = time.perf_counter()
            if now - last_tick >= 1.0:
                total, _, _ = grab.capture_stats()
                cap_fps_inst = (total - last_cap_count) / (now - last_tick)
                last_cap_count = total

                if ages_ms:
                    ages_sorted = sorted(ages_ms)
                    med = statistics.median(ages_sorted)
                    p95 = ages_sorted[int(0.95 * (len(ages_sorted) - 1))]
                    worst = ages_sorted[-1]
                    ages_ms.clear()
                else:
                    med = p95 = worst = float('nan')

                print(f"[AVF bgr0] CAP {cap_fps_inst:6.1f} fps | "
                      f"TURNAROUND median {med:6.2f} ms  p95 {p95:6.2f} ms  max {worst:6.2f} ms "
                      f"(region {args.W}x{args.H})")
                last_tick = now
    except KeyboardInterrupt:
        pass
    finally:
        grab.stop()
        print("Stopped.")


if __name__ == "__main__":
    main()


usage: ipykernel_launcher.py [-h] [--list-devices] [--device DEVICE] [-x X]
                             [-y Y] [-W W] [-H H] [-r FPS]
                             [--pix-fmt {bgr0,nv12}] [--return-bgr]
ipykernel_launcher.py: error: argument -r/--fps: invalid int value: '/Users/marcus/Library/Jupyter/runtime/kernel-v3f2b5e4280663f99227882b893a119c4f4175d9f2.json'


SystemExit: 2