In [1]:
#Live inference testing -> No overlays, no saves -> Live state tracking

In [None]:
#!/usr/bin/env python3
# Live capture + ultra-fast single-image pipeline (no batching)
# ESC to stop. Captures a cropped region, processes each frame immediately, prints per-frame timings.

import os, time, math, subprocess
import cv2, numpy as np
from mss import mss
import pyautogui
from pynput import keyboard
import torch
from ultralytics import YOLO
from pathlib import Path

# Lane state
lane = 1
MIN_LANE = 0
MAX_LANE = 2
running = True

def on_press(key):
    """Single handler: arrow keys change lanes; ESC stops."""
    global lane, running
    try:
        if key == keyboard.Key.left:
            lane = max(MIN_LANE, lane - 1)
            print(f"Moved Left → Lane {lane}")
        elif key == keyboard.Key.right:
            lane = min(MAX_LANE, lane + 1)
            print(f"Moved Right → Lane {lane}")
        elif key == keyboard.Key.esc:
            print("ESC pressed — exiting")
            running = False
            return False
    except Exception as e:
        print(f"Error: {e}")

# =======================
# Capture / UI bootstrap
# =======================
# Start the keyboard listener once
listener = keyboard.Listener(on_press=on_press)
listener.start()

# Parsec to front (macOS)
try:
    subprocess.run(["osascript", "-e", 'tell application "Parsec" to activate'], check=False)
    time.sleep(0.4)
except Exception:
    pass

# Choose crop + click based on ad layout
advertisement = True
if advertisement:
    snap_coords = (644, 77, (1149-644), (981-75))  # (left, top, width, height)
    start_click = (1030, 900)
else:
    snap_coords = (483, 75, (988-483), (981-75))
    start_click = (870, 895)

# Click "Start"
try:
    pyautogui.click(start_click)
except Exception:
    pass

# Initialize capture
sct = mss()

# =======================
# Model / Pipeline config
# =======================
home       = os.path.expanduser("~")
weights    = f"{home}/models/jakes-loped/jakes-finder-mk1/1/weights.pt"

RAIL_ID    = 9
IMG_SIZE   = 512
CONF, IOU  = 0.30, 0.45
MAX_DET    = 30

# Color/region filter for "green rails"
TARGET_COLORS_RGB  = [(119,104,67), (81,42,45)]
TOLERANCE          = 20.0
MIN_REGION_SIZE    = 30
MIN_REGION_HEIGHT  = 150

# Heat/triangle params
HEAT_BLUR_KSIZE     = 51
RED_SCORE_THRESH    = 220
EXCLUDE_TOP_FRAC    = 0.40
EXCLUDE_BOTTOM_FRAC = 0.15
MIN_DARK_RED_AREA   = 1200
MIN_DARK_FRACTION   = 0.15

# Sampling ray
SAMPLE_UP_PX        = 180
RAY_STEP_PX         = 20  # probe step

# Jake lane points + bearing targets
LANE_LEFT   = (300, 1340)
LANE_MID    = (490, 1340)
LANE_RIGHT  = (680, 1340)
LANE_POINTS = (LANE_LEFT, LANE_MID, LANE_RIGHT)  # index by lane (0/1/2)

LANE_TARGET_DEG = {"left": -10.7, "mid": +1.5, "right": +15.0}

# Bend degrees
BEND_LEFT_STATE_RIGHT_DEG  = -20.0  # N1
BEND_MID_STATE_RIGHT_DEG   = -20.0  # N2
BEND_MID_STATE_LEFT_DEG    = +20.0  # N3
BEND_RIGHT_STATE_LEFT_DEG  = +20.0  # N4

# =======================
# System / backends
# =======================
cv2.setUseOptimized(True)
try:
    cv2.setNumThreads(max(1, (os.cpu_count() or 1) - 1))
except Exception:
    pass

if torch.cuda.is_available():
    device, half = 0, True
    torch.backends.cudnn.benchmark = True
    try: torch.set_float32_matmul_precision('high')
    except Exception: pass
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device, half = "mps", False
else:
    device, half = "cpu", False

# =======================
# Model (singleton, warmed)
# =======================
model = YOLO(weights)
try: model.fuse()
except Exception: pass

_dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
_ = model.predict(_dummy, task="segment", imgsz=IMG_SIZE, device=device,
                  conf=CONF, iou=IOU, verbose=False, half=half, max_det=MAX_DET)

# =======================
# Precomputed tables
# =======================
TARGETS_BGR_F32 = np.array([(r, g, b)[::-1] for (r, g, b) in TARGET_COLORS_RGB], dtype=np.float32)
TOL2            = TOLERANCE * TOLERANCE
MORPH_OPEN_SE   = cv2.getStructuringElement(cv2.MORPH_RECT, (5, 9))

DANGER_RED   = {1, 6, 7, 11}
WARN_YELLOW  = {2, 3, 4, 5, 8}
BOOTS_PINK   = {0}

def lane_name_from_point(p):
    if p == LANE_LEFT:  return "left"
    if p == LANE_MID:   return "mid"
    if p == LANE_RIGHT: return "right"
    return "mid"

def _clampi(v, lo, hi):
    return lo if v < lo else (hi if v > hi else v)

# =======================
# Pipeline helpers
# =======================
def highlight_rails_mask_only_fast(img_bgr, rail_mask):
    H, W = rail_mask.shape
    if not rail_mask.any():
        return np.zeros((H, W), dtype=bool)

    rail_u8 = rail_mask.view(dtype=np.uint8) * 255
    x, y, w, h = cv2.boundingRect(rail_u8)
    img_roi  = img_bgr[y:y+h, x:x+w]
    mask_roi = rail_u8[y:y+h, x:x+w]

    img_f = img_roi.astype(np.float32, copy=False)
    diff  = img_f[:, :, None, :] - TARGETS_BGR_F32[None, None, :, :]
    dist2 = (diff * diff).sum(-1)
    colour_hit = (dist2 <= TOL2).any(-1)

    combined = np.logical_and(colour_hit, mask_roi.astype(bool))
    comp = combined.astype(np.uint8)

    n, lbls, stats, _ = cv2.connectedComponentsWithStats(comp, 8)
    if n <= 1: return np.zeros((H, W), dtype=bool)

    good = np.zeros_like(combined)
    areas = stats[1:, cv2.CC_STAT_AREA]
    hs    = stats[1:, cv2.CC_STAT_HEIGHT]
    keep  = np.where((areas >= MIN_REGION_SIZE) & (hs >= MIN_REGION_HEIGHT))[0] + 1
    for k in keep: good[lbls == k] = True

    full = np.zeros((H, W), dtype=bool)
    full[y:y+h, x:x+w] = good
    return full

def red_vs_green_score(red_mask, green_mask):
    k = (HEAT_BLUR_KSIZE, HEAT_BLUR_KSIZE)
    r = cv2.blur(red_mask.astype(np.float32, copy=False), k)
    g = cv2.blur(green_mask.astype(np.float32, copy=False), k)
    diff = r - g
    amax = float(np.max(np.abs(diff))) + 1e-6
    norm = (diff / (2.0 * amax) + 0.5)
    return np.clip(norm * 255.0, 0, 255.0).astype(np.uint8, copy=False)

def purple_triangles(score, H):
    top_ex = int(H * EXCLUDE_TOP_FRAC)
    bot_ex = int(H * EXCLUDE_BOTTOM_FRAC)
    dark = (score >= RED_SCORE_THRESH).astype(np.uint8, copy=False)
    if top_ex: dark[:top_ex, :] = 0
    if bot_ex: dark[-bot_ex:, :] = 0
    dark = cv2.morphologyEx(dark, cv2.MORPH_OPEN, MORPH_OPEN_SE, iterations=1)

    total_dark = int(dark.sum())
    if total_dark == 0: return [], None
    frac_thresh = int(np.ceil(MIN_DARK_FRACTION * total_dark))

    n_lbl, lbls, stats, _ = cv2.connectedComponentsWithStats(dark, 8)
    if n_lbl <= 1: return [], None

    tris = []
    for lbl in range(1, n_lbl):
        area = stats[lbl, cv2.CC_STAT_AREA]
        if area >= MIN_DARK_RED_AREA and area >= frac_thresh:
            ys, xs = np.where(lbls == lbl)
            if ys.size == 0: continue
            y_top = ys.min()
            x_mid = int(xs[ys == y_top].mean())
            tris.append((x_mid, int(y_top)))

    if not tris: return [], None
    best = min(tris, key=lambda xy: xy[1])
    return tris, best

def signed_degrees_from_vertical(dx, dy):
    if dx == 0 and dy == 0: return 0.0
    return -math.degrees(math.atan2(dx, -dy))

def select_triangle_by_bearing(tri_positions, jx, jy, target_deg, min_dy=6):
    best_i, best_deg, best_err = -1, None, None
    for i, (xt, yt) in enumerate(tri_positions):
        dy = yt - jy
        if dy >= -min_dy:  # must be above Jake
            continue
        deg = signed_degrees_from_vertical(xt - jx, dy)
        err = abs(deg - target_deg)
        if (best_err is None) or (err < best_err):
            best_i, best_deg, best_err = i, deg, err
    return best_i, best_deg, best_err

def _precompute_trig():
    angles = sorted(set([0.0,
        BEND_LEFT_STATE_RIGHT_DEG,
        BEND_MID_STATE_RIGHT_DEG,
        BEND_MID_STATE_LEFT_DEG,
        BEND_RIGHT_STATE_LEFT_DEG
    ]))
    table = {}
    for a in angles:
        r = math.radians(a)
        table[a] = (math.sin(r), -math.cos(r))  # (dx, dy) unit ray (up = -y)
    return table
TRIG_TABLE = _precompute_trig()

def pick_bend_angle(jake_point, xt, x_ref, idx, best_idx):
    if idx == best_idx:
        return 0.0
    if jake_point == LANE_LEFT:
        return BEND_LEFT_STATE_RIGHT_DEG if xt > x_ref else 0.0
    if jake_point == LANE_RIGHT:
        return BEND_RIGHT_STATE_LEFT_DEG if xt < x_ref else 0.0
    if xt > x_ref: return BEND_MID_STATE_RIGHT_DEG
    if xt < x_ref: return BEND_MID_STATE_LEFT_DEG
    return 0.0

def classify_triangles_at_sample_curved(
    tri_positions, masks_np, classes_np, H, W,
    jake_point, x_ref, best_idx, sample_px=SAMPLE_UP_PX, step_px=RAY_STEP_PX
):
    if masks_np is None or classes_np is None or len(tri_positions) == 0:
        return []

    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1))
    sy = (mh - 1) / max(1, (H - 1))

    red_idx    = [i for i, c in enumerate(classes_np) if int(c) in DANGER_RED]
    yellow_idx = [i for i, c in enumerate(classes_np) if int(c) in WARN_YELLOW]
    boots_idx  = [i for i, c in enumerate(classes_np) if int(c) in BOOTS_PINK]

    colours = []
    max_k = max(1, sample_px // max(1, step_px))

    for idx, (x0, y0) in enumerate(tri_positions):
        theta = pick_bend_angle(jake_point, x0, x_ref, idx, best_idx)
        dx1, dy1 = TRIG_TABLE[theta]

        hit_colour = None
        for k in range(1, max_k + 1):
            t  = k * step_px
            xs = _clampi(int(round(x0 + dx1 * t)), 0, W-1)
            ys = _clampi(int(round(y0 + dy1 * t)), 0, H-1)
            mx = _clampi(int(round(xs * sx)), 0, mw-1)
            my = _clampi(int(round(ys * sy)), 0, mh-1)

            for i in red_idx:
                if masks_np[i][my, mx] > 0.5:
                    hit_colour = 3; break
            if hit_colour is not None: break
            for i in yellow_idx:
                if masks_np[i][my, mx] > 0.5:
                    hit_colour = 2; break
            if hit_colour is not None: break
            for i in boots_idx:
                if masks_np[i][my, mx] > 0.5:
                    hit_colour = 1; break
            if hit_colour is not None: break

        if hit_colour == 3: colours.append((0,0,255))
        elif hit_colour == 2: colours.append((0,255,255))
        elif hit_colour == 1: colours.append((203,192,255))
        else: colours.append((0,255,0))
    return colours

def process_frame_post(frame_bgr, yolo_res, jake_point):
    H, W = frame_bgr.shape[:2]
    if yolo_res.masks is None:
        return 0, 0, 0.0, 0.0

    t0 = time.perf_counter()
    masks_np = yolo_res.masks.data.detach().cpu().numpy()  # [n,h,w]
    if hasattr(yolo_res.masks, "cls") and yolo_res.masks.cls is not None:
        classes_np = yolo_res.masks.cls.detach().cpu().numpy().astype(int)
    else:
        classes_np = yolo_res.boxes.cls.detach().cpu().numpy().astype(int)
    to_cpu_ms = (time.perf_counter() - t0) * 1000.0

    mask_count = int(masks_np.shape[0])
    if mask_count == 0 or classes_np.size == 0:
        return 0, mask_count, to_cpu_ms, 0.0

    rail_sel = (classes_np == RAIL_ID)
    if not np.any(rail_sel):
        return 0, mask_count, to_cpu_ms, 0.0

    t1 = time.perf_counter()
    rail_masks = masks_np[rail_sel].astype(bool, copy=False)
    union = np.any(rail_masks, axis=0).astype(np.uint8, copy=False)
    rail_mask = cv2.resize(union, (W, H), interpolation=cv2.INTER_NEAREST).astype(bool, copy=False)

    green = highlight_rails_mask_only_fast(frame_bgr, rail_mask)
    red   = np.logical_and(rail_mask, np.logical_not(green))
    score = red_vs_green_score(red, green)
    tri_positions, _ = purple_triangles(score, H)

    # choose Jake triangle and set x_ref
    lane_name = lane_name_from_point(jake_point)
    target_deg = LANE_TARGET_DEG[lane_name]
    xj, yj = jake_point
    best_idx, _, _ = select_triangle_by_bearing(tri_positions, xj, yj, target_deg, min_dy=6)
    x_ref = tri_positions[best_idx][0] if (lane_name == "mid" and 0 <= best_idx < len(tri_positions)) else xj

    # run probe classification (not used further here, but triggers same work)
    _ = classify_triangles_at_sample_curved(
        tri_positions, masks_np, classes_np, H, W, jake_point, x_ref, best_idx,
        SAMPLE_UP_PX, RAY_STEP_PX
    )
    post_ms = (time.perf_counter() - t1) * 1000.0

    return len(tri_positions), mask_count, to_cpu_ms, post_ms

def process_image_bgr(img_bgr, name, jake_point):
    """Process one BGR frame already in memory and print timing line."""
    if img_bgr is None:
        return
    predict = model.predict  # local binding

    # In live mode there's no disk read; keep field for consistency
    read_ms = 0.0

    t0_inf = time.perf_counter()
    yres_list = predict(
        [img_bgr], task="segment", imgsz=IMG_SIZE, device=device,
        conf=CONF, iou=IOU, verbose=False, half=half, max_det=MAX_DET, batch=1
    )
    infer_ms = (time.perf_counter() - t0_inf) * 1000.0

    yres = yres_list[0]
    tri_count, mask_count, to_cpu_ms, post_ms = process_frame_post(img_bgr, yres, jake_point)
    proc_ms = infer_ms + to_cpu_ms + post_ms

    print(f"[live] {name}  "
          f"read {read_ms:.1f} | infer {infer_ms:.1f} | "
          f"to_cpu {to_cpu_ms:.1f} | post {post_ms:.1f} | "
          f"masks {mask_count} | triangles {tri_count} "
          f"=> proc {proc_ms:.1f} ms")

# =======================
# Live loop
# =======================

# Output folder for saved frames
out_dir = Path.home() / "Documents" / "GitHub" / "Ai-plays-SubwaySurfers" / "live_run"
out_dir.mkdir(parents=True, exist_ok=True)

prev_ts = time.time()
frame_idx = 0

while running:
    # Grab screen region
    left, top, width, height = snap_coords
    raw = sct.grab({"left": left, "top": top, "width": width, "height": height})
    frame_bgr = np.array(raw)[:, :, :3]  # BGRA -> BGR

    # Determine JAKE_POINT for this frame from current lane (0/1/2)
    jake_point = LANE_POINTS[lane]

    # Process immediately (no batching, no saving)
    frame_idx += 1
    process_image_bgr(frame_bgr, name=f"frame_{frame_idx:05d}", jake_point=jake_point)

    # Save a copy with JAKE_POINT text at top-left (does not affect inference)
    annotated = frame_bgr.copy()
    jp_name = lane_name_from_point(jake_point).upper()
    cv2.putText(annotated, f"JAKE_POINT: {jp_name}",
                (10, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 0), 2, cv2.LINE_AA)
    out_path = out_dir / f"live_{frame_idx:05d}.jpg"
    cv2.imwrite(str(out_path), annotated)

    # (Optional) inter-frame delta print — comment out if noisy
    now = time.time()
    # print(f"Δ between frames: {now - prev_ts:.3f}s")
    prev_ts = now

# Cleanup
listener.join()
print("Script halted.")


In [2]:
#Overlay saves and movement logic implementation

In [3]:
#!/usr/bin/env python3
# Live overlays + lane-aware curved sampling (optimized postproc)
# • Parsec focus + auto click
# • mss live capture of a crop region
# • Arrow keys switch lane (0/1/2) -> JAKE_POINT updates per frame
# • Full overlay rendering + per-frame save
# • Prints compact timing per frame

import os, time, math, subprocess, glob
import cv2, torch, numpy as np
from pathlib import Path
from mss import mss
import pyautogui
from pynput import keyboard
from ultralytics import YOLO

# =======================
# Config
# =======================
home       = os.path.expanduser("~")
weights    = f"{home}/models/jakes-loped/jakes-finder-mk1/1/weights.pt"

# SAVE HERE
out_dir    = Path(home) / "Documents" / "GitHub" / "Ai-plays-SubwaySurfers" / "out_live_overlays"
out_dir.mkdir(parents=True, exist_ok=True)

# Crop + click (set by ad layout)
advertisement = True
if advertisement:
    snap_coords = (644, 77, (1149-644), (981-75))  # (left, top, width, height)
    start_click = (1030, 900)
else:
    snap_coords = (483, 75, (988-483), (981-75))
    start_click = (870, 895)

RAIL_ID    = 9
IMG_SIZE   = 512
CONF, IOU  = 0.30, 0.45
MAX_DET    = 30

# Color/region filter
TARGET_COLORS_RGB  = [(119,104,67), (81,42,45)]
TOLERANCE          = 20.0
MIN_REGION_SIZE    = 30
MIN_REGION_HEIGHT  = 150

# Heat/triangle
HEAT_BLUR_KSIZE     = 51
RED_SCORE_THRESH    = 220
EXCLUDE_TOP_FRAC    = 0.40
EXCLUDE_BOTTOM_FRAC = 0.15
MIN_DARK_RED_AREA   = 1200
MIN_DARK_FRACTION   = 0.15
TRI_SIZE_PX         = 18

# Sampling ray length
SAMPLE_UP_PX        = 180
RAY_STEP_PX         = 20   # walk the ray every 20 px

# ===== Bend degrees (tune here) =====
BEND_LEFT_STATE_RIGHT_DEG  = -20.0  # N1
BEND_MID_STATE_RIGHT_DEG   = -20.0  # N2
BEND_MID_STATE_LEFT_DEG    = +20.0  # N3
BEND_RIGHT_STATE_LEFT_DEG  = +20.0  # N4

# Colours (BGR)
COLOR_GREEN  = (0, 255, 0)
COLOR_PINK   = (203, 192, 255)
COLOR_YELLOW = (0, 255, 255)
COLOR_RED    = (0, 0, 255)
COLOR_WHITE  = (255, 255, 255)
COLOR_CYAN   = (255, 255, 0)
COLOR_BLACK  = (0, 0, 0)

# =======================
# Jake lane points + dynamic JAKE_POINT
# =======================
LANE_LEFT   = (300, 1340)
LANE_MID    = (490, 1340)
LANE_RIGHT  = (680, 1340)
LANE_POINTS = (LANE_LEFT, LANE_MID, LANE_RIGHT)  # index by lane (0,1,2)
JAKE_POINT  = LANE_MID  # will be overwritten each frame from 'lane'

LANE_TARGET_DEG = {"left": -10.7, "mid": +1.5, "right": +15.0}

def lane_name_from_point(p):
    if p == LANE_LEFT:  return "left"
    if p == LANE_MID:   return "mid"
    if p == LANE_RIGHT: return "right"
    return "mid"

# =======================
# Lane/keyboard state
# =======================
lane = 1
MIN_LANE = 0
MAX_LANE = 2
running = True

def on_press(key):
    global lane, running
    try:
        if key == keyboard.Key.left:
            lane = max(MIN_LANE, lane - 1)
            print(f"Moved Left → Lane {lane}")
        elif key == keyboard.Key.right:
            lane = min(MAX_LANE, lane + 1)
            print(f"Moved Right → Lane {lane}")
        elif key == keyboard.Key.esc:
            print("ESC pressed — exiting")
            running = False
            return False
    except Exception as e:
        print(f"Error: {e}")

# =======================
# System/Backends
# =======================
cv2.setUseOptimized(True)
try: cv2.setNumThreads(max(1, (os.cpu_count() or 1) - 1))
except Exception: pass

if torch.cuda.is_available():
    device, half = 0, True
    torch.backends.cudnn.benchmark = True
    try: torch.set_float32_matmul_precision('high')
    except Exception: pass
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device, half = "mps", False
else:
    device, half = "cpu", False

# =======================
# Model
# =======================
model = YOLO(weights)
try: model.fuse()
except Exception: pass

# warmup
_dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
_ = model.predict(_dummy, task="segment", imgsz=IMG_SIZE, device=device,
                  conf=CONF, iou=IOU, verbose=False, half=half, max_det=MAX_DET)

# =======================
# Precomputed
# =======================
TARGETS_BGR_F32 = np.array([(r,g,b)[::-1] for (r,g,b) in TARGET_COLORS_RGB], dtype=np.float32)
TOL2            = TOLERANCE * TOLERANCE

# Class buckets for probe classification
DANGER_RED   = {1, 6, 7, 11}
WARN_YELLOW  = {2, 3, 4, 5, 8}
BOOTS_PINK   = {0}

CLASS_COLOURS = {
    0:(255,255,0),1:(192,192,192),2:(0,128,255),3:(0,255,0),
    4:(255,0,255),5:(0,255,255),6:(255,128,0),7:(128,0,255),
    8:(0,0,128),9:(0,0,255),10:(128,128,0),11:(255,255,102)
}
LABELS = {
    0:"BOOTS",1:"GREYTRAIN",2:"HIGHBARRIER1",3:"JUMP",4:"LOWBARRIER1",
    5:"LOWBARRIER2",6:"ORANGETRAIN",7:"PILLAR",8:"RAMP",9:"RAILS",
    10:"SIDEWALK",11:"YELLOWTRAIN"
}

# ====== tiny helpers ======
def _clampi(v, lo, hi):
    return lo if v < lo else (hi if v > hi else v)

# =======================
# Parsec to front + click Start (non-blocking failures)
# =======================
try:
    subprocess.run(["osascript", "-e", 'tell application "Parsec" to activate'], check=False)
    time.sleep(0.4)
except Exception:
    pass

try:
    pyautogui.click(start_click)
except Exception:
    pass

# =======================
# Fast rails green finder
# =======================
def highlight_rails_mask_only_fast(img_bgr, rail_mask):
    H, W = rail_mask.shape
    if not rail_mask.any():
        return np.zeros((H, W), dtype=bool)

    rail_u8 = rail_mask.view(dtype=np.uint8) * 255
    x, y, w, h = cv2.boundingRect(rail_u8)
    img_roi  = img_bgr[y:y+h, x:x+w]
    mask_roi = rail_u8[y:y+h, x:x+w]

    img_f = img_roi.astype(np.float32, copy=False)
    diff  = img_f[:, :, None, :] - TARGETS_BGR_F32[None, None, :, :]
    dist2 = (diff * diff).sum(-1)
    colour_hit = (dist2 <= TOL2).any(-1)

    combined = np.logical_and(colour_hit, mask_roi.astype(bool))

    comp = combined.astype(np.uint8)
    n, lbls, stats, _ = cv2.connectedComponentsWithStats(comp, 8)
    if n <= 1: return np.zeros((H, W), dtype=bool)

    good = np.zeros_like(combined)
    areas = stats[1:, cv2.CC_STAT_AREA]
    hs    = stats[1:, cv2.CC_STAT_HEIGHT]
    keep  = np.where((areas >= MIN_REGION_SIZE) & (hs >= MIN_REGION_HEIGHT))[0] + 1
    for k in keep: good[lbls == k] = True

    full = np.zeros((H, W), dtype=bool)
    full[y:y+h, x:x+w] = good
    return full

def red_vs_green_score(red_mask, green_mask):
    k = (HEAT_BLUR_KSIZE, HEAT_BLUR_KSIZE)
    r = cv2.blur(red_mask.astype(np.float32, copy=False), k)
    g = cv2.blur(green_mask.astype(np.float32, copy=False), k)
    diff = r - g
    amax = float(np.max(np.abs(diff))) + 1e-6
    norm = (diff / (2.0 * amax) + 0.5)
    return np.clip(norm * 255.0, 0, 255.0).astype(np.uint8, copy=False)

def purple_triangles(score, H):
    top_ex = int(H * EXCLUDE_TOP_FRAC)
    bot_ex = int(H * EXCLUDE_BOTTOM_FRAC)
    dark = (score >= RED_SCORE_THRESH).astype(np.uint8, copy=False)
    if top_ex: dark[:top_ex, :] = 0
    if bot_ex: dark[-bot_ex:, :] = 0

    dark = cv2.morphologyEx(
        dark, cv2.MORPH_OPEN,
        cv2.getStructuringElement(cv2.MORPH_RECT, (5, 9)), iterations=1
    )
    total_dark = int(dark.sum())
    if total_dark == 0: return [], None

    frac_thresh = int(np.ceil(MIN_DARK_FRACTION * total_dark))
    n_lbl, lbls, stats, _ = cv2.connectedComponentsWithStats(dark, 8)
    if n_lbl <= 1: return [], None

    tris = []
    for lbl in range(1, n_lbl):
        area = stats[lbl, cv2.CC_STAT_AREA]
        if area >= MIN_DARK_RED_AREA and area >= frac_thresh:
            ys, xs = np.where(lbls == lbl)
            if ys.size == 0: continue
            y_top = ys.min()
            x_mid = int(xs[ys == y_top].mean())
            tris.append((x_mid, int(y_top)))

    if not tris: return [], None
    best = min(tris, key=lambda xy: xy[1])
    return tris, best

# ===== Bearing-based Jake triangle selection =====
def signed_degrees_from_vertical(dx, dy):
    if dx == 0 and dy == 0: return 0.0
    return -math.degrees(math.atan2(dx, -dy))

def select_triangle_by_bearing(tri_positions, jx, jy, target_deg, min_dy=6):
    best_i, best_deg, best_err = -1, None, None
    for i, (xt, yt) in enumerate(tri_positions):
        dy = yt - jy
        if dy >= -min_dy:  # must be above Jake
            continue
        deg = signed_degrees_from_vertical(xt - jx, dy)
        err = abs(deg - target_deg)
        if (best_err is None) or (err < best_err):
            best_i, best_deg, best_err = i, deg, err
    return best_i, best_deg, best_err

# ===== Lane-aware curved sampling (precompute sin/cos) =====
def _precompute_trig():
    angles = sorted(set([0.0,
        BEND_LEFT_STATE_RIGHT_DEG,
        BEND_MID_STATE_RIGHT_DEG,
        BEND_MID_STATE_LEFT_DEG,
        BEND_RIGHT_STATE_LEFT_DEG
    ]))
    table = {}
    for a in angles:
        r = math.radians(a)
        table[a] = (math.sin(r), -math.cos(r))  # (dx, dy) for unit ray (up = -y)
    return table
TRIG_TABLE = _precompute_trig()

def pick_bend_angle(jake_point, xt, x_ref, idx, best_idx):
    if idx == best_idx:
        return 0.0
    if jake_point == LANE_LEFT:
        return BEND_LEFT_STATE_RIGHT_DEG if xt > x_ref else 0.0
    if jake_point == LANE_RIGHT:
        return BEND_RIGHT_STATE_LEFT_DEG if xt < x_ref else 0.0
    if xt > x_ref: return BEND_MID_STATE_RIGHT_DEG
    if xt < x_ref: return BEND_MID_STATE_LEFT_DEG
    return 0.0

# --------- walk-the-ray classifier (first-hit wins) ----------
def classify_triangles_at_sample_curved(
    tri_positions, masks_np, classes_np, H, W,
    jake_point, x_ref, best_idx, sample_px=SAMPLE_UP_PX, step_px=RAY_STEP_PX
):
    if masks_np is None or classes_np is None or len(tri_positions) == 0:
        return [], []

    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1))
    sy = (mh - 1) / max(1, (H - 1))

    red_idx    = [i for i, c in enumerate(classes_np) if int(c) in DANGER_RED]
    yellow_idx = [i for i, c in enumerate(classes_np) if int(c) in WARN_YELLOW]
    boots_idx  = [i for i, c in enumerate(classes_np) if int(c) in BOOTS_PINK]

    colours, rays = [], []
    max_k = max(1, sample_px // max(1, step_px))

    for idx, (x0, y0) in enumerate(tri_positions):
        theta = pick_bend_angle(jake_point, x0, x_ref, idx, best_idx)
        dx1, dy1 = TRIG_TABLE[theta]

        hit_colour = COLOR_GREEN
        found = False
        for k in range(1, max_k + 1):
            t  = k * step_px
            xs = _clampi(int(round(x0 + dx1 * t)), 0, W-1)
            ys = _clampi(int(round(y0 + dy1 * t)), 0, H-1)
            mx = _clampi(int(round(xs * sx)), 0, mw-1)
            my = _clampi(int(round(ys * sy)), 0, mh-1)

            for i in red_idx:
                if masks_np[i][my, mx] > 0.5: hit_colour = COLOR_RED; found = True; break
            if found: break
            for i in yellow_idx:
                if masks_np[i][my, mx] > 0.5: hit_colour = COLOR_YELLOW; found = True; break
            if found: break
            for i in boots_idx:
                if masks_np[i][my, mx] > 0.5: hit_colour = COLOR_PINK; found = True; break
            if found: break

        x1 = _clampi(int(round(x0 + dx1 * sample_px)), 0, W-1)
        y1 = _clampi(int(round(y0 + dy1 * sample_px)), 0, H-1)

        colours.append(hit_colour)
        rays.append(((int(x0), int(y0)), (x1, y1), float(theta)))

    return colours, rays
# -----------------------------------------------------------------------

# =======================
# Frame post-processing
# =======================
def process_frame_post(frame_bgr, yolo_res, jake_point):
    H, W = frame_bgr.shape[:2]
    if yolo_res.masks is None:
        return None, 0, 0, 0.0, 0.0, None, None, None, None, [], [], [], None, None, None

    t0 = time.perf_counter()
    masks_np = yolo_res.masks.data.detach().cpu().numpy()  # [n,h,w]
    if hasattr(yolo_res.masks, "cls") and yolo_res.masks.cls is not None:
        classes_np = yolo_res.masks.cls.detach().cpu().numpy().astype(int)
    else:
        classes_np = yolo_res.boxes.cls.detach().cpu().numpy().astype(int)
    to_cpu_ms = (time.perf_counter() - t0) * 1000.0
    mask_count = int(masks_np.shape[0])
    if mask_count == 0 or classes_np.size == 0:
        return None, 0, mask_count, to_cpu_ms, 0.0, masks_np, classes_np, None, None, [], [], [], None, None, None

    rail_sel = (classes_np == RAIL_ID)
    if not np.any(rail_sel):
        return None, 0, mask_count, to_cpu_ms, 0.0, masks_np, classes_np, None, None, [], [], [], None, None, None

    t1 = time.perf_counter()
    rail_masks = masks_np[rail_sel].astype(bool, copy=False)
    union = np.any(rail_masks, axis=0).astype(np.uint8, copy=False)
    rail_mask = cv2.resize(union, (W, H), interpolation=cv2.INTER_NEAREST).astype(bool, copy=False)

    green = highlight_rails_mask_only_fast(frame_bgr, rail_mask)
    red   = np.logical_and(rail_mask, np.logical_not(green))
    score = red_vs_green_score(red, green)
    tri_positions, tri_best = purple_triangles(score, H)

    # Jake triangle by bearing
    lane_name = lane_name_from_point(jake_point)
    target_deg = LANE_TARGET_DEG[lane_name]
    xj, yj = jake_point
    best_idx, best_deg, _ = select_triangle_by_bearing(tri_positions, xj, yj, target_deg, min_dy=6)

    # x_ref for bending
    if lane_name == "mid" and (best_idx is not None) and (0 <= best_idx < len(tri_positions)):
        x_ref = tri_positions[best_idx][0]
    else:
        x_ref = xj

    tri_colours, tri_rays = classify_triangles_at_sample_curved(
        tri_positions, masks_np, classes_np, H, W, jake_point, x_ref, best_idx,
        SAMPLE_UP_PX, RAY_STEP_PX
    )
    post_ms = (time.perf_counter() - t1) * 1000.0

    return tri_best, len(tri_positions), mask_count, to_cpu_ms, post_ms, masks_np, classes_np, rail_mask, green, tri_positions, tri_colours, tri_rays, best_idx, best_deg, x_ref

# =======================
# Viz helpers
# =======================
def _colour_for_point(x, y, masks_np, classes_np, H, W):
    if masks_np is None or classes_np is None or masks_np.size == 0: return COLOR_GREEN
    mh, mw = masks_np.shape[1], masks_np.shape[2]
    sx = (mw - 1) / max(1, (W - 1)); sy = (mh - 1) / max(1, (H - 1))
    mx = _clampi(int(round(x * sx)), 0, mw-1)
    my = _clampi(int(round(y * sy)), 0, mh-1)
    cls_here = None
    for m, c in zip(masks_np, classes_np):
        if m[my, mx] > 0.5: cls_here = int(c); break
    if cls_here in DANGER_RED:   return COLOR_RED
    if cls_here in WARN_YELLOW:  return COLOR_YELLOW
    if cls_here in BOOTS_PINK:   return COLOR_PINK
    return COLOR_GREEN

def draw_triangle(img, x, y, size=TRI_SIZE_PX, colour=COLOR_RED):
    h = int(size * 1.2)
    pts = np.array([[x, y], [x-size, y+h], [x+size, y+h]], np.int32)
    cv2.fillConvexPoly(img, pts, colour)
    cv2.polylines(img, [pts.reshape(-1,1,2)], True, COLOR_BLACK, 1, cv2.LINE_AA)

def triangle_pts(x, y, size=TRI_SIZE_PX):
    h = int(size * 1.2)
    return np.array([[x, y], [x-size, y+h], [x+size, y+h]], np.int32)

def render_overlays(frame_bgr, masks_np, classes_np, rail_mask, green_mask,
                    tri_positions, tri_colours, tri_rays, best_idx, best_deg, x_ref, jake_point):
    out = frame_bgr.copy()
    H, W = out.shape[:2]
    alpha = 0.45

    if masks_np is not None and classes_np is not None and masks_np.size:
        for m, c in zip(masks_np, classes_np):
            m_full = m
            if m.shape != (H, W):
                m_full = cv2.resize(m.astype(np.uint8), (W, H), interpolation=cv2.INTER_NEAREST).astype(bool)
            color = CLASS_COLOURS.get(int(c), (255,255,255))
            out[m_full] = (np.array(color, dtype=np.uint8) * alpha + out[m_full] * (1 - alpha)).astype(np.uint8)
            ys, xs = np.where(m_full)
            if xs.size:
                xc, yc = int(xs.mean()), int(ys.mean())
                label = LABELS.get(int(c), f"C{int(c)}")
                cv2.putText(out, label, (max(5, xc-40), max(20, yc)),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, COLOR_BLACK, 2, cv2.LINE_AA)
                cv2.putText(out, label, (max(5, xc-40), max(20, yc)),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 1, cv2.LINE_AA)

    if rail_mask is not None:
        tint = out.copy()
        tint[rail_mask] = (0, 0, 255)
        out = cv2.addWeighted(tint, 0.30, out, 0.70, 0)
    if green_mask is not None:
        out[green_mask] = (0, 255, 0)

    # tiny scout lines (viz only)
    if tri_positions:
        for (x, y) in tri_positions:
            y_end = max(0, y - SAMPLE_UP_PX)
            for yy in range(y, y_end - 1, -1):
                out[yy, x] = _colour_for_point(x, yy, masks_np, classes_np, H, W)

    # starburst to Jake
    xj, yj = jake_point
    for idx, (xt, yt) in enumerate(tri_positions):
        xt = _clampi(int(xt), 0, W-1); yt = _clampi(int(yt), 0, H-1)
        deg_signed = signed_degrees_from_vertical(xt - xj, yt - yj)
        cv2.line(out, (xj, yj), (xt, yt),
                 COLOR_CYAN if idx == best_idx else COLOR_WHITE, 2, cv2.LINE_AA)
        mx = (xj + xt) // 2; my = (yj + yt) // 2
        txt = f"{deg_signed:.1f}°"
        cv2.putText(out, txt, (mx, my), cv2.FONT_HERSHEY_SIMPLEX, 0.55, COLOR_BLACK, 2, cv2.LINE_AA)
        cv2.putText(out, txt, (mx, my), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255,255,255), 1, cv2.LINE_AA)

    # curved sampling rays (viz)
    for (p0, p1, theta) in tri_rays:
        cv2.line(out, p0, p1, (255,255,255), 2, cv2.LINE_AA)
        mx = (p0[0] + p1[0]) // 2; my = (p0[1] + p1[1]) // 2
        ttxt = f"{theta:+.1f}°"
        cv2.putText(out, ttxt, (mx, my), cv2.FONT_HERSHEY_SIMPLEX, 0.55, COLOR_BLACK, 2, cv2.LINE_AA)
        cv2.putText(out, ttxt, (mx, my), cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255,255,255), 1, cv2.LINE_AA)

    for (x, y), col in zip(tri_positions, tri_colours):
        draw_triangle(out, int(x), int(y), colour=col)

    lane_name = lane_name_from_point(jake_point)
    target_deg = LANE_TARGET_DEG[lane_name]
    if best_idx is not None and 0 <= best_idx < len(tri_positions):
        xt, yt = tri_positions[best_idx]
        pts = triangle_pts(int(xt), int(yt), size=TRI_SIZE_PX)
        cv2.polylines(out, [pts.reshape(-1,1,2)], True, COLOR_CYAN, 3, cv2.LINE_AA)
        tag = f"JAKE_TRI ({lane_name}: target {target_deg:.1f}°)"
        cv2.putText(out, tag, (max(5, int(xt)-70), max(20, int(yt)-16)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.55, COLOR_BLACK, 2, cv2.LINE_AA)
        cv2.putText(out, tag, (max(5, int(xt)-70), max(20, int(yt)-16)),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.55, (255,255,255), 1, cv2.LINE_AA)

    # top-left JAKE_POINT state
    cv2.putText(out, f"JAKE_POINT: {lane_name.upper()}",
                (10, 24), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,255,0), 2, cv2.LINE_AA)

    return out

# =======================
# Live loop
# =======================
listener = keyboard.Listener(on_press=on_press)
listener.start()

sct = mss()
frame_idx = 0

while running:
    # Screen grab
    left, top, width, height = snap_coords
    raw = sct.grab({"left": left, "top": top, "width": width, "height": height})
    frame_bgr = np.array(raw)[:, :, :3]  # BGRA -> BGR

    # Dynamic JAKE_POINT from current lane (O(1))
    JAKE_POINT = LANE_POINTS[lane]

    # Inference
    t0_inf = time.perf_counter()
    res_list = model.predict(
        [frame_bgr], task="segment", imgsz=IMG_SIZE, device=device,
        conf=CONF, iou=IOU, verbose=False, half=half, max_det=MAX_DET, batch=1
    )
    infer_ms = (time.perf_counter() - t0_inf) * 1000.0
    yres = res_list[0]

    # Postproc
    (tri_best_xy, tri_count, mask_count, to_cpu_ms, post_ms,
     masks_np, classes_np, rail_mask, green_mask, tri_positions, tri_colours,
     tri_rays, best_idx, best_deg, x_ref) = process_frame_post(frame_bgr, yres, JAKE_POINT)

    proc_ms = infer_ms + to_cpu_ms + post_ms

    # Render + save
    overlay = render_overlays(frame_bgr, masks_np, classes_np, rail_mask, green_mask,
                              tri_positions, tri_colours, tri_rays, best_idx, best_deg, x_ref, JAKE_POINT)
    frame_idx += 1
    out_path = out_dir / f"live_overlay_{frame_idx:05d}.jpg"
    cv2.imwrite(str(out_path), overlay)

    # Print compact timing
    print(f"[live {frame_idx:05d}] "
          f"infer {infer_ms:.1f} | to_cpu {to_cpu_ms:.1f} | post {post_ms:.1f} | "
          f"masks {mask_count} | triangles {tri_count} => proc {proc_ms:.1f} ms")

# Cleanup
listener.join()
print("Script halted.")


YOLO11n-seg summary (fused): 113 layers, 2,836,908 parameters, 0 gradients, 10.2 GFLOPs
[live 00001] infer 700.4 | to_cpu 1.3 | post 169.2 | masks 1 | triangles 1 => proc 870.9 ms
[live 00002] infer 184.1 | to_cpu 197.3 | post 139.2 | masks 5 | triangles 1 => proc 520.6 ms
[live 00003] infer 149.7 | to_cpu 1.0 | post 146.2 | masks 6 | triangles 1 => proc 297.0 ms
[live 00004] infer 84.5 | to_cpu 1.1 | post 144.9 | masks 5 | triangles 1 => proc 230.5 ms
Moved Right → Lane 2
[live 00005] infer 143.6 | to_cpu 0.8 | post 157.1 | masks 4 | triangles 3 => proc 301.6 ms
[live 00006] infer 78.1 | to_cpu 1.0 | post 148.7 | masks 5 | triangles 2 => proc 227.9 ms
[live 00007] infer 75.2 | to_cpu 1.0 | post 148.8 | masks 4 | triangles 2 => proc 225.0 ms
[live 00008] infer 43.7 | to_cpu 1.0 | post 144.2 | masks 4 | triangles 2 => proc 188.9 ms
[live 00009] infer 183.2 | to_cpu 0.9 | post 145.8 | masks 3 | triangles 2 => proc 329.9 ms
[live 00010] infer 43.5 | to_cpu 0.9 | post 114.1 | masks 4 | tri

KeyboardInterrupt: 