In [None]:
#Hyperfast External printing 

In [None]:
#!/usr/bin/env python3
"""
High-utilization visualiser: pipelined prefetch → inference → postproc.

• Keeps GPU/MPS saturated: inference on frame N+1 overlaps CPU postproc on N
• Runs YOLO on GPU/MPS, RF-DETR on CPU by default (detected automatically)
• Preserves all logic: labels, Jake bbox, under-mask tint, triangle, heat
• Warmup frames executed but not timed; summary at exit
• SPACE to advance; q/ESC to quit
"""

import os, glob, sys, time, statistics, threading, queue
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
from PIL import Image
from ultralytics import YOLO
from rfdetr import RFDETRBase

# ── disable antialias to avoid unsupported MPS op on some ops ─
_original_interpolate = F.interpolate
def _interpolate_no_antialias(input, *args, **kwargs):
    kwargs['antialias'] = False
    return _original_interpolate(input, *args, **kwargs)
F.interpolate = _interpolate_no_antialias

# =======================
# Config
# =======================
home          = os.path.expanduser("~")
yolo_weights  = f"{home}/models/jakes-loped/jakes-finder-mk1/1/weights.pt"
jake_weights  = f"{home}/downloads/weightsjake.pt"
RAIL_ID       = 9
IMG_SIZE      = 512
CONF, IOU     = 0.30, 0.45
ALPHA         = 0.40

TARGET_COLORS_RGB  = [(119,104,67), (81,42,45)]
TOLERANCE          = 20.0
MIN_REGION_SIZE    = 30
MIN_REGION_HEIGHT  = 150

HEAT_BLUR_KSIZE     = 51
RED_SCORE_THRESH    = 220
EXCLUDE_TOP_FRAC    = 0.40
EXCLUDE_BOTTOM_FRAC = 0.15
MIN_DARK_RED_AREA   = 1200
TRI_SIZE_PX         = 18
PURPLE              = (255, 0, 255)
MIN_DARK_FRACTION   = 0.15

WIN_NAME        = "Left: labels | Right: rails | Bottom: heat (SPACE=next, q=quit)"
MAX_DISPLAY_W   = 1600
TEXT_CLR        = (255, 255, 255)
JAKE_BOX_CLR    = (0, 255, 0)
UNDER_TINT_BGR  = (255, 0, 255)

WARMUP_FRAMES   = 10
PREFETCH        = 24    # bigger prefetch helps saturate I/O
INFER_QUEUE     = 4     # frames in-flight through inference

# CPU threading (OpenCV)
OPENCV_THREADS  = max(1, (os.cpu_count() or 8) - 1)

# =======================
# Device selection
# =======================
if torch.cuda.is_available():
    yolo_device, half = 0, True
    DET_ON_CPU = True
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    yolo_device, half = "mps", False
    DET_ON_CPU = True  # keep RF-DETR on CPU to overlap with MPS
else:
    yolo_device, half = "cpu", False
    DET_ON_CPU = False  # both CPU; still benefits from pipeline

# =======================
# Models
# =======================
yolo_model = YOLO(yolo_weights)
try: yolo_model.fuse()
except Exception: pass

jake_model = RFDETRBase(pretrain_weights=jake_weights, num_classes=3)

# warmup YOLO
_dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
_ = yolo_model.predict(_dummy, task="segment", imgsz=IMG_SIZE,
                       device=yolo_device, conf=CONF, iou=IOU,
                       verbose=False, half=half)

# OpenCV threads
try:
    cv2.setNumThreads(OPENCV_THREADS)
except Exception:
    pass

# =======================
# Utils
# =======================
obstacle_classes = {
    0:"BOOTS",1:"GREYTRAIN",2:"HIGHBARRIER1",3:"JUMP",4:"LOWBARRIER1",
    5:"LOWBARRIER2",6:"ORANGETRAIN",7:"PILLAR",8:"RAMP",9:"RAILS",
    10:"SIDEWALK",11:"YELLOWTRAIN"
}
CLASS_COLOURS = {
    0:(255,255,0),1:(192,192,192),2:(0,128,255),3:(0,255,0),
    4:(255,0,255),5:(0,255,255),6:(255,128,0),7:(128,0,255),
    8:(0,0,128),10:(128,128,0),11:(255,255,102)
}

def highlight_rails_mask_only_fast(img_bgr, rail_mask, target_colors_rgb,
                                   tolerance=30.0, min_region_size=50,
                                   min_region_height=150):
    H, W = img_bgr.shape[:2]
    filtered_full = np.zeros((H, W), dtype=bool)
    if not rail_mask.any():
        return filtered_full
    ys, xs = np.where(rail_mask)
    y0, y1 = ys.min(), ys.max()+1
    x0, x1 = xs.min(), xs.max()+1
    img_roi  = img_bgr[y0:y1, x0:x1]
    mask_roi = rail_mask[y0:y1, x0:x1]
    targets_bgr = np.array([(r,g,b)[::-1] for r,g,b in target_colors_rgb], dtype=np.float32)
    img_f = img_roi.astype(np.float32)
    diff  = img_f[:, :, None, :] - targets_bgr[None, None, :, :]
    dist2 = np.sum(diff*diff, axis=-1)
    colour_hit = np.any(dist2 <= tolerance**2, axis=-1)
    combined = colour_hit & mask_roi
    comp = combined.astype(np.uint8)
    n, lbls, stats, _ = cv2.connectedComponentsWithStats(comp, 8)
    good = np.zeros_like(combined)
    for lbl in range(1, n):
        area, h = stats[lbl, cv2.CC_STAT_AREA], stats[lbl, cv2.CC_STAT_HEIGHT]
        if area >= min_region_size and h >= min_region_height:
            good[lbls == lbl] = True
    filtered_full[y0:y1, x0:x1] = good
    return filtered_full

def red_vs_green_score(red_mask, green_mask, blur_ksize=51):
    k = (blur_ksize, blur_ksize)
    r = cv2.blur(red_mask.astype(np.float32), k)
    g = cv2.blur(green_mask.astype(np.float32), k)
    diff = r - g
    amax = float(np.max(np.abs(diff))) + 1e-6
    norm = (diff / (2*amax) + 0.5)
    return np.clip(norm * 255, 0, 255).astype(np.uint8)

def draw_triangle(img, x, y, size=TRI_SIZE_PX, colour=PURPLE):
    h = int(size * 1.2)
    pts = np.array([[x, y], [x-size, y+h], [x+size, y+h]], np.int32)
    cv2.fillConvexPoly(img, pts, colour)
    cv2.polylines(img, [pts.reshape(-1,1,2)], True, (0,0,0), 1, cv2.LINE_AA)

def assemble_canvas(left, right, heat):
    top    = np.hstack((left, right))
    return np.vstack((top, heat))

def maybe_downscale(img, max_w=MAX_DISPLAY_W):
    h, w = img.shape[:2]
    if w <= max_w: return img
    scale = max_w / float(w)
    return cv2.resize(img, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)

# =======================
# Inference worker (persistent)
# =======================
class InferenceWorker(threading.Thread):
    """
    Consumes frames from in_q, runs YOLO (GPU/MPS) + RF-DETR (CPU optionally),
    returns (path, frame, masks(np), classes(np), jake_xyxy or None)
    """
    def __init__(self, in_q, out_q):
        super().__init__(daemon=True)
        self.in_q  = in_q
        self.out_q = out_q
        self.stop  = False

    def run(self):
        while not self.stop:
            item = self.in_q.get()
            if item is None:  # sentinel
                self.out_q.put(None)
                break
            path, frame = item
            H, W = frame.shape[:2]

            # Run YOLO on device
            res = yolo_model.predict(frame, task="segment", imgsz=IMG_SIZE,
                                     device=yolo_device, conf=CONF, iou=IOU,
                                     max_det=30, verbose=False, half=half)[0]

            # RF-DETR on CPU to overlap with GPU (unless both CPU)
            jake_xyxy = None
            try:
                pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                det_dev = "cpu" if DET_ON_CPU else ("cpu" if yolo_device != 0 else 0)
                dets = jake_model.predict(pil_img, threshold=0.5, device=det_dev)[0]
                if hasattr(dets, "xyxy") and len(dets.xyxy) > 0:
                    x1, y1, x2, y2 = dets.xyxy[0].astype(int).tolist()
                    # clamp
                    x1 = max(0, min(W-1, x1)); x2 = max(0, min(W-1, x2))
                    y1 = max(0, min(H-1, y1)); y2 = max(0, min(H-1, y2))
                    if x2 > x1 and y2 > y1:
                        jake_xyxy = (x1, y1, x2, y2)
            except Exception:
                pass

            if res.masks is None:
                self.out_q.put((path, frame, None, None, jake_xyxy))
                continue

            masks   = res.masks.data.cpu().numpy()  # (N, h_m, w_m)
            classes = (res.masks.cls.cpu().numpy() if hasattr(res.masks, "cls")
                       else res.boxes.cls.cpu().numpy())
            classes = np.asarray(classes, dtype=int).ravel()  # <<< FIX: flatten here

            self.out_q.put((path, frame, masks, classes, jake_xyxy))

# =======================
# Prefetcher
# =======================
class FramePrefetcher(threading.Thread):
    def __init__(self, paths, out_q, maxsize=PREFETCH):
        super().__init__(daemon=True)
        self.paths = paths
        self.q = out_q
        self.stop = False

    def run(self):
        for p in self.paths:
            if self.stop: break
            img = cv2.imread(p)
            self.q.put((p, img))
        self.q.put(None)  # sentinel

# =======================
# Main
# =======================
if __name__ == "__main__":
    folder = Path.home() / "Documents" / "GitHub" / "Ai-plays-SubwaySurfers" / "frames"
    if not folder.is_dir():
        print(f"[ERROR] frames folder not found: {folder}", file=sys.stderr); sys.exit(1)

    paths = sorted(
        glob.glob(str(folder / "frame_*.jpg")) +
        glob.glob(str(folder / "frame_*.png")) +
        glob.glob(str(folder / "*.jpg")) +
        glob.glob(str(folder / "*.png"))
    )
    if not paths:
        print(f"[ERROR] No frame images found in {folder}", file=sys.stderr); sys.exit(1)

    cv2.namedWindow(WIN_NAME, cv2.WINDOW_NORMAL)

    # Queues
    prefetch_q = queue.Queue(maxsize=PREFETCH)
    infer_in_q = queue.Queue(maxsize=INFER_QUEUE)
    infer_out_q = queue.Queue(maxsize=INFER_QUEUE)

    # Threads
    pf = FramePrefetcher(paths, prefetch_q, maxsize=PREFETCH)
    pf.start()

    iw = InferenceWorker(infer_in_q, infer_out_q)
    iw.start()

    # Prime inference queue with a few frames
    buffered = 0
    while buffered < INFER_QUEUE:
        item = prefetch_q.get()
        infer_in_q.put(item)
        if item is None: break
        buffered += 1

    times_ms = []
    processed = 0
    t_all_start = time.perf_counter()

    i = 0
    while True:
        # Keep inference fed
        while not prefetch_q.empty() and infer_in_q.qsize() < INFER_QUEUE:
            infer_in_q.put(prefetch_q.get())

        out = infer_out_q.get()
        if out is None:
            break

        p, frame, masks, classes, jake_xyxy = out
        if frame is None:
            i += 1
            continue

        # <<< Ensure classes is always a flat int array for safety
        if classes is not None:
            classes = np.asarray(classes, dtype=int).ravel()

        t0 = time.perf_counter()

        H, W = frame.shape[:2]
        if masks is None or classes is None:
            left = frame.copy(); right = frame.copy()
            heat = cv2.applyColorMap(np.zeros((H, W), np.uint8), cv2.COLORMAP_JET)
        else:
            h_m, w_m = masks.shape[1:]
            rail_union = np.zeros((h_m, w_m), dtype=bool)
            for m, c in zip(masks, classes):
                if int(c) == RAIL_ID:
                    rail_union |= m.astype(bool)
            rail_mask = cv2.resize(rail_union.astype(np.uint8), (W, H),
                                   interpolation=cv2.INTER_NEAREST).astype(bool)

            # Green/red+score
            green = highlight_rails_mask_only_fast(frame, rail_mask,
                                                   TARGET_COLORS_RGB, TOLERANCE,
                                                   MIN_REGION_SIZE, MIN_REGION_HEIGHT)
            red   = rail_mask & ~green
            score = red_vs_green_score(red, green, HEAT_BLUR_KSIZE)

            # bands
            top_ex = int(H * EXCLUDE_TOP_FRAC)
            bot_ex = int(H * EXCLUDE_BOTTOM_FRAC)
            score[:top_ex, :] = 0
            if bot_ex: score[H-bot_ex:, :] = 0

            dark = (score >= RED_SCORE_THRESH).astype(np.uint8)
            dark = cv2.morphologyEx(
                dark, cv2.MORPH_OPEN,
                cv2.getStructuringElement(cv2.MORPH_RECT, (5, 9)),
                iterations=1
            )

            total_dark_area  = int(dark.sum())
            frac_area_thresh = int(np.ceil(MIN_DARK_FRACTION * total_dark_area))

            # LEFT labels + under-mask tint
            left    = frame.copy()
            overlay = left.copy()

            # Jake overlap (low-res bbox)
            best_idx = None
            if jake_xyxy and masks is not None:
                x1, y1, x2, y2 = jake_xyxy
                sx, sy = w_m / W, h_m / H
                mx1, mx2 = max(0, int(x1 * sx)), min(w_m, int(x2 * sx))
                my1, my2 = max(0, int(y1 * sy)), min(h_m, int(y2 * sy))
            else:
                mx1 = mx2 = my1 = my2 = None

            # <<< FIXED LOOP: no truthiness on numpy arrays
            for idx in range(len(classes)):
                cid = int(classes[idx])
                if cid == RAIL_ID:
                    continue
                m_low = masks[idx]
                mask_full = cv2.resize(m_low.astype(np.uint8), (W, H),
                                       interpolation=cv2.INTER_NEAREST).astype(bool)
                colour = CLASS_COLOURS.get(cid, (255,255,255))
                overlay[mask_full] = colour

                ys, xs = np.where(mask_full)
                if len(xs):
                    x_c, y_c = int(xs.mean()), int(ys.mean())
                    label = obstacle_classes.get(cid, f"CLASS {cid}")
                    cv2.putText(overlay, label, (x_c-40, y_c),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2, cv2.LINE_AA)
                    cv2.putText(overlay, label, (x_c-40, y_c),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 1, cv2.LINE_AA)

                if jake_xyxy and (mx2 > mx1) and (my2 > my1):
                    area = int(m_low[my1:my2, mx1:mx2].sum())
                    if best_idx is None:
                        best_idx = idx
                    else:
                        prev = int(masks[best_idx][my1:my2, mx1:mx2].sum())
                        if area > prev:
                            best_idx = idx

            left = cv2.addWeighted(overlay, 0.6, left, 0.4, 0)

            # triangles
            n_lbl, lbl_mat, stats, _ = cv2.connectedComponentsWithStats(dark, 8)
            for lbl in range(1, n_lbl):
                area = stats[lbl, cv2.CC_STAT_AREA]
                if area < MIN_DARK_RED_AREA or area < frac_area_thresh:
                    continue
                ys, xs = np.where(lbl_mat == lbl)
                y_top  = ys.min()
                x_mid  = int(xs[ys == ys.min()].mean())
                draw_triangle(left, x_mid, y_top)

            # Jake bbox + under-mask tint
            if jake_xyxy:
                x1, y1, x2, y2 = jake_xyxy
                cv2.rectangle(left, (x1, y1), (x2, y2), JAKE_BOX_CLR, 2)
                if best_idx is not None:
                    best_mask_full = cv2.resize(masks[best_idx].astype(np.uint8), (W, H),
                                                interpolation=cv2.INTER_NEAREST).astype(bool)
                    pink_layer = left.copy()
                    pink_layer[best_mask_full] = UNDER_TINT_BGR
                    left = cv2.addWeighted(pink_layer, 0.35, left, 0.65, 0)

            # RIGHT rails + green
            right = frame.copy()
            rails_tinted = right.copy()
            rails_tinted[rail_mask] = (0,0,255)
            right = cv2.addWeighted(rails_tinted, ALPHA, right, 1-ALPHA, 0)
            right[green] = (0,255,0)

            # BOTTOM heat
            heat = cv2.applyColorMap(score, cv2.COLORMAP_JET)
            heat = cv2.resize(heat, (left.shape[1] + right.shape[1], H),
                              interpolation=cv2.INTER_LINEAR)

        # timing
        dt_ms = (time.perf_counter() - t0) * 1000.0
        if i >= WARMUP_FRAMES:
            times_ms.append(dt_ms)
        processed += 1

        # Display & controls
        canvas = assemble_canvas(left, right, heat)
        canvas = maybe_downscale(canvas, MAX_DISPLAY_W)
        tag = f"{os.path.basename(p)}  |  proc: {dt_ms:.1f} ms"
        cv2.putText(canvas, tag, (12, 28),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,0), 3, cv2.LINE_AA)
        cv2.putText(canvas, tag, (12, 28),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, TEXT_CLR, 1, cv2.LINE_AA)
        cv2.putText(canvas, "SPACE: next   q/ESC: quit",
                    (12, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0), 3, cv2.LINE_AA)
        cv2.putText(canvas, "SPACE: next   q/ESC: quit",
                    (12, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, TEXT_CLR, 1, cv2.LINE_AA)

        while True:
            cv2.imshow(WIN_NAME, canvas)
            key = cv2.waitKey(1) & 0xFF
            if cv2.getWindowProperty(WIN_NAME, cv2.WND_PROP_VISIBLE) < 1:
                infer_in_q.put(None); cv2.destroyAllWindows(); sys.exit(0)
            if key == 32:  # SPACE
                break
            if key in (ord('q'), 27):
                infer_in_q.put(None); cv2.destroyAllWindows()
                # summary
                if times_ms:
                    total_time = (time.perf_counter() - t_all_start)
                    print("\n───── Speed-test summary (warm-up skipped) ─────")
                    print(f"Total frames           : {processed}")
                    print(f"Warm-up frames ignored : {min(WARMUP_FRAMES, processed)}")
                    print(f"Frames timed           : {len(times_ms)}")
                    print(f"Total wall-clock time  : {total_time:,.2f} s")
                    avg = statistics.mean(times_ms)
                    print(f"Average / frame        : {avg:,.2f} ms  ({1000.0/avg:,.2f} FPS)")
                    print(f"Median                 : {statistics.median(times_ms):,.2f} ms")
                    print(f"Fastest                : {min(times_ms):,.2f} ms")
                    print(f"Slowest                : {max(times_ms):,.2f} ms")
                    print("────────────────────────────────────────────────")
                sys.exit(0)

        i += 1

    cv2.destroyAllWindows()
    total_time = (time.perf_counter() - t_all_start)
    if times_ms:
        avg = statistics.mean(times_ms)
        print("\n───── Speed-test summary (warm-up skipped) ─────")
        print(f"Total frames           : {processed}")
        print(f"Warm-up frames ignored : {min(WARMUP_FRAMES, processed)}")
        print(f"Frames timed           : {len(times_ms)}")
        print(f"Total wall-clock time  : {total_time:,.2f} s")
        print(f"Average / frame        : {avg:,.2f} ms  ({1000.0/avg:,.2f} FPS)")
        print(f"Median                 : {statistics.median(times_ms):,.2f} ms")
        print(f"Fastest                : {min(times_ms):,.2f} ms")
        print(f"Slowest                : {max(times_ms):,.2f} ms")
        print("────────────────────────────────────────────────")


In [None]:
#Internal printing -> tryign to instane real move logic now

In [None]:
#!/usr/bin/env python3
"""
JN/VSCode inline visualiser + speed-test for Subway Surfers pipeline.

• Inline display: stacks each processed frame in the notebook (no cv2 windows)
• SHOW_FIRST_N controls how many frames are displayed (rest still processed)
• Pipelined: prefetch → persistent inference worker → CPU postproc
• YOLO on GPU/MPS, RF-DETR on CPU by default to overlap compute
• Warmup frames executed but not timed; summary printed at the end
"""

import os, glob, sys, time, statistics, threading, queue
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
from PIL import Image
from ultralytics import YOLO
from rfdetr import RFDETRBase

# Notebook display (inline)
try:
    from IPython.display import display
    _HAS_IPY = True
except Exception:
    _HAS_IPY = False

# ── disable antialias to avoid unsupported MPS op on some ops ─
_original_interpolate = F.interpolate
def _interpolate_no_antialias(input, *args, **kwargs):
    kwargs['antialias'] = False
    return _original_interpolate(input, *args, **kwargs)
F.interpolate = _interpolate_no_antialias

# =======================
# Config
# =======================
home          = os.path.expanduser("~")
yolo_weights  = f"{home}/models/jakes-loped/jakes-finder-mk1/1/weights.pt"
jake_weights  = f"{home}/downloads/weightsjake.pt"

# ── New: only SHOW the first N frames inline (still process all) ──
SHOW_FIRST_N  = 20

RAIL_ID       = 9
IMG_SIZE      = 512
CONF, IOU     = 0.30, 0.45
ALPHA         = 0.40

TARGET_COLORS_RGB  = [(119,104,67), (81,42,45)]
TOLERANCE          = 20.0
MIN_REGION_SIZE    = 30
MIN_REGION_HEIGHT  = 150

HEAT_BLUR_KSIZE     = 51
RED_SCORE_THRESH    = 220
EXCLUDE_TOP_FRAC    = 0.40
EXCLUDE_BOTTOM_FRAC = 0.15
MIN_DARK_RED_AREA   = 1200
TRI_SIZE_PX         = 18
PURPLE              = (255, 0, 255)
MIN_DARK_FRACTION   = 0.15

MAX_DISPLAY_W   = 1600
TEXT_CLR        = (255, 255, 255)
JAKE_BOX_CLR    = (0, 255, 0)
UNDER_TINT_BGR  = (255, 0, 255)

WARMUP_FRAMES   = 10
PREFETCH        = 24    # bigger prefetch helps saturate I/O
INFER_QUEUE     = 4     # frames in-flight through inference

# CPU threading (OpenCV)
OPENCV_THREADS  = max(1, (os.cpu_count() or 8) - 1)

# =======================
# Device selection
# =======================
if torch.cuda.is_available():
    yolo_device, half = 0, True
    DET_ON_CPU = True
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    yolo_device, half = "mps", False
    DET_ON_CPU = True   # RF-DETR on CPU to overlap with MPS
else:
    yolo_device, half = "cpu", False
    DET_ON_CPU = False  # both CPU; pipeline still helps

# =======================
# Models
# =======================
yolo_model = YOLO(yolo_weights)
try: yolo_model.fuse()
except Exception: pass

jake_model = RFDETRBase(pretrain_weights=jake_weights, num_classes=3)

# warmup YOLO
_dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
_ = yolo_model.predict(_dummy, task="segment", imgsz=IMG_SIZE,
                       device=yolo_device, conf=CONF, iou=IOU,
                       verbose=False, half=half)

# OpenCV threads
try:
    cv2.setNumThreads(OPENCV_THREADS)
except Exception:
    pass

# =======================
# Utils
# =======================
obstacle_classes = {
    0:"BOOTS",1:"GREYTRAIN",2:"HIGHBARRIER1",3:"JUMP",4:"LOWBARRIER1",
    5:"LOWBARRIER2",6:"ORANGETRAIN",7:"PILLAR",8:"RAMP",9:"RAILS",
    10:"SIDEWALK",11:"YELLOWTRAIN"
}
CLASS_COLOURS = {
    0:(255,255,0),1:(192,192,192),2:(0,128,255),3:(0,255,0),
    4:(255,0,255),5:(0,255,255),6:(255,128,0),7:(128,0,255),
    8:(0,0,128),10:(128,128,0),11:(255,255,102)
}

def highlight_rails_mask_only_fast(img_bgr, rail_mask, target_colors_rgb,
                                   tolerance=30.0, min_region_size=50,
                                   min_region_height=150):
    H, W = img_bgr.shape[:2]
    filtered_full = np.zeros((H, W), dtype=bool)
    if not rail_mask.any():
        return filtered_full
    ys, xs = np.where(rail_mask)
    y0, y1 = ys.min(), ys.max()+1
    x0, x1 = xs.min(), xs.max()+1
    img_roi  = img_bgr[y0:y1, x0:x1]
    mask_roi = rail_mask[y0:y1, x0:x1]
    targets_bgr = np.array([(r,g,b)[::-1] for r,g,b in target_colors_rgb], dtype=np.float32)
    img_f = img_roi.astype(np.float32)
    diff  = img_f[:, :, None, :] - targets_bgr[None, None, :, :]
    dist2 = np.sum(diff*diff, axis=-1)
    colour_hit = np.any(dist2 <= tolerance**2, axis=-1)
    combined = colour_hit & mask_roi
    comp = combined.astype(np.uint8)
    n, lbls, stats, _ = cv2.connectedComponentsWithStats(comp, 8)
    good = np.zeros_like(combined)
    for lbl in range(1, n):
        area, h = stats[lbl, cv2.CC_STAT_AREA], stats[lbl, cv2.CC_STAT_HEIGHT]
        if area >= min_region_size and h >= min_region_height:
            good[lbls == lbl] = True
    filtered_full[y0:y1, x0:x1] = good
    return filtered_full

def red_vs_green_score(red_mask, green_mask, blur_ksize=51):
    k = (blur_ksize, blur_ksize)
    r = cv2.blur(red_mask.astype(np.float32), k)
    g = cv2.blur(green_mask.astype(np.float32), k)
    diff = r - g
    amax = float(np.max(np.abs(diff))) + 1e-6
    norm = (diff / (2*amax) + 0.5)
    return np.clip(norm * 255, 0, 255).astype(np.uint8)

def draw_triangle(img, x, y, size=TRI_SIZE_PX, colour=PURPLE):
    h = int(size * 1.2)
    pts = np.array([[x, y], [x-size, y+h], [x+size, y+h]], np.int32)
    cv2.fillConvexPoly(img, pts, colour)
    cv2.polylines(img, [pts.reshape(-1,1,2)], True, (0,0,0), 1, cv2.LINE_AA)

def assemble_canvas(left, right, heat):
    top    = np.hstack((left, right))
    return np.vstack((top, heat))

def maybe_downscale(img, max_w=MAX_DISPLAY_W):
    h, w = img.shape[:2]
    if w <= max_w: return img
    scale = max_w / float(w)
    return cv2.resize(img, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)

# =======================
# Inference worker (persistent)
# =======================
class InferenceWorker(threading.Thread):
    """
    Consumes frames from in_q, runs YOLO (GPU/MPS) + RF-DETR (CPU optionally),
    returns (path, frame, masks(np or None), classes(np or None), jake_xyxy or None)
    """
    def __init__(self, in_q, out_q):
        super().__init__(daemon=True)
        self.in_q  = in_q
        self.out_q = out_q
        self.stop  = False

    def run(self):
        while not self.stop:
            item = self.in_q.get()
            if item is None:  # sentinel
                self.out_q.put(None)
                break
            path, frame = item
            H, W = frame.shape[:2]

            # YOLO (device)
            res = yolo_model.predict(frame, task="segment", imgsz=IMG_SIZE,
                                     device=yolo_device, conf=CONF, iou=IOU,
                                     max_det=30, verbose=False, half=half)[0]

            # RF-DETR (CPU to overlap by default)
            jake_xyxy = None
            try:
                pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                det_dev = "cpu" if DET_ON_CPU else ("cpu" if yolo_device != 0 else 0)
                dets = jake_model.predict(pil_img, threshold=0.5, device=det_dev)[0]
                if hasattr(dets, "xyxy") and len(dets.xyxy) > 0:
                    x1, y1, x2, y2 = dets.xyxy[0].astype(int).tolist()
                    # clamp
                    x1 = max(0, min(W-1, x1)); x2 = max(0, min(W-1, x2))
                    y1 = max(0, min(H-1, y1)); y2 = max(0, min(H-1, y2))
                    if x2 > x1 and y2 > y1:
                        jake_xyxy = (x1, y1, x2, y2)
            except Exception:
                pass

            if res.masks is None:
                self.out_q.put((path, frame, None, None, jake_xyxy))
                continue

            masks   = res.masks.data.cpu().numpy()  # (N, h_m, w_m)
            classes = (res.masks.cls.cpu().numpy() if hasattr(res.masks, "cls")
                       else res.boxes.cls.cpu().numpy())
            classes = np.asarray(classes, dtype=int).ravel()  # flatten once

            self.out_q.put((path, frame, masks, classes, jake_xyxy))

# =======================
# Prefetcher
# =======================
class FramePrefetcher(threading.Thread):
    def __init__(self, paths, out_q, maxsize=PREFETCH):
        super().__init__(daemon=True)
        self.paths = paths
        self.q = out_q
        self.stop = False

    def run(self):
        for p in self.paths:
            if self.stop: break
            img = cv2.imread(p)
            self.q.put((p, img))
        self.q.put(None)  # sentinel

# =======================
# Main (JN-friendly)
# =======================
if __name__ == "__main__":
    folder = Path.home() / "Documents" / "GitHub" / "Ai-plays-SubwaySurfers" / "frames"
    if not folder.is_dir():
        print(f"[ERROR] frames folder not found: {folder}", file=sys.stderr); sys.exit(1)

    paths = sorted(
        glob.glob(str(folder / "frame_*.jpg")) +
        glob.glob(str(folder / "frame_*.png")) +
        glob.glob(str(folder / "*.jpg")) +
        glob.glob(str(folder / "*.png"))
    )
    if not paths:
        print(f"[ERROR] No frame images found in {folder}", file=sys.stderr); sys.exit(1)

    # Queues
    prefetch_q = queue.Queue(maxsize=PREFETCH)
    infer_in_q = queue.Queue(maxsize=INFER_QUEUE)
    infer_out_q = queue.Queue(maxsize=INFER_QUEUE)

    # Threads
    pf = FramePrefetcher(paths, prefetch_q, maxsize=PREFETCH)
    pf.start()

    iw = InferenceWorker(infer_in_q, infer_out_q)
    iw.start()

    # Prime inference queue with a few frames
    buffered = 0
    while buffered < INFER_QUEUE:
        item = prefetch_q.get()
        infer_in_q.put(item)
        if item is None: break
        buffered += 1

    times_ms = []
    processed = 0
    shown = 0
    t_all_start = time.perf_counter()

    i = 0
    while True:
        # Keep inference fed
        while not prefetch_q.empty() and infer_in_q.qsize() < INFER_QUEUE:
            infer_in_q.put(prefetch_q.get())

        out = infer_out_q.get()
        if out is None:
            break

        p, frame, masks, classes, jake_xyxy = out
        if frame is None:
            i += 1
            continue

        # Safety: ensure classes shape
        if classes is not None:
            classes = np.asarray(classes, dtype=int).ravel()

        t0 = time.perf_counter()

        H, W = frame.shape[:2]
        if masks is None or classes is None:
            left = frame.copy(); right = frame.copy()
            heat = cv2.applyColorMap(np.zeros((H, W), np.uint8), cv2.COLORMAP_JET)
        else:
            h_m, w_m = masks.shape[1:]
            rail_union = np.zeros((h_m, w_m), dtype=bool)
            for m, c in zip(masks, classes):
                if int(c) == RAIL_ID:
                    rail_union |= m.astype(bool)
            rail_mask = cv2.resize(rail_union.astype(np.uint8), (W, H),
                                   interpolation=cv2.INTER_NEAREST).astype(bool)

            # Green/red + score
            green = highlight_rails_mask_only_fast(frame, rail_mask,
                                                   TARGET_COLORS_RGB, TOLERANCE,
                                                   MIN_REGION_SIZE, MIN_REGION_HEIGHT)
            red   = rail_mask & ~green
            score = red_vs_green_score(red, green, HEAT_BLUR_KSIZE)

            # Exclude bands
            top_ex = int(H * EXCLUDE_TOP_FRAC)
            bot_ex = int(H * EXCLUDE_BOTTOM_FRAC)
            score[:top_ex, :] = 0
            if bot_ex: score[H-bot_ex:, :] = 0

            dark = (score >= RED_SCORE_THRESH).astype(np.uint8)
            dark = cv2.morphologyEx(
                dark, cv2.MORPH_OPEN,
                cv2.getStructuringElement(cv2.MORPH_RECT, (5, 9)),
                iterations=1
            )

            total_dark_area  = int(dark.sum())
            frac_area_thresh = int(np.ceil(MIN_DARK_FRACTION * total_dark_area))

            # LEFT labels + under-mask tint
            left    = frame.copy()
            overlay = left.copy()

            # Jake overlap (low-res bbox)
            best_idx = None
            if jake_xyxy and masks is not None:
                x1, y1, x2, y2 = jake_xyxy
                sx, sy = w_m / W, h_m / H
                mx1, mx2 = max(0, int(x1 * sx)), min(w_m, int(x2 * sx))
                my1, my2 = max(0, int(y1 * sy)), min(h_m, int(y2 * sy))
            else:
                mx1 = mx2 = my1 = my2 = None

            for idx in range(len(classes)):
                cid = int(classes[idx])
                if cid == RAIL_ID:
                    continue
                m_low = masks[idx]
                mask_full = cv2.resize(m_low.astype(np.uint8), (W, H),
                                       interpolation=cv2.INTER_NEAREST).astype(bool)
                colour = CLASS_COLOURS.get(cid, (255,255,255))
                overlay[mask_full] = colour

                ys, xs = np.where(mask_full)
                if len(xs):
                    x_c, y_c = int(xs.mean()), int(ys.mean())
                    label = obstacle_classes.get(cid, f"CLASS {cid}")
                    cv2.putText(overlay, label, (x_c-40, y_c),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2, cv2.LINE_AA)
                    cv2.putText(overlay, label, (x_c-40, y_c),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 1, cv2.LINE_AA)

                if jake_xyxy and (mx2 > mx1) and (my2 > my1):
                    area = int(m_low[my1:my2, mx1:mx2].sum())
                    if best_idx is None:
                        best_idx = idx
                    else:
                        prev = int(masks[best_idx][my1:my2, mx1:mx2].sum())
                        if area > prev:
                            best_idx = idx

            left = cv2.addWeighted(overlay, 0.6, left, 0.4, 0)

            # Purple-triangle warnings
            n_lbl, lbl_mat, stats, _ = cv2.connectedComponentsWithStats(dark, 8)
            for lbl in range(1, n_lbl):
                area = stats[lbl, cv2.CC_STAT_AREA]
                if area < MIN_DARK_RED_AREA or area < frac_area_thresh:
                    continue
                ys, xs = np.where(lbl_mat == lbl)
                y_top  = ys.min()
                x_mid  = int(xs[ys == ys.min()].mean())
                draw_triangle(left, x_mid, y_top)

            # Jake bbox + under-mask tint
            if jake_xyxy:
                x1, y1, x2, y2 = jake_xyxy
                cv2.rectangle(left, (x1, y1), (x2, y2), JAKE_BOX_CLR, 2)
                if best_idx is not None:
                    best_mask_full = cv2.resize(masks[best_idx].astype(np.uint8), (W, H),
                                                interpolation=cv2.INTER_NEAREST).astype(bool)
                    pink_layer = left.copy()
                    pink_layer[best_mask_full] = UNDER_TINT_BGR
                    left = cv2.addWeighted(pink_layer, 0.35, left, 0.65, 0)

            # RIGHT rails + green
            right = frame.copy()
            rails_tinted = right.copy()
            rails_tinted[rail_mask] = (0,0,255)
            right = cv2.addWeighted(rails_tinted, ALPHA, right, 1-ALPHA, 0)
            right[green] = (0,255,0)

            # BOTTOM heat
            heat = cv2.applyColorMap(score, cv2.COLORMAP_JET)
            heat = cv2.resize(heat, (left.shape[1] + right.shape[1], H),
                              interpolation=cv2.INTER_LINEAR)

        # timing
        dt_ms = (time.perf_counter() - t0) * 1000.0
        if i >= WARMUP_FRAMES:
            times_ms.append(dt_ms)
        processed += 1

        # Build canvas for inline display (only for first SHOW_FIRST_N)
        if _HAS_IPY and shown < SHOW_FIRST_N:
            canvas = assemble_canvas(left, right, heat)
            canvas = maybe_downscale(canvas, MAX_DISPLAY_W)
            tag = f"{os.path.basename(p)}  |  proc: {dt_ms:.1f} ms"
            cv2.putText(canvas, tag, (12, 28),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,0), 3, cv2.LINE_AA)
            cv2.putText(canvas, tag, (12, 28),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.8, TEXT_CLR, 1, cv2.LINE_AA)
            # show inline
            rgb = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
            display(Image.fromarray(rgb))
            shown += 1

        i += 1

    # Done
    total_time = (time.perf_counter() - t_all_start)
    if times_ms:
        avg = statistics.mean(times_ms)
        print("\n───── Speed-test summary (warm-up skipped) ─────")
        print(f"Total frames           : {processed}")
        print(f"Warm-up frames ignored : {min(WARMUP_FRAMES, processed)}")
        print(f"Frames timed           : {len(times_ms)}")
        print(f"Total wall-clock time  : {total_time:,.2f} s")
        print(f"Average / frame        : {avg:,.2f} ms  ({1000.0/avg:,.2f} FPS)")
        print(f"Median                 : {statistics.median(times_ms):,.2f} ms")
        print(f"Fastest                : {min(times_ms):,.2f} ms")
        print(f"Slowest                : {max(times_ms):,.2f} ms")
        print("────────────────────────────────────────────────")


In [None]:
#Updated timing metrics and image sizing

In [None]:
#!/usr/bin/env python3
"""
JN/VSCode inline visualiser + speed-test for Subway Surfers pipeline.

• Inline display: stacks each processed frame in the notebook (no cv2 windows)
• SHOW_FIRST_N controls how many frames are displayed (rest still processed)
• Pipelined: prefetch → persistent inference worker → CPU postproc
• YOLO on GPU/MPS, RF-DETR on CPU by default to overlap compute
• Warmup frames executed but not timed; summary printed at the end
• Timing excludes ANY display/printing work (only core processing is timed)
"""

import os, glob, sys, time, statistics, threading, queue
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
from PIL import Image
from ultralytics import YOLO
from rfdetr import RFDETRBase

# Notebook display (inline)
try:
    from IPython.display import display
    _HAS_IPY = True
except Exception:
    _HAS_IPY = False

# ── disable antialias to avoid unsupported MPS op on some ops ─
_original_interpolate = F.interpolate
def _interpolate_no_antialias(input, *args, **kwargs):
    kwargs['antialias'] = False
    return _original_interpolate(input, *args, **kwargs)
F.interpolate = _interpolate_no_antialias

# =======================
# Config
# =======================
home          = os.path.expanduser("~")
yolo_weights  = f"{home}/models/jakes-loped/jakes-finder-mk1/1/weights.pt"
jake_weights  = f"{home}/downloads/weightsjake.pt"

# ── Only SHOW the first N frames inline (still process all) ──
SHOW_FIRST_N   = 20

# ── Make inline images neater/smaller ──
DISPLAY_MAX_W  = 960    # target max width of the displayed canvas (px)
TEXT_SCALE     = 0.6    # overlay text size on canvas
TEXT_THICK     = 1      # overlay text thickness

RAIL_ID       = 9
IMG_SIZE      = 512
CONF, IOU     = 0.30, 0.45
ALPHA         = 0.40

TARGET_COLORS_RGB  = [(119,104,67), (81,42,45)]
TOLERANCE          = 20.0
MIN_REGION_SIZE    = 30
MIN_REGION_HEIGHT  = 150

HEAT_BLUR_KSIZE     = 51
RED_SCORE_THRESH    = 220
EXCLUDE_TOP_FRAC    = 0.40
EXCLUDE_BOTTOM_FRAC = 0.15
MIN_DARK_RED_AREA   = 1200
TRI_SIZE_PX         = 18
PURPLE              = (255, 0, 255)
MIN_DARK_FRACTION   = 0.15

TEXT_CLR        = (255, 255, 255)
JAKE_BOX_CLR    = (0, 255, 0)
UNDER_TINT_BGR  = (255, 0, 255)

WARMUP_FRAMES   = 10
PREFETCH        = 24    # prefetch helps saturate I/O
INFER_QUEUE     = 4     # frames in-flight through inference

# CPU threading (OpenCV)
OPENCV_THREADS  = max(1, (os.cpu_count() or 8) - 1)

# =======================
# Device selection
# =======================
if torch.cuda.is_available():
    yolo_device, half = 0, True
    DET_ON_CPU = True
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    yolo_device, half = "mps", False
    DET_ON_CPU = True   # RF-DETR on CPU to overlap with MPS
else:
    yolo_device, half = "cpu", False
    DET_ON_CPU = False  # both CPU; pipeline still helps

# =======================
# Models
# =======================
yolo_model = YOLO(yolo_weights)
try: yolo_model.fuse()
except Exception: pass

jake_model = RFDETRBase(pretrain_weights=jake_weights, num_classes=3)

# warmup YOLO
_dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
_ = yolo_model.predict(_dummy, task="segment", imgsz=IMG_SIZE,
                       device=yolo_device, conf=CONF, iou=IOU,
                       verbose=False, half=half)

# OpenCV threads
try:
    cv2.setNumThreads(OPENCV_THREADS)
except Exception:
    pass

# =======================
# Utils
# =======================
obstacle_classes = {
    0:"BOOTS",1:"GREYTRAIN",2:"HIGHBARRIER1",3:"JUMP",4:"LOWBARRIER1",
    5:"LOWBARRIER2",6:"ORANGETRAIN",7:"PILLAR",8:"RAMP",9:"RAILS",
    10:"SIDEWALK",11:"YELLOWTRAIN"
}
CLASS_COLOURS = {
    0:(255,255,0),1:(192,192,192),2:(0,128,255),3:(0,255,0),
    4:(255,0,255),5:(0,255,255),6:(255,128,0),7:(128,0,255),
    8:(0,0,128),10:(128,128,0),11:(255,255,102)
}

def highlight_rails_mask_only_fast(img_bgr, rail_mask, target_colors_rgb,
                                   tolerance=30.0, min_region_size=50,
                                   min_region_height=150):
    H, W = img_bgr.shape[:2]
    filtered_full = np.zeros((H, W), dtype=bool)
    if not rail_mask.any():
        return filtered_full
    ys, xs = np.where(rail_mask)
    y0, y1 = ys.min(), ys.max()+1
    x0, x1 = xs.min(), xs.max()+1
    img_roi  = img_bgr[y0:y1, x0:x1]
    mask_roi = rail_mask[y0:y1, x0:x1]
    targets_bgr = np.array([(r,g,b)[::-1] for r,g,b in target_colors_rgb], dtype=np.float32)
    img_f = img_roi.astype(np.float32)
    diff  = img_f[:, :, None, :] - targets_bgr[None, None, :, :]
    dist2 = np.sum(diff*diff, axis=-1)
    colour_hit = np.any(dist2 <= tolerance**2, axis=-1)
    combined = colour_hit & mask_roi
    comp = combined.astype(np.uint8)
    n, lbls, stats, _ = cv2.connectedComponentsWithStats(comp, 8)
    good = np.zeros_like(combined)
    for lbl in range(1, n):
        area, h = stats[lbl, cv2.CC_STAT_AREA], stats[lbl, cv2.CC_STAT_HEIGHT]
        if area >= min_region_size and h >= min_region_height:
            good[lbls == lbl] = True
    filtered_full[y0:y1, x0:x1] = good
    return filtered_full

def red_vs_green_score(red_mask, green_mask, blur_ksize=51):
    k = (blur_ksize, blur_ksize)
    r = cv2.blur(red_mask.astype(np.float32), k)
    g = cv2.blur(green_mask.astype(np.float32), k)
    diff = r - g
    amax = float(np.max(np.abs(diff))) + 1e-6
    norm = (diff / (2*amax) + 0.5)
    return np.clip(norm * 255, 0, 255).astype(np.uint8)

def draw_triangle(img, x, y, size=TRI_SIZE_PX, colour=PURPLE):
    h = int(size * 1.2)
    pts = np.array([[x, y], [x-size, y+h], [x+size, y+h]], np.int32)
    cv2.fillConvexPoly(img, pts, colour)
    cv2.polylines(img, [pts.reshape(-1,1,2)], True, (0,0,0), 1, cv2.LINE_AA)

def assemble_canvas(left, right, heat):
    top    = np.hstack((left, right))
    return np.vstack((top, heat))

def resize_to_width(img, max_w=DISPLAY_MAX_W):
    """Resize maintaining aspect ratio if wider than max_w."""
    h, w = img.shape[:2]
    if w <= max_w:
        return img
    scale = max_w / float(w)
    return cv2.resize(img, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)

# =======================
# Inference worker (persistent)
# =======================
class InferenceWorker(threading.Thread):
    """
    Consumes frames from in_q, runs YOLO (GPU/MPS) + RF-DETR (CPU optionally),
    returns (path, frame, masks(np or None), classes(np or None), jake_xyxy or None)
    """
    def __init__(self, in_q, out_q):
        super().__init__(daemon=True)
        self.in_q  = in_q
        self.out_q = out_q
        self.stop  = False

    def run(self):
        while not self.stop:
            item = self.in_q.get()
            if item is None:  # sentinel
                self.out_q.put(None)
                break
            path, frame = item
            H, W = frame.shape[:2]

            # YOLO (device)
            res = yolo_model.predict(frame, task="segment", imgsz=IMG_SIZE,
                                     device=yolo_device, conf=CONF, iou=IOU,
                                     max_det=30, verbose=False, half=half)[0]

            # RF-DETR (CPU to overlap by default)
            jake_xyxy = None
            try:
                pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                det_dev = "cpu" if DET_ON_CPU else ("cpu" if yolo_device != 0 else 0)
                dets = jake_model.predict(pil_img, threshold=0.5, device=det_dev)[0]
                if hasattr(dets, "xyxy") and len(dets.xyxy) > 0:
                    x1, y1, x2, y2 = dets.xyxy[0].astype(int).tolist()
                    # clamp
                    x1 = max(0, min(W-1, x1)); x2 = max(0, min(W-1, x2))
                    y1 = max(0, min(H-1, y1)); y2 = max(0, min(H-1, y2))
                    if x2 > x1 and y2 > y1:
                        jake_xyxy = (x1, y1, x2, y2)
            except Exception:
                pass

            if res.masks is None:
                self.out_q.put((path, frame, None, None, jake_xyxy))
                continue

            masks   = res.masks.data.cpu().numpy()  # (N, h_m, w_m)
            classes = (res.masks.cls.cpu().numpy() if hasattr(res.masks, "cls")
                       else res.boxes.cls.cpu().numpy())
            classes = np.asarray(classes, dtype=int).ravel()  # flatten once

            self.out_q.put((path, frame, masks, classes, jake_xyxy))

# =======================
# Prefetcher
# =======================
class FramePrefetcher(threading.Thread):
    def __init__(self, paths, out_q, maxsize=PREFETCH):
        super().__init__(daemon=True)
        self.paths = paths
        self.q = out_q
        self.stop = False

    def run(self):
        for p in self.paths:
            if self.stop: break
            img = cv2.imread(p)
            self.q.put((p, img))
        self.q.put(None)  # sentinel

# =======================
# Main (JN-friendly)
# =======================
if __name__ == "__main__":
    folder = Path.home() / "Documents" / "GitHub" / "Ai-plays-SubwaySurfers" / "frames"
    if not folder.is_dir():
        print(f"[ERROR] frames folder not found: {folder}", file=sys.stderr); sys.exit(1)

    paths = sorted(
        glob.glob(str(folder / "frame_*.jpg")) +
        glob.glob(str(folder / "frame_*.png")) +
        glob.glob(str(folder / "*.jpg")) +
        glob.glob(str(folder / "*.png"))
    )
    if not paths:
        print(f"[ERROR] No frame images found in {folder}", file=sys.stderr); sys.exit(1)

    # Queues
    prefetch_q = queue.Queue(maxsize=PREFETCH)
    infer_in_q = queue.Queue(maxsize=INFER_QUEUE)
    infer_out_q = queue.Queue(maxsize=INFER_QUEUE)

    # Threads
    pf = FramePrefetcher(paths, prefetch_q, maxsize=PREFETCH)
    pf.start()

    iw = InferenceWorker(infer_in_q, infer_out_q)
    iw.start()

    # Prime inference queue with a few frames
    buffered = 0
    while buffered < INFER_QUEUE:
        item = prefetch_q.get()
        infer_in_q.put(item)
        if item is None: break
        buffered += 1

    times_ms = []
    processed = 0
    shown = 0
    t_all_start = time.perf_counter()

    i = 0
    while True:
        # Keep inference fed
        while not prefetch_q.empty() and infer_in_q.qsize() < INFER_QUEUE:
            infer_in_q.put(prefetch_q.get())

        out = infer_out_q.get()
        if out is None:
            break

        p, frame, masks, classes, jake_xyxy = out
        if frame is None:
            i += 1
            continue

        # Safety: ensure classes shape
        if classes is not None:
            classes = np.asarray(classes, dtype=int).ravel()

        # ─────────────────────── TIMING START (processing only) ───────────────────────
        t0 = time.perf_counter()

        H, W = frame.shape[:2]
        if masks is None or classes is None:
            left = frame.copy(); right = frame.copy()
            heat = cv2.applyColorMap(np.zeros((H, W), np.uint8), cv2.COLORMAP_JET)
        else:
            h_m, w_m = masks.shape[1:]
            rail_union = np.zeros((h_m, w_m), dtype=bool)
            for m, c in zip(masks, classes):
                if int(c) == RAIL_ID:
                    rail_union |= m.astype(bool)
            rail_mask = cv2.resize(rail_union.astype(np.uint8), (W, H),
                                   interpolation=cv2.INTER_NEAREST).astype(bool)

            # Green/red + score
            green = highlight_rails_mask_only_fast(frame, rail_mask,
                                                   TARGET_COLORS_RGB, TOLERANCE,
                                                   MIN_REGION_SIZE, MIN_REGION_HEIGHT)
            red   = rail_mask & ~green
            score = red_vs_green_score(red, green, HEAT_BLUR_KSIZE)

            # Exclude bands
            top_ex = int(H * EXCLUDE_TOP_FRAC)
            bot_ex = int(H * EXCLUDE_BOTTOM_FRAC)
            score[:top_ex, :] = 0
            if bot_ex: score[H-bot_ex:, :] = 0

            dark = (score >= RED_SCORE_THRESH).astype(np.uint8)
            dark = cv2.morphologyEx(
                dark, cv2.MORPH_OPEN,
                cv2.getStructuringElement(cv2.MORPH_RECT, (5, 9)),
                iterations=1
            )

            total_dark_area  = int(dark.sum())
            frac_area_thresh = int(np.ceil(MIN_DARK_FRACTION * total_dark_area))

            # LEFT labels + under-mask tint
            left    = frame.copy()
            overlay = left.copy()

            # Jake overlap (low-res bbox)
            best_idx = None
            if jake_xyxy and masks is not None:
                x1, y1, x2, y2 = jake_xyxy
                sx, sy = w_m / W, h_m / H
                mx1, mx2 = max(0, int(x1 * sx)), min(w_m, int(x2 * sx))
                my1, my2 = max(0, int(y1 * sy)), min(h_m, int(y2 * sy))
            else:
                mx1 = mx2 = my1 = my2 = None

            for idx in range(len(classes)):
                cid = int(classes[idx])
                if cid == RAIL_ID:
                    continue
                m_low = masks[idx]
                mask_full = cv2.resize(m_low.astype(np.uint8), (W, H),
                                       interpolation=cv2.INTER_NEAREST).astype(bool)
                colour = CLASS_COLOURS.get(cid, (255,255,255))
                overlay[mask_full] = colour

                ys, xs = np.where(mask_full)
                if len(xs):
                    x_c, y_c = int(xs.mean()), int(ys.mean())
                    label = obstacle_classes.get(cid, f"CLASS {cid}")
                    cv2.putText(overlay, label, (x_c-40, y_c),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2, cv2.LINE_AA)
                    cv2.putText(overlay, label, (x_c-40, y_c),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 1, cv2.LINE_AA)

                if jake_xyxy and (mx2 > mx1) and (my2 > my1):
                    area = int(m_low[my1:my2, mx1:mx2].sum())
                    if best_idx is None:
                        best_idx = idx
                    else:
                        prev = int(masks[best_idx][my1:my2, mx1:mx2].sum())
                        if area > prev:
                            best_idx = idx

            left = cv2.addWeighted(overlay, 0.6, left, 0.4, 0)

            # Purple-triangle warnings
            n_lbl, lbl_mat, stats, _ = cv2.connectedComponentsWithStats(dark, 8)
            for lbl in range(1, n_lbl):
                area = stats[lbl, cv2.CC_STAT_AREA]
                if area < MIN_DARK_RED_AREA or area < frac_area_thresh:
                    continue
                ys, xs = np.where(lbl_mat == lbl)
                y_top  = ys.min()
                x_mid  = int(xs[ys == ys.min()].mean())
                draw_triangle(left, x_mid, y_top)

            # Jake bbox + under-mask tint
            if jake_xyxy:
                x1, y1, x2, y2 = jake_xyxy
                cv2.rectangle(left, (x1, y1), (x2, y2), JAKE_BOX_CLR, 2)
                if best_idx is not None:
                    best_mask_full = cv2.resize(masks[best_idx].astype(np.uint8), (W, H),
                                                interpolation=cv2.INTER_NEAREST).astype(bool)
                    pink_layer = left.copy()
                    pink_layer[best_mask_full] = UNDER_TINT_BGR
                    left = cv2.addWeighted(pink_layer, 0.35, left, 0.65, 0)

            # RIGHT rails + green
            right = frame.copy()
            rails_tinted = right.copy()
            rails_tinted[rail_mask] = (0,0,255)
            right = cv2.addWeighted(rails_tinted, ALPHA, right, 1-ALPHA, 0)
            right[green] = (0,255,0)

            # BOTTOM heat
            heat = cv2.applyColorMap(score, cv2.COLORMAP_JET)
            heat = cv2.resize(heat, (left.shape[1] + right.shape[1], H),
                              interpolation=cv2.INTER_LINEAR)

        # ─────────────────────── TIMING END (processing only) ───────────────────────
        dt_ms = (time.perf_counter() - t0) * 1000.0

        # record timing for non-warmup frames
        if i >= WARMUP_FRAMES:
            times_ms.append(dt_ms)
        processed += 1

        # ── DISPLAY (not timed): assemble, resize, annotate, convert, show ──
        if _HAS_IPY and shown < SHOW_FIRST_N:
            canvas = assemble_canvas(left, right, heat)
            canvas = resize_to_width(canvas, DISPLAY_MAX_W)

            tag = f"{os.path.basename(p)}  |  proc: {dt_ms:.1f} ms"
            cv2.putText(canvas, tag, (12, 28),
                        cv2.FONT_HERSHEY_SIMPLEX, TEXT_SCALE, TEXT_CLR, TEXT_THICK, cv2.LINE_AA)

            rgb = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
            display(Image.fromarray(rgb))
            shown += 1

        i += 1

    # Done
    total_time = (time.perf_counter() - t_all_start)
    if times_ms:
        avg = statistics.mean(times_ms)
        print("\n───── Speed-test summary (warm-up skipped) ─────")
        print(f"Total frames           : {processed}")
        print(f"Warm-up frames ignored : {min(WARMUP_FRAMES, processed)}")
        print(f"Frames timed           : {len(times_ms)}")
        print(f"Total wall-clock time  : {total_time:,.2f} s")
        print(f"Average / frame        : {avg:,.2f} ms  ({1000.0/avg:,.2f} FPS)")
        print(f"Median                 : {statistics.median(times_ms):,.2f} ms")
        print(f"Fastest                : {min(times_ms):,.2f} ms")
        print(f"Slowest                : {max(times_ms):,.2f} ms")
        print("────────────────────────────────────────────────")


In [None]:
#NO CHEATING MODEL

In [None]:
#!/usr/bin/env python3
"""
JN/VSCode inline visualiser + speed-test for Subway Surfers pipeline (honest timings).

• End-to-end per-frame latency (no cheating): inference (YOLO+RF-DETR) + CPU postproc
  - Worker records infer_start_ts / infer_end_ts and returns them
  - Main times postproc and computes total_ms = post_end_ts - infer_start_ts
• Inline display (first N frames) happens AFTER timing, never included
• Pipeline: prefetch → persistent inference worker → CPU postproc
• YOLO on GPU/MPS, RF-DETR on CPU by default to overlap compute
• Warmup frames executed but not included in statistics; summary printed at the end
"""

import os, glob, sys, time, statistics, threading, queue
os.environ.setdefault("PYTORCH_ENABLE_MPS_FALLBACK", "1")

import cv2
import numpy as np
import torch
import torch.nn.functional as F
from pathlib import Path
from PIL import Image
from ultralytics import YOLO
from rfdetr import RFDETRBase

# Notebook display (inline)
try:
    from IPython.display import display
    _HAS_IPY = True
except Exception:
    _HAS_IPY = False

# ── disable antialias to avoid unsupported MPS op on some ops ─
_original_interpolate = F.interpolate
def _interpolate_no_antialias(input, *args, **kwargs):
    kwargs['antialias'] = False
    return _original_interpolate(input, *args, **kwargs)
F.interpolate = _interpolate_no_antialias

# =======================
# Config
# =======================
home          = os.path.expanduser("~")
yolo_weights  = f"{home}/models/jakes-loped/jakes-finder-mk1/1/weights.pt"
jake_weights  = f"{home}/downloads/weightsjake.pt"

# Display: only SHOW the first N frames inline (still process all)
SHOW_FIRST_N   = 15
DISPLAY_MAX_W  = 960    # target max width of displayed canvas
TEXT_SCALE     = 0.6
TEXT_THICK     = 1

RAIL_ID       = 9
IMG_SIZE      = 512
CONF, IOU     = 0.30, 0.45
ALPHA         = 0.40

TARGET_COLORS_RGB  = [(119,104,67), (81,42,45)]
TOLERANCE          = 20.0
MIN_REGION_SIZE    = 30
MIN_REGION_HEIGHT  = 150

HEAT_BLUR_KSIZE     = 51
RED_SCORE_THRESH    = 220
EXCLUDE_TOP_FRAC    = 0.40
EXCLUDE_BOTTOM_FRAC = 0.15
MIN_DARK_RED_AREA   = 1200
TRI_SIZE_PX         = 18
PURPLE              = (255, 0, 255)
MIN_DARK_FRACTION   = 0.15

TEXT_CLR        = (255, 255, 255)
JAKE_BOX_CLR    = (0, 255, 0)
UNDER_TINT_BGR  = (255, 0, 255)

WARMUP_FRAMES   = 10
PREFETCH        = 24    # prefetch helps saturate I/O
INFER_QUEUE     = 4     # frames in-flight through inference

# CPU threading (OpenCV)
OPENCV_THREADS  = max(1, (os.cpu_count() or 8) - 1)
try:
    cv2.setNumThreads(OPENCV_THREADS)
except Exception:
    pass

# =======================
# Device selection
# =======================
if torch.cuda.is_available():
    yolo_device, half = 0, True
    DET_ON_CPU = True
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    yolo_device, half = "mps", False
    DET_ON_CPU = True   # RF-DETR on CPU to overlap with MPS
else:
    yolo_device, half = "cpu", False
    DET_ON_CPU = False  # both CPU; pipeline still helps

# =======================
# Models
# =======================
yolo_model = YOLO(yolo_weights)
try: yolo_model.fuse()
except Exception: pass

jake_model = RFDETRBase(pretrain_weights=jake_weights, num_classes=3)

# Warmup YOLO (so first frame isn't laggy)
_dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
with torch.inference_mode():
    _ = yolo_model.predict(_dummy, task="segment", imgsz=IMG_SIZE,
                           device=yolo_device, conf=CONF, iou=IOU,
                           verbose=False, half=half)

# =======================
# Utils
# =======================
obstacle_classes = {
    0:"BOOTS",1:"GREYTRAIN",2:"HIGHBARRIER1",3:"JUMP",4:"LOWBARRIER1",
    5:"LOWBARRIER2",6:"ORANGETRAIN",7:"PILLAR",8:"RAMP",9:"RAILS",
    10:"SIDEWALK",11:"YELLOWTRAIN"
}
CLASS_COLOURS = {
    0:(255,255,0),1:(192,192,192),2:(0,128,255),3:(0,255,0),
    4:(255,0,255),5:(0,255,255),6:(255,128,0),7:(128,0,255),
    8:(0,0,128),10:(128,128,0),11:(255,255,102)
}

def highlight_rails_mask_only_fast(img_bgr, rail_mask, target_colors_rgb,
                                   tolerance=30.0, min_region_size=50,
                                   min_region_height=150):
    H, W = img_bgr.shape[:2]
    filtered_full = np.zeros((H, W), dtype=bool)
    if not rail_mask.any():
        return filtered_full
    ys, xs = np.where(rail_mask)
    y0, y1 = ys.min(), ys.max()+1
    x0, x1 = xs.min(), xs.max()+1
    img_roi  = img_bgr[y0:y1, x0:x1]
    mask_roi = rail_mask[y0:y1, x0:x1]
    targets_bgr = np.array([(r,g,b)[::-1] for r,g,b in target_colors_rgb], dtype=np.float32)
    img_f = img_roi.astype(np.float32)
    diff  = img_f[:, :, None, :] - targets_bgr[None, None, :, :]
    dist2 = np.sum(diff*diff, axis=-1)
    colour_hit = np.any(dist2 <= tolerance**2, axis=-1)
    combined = colour_hit & mask_roi
    comp = combined.astype(np.uint8)
    n, lbls, stats, _ = cv2.connectedComponentsWithStats(comp, 8)
    good = np.zeros_like(combined)
    for lbl in range(1, n):
        area, h = stats[lbl, cv2.CC_STAT_AREA], stats[lbl, cv2.CC_STAT_HEIGHT]
        if area >= min_region_size and h >= min_region_height:
            good[lbls == lbl] = True
    filtered_full[y0:y1, x0:x1] = good
    return filtered_full

def red_vs_green_score(red_mask, green_mask, blur_ksize=51):
    k = (blur_ksize, blur_ksize)
    r = cv2.blur(red_mask.astype(np.float32), k)
    g = cv2.blur(green_mask.astype(np.float32), k)
    diff = r - g
    amax = float(np.max(np.abs(diff))) + 1e-6
    norm = (diff / (2*amax) + 0.5)
    return np.clip(norm * 255, 0, 255).astype(np.uint8)

def draw_triangle(img, x, y, size=TRI_SIZE_PX, colour=PURPLE):
    h = int(size * 1.2)
    pts = np.array([[x, y], [x-size, y+h], [x+size, y+h]], np.int32)
    cv2.fillConvexPoly(img, pts, colour)
    cv2.polylines(img, [pts.reshape(-1,1,2)], True, (0,0,0), 1, cv2.LINE_AA)

def assemble_canvas(left, right, heat):
    top    = np.hstack((left, right))
    return np.vstack((top, heat))

def resize_to_width(img, max_w=DISPLAY_MAX_W):
    """Resize maintaining aspect ratio if wider than max_w."""
    h, w = img.shape[:2]
    if w <= max_w:
        return img
    scale = max_w / float(w)
    return cv2.resize(img, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)

# =======================
# Inference worker (persistent, measures inference time)
# =======================
class InferenceWorker(threading.Thread):
    """
    Consumes frames from in_q, runs YOLO (GPU/MPS) + RF-DETR (CPU optionally),
    and returns:
      (path, frame, masks, classes, jake_xyxy, infer_start_ts, infer_end_ts)
    """
    def __init__(self, in_q, out_q):
        super().__init__(daemon=True)
        self.in_q  = in_q
        self.out_q = out_q
        self.stop  = False

    def run(self):
        while not self.stop:
            item = self.in_q.get()
            if item is None:  # sentinel
                self.out_q.put(None)
                break
            path, frame = item
            H, W = frame.shape[:2]

            with torch.inference_mode():
                infer_start_ts = time.perf_counter()

                # YOLO (device)
                res = yolo_model.predict(frame, task="segment", imgsz=IMG_SIZE,
                                         device=yolo_device, conf=CONF, iou=IOU,
                                         max_det=30, verbose=False, half=half)[0]

                # RF-DETR (CPU to overlap by default)
                jake_xyxy = None
                try:
                    pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
                    det_dev = "cpu" if DET_ON_CPU else ("cpu" if yolo_device != 0 else 0)
                    dets = jake_model.predict(pil_img, threshold=0.5, device=det_dev)[0]
                    if hasattr(dets, "xyxy") and len(dets.xyxy) > 0:
                        x1, y1, x2, y2 = dets.xyxy[0].astype(int).tolist()
                        # clamp
                        x1 = max(0, min(W-1, x1)); x2 = max(0, min(W-1, x2))
                        y1 = max(0, min(H-1, y1)); y2 = max(0, min(H-1, y2))
                        if x2 > x1 and y2 > y1:
                            jake_xyxy = (x1, y1, x2, y2)
                except Exception:
                    pass

                infer_end_ts = time.perf_counter()

            if res.masks is None:
                self.out_q.put((path, frame, None, None, jake_xyxy, infer_start_ts, infer_end_ts))
                continue

            masks   = res.masks.data.cpu().numpy()  # (N, h_m, w_m)
            classes = (res.masks.cls.cpu().numpy() if hasattr(res.masks, "cls")
                       else res.boxes.cls.cpu().numpy())
            classes = np.asarray(classes, dtype=int).ravel()

            self.out_q.put((path, frame, masks, classes, jake_xyxy, infer_start_ts, infer_end_ts))

# =======================
# Prefetcher
# =======================
class FramePrefetcher(threading.Thread):
    def __init__(self, paths, out_q, maxsize=PREFETCH):
        super().__init__(daemon=True)
        self.paths = paths
        self.q = out_q
        self.stop = False

    def run(self):
        for p in self.paths:
            if self.stop: break
            img = cv2.imread(p)
            self.q.put((p, img))
        self.q.put(None)  # sentinel

# =======================
# Main (JN-friendly, honest timing)
# =======================
if __name__ == "__main__":
    folder = Path.home() / "Documents" / "GitHub" / "Ai-plays-SubwaySurfers" / "frames"
    if not folder.is_dir():
        print(f"[ERROR] frames folder not found: {folder}", file=sys.stderr); sys.exit(1)

    paths = sorted(
        glob.glob(str(folder / "frame_*.jpg")) +
        glob.glob(str(folder / "frame_*.png")) +
        glob.glob(str(folder / "*.jpg")) +
        glob.glob(str(folder / "*.png"))
    )
    if not paths:
        print(f"[ERROR] No frame images found in {folder}", file=sys.stderr); sys.exit(1)

    # Queues
    prefetch_q = queue.Queue(maxsize=PREFETCH)
    infer_in_q = queue.Queue(maxsize=INFER_QUEUE)
    infer_out_q = queue.Queue(maxsize=INFER_QUEUE)

    # Threads
    pf = FramePrefetcher(paths, prefetch_q, maxsize=PREFETCH)
    pf.start()

    iw = InferenceWorker(infer_in_q, infer_out_q)
    iw.start()

    # Prime inference queue with a few frames
    buffered = 0
    while buffered < INFER_QUEUE:
        item = prefetch_q.get()
        infer_in_q.put(item)
        if item is None: break
        buffered += 1

    # Timing accumulators (end-to-end), and also separate components
    total_ms_list = []
    infer_ms_list = []
    post_ms_list  = []

    processed = 0
    shown = 0
    t_all_start = time.perf_counter()

    i = 0
    while True:
        # Keep inference fed
        while not prefetch_q.empty() and infer_in_q.qsize() < INFER_QUEUE:
            infer_in_q.put(prefetch_q.get())

        out = infer_out_q.get()
        if out is None:
            break

        p, frame, masks, classes, jake_xyxy, infer_start_ts, infer_end_ts = out
        if frame is None:
            i += 1
            continue

        if classes is not None:
            classes = np.asarray(classes, dtype=int).ravel()

        # ───────────── POST-PROC TIMING (processing only) ─────────────
        post_start_ts = time.perf_counter()

        H, W = frame.shape[:2]
        if masks is None or classes is None:
            left = frame.copy(); right = frame.copy()
            heat = cv2.applyColorMap(np.zeros((H, W), np.uint8), cv2.COLORMAP_JET)
        else:
            h_m, w_m = masks.shape[1:]
            rail_union = np.zeros((h_m, w_m), dtype=bool)
            for m, c in zip(masks, classes):
                if int(c) == RAIL_ID:
                    rail_union |= m.astype(bool)
            rail_mask = cv2.resize(rail_union.astype(np.uint8), (W, H),
                                   interpolation=cv2.INTER_NEAREST).astype(bool)

            # Green/red + score
            green = highlight_rails_mask_only_fast(frame, rail_mask,
                                                   TARGET_COLORS_RGB, TOLERANCE,
                                                   MIN_REGION_SIZE, MIN_REGION_HEIGHT)
            red   = rail_mask & ~green
            score = red_vs_green_score(red, green, HEAT_BLUR_KSIZE)

            # Exclude bands
            top_ex = int(H * EXCLUDE_TOP_FRAC)
            bot_ex = int(H * EXCLUDE_BOTTOM_FRAC)
            score[:top_ex, :] = 0
            if bot_ex: score[H-bot_ex:, :] = 0

            dark = (score >= RED_SCORE_THRESH).astype(np.uint8)
            dark = cv2.morphologyEx(
                dark, cv2.MORPH_OPEN,
                cv2.getStructuringElement(cv2.MORPH_RECT, (5, 9)),
                iterations=1
            )

            total_dark_area  = int(dark.sum())
            frac_area_thresh = int(np.ceil(MIN_DARK_FRACTION * total_dark_area))

            # LEFT labels + under-mask tint
            left    = frame.copy()
            overlay = left.copy()

            # Jake overlap (low-res bbox)
            best_idx = None
            if jake_xyxy and masks is not None:
                x1, y1, x2, y2 = jake_xyxy
                sx, sy = w_m / W, h_m / H
                mx1, mx2 = max(0, int(x1 * sx)), min(w_m, int(x2 * sx))
                my1, my2 = max(0, int(y1 * sy)), min(h_m, int(y2 * sy))
            else:
                mx1 = mx2 = my1 = my2 = None

            for idx in range(len(classes)):
                cid = int(classes[idx])
                if cid == RAIL_ID:
                    continue
                m_low = masks[idx]
                mask_full = cv2.resize(m_low.astype(np.uint8), (W, H),
                                       interpolation=cv2.INTER_NEAREST).astype(bool)
                colour = CLASS_COLOURS.get(cid, (255,255,255))
                overlay[mask_full] = colour

                ys, xs = np.where(mask_full)
                if len(xs):
                    x_c, y_c = int(xs.mean()), int(ys.mean())
                    label = obstacle_classes.get(cid, f"CLASS {cid}")
                    cv2.putText(overlay, label, (x_c-40, y_c),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2, cv2.LINE_AA)
                    cv2.putText(overlay, label, (x_c-40, y_c),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 1, cv2.LINE_AA)

                if jake_xyxy and (mx2 > mx1) and (my2 > my1):
                    area = int(m_low[my1:my2, mx1:mx2].sum())
                    if best_idx is None:
                        best_idx = idx
                    else:
                        prev = int(masks[best_idx][my1:my2, mx1:mx2].sum())
                        if area > prev:
                            best_idx = idx

            left = cv2.addWeighted(overlay, 0.6, left, 0.4, 0)

            # Purple-triangle warnings
            n_lbl, lbl_mat, stats, _ = cv2.connectedComponentsWithStats(dark, 8)
            for lbl in range(1, n_lbl):
                area = stats[lbl, cv2.CC_STAT_AREA]
                if area < MIN_DARK_RED_AREA or area < frac_area_thresh:
                    continue
                ys, xs = np.where(lbl_mat == lbl)
                y_top  = ys.min()
                x_mid  = int(xs[ys == ys.min()].mean())
                draw_triangle(left, x_mid, y_top)

            # Jake bbox + under-mask tint
            if jake_xyxy:
                x1, y1, x2, y2 = jake_xyxy
                cv2.rectangle(left, (x1, y1), (x2, y2), JAKE_BOX_CLR, 2)
                if best_idx is not None:
                    best_mask_full = cv2.resize(masks[best_idx].astype(np.uint8), (W, H),
                                                interpolation=cv2.INTER_NEAREST).astype(bool)
                    pink_layer = left.copy()
                    pink_layer[best_mask_full] = UNDER_TINT_BGR
                    left = cv2.addWeighted(pink_layer, 0.35, left, 0.65, 0)

            # RIGHT rails + green
            right = frame.copy()
            rails_tinted = right.copy()
            rails_tinted[rail_mask] = (0,0,255)
            right = cv2.addWeighted(rails_tinted, ALPHA, right, 1-ALPHA, 0)
            right[green] = (0,255,0)

            # BOTTOM heat
            heat = cv2.applyColorMap(score, cv2.COLORMAP_JET)
            heat = cv2.resize(heat, (left.shape[1] + right.shape[1], H),
                              interpolation=cv2.INTER_LINEAR)

        post_end_ts = time.perf_counter()
        # ───────────── END TIMING ─────────────

        # Calculate timings
        infer_ms = (infer_end_ts - infer_start_ts) * 1000.0
        post_ms  = (post_end_ts  - post_start_ts) * 1000.0
        total_ms = (post_end_ts  - infer_start_ts) * 1000.0  # honest end-to-end

        # record timing for non-warmup frames
        if i >= WARMUP_FRAMES:
            infer_ms_list.append(infer_ms)
            post_ms_list.append(post_ms)
            total_ms_list.append(total_ms)
        processed += 1

        # ── DISPLAY (not timed): assemble, resize, annotate, convert, show ──
        if _HAS_IPY and shown < SHOW_FIRST_N:
            canvas = assemble_canvas(left, right, heat)
            canvas = resize_to_width(canvas, DISPLAY_MAX_W)
            tag = (f"{os.path.basename(p)} | infer: {infer_ms:.1f} ms | "
                   f"post: {post_ms:.1f} ms | total: {total_ms:.1f} ms")
            cv2.putText(canvas, tag, (12, 28),
                        cv2.FONT_HERSHEY_SIMPLEX, TEXT_SCALE, TEXT_CLR, TEXT_THICK, cv2.LINE_AA)
            rgb = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
            display(Image.fromarray(rgb))
            shown += 1

        i += 1

    # Done
    total_time = (time.perf_counter() - t_all_start)
    if total_ms_list:
        avg_total = statistics.mean(total_ms_list)
        avg_infer = statistics.mean(infer_ms_list)
        avg_post  = statistics.mean(post_ms_list)
        print("\n───── Speed-test summary (warm-up skipped) ─────")
        print(f"Total frames           : {processed}")
        print(f"Warm-up frames ignored : {min(WARMUP_FRAMES, processed)}")
        print(f"Frames timed           : {len(total_ms_list)}")
        print(f"Total wall-clock time  : {total_time:,.2f} s")
        print(f"Average inference      : {avg_infer:,.2f} ms")
        print(f"Average postproc       : {avg_post:,.2f} ms")
        print(f"Average end-to-end     : {avg_total:,.2f} ms  ({1000.0/avg_total:,.2f} FPS)")
        print(f"Median end-to-end      : {statistics.median(total_ms_list):,.2f} ms")
        print(f"Fastest end-to-end     : {min(total_ms_list):,.2f} ms")
        print(f"Slowest end-to-end     : {max(total_ms_list):,.2f} ms")
        print("────────────────────────────────────────────────")
