# Cluster-Track Stage

In [None]:
"""
Joint clustering-tracking for LiDAR frames.
- Per-point encoder produces embeddings for spatial-semantic assignment.
- Two-stage assignment: CORE (tight gate + margin) then GROW (looser, density- and Mahalanobis-aware).
- Optional births/deaths keep tracks consistent through occlusions and merges.

This snippet is the clean version: paths, plotting, exports and long training boilerplate are removed.
"""

from __future__ import annotations
import math, time, json, random, re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# ----------------------------- Config -----------------------------

@dataclass
class GeoMask:
    range_max_m: float = 95.0
    hfov_deg: float = 70.4
    r_full_m: float = 30.0
    center_at_far_deg: float = 5.0
    taper_power: float = 5.0
    use_planar_range: bool = True
    use_simple_ground: bool = True
    ground_z_margin: float = 0.12

@dataclass
class EvalCfg:
    assign_max_dist: float = 3.0
    adaptive_gate_k: float = 0.04
    mahal_xy_thr: float = 2.8
    mahal_eps: float = 1e-4
    use_two_stage: bool = True
    allow_births: bool = True
    allow_deaths: bool = True
    death_max_misses: int = 6
    birth_min_pts: int = 28
    birth_eps_xy: float = 0.9
    birth_min_d_from_seeds: float = 1.2
    birth_keep_if_dense_q: float = 0.50
    r_split_near: float = 25.0
    alpha_near: float = 0.60
    alpha_far: float = 0.40
    bg_gate_fraction: float = 0.75
    low_density_factor: float = 0.40
    far_close_ratio: float = 0.60
    motion_smooth: float = 0.30
    pca_low: float = 5.0
    pca_high: float = 95.0
    pca_min_span: float = 0.25
    emb_batch: int = 32768
    use_amp: bool = True

@dataclass
class BuildCfg:
    stack_w: int = 4
    margin: float = 0.05
    min_points: int = 10
    include_classes: Iterable[str] = ("dog", "human", "atlas")

GMASK = GeoMask()
ECFG  = EvalCfg()
BCFG  = BuildCfg()

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TIME_DENOM = 100.0  # sin/cos time code scale

# ----------------------------- I/O -----------------------------

POINT_EXTS = [".bin", ".npz", ".npy", ".csv", ".txt", ".pcd"]

def _frame_idx_from_stem(stem: str) -> int:
    m = re.search(r"(\d+)$", stem)
    return int(m.group(1)) if m else -1

def _find_points_file(points_dir: Path, stem: str) -> Optional[Path]:
    for ext in POINT_EXTS:
        p = points_dir / f"{stem}{ext}"
        if p.exists(): return p
    return None

def read_points_auto(path: Path) -> np.ndarray:
    suf = path.suffix.lower()
    if suf == ".bin":
        a = np.fromfile(path, dtype=np.float32)
        a = a.reshape(-1, 4)[:, :3] if a.size % 4 == 0 else a[: (a.size//3)*3].reshape(-1, 3)
        return a.astype(np.float32)
    if suf == ".npz":
        d = np.load(path)
        for k in ("xyz", "points", "xyzi", "arr_0"):
            if k in d and d[k].ndim == 2 and d[k].shape[1] >= 3:
                return d[k][:, :3].astype(np.float32)
        return np.zeros((0, 3), np.float32)
    if suf == ".npy":
        a = np.asarray(np.load(path))
        if a.ndim == 2 and a.shape[1] >= 3: return a[:, :3].astype(np.float32)
        if a.ndim == 2 and a.shape[0] >= 3: return a.T[:, :3].astype(np.float32)
        return np.zeros((0, 3), np.float32)
    if suf in {".csv", ".txt"}:
        try: a = np.loadtxt(path, delimiter=",")
        except Exception: a = np.loadtxt(path)
        a = np.asarray(a)
        if a.ndim == 2 and a.shape[1] >= 3: return a[:, :3].astype(np.float32)
        if a.ndim == 2 and a.shape[0] >= 3: return a.T[:, :3].astype(np.float32)
        return np.zeros((0, 3), np.float32)
    if suf == ".pcd":
        try:
            import open3d as o3d
            o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error)
            pcd = o3d.io.read_point_cloud(str(path))
            return np.asarray(pcd.points, dtype=np.float32)
        except Exception:
            return np.zeros((0, 3), np.float32)
    return np.zeros((0, 3), np.float32)

def load_label(path: Path) -> Dict:
    try: return json.loads(path.read_text(encoding="utf-8"))
    except Exception: return {}

def parse_objects(meta: Dict) -> List[Dict]:
    inc = {c.lower() for c in BCFG.include_classes}
    out = []
    for o in meta.get("objects", []):
        cls = str(o.get("class", "")).lower().strip()
        if cls == "dog_atlas": cls = "atlas"
        if cls not in inc: continue
        pos = np.array(o.get("position", [0, 0, 0]), np.float32)
        rot = np.array(o.get("rotation", [0, 0, 0]), np.float32)
        scale = np.array(o.get("scale",    [0, 0, 0]), np.float32)
        tid = int(o.get("track_id", -1))
        out.append({"class": cls, "pos": pos, "yaw": float(rot[-1]), "scale": scale, "track_id": tid})
    return out

# ----------------------------- Masks & PCA -----------------------------

def _width_allowed_monotone(r: np.ndarray, r_full, r_max, hfov_half, center_far, power):
    t = np.clip((r - r_full) / max(1e-6, (r_max - r_full)), 0.0, 1.0)
    ease = np.power(1.0 - t, power)
    theta = center_far + (hfov_half - center_far) * ease
    theta[r <= r_full] = hfov_half
    return r * np.tan(theta)

def fov_mask(xyz: np.ndarray, cfg: GeoMask = GMASK) -> np.ndarray:
    if xyz.size == 0: return np.zeros((0,), bool)
    x, y, _ = xyz[:, 0], xyz[:, 1], xyz[:, 2]
    r = np.sqrt(x*x + y*y) if cfg.use_planar_range else np.linalg.norm(xyz, axis=1)
    m_range = (r <= float(cfg.range_max_m)) & np.isfinite(r)
    hf = math.radians(cfg.hfov_deg * 0.5)
    cf = math.radians(cfg.center_at_far_deg)
    w = _width_allowed_monotone(r, cfg.r_full_m, cfg.range_max_m, hf, cf, cfg.taper_power)
    m_width = np.abs(y) <= w
    return m_range & m_width

def simple_ground_mask(xyz: np.ndarray, z_margin: float = GMASK.ground_z_margin) -> np.ndarray:
    if xyz.size == 0: return np.zeros((0,), bool)
    x, y, z = xyz[:, 0], xyz[:, 1], xyz[:, 2]
    r = np.sqrt(x*x + y*y)
    bins = np.linspace(0, 95.0, 20)
    keep = np.zeros(len(z), bool)
    for i in range(len(bins)-1):
        m = (r >= bins[i]) & (r < bins[i+1])
        if not m.any(): continue
        zg = np.percentile(z[m], 5)
        keep[m] = z[m] > (zg + z_margin)
    return keep

def pca_tighten(pts: np.ndarray, low=5.0, high=95.0, min_span=0.25) -> np.ndarray:
    if pts.shape[0] < 5: return pts
    xy = pts[:, :2] - np.median(pts[:, :2], axis=0, keepdims=True)
    C = np.cov(xy.T) + 1e-6 * np.eye(2, dtype=np.float32)
    _, V = np.linalg.eigh(C)
    xy_r = xy @ V
    lo, hi = np.percentile(xy_r, low, axis=0), np.percentile(xy_r, high, axis=0)
    span = np.maximum(hi - lo, min_span)
    ok_xy = ((xy_r[:, 0] >= lo[0]) & (xy_r[:, 0] <= lo[0] + span[0]) &
             (xy_r[:, 1] >= lo[1]) & (xy_r[:, 1] <= lo[1] + span[1]))
    z = pts[:, 2]
    zlo, zhi = np.percentile(z, low), np.percentile(z, high)
    ok_z = np.ones_like(z, bool) if (zhi - zlo) < 1e-6 else ((z >= zlo) & (z <= zhi))
    keep = ok_xy & ok_z
    return pts if not np.any(keep) else pts[keep]

def pca_tighten_eval(pts: np.ndarray) -> np.ndarray:
    return pca_tighten(pts, ECFG.pca_low, ECFG.pca_high, ECFG.pca_min_span)

# ----------------------------- Labels to points -----------------------------

def points_in_cuboid(xyz: np.ndarray, center: np.ndarray, scale: np.ndarray, yaw: float, margin: float) -> np.ndarray:
    if xyz.size == 0: return np.zeros((0,), bool)
    c = np.asarray(center, np.float32).reshape(3)
    s = np.asarray(scale,  np.float32).reshape(3)
    hx, hy, hz = 0.5*np.abs(s) * (1.0 + margin)
    if hz <= 0: hz = 5.0
    if (hx < 1e-3) or (hy < 1e-3): return np.zeros((xyz.shape[0],), bool)
    if abs(yaw) > math.pi: yaw = math.radians(yaw % 360.0)
    cth, sth = math.cos(yaw), math.sin(yaw)
    P = xyz - c[None, :]
    X =  cth*P[:,0] + sth*P[:,1]
    Y = -sth*P[:,0] + cth*P[:,1]
    Z =  P[:,2]
    return (np.abs(X)<=hx) & (np.abs(Y)<=hy) & (np.abs(Z)<=hz)

def gt_assign_for_eval(xyz: np.ndarray, objs: List[Dict]) -> np.ndarray:
    if xyz.shape[0] == 0 or not objs: return -np.ones((0,), np.int64)
    centers = np.stack([o["pos"][:2] for o in objs], 0).astype(np.float32)
    tids = np.array([int(o["track_id"]) for o in objs], np.int64)
    radii = []
    for o in objs:
        sc = np.asarray(o.get("scale", [1,1,1]), np.float32)
        r = float(1.3 * float(np.max(np.abs(sc[:2])))) if sc.size >= 2 else 1.0
        radii.append(r)
    radii = np.array(radii, np.float32)
    XY = xyz[:, :2]
    d2 = ((XY[:, None, :] - centers[None, :, :])**2).sum(2)
    j = np.argmin(d2, axis=1)
    d = np.sqrt(d2[np.arange(xyz.shape[0]), j])
    out = -np.ones((xyz.shape[0],), np.int64)
    ok = d <= radii[j]
    out[ok] = tids[j[ok]]
    return out

# ----------------------------- Encoder (eval-time) -----------------------------

class SetEncoder(nn.Module):
    """Per-point MLP -> max-pool -> embedding head. Used both for per-point and seed prototypes."""
    def __init__(self, in_dim=7, hidden=256, emb_dim=128):
        super().__init__()
        self.point_mlp = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.ReLU(True),
            nn.Linear(hidden, hidden), nn.ReLU(True),
            nn.Linear(hidden, 256),     nn.ReLU(True),
        )
        self.head = nn.Sequential(nn.Linear(256, emb_dim), nn.LayerNorm(emb_dim))

    def point_feats(self, X):  # (B,N,7) -> (B,N,256)
        return self.point_mlp(X)

    def pool_set(self, H):     # (B,N,256) -> (B,256)
        return torch.max(H, dim=1).values

    def forward(self, X):      # (B,N,7) -> (B,emb)
        H = self.point_feats(X); Z = self.head(self.pool_set(H))
        return F.normalize(Z, dim=1)

def _time_code_scalar(t_idx: int, denom: float = TIME_DENOM) -> Tuple[np.float32, np.float32]:
    tau = float(t_idx) / float(denom)
    return np.float32(np.cos(tau)), np.float32(np.sin(tau))

@torch.no_grad()
def embed_points_pointwise(model: SetEncoder, P: np.ndarray, t_idx: int = 0, batch: int = None) -> np.ndarray:
    if P.size == 0: return np.zeros((0, 128), np.float32)
    batch = batch or ECFG.emb_batch
    out = []
    use_amp = ECFG.use_amp and torch.cuda.is_available()
    for s in range(0, P.shape[0], batch):
        e = min(P.shape[0], s+batch)
        A = P[s:e].astype(np.float32)
        x, y, z = A[:,0], A[:,1], A[:,2]
        r = np.sqrt(x*x + y*y); th = np.arctan2(y, x)
        tcos, tsin = _time_code_scalar(int(t_idx))
        X = torch.from_numpy(np.stack([x,y,z,r,th,
                                       np.full_like(r, tcos, np.float32),
                                       np.full_like(r, tsin, np.float32)], 1)
                             ).unsqueeze(0).to(DEVICE, non_blocking=True)
        if use_amp:
            with torch.amp.autocast(device_type="cuda", dtype=torch.float16):
                H = model.point_feats(X); Zp = model.head(H.squeeze(0))
        else:
            H = model.point_feats(X); Zp = model.head(H.squeeze(0))
        out.append(F.normalize(Zp, dim=1).float().cpu().numpy())
    return np.concatenate(out, 0)

@torch.no_grad()
def seed_proto(model: SetEncoder, S: np.ndarray, t_idx: int) -> np.ndarray:
    if S.shape[0] == 0: return np.zeros((128,), np.float32)
    K = min(512, max(32, S.shape[0]))
    A = S[np.random.choice(S.shape[0], size=K, replace=True)]
    x,y,z = A[:,0], A[:,1], A[:,2]
    r = np.sqrt(x*x + y*y); th = np.arctan2(y, x)
    tcos, tsin = _time_code_scalar(int(t_idx))
    X = torch.from_numpy(np.stack([x,y,z,r,th,
                                   np.full_like(r, tcos, np.float32),
                                   np.full_like(r, tsin, np.float32)], 1)
                         ).unsqueeze(0).to(DEVICE, non_blocking=True)
    with torch.amp.autocast(device_type="cuda", dtype=torch.float16, enabled=(ECFG.use_amp and torch.cuda.is_available())):
        H = model.point_feats(X); Z = model.head(model.pool_set(H))
    return F.normalize(Z.squeeze(0), dim=0).float().cpu().numpy()

# ----------------------------- Density + helpers -----------------------------

def knn_density_xy(Pxy: np.ndarray, k: int = 16) -> np.ndarray:
    if Pxy.size == 0: return np.zeros((0,), np.float32)
    try:
        from scipy.spatial import cKDTree
        tree = cKDTree(Pxy.astype(np.float64), leafsize=64)
        k_eff = int(min(k + 1, max(2, len(Pxy))))
        dists, _ = tree.query(Pxy, k=k_eff, workers=-1)
        if dists.ndim == 1: dists = dists[:, None]
        rk = np.maximum(dists[:, 1:].max(axis=1).astype(np.float32), 1e-3)
        return (k / (np.pi * rk * rk)).astype(np.float32)
    except Exception:
        # Histogram fallback
        nb = 128
        H, xedges, yedges = np.histogram2d(Pxy[:,0], Pxy[:,1], bins=nb)
        xi = np.clip(np.searchsorted(xedges, Pxy[:,0], side='right') - 1, 0, nb-1)
        yi = np.clip(np.searchsorted(yedges, Pxy[:,1], side='right') - 1, 0, nb-1)
        counts = H[xi, yi].astype(np.float32)
        area = (xedges[1]-xedges[0]) * (yedges[1]-yedges[0]) or 1.0
        return counts / np.float32(area)

def _alpha_for_range(seeds_xy: np.ndarray) -> float:
    if seeds_xy.size == 0: return ECFG.alpha_far
    r = np.linalg.norm(seeds_xy, axis=1)
    near_frac = (r < ECFG.r_split_near).mean() if r.size else 0.0
    return float(near_frac*ECFG.alpha_near + (1.0-near_frac)*ECFG.alpha_far)

# ----------------------------- Assignment (two-stage) -----------------------------

def assign_two_stage(model: SetEncoder, P: np.ndarray, t_idx: int,
                     seeds_xy: np.ndarray, seeds_emb: np.ndarray,
                     mahal_state: Optional[Dict] = None) -> np.ndarray:
    if P.shape[0] == 0 or seeds_xy.shape[0] == 0:
        return np.full((P.shape[0],), -1, int)

    Zp = embed_points_pointwise(model, P, t_idx)
    Zp_n = Zp / (np.linalg.norm(Zp, axis=1, keepdims=True) + 1e-8)
    Ze_n = seeds_emb / (np.linalg.norm(seeds_emb, axis=1, keepdims=True) + 1e-8)

    XY = P[:, :2].astype(np.float32)
    d_xy = np.linalg.norm(XY[:, None, :] - seeds_xy[None, :, :], axis=2)
    sim  = np.clip(Zp_n @ Ze_n.T, -1.0, 1.0)
    d_emb = 1.0 - sim

    alpha = _alpha_for_range(seeds_xy)
    seed_r_med = np.median(np.linalg.norm(seeds_xy, axis=1)) if seeds_xy.size else 0.0
    gate_xy = ECFG.assign_max_dist * (1.0 + ECFG.adaptive_gate_k * max(0.0, seed_r_med - 20.0))
    d_blend = alpha*d_xy + (1.0 - alpha)*d_emb

    # --- CORE (tight + margin) ---
    j1   = np.argmin(d_blend, axis=1)
    best = np.min(d_blend, axis=1)
    if d_blend.shape[1] >= 2:
        part2 = np.partition(d_blend, kth=1, axis=1)[:, :2]
        second = part2.max(axis=1)
        margin = second - best
    else:
        margin = np.full_like(best, 1.0, dtype=np.float32)

    core_gate_xy = 0.75 * gate_xy
    thr_far_core = (1.4 + 0.02 * max(0.0, seed_r_med - 25.0))
    ok_xy  = d_xy[np.arange(d_xy.shape[0]), j1] <= core_gate_xy
    ok_mix = best <= (alpha*core_gate_xy + (1.0 - alpha)*thr_far_core)
    ok_mg  = margin >= 0.15

    asg = np.full((P.shape[0],), -1, int)
    core_mask = ok_xy & ok_mix & ok_mg
    asg[core_mask] = j1[core_mask]

    # Update seed stats from cores (EMA position + refreshed prototypes + Mahalanobis)
    K = seeds_xy.shape[0]
    mu_list, inv_list = [], []
    for k in range(K):
        hit = (asg == k)
        if not np.any(hit):
            mu_list.append(seeds_xy[k]); inv_list.append(np.eye(2, dtype=np.float32)); continue
        cl = pca_tighten_eval(P[hit])
        if cl.shape[0] == 0:
            mu_list.append(seeds_xy[k]); inv_list.append(np.eye(2, dtype=np.float32)); continue
        new_xy = np.median(cl[:, :2], 0).astype(np.float32)
        seeds_xy[k] = (1.0-ECFG.motion_smooth)*seeds_xy[k] + ECFG.motion_smooth*new_xy
        seeds_emb[k] = seed_proto(model, cl, t_idx)
        xy = cl[:, :2].astype(np.float32)
        mu = np.median(xy, axis=0).astype(np.float32)
        ctr = xy - mu
        if ctr.shape[0] < 8:
            var = np.maximum(ctr.var(axis=0), ECFG.mahal_eps)
            C = np.diag(var + ECFG.mahal_eps).astype(np.float32)
        else:
            C = np.cov(ctr.T, bias=False) + ECFG.mahal_eps*np.eye(2, dtype=np.float32)
        try: Cinv = np.linalg.inv(C)
        except np.linalg.LinAlgError: Cinv = np.linalg.pinv(C)
        mu_list.append(mu); inv_list.append(Cinv)

    if mahal_state is not None:
        mahal_state["mu"] = np.stack(mu_list, 0).astype(np.float32)
        mahal_state["covinv"] = np.stack(inv_list, 0).astype(np.float32)

    # --- GROW (looser, density-aware, Mahalanobis-gated) ---
    remain = (asg < 0)
    if not np.any(remain): return asg

    rho = knn_density_xy(XY); med_rho = float(np.median(rho)) if rho.size else 0.0
    mu = np.stack(mu_list, 0).astype(np.float32); Cinv = np.stack(inv_list, 0).astype(np.float32)

    grow_gate_xy = 1.15 * gate_xy
    thr_far_grow = (1.7 + 0.02 * max(0.0, seed_r_med - 25.0))
    low_den_thr = ECFG.low_density_factor * (med_rho + 1e-6)

    idxs = np.where(remain)[0]
    for ii in idxs:
        d = XY[ii][None, :] - mu
        q = np.einsum('ki,kij,kj->k', d, Cinv, d)
        mxy = np.sqrt(np.maximum(0.0, q))
        s_xy = d_xy[ii]
        s_blend = alpha*s_xy + (1.0 - alpha)*(1.0 - sim[ii])

        feas = (s_xy <= grow_gate_xy) & (mxy <= ECFG.mahal_xy_thr)
        if not np.any(feas): continue
        kbest = int(np.argmin(np.where(feas, s_blend, np.inf)))
        if (rho[ii] < low_den_thr) and (s_xy[kbest] > 0.9*grow_gate_xy): continue
        if s_blend[kbest] > (alpha*grow_gate_xy + (1.0 - alpha)*thr_far_grow): continue
        asg[ii] = kbest

    return asg

# ----------------------------- Metrics -----------------------------

def point_accuracy(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    if y_true.size == 0 or y_pred.size == 0: return 0.0
    return float(np.mean(y_true == y_pred))

def bcubed_scores(y_true: np.ndarray, y_pred: np.ndarray) -> Tuple[float, float, float]:
    T, P = y_true.copy(), y_pred.copy()
    m = (T != -1) | (P != -1)
    if not np.any(m): return 0.0, 0.0, 0.0
    T, P = T[m], P[m]
    N = len(T)
    precs, recs = [], []
    for i in range(N):
        mask_p = (P == P[i]); mask_t = (T == T[i])
        inter = np.sum(mask_p & mask_t)
        precs.append(inter / max(1, np.sum(mask_p)))
        recs.append(inter / max(1, np.sum(mask_t)))
    Pm, Rm = float(np.mean(precs)), float(np.mean(recs))
    F = 0.0 if (Pm + Rm) == 0 else (2*Pm*Rm)/(Pm + Rm)
    return Pm, Rm, F

# ----------------------------- Minimal eval loop (per recording) -----------------------------

def eval_recording(model: SetEncoder, frames: List[Dict]) -> Dict[str, float]:
    """Evaluates a single contiguous recording (list of {points_path, labels_path, frame_idx})."""
    if not frames: return {"PointAcc": 0.0, "B3F1": 0.0}

    # Bootstrap seeds from first frame GT
    first = frames[0]
    P0 = read_points_auto(Path(first["points_path"]))
    if P0.size == 0: return {"PointAcc": 0.0, "B3F1": 0.0}
    P0 = P0[fov_mask(P0, GMASK)]
    if GMASK.use_simple_ground: P0 = P0[simple_ground_mask(P0, GMASK.ground_z_margin)]
    objs0 = parse_objects(load_label(Path(first["labels_path"])))
    if not objs0: return {"PointAcc": 0.0, "B3F1": 0.0}

    seeds_xy = np.stack([o["pos"][:2] for o in objs0], 0).astype(np.float32)
    seeds_emb = []
    for o in objs0:
        m0 = points_in_cuboid(P0, o["pos"], o["scale"], o["yaw"], margin=BCFG.margin)
        cl = pca_tighten(P0[m0]) if np.any(m0) else P0
        if cl.shape[0] == 0: cl = P0
        seeds_emb.append(seed_proto(model, cl, first["frame_idx"]))
    seeds_emb = np.stack(seeds_emb, 0)
    miss_counts = [0] * seeds_xy.shape[0]
    mahal_state = {"mu": None, "covinv": None}

    accs, f1s = [], []
    for m in frames:
        P = read_points_auto(Path(m["points_path"]))
        if P.size == 0: continue
        P = P[fov_mask(P, GMASK)]
        if GMASK.use_simple_ground: P = P[simple_ground_mask(P, GMASK.ground_z_margin)]
        objs = parse_objects(load_label(Path(m["labels_path"])))
        if not objs: continue

        y_true = gt_assign_for_eval(P, objs)
        asg = assign_two_stage(model, P, m["frame_idx"], seeds_xy, seeds_emb, mahal_state)

        # births
        if ECFG.allow_births:
            XY = P[:, :2].astype(np.float32)
            un = np.where(asg < 0)[0]
            if un.size >= ECFG.birth_min_pts:
                rho = knn_density_xy(XY)
                med = float(np.median(rho)) if rho.size else 0.0
                keep_un = un[rho[un] >= (ECFG.birth_keep_if_dense_q * (med + 1e-6))]
                if keep_un.size >= ECFG.birth_min_pts:
                    # quick greedy clustering in XY
                    used = np.zeros(keep_un.size, bool); new_xy, new_emb = [], []
                    for ii in range(keep_un.size):
                        if used[ii]: continue
                        p = XY[keep_un[ii]]
                        d = np.linalg.norm(XY[keep_un] - p[None, :], axis=1)
                        grp = np.where((d <= ECFG.birth_eps_xy) & (~used))[0]
                        if grp.size >= ECFG.birth_min_pts:
                            ctr = np.median(XY[keep_un][grp], axis=0).astype(np.float32)
                            if seeds_xy.size:
                                dmin = np.min(np.linalg.norm(seeds_xy - ctr[None, :], axis=1))
                                if dmin < ECFG.birth_min_d_from_seeds:
                                    used[grp] = True; continue
                            cl = pca_tighten_eval(P[keep_un][grp])
                            if cl.shape[0] >= ECFG.birth_min_pts:
                                new_xy.append(ctr); new_emb.append(seed_proto(model, cl, m["frame_idx"]))
                            used[grp] = True
                    if new_xy:
                        seeds_xy = np.concatenate([seeds_xy, np.stack(new_xy, 0)], 0)
                        seeds_emb = np.concatenate([seeds_emb, np.stack(new_emb, 0)], 0)
                        miss_counts.extend([0] * len(new_xy))

        # deaths
        if ECFG.allow_deaths and seeds_xy.shape[0]:
            had_hit = [np.any(asg == k) for k in range(seeds_xy.shape[0])]
            for k, hit in enumerate(had_hit): miss_counts[k] = 0 if hit else (miss_counts[k] + 1)
            keep = np.array([c <= ECFG.death_max_misses for c in miss_counts], bool)
            if not np.all(keep):
                seeds_xy, seeds_emb = seeds_xy[keep], seeds_emb[keep]
                miss_counts = [c for c, k in zip(miss_counts, keep) if k]
                if mahal_state["mu"] is not None:
                    mahal_state["mu"] = mahal_state["mu"][keep]
                    mahal_state["covinv"] = mahal_state["covinv"][keep]

        # map predicted cluster IDs to GT IDs for pointwise metrics
        pred_ids = np.unique(asg[asg >= 0]).astype(int)
        if pred_ids.size and objs:
            pred_centers = []
            for pid in pred_ids:
                pts = P[asg == pid, :2]
                pred_centers.append(np.median(pts, axis=0).astype(np.float32) if pts.size else np.array([np.nan, np.nan], np.float32))
            pred_centers = np.stack(pred_centers, 0)
            gt_tids = np.array([int(o["track_id"]) for o in objs], np.int64)
            gt_centers = np.stack([o["pos"][:2] for o in objs], 0).astype(np.float32)
            D = np.linalg.norm(pred_centers[:, None, :] - gt_centers[None, :, :], axis=2)
            D[~np.isfinite(D)] = np.inf
            mapping = {}
            while np.isfinite(D).any():
                flat = np.nanargmin(D); i, j = divmod(flat, D.shape[1])
                if not np.isfinite(D[i, j]): break
                mapping[int(pred_ids[i])] = int(gt_tids[j])
                D[i, :] = np.inf; D[:, j] = np.inf
            y_pred = np.array([mapping.get(int(k), -1) if k >= 0 else -1 for k in asg], np.int64)
        else:
            y_pred = -np.ones_like(asg, np.int64)

        accs.append(point_accuracy(y_true, y_pred))
        _, _, F1 = bcubed_scores(y_true, y_pred); f1s.append(F1)

    return {"PointAcc": float(np.mean(accs) if accs else 0.0),
            "B3F1":    float(np.mean(f1s) if f1s else 0.0)}
# ----------------------------- End -----------------------------

# Fine-Grained Classification Stage

In [None]:
"""
Fine-grained classifier on exported crops.
Pipeline: (1) build per-sequence CSVs -> (2) train per-frame DGCNN
-> (3) cache per-frame logits -> (4) train tiny temporal head
+ rich metrics and fixed-horizon / dynamic-stopping reports.
"""

from __future__ import annotations
import json, math, time, re
from dataclasses import dataclass
from pathlib import Path
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ----------------------------- Paths & Config -----------------------------
try:
    FG_ROOT = paths.FG_EXPORT          # from the tracking script if in same session
    CKPT_DIR = paths.CKPT_DIR
except NameError:
    FG_ROOT  = Path(r"C:\UNI\Thesis\Dog Data\fg_export")
    CKPT_DIR = Path(r"C:\UNI\Thesis\Dog Data\checkpoints")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CLS_NAMES = ["dog", "human", "atlas"]
CLS_TO_ID = {c: i for i, c in enumerate(CLS_NAMES)}

# Per-frame graph
K_FOLDS = 5; KNN_K = 16; EPOCHS = 12; BATCH = 24; LR = 1e-3
STEP = 6; GAMMA = 0.5; CLIP = 1.0; AMP = True; MAX_PTS_PER_CROP = 2048

# Temporal head & eval
T_EPOCHS = 25; T_BATCH = 16; T_LR = 1e-3; T_STEP = 12; T_GAMMA = 0.5; T_CLIP = 1.0
MAX_T = 64; DO_FULL_TEMPORAL_EVAL = False

# Geometry breakdown
RANGE_BINS_METERS = [0, 20, 45, 95]
CENTER_ANG_DEG = 7.0

# Dynamic-stopping knobs
DECISION_THRESH = 0.60
STABLE_K        = 1
DYN_MAX_FRAMES  = 8

# ----------------------------- I/O Helpers -----------------------------
def _resolve(p: str, base: Path) -> Path:
    """Resolve relative path inside FG export trees."""
    pp = Path(p)
    if pp.is_absolute() and pp.exists(): return pp
    if (base/pp).exists(): return base/pp
    hits = list(base.rglob(pp.name))
    return hits[0] if hits else (base/pp)

def load_xyz_from_crop(npz_path: Path) -> np.ndarray:
    """Return centered XYZI (I=0), subsampled."""
    d = np.load(npz_path, allow_pickle=True)
    if "xyz" in d:
        X = np.asarray(d["xyz"], np.float32)
    else:
        for k in ("points","xyzi","pc","pts","arr_0"):
            if k in d and isinstance(d[k], np.ndarray):
                arr = d[k]; break
        else:
            return np.zeros((1,4), np.float32)
        X = arr[:, :3].astype(np.float32) if arr.ndim==2 and arr.shape[1]>=3 else np.zeros((1,3), np.float32)
    if X.shape[0] == 0: X = np.zeros((1,3), np.float32)
    if X.shape[0] > MAX_PTS_PER_CROP:
        idx = np.random.choice(X.shape[0], MAX_PTS_PER_CROP, replace=False); X = X[idx]
    X = X - X.mean(0, keepdims=True)
    return np.concatenate([X, np.zeros((X.shape[0],1), np.float32)], 1)

def load_raw_xyz_for_pose(npz_path: Path) -> np.ndarray:
    """Raw XYZ for pose (range/bearing) estimation."""
    d = np.load(npz_path, allow_pickle=True)
    if "xyz" in d:
        return np.asarray(d["xyz"], np.float32)
    for k in ("points","xyzi","pc","pts","arr_0"):
        if k in d and isinstance(d[k], np.ndarray):
            arr = d[k]; return arr[:, :3].astype(np.float32) if arr.ndim==2 and arr.shape[1]>=3 else np.zeros((0,3), np.float32)
    return np.zeros((0,3), np.float32)

# ----------------------------- Sequence CSVs -----------------------------
def build_sequence_csvs_from_exports(fold_idx: int) -> Tuple[Path, Path, Path]:
    """From fg_export manifests -> per-sequence CSVs (train/val)."""
    outdir = FG_ROOT / "clf_outputs"; (outdir / "seq_logits").mkdir(parents=True, exist_ok=True)

    def _load(split):
        m = FG_ROOT / split / f"fold_{fold_idx}" / "manifest.csv"
        df = pd.read_csv(m)
        df["label"] = df["class"].map(lambda c: CLS_TO_ID.get(str(c).lower(), -1))
        df = df[df["label"] >= 0].copy()
        df["abs_npz"] = df["rel_path"].apply(lambda r: str(_resolve(r, m.parent)))
        return df

    def _to_sequences(df: pd.DataFrame) -> pd.DataFrame:
        rows = []
        for (b, tid, cls), g in df.groupby(["bucket","track_id","class"], sort=False):
            g2 = g.sort_values("frame_idx")
            rows.append({
                "seq_id": f"{b}__t{int(tid)}",
                "bucket": b,
                "track_id": int(tid),
                "roi_npz_paths": "|".join(g2["abs_npz"].tolist()),
                "class": cls,
                "label": int(CLS_TO_ID.get(str(cls).lower(), -1)),
            })
        return pd.DataFrame(rows)

    sq_tr = _to_sequences(_load("train")); sq_va = _to_sequences(_load("val"))
    csv_tr = outdir / f"sequences_train_fold{fold_idx}.csv"
    csv_va = outdir / f"sequences_val_fold{fold_idx}.csv"
    sq_tr.to_csv(csv_tr, index=False); sq_va.to_csv(csv_va, index=False)

    folds_json = outdir / "folds.json"
    js = {"n_splits": K_FOLDS, "folds": []}
    for k in range(1, K_FOLDS+1):
        js["folds"].append({"train": sq_tr["seq_id"].tolist() if k==fold_idx else [],
                            "val":   sq_va["seq_id"].tolist() if k==fold_idx else []})
    folds_json.write_text(json.dumps(js, indent=2), encoding="utf-8")
    return outdir, csv_tr, csv_va

# ----------------------------- DGCNN (Var-N) -----------------------------
def knn_graph_masked(x, mask, k):
    """KNN indices respecting variable point counts per batch item."""
    B,N,F = x.shape; device = x.device
    idx_out = torch.zeros(B,N,k, dtype=torch.long, device=device)
    vc = mask.sum(1)
    for b in range(B):
        n = int(vc[b])
        if n <= 1:
            inds = torch.nonzero(mask[b], as_tuple=False).squeeze(1)
            fill = int(inds[0].item()) if n==1 else 0
            idx_out[b].fill_(fill); continue
        inds = torch.nonzero(mask[b], as_tuple=False).squeeze(1); xb = x[b, inds, :]
        try:
            with torch.amp.autocast('cuda', enabled=x.is_cuda): d = torch.cdist(xb, xb)
        except RuntimeError:
            d = torch.cdist(xb.cpu(), xb.cpu()).to(device)
        d.fill_diagonal_(float('inf'))
        k_eff = min(k, max(1, n-1))
        nbr_local = d.topk(k_eff, dim=1, largest=False).indices
        if k_eff < k: nbr_local = torch.cat([nbr_local, nbr_local[:, :1].expand(n, k-k_eff)], 1)
        idx_out[b, inds, :] = inds[nbr_local]
        if (~mask[b]).any(): idx_out[b, ~mask[b], :] = inds[0]
    return idx_out

def _gather_neighbors_flat(x, idx):
    B,N,F = x.shape; k = idx.size(-1)
    base = torch.arange(B, device=x.device).view(B,1,1)*N
    return x.reshape(B*N, F)[(idx + base).reshape(-1), :].reshape(B,N,k,F)

def masked_max_mean(f, mask):
    """Global max + mean pooling under mask."""
    maskC = mask.unsqueeze(-1); f_masked = f.masked_fill(~maskC, float('-inf'))
    f_max = torch.where(torch.isfinite(torch.amax(f_masked, dim=1)), torch.amax(f_masked, dim=1), torch.zeros_like(f_masked[:,0,:]))
    s = (f * maskC.float()).sum(1); cnt = mask.float().sum(1, keepdim=True).clamp_min(1.0)
    return f_max, s / cnt

class EdgeConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.mlp = nn.Sequential(nn.Linear(in_ch*2, out_ch, bias=False),
                                 nn.BatchNorm1d(out_ch),
                                 nn.LeakyReLU(0.2, inplace=False))
    def forward(self, x, mask, idx):
        B,N,F = x.shape; k = idx.size(-1)
        xi = x.unsqueeze(2).expand(-1,-1,k,-1); xj = _gather_neighbors_flat(x, idx)
        e = torch.cat([xi, xj - xi], -1).reshape(B*N*k, 2*F)
        f = self.mlp(e).view(B,N,k,-1).max(2)[0] * mask.unsqueeze(-1).float()
        return f

class DGCNNVarN(nn.Module):
    def __init__(self, in_ch=4, k=16, num_classes=3):
        super().__init__()
        self.k = k
        self.ec1 = EdgeConv(in_ch, 64); self.ec2 = EdgeConv(64, 128); self.ec3 = EdgeConv(128, 256)
        glob = (64+128+256)*2
        self.head = nn.Sequential(
            nn.Linear(glob, 256, bias=False), nn.BatchNorm1d(256), nn.LeakyReLU(0.2, inplace=False), nn.Dropout(0.3),
            nn.Linear(256, 256, bias=False), nn.BatchNorm1d(256), nn.LeakyReLU(0.2, inplace=False), nn.Dropout(0.3),
            nn.Linear(256, num_classes),
        )
    def forward(self, x, mask):
        idx = knn_graph_masked(x, mask, self.k)
        f1 = self.ec1(x, mask, idx); f2 = self.ec2(f1, mask, idx); f3 = self.ec3(f2, mask, idx)
        gmax, gmean = masked_max_mean(torch.cat([f1,f2,f3], -1), mask)
        return self.head(torch.cat([gmax, gmean], -1))

# ----------------------------- Datasets / Collates -----------------------------
class CropsFrameDS(Dataset):
    def __init__(self, manifest_csv: Path):
        df = pd.read_csv(manifest_csv)
        df["label"] = df["class"].map(lambda c: CLS_TO_ID.get(str(c).lower(), -1))
        df = df[df["label"] >= 0].copy()
        df["abs_npz"] = df["rel_path"].apply(lambda r: str(_resolve(r, manifest_csv.parent)))
        self.df = df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        X = load_xyz_from_crop(Path(r["abs_npz"]))
        return torch.from_numpy(X), torch.tensor(int(r["label"]), dtype=torch.long)

def collate_frames(batch):
    Ns = [x[0].shape[0] for x in batch]; Nmax = max(Ns) if Ns else 1; B = len(batch)
    X = torch.zeros(B, Nmax, 4, dtype=torch.float32); M = torch.zeros(B, Nmax, dtype=torch.bool); Y = torch.zeros(B, dtype=torch.long)
    for b,(Xi,y) in enumerate(batch): n = Xi.shape[0]; X[b,:n,:] = Xi; M[b,:n] = True; Y[b] = y
    return X, M, Y

class CropsFrameEvalDS(CropsFrameDS):
    def __getitem__(self, i):
        Xc, y = super().__getitem__(i); r = self.df.iloc[i]; Xraw = load_raw_xyz_for_pose(Path(r["abs_npz"]))
        if Xraw.shape[0]==0: rng, th = 0.0, 0.0
        else:
            med = np.median(Xraw, 0).astype(np.float32); rng = float(np.hypot(med[0], med[1])); th = float(np.arctan2(med[1], med[0]))
        return Xc, y, torch.tensor(rng, np.float32), torch.tensor(th, np.float32)

def collate_frames_with_pose(batch):
    X,M,Y = collate_frames([(b[0], b[1]) for b in batch])
    B = len(batch); R = torch.zeros(B, dtype=torch.float32); TH = torch.zeros(B, dtype=torch.float32)
    for i,(_,_,r,th) in enumerate(batch): R[i] = r; TH[i] = th
    return X,M,Y,R,TH

# Cached-logits datasets
class LogitSeqDS(Dataset):
    def __init__(self, outdir: Path, split: str, fold_idx: int):
        self.df = pd.read_csv(outdir / f"sequences_{split}_fold{fold_idx}.csv")
        self.dir = outdir / "seq_logits"
        self.files = [(self.dir / f"{str(r['seq_id'])}.npz", int(r["label"])) for _, r in self.df.iterrows()
                      if (self.dir / f"{str(r['seq_id'])}.npz").exists()]
        if not self.files: raise RuntimeError(f"No cached sequences for fold {fold_idx} {split}.")
    def __len__(self): return len(self.files)
    def __getitem__(self, i):
        p, y = self.files[i]; L = np.load(p)["logits"].astype(np.float32)
        if MAX_T and L.shape[0] > MAX_T:
            idx = np.linspace(0, L.shape[0]-1, MAX_T).round().astype(int); L = L[idx]
        return torch.from_numpy(L), torch.tensor(y, torch.long)

def collate_logits(batch):
    xs, ys = zip(*batch); Tm = max(x.shape[0] for x in xs); C = xs[0].shape[1]; B = len(xs)
    X = torch.zeros(B, Tm, C, dtype=torch.float32); M = torch.zeros(B, Tm, dtype=torch.bool)
    for b,x in enumerate(xs): t = x.shape[0]; X[b,:t,:] = x; M[b,:t] = True
    return X, M, torch.stack(ys, 0)

# ----------------------------- Train / Cache (Per-frame) -----------------------------
@torch.no_grad()
def eval_frame(model, dl):
    model.eval(); crit = nn.CrossEntropyLoss(); tot=0.0; n=0; acc=0.0
    ys=[]; ps=[]
    for X,M,Y in dl:
        X,M,Y = X.to(DEVICE), M.to(DEVICE), Y.to(DEVICE)
        with torch.amp.autocast('cuda', enabled=(DEVICE.type=='cuda' and AMP)):
            logits = model(X,M); loss = crit(logits, Y)
        tot += loss.item()*Y.size(0); n += Y.size(0)
        ys.append(Y.cpu().numpy()); ps.append(logits.argmax(1).cpu().numpy())
    y_true = np.concatenate(ys) if ys else np.array([], int); y_pred = np.concatenate(ps) if ps else np.array([], int)
    if len(y_true): acc = float((y_true == y_pred).mean())
    return tot/max(1,n), acc

def train_per_frame_for_fold(fold_idx: int, outdir: Path) -> Path:
    man_tr = FG_ROOT / "train" / f"fold_{fold_idx}" / "manifest.csv"
    man_va = FG_ROOT / "val"   / f"fold_{fold_idx}" / "manifest.csv"
    dl_tr = DataLoader(CropsFrameDS(man_tr), batch_size=BATCH, shuffle=True,  num_workers=0,
                       collate_fn=collate_frames, pin_memory=(DEVICE.type=='cuda'))
    dl_va = DataLoader(CropsFrameDS(man_va), batch_size=BATCH, shuffle=False, num_workers=0,
                       collate_fn=collate_frames, pin_memory=(DEVICE.type=='cuda'))

    model = DGCNNVarN(in_ch=4, k=KNN_K, num_classes=len(CLS_NAMES)).to(DEVICE)
    opt = torch.optim.Adam(model.parameters(), lr=LR)
    sch = torch.optim.lr_scheduler.StepLR(opt, step_size=STEP, gamma=GAMMA)
    scaler = torch.amp.GradScaler('cuda') if (DEVICE.type=='cuda' and AMP) else None
    crit = nn.CrossEntropyLoss()

    best_vloss, best_state, best_ep = float('inf'), None, -1
    for ep in range(1, EPOCHS+1):
        model.train(); run=0.0; cnt=0; t0=time.perf_counter()
        for X,M,Y in dl_tr:
            X,M,Y = X.to(DEVICE), M.to(DEVICE), Y.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            if scaler is None:
                logits = model(X,M); loss = crit(logits, Y)
                loss.backward(); nn.utils.clip_grad_norm_(model.parameters(), CLIP); opt.step()
            else:
                with torch.amp.autocast('cuda'): logits = model(X,M); loss = crit(logits, Y)
                scaler.scale(loss).backward(); scaler.unscale_(opt); nn.utils.clip_grad_norm_(model.parameters(), CLIP)
                scaler.step(opt); scaler.update()
            run += loss.item()*Y.size(0); cnt += Y.size(0)
        sch.step()
        vl, va = eval_frame(model, dl_va)
        if vl < best_vloss:
            best_vloss, best_ep = vl, ep
            best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
        dt=time.perf_counter()-t0
        print(f"[frame fold {fold_idx:02d} ep {ep:02d}] train={run/max(1,cnt):.4f}  val={vl:.4f}  acc={100*va:.1f}%  ({dt:.1f}s)")

    ck = outdir / f"per_frame_graph_best_vloss_fold{fold_idx}.pt"
    torch.save({"state_dict": best_state, "classes": CLS_NAMES}, ck)
    print(f"[frame fold {fold_idx}] best_val_loss={best_vloss:.4f} @ ep {best_ep} -> {ck}")
    return ck

def load_per_frame_model(ckpt_path: Path) -> nn.Module:
    ck = torch.load(ckpt_path, map_location=DEVICE)
    state = ck["state_dict"] if isinstance(ck, dict) and "state_dict" in ck else ck
    model = DGCNNVarN(in_ch=4, k=KNN_K, num_classes=len(CLS_NAMES)).to(DEVICE)
    model.load_state_dict(state); model.eval()
    return model

@torch.no_grad()
def cache_sequence_logits_for_fold(fold_idx: int, outdir: Path, ck_graph: Path):
    """Run per-frame model over sequences; save per-sequence logits (TxC) to NPZ."""
    seq_dir = outdir / "seq_logits"; seq_dir.mkdir(parents=True, exist_ok=True)
    df_tr = pd.read_csv(outdir / f"sequences_train_fold{fold_idx}.csv")
    df_va = pd.read_csv(outdir / f"sequences_val_fold{fold_idx}.csv")
    model = load_per_frame_model(ck_graph)

    def _cache(df, tag):
        for i, r in df.iterrows():
            L = []
            for p in str(r["roi_npz_paths"]).split("|"):
                Xi = load_xyz_from_crop(Path(p))
                with torch.amp.autocast('cuda', enabled=(DEVICE.type=='cuda' and AMP)):
                    X = torch.from_numpy(Xi[None]).float().to(DEVICE)
                    M = torch.ones(1, Xi.shape[0], dtype=torch.bool, device=DEVICE)
                    logits = model(X, M)
                L.append(logits.squeeze(0).detach().cpu().numpy().astype(np.float32))
            L = np.stack(L, 0) if L else np.zeros((1, len(CLS_NAMES)), np.float32)
            np.savez_compressed(seq_dir / f"{str(r['seq_id'])}.npz", logits=L, label=np.int64(int(r["label"])))
            if (i % 25) == 0: print(f"[cache fold{fold_idx} {tag}] {i}/{len(df)} {str(r['seq_id'])}  T={L.shape[0]}")
        print(f"[cache] fold{fold_idx} {tag}: {len(df)} sequences processed")

    _cache(df_tr, "train"); _cache(df_va, "val")

# ----------------------------- Temporal Head -----------------------------
class TinyTemporalHead(nn.Module):
    """Tiny conv+attn across cached logits (TxC) -> class."""
    def __init__(self, C, hid=64, num_classes=3):
        super().__init__()
        self.proj = nn.Linear(C, hid); self.conv = nn.Conv1d(hid, hid, 3, padding=1)
        self.attn_v = nn.Linear(hid, 1, bias=False)
        self.fc = nn.Sequential(nn.Linear(hid,128), nn.ReLU(True), nn.Dropout(0.3), nn.Linear(128, num_classes))
    def forward(self, X, M):
        B,T,C = X.shape
        H = self.proj(X); Hc = self.conv(H.transpose(1,2)).transpose(1,2)
        H = (H+Hc) * M.unsqueeze(-1).to(H.dtype)
        s = self.attn_v(torch.tanh(H)).squeeze(-1)
        s = s.masked_fill(~M, torch.tensor(-1e4, dtype=s.dtype, device=s.device))
        w = torch.softmax(s - s.max(1, keepdim=True).values, 1)
        return self.fc((H * w.unsqueeze(-1)).sum(1))

@torch.no_grad()
def evaluate_temporal(model, dl, crit):
    model.eval(); tot=0.0; n=0; acc=0.0
    for X,M,Y in dl:
        X,M,Y = X.to(DEVICE), M.to(DEVICE), Y.to(DEVICE)
        with torch.amp.autocast('cuda', enabled=(DEVICE.type=='cuda' and AMP)):
            logits = model(X,M); loss = crit(logits, Y)
        tot += loss.item()*Y.size(0); n += Y.size(0); acc += (logits.argmax(1)==Y).sum().item()
    return tot/max(1,n), acc/max(1,n)

def train_temporal_for_fold(fold_idx: int, outdir: Path) -> Path:
    tr = LogitSeqDS(outdir, "train", fold_idx); va = LogitSeqDS(outdir, "val", fold_idx)
    dl_tr = DataLoader(tr, batch_size=T_BATCH, shuffle=True,  num_workers=0, collate_fn=collate_logits, pin_memory=(DEVICE.type=='cuda'))
    dl_va = DataLoader(va, batch_size=T_BATCH, shuffle=False, num_workers=0, collate_fn=collate_logits, pin_memory=(DEVICE.type=='cuda'))

    model = TinyTemporalHead(C=len(CLS_NAMES), hid=64, num_classes=len(CLS_NAMES)).to(DEVICE)
    crit = nn.CrossEntropyLoss(); opt = torch.optim.Adam(model.parameters(), lr=T_LR)
    sch = torch.optim.lr_scheduler.StepLR(opt, step_size=T_STEP, gamma=T_GAMMA)
    scaler = torch.amp.GradScaler('cuda') if (DEVICE.type=='cuda' and AMP) else None

    best_vloss, best_state, best_ep = float('inf'), None, -1
    for ep in range(1, T_EPOCHS+1):
        model.train(); run=0.0; cnt=0
        for X,M,Y in dl_tr:
            X,M,Y = X.to(DEVICE), M.to(DEVICE), Y.to(DEVICE)
            opt.zero_grad(set_to_none=True)
            if scaler is None:
                logits = model(X,M); loss = crit(logits, Y); loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), T_CLIP); opt.step()
            else:
                with torch.amp.autocast('cuda'): logits = model(X,M); loss = crit(logits, Y)
                scaler.scale(loss).backward(); scaler.unscale_(opt); nn.utils.clip_grad_norm_(model.parameters(), T_CLIP)
                scaler.step(opt); scaler.update()
            run += loss.item()*Y.size(0); cnt += Y.size(0)
        sch.step()
        vl, va = evaluate_temporal(model, dl_va, crit)
        if vl < best_vloss: best_vloss, best_ep, best_state = vl, ep, {k: v.detach().cpu() for k,v in model.state_dict().items()}
        print(f"[temp fold {fold_idx:02d} ep {ep:02d}] train={run/max(1,cnt):.4f}  val={vl:.4f}  acc={100*va:.1f}%")

    ck = outdir / f"temporal_from_logits_best_vloss_fold{fold_idx}.pt"
    torch.save({"state_dict": best_state, "classes": CLS_NAMES}, ck)
    print(f"[temp fold {fold_idx}] best_val_loss={best_vloss:.4f} @ ep {best_ep} -> {ck}")
    return ck

# ----------------------------- Metrics (NumPy, no sklearn) -----------------------------
def _softmax_np(z):
    z = z - np.max(z, 1, keepdims=True); e = np.exp(z); return e / np.clip(e.sum(1, keepdims=True), 1e-9, None)

def confusion_matrix_np(y_true, y_pred, nC):
    cm = np.zeros((nC,nC), np.int64)
    for t,p in zip(y_true, y_pred):
        if 0 <= t < nC and 0 <= p < nC: cm[t,p] += 1
    return cm

def precision_recall_f1_from_cm(cm):
    tp = np.diag(cm).astype(np.float64); fp = cm.sum(0) - tp; fn = cm.sum(1) - tp
    with np.errstate(divide='ignore', invalid='ignore'):
        prec = np.where(tp+fp>0, tp/(tp+fp), 0.0); rec = np.where(tp+fn>0, tp/(tp+fn), 0.0)
        f1 = np.where(prec+rec>0, 2*prec*rec/(prec+rec), 0.0)
    return prec, rec, f1, cm.sum(1).astype(int)

def top_k_accuracy(prob, y_true, k=3):
    k = min(k, prob.shape[1]); topk = np.argsort(-prob, 1)[:, :k]
    return float(np.any(topk == y_true[:,None], 1).mean()) if len(y_true) else 0.0

def brier_score(prob, y_true, nC):
    oh = np.zeros_like(prob); oh[np.arange(len(y_true)), y_true] = 1.0
    return float(np.mean(np.sum((prob - oh)**2, 1))) if len(y_true) else 0.0

def ece_score(prob, y_true, n_bins=15):
    conf = prob.max(1); preds = prob.argmax(1); acc = (preds == y_true).astype(np.float64)
    bins = np.linspace(0.0, 1.0, n_bins+1); ece = 0.0
    for i in range(n_bins):
        m = (conf >= bins[i]) & (conf < bins[i+1]) if i < n_bins-1 else (conf >= bins[i]) & (conf <= bins[i+1])
        if not np.any(m): continue
        ece += m.mean() * abs(acc[m].mean() - conf[m].mean())
    return float(ece)

def _binary_auc(scores, labels):
    if len(scores)==0: return 0.0
    o = np.argsort(-scores); y = labels[o]; tp = np.cumsum(y); fp = np.cumsum(1-y)
    tpr = tp / max(1, tp[-1]); fpr = fp / max(1, fp[-1])
    tpr = np.concatenate([[0.0], tpr, [1.0]]); fpr = np.concatenate([[0.0], fpr, [1.0]])
    return float(np.trapezoid(tpr, fpr))

def _binary_auprc(scores, labels):
    if len(scores)==0: return 0.0
    o = np.argsort(-scores); y = labels[o]; tp = np.cumsum(y); fp = np.cumsum(1-y)
    prec = tp / np.maximum(1, tp+fp); rec = tp / max(1, tp[-1])
    rec = np.concatenate([[0.0], rec]); prec = np.concatenate([[float(np.mean(labels)) if len(labels) else 0.0], prec])
    return float(np.trapezoid(prec, rec))

def auroc_auprc_ovr(prob, y_true, nC):
    auroc = np.zeros(nC, np.float64); auprc = np.zeros(nC, np.float64)
    for c in range(nC):
        s = prob[:, c]; lab = (y_true == c).astype(np.int32)
        auroc[c] = _binary_auc(s, lab); auprc[c] = _binary_auprc(s, lab)
    return auroc, auprc

def print_confusion(cm, names):
    print("Confusion (rows=GT, cols=Pred):")
    head = "      " + " ".join([f"{c[:6]:>7s}" for c in names]); print(head)
    for i,row in enumerate(cm): print(f"{names[i][:6]:>6s} " + " ".join([f"{v:7d}" for v in row]))

def print_metric_block(title, y_true, y_pred, prob, names):
    nC = len(names); print(f"\n===== {title} =====")
    if len(y_true)==0: print("No samples."); return
    acc = float((y_true == y_pred).mean()); cm = confusion_matrix_np(y_true, y_pred, nC)
    prec, rec, f1, sup = precision_recall_f1_from_cm(cm); auroc, auprc = auroc_auprc_ovr(prob, y_true, nC)
    print_confusion(cm, names)
    print("\nPer-class:\n class      supp   prec    rec     f1    AUROC  AUPRC")
    for i,c in enumerate(names):
        print(f" {c:<9s} {sup[i]:6d}  {prec[i]:6.3f} {rec[i]:6.3f} {f1[i]:6.3f}  {auroc[i]:6.3f} {auprc[i]:6.3f}")
    print(f"\nMacro avg : prec={np.mean(prec):.3f} rec={np.mean(rec):.3f} f1={np.mean(f1):.3f}  AUROC={auroc.mean():.3f} AUPRC={auprc.mean():.3f}")
    print(f"Micro/Acc : acc={acc:.3f}  top-3={top_k_accuracy(prob, y_true, k=min(3,nC)):.3f}")
    print(f"Calibration: ECE={ece_score(prob, y_true, 15):.3f}  Brier={brier_score(prob, y_true, nC):.3f}")

def print_accuracy_by_range_and_angle(title, y_true, y_pred, ranges_m, thetas_rad, bins_m, center_deg, names):
    print(f"\n----- {title}: Accuracy by range and angle -----")
    edges = np.array(bins_m, float)
    for i in range(len(edges)-1):
        lo, hi = edges[i], edges[i+1]; m = (ranges_m >= lo) & (ranges_m < hi)
        print(f"Range [{lo:>4.0f},{hi:>4.0f}) m : {100*float((y_true[m]==y_pred[m]).mean()) if np.any(m) else 0.0:5.1f}%  (n={m.sum()})")
    th_abs = np.abs(thetas_rad) * 180.0 / np.pi
    m_center = th_abs <= float(center_deg); m_off = ~m_center
    if np.any(m_center):
        print(f"\nCenter (|bearing| <= {center_deg:.1f} deg): {100*float((y_true[m_center]==y_pred[m_center]).mean()):5.1f}%  (n={m_center.sum()})")
    if np.any(m_off):
        print(f"Off-center (>|bearing|):                 {100*float((y_true[m_off]==y_pred[m_off]).mean()):5.1f}%  (n={m_off.sum()})")

# Collectors that also return pose arrays
@torch.no_grad()
def collect_frame_outputs_with_pose(model, dl, nC: int):
    model.eval(); crit = nn.CrossEntropyLoss(); tot=0.0; n=0
    ys=[]; ps=[]; logits_all=[]; rngs=[]; ths=[]
    for X,M,Y,R,TH in dl:
        X,M,Y = X.to(DEVICE), M.to(DEVICE), Y.to(DEVICE)
        with torch.amp.autocast('cuda', enabled=(DEVICE.type=='cuda' and AMP)):
            logits = model(X,M); loss = crit(logits, Y)
        tot += loss.item()*Y.size(0); n += Y.size(0)
        ys.append(Y.cpu().numpy()); logits_all.append(logits.detach().cpu().numpy()); ps.append(logits.argmax(1).cpu().numpy())
        rngs.append(R.numpy()); ths.append(TH.numpy())
    y_true = np.concatenate(ys) if ys else np.zeros((0,), int)
    y_pred = np.concatenate(ps) if ps else np.zeros((0,), int)
    logits = np.concatenate(logits_all) if logits_all else np.zeros((0,nC), np.float32)
    prob = _softmax_np(logits)
    Rall = np.concatenate(rngs) if rngs else np.zeros((0,), np.float32)
    THall= np.concatenate(ths)  if ths  else np.zeros((0,), np.float32)
    return y_true, y_pred, prob, Rall, THall, (tot/max(1,n))

# ----------------------------- Fixed-horizon & Dynamic eval -----------------------------
def _load_seq_csv_for_eval(outdir: Path, split: str, fold_idx: int) -> pd.DataFrame:
    df = pd.read_csv(outdir / f"sequences_{split}_fold{fold_idx}.csv")
    if "label" not in df.columns:
        df["label"] = df["class"].map(lambda c: CLS_TO_ID.get(str(c).lower(), -1))
    return df

def _seq_paths(row: pd.Series) -> List[str]:
    s = str(row["roi_npz_paths"]); return [p for p in s.split("|") if p] if s else []

def _pose_for_paths(paths: List[str], t_use: int) -> Tuple[float, float]:
    rs, ths = [], []
    for q in paths[:t_use]:
        Xraw = load_raw_xyz_for_pose(Path(q))
        if Xraw.shape[0]==0: continue
        med = np.median(Xraw, 0).astype(np.float32); rs.append(float(np.hypot(med[0], med[1]))); ths.append(float(np.arctan2(med[1], med[0])))
    return (float(np.median(rs)) if rs else 0.0, float(np.median(ths)) if ths else 0.0)

def _fixed_horizon_eval(outdir: Path, fold_idx: int, split: str, horizon: int):
    df = _load_seq_csv_for_eval(outdir, split, fold_idx); seq_dir = outdir / "seq_logits"
    y_true, y_pred, P_all, rngs, ths, missed = [], [], [], [], [], 0
    for _, r in df.iterrows():
        p = seq_dir / f"{str(r['seq_id'])}.npz"
        if not p.exists(): continue
        L = np.asarray(np.load(p)["logits"], np.float32)
        if L.shape[0] < horizon: missed += 1; continue
        P = _softmax_np(L)[:horizon].mean(0); pred = int(P.argmax()); y = int(r["label"])
        y_true.append(y); y_pred.append(pred); P_all.append(P[None,:])
        rng, th = _pose_for_paths(_seq_paths(r), t_use=horizon); rngs.append(rng); ths.append(th)
    if not y_true: return None
    return np.array(y_true,int), np.array(y_pred,int), np.concatenate(P_all,0), np.array(rngs,np.float32), np.array(ths,np.float32), missed

def _dynamic_eval(outdir: Path, fold_idx: int, split: str, conf_thr=DECISION_THRESH, stable_k=STABLE_K, tmax=DYN_MAX_FRAMES):
    df = _load_seq_csv_for_eval(outdir, split, fold_idx); seq_dir = outdir / "seq_logits"
    y_true, y_pred, P_all, rngs, ths, usedT, skipped = [], [], [], [], [], [], 0
    for _, r in df.iterrows():
        p = seq_dir / f"{str(r['seq_id'])}.npz"
        if not p.exists(): continue
        L = np.asarray(np.load(p)["logits"], np.float32); T = L.shape[0]
        if T < 1: skipped += 1; continue
        P = _softmax_np(L); last=None; streak=0; decided=False
        for t in range(1, min(tmax, T)+1):
            P_cum = P[:t].mean(0); pred = int(P_cum.argmax()); conf = float(P_cum[pred])
            streak = streak+1 if pred==last else 1; last = pred
            if (streak>=stable_k) and (conf>=conf_thr):
                rng, th = _pose_for_paths(_seq_paths(r), t_use=t)
                y_true.append(int(r["label"])); y_pred.append(pred); P_all.append(P_cum[None,:])
                rngs.append(rng); ths.append(th); usedT.append(t); decided=True; break
        if not decided: skipped += 1
    if not y_true: return None
    return np.array(y_true,int), np.array(y_pred,int), np.concatenate(P_all,0), np.array(rngs,np.float32), np.array(ths,np.float32), np.array(usedT,int), skipped

def _print_fixed_block(title: str, pack, bins_m=None, center_deg=None):
    y_true, y_pred, P, R, TH, missed = pack
    print(f"\n================ {title} ================")
    print(f"Used sequences: {len(y_true)}   Skipped (shorter than horizon): {missed}")
    print_metric_block(title, y_true, y_pred, P, CLS_NAMES)
    if bins_m is not None: print_accuracy_by_range_and_angle(title, y_true, y_pred, R, TH, bins_m, center_deg, CLS_NAMES)

def _print_dynamic_block(title: str, pack, bins_m=None, center_deg=None, fps: Optional[float]=None):
    y_true, y_pred, P, R, TH, usedT, skipped = pack
    print(f"\n================ {title} ================")
    print(f"Decided sequences: {len(y_true)}   Skipped (no decision <= tmax): {skipped}")
    if len(usedT):
        print(f"Frames-to-decision: mean={usedT.mean():.2f}, median={np.median(usedT):.0f}, p90={np.percentile(usedT,90):.0f}")
        if fps and fps>0: print(f"In seconds: mean={usedT.mean()/fps:.2f}s, median={np.median(usedT)/fps:.2f}s, p90={np.percentile(usedT,90)/fps:.2f}s")
    print_metric_block(title, y_true, y_pred, P, CLS_NAMES)
    if bins_m is not None: print_accuracy_by_range_and_angle(title, y_true, y_pred, R, TH, bins_m, center_deg, CLS_NAMES)

# ----------------------------- Orchestrate -----------------------------
def run_fine_grained_all_folds():
    print(f"[cfg] FG_ROOT={FG_ROOT}")
    out_all = FG_ROOT / "clf_outputs"
    AF_y_true=[]; AF_y_pred=[]; AF_prob=[]; AF_rng=[]; AF_th=[]; fold_acc=[]
    for fold_idx in range(1, K_FOLDS+1):
        print("\n============================")
        print(f"=== Fine-grained: Fold {fold_idx}/{K_FOLDS} ===")
        print("============================")
        outdir, csv_tr, csv_va = build_sequence_csvs_from_exports(fold_idx)
        ck_graph = train_per_frame_for_fold(fold_idx, outdir)

        # Per-frame VAL metrics + pose breakdown
        eval_dl = DataLoader(CropsFrameEvalDS(FG_ROOT / "val" / f"fold_{fold_idx}" / "manifest.csv"),
                             batch_size=BATCH, shuffle=False, num_workers=0,
                             collate_fn=collate_frames_with_pose, pin_memory=(DEVICE.type=='cuda'))
        model_frame = load_per_frame_model(ck_graph)
        y_true, y_pred, prob, R, TH, _ = collect_frame_outputs_with_pose(model_frame, eval_dl, nC=len(CLS_NAMES))
        if len(y_true):
            fold_acc.append(float((y_true==y_pred).mean()))
            AF_y_true.append(y_true); AF_y_pred.append(y_pred); AF_prob.append(prob); AF_rng.append(R); AF_th.append(TH)
        print_metric_block(f"Fold {fold_idx} - PER-FRAME (VAL)", y_true, y_pred, prob, CLS_NAMES)
        print_accuracy_by_range_and_angle(f"Fold {fold_idx} - PER-FRAME (VAL)", y_true, y_pred, R, TH, RANGE_BINS_METERS, CENTER_ANG_DEG, CLS_NAMES)

        # Cache logits and (optionally) train temporal head
        cache_sequence_logits_for_fold(fold_idx, outdir, ck_graph)
        train_temporal_for_fold(fold_idx, outdir)  # kept for completeness

        # Optional whole-sequence temporal eval (disabled by default)
        if DO_FULL_TEMPORAL_EVAL:
            pass  # can be enabled same as original script

    # All-fold summary
    if AF_y_true:
        Yt = np.concatenate(AF_y_true); Yp = np.concatenate(AF_y_pred); Pb = np.concatenate(AF_prob)
        Rg = np.concatenate(AF_rng); Th = np.concatenate(AF_th)
        print_metric_block("ALL FOLDS - PER-FRAME (VAL)", Yt, Yp, Pb, CLS_NAMES)
        print_accuracy_by_range_and_angle("ALL FOLDS - PER-FRAME (VAL)", Yt, Yp, Rg, Th, RANGE_BINS_METERS, CENTER_ANG_DEG, CLS_NAMES)
        print(f"\n=== SUMMARY: Per-frame ACC across folds ===\n{100*np.mean(fold_acc):.2f}% +/- {100*np.std(fold_acc):.2f}%")

# Fixed horizons (t=3/4/5) + Dynamic stopping reports
def run_fixed_horizon_and_dynamic_reports(fold_idx: int, split: str = "val", fps_for_seconds: Optional[float] = None):
    outdir = FG_ROOT / "clf_outputs"
    for t in (3,4,5):
        pack = _fixed_horizon_eval(outdir, fold_idx, split, t)
        if pack is not None: _print_fixed_block(f"FOLD {fold_idx} | {split.upper()} | FIXED HORIZON t={t}", pack, RANGE_BINS_METERS, CENTER_ANG_DEG)
        else: print(f"[fold {fold_idx}] [{split}] No sequences usable for t={t}")
    dyn = _dynamic_eval(outdir, fold_idx, split, DECISION_THRESH, STABLE_K, DYN_MAX_FRAMES)
    if dyn is not None:
        _print_dynamic_block(f"FOLD {fold_idx} | {split.upper()} | DYNAMIC (thr={DECISION_THRESH},K={STABLE_K},tmax={DYN_MAX_FRAMES})",
                             dyn, RANGE_BINS_METERS, CENTER_ANG_DEG, fps=fps_for_seconds)
    else:
        print(f"[fold {fold_idx}] [{split}] No sequences produced a dynamic decision.")

# ----------------------------- Main -----------------------------
if __name__ == "__main__":
    run_fine_grained_all_folds()
    FPS_FOR_SECONDS = 10.0
    for k in range(1, K_FOLDS+1):
        run_fixed_horizon_and_dynamic_reports(k, split="val", fps_for_seconds=FPS_FOR_SECONDS)