In [None]:
#!/usr/bin/env python3
# MABe: Rule-based baseline+++ (Priors + Windows + Timing + Opportunity + Dilation + Deconflict + Fallback)
# - Kaggle-ready: submission.csv
# - LocalCV: GroupKFold over valid videos
# - Key fixes:
#   * Exclude only "MABe22_" labs (duplicates)
#   * Fallback: ensure >=1 segment per (video, agent, target) when allowed_total>0
#   * Scorer uses safe_json_loads() for active labels

import os, json, warnings
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Set
from collections import defaultdict

import numpy as np
import polars as pl
from tqdm.auto import tqdm

warnings.filterwarnings("ignore")

# ========================
# User knobs
# ========================
DO_CV        = True
DO_SUBMIT    = True
N_SPLITS     = 5
BETA         = 1.0
MAX_VIDEOS_PER_FOLD = 100   # None for all; 80-120で安定&高速化

# Stable params (≈0.12狙いの安定域)
EXCLUDE_PREFIXES = ("MABe22_",)  # ← 重要: 重複ラボのみ除外
LAPLACE_EPS      = 20.0
MIN_LEN          = 10
GAP_CLOSE        = 8
RARE_P_MIN       = 0.02
RARE_CAP         = 0.05
USE_WINDOWS      = True
OPP_GAMMA        = 0.60          # opp reweight for sniff/chase/follow
DILATE_BY        = 4

# ========================
# Config & schemas
# ========================
class Config:
    def __init__(self,
                 data_root: Path = Path("/kaggle/input/MABe-mouse-behavior-detection"),
                 submission_file: str = "submission.csv"):
        self.data_root = data_root
        self.submission_file = submission_file
        self.train_csv = self.data_root / "train.csv"
        self.test_csv  = self.data_root / "test.csv"
        self.train_annot_dir = self.data_root / "train_annotation"
        self.train_track_dir = self.data_root / "train_tracking"
        self.test_track_dir  = self.data_root / "test_tracking"
        self.solution_schema = {
            "video_id": pl.Int64, "agent_id": pl.Utf8, "target_id": pl.Utf8,
            "action": pl.Utf8, "start_frame": pl.Int64, "stop_frame": pl.Int64,
            "lab_id": pl.Utf8, "behaviors_labeled": pl.Utf8,
        }
        self.submission_schema = {
            "video_id": pl.Int64, "agent_id": pl.Utf8, "target_id": pl.Utf8,
            "action": pl.Utf8, "start_frame": pl.Int64, "stop_frame": pl.Int64,
        }

cfg = Config()

# ========================
# Utils
# ========================
def is_excluded_lab(lab: str) -> bool:
    return any(lab.startswith(p) for p in EXCLUDE_PREFIXES)

def safe_json_loads(s) -> List[str]:
    if s is None: return []
    if isinstance(s, list): return [str(x) for x in s]
    if not isinstance(s, str): return []
    t = s.strip()
    if not t: return []
    try:
        arr = json.loads(t)
        return [str(x) for x in arr] if isinstance(arr, list) else []
    except Exception:
        try:
            arr = json.loads(t.replace("'", '"'))
            return [str(x) for x in arr] if isinstance(arr, list) else []
        except Exception:
            return []

def _norm_mouse(x) -> str:
    s = str(x)
    return s if s.startswith("mouse") else f"mouse{s}"

def _strip_mouse(s) -> str:
    s=str(s); return s[5:] if s.startswith("mouse") else s

def validate_schema(df: pl.DataFrame, schema: Dict[str, pl.DataType], name: str) -> pl.DataFrame:
    for col, dtype in schema.items():
        if col not in df.columns:
            df = df.with_columns(pl.lit(None).cast(dtype).alias(col))
    casts = [pl.col(col).cast(dtype) for col, dtype in schema.items() if df[col].dtype != dtype]
    if casts: df = df.with_columns(casts)
    return df

# ========================
# Solution & spans
# ========================
def create_solution_df(meta: pl.DataFrame, cfg: Config) -> pl.DataFrame:
    recs: List[pl.DataFrame] = []

    def _pack(annot: pl.DataFrame, lab: str, vid: int, bl: str) -> pl.DataFrame:
        a = annot
        if "agent_id" in a.columns:
            a = a.with_columns(pl.concat_str([pl.lit("mouse"), pl.col("agent_id").cast(pl.Utf8)]).alias("agent_id"))
        if "target_id" in a.columns:
            a = a.with_columns(pl.concat_str([pl.lit("mouse"), pl.col("target_id").cast(pl.Utf8)]).alias("target_id"))
        a = a.with_columns(
            pl.lit(lab).alias("lab_id"),
            pl.lit(int(vid)).alias("video_id"),
            pl.lit(bl if isinstance(bl, str) else json.dumps(bl)).alias("behaviors_labeled"),
        )
        a = a.select([c for c in ["video_id","agent_id","target_id","action","start_frame","stop_frame","lab_id","behaviors_labeled"] if c in a.columns])
        a = validate_schema(a, cfg.solution_schema, "Solution(part)")
        return a

    for row in meta.to_dicts():
        lab = str(row["lab_id"])
        if is_excluded_lab(lab): 
            continue
        vid = int(row["video_id"])
        p = cfg.train_annot_dir / lab / f"{vid}.parquet"
        if not p.exists(): 
            continue
        try:
            annot = pl.read_parquet(p)
            recs.append(_pack(annot, lab, vid, row.get("behaviors_labeled","")))
        except Exception:
            continue

    if not recs:
        raise ValueError("No annotations loaded (check dataset paths).")
    sol = pl.concat(recs, how="vertical", parallel=True, rechunk=True)
    sol = validate_schema(sol, cfg.solution_schema, "Solution")
    if not (sol["start_frame"] <= sol["stop_frame"]).all():
        raise ValueError("start_frame > stop_frame exists.")
    return sol

def build_spans(meta: pl.DataFrame, split: str, cfg: Config) -> Dict[int, Tuple[int,int]]:
    track_dir = cfg.train_track_dir if split=="train" else cfg.test_track_dir
    out: Dict[int, Tuple[int,int]] = {}
    for r in meta.to_dicts():
        lab = str(r["lab_id"])
        if is_excluded_lab(lab): 
            continue
        vid = int(r["video_id"])
        p = track_dir / lab / f"{vid}.parquet"
        if not p.exists(): continue
        try:
            df = pl.read_parquet(p)
            frame_col = next((c for c in ["video_frame","frame","frame_idx"] if c in df.columns), None)
            if frame_col is None: 
                continue
            s = int(df[frame_col].min()); e = int(df[frame_col].max()) + 1
            if e > s:
                out[vid] = (s,e)
        except Exception:
            continue
    return out

# ========================
# Priors & timing
# ========================
def compute_action_priors(solution: pl.DataFrame, eps: float = LAPLACE_EPS):
    sol = solution.with_columns((pl.col("stop_frame") - pl.col("start_frame")).alias("dur"))
    by_lab = sol.group_by(["lab_id","action"]).agg(pl.col("dur").sum().alias("dur_sum"))
    global_ = sol.group_by(["action"]).agg(pl.col("dur").sum().alias("dur_sum"))

    actions = set(global_["action"].to_list())

    per_lab_weight: Dict[str, Dict[str, float]] = defaultdict(dict)
    for lab in by_lab["lab_id"].unique().to_list():
        sub = by_lab.filter(pl.col("lab_id")==lab)
        m = {r["action"]: float(r["dur_sum"]) for r in sub.to_dicts()}
        for a in actions: m[a] = m.get(a, 0.0) + eps
        tot = sum(m.values()) or 1.0
        per_lab_weight[str(lab)] = {a: m[a]/tot for a in actions}

    gm = {r["action"]: float(r["dur_sum"]) for r in global_.to_dicts()}
    for a in actions: gm[a] = gm.get(a, 0.0) + eps
    gtot = sum(gm.values()) or 1.0
    global_weight = {a: gm[a]/gtot for a in actions}

    med_lab_df = sol.group_by(["lab_id","action"]).median().select(["lab_id","action","dur"])
    med_glob_df= sol.group_by(["action"]).median().select(["action","dur"])
    med_lab: Dict[str, Dict[str,int]] = defaultdict(dict)
    for r in med_lab_df.to_dicts():
        med_lab[str(r["lab_id"])][str(r["action"])] = int(r["dur"])
    med_glob = {r["action"]: int(r["dur"]) for r in med_glob_df.to_dicts()}
    return per_lab_weight, global_weight, med_lab, med_glob

def compute_timing_priors(solution: pl.DataFrame, spans: Dict[int, Tuple[int,int]]):
    def _start_pct(row) -> float:
        vid = int(row["video_id"]); s,e = spans.get(vid, (0,1))
        denom = max(1, e - s)
        return float(max(0, min(1, (int(row["start_frame"]) - s)/denom)))
    rows=[]
    for r in solution.select(["lab_id","action","video_id","start_frame"]).to_dicts():
        rows.append({"lab_id": r["lab_id"], "action": r["action"], "start_pct": _start_pct(r)})
    df = pl.DataFrame(rows)
    lab_med = df.group_by(["lab_id","action"]).median().select(["lab_id","action","start_pct"])
    per_lab: Dict[str, Dict[str,float]] = defaultdict(dict)
    for r in lab_med.to_dicts():
        per_lab[str(r["lab_id"])][str(r["action"])] = float(r["start_pct"])
    g = df.group_by(["action"]).median().select(["action","start_pct"])
    glob = {r["action"]: float(r["start_pct"]) for r in g.to_dicts()}
    return per_lab, glob

# ========================
# Tracking features → windows & opportunity
# ========================
def _pair_features(df: pl.DataFrame, agent: str, target: str, downsample: int = 1) -> Optional[pl.DataFrame]:
    frame_candidates = ["video_frame","frame","frame_idx"]
    id_candidates    = ["mouse_id","id","track_id","agent_id"]
    x_candidates     = ["x","x_pos","x_position","x_mm","centroid_x","cx"]
    y_candidates     = ["y","y_pos","y_position","y_mm","centroid_y","cy"]
    cols=set(df.columns)
    frame_col = next((c for c in frame_candidates if c in cols), None)
    id_col    = next((c for c in id_candidates if c in cols), None)
    x_col     = next((c for c in x_candidates if c in cols), None)
    y_col     = next((c for c in y_candidates if c in cols), None)
    if not all([frame_col,id_col,x_col,y_col]): 
        return None

    a_id = _strip_mouse(agent); t_id = _strip_mouse(target)
    pdf = df.select([frame_col,id_col,x_col,y_col]).to_pandas()
    pdf[frame_col] = pdf[frame_col].astype(np.int64, copy=False)
    pdf[id_col]    = pdf[id_col].astype(str, copy=False)

    A = pdf[pdf[id_col]==a_id].copy(); B = pdf[pdf[id_col]==t_id].copy()
    if A.empty or B.empty: 
        return None
    A.drop_duplicates(subset=[frame_col], keep="first", inplace=True)
    B.drop_duplicates(subset=[frame_col], keep="first", inplace=True)

    M = A.merge(B, on=frame_col, how="inner", suffixes=("_a","_b"))
    if M.empty:
        return None
    M.sort_values(frame_col, inplace=True)

    ax = M[f"{x_col}_a"].to_numpy(np.float64, copy=False)
    ay = M[f"{y_col}_a"].to_numpy(np.float64, copy=False)
    bx = M[f"{x_col}_b"].to_numpy(np.float64, copy=False)
    by = M[f"{y_col}_b"].to_numpy(np.float64, copy=False)
    fr = M[frame_col].to_numpy(np.int64, copy=False)

    if downsample>1:
        sl = slice(0, None, int(downsample))
        ax,ay,bx,by,fr = ax[sl], ay[sl], bx[sl], by[sl], fr[sl]
        if ax.size==0: return None

    dx = ax - bx; dy = ay - by
    dist = np.sqrt(dx*dx + dy*dy)

    dax = np.diff(ax, prepend=ax[0]); day = np.diff(ay, prepend=ay[0])
    dbx = np.diff(bx, prepend=bx[0]); dby = np.diff(by, prepend=by[0])
    speed_a = np.sqrt(dax*dax + day*day)
    speed_b = np.sqrt(dbx*dbx + dby*dby)

    rel_speed = speed_a - speed_b
    ddist = np.diff(dist, prepend=dist[0])

    return pl.DataFrame({"frame": fr, "dist": dist, "rel_speed": rel_speed, "ddist": ddist}).sort("frame")

def _merge_intervals(iv: List[Tuple[int,int]]) -> List[Tuple[int,int]]:
    if not iv: return []
    iv = sorted(iv)
    out=[iv[0]]
    for s,e in iv[1:]:
        ps,pe = out[-1]
        if s<=pe: out[-1]=(ps, max(pe,e))
        else: out.append((s,e))
    return out

def make_windows(feat: pl.DataFrame, min_len: int,
                 q_dist: float=0.40, q_rel: float=0.60, q_ddist: float=0.40) -> List[Tuple[int,int]]:
    if feat is None or len(feat)==0: return []
    qd  = float(feat["dist"].quantile(q_dist))
    qr  = float(feat["rel_speed"].quantile(q_rel))
    qdd = float(feat["ddist"].quantile(q_ddist))
    cond = (pl.col("dist") <= qd) | ((pl.col("rel_speed") >= qr) & (pl.col("ddist") <= qdd))
    mask = feat.select(cond.alias("m")).to_series().to_list()
    frames = feat["frame"].to_list()
    win=[]; run=None
    for i,flag in enumerate(mask):
        if flag and run is None:
            run=[frames[i], frames[i]]
        elif flag and run is not None:
            run[1]=frames[i]
        elif (not flag) and run is not None:
            s,e=run[0], run[1]+1
            if e-s>=min_len: win.append((s,e))
            run=None
    if run is not None:
        s,e=run[0], run[1]+1
        if e-s>=min_len: win.append((s,e))
    return _merge_intervals(win)

def action_opportunity_weights(
    feat: Optional[pl.DataFrame],
    q_sniff_dist: float = 0.35,
    q_chase_rel: float = 0.70,
    q_chase_ddist: float = 0.35,
    q_follow_mid_low: float = 0.30,
    q_follow_mid_high: float = 0.70,
    q_follow_rel: float = 0.40,
) -> Dict[str, float]:
    if feat is None or len(feat) == 0:
        return {"sniff": 1e-3, "chase": 1e-3, "follow": 1e-3}
    n = int(len(feat))
    qd_sniff = float(feat["dist"].quantile(q_sniff_dist))
    qr_chase = float(feat["rel_speed"].quantile(q_chase_rel))
    qd_chase = float(feat["ddist"].quantile(q_chase_ddist))
    qd_low  = float(feat["dist"].quantile(q_follow_mid_low))
    qd_high = float(feat["dist"].quantile(q_follow_mid_high))
    qrel_small = float(feat["rel_speed"].abs().quantile(q_follow_rel))
    sniff_mask  = (feat["dist"] <= qd_sniff)
    chase_mask  = (feat["rel_speed"] >= qr_chase) & (feat["ddist"] <= qd_chase)
    follow_mask = (feat["dist"] >= qd_low) & (feat["dist"] <= qd_high) & (feat["rel_speed"].abs() <= qrel_small)
    sniff_p  = float(sniff_mask.sum())  / max(1, n)
    chase_p  = float(chase_mask.sum())  / max(1, n)
    follow_p = float(follow_mask.sum()) / max(1, n)
    eps = 1e-3
    return {"sniff": sniff_p + eps, "chase": chase_p + eps, "follow": follow_p + eps}

# ========================
# Allocation (no ML)
# ========================
def order_by_timing(actions: List[str], lab_id: str,
                    timing_lab: Dict[str, Dict[str,float]],
                    timing_glob: Dict[str,float],
                    tiebreak: Dict[str,int]) -> List[str]:
    def s(a: str) -> Tuple[float,int]:
        t = timing_lab.get(lab_id, {}).get(a, timing_glob.get(a, 0.5))
        return (t, tiebreak.get(a, 99))
    return sorted(actions, key=s)

def clip_rare(weights: Dict[str,float], actions: List[str], p_min: float, cap: float) -> Dict[str,float]:
    w = {a: max(0.0, float(weights.get(a, 0.0))) for a in actions}
    for a in actions:
        if w[a] < p_min:
            w[a] = min(w[a], cap)
    s = sum(w.values()) or 1.0
    return {a: w[a]/s for a in actions}

def allocate_in_windows(windows: List[Tuple[int,int]],
                        ordered_actions: List[str],
                        weights: Dict[str,float],
                        med_dur: Dict[str,int],
                        total_frames: int) -> List[Tuple[str,int,int]]:
    remain = sum(e-s for s,e in windows)
    if remain<=0: return []
    out=[]
    idx=0
    cur_s, cur_e = windows[0]
    for a in ordered_actions:
        if remain<=0: break
        want = int(weights.get(a,0)*total_frames)
        want = max(want, int(med_dur.get(a,0) or 0))
        want = min(want, remain)
        got=0
        while got<want and idx < len(windows):
            s,e = cur_s, cur_e
            if s>=e:
                idx+=1
                if idx>=len(windows): break
                cur_s, cur_e = windows[idx]
                continue
            take = min(want-got, e-s)
            out.append((a, s, s+take))
            got += take
            remain -= take
            cur_s = s + take
            if cur_s >= e and idx < len(windows):
                idx += 1
                if idx < len(windows):
                    cur_s, cur_e = windows[idx]
    return out

def smooth_segments(segs: List[Tuple[str,int,int]], min_len: int, gap_close: int) -> List[Tuple[str,int,int]]:
    if not segs: return []
    segs = sorted(segs, key=lambda x:(x[1],x[2],x[0]))
    segs = [s for s in segs if s[2]-s[1] >= min_len]
    if not segs: return []
    out=[segs[0]]
    for a,s,e in segs[1:]:
        pa,ps,pe = out[-1]
        if a==pa and s-pe <= gap_close:
            out[-1]=(a, ps, e)
        else:
            out.append((a,s,e))
    return out

# Deconflict within same (agent,target): no overlapping frames
def subtract_interval(interval: Tuple[int,int], occ: List[Tuple[int,int]]) -> List[Tuple[int,int]]:
    s,e = interval
    if s>=e: return []
    free = [(s,e)]
    for os,oe in occ:
        new_free=[]
        for fs,fe in free:
            if oe<=fs or fe<=os:
                new_free.append((fs,fe))
            else:
                if fs < os:
                    new_free.append((fs, os))
                if oe < fe:
                    new_free.append((oe, fe))
        free = new_free
        if not free: break
    return free

def add_interval(occ: List[Tuple[int,int]], interval: Tuple[int,int]) -> List[Tuple[int,int]]:
    occ.append(interval)
    occ.sort()
    merged=[]
    for s,e in occ:
        if not merged or s>merged[-1][1]:
            merged.append([s,e])
        else:
            merged[-1][1] = max(merged[-1][1], e)
    return [(s,e) for s,e in merged]

def deconflict_segments(segs: List[Tuple[str,int,int]],
                        priority: Dict[str,float],
                        min_len: int) -> List[Tuple[str,int,int]]:
    if not segs: return []
    segs_sorted = sorted(segs, key=lambda x: (-priority.get(x[0], 0.0), x[1], x[2]))
    occ: List[Tuple[int,int]] = []
    kept: List[Tuple[str,int,int]] = []
    for a,s,e in segs_sorted:
        for fs,fe in subtract_interval((s,e), occ):
            if fe - fs >= min_len:
                kept.append((a, fs, fe))
                occ = add_interval(occ, (fs,fe))
    kept.sort(key=lambda x:(x[1], x[2], x[0]))
    return kept

# ========================
# Predict
# ========================
def predict_without_ml(dataset: pl.DataFrame, split: str, cfg: Config,
                       priors_per_lab, priors_global,
                       meddur_per_lab, meddur_global,
                       timing_lab, timing_global,
                       use_windows: bool = USE_WINDOWS,
                       min_len: int = MIN_LEN, gap_close: int = GAP_CLOSE,
                       p_min: float = RARE_P_MIN, cap: float = RARE_CAP,
                       opp_gamma: float = OPP_GAMMA,
                       dilate_by: int = DILATE_BY) -> pl.DataFrame:
    track_dir = cfg.test_track_dir if split=="test" else cfg.train_track_dir
    rows=[]
    canonical = {"chase":0,"chaseattack":1,"attack":2,"approach":3,"avoid":4,"mount":5,"submit":6,"sniff":7,"follow":8}

    for r in tqdm(dataset.to_dicts(), total=len(dataset), desc=f"Predict {split}"):
        lab = str(r["lab_id"])
        if is_excluded_lab(lab): 
            continue
        vid = int(r["video_id"])
        p = track_dir / lab / f"{vid}.parquet"
        if not p.exists():
            continue
        try:
            trk = pl.read_parquet(p)
            frame_col = next((c for c in ["video_frame","frame","frame_idx"] if c in trk.columns), None)
            if frame_col is None: 
                continue
            s0 = int(trk[frame_col].min()); e0 = int(trk[frame_col].max())+1
            total = e0 - s0
            if total<=0: 
                continue

            raw = safe_json_loads(r.get("behaviors_labeled",""))
            triples=[]
            for t in raw:
                parts=[pp.strip() for pp in str(t).replace("'","").split(",")]
                if len(parts)==3:
                    triples.append(parts)
            if not triples: 
                continue
            beh = pl.DataFrame(triples, schema=["agent","target","action"], orient="row")

            for (agent, target), g in beh.group_by(["agent","target"]):
                actions = sorted(list(set(g["action"].to_list())), key=lambda a: canonical.get(a,99))
                if not actions: 
                    continue

                w_map = priors_per_lab.get(lab, {}) or priors_global
                md_map= meddur_per_lab.get(lab, {}) or meddur_global
                weights = clip_rare(w_map, actions, p_min=p_min, cap=cap)
                ordered = order_by_timing(actions, lab, timing_lab, timing_global, canonical)

                if use_windows:
                    feat = _pair_features(trk, _norm_mouse(agent), _norm_mouse(target), downsample=1)
                    wins = make_windows(feat, min_len=min_len) if feat is not None else [(s0,e0)]
                    if not wins: wins=[(s0,e0)]
                else:
                    feat = None
                    wins=[(s0,e0)]
                wins = _merge_intervals(wins)
                allowed_total = sum(e-s for s,e in wins)
                if allowed_total<=0: 
                    continue

                # opportunity reweight (sniff/chase/followのみ)
                if opp_gamma > 0.0:
                    opp = action_opportunity_weights(feat) if (use_windows and feat is not None) else {"sniff":1e-3,"chase":1e-3,"follow":1e-3}
                    TARGET = {"sniff","chase","follow"}
                    w_adj = {}
                    for a in actions:
                        base = float(weights.get(a, 0.0))
                        mult = (float(opp.get(a, 1e-3)) ** opp_gamma) if a in TARGET else 1.0
                        w_adj[a] = base * mult
                    s_w = sum(w_adj.values()) or 1.0
                    weights = {a: w_adj[a]/s_w for a in actions}

                segs = allocate_in_windows(wins, ordered, weights, md_map, allowed_total)
                segs = smooth_segments(segs, min_len=min_len, gap_close=gap_close)

                # --- Fallback: if still empty, force one short seg (prevents zero score) ---
                if not segs and allowed_total > 0:
                    best_a = max(actions, key=lambda a: (weights.get(a,0.0), -canonical.get(a,99)))
                    fs, fe = wins[0]
                    fe = min(fe, fs + max(min_len, int((sum(meddur_per_lab.get(lab,{}).values()) or 0) / max(1,len(actions)))))
                    if fe > fs:
                        segs = [(best_a, fs, fe)]

                # dilation
                if dilate_by and dilate_by > 0 and segs:
                    dil = []
                    for a, s, e in segs:
                        ss = max(s0, s - dilate_by)
                        ee = min(e0, e + dilate_by)
                        if ee > ss:
                            dil.append((a, ss, ee))
                    segs = smooth_segments(dil, min_len=min_len, gap_close=gap_close)

                # deconflict within pair
                if segs:
                    pri = {a: float(weights.get(a, 0.0)) + 1e-9*(1.0/(1+canonical.get(a,99))) for a in actions}
                    segs = deconflict_segments(segs, pri, min_len=min_len)

                for a, s, e in segs:
                    if e>s:
                        rows.append((vid, _norm_mouse(agent), _norm_mouse(target), a, int(s), int(e)))
        except Exception:
            continue

    if not rows:
        if split == "train":
            return pl.DataFrame([], schema=cfg.submission_schema)
        raise ValueError("No predictions generated.")
    sub = pl.DataFrame(rows, schema=cfg.submission_schema, orient="row")
    sub = validate_schema(sub, cfg.submission_schema, "Submission")
    return sub

# ========================
# Scorer (official-compatible, robust active labels)
# ========================
class HostVisibleError(Exception): pass

def single_lab_f1(lab_solution: pl.DataFrame, lab_submission: pl.DataFrame, beta: float = 1.0) -> float:
    label_frames: Dict[str, Set[int]] = defaultdict(set)
    for row in lab_solution.to_dicts():
        label_frames[row["label_key"]].update(range(int(row["start_frame"]), int(row["stop_frame"])))

    active_by_video: Dict[int, Set[str]] = {}
    for row in lab_solution.select(["video_id", "behaviors_labeled"]).unique().to_dicts():
        s: Set[str] = set()
        for item in safe_json_loads(row["behaviors_labeled"]):  # ← robust parse
            parts = [p.strip() for p in str(item).replace("'", "").split(",")]
            if len(parts) == 3:
                a, t, act = parts
                s.add(f"{_norm_mouse(a)},{_norm_mouse(t)},{act}")
        active_by_video[int(row["video_id"])] = s

    prediction_frames: Dict[str, Set[int]] = defaultdict(set)
    for video_id in lab_solution["video_id"].unique():
        vid = int(video_id)
        active = active_by_video.get(vid, set())
        predicted_mouse_pairs: Dict[str, Set[int]] = defaultdict(set)
        for row in lab_submission.filter(pl.col("video_id") == vid).to_dicts():
            triple_norm = f"{_norm_mouse(row['agent_id'])},{_norm_mouse(row['target_id'])},{row['action']}"
            if triple_norm not in active:
                continue
            pred_key = row["prediction_key"]
            agent_target = f"{_norm_mouse(row['agent_id'])},{_norm_mouse(row['target_id'])}"
            new_frames = set(range(int(row["start_frame"]), int(row["stop_frame"])))
            new_frames -= prediction_frames[pred_key]
            if predicted_mouse_pairs[agent_target] & new_frames:
                raise HostVisibleError("Multiple predictions for the same frame from one agent/target pair")
            prediction_frames[pred_key].update(new_frames)
            predicted_mouse_pairs[agent_target].update(new_frames)

    tps: Dict[str, int] = defaultdict(int)
    fns: Dict[str, int] = defaultdict(int)
    fps: Dict[str, int] = defaultdict(int)
    distinct_actions: Set[str] = set()

    for key, pred_frames in prediction_frames.items():
        action = key.split("_")[-1]
        distinct_actions.add(action)
        gt_frames = label_frames.get(key, set())
        tps[action] += len(pred_frames & gt_frames)
        fns[action] += len(gt_frames - pred_frames)
        fps[action] += len(pred_frames - gt_frames)

    for key, gt_frames in label_frames.items():
        action = key.split("_")[-1]
        distinct_actions.add(action)
        if key not in prediction_frames:
            fns[action] += len(gt_frames)

    if not distinct_actions:
        return 0.0

    beta2 = beta * beta
    f_scores: List[float] = []
    for action in distinct_actions:
        tp, fn, fp = tps[action], fns[action], fps[action]
        denom = (1 + beta2) * tp + beta2 * fn + fp
        f_scores.append(0.0 if denom == 0 else (1 + beta2) * tp / denom)
    return sum(f_scores) / len(f_scores)

def mouse_fbeta(solution: pl.DataFrame, submission: pl.DataFrame, beta: float = 1.0) -> float:
    solution = solution.select(["video_id","agent_id","target_id","action","start_frame","stop_frame","lab_id","behaviors_labeled"])
    submission = submission.select(["video_id","agent_id","target_id","action","start_frame","stop_frame"])
    solution = validate_schema(solution, {
        "video_id": pl.Int64,"agent_id": pl.Utf8,"target_id": pl.Utf8,"action": pl.Utf8,
        "start_frame": pl.Int64,"stop_frame": pl.Int64,"lab_id": pl.Utf8,"behaviors_labeled": pl.Utf8
    }, "Solution")
    submission = validate_schema(submission, {
        "video_id": pl.Int64,"agent_id": pl.Utf8,"target_id": pl.Utf8,"action": pl.Utf8,
        "start_frame": pl.Int64,"stop_frame": pl.Int64
    }, "Submission")
    if not (solution["start_frame"] <= solution["stop_frame"]).all():
        raise ValueError("Solution has start>stop")
    if not (submission["start_frame"] <= submission["stop_frame"]).all():
        raise ValueError("Submission has start>stop")

    vids = solution["video_id"].unique()
    submission = submission.filter(pl.col("video_id").is_in(vids))

    def add_key(df: pl.DataFrame, col_name: str) -> pl.DataFrame:
        return df.with_columns(
            pl.concat_str(
                [
                    pl.col("video_id").cast(pl.Utf8),
                    pl.col("agent_id").cast(pl.Utf8),
                    pl.col("target_id").cast(pl.Utf8),
                    pl.col("action"),
                ],
                separator="_",
            ).alias(col_name)
        )

    solution = add_key(solution, "label_key")
    submission = add_key(submission, "prediction_key")

    lab_scores: List[float] = []
    for lab_id in solution["lab_id"].unique():
        lab_solution = solution.filter(pl.col("lab_id") == lab_id)
        lab_videos = lab_solution["video_id"].unique()
        lab_submission = submission.filter(pl.col("video_id").is_in(lab_videos))
        lab_scores.append(single_lab_f1(lab_solution, lab_submission, beta=beta))
    return sum(lab_scores) / len(lab_scores) if lab_scores else 0.0

# ========================
# Valid video filtering & folds
# ========================
def is_valid_row(r: dict, split: str, cfg: Config) -> bool:
    lab = str(r["lab_id"])
    if is_excluded_lab(lab): 
        return False
    vid = int(r["video_id"])
    track_dir = cfg.train_track_dir if split=="train" else cfg.test_track_dir
    p = track_dir / lab / f"{vid}.parquet"
    if not p.exists():
        return False
    raw = safe_json_loads(r.get("behaviors_labeled",""))
    if not any(len(str(t).split(",")) >= 3 for t in raw):
        return False
    return True

def valid_meta_only(meta: pl.DataFrame, split: str, cfg: Config) -> pl.DataFrame:
    keep = [r for r in meta.to_dicts() if is_valid_row(r, split, cfg)]
    return pl.DataFrame(keep, schema=meta.schema) if keep else meta.head(0)

def limit_videos(meta: pl.DataFrame, n: Optional[int], seed: int = 42) -> pl.DataFrame:
    if (n is None) or (n <= 0) or (len(meta) <= n):
        return meta
    vids = meta["video_id"].unique().to_list()
    rng = np.random.RandomState(seed)
    pick = set(rng.choice(vids, size=n, replace=False).tolist())
    return meta.filter(pl.col("video_id").is_in(pick))

def build_folds_on_valid(valid_meta: pl.DataFrame, n_splits: int = 5) -> List[Tuple[np.ndarray, np.ndarray]]:
    from sklearn.model_selection import GroupKFold
    pdf = valid_meta.select(["video_id","lab_id"]).to_pandas()
    idx = np.arange(len(pdf))
    gkf = GroupKFold(n_splits=n_splits)
    return [(tr, va) for tr, va in gkf.split(idx, groups=pdf["lab_id"].values)]

def run_local_cv(train_meta: pl.DataFrame,
                 n_splits: int = 5, beta: float = 1.0,
                 max_videos_per_fold: Optional[int] = MAX_VIDEOS_PER_FOLD,
                 use_windows: bool = USE_WINDOWS, min_len: int = MIN_LEN, gap_close: int = GAP_CLOSE,
                 p_min: float = RARE_P_MIN, cap: float = RARE_CAP,
                 opp_gamma: float = OPP_GAMMA, dilate_by: int = DILATE_BY) -> List[float]:

    valid_train = valid_meta_only(train_meta, "train", cfg)
    if len(valid_train) == 0:
        print("[CV] No valid videos in train.")
        return [0.0]*n_splits

    folds = build_folds_on_valid(valid_train, n_splits=n_splits)
    scores=[]
    for i, (tr_idx, va_idx) in enumerate(folds, 1):
        tr_labs = set(valid_train[tr_idx.tolist()]["lab_id"].unique().to_list())
        va_labs = set(valid_train[va_idx.tolist()]["lab_id"].unique().to_list())

        tr_meta = train_meta.filter(pl.col("lab_id").is_in(list(tr_labs)))
        val_meta_valid = valid_train.filter(pl.col("lab_id").is_in(list(va_labs)))
        val_meta_valid = limit_videos(val_meta_valid, max_videos_per_fold, seed=100+i)

        if len(val_meta_valid) == 0:
            print(f"[Fold {i}] no valid videos -> score=0.000000")
            scores.append(0.0)
            continue

        solution_tr = create_solution_df(tr_meta, cfg)
        spans_tr    = build_spans(tr_meta, "train", cfg)
        per_lab, global_w, med_lab, med_glob = compute_action_priors(solution_tr, eps=LAPLACE_EPS)
        time_lab, time_glob = compute_timing_priors(solution_tr, spans_tr)

        pred_val = predict_without_ml(
            val_meta_valid, "train", cfg,
            priors_per_lab=per_lab, priors_global=global_w,
            meddur_per_lab=med_lab, meddur_global=med_glob,
            timing_lab=time_lab, timing_global=time_glob,
            use_windows=use_windows, min_len=min_len, gap_close=gap_close,
            p_min=p_min, cap=cap, opp_gamma=opp_gamma, dilate_by=dilate_by
        )

        solution_val = create_solution_df(val_meta_valid, cfg)

        try:
            sc = mouse_fbeta(solution_val, pred_val, beta=beta)
        except HostVisibleError as e:
            print(f"[Fold {i}] HostVisibleError: {e} -> 0.0")
            sc = 0.0
        print(f"[Fold {i}] videos={val_meta_valid['video_id'].n_unique()} labs={len(va_labs)} score={sc:.6f}")
        scores.append(sc)

    print(f"[LocalCV] mean={np.mean(scores):.6f}  std={np.std(scores):.6f}")
    return scores

# ========================
# Main
# ========================
print("[1/3] Load metadata ...")
train = pl.read_csv(cfg.train_csv)
test  = pl.read_csv(cfg.test_csv)

if DO_CV:
    print("\n[2/3] LocalCV ...")
    _ = run_local_cv(train,
                     n_splits=N_SPLITS, beta=BETA,
                     max_videos_per_fold=MAX_VIDEOS_PER_FOLD,
                     use_windows=USE_WINDOWS, min_len=MIN_LEN, gap_close=GAP_CLOSE,
                     p_min=RARE_P_MIN, cap=RARE_CAP,
                     opp_gamma=OPP_GAMMA, dilate_by=DILATE_BY)

if DO_SUBMIT:
    print("\n[3/3] Predict test (build priors from full train) ...")
    solution_all = create_solution_df(train, cfg)
    spans_all    = build_spans(train, "train", cfg)
    per_lab, global_w, med_lab, med_glob = compute_action_priors(solution_all, eps=LAPLACE_EPS)
    time_lab, time_glob = compute_timing_priors(solution_all, spans_all)

    submission = predict_without_ml(
        test, "test", cfg,
        priors_per_lab=per_lab, priors_global=global_w,
        meddur_per_lab=med_lab, meddur_global=med_glob,
        timing_lab=time_lab, timing_global=time_glob,
        use_windows=USE_WINDOWS, min_len=MIN_LEN, gap_close=GAP_CLOSE,
        p_min=RARE_P_MIN, cap=RARE_CAP, opp_gamma=OPP_GAMMA, dilate_by=DILATE_BY
    )

    submission = submission.select(["video_id","agent_id","target_id","action","start_frame","stop_frame"]).with_row_index("row_id")
    submission.write_csv(cfg.submission_file)
    print("Saved:", cfg.submission_file)
    try:
        print(submission.head())
    except Exception:
        pass
