In [None]:
#JUst running the detection and ground pathing. will need to think smart about how to do top of trains

In [None]:
#!/usr/bin/env python3
"""
Visualise *all* segmentation masks with labelled overlays.

• Rails are tinted red (right) and used for the heat-map + purple-triangle.
• Other masks are labelled and coloured on the original (left).
• Bottom panel shows the heat-map.

Controls:
  SPACE: next frame   |   q / ESC: quit

Loads images from:
    ~/Documents/GitHub/Ai-plays-SubwaySurfers/frames

Change requested:
  Before plotting a purple triangle, require that the component's dark-red
  area is at least 15% of the *total* dark-red area observed for the frame.
  Implemented with negligible overhead and no other logic changes.
"""

import os, glob, sys, time
import cv2, torch, numpy as np
from pathlib import Path
from ultralytics import YOLO

# =======================
# Config
# =======================
home     = os.path.expanduser("~")
weights  = f"{home}/models/jakes-loped/jakes-finder-mk1/1/weights.pt"
RAIL_ID  = 9

IMG_SIZE = 512
CONF, IOU = 0.30, 0.45
ALPHA    = 0.40

TARGET_COLORS_RGB  = [(119,104,67), (81,42,45)]
TOLERANCE          = 20.0
MIN_REGION_SIZE    = 30
MIN_REGION_HEIGHT  = 150

HEAT_BLUR_KSIZE     = 51
RED_SCORE_THRESH    = 220
EXCLUDE_TOP_FRAC    = 0.40
EXCLUDE_BOTTOM_FRAC = 0.15
MIN_DARK_RED_AREA   = 1200
TRI_SIZE_PX         = 18
PURPLE              = (255, 0, 255)

# New: minimum fraction of total dark-red area a blob must contribute
MIN_DARK_FRACTION   = 0.15  # 15%

# UI / display constraints
WIN_NAME        = "Left: labels | Right: rails | Bottom: heat-map (SPACE=next, q=quit)"
MAX_DISPLAY_W   = 1600
TEXT_CLR        = (255, 255, 255)

obstacle_classes = {
    0: "BOOTS", 1: "GREYTRAIN", 2: "HIGHBARRIER1", 3: "JUMP", 4: "LOWBARRIER1",
    5: "LOWBARRIER2", 6: "ORANGETRAIN", 7: "PILLAR", 8: "RAMP", 9: "RAILS",
    10: "SIDEWALK", 11: "YELLOWTRAIN"
}
CLASS_COLOURS = {
    0: (255,255,0), 1: (192,192,192), 2: (0,128,255), 3: (0,255,0),
    4: (255,0,255), 5: (0,255,255), 6: (255,128,0), 7: (128,0,255),
    8: (0,0,128), 10: (128,128,0), 11: (255,255,102)
}

# =======================
# Device / precision
# =======================
if torch.cuda.is_available():
    device, half = 0, True
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device, half = "mps", False
else:
    device, half = "cpu", False

# Optional: reduce OpenCV CPU thread contention with PyTorch
try:
    cv2.setNumThreads(1)
except Exception:
    pass

model = YOLO(weights)
try: model.fuse()
except Exception: pass

# warm up once so the first frame isn't laggy
_dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
_ = model.predict(_dummy, task="segment", imgsz=IMG_SIZE, device=device,
                  conf=CONF, iou=IOU, verbose=False, half=half)

# ------------------------------------------------------------------
def highlight_rails_mask_only_fast(img_bgr, rail_mask, target_colors_rgb,
                                   tolerance=30.0,
                                   min_region_size=50,
                                   min_region_height=150):
    H, W = img_bgr.shape[:2]
    filtered_full = np.zeros((H, W), dtype=bool)
    if not rail_mask.any():
        return filtered_full

    ys, xs = np.where(rail_mask)
    y0, y1 = ys.min(), ys.max()+1
    x0, x1 = xs.min(), xs.max()+1

    img_roi  = img_bgr[y0:y1, x0:x1]
    mask_roi = rail_mask[y0:y1, x0:x1]

    targets_bgr = np.array([(r,g,b)[::-1] for r,g,b in target_colors_rgb],
                           dtype=np.float32)
    img_f = img_roi.astype(np.float32)
    diff  = img_f[:, :, None, :] - targets_bgr[None, None, :, :]
    dist2 = np.sum(diff*diff, axis=-1)
    colour_hit = np.any(dist2 <= tolerance**2, axis=-1)

    combined = colour_hit & mask_roi
    comp = combined.astype(np.uint8)
    n, lbls, stats, _ = cv2.connectedComponentsWithStats(comp, 8)

    good = np.zeros_like(combined)
    for lbl in range(1, n):
        area, h = stats[lbl, cv2.CC_STAT_AREA], stats[lbl, cv2.CC_STAT_HEIGHT]
        if area >= min_region_size and h >= min_region_height:
            good[lbls == lbl] = True

    filtered_full[y0:y1, x0:x1] = good
    return filtered_full


def red_vs_green_score(red_mask, green_mask, blur_ksize=51):
    k = (blur_ksize, blur_ksize)
    r = cv2.blur(red_mask.astype(np.float32), k)
    g = cv2.blur(green_mask.astype(np.float32), k)
    diff = r - g
    amax = float(np.max(np.abs(diff))) + 1e-6
    norm = (diff / (2*amax) + 0.5)
    return np.clip(norm * 255, 0, 255).astype(np.uint8)


def draw_triangle(img, x, y, size=TRI_SIZE_PX, colour=PURPLE):
    h = int(size * 1.2)
    pts = np.array([[x, y], [x-size, y+h], [x+size, y+h]], np.int32)
    cv2.fillConvexPoly(img, pts, colour)
    cv2.polylines(img, [pts.reshape(-1,1,2)], True, (0,0,0), 1, cv2.LINE_AA)

# ------------------------------------------------------------------
def process_frame(img_bgr):
    """Returns (left, right, heat_col) – all uint8 BGR images."""
    res = model.predict(img_bgr, task="segment", imgsz=IMG_SIZE,
                        device=device, conf=CONF, iou=IOU,
                        max_det=30, verbose=False, half=half)[0]

    H, W = img_bgr.shape[:2]
    if res.masks is None:
        blank_heat = np.zeros((H, W), np.uint8)
        return img_bgr.copy(), img_bgr.copy(), cv2.applyColorMap(blank_heat, cv2.COLORMAP_JET)

    masks   = res.masks.data.cpu().numpy()
    classes = res.masks.cls.cpu().numpy() if hasattr(res.masks, "cls") \
              else res.boxes.cls.cpu().numpy()
    h_m, w_m = masks.shape[1:]

    rail_union = np.zeros((h_m, w_m), bool)
    for m, c in zip(masks, classes):
        if int(c) == RAIL_ID:
            rail_union |= m.astype(bool)

    rail_mask = rail_union
    if rail_mask.shape != (H, W):
        rail_mask = cv2.resize(rail_mask.astype(np.uint8), (W, H),
                               interpolation=cv2.INTER_NEAREST).astype(bool)

    green = highlight_rails_mask_only_fast(img_bgr, rail_mask,
                                           TARGET_COLORS_RGB, TOLERANCE,
                                           MIN_REGION_SIZE, MIN_REGION_HEIGHT)
    red   = rail_mask & ~green
    score = red_vs_green_score(red, green, HEAT_BLUR_KSIZE)

    top_ex = int(H * EXCLUDE_TOP_FRAC)
    bot_ex = int(H * EXCLUDE_BOTTOM_FRAC)
    score[:top_ex, :] = 0
    if bot_ex: score[H-bot_ex:, :] = 0

    dark = (score >= RED_SCORE_THRESH).astype(np.uint8)
    dark = cv2.morphologyEx(
        dark, cv2.MORPH_OPEN,
        cv2.getStructuringElement(cv2.MORPH_RECT, (5, 9)),
        iterations=1
    )

    # ================== New tiny-overhead filter ==================
    # Require a component to be at least MIN_DARK_FRACTION of total dark area.
    total_dark_area = int(dark.sum())
    frac_area_thresh = int(np.ceil(MIN_DARK_FRACTION * total_dark_area))
    # =============================================================

    # left: labels overlaid
    left    = img_bgr.copy()
    overlay = left.copy()
    for m, c in zip(masks, classes):
        cid = int(c)
        if cid == RAIL_ID:
            continue
        mask_full = m
        if mask_full.shape != (H, W):
            mask_full = cv2.resize(mask_full.astype(np.uint8), (W, H),
                                   interpolation=cv2.INTER_NEAREST).astype(bool)
        label  = obstacle_classes.get(cid, f"CLASS {cid}")
        colour = CLASS_COLOURS.get(cid, (255, 255, 255))
        overlay[mask_full] = colour

        ys, xs = np.where(mask_full)
        if len(xs):
            x_c, y_c = int(xs.mean()), int(ys.mean())
            cv2.putText(overlay, label, (x_c - 40, y_c),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2, cv2.LINE_AA)
            cv2.putText(overlay, label, (x_c - 40, y_c),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 1, cv2.LINE_AA)
    left = cv2.addWeighted(overlay, 0.6, left, 0.4, 0)

    # purple triangle warnings
    n_lbl, lbl_mat, stats, _ = cv2.connectedComponentsWithStats(dark, 8)
    for lbl in range(1, n_lbl):
        area = stats[lbl, cv2.CC_STAT_AREA]
        if area < MIN_DARK_RED_AREA or area < frac_area_thresh:
            continue
        ys, xs = np.where(lbl_mat == lbl)
        y_top, x_mid = ys.min(), int(xs[ys == ys.min()].mean())
        draw_triangle(left, x_mid, y_top)

    # right: rails tinted
    right = img_bgr.copy()
    rails_tinted = right.copy()
    rails_tinted[rail_mask] = (0,0,255)
    right = cv2.addWeighted(rails_tinted, ALPHA, right, 1-ALPHA, 0)
    right[green] = (0,255,0)

    # bottom: heat map (width = left+right, height = H)
    heat_col = cv2.applyColorMap(score, cv2.COLORMAP_JET)
    heat_col = cv2.resize(heat_col, (left.shape[1] + right.shape[1], H),
                          interpolation=cv2.INTER_LINEAR)

    return left, right, heat_col

def assemble_canvas(left, right, heat):
    top    = np.hstack((left, right))
    canvas = np.vstack((top, heat))
    return canvas

def maybe_downscale(img, max_w=MAX_DISPLAY_W):
    h, w = img.shape[:2]
    if w <= max_w:
        return img
    scale = max_w / float(w)
    new_size = (int(w*scale), int(h*scale))
    return cv2.resize(img, new_size, interpolation=cv2.INTER_AREA)

# ------------------------------------------------------------------
if __name__ == "__main__":
    folder = Path.home() / "Documents" / "GitHub" / "Ai-plays-SubwaySurfers" / "frames"
    if not folder.is_dir():
        print(f"frames folder not found: {folder}", file=sys.stderr); sys.exit(1)

    paths = sorted(
        glob.glob(str(folder / "frame_*.jpg")) +
        glob.glob(str(folder / "frame_*.png")) +
        glob.glob(str(folder / "*.jpg")) +
        glob.glob(str(folder / "*.png"))
    )
    if not paths:
        print(f"No frame images found in {folder}", file=sys.stderr); sys.exit(1)

    cv2.namedWindow(WIN_NAME, cv2.WINDOW_NORMAL)

    i = 0
    n = len(paths)
    while i < n:
        p = paths[i]
        frame = cv2.imread(p)
        if frame is None:
            print(f"[WARN] unreadable image: {p}")
            i += 1
            continue

        try:
            t0 = time.perf_counter()
            left, right, heat = process_frame(frame)
            proc_ms = (time.perf_counter() - t0) * 1000.0
            canvas = assemble_canvas(left, right, heat)
            canvas = maybe_downscale(canvas, MAX_DISPLAY_W)

            # On-screen text for processing time
            cv2.putText(canvas, f"{os.path.basename(p)}  |  proc: {proc_ms:.1f} ms",
                        (12, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,0), 3, cv2.LINE_AA)
            cv2.putText(canvas, f"{os.path.basename(p)}  |  proc: {proc_ms:.1f} ms",
                        (12, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.8, TEXT_CLR, 1, cv2.LINE_AA)
            cv2.putText(canvas, "SPACE: next   q/ESC: quit",
                        (12, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0), 3, cv2.LINE_AA)
            cv2.putText(canvas, "SPACE: next   q/ESC: quit",
                        (12, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, TEXT_CLR, 1, cv2.LINE_AA)

        except Exception as e:
            print(f"[WARN] error on {os.path.basename(p)}: {e}")
            i += 1
            continue

        # Responsive event loop
        while True:
            cv2.imshow(WIN_NAME, canvas)
            key = cv2.waitKey(1) & 0xFF

            # Handle window close
            if cv2.getWindowProperty(WIN_NAME, cv2.WND_PROP_VISIBLE) < 1:
                cv2.destroyAllWindows()
                sys.exit(0)

            if key == 32:          # SPACE → next
                i += 1
                break
            elif key in (ord('q'), 27):  # q or ESC
                cv2.destroyAllWindows()
                sys.exit(0)

    cv2.destroyAllWindows()


In [None]:
#Added multithreading

In [None]:
#!/usr/bin/env python3
"""
Visualise *all* segmentation masks with labelled overlays — high-utilization version.

• Fully pipelined to use CPU + GPU:
  - Frame prefetch (I/O)
  - Batched YOLO inference on GPU/MPS (or CPU) with configurable batch size
  - Parallel CPU postprocessing workers (triangles, labels, overlays, heat)
  - UI thread only displays; compute continues in background
• Logic unchanged: rails, heat-map, purple triangle (≥15% of total dark area), labels

Controls:
  SPACE: next frame   |   q / ESC: quit

Loads images from:
    ~/Documents/GitHub/Ai-plays-SubwaySurfers/frames
"""

import os, glob, sys, time, threading, queue
import cv2, torch, numpy as np
from pathlib import Path
from ultralytics import YOLO

# =======================
# Config (tune these)
# =======================
home          = os.path.expanduser("~")
weights       = f"{home}/models/jakes-loped/jakes-finder-mk1/1/weights.pt"
RAIL_ID       = 9

IMG_SIZE      = 512
CONF, IOU     = 0.30, 0.45
ALPHA         = 0.40

TARGET_COLORS_RGB  = [(119,104,67), (81,42,45)]
TOLERANCE          = 20.0
MIN_REGION_SIZE    = 30
MIN_REGION_HEIGHT  = 150

HEAT_BLUR_KSIZE     = 51
RED_SCORE_THRESH    = 220
EXCLUDE_TOP_FRAC    = 0.40
EXCLUDE_BOTTOM_FRAC = 0.15
MIN_DARK_RED_AREA   = 1200
TRI_SIZE_PX         = 18
PURPLE              = (255, 0, 255)

# Require blob to be at least this fraction of total dark-red area
MIN_DARK_FRACTION   = 0.15  # 15%

# UI
WIN_NAME        = "Left: labels | Right: rails | Bottom: heat-map (SPACE=next, q=quit)"
MAX_DISPLAY_W   = 1600
TEXT_CLR        = (255, 255, 255)

# Pipeline sizing
YOLO_BATCH      = 4    # try 2–8 depending on VRAM
PREFETCH        = 32   # frames buffered from disk
INFER_QUEUE     = 32   # frames waiting for inference
POST_QUEUE      = 32   # inference outputs waiting for postproc
POST_WORKERS    = max(2, (os.cpu_count() or 8) - 1)  # CPU worker threads

# =======================
# Class colours/names
# =======================
obstacle_classes = {
    0: "BOOTS", 1: "GREYTRAIN", 2: "HIGHBARRIER1", 3: "JUMP", 4: "LOWBARRIER1",
    5: "LOWBARRIER2", 6: "ORANGETRAIN", 7: "PILLAR", 8: "RAMP", 9: "RAILS",
    10: "SIDEWALK", 11: "YELLOWTRAIN"
}
CLASS_COLOURS = {
    0: (255,255,0), 1: (192,192,192), 2: (0,128,255), 3: (0,255,0),
    4: (255,0,255), 5: (0,255,255), 6: (255,128,0), 7: (128,0,255),
    8: (0,0,128), 10: (128,128,0), 11: (255,255,102)
}

# =======================
# Device / precision
# =======================
if torch.cuda.is_available():
    device, half = 0, True
    torch.backends.cudnn.benchmark = True
    try: torch.set_float32_matmul_precision('high')
    except Exception: pass
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device, half = "mps", False
else:
    device, half = "cpu", False

# Optional: reduce OpenCV thread contention with PyTorch (GPU-heavy workloads)
try:
    cv2.setNumThreads(1)  # our own thread pool does the parallelism
except Exception:
    pass

# =======================
# Model
# =======================
model = YOLO(weights)
try: model.fuse()
except Exception: pass

# Warm up once so the first frame isn't laggy
_dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
_ = model.predict(_dummy, task="segment", imgsz=IMG_SIZE, device=device,
                  conf=CONF, iou=IOU, verbose=False, half=half)

# =======================
# Helpers (logic unchanged)
# =======================
def highlight_rails_mask_only_fast(img_bgr, rail_mask, target_colors_rgb,
                                   tolerance=30.0, min_region_size=50,
                                   min_region_height=150):
    H, W = img_bgr.shape[:2]
    filtered_full = np.zeros((H, W), dtype=bool)
    if not rail_mask.any():
        return filtered_full

    ys, xs = np.where(rail_mask)
    y0, y1 = ys.min(), ys.max()+1
    x0, x1 = xs.min(), xs.max()+1

    img_roi  = img_bgr[y0:y1, x0:x1]
    mask_roi = rail_mask[y0:y1, x0:x1]

    targets_bgr = np.array([(r,g,b)[::-1] for r,g,b in target_colors_rgb],
                           dtype=np.float32)
    img_f = img_roi.astype(np.float32)
    diff  = img_f[:, :, None, :] - targets_bgr[None, None, :, :]
    dist2 = np.sum(diff*diff, axis=-1)
    colour_hit = np.any(dist2 <= tolerance**2, axis=-1)

    combined = colour_hit & mask_roi
    comp = combined.astype(np.uint8)
    n, lbls, stats, _ = cv2.connectedComponentsWithStats(comp, 8)

    good = np.zeros_like(combined)
    for lbl in range(1, n):
        area, h = stats[lbl, cv2.CC_STAT_AREA], stats[lbl, cv2.CC_STAT_HEIGHT]
        if area >= min_region_size and h >= min_region_height:
            good[lbls == lbl] = True

    filtered_full[y0:y1, x0:x1] = good
    return filtered_full

def red_vs_green_score(red_mask, green_mask, blur_ksize=51):
    k = (blur_ksize, blur_ksize)
    r = cv2.blur(red_mask.astype(np.float32), k)
    g = cv2.blur(green_mask.astype(np.float32), k)
    diff = r - g
    amax = float(np.max(np.abs(diff))) + 1e-6
    norm = (diff / (2*amax) + 0.5)
    return np.clip(norm * 255, 0, 255).astype(np.uint8)

def draw_triangle(img, x, y, size=TRI_SIZE_PX, colour=PURPLE):
    h = int(size * 1.2)
    pts = np.array([[x, y], [x-size, y+h], [x+size, y+h]], np.int32)
    cv2.fillConvexPoly(img, pts, colour)
    cv2.polylines(img, [pts.reshape(-1,1,2)], True, (0,0,0), 1, cv2.LINE_AA)

def assemble_canvas(left, right, heat):
    """Stack LEFT|RIGHT on top row and HEAT on bottom. Auto-fix minor width drift."""
    top = np.hstack((left, right))
    th, tw = top.shape[:2]
    hh, hw = heat.shape[:2]
    if hw != tw:
        # Auto-resize heat width to match top’s width (height stays heat’s height)
        heat = cv2.resize(heat, (tw, hh), interpolation=cv2.INTER_LINEAR)
    return np.vstack((top, heat))

def maybe_downscale(img, max_w=MAX_DISPLAY_W):
    h, w = img.shape[:2]
    if w <= max_w:
        return img
    scale = max_w / float(w)
    new_size = (int(w*scale), int(h*scale))
    return cv2.resize(img, new_size, interpolation=cv2.INTER_AREA)

# =======================
# Post-processing: build left/right/heat for one frame
# (logic preserved; called from CPU worker threads)
# =======================
def build_panes_from_result(frame, res):
    H, W = frame.shape[:2]

    # ── NO DETECTIONS: left/right are W×H; heat must be (2W)×H to match top row
    if res is None or res.masks is None:
        left  = frame.copy()
        right = frame.copy()
        blank_heat = np.zeros((H, W), np.uint8)
        heat_col = cv2.applyColorMap(blank_heat, cv2.COLORMAP_JET)
        # *** FIX: make heat width = left.width + right.width = 2W ***
        heat_col = cv2.resize(heat_col, (left.shape[1] + right.shape[1], H),
                              interpolation=cv2.INTER_LINEAR)
        return left, right, heat_col

    masks   = res.masks.data.cpu().numpy()   # (N, h_m, w_m)
    classes = (res.masks.cls.cpu().numpy() if hasattr(res.masks, "cls")
               else res.boxes.cls.cpu().numpy())
    h_m, w_m = masks.shape[1:]

    rail_union = np.zeros((h_m, w_m), bool)
    for m, c in zip(masks, classes):
        if int(c) == RAIL_ID:
            rail_union |= m.astype(bool)

    rail_mask = rail_union
    if rail_mask.shape != (H, W):
        rail_mask = cv2.resize(rail_mask.astype(np.uint8), (W, H),
                               interpolation=cv2.INTER_NEAREST).astype(bool)

    green = highlight_rails_mask_only_fast(frame, rail_mask,
                                           TARGET_COLORS_RGB, TOLERANCE,
                                           MIN_REGION_SIZE, MIN_REGION_HEIGHT)
    red   = rail_mask & ~green
    score = red_vs_green_score(red, green, HEAT_BLUR_KSIZE)

    # Exclude top/bottom
    top_ex = int(H * EXCLUDE_TOP_FRAC)
    bot_ex = int(H * EXCLUDE_BOTTOM_FRAC)
    score[:top_ex, :] = 0
    if bot_ex: score[H-bot_ex:, :] = 0

    dark = (score >= RED_SCORE_THRESH).astype(np.uint8)
    dark = cv2.morphologyEx(
        dark, cv2.MORPH_OPEN,
        cv2.getStructuringElement(cv2.MORPH_RECT, (5, 9)),
        iterations=1
    )

    # 15% fraction filter
    total_dark_area  = int(dark.sum())
    frac_area_thresh = int(np.ceil(MIN_DARK_FRACTION * total_dark_area))

    # LEFT: labels overlaid
    left    = frame.copy()
    overlay = left.copy()
    for m, c in zip(masks, classes):
        cid = int(c)
        if cid == RAIL_ID:
            continue
        mask_full = m
        if mask_full.shape != (H, W):
            mask_full = cv2.resize(mask_full.astype(np.uint8), (W, H),
                                   interpolation=cv2.INTER_NEAREST).astype(bool)
        label  = obstacle_classes.get(cid, f"CLASS {cid}")
        colour = CLASS_COLOURS.get(cid, (255, 255, 255))
        overlay[mask_full] = colour

        ys, xs = np.where(mask_full)
        if len(xs):
            x_c, y_c = int(xs.mean()), int(ys.mean())
            cv2.putText(overlay, label, (x_c - 40, y_c),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2, cv2.LINE_AA)
            cv2.putText(overlay, label, (x_c - 40, y_c),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 1, cv2.LINE_AA)
    left = cv2.addWeighted(overlay, 0.6, left, 0.4, 0)

    # Purple triangle warnings
    n_lbl, lbl_mat, stats, _ = cv2.connectedComponentsWithStats(dark, 8)
    for lbl in range(1, n_lbl):
        area = stats[lbl, cv2.CC_STAT_AREA]
        if area < MIN_DARK_RED_AREA or area < frac_area_thresh:
            continue
        ys, xs = np.where(lbl_mat == lbl)
        y_top, x_mid = ys.min(), int(xs[ys == ys.min()].mean())
        draw_triangle(left, x_mid, y_top)

    # RIGHT: rails tinted + green
    right = frame.copy()
    rails_tinted = right.copy()
    rails_tinted[rail_mask] = (0,0,255)
    right = cv2.addWeighted(rails_tinted, ALPHA, right, 1-ALPHA, 0)
    right[green] = (0,255,0)

    # HEAT
    heat_col = cv2.applyColorMap(score, cv2.COLORMAP_JET)
    # width must equal left.width + right.width
    heat_col = cv2.resize(heat_col, (left.shape[1] + right.shape[1], H),
                          interpolation=cv2.INTER_LINEAR)

    return left, right, heat_col

# =======================
# Pipeline threads
# =======================
class Prefetcher(threading.Thread):
    def __init__(self, paths, out_q):
        super().__init__(daemon=True)
        self.paths = paths
        self.out_q = out_q

    def run(self):
        for idx, p in enumerate(self.paths):
            img = cv2.imread(p)
            self.out_q.put((idx, p, img))
        self.out_q.put((None, None, None))  # sentinel

class InferenceWorker(threading.Thread):
    """Batched YOLO inference on GPU/MPS/CPU"""
    def __init__(self, in_q, out_q, batch=YOLO_BATCH):
        super().__init__(daemon=True)
        self.in_q = in_q
        self.out_q = out_q
        self.batch = batch

    def run(self):
        pend = []
        while True:
            # Fill a batch
            if not pend:
                item = self.in_q.get()
                if item[0] is None:
                    break
                pend.append(item)
            while len(pend) < self.batch:
                try:
                    it = self.in_q.get_nowait()
                except queue.Empty:
                    break
                if it[0] is None:
                    # push current batch then stop
                    break
                pend.append(it)

            # Infer the current batch
            idxs  = [i for i, _, _ in pend if i is not None]
            paths = [p for i, p, _ in pend if i is not None]
            frames= [f for i, _, f in pend if i is not None]

            if frames:
                res_list = model.predict(
                    frames, task="segment", imgsz=IMG_SIZE,
                    device=device, conf=CONF, iou=IOU, max_det=30,
                    verbose=False, half=half
                )

                # Ensure GPU kernels complete before posting
                try:
                    if device == 0 and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    elif device == "mps" and torch.backends.mps.is_available():
                        torch.mps.synchronize()
                except Exception:
                    pass

                for idx, p, f, res in zip(idxs, paths, frames, res_list):
                    self.out_q.put((idx, p, f, res))

            # If last sentinel seen, break after draining
            if pend and pend[-1][0] is None:
                break

            pend = []

        self.out_q.put((None, None, None, None))  # sentinel

class PostprocWorker(threading.Thread):
    """CPU worker: convert YOLO result to final canvas panes"""
    def __init__(self, in_q, out_q):
        super().__init__(daemon=True)
        self.in_q = in_q
        self.out_q = out_q

    def run(self):
        while True:
            idx, p, frame, res = self.in_q.get()
            if idx is None:
                break
            # Build panes and assemble canvas
            left, right, heat = build_panes_from_result(frame, res)
            canvas = assemble_canvas(left, right, heat)
            canvas = maybe_downscale(canvas, MAX_DISPLAY_W)
            self.out_q.put((idx, p, canvas))
        self.out_q.put((None, None, None))

# =======================
# Main
# =======================
if __name__ == "__main__":
    folder = Path.home() / "Documents" / "GitHub" / "Ai-plays-SubwaySurfers" / "frames"
    if not folder.is_dir():
        print(f"[ERROR] frames folder not found: {folder}", file=sys.stderr); sys.exit(1)

    paths = sorted(
        glob.glob(str(folder / "frame_*.jpg")) +
        glob.glob(str(folder / "frame_*.png")) +
        glob.glob(str(folder / "*.jpg")) +
        glob.glob(str(folder / "*.png"))
    )
    if not paths:
        print(f"[ERROR] No frame images found in {folder}", file=sys.stderr); sys.exit(1)

    # Queues
    prefetch_q = queue.Queue(maxsize=PREFETCH)
    post_q     = queue.Queue(maxsize=POST_QUEUE)
    display_q  = queue.Queue(maxsize=POST_QUEUE)

    # Threads: prefetch -> inference -> postproc (N workers)
    pf = Prefetcher(paths, prefetch_q); pf.start()
    inf = InferenceWorker(prefetch_q, post_q, batch=YOLO_BATCH); inf.start()

    post_workers = []
    for _ in range(POST_WORKERS):
        w = PostprocWorker(post_q, display_q)
        w.start(); post_workers.append(w)

    # UI loop: show frames IN ORDER; compute runs in background
    cv2.namedWindow(WIN_NAME, cv2.WINDOW_NORMAL)
    next_to_show = 0
    buffer = {}  # idx -> (path, canvas)

    total = len(paths)
    finished_workers = 0
    expected_worker_sentinels = POST_WORKERS

    while True:
        # Drain display queue into buffer
        try:
            while True:
                idx, p, canvas = display_q.get(timeout=0.02)
                if idx is None:
                    finished_workers += 1
                    continue
                buffer[idx] = (p, canvas)
        except queue.Empty:
            pass

        # Present next frame in order when ready
        if next_to_show in buffer:
            p, canvas = buffer.pop(next_to_show)
            # On-screen text: show index + filename
            cv2.putText(canvas, f"{next_to_show+1}/{total}  |  {os.path.basename(p)}",
                        (12, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,0), 3, cv2.LINE_AA)
            cv2.putText(canvas, f"{next_to_show+1}/{total}  |  {os.path.basename(p)}",
                        (12, 28), cv2.FONT_HERSHEY_SIMPLEX, 0.8, TEXT_CLR, 1, cv2.LINE_AA)

            while True:
                cv2.imshow(WIN_NAME, canvas)
                key = cv2.waitKey(1) & 0xFF

                if cv2.getWindowProperty(WIN_NAME, cv2.WND_PROP_VISIBLE) < 1:
                    cv2.destroyAllWindows(); sys.exit(0)

                if key == 32:  # SPACE → next
                    next_to_show += 1
                    break
                elif key in (ord('q'), 27):  # q/ESC
                    cv2.destroyAllWindows(); sys.exit(0)
        else:
            # If all workers finished and nothing left to show, exit
            if finished_workers >= expected_worker_sentinels and next_to_show >= total:
                break
            time.sleep(0.005)

    cv2.destroyAllWindows()


In [None]:
#Indivudal processing of frames and timer

In [None]:
#!/usr/bin/env python3
"""
Visualise *all* segmentation masks with labelled overlays — single-frame, on-demand.

WHAT THIS DOES
• Waits for SPACE, then:
  - loads exactly one frame from disk (no prefetch, no batching)
  - runs YOLO segmentation on GPU/MPS (or CPU fallback)
  - does all CPU post-processing (labels, rails, heat, purple-triangle ≥15%)
  - prints honest end-to-end time for that frame and overlays it on the canvas
• Uses all CPU cores (OpenCV + PyTorch thread hints) and your GPU/MPS.

Controls:
  SPACE: process next frame   |   q / ESC: quit

Frames folder:
    ~/Documents/GitHub/Ai-plays-SubwaySurfers/frames
"""

import os, glob, sys, time
import cv2, torch, numpy as np
from pathlib import Path
from ultralytics import YOLO

# =======================
# Config
# =======================
home       = os.path.expanduser("~")
weights    = f"{home}/models/jakes-loped/jakes-finder-mk1/1/weights.pt"
frames_dir = Path(home) / "Documents" / "GitHub" / "Ai-plays-SubwaySurfers" / "frames"

RAIL_ID    = 9
IMG_SIZE   = 512
CONF, IOU  = 0.30, 0.45
ALPHA      = 0.40

TARGET_COLORS_RGB  = [(119,104,67), (81,42,45)]
TOLERANCE          = 20.0
MIN_REGION_SIZE    = 30
MIN_REGION_HEIGHT  = 150

HEAT_BLUR_KSIZE     = 51
RED_SCORE_THRESH    = 220
EXCLUDE_TOP_FRAC    = 0.40
EXCLUDE_BOTTOM_FRAC = 0.15
MIN_DARK_RED_AREA   = 1200
TRI_SIZE_PX         = 18
PURPLE              = (255, 0, 255)
MIN_DARK_FRACTION   = 0.15   # require ≥15% of total dark-red area

# UI settings
WIN_NAME      = "SPACE: next | q/ESC: quit"
MAX_DISPLAY_W = 1600
TEXT_CLR      = (255,255,255)

# =======================
# Maximize CPU/GPU utilization
# =======================
try:
    torch.set_num_threads(os.cpu_count() or 1)
except Exception:
    pass

try:
    # Leave 1 core free for the OS/UI
    OPENCV_THREADS = max(1, (os.cpu_count() or 1) - 1)
    cv2.setNumThreads(OPENCV_THREADS)
except Exception:
    pass

# Enable best GPU/matmul mode if CUDA
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    try:
        torch.set_float32_matmul_precision('high')
    except Exception:
        pass

# =======================
# Device & model init
# =======================
if torch.cuda.is_available():
    device, half = 0, True
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device, half = "mps", False
else:
    device, half = "cpu", False

model = YOLO(weights)
try:
    model.fuse()
except Exception:
    pass

# Warmup once so the first real frame isn't penalized by lazy inits
_dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
_ = model.predict(_dummy, task="segment", imgsz=IMG_SIZE,
                  device=device, conf=CONF, iou=IOU,
                  verbose=False, half=half)

# =======================
# Helper functions (logic unchanged)
# =======================
def highlight_rails_mask_only_fast(img_bgr, rail_mask, target_colors_rgb,
                                   tolerance=30.0, min_region_size=50,
                                   min_region_height=150):
    H, W = img_bgr.shape[:2]
    filtered_full = np.zeros((H, W), dtype=bool)
    if not rail_mask.any():
        return filtered_full

    ys, xs = np.where(rail_mask)
    y0, y1 = ys.min(), ys.max()+1
    x0, x1 = xs.min(), xs.max()+1

    img_roi  = img_bgr[y0:y1, x0:x1]
    mask_roi = rail_mask[y0:y1, x0:x1]

    targets_bgr = np.array([(r,g,b)[::-1] for r,g,b in target_colors_rgb],
                           dtype=np.float32)
    img_f = img_roi.astype(np.float32)
    diff  = img_f[:, :, None, :] - targets_bgr[None, None, :, :]
    dist2 = np.sum(diff*diff, axis=-1)
    colour_hit = np.any(dist2 <= tolerance**2, axis=-1)

    combined = colour_hit & mask_roi
    comp = combined.astype(np.uint8)
    n, lbls, stats, _ = cv2.connectedComponentsWithStats(comp, 8)

    good = np.zeros_like(combined)
    for lbl in range(1, n):
        area, h = stats[lbl, cv2.CC_STAT_AREA], stats[lbl, cv2.CC_STAT_HEIGHT]
        if area >= min_region_size and h >= min_region_height:
            good[lbls == lbl] = True

    filtered_full[y0:y1, x0:x1] = good
    return filtered_full

def red_vs_green_score(red_mask, green_mask, blur_ksize=51):
    k = (blur_ksize, blur_ksize)
    r = cv2.blur(red_mask.astype(np.float32), k)
    g = cv2.blur(green_mask.astype(np.float32), k)
    diff = r - g
    amax = float(np.max(np.abs(diff))) + 1e-6
    norm = (diff / (2*amax) + 0.5)
    return np.clip(norm * 255, 0, 255).astype(np.uint8)

def draw_triangle(img, x, y, size=TRI_SIZE_PX, colour=PURPLE):
    h = int(size * 1.2)
    pts = np.array([[x, y], [x-size, y+h], [x+size, y+h]], np.int32)
    cv2.fillConvexPoly(img, pts, colour)
    cv2.polylines(img, [pts.reshape(-1,1,2)], True, (0,0,0), 1, cv2.LINE_AA)

def assemble_canvas(left, right, heat):
    """Stack LEFT|RIGHT on the top row and HEAT on the bottom; width-safe."""
    top = np.hstack((left, right))
    th, tw = top.shape[:2]
    hh, hw = heat.shape[:2]
    if hw != tw or hh != th:
        # Make heat exactly the same size as 'top'
        heat = cv2.resize(heat, (tw, th), interpolation=cv2.INTER_LINEAR)
    return np.vstack((top, heat))

def maybe_downscale(img, max_w=MAX_DISPLAY_W):
    h, w = img.shape[:2]
    if w <= max_w:
        return img
    scale = max_w / float(w)
    new_size = (int(w*scale), int(h*scale))
    return cv2.resize(img, new_size, interpolation=cv2.INTER_AREA)

# =======================
# Single-frame processing (called after SPACE)
# =======================
def process_one_frame(frame):
    """
    Process one frame end-to-end (inference + postproc) and return:
      canvas (BGR), processing_time_ms
    """
    t0 = time.perf_counter()

    # YOLO inference (single image)
    res = model.predict(frame, task="segment", imgsz=IMG_SIZE,
                        device=device, conf=CONF, iou=IOU,
                        max_det=30, verbose=False, half=half)[0]

    # Ensure GPU/MPS kernels are done before timing postproc or reading tensors
    if device == 0 and torch.cuda.is_available():
        torch.cuda.synchronize()
    elif device == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
        torch.mps.synchronize()

    H, W = frame.shape[:2]
    if res.masks is None:
        left  = frame.copy()
        right = frame.copy()
        blank = np.zeros((H, W), np.uint8)
        heat  = cv2.applyColorMap(blank, cv2.COLORMAP_JET)
        heat  = cv2.resize(heat, (left.shape[1] + right.shape[1], H),
                           interpolation=cv2.INTER_LINEAR)
    else:
        masks = res.masks.data.cpu().numpy()
        classes = (res.masks.cls.cpu().numpy() if hasattr(res.masks, "cls")
                   else res.boxes.cls.cpu().numpy()).astype(int)

        # Build rail mask
        h_m, w_m = masks.shape[1:]
        union = np.zeros((h_m, w_m), bool)
        for m, c in zip(masks, classes):
            if int(c) == RAIL_ID:
                union |= m.astype(bool)
        rail_mask = cv2.resize(union.astype(np.uint8), (W, H),
                               interpolation=cv2.INTER_NEAREST).astype(bool)

        # Green/red + heat
        green = highlight_rails_mask_only_fast(frame, rail_mask,
                                               TARGET_COLORS_RGB, TOLERANCE,
                                               MIN_REGION_SIZE, MIN_REGION_HEIGHT)
        red   = rail_mask & ~green
        score = red_vs_green_score(red, green)

        # Exclude bands
        top_ex = int(H * EXCLUDE_TOP_FRAC)
        bot_ex = int(H * EXCLUDE_BOTTOM_FRAC)
        score[:top_ex, :] = 0
        if bot_ex: score[H-bot_ex:, :] = 0

        dark = (score >= RED_SCORE_THRESH).astype(np.uint8)
        dark = cv2.morphologyEx(
            dark, cv2.MORPH_OPEN,
            cv2.getStructuringElement(cv2.MORPH_RECT, (5, 9)),
            iterations=1
        )

        total_dark = int(dark.sum())
        frac_thresh = int(np.ceil(MIN_DARK_FRACTION * total_dark))

        # LEFT with labels
        left = frame.copy()
        overlay = left.copy()
        CLASS_COLOURS = {
            0:(255,255,0),1:(192,192,192),2:(0,128,255),3:(0,255,0),
            4:(255,0,255),5:(0,255,255),6:(255,128,0),7:(128,0,255),
            8:(0,0,128),10:(128,128,0),11:(255,255,102)
        }
        obstacle_classes = {
            0:"BOOTS",1:"GREYTRAIN",2:"HIGHBARRIER1",3:"JUMP",4:"LOWBARRIER1",
            5:"LOWBARRIER2",6:"ORANGETRAIN",7:"PILLAR",8:"RAMP",9:"RAILS",
            10:"SIDEWALK",11:"YELLOWTRAIN"
        }
        for m, c in zip(masks, classes):
            cid = int(c)
            if cid == RAIL_ID:
                continue
            mask_full = m
            if mask_full.shape != (H, W):
                mask_full = cv2.resize(mask_full.astype(np.uint8), (W, H),
                                       interpolation=cv2.INTER_NEAREST).astype(bool)
            colour = CLASS_COLOURS.get(cid, (255,255,255))
            overlay[mask_full] = colour

            ys, xs = np.where(mask_full)
            if len(xs):
                xc, yc = int(xs.mean()), int(ys.mean())
                label = obstacle_classes.get(cid, f"CLASS {cid}")
                cv2.putText(overlay, label, (xc-40, yc),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2, cv2.LINE_AA)
                cv2.putText(overlay, label, (xc-40, yc),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, TEXT_CLR, 1, cv2.LINE_AA)

        left = cv2.addWeighted(overlay, 0.6, left, 0.4, 0)

        # Purple triangles subject to absolute+fraction area thresholds
        n_lbl, lbls, stats, _ = cv2.connectedComponentsWithStats(dark, 8)
        for lbl in range(1, n_lbl):
            area = stats[lbl, cv2.CC_STAT_AREA]
            if area < MIN_DARK_RED_AREA or area < frac_thresh:
                continue
            ys, xs = np.where(lbls == lbl)
            y_top, x_mid = ys.min(), int(xs[ys == ys.min()].mean())
            draw_triangle(left, x_mid, y_top)

        # RIGHT rails + green
        right = frame.copy()
        tint = right.copy()
        tint[rail_mask] = (0,0,255)
        right = cv2.addWeighted(tint, ALPHA, right, 1-ALPHA, 0)
        right[green] = (0,255,0)

        # HEAT (width must be left+right)
        heat = cv2.applyColorMap(score, cv2.COLORMAP_JET)
        heat = cv2.resize(heat, (left.shape[1] + right.shape[1], H),
                          interpolation=cv2.INTER_LINEAR)

    # Assemble + annotate
    canvas = assemble_canvas(left, right, heat)
    ms = (time.perf_counter() - t0) * 1000.0
    tag = f"{ms:.1f} ms"
    cv2.putText(canvas, tag, (12, 28),
                cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,0), 3, cv2.LINE_AA)
    cv2.putText(canvas, tag, (12, 28),
                cv2.FONT_HERSHEY_SIMPLEX, 0.8, TEXT_CLR, 1, cv2.LINE_AA)
    canvas = maybe_downscale(canvas, MAX_DISPLAY_W)
    return canvas, ms

# =======================
# Main loop (no batching, no precompute)
# =======================
if __name__=="__main__":
    if not frames_dir.is_dir():
        print(f"[ERROR] Frames folder not found: {frames_dir}", file=sys.stderr)
        sys.exit(1)

    paths = sorted(#glob.glob(str(frames_dir/"frame_*.jpg")) +
                   #glob.glob(str(frames_dir/"frame_*.png")) +
                   #glob.glob(str(frames_dir/"*.jpg")) +
                   glob.glob(str(frames_dir/"*.png")))
    if not paths:
        print("[ERROR] No frame images found.", file=sys.stderr)
        sys.exit(1)

    cv2.namedWindow(WIN_NAME, cv2.WINDOW_NORMAL)
    total = len(paths)
    idx = 0

    # Show an instruction screen
    canvas0 = np.zeros((320, 960, 3), np.uint8)
    cv2.putText(canvas0, "Press SPACE to process next frame   |   q/ESC to quit",
                (20, 170), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (255,255,255), 2, cv2.LINE_AA)
    cv2.imshow(WIN_NAME, canvas0)

    while idx < total:
        key = cv2.waitKey(0) & 0xFF
        if key in (ord('q'), 27):
            break
        if key != 32:  # not SPACE
            continue

        p = paths[idx]
        frame = cv2.imread(p)
        if frame is None:
            print(f"[WARN] unreadable image: {p}")
            idx += 1
            continue

        canvas, ms = process_one_frame(frame)
        print(f"[{idx+1}/{total}] {os.path.basename(p)}  →  {ms:.1f} ms")
        # filename overlay
        cv2.putText(canvas, f"{idx+1}/{total}  |  {os.path.basename(p)}",
                    (12, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0,0,0), 3, cv2.LINE_AA)
        cv2.putText(canvas, f"{idx+1}/{total}  |  {os.path.basename(p)}",
                    (12, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.7, TEXT_CLR, 1, cv2.LINE_AA)

        cv2.imshow(WIN_NAME, canvas)
        idx += 1

        # If window closed manually, exit cleanly
        if cv2.getWindowProperty(WIN_NAME, cv2.WND_PROP_VISIBLE) < 1:
            break

    cv2.destroyAllWindows()


In [None]:
#First 20 frames, ready for ground logic

In [None]:
#!/usr/bin/env python3
"""
Notebook inline visualiser — single-frame processing with *granular* timing.

What it does
• Processes the first N frames (no prefetch, no batching, no UI window).
• For each frame:
  - YOLO segmentation on GPU/MPS (or CPU)
  - CPU post-processing (labels, rails, heat, purple-triangle ≥15%)
  - Prints a compact per-frame timing breakdown (read, infer, sync, to_cpu, post, assemble)
  - 'proc_total' = inference + sync + to_cpu + post + assemble (no display/printing)
• Shows the image inline (if in JN/VSCode). Display/printing are excluded from timings.

Frames folder:
    ~/Documents/GitHub/Ai-plays-SubwaySurfers/frames
"""

import os, glob, sys, time
import cv2, torch, numpy as np
from pathlib import Path
from ultralytics import YOLO

# Inline display (not part of timing)
try:
    from IPython.display import display
    from PIL import Image
    _HAS_IPY = True
except Exception:
    _HAS_IPY = False

# =======================
# Config
# =======================
home       = os.path.expanduser("~")
weights    = f"{home}/models/jakes-loped/jakes-finder-mk1/1/weights.pt"
frames_dir = Path(home) / "Documents" / "GitHub" / "Ai-plays-SubwaySurfers" / "frames"

SHOW_FIRST_N = 20  # process only first N frames; set None to do all

RAIL_ID    = 9
IMG_SIZE   = 512
CONF, IOU  = 0.30, 0.45
ALPHA      = 0.40

TARGET_COLORS_RGB  = [(119,104,67), (81,42,45)]
TOLERANCE          = 20.0
MIN_REGION_SIZE    = 30
MIN_REGION_HEIGHT  = 150

HEAT_BLUR_KSIZE     = 51
RED_SCORE_THRESH    = 220
EXCLUDE_TOP_FRAC    = 0.40
EXCLUDE_BOTTOM_FRAC = 0.15
MIN_DARK_RED_AREA   = 1200
TRI_SIZE_PX         = 18
PURPLE              = (255, 0, 255)
MIN_DARK_FRACTION   = 0.15   # ≥15% of total dark-red area

# Display overlay settings
MAX_DISPLAY_W = 1200
TEXT_CLR      = (255,255,255)

# =======================
# Maximize CPU/GPU utilization (thread hints)
# =======================
try:
    torch.set_num_threads(os.cpu_count() or 1)
except Exception:
    pass

try:
    # Leave 1 core for OS/UI
    OPENCV_THREADS = max(1, (os.cpu_count() or 1) - 1)
    cv2.setNumThreads(OPENCV_THREADS)
except Exception:
    pass

# CUDA backend optimizations
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    try:
        torch.set_float32_matmul_precision('high')
    except Exception:
        pass

# =======================
# Device & model init
# =======================
if torch.cuda.is_available():
    device, half = 0, True
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device, half = "mps", False
else:
    device, half = "cpu", False

model = YOLO(weights)
try:
    model.fuse()
except Exception:
    pass

# Warmup to avoid first-frame penalty (not timed)
_dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
_ = model.predict(_dummy, task="segment", imgsz=IMG_SIZE,
                  device=device, conf=CONF, iou=IOU,
                  verbose=False, half=half)

# =======================
# Helpers (logic unchanged)
# =======================
def highlight_rails_mask_only_fast(img_bgr, rail_mask, target_colors_rgb,
                                   tolerance=30.0, min_region_size=50,
                                   min_region_height=150):
    H, W = img_bgr.shape[:2]
    filtered_full = np.zeros((H, W), dtype=bool)
    if not rail_mask.any():
        return filtered_full

    ys, xs = np.where(rail_mask)
    y0, y1 = ys.min(), ys.max()+1
    x0, x1 = xs.min(), xs.max()+1

    img_roi  = img_bgr[y0:y1, x0:x1]
    mask_roi = rail_mask[y0:y1, x0:x1]

    targets_bgr = np.array([(r,g,b)[::-1] for r,g,b in target_colors_rgb],
                           dtype=np.float32)
    img_f = img_roi.astype(np.float32)
    diff  = img_f[:, :, None, :] - targets_bgr[None, None, :, :]
    dist2 = np.sum(diff*diff, axis=-1)
    colour_hit = np.any(dist2 <= tolerance**2, axis=-1)

    combined = colour_hit & mask_roi
    comp = combined.astype(np.uint8)
    n, lbls, stats, _ = cv2.connectedComponentsWithStats(comp, 8)

    good = np.zeros_like(combined)
    for lbl in range(1, n):
        area, h = stats[lbl, cv2.CC_STAT_AREA], stats[lbl, cv2.CC_STAT_HEIGHT]
        if area >= min_region_size and h >= min_region_height:
            good[lbls == lbl] = True

    filtered_full[y0:y1, x0:x1] = good
    return filtered_full

def red_vs_green_score(red_mask, green_mask, blur_ksize=51):
    k = (blur_ksize, blur_ksize)
    r = cv2.blur(red_mask.astype(np.float32), k)
    g = cv2.blur(green_mask.astype(np.float32), k)
    diff = r - g
    amax = float(np.max(np.abs(diff))) + 1e-6
    norm = (diff / (2*amax) + 0.5)
    return np.clip(norm * 255, 0, 255).astype(np.uint8)

def draw_triangle(img, x, y, size=TRI_SIZE_PX, colour=PURPLE):
    h = int(size * 1.2)
    pts = np.array([[x, y], [x-size, y+h], [x+size, y+h]], np.int32)
    cv2.fillConvexPoly(img, pts, colour)
    cv2.polylines(img, [pts.reshape(-1,1,2)], True, (0,0,0), 1, cv2.LINE_AA)

def assemble_canvas(left, right, heat):
    """Stack LEFT|RIGHT on top row and HEAT on bottom; exact width match."""
    top = np.hstack((left, right))
    th, tw = top.shape[:2]
    hh, hw = heat.shape[:2]
    if hw != tw or hh != th:
        heat = cv2.resize(heat, (tw, th), interpolation=cv2.INTER_LINEAR)
    return np.vstack((top, heat))

def maybe_downscale(img, max_w=MAX_DISPLAY_W):
    h, w = img.shape[:2]
    if w <= max_w:
        return img
    scale = max_w / float(w)
    return cv2.resize(img, (int(w*scale), int(h*scale)), interpolation=cv2.INTER_AREA)

# =======================
# Per-frame processing with segmented timing
# =======================
def process_one_frame_with_timers(frame):
    """
    Returns (canvas_bgr, timings_dict)

    timings_dict fields (ms):
      infer, sync, to_cpu, post, assemble, proc_total
    """
    t0_infer = time.perf_counter()

    # YOLO inference (single image)
    res = model.predict(frame, task="segment", imgsz=IMG_SIZE,
                        device=device, conf=CONF, iou=IOU,
                        max_det=30, verbose=False, half=half)[0]
    t1_infer = time.perf_counter()

    # Ensure GPU/MPS kernels complete before using results
    t0_sync = time.perf_counter()
    try:
        if device == 0 and torch.cuda.is_available():
            torch.cuda.synchronize()
        elif device == "mps" and getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
            torch.mps.synchronize()
    except Exception:
        pass
    t1_sync = time.perf_counter()

    # Move tensors to CPU (measured separately)
    t0_to_cpu = time.perf_counter()
    if res.masks is None:
        masks_np, classes_np = None, None
    else:
        masks_np = res.masks.data.cpu().numpy()
        classes_np = (res.masks.cls.cpu().numpy() if hasattr(res.masks, "cls")
                      else res.boxes.cls.cpu().numpy()).astype(int)
    t1_to_cpu = time.perf_counter()

    # Post-processing
    t0_post = time.perf_counter()
    H, W = frame.shape[:2]
    if masks_np is None:
        left  = frame.copy()
        right = frame.copy()
        blank = np.zeros((H, W), np.uint8)
        heat  = cv2.applyColorMap(blank, cv2.COLORMAP_JET)
        heat  = cv2.resize(heat, (left.shape[1] + right.shape[1], H),
                           interpolation=cv2.INTER_LINEAR)
    else:
        masks   = masks_np
        classes = classes_np

        # Rail union at model mask res → upsample to frame res
        h_m, w_m = masks.shape[1:]
        union = np.zeros((h_m, w_m), bool)
        for m, c in zip(masks, classes):
            if int(c) == RAIL_ID:
                union |= m.astype(bool)
        rail_mask = cv2.resize(union.astype(np.uint8), (W, H),
                               interpolation=cv2.INTER_NEAREST).astype(bool)

        # Green/red + heat precursor
        green = highlight_rails_mask_only_fast(frame, rail_mask,
                                               TARGET_COLORS_RGB, TOLERANCE,
                                               MIN_REGION_SIZE, MIN_REGION_HEIGHT)
        red   = rail_mask & ~green
        score = red_vs_green_score(red, green, HEAT_BLUR_KSIZE)

        # Exclude top/bottom bands
        top_ex = int(H * EXCLUDE_TOP_FRAC)
        bot_ex = int(H * EXCLUDE_BOTTOM_FRAC)
        score[:top_ex, :] = 0
        if bot_ex: score[H-bot_ex:, :] = 0

        dark = (score >= RED_SCORE_THRESH).astype(np.uint8)
        dark = cv2.morphologyEx(
            dark, cv2.MORPH_OPEN,
            cv2.getStructuringElement(cv2.MORPH_RECT, (5, 9)),
            iterations=1
        )

        total_dark  = int(dark.sum())
        frac_thresh = int(np.ceil(MIN_DARK_FRACTION * total_dark))

        # LEFT with labels
        left    = frame.copy()
        overlay = left.copy()
        CLASS_COLOURS = {
            0:(255,255,0),1:(192,192,192),2:(0,128,255),3:(0,255,0),
            4:(255,0,255),5:(0,255,255),6:(255,128,0),7:(128,0,255),
            8:(0,0,128),10:(128,128,0),11:(255,255,102)
        }
        obstacle_classes = {
            0:"BOOTS",1:"GREYTRAIN",2:"HIGHBARRIER1",3:"JUMP",4:"LOWBARRIER1",
            5:"LOWBARRIER2",6:"ORANGETRAIN",7:"PILLAR",8:"RAMP",9:"RAILS",
            10:"SIDEWALK",11:"YELLOWTRAIN"
        }
        for m, c in zip(masks, classes):
            cid = int(c)
            if cid == RAIL_ID:
                continue
            mask_full = m
            if mask_full.shape != (H, W):
                mask_full = cv2.resize(mask_full.astype(np.uint8), (W, H),
                                       interpolation=cv2.INTER_NEAREST).astype(bool)
            overlay[mask_full] = CLASS_COLOURS.get(cid, (255,255,255))

            ys, xs = np.where(mask_full)
            if len(xs):
                xc, yc = int(xs.mean()), int(ys.mean())
                label = obstacle_classes.get(cid, f"CLASS {cid}")
                cv2.putText(overlay, label, (xc-40, yc),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2, cv2.LINE_AA)
                cv2.putText(overlay, label, (xc-40, yc),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, TEXT_CLR, 1, cv2.LINE_AA)

        left = cv2.addWeighted(overlay, 0.6, left, 0.4, 0)

        # Purple triangles: absolute + fraction thresholds
        n_lbl, lbls, stats, _ = cv2.connectedComponentsWithStats(dark, 8)
        for lbl in range(1, n_lbl):
            area = stats[lbl, cv2.CC_STAT_AREA]
            if area < MIN_DARK_RED_AREA or area < frac_thresh:
                continue
            ys, xs = np.where(lbls == lbl)
            y_top, x_mid = ys.min(), int(xs[ys == ys.min()].mean())
            draw_triangle(left, x_mid, y_top)

        # RIGHT rails + green
        right = frame.copy()
        tint = right.copy()
        tint[rail_mask] = (0,0,255)
        right = cv2.addWeighted(tint, ALPHA, right, 1-ALPHA, 0)
        right[green] = (0,255,0)

        # HEAT (width = left + right)
        heat = cv2.applyColorMap(score, cv2.COLORMAP_JET)
        heat = cv2.resize(heat, (left.shape[1] + right.shape[1], H),
                          interpolation=cv2.INTER_LINEAR)

    t1_post = time.perf_counter()

    # Assemble + downscale
    t0_assemble = time.perf_counter()
    canvas = assemble_canvas(left, right, heat)
    canvas = maybe_downscale(canvas, MAX_DISPLAY_W)
    t1_assemble = time.perf_counter()

    # Compose timing dict
    infer_ms    = (t1_infer    - t0_infer)    * 1000.0
    sync_ms     = (t1_sync     - t0_sync)     * 1000.0
    to_cpu_ms   = (t1_to_cpu   - t0_to_cpu)   * 1000.0
    post_ms     = (t1_post     - t0_post)     * 1000.0
    assemble_ms = (t1_assemble - t0_assemble) * 1000.0
    proc_total  = infer_ms + sync_ms + to_cpu_ms + post_ms + assemble_ms

    timings = {
        "infer": infer_ms,
        "sync": sync_ms,
        "to_cpu": to_cpu_ms,
        "post": post_ms,
        "assemble": assemble_ms,
        "proc_total": proc_total
    }
    return canvas, timings

def tiny_line(idx, total, fname, read_ms, t):
    # Compact, fixed-order timing print
    return (f"[{idx}/{total}] {fname}  "
            f"read {read_ms:.1f} | infer {t['infer']:.1f} | sync {t['sync']:.1f} | "
            f"to_cpu {t['to_cpu']:.1f} | post {t['post']:.1f} | asm {t['assemble']:.1f} "
            f"=> proc {t['proc_total']:.1f} ms")

# =======================
# Main (first N frames, inline prints)
# =======================
if __name__=="__main__":
    if not frames_dir.is_dir():
        print(f"[ERROR] Frames folder not found: {frames_dir}", file=sys.stderr)
        sys.exit(1)

    # Gather frames once (I/O not included in processing timing)
    paths_raw = (
        glob.glob(str(frames_dir/"frame_*.jpg")) +
        glob.glob(str(frames_dir/"frame_*.png")) +
        glob.glob(str(frames_dir/"*.jpg")) +
        glob.glob(str(frames_dir/"*.png"))
    )
    paths = sorted(set(paths_raw))
    if not paths:
        print("[ERROR] No frame images found.", file=sys.stderr)
        sys.exit(1)

    total = len(paths) if SHOW_FIRST_N is None else min(SHOW_FIRST_N, len(paths))

    for i in range(total):
        p = paths[i]

        # --- disk I/O timing (for context only) ---
        t0_read = time.perf_counter()
        frame = cv2.imread(p)
        t1_read = time.perf_counter()
        read_ms = (t1_read - t0_read) * 1000.0
        if frame is None:
            print(f"[warn] unreadable: {os.path.basename(p)}")
            continue

        # --- processing timing (excludes any display/printing) ---
        canvas, t = process_one_frame_with_timers(frame)

        # Overlay *processing* time on image (proc_total)
        tag = f"{t['proc_total']:.1f} ms"
        cv2.putText(canvas, tag, (12, 28),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0,0,0), 3, cv2.LINE_AA)
        cv2.putText(canvas, tag, (12, 28),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, TEXT_CLR, 1, cv2.LINE_AA)
        cv2.putText(canvas, f"{i+1}/{total} | {os.path.basename(p)}",
                    (12, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0,0,0), 2, cv2.LINE_AA)
        cv2.putText(canvas, f"{i+1}/{total} | {os.path.basename(p)}",
                    (12, 60), cv2.FONT_HERSHEY_SIMPLEX, 0.6, TEXT_CLR, 1, cv2.LINE_AA)

        # --- display + tiny print (not timed) ---
        if _HAS_IPY:
            rgb = cv2.cvtColor(canvas, cv2.COLOR_BGR2RGB)
            display(Image.fromarray(rgb))
            print(tiny_line(i+1, total, os.path.basename(p), read_ms, t))
        else:
            print(tiny_line(i+1, total, os.path.basename(p), read_ms, t))

    if _HAS_IPY and total < len(paths):
        print(f"\nDone. Shown first {total} of {len(paths)} frames.")


In [None]:
#!/usr/bin/env python3
"""
Pure inference timing on images in ~/Documents/GitHub/Ai-plays-SubwaySurfers/frames

What we time:
  start = just before model.predict(...)
  end   = after GPU/MPS sync completes
  => inference-only latency per image

No post-processing, no display. Small prints.
"""

import os, glob, sys, time, statistics
import cv2, torch, numpy as np
from pathlib import Path
from ultralytics import YOLO

# ---------- Config ----------
home       = os.path.expanduser("~")
weights    = f"{home}/models/jakes-loped/jakes-finder-mk1/1/weights.pt"
frames_dir = Path(home) / "Documents" / "GitHub" / "Ai-plays-SubwaySurfers" / "frames"

IMG_SIZE   = 512
CONF, IOU  = 0.30, 0.45
MAX_DET    = 30
WARMUP     = 5   # number of frames to run but not include in stats

# ---------- Device & perf hints ----------
if torch.cuda.is_available():
    device, half = 0, True
    torch.backends.cudnn.benchmark = True
    try: torch.set_float32_matmul_precision('high')
    except Exception: pass
elif getattr(torch.backends, "mps", None) and torch.backends.mps.is_available():
    device, half = "mps", False
else:
    device, half = "cpu", False

# Use all CPU threads for any host-side work (tiny effect here, but harmless)
try: torch.set_num_threads(os.cpu_count() or 1)
except Exception: pass

# ---------- Load model ----------
model = YOLO(weights)
try: model.fuse()
except Exception: pass

# Warmup once (not timed)
_dummy = np.zeros((IMG_SIZE, IMG_SIZE, 3), np.uint8)
_ = model.predict(_dummy, task="segment", imgsz=IMG_SIZE,
                  device=device, conf=CONF, iou=IOU,
                  verbose=False, half=half)

# ---------- Collect frames ----------
if not frames_dir.is_dir():
    print(f"[ERROR] Frames folder not found: {frames_dir}", file=sys.stderr)
    sys.exit(1)

paths = sorted(
    set(glob.glob(str(frames_dir/"frame_*.jpg")) +
        glob.glob(str(frames_dir/"frame_*.png")) +
        glob.glob(str(frames_dir/"*.jpg")) +
        glob.glob(str(frames_dir/"*.png")))
)
if not paths:
    print("[ERROR] No frame images found.", file=sys.stderr)
    sys.exit(1)

# ---------- Run inference timing ----------
times_ms = []
n_total = len(paths)

for i, p in enumerate(paths, start=1):
    img = cv2.imread(p)
    if img is None:
        print(f"[warn] unreadable: {os.path.basename(p)}")
        continue

    t0 = time.perf_counter()
    res = model.predict(
        img,
        task="segment",
        imgsz=IMG_SIZE,
        device=device,
        conf=CONF,
        iou=IOU,
        max_det=MAX_DET,
        verbose=False,
        half=half
    )
    # Ensure kernels complete before stopping the clock
    try:
        if device == 0 and torch.cuda.is_available():
            torch.cuda.synchronize()
        elif device == "mps" and torch.backends.mps.is_available():
            torch.mps.synchronize()
    except Exception:
        pass

    t_ms = (time.perf_counter() - t0) * 1000.0
    # Skip counting the first WARMUP frames in stats, but still show a print
    show_idx = i
    if i > WARMUP:
        times_ms.append(t_ms)
        print(f"[{show_idx}/{n_total}] {os.path.basename(p)} → {t_ms:.1f} ms")
    else:
        print(f"[warmup {i}/{WARMUP}] {os.path.basename(p)} → {t_ms:.1f} ms")

# ---------- Summary ----------
if times_ms:
    avg = statistics.mean(times_ms)
    med = statistics.median(times_ms)
    print("\n--- Inference-only summary (warmup skipped) ---")
    print(f"Frames timed : {len(times_ms)}")
    print(f"Average      : {avg:.1f} ms  ({1000.0/avg:.2f} FPS)")
    print(f"Median       : {med:.1f} ms")
    print(f"Fastest      : {min(times_ms):.1f} ms")
    print(f"Slowest      : {max(times_ms):.1f} ms")
else:
    print("\nNo frames were timed (all were warmup or unreadable).")
