In [None]:
# ================================================================
# MABe Challenge - Reworked Baseline (Single-file, unified main)
# - Unified Config flags (VIS / TRAIN / SUB)
# - FPS-aware feature engineering preserved from baseline
# - 5-fold CV (GroupKFold by video_id) for XGB (+ full XGB)
# - CatBoost full, and 3x LightGBM full (different params)
# - Train-time light Gaussian noise (optional)
# - Test-time augmentation (TTA) with Gaussian noise
# - Adaptive thresholds + temporal smoothing -> [start, stop) segments
# - robustify() for non-overlap & full coverage
# - All models saved per (section/mode/action) under CFG.MODEL_DIR
# ================================================================

import os, sys, gc, json, math, random, hashlib, warnings
from pathlib import Path
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import itertools
from tqdm.auto import tqdm

# tree models
import lightgbm as lgb
XGBOOST_AVAILABLE = False
CATBOOST_AVAILABLE = False
try:
    from xgboost import XGBClassifier
    XGBOOST_AVAILABLE = True
except Exception:
    pass

try:
    from catboost import CatBoostClassifier
    CATBOOST_AVAILABLE = True
except Exception:
    pass

from sklearn.model_selection import GroupKFold
from sklearn.base import clone
import joblib

# Optional GPU probe
try:
    import torch
    TORCH_OK = True
except Exception:
    TORCH_OK = False

# ---------------------------
# 1) CONFIG
# ---------------------------
class CFG:
    # Flags
    VIS   = False    # run EDA only
    TRAIN = False    # train & save weights
    SUB   = True    # load weights & predict & submit

    # Repro
    SEED = 42

    # Paths (Kaggle)
    INPUT_DIR = Path("/kaggle/input/MABe-mouse-behavior-detection")
    WORK_DIR  = Path("/kaggle/working")

    if TRAIN:
        MODEL_DIR = WORK_DIR / "models"
        META_DIR  = WORK_DIR / "model_meta"
        SUB_PATH  = WORK_DIR / "submission.csv"
    elif SUB:
        MODEL_DIR = Path("/kaggle/input/mice2025exp-notebook-test/models")
        META_DIR  = Path("/kaggle/input/mice2025exp-notebook-test/model_meta")
        SUB_PATH  = WORK_DIR / "submission.csv"

    # CV
    N_FOLDS_XGB = 5

    # Train-time augmentation
    TRAIN_NOISE_STD   = 0.0   # 0 to disable (e.g. 0.003~0.01 for light noise)
    TRAIN_NOISE_TIMES = 0     # 0 to disable

    # Test-time augmentation (TTA)
    TTA_NOISE_STD = 0.005
    TTA_N         = 5

    # GPU preference
    USE_GPU = True

    # LightGBM (3 variants)
    LGBM_PARAMSETS = [
        dict(n_estimators=225, learning_rate=0.07, min_child_samples=40,
             num_leaves=31, subsample=0.8, colsample_bytree=0.8, verbosity=-1),
        dict(n_estimators=150, learning_rate=0.10, min_child_samples=20,
             num_leaves=63, max_depth=8, subsample=0.7, colsample_bytree=0.9,
             reg_alpha=0.1, reg_lambda=0.1, verbosity=-1),
        dict(n_estimators=100, learning_rate=0.05, min_child_samples=30,
             num_leaves=127, max_depth=10, subsample=0.75, verbosity=-1),
    ]

    # XGB (fold & full)
    XGB_BASE = dict(
        n_estimators=180, learning_rate=0.08, max_depth=6,
        min_child_weight=5, subsample=0.8, colsample_bytree=0.8,
        eval_metric="logloss", n_jobs=-1, random_state=SEED
    )

    # CatBoost (full)
    CAT_BASE = dict(
        iterations=120, learning_rate=0.1, depth=6,
        verbose=False, allow_writing_files=False, random_seed=SEED
    )

    # ---------- EDA knobs ----------
    EDA_SAMPLE_TRAIN_IDXS = [5772, 484, 397, 428, 8669, 306]   # 會逐個載入並畫幾個關鍵幀
    EDA_FRAMES_PER_VIDEO  = 3                                  # 每支影片取幾個位置等距抽樣畫圖
    EDA_DO_ANIM           = True                               # 是否做行為片段動態骨架動畫
    EDA_ANIM_PADDING_FR   = 20                                 # 動畫上下游額外幀數
    EDA_MAX_ANNOT_VIDEOS  = -1                                 # 建全量標註表時最多掃多少支影片；-1=不限制（全掃）
    EDA_SAVE_ANIM_HTML    = False                              # 是否把動畫存成 HTML（kaggle/working 下）
    EDA_RANDOM_SEED       = 42                                 # EDA 隨機選樣的種子

# Ensure dirs
if CFG.TRAIN:
    CFG.MODEL_DIR.mkdir(parents=True, exist_ok=True)
    CFG.META_DIR.mkdir(parents=True, exist_ok=True)



In [None]:
# ---------------------------
# 2) UTILS
# ---------------------------
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    try:
        import torch
        torch.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    except Exception:
        pass

def has_gpu() -> bool:
    if not CFG.USE_GPU:
        return False
    if TORCH_OK and torch.cuda.is_available():
        return True
    return os.environ.get("CUDA_VISIBLE_DEVICES", "") not in ("", "-1")

GPU_AVAILABLE = has_gpu()
seed_everything(CFG.SEED)
print(f"[ENV] GPU_AVAILABLE={GPU_AVAILABLE}, XGB={XGBOOST_AVAILABLE}, CAT={CATBOOST_AVAILABLE}")

def bpt_slug(s: str) -> str:
    return hashlib.md5(s.encode("utf-8")).hexdigest()[:10]

def save_columns(cols, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "w") as f:
        for c in cols:
            f.write(str(c) + "\n")

def save_model_any(model, path: Path, model_type: str):
    path.parent.mkdir(parents=True, exist_ok=True)
    if model_type == "xgb":
        try:
            model.save_model(str(path))
        except Exception:
            joblib.dump(model, path)
    elif model_type == "cat":
        try:
            model.save_model(str(path))
        except Exception:
            joblib.dump(model, path)
    elif model_type == "lgbm":
        try:
            model.booster_.save_model(str(path))
        except Exception:
            joblib.dump(model, path)
    else:
        joblib.dump(model, path)

# ---------------------------
# 3) DATA LOADING
# ---------------------------
train = pd.read_csv(CFG.INPUT_DIR / "train.csv")
test  = pd.read_csv(CFG.INPUT_DIR / "test.csv")
train["n_mice"] = 4 - train[["mouse1_strain","mouse2_strain","mouse3_strain","mouse4_strain"]].isna().sum(axis=1)
body_parts_tracked_list = list(np.unique(train.body_parts_tracked))

drop_body_parts = [
    "headpiece_bottombackleft","headpiece_bottombackright","headpiece_bottomfrontleft","headpiece_bottomfrontright",
    "headpiece_topbackleft","headpiece_topbackright","headpiece_topfrontleft","headpiece_topfrontright",
    "spine_1","spine_2","tail_middle_1","tail_middle_2","tail_midpoint"
]



In [None]:
# ---------------------------
# 4) DATA PIPELINE (baseline logic)
# ---------------------------
def generate_mouse_data(dataset, traintest, traintest_directory=None, generate_single=True, generate_pair=True, verbose=False):
    assert traintest in ["train", "test"]
    if traintest_directory is None:
        traintest_directory = f"{CFG.INPUT_DIR}/{traintest}_tracking"

    for _, row in dataset.iterrows():
        lab_id  = row.lab_id
        video_id = row.video_id

        if type(row.behaviors_labeled) != str:
            if verbose: print("No labeled behaviors:", lab_id, video_id)
            continue

        path = f"{traintest_directory}/{lab_id}/{video_id}.parquet"
        vid = pd.read_parquet(path)
        if len(np.unique(vid.bodypart)) > 5:
            vid = vid.query("~ bodypart.isin(@drop_body_parts)")
        pvid = vid.pivot(columns=["mouse_id","bodypart"], index="video_frame", values=["x","y"])
        del vid
        pvid = pvid.reorder_levels([1,2,0], axis=1).T.sort_index().T
        pvid /= row.pix_per_cm_approx  # convert to cm

        vid_behaviors = json.loads(row.behaviors_labeled)
        vid_behaviors = sorted(list({b.replace("'", "") for b in vid_behaviors}))
        vid_behaviors = [b.split(",") for b in vid_behaviors]
        vid_behaviors = pd.DataFrame(vid_behaviors, columns=["agent","target","action"])

        if traintest == "train":
            try:
                annot = pd.read_parquet(path.replace("train_tracking","train_annotation"))
            except FileNotFoundError:
                continue

        if generate_single:
            vb_single = vid_behaviors.query("target == 'self'")
            for mouse_id_str in np.unique(vb_single.agent):
                try:
                    mouse_id = int(mouse_id_str[-1])
                    actions = np.unique(vb_single.query("agent == @mouse_id_str").action)
                    single_mouse = pvid.loc[:, mouse_id]
                    single_meta = pd.DataFrame({
                        "video_id": video_id, "agent_id": mouse_id_str, "target_id": "self",
                        "video_frame": single_mouse.index, "frames_per_second": row.frames_per_second
                    })
                    if traintest == "train":
                        single_label = pd.DataFrame(0.0, columns=actions, index=single_mouse.index)
                        annot_subset = annot.query("(agent_id == @mouse_id) & (target_id == @mouse_id)")
                        for i in range(len(annot_subset)):
                            ar = annot_subset.iloc[i]
                            single_label.loc[ar["start_frame"]:ar["stop_frame"], ar.action] = 1.0
                        yield "single", single_mouse, single_meta, single_label
                    else:
                        yield "single", single_mouse, single_meta, actions
                except KeyError:
                    pass

        if generate_pair:
            vb_pair = vid_behaviors.query("target != 'self'")
            if len(vb_pair) > 0:
                for agent, target in itertools.permutations(np.unique(pvid.columns.get_level_values("mouse_id")), 2):
                    agent_str = f"mouse{agent}"
                    target_str = f"mouse{target}"
                    actions = np.unique(vb_pair.query("(agent == @agent_str) & (target == @target_str)").action)
                    mouse_pair = pd.concat([pvid[agent], pvid[target]], axis=1, keys=["A","B"])
                    pair_meta = pd.DataFrame({
                        "video_id": video_id, "agent_id": agent_str, "target_id": target_str,
                        "video_frame": mouse_pair.index, "frames_per_second": row.frames_per_second
                    })
                    if traintest == "train":
                        pair_label = pd.DataFrame(0.0, columns=actions, index=mouse_pair.index)
                        annot_subset = annot.query("(agent_id == @agent) & (target_id == @target)")
                        for i in range(len(annot_subset)):
                            ar = annot_subset.iloc[i]
                            pair_label.loc[ar["start_frame"]:ar["stop_frame"], ar.action] = 1.0
                        yield "pair", mouse_pair, pair_meta, pair_label
                    else:
                        yield "pair", mouse_pair, pair_meta, actions

# ---- Feature helpers (fps-aware)
def _scale(n_frames_at_30fps, fps, ref=30.0):
    return max(1, int(round(n_frames_at_30fps * float(fps) / ref)))

def _scale_signed(n_frames_at_30fps, fps, ref=30.0):
    if n_frames_at_30fps == 0: return 0
    s = 1 if n_frames_at_30fps > 0 else -1
    mag = max(1, int(round(abs(n_frames_at_30fps) * float(fps) / ref)))
    return s * mag

def _fps_from_meta(meta_df, fallback_lookup, default_fps=30.0):
    if "frames_per_second" in meta_df.columns and pd.notnull(meta_df["frames_per_second"]).any():
        return float(meta_df["frames_per_second"].iloc[0])
    vid = meta_df["video_id"].iloc[0]
    return float(fallback_lookup.get(vid, default_fps))

def add_curvature_features(X, cx, cy, fps):
    vx, vy = cx.diff(), cy.diff()
    ax, ay = vx.diff(), vy.diff()
    cross = vx * ay - vy * ax
    vmag = np.sqrt(vx**2 + vy**2)
    curv = np.abs(cross) / (vmag**3 + 1e-6)
    for w in [30, 60]:
        ws = _scale(w, fps); X[f"curv_mean_{w}"] = curv.rolling(ws, min_periods=max(1, ws//6)).mean()
    ang = np.arctan2(vy, vx); ang_ch = np.abs(ang.diff())
    ws = _scale(30, fps); X[f"turn_rate_{30}"] = ang_ch.rolling(ws, min_periods=max(1, ws//6)).sum()
    return X

def add_multiscale_features(X, cx, cy, fps):
    speed = np.sqrt(cx.diff()**2 + cy.diff()**2) * float(fps)
    for s in [10, 40, 160]:
        ws = _scale(s, fps)
        if len(speed) >= ws:
            X[f"sp_m{s}"] = speed.rolling(ws, min_periods=max(1, ws//4)).mean()
            X[f"sp_s{s}"] = speed.rolling(ws, min_periods=max(1, ws//4)).std()
    if "sp_m10" in X.columns and "sp_m160" in X.columns:
        X["sp_ratio"] = X["sp_m10"] / (X["sp_m160"] + 1e-6)
    return X

def add_state_features(X, cx, cy, fps):
    speed = np.sqrt(cx.diff()**2 + cy.diff()**2) * float(fps)
    ws_ma = _scale(15, fps); sp_ma = speed.rolling(ws_ma, min_periods=max(1, ws_ma//3)).mean()
    try:
        bins = [-np.inf, 0.5*fps, 2.0*fps, 5.0*fps, np.inf]
        states = pd.cut(sp_ma, bins=bins, labels=[0,1,2,3]).astype(float)
        for w in [60, 120]:
            ws = _scale(w, fps)
            if len(states) >= ws:
                for st in [0,1,2,3]:
                    X[f"s{st}_{w}"] = ((states == st).astype(float)).rolling(ws, min_periods=max(1, ws//6)).mean()
                X[f"trans_{w}"] = (states != states.shift(1)).astype(float).rolling(ws, min_periods=max(1, ws//6)).sum()
    except Exception:
        pass
    return X

def add_longrange_features(X, cx, cy, fps):
    for w in [120, 240]:
        ws = _scale(w, fps)
        if len(cx) >= ws:
            X[f"x_ml{w}"] = cx.rolling(ws, min_periods=max(5, ws//6)).mean()
            X[f"y_ml{w}"] = cy.rolling(ws, min_periods=max(5, ws//6)).mean()
    for span in [60, 120]:
        s = _scale(span, fps)
        X[f"x_e{span}"] = cx.ewm(span=s, min_periods=1).mean()
        X[f"y_e{span}"] = cy.ewm(span=s, min_periods=1).mean()
    speed = np.sqrt(cx.diff()**2 + cy.diff()**2) * float(fps)
    for w in [60, 120]:
        ws = _scale(w, fps)
        if len(speed) >= ws:
            X[f"sp_pct{w}"] = speed.rolling(ws, min_periods=max(5, ws//6)).rank(pct=True)
    return X

def add_interaction_features(X, pair, avail_A, avail_B, fps):
    if "body_center" not in avail_A or "body_center" not in avail_B: return X
    rel_x = pair["A"]["body_center"]["x"] - pair["B"]["body_center"]["x"]
    rel_y = pair["A"]["body_center"]["y"] - pair["B"]["body_center"]["y"]
    rel_d = np.sqrt(rel_x**2 + rel_y**2)
    Avx = pair["A"]["body_center"]["x"].diff(); Avy = pair["A"]["body_center"]["y"].diff()
    Bvx = pair["B"]["body_center"]["x"].diff(); Bvy = pair["B"]["body_center"]["y"].diff()
    A_lead = (Avx*rel_x + Avy*rel_y) / (np.sqrt(Avx**2 + Avy**2)*rel_d + 1e-6)
    B_lead = (Bvx*(-rel_x) + Bvy*(-rel_y)) / (np.sqrt(Bvx**2 + Bvy**2)*rel_d + 1e-6)
    for w in [30, 60]:
        ws = _scale(w, fps)
        X[f"A_ld{w}"] = A_lead.rolling(ws, min_periods=max(1, ws//6)).mean()
        X[f"B_ld{w}"] = B_lead.rolling(ws, min_periods=max(1, ws//6)).mean()
    approach = -rel_d.diff()
    chase = approach * B_lead
    ws = _scale(30, fps); X[f"chase_{30}"] = chase.rolling(ws, min_periods=max(1, ws//6)).mean()
    for w in [60,120]:
        ws = _scale(w, fps)
        A_sp = np.sqrt(Avx**2 + Avy**2); B_sp = np.sqrt(Bvx**2 + Bvy**2)
        X[f"sp_cor{w}"] = A_sp.rolling(ws, min_periods=max(1, ws//6)).corr(B_sp)
    return X

def transform_single(single_mouse, body_parts_tracked, fps):
    avail = single_mouse.columns.get_level_values(0)
    X = pd.DataFrame({
        f"{p1}+{p2}": np.square(single_mouse[p1] - single_mouse[p2]).sum(axis=1, skipna=False)
        for p1, p2 in itertools.combinations(body_parts_tracked, 2)
        if p1 in avail and p2 in avail
    })
    X = X.reindex(columns=[f"{p1}+{p2}" for p1, p2 in itertools.combinations(body_parts_tracked, 2)], copy=False)
    if all(p in single_mouse.columns for p in ["ear_left","ear_right","tail_base"]):
        lag = _scale(10, fps)
        sh = single_mouse[["ear_left","ear_right","tail_base"]].shift(lag)
        speeds = pd.DataFrame({
            "sp_lf":  np.square(single_mouse["ear_left"]  - sh["ear_left"]).sum(axis=1, skipna=False),
            "sp_rt":  np.square(single_mouse["ear_right"] - sh["ear_right"]).sum(axis=1, skipna=False),
            "sp_lf2": np.square(single_mouse["ear_left"]  - sh["tail_base"]).sum(axis=1, skipna=False),
            "sp_rt2": np.square(single_mouse["ear_right"] - sh["tail_base"]).sum(axis=1, skipna=False),
        })
        X = pd.concat([X, speeds], axis=1)
    if "nose+tail_base" in X.columns and "ear_left+ear_right" in X.columns:
        X["elong"] = X["nose+tail_base"] / (X["ear_left+ear_right"] + 1e-6)
    if all(p in avail for p in ["nose","body_center","tail_base"]):
        v1 = single_mouse["nose"] - single_mouse["body_center"]
        v2 = single_mouse["tail_base"] - single_mouse["body_center"]
        X["body_ang"] = (v1["x"]*v2["x"] + v1["y"]*v2["y"]) / (np.sqrt(v1["x"]**2+v1["y"]**2) * np.sqrt(v2["x"]**2+v2["y"]**2) + 1e-6)
    if "body_center" in avail:
        cx = single_mouse["body_center"]["x"]; cy = single_mouse["body_center"]["y"]
        for w in [5,15,30,60]:
            ws = _scale(w, fps); roll = dict(min_periods=1, center=True)
            X[f"cx_m{w}"] = cx.rolling(ws, **roll).mean()
            X[f"cy_m{w}"] = cy.rolling(ws, **roll).mean()
            X[f"cx_s{w}"] = cx.rolling(ws, **roll).std()
            X[f"cy_s{w}"] = cy.rolling(ws, **roll).std()
            X[f"x_rng{w}"] = cx.rolling(ws, **roll).max() - cx.rolling(ws, **roll).min()
            X[f"y_rng{w}"] = cy.rolling(ws, **roll).max() - cy.rolling(ws, **roll).min()
            X[f"disp{w}"] = np.sqrt(cx.diff().rolling(ws, min_periods=1).sum()**2 +
                                    cy.diff().rolling(ws, min_periods=1).sum()**2)
            X[f"act{w}"]  = np.sqrt(cx.diff().rolling(ws, min_periods=1).var() +
                                    cy.diff().rolling(ws, min_periods=1).var())
        X = add_curvature_features(X, cx, cy, fps)
        X = add_multiscale_features(X, cx, cy, fps)
        X = add_state_features(X, cx, cy, fps)
        X = add_longrange_features(X, cx, cy, fps)
    if all(p in avail for p in ["nose","tail_base"]):
        nt = np.sqrt((single_mouse["nose"]["x"] - single_mouse["tail_base"]["x"])**2 +
                     (single_mouse["nose"]["y"] - single_mouse["tail_base"]["y"])**2)
        for lag in [10,20,40]:
            l = _scale(lag, fps); X[f"nt_lg{lag}"] = nt.shift(l); X[f"nt_df{lag}"] = nt - nt.shift(l)
    if all(p in avail for p in ["ear_left","ear_right"]):
        ed = np.sqrt((single_mouse["ear_left"]["x"] - single_mouse["ear_right"]["x"])**2 +
                     (single_mouse["ear_left"]["y"] - single_mouse["ear_right"]["y"])**2)
        for off in [-20,-10,10,20]:
            o = _scale_signed(off, fps); X[f"ear_o{off}"] = ed.shift(-o)
        w = _scale(30, fps)
        X["ear_con"] = ed.rolling(w, min_periods=1, center=True).std() / (ed.rolling(w, min_periods=1, center=True).mean() + 1e-6)
    return X.astype(np.float32, copy=False)

def transform_pair(pair, body_parts_tracked, fps):
    avail_A = pair["A"].columns.get_level_values(0)
    avail_B = pair["B"].columns.get_level_values(0)
    X = pd.DataFrame({
        f"12+{p1}+{p2}": np.square(pair["A"][p1] - pair["B"][p2]).sum(axis=1, skipna=False)
        for p1, p2 in itertools.product(body_parts_tracked, repeat=2)
        if p1 in avail_A and p2 in avail_B
    })
    X = X.reindex(columns=[f"12+{p1}+{p2}" for p1, p2 in itertools.product(body_parts_tracked, repeat=2)], copy=False)
    if ("A","ear_left") in pair.columns and ("B","ear_left") in pair.columns:
        lag = _scale(10, fps)
        shA = pair["A"]["ear_left"].shift(lag); shB = pair["B"]["ear_left"].shift(lag)
        speeds = pd.DataFrame({
            "sp_A":  np.square(pair["A"]["ear_left"] - shA).sum(axis=1, skipna=False),
            "sp_AB": np.square(pair["A"]["ear_left"] - shB).sum(axis=1, skipna=False),
            "sp_B":  np.square(pair["B"]["ear_left"] - shB).sum(axis=1, skipna=False),
        })
        X = pd.concat([X, speeds], axis=1)
    if "nose+tail_base" in X.columns and "ear_left+ear_right" in X.columns:
        X["elong"] = X["nose+tail_base"] / (X["ear_left+ear_right"] + 1e-6)
    if all(p in avail_A for p in ["nose","tail_base"]) and all(p in avail_B for p in ["nose","tail_base"]):
        dir_A = pair["A"]["nose"] - pair["A"]["tail_base"]
        dir_B = pair["B"]["nose"] - pair["B"]["tail_base"]
        X["rel_ori"] = (dir_A["x"]*dir_B["x"] + dir_A["y"]*dir_B["y"]) / (np.sqrt(dir_A["x"]**2+dir_A["y"]**2) * np.sqrt(dir_B["x"]**2+dir_B["y"]**2) + 1e-6)
    if all(p in avail_A for p in ["nose"]) and all(p in avail_B for p in ["nose"]):
        cur = np.square(pair["A"]["nose"] - pair["B"]["nose"]).sum(axis=1, skipna=False)
        lag = _scale(10, fps)
        shA = pair["A"]["nose"].shift(lag); shB = pair["B"]["nose"].shift(lag)
        past = np.square(shA - shB).sum(axis=1, skipna=False)
        X["appr"] = cur - past
    if "body_center" in avail_A and "body_center" in avail_B:
        cd = np.sqrt((pair["A"]["body_center"]["x"] - pair["B"]["body_center"]["x"])**2 +
                     (pair["A"]["body_center"]["y"] - pair["B"]["body_center"]["y"])**2)
        X["v_cls"] = (cd < 5.0).astype(float)
        X["cls"]   = ((cd >= 5.0) & (cd < 15.0)).astype(float)
        X["med"]   = ((cd >= 15.0) & (cd < 30.0)).astype(float)
        X["far"]   = (cd >= 30.0).astype(float)
        cd_full = np.square(pair["A"]["body_center"] - pair["B"]["body_center"]).sum(axis=1, skipna=False)
        for w in [5,15,30,60]:
            ws = _scale(w, fps); roll = dict(min_periods=1, center=True)
            X[f"d_m{w}"]  = cd_full.rolling(ws, **roll).mean()
            X[f"d_s{w}"]  = cd_full.rolling(ws, **roll).std()
            X[f"d_mn{w}"] = cd_full.rolling(ws, **roll).min()
            X[f"d_mx{w}"] = cd_full.rolling(ws, **roll).max()
            d_var = cd_full.rolling(ws, **roll).var()
            X[f"int{w}"] = 1 / (1 + d_var)
            Axd = pair["A"]["body_center"]["x"].diff()
            Ayd = pair["A"]["body_center"]["y"].diff()
            Bxd = pair["B"]["body_center"]["x"].diff()
            Byd = pair["B"]["body_center"]["y"].diff()
            coord = Axd*Bxd + Ayd*Byd
            X[f"co_m{w}"] = coord.rolling(ws, **roll).mean()
            X[f"co_s{w}"] = coord.rolling(ws, **roll).std()
    if "nose" in avail_A and "nose" in avail_B:
        nn = np.sqrt((pair["A"]["nose"]["x"] - pair["B"]["nose"]["x"])**2 +
                     (pair["A"]["nose"]["y"] - pair["B"]["nose"]["y"])**2)
        for lag in [10,20,40]:
            l = _scale(lag, fps)
            X[f"nn_lg{lag}"]  = nn.shift(l)
            X[f"nn_ch{lag}"]  = nn - nn.shift(l)
            is_cl = (nn < 10.0).astype(float)
            X[f"cl_ps{lag}"]  = is_cl.rolling(l, min_periods=1).mean()
    if "body_center" in avail_A and "body_center" in avail_B:
        Avx = pair["A"]["body_center"]["x"].diff(); Avy = pair["A"]["body_center"]["y"].diff()
        Bvx = pair["B"]["body_center"]["x"].diff(); Bvy = pair["B"]["body_center"]["y"].diff()
        val = (Avx*Bvx + Avy*Bvy) / (np.sqrt(Avx**2+Avy**2) * np.sqrt(Bvx**2+Bvy**2) + 1e-6)
        for off in [-20,-10,0,10,20]:
            o = _scale_signed(off, fps); X[f"va_{off}"] = val.shift(-o)
        w = _scale(30, fps)
        cd_full = np.square(pair["A"]["body_center"] - pair["B"]["body_center"]).sum(axis=1, skipna=False)
        X["int_con"] = cd_full.rolling(w, min_periods=1, center=True).std() / (cd_full.rolling(w, min_periods=1, center=True).mean() + 1e-6)
        X = add_interaction_features(X, pair, avail_A, avail_B, fps)
    return X.astype(np.float32, copy=False)

In [None]:
# ---------------------------
# 5) TRAIN PREP
# ---------------------------
def prepare_train_mats(train_subset: pd.DataFrame, body_parts_tracked, mode: str):
    feats_parts, meta_list, label_list = [], [], []
    fps_lookup = (
        train_subset[["video_id","frames_per_second"]]
        .drop_duplicates("video_id").set_index("video_id")["frames_per_second"].to_dict()
    )
    for switch, data_i, meta_i, label_i in generate_mouse_data(
        train_subset, "train",
        generate_single=(mode=="single"),
        generate_pair=(mode=="pair"),
        verbose=False
    ):
        if switch != mode: 
            continue
        fps_i = _fps_from_meta(meta_i, fps_lookup, default_fps=30.0)
        if mode == "single":
            Xi = transform_single(data_i, body_parts_tracked, fps_i).astype(np.float32)
        else:
            Xi = transform_pair(data_i, body_parts_tracked, fps_i).astype(np.float32)
        feats_parts.append(Xi); meta_list.append(meta_i); label_list.append(label_i)

    if len(feats_parts) == 0:
        return None, None, None, fps_lookup

    X_tr  = pd.concat(feats_parts, axis=0, ignore_index=True)
    meta  = pd.concat(meta_list,  axis=0, ignore_index=True)
    label = pd.concat(label_list, axis=0, ignore_index=True)
    return X_tr, label, meta, fps_lookup


# ---------------------------
# 6) TRAINERS
# ---------------------------
def _maybe_augment(X: np.ndarray, y: np.ndarray, noise_std: float, times: int, seed: int):
    if noise_std <= 0 or times <= 0:
        return X, y
    rng = np.random.default_rng(seed)
    Xs = [X]; ys = [y]
    for _ in range(times):
        Xs.append(X + rng.normal(0.0, noise_std, size=X.shape).astype(np.float32))
        ys.append(y.copy())
    return np.concatenate(Xs, 0), np.concatenate(ys, 0)

def _xgb_params():
    p = CFG.XGB_BASE.copy()
    p["tree_method"] = "gpu_hist" if (GPU_AVAILABLE and CFG.USE_GPU) else "hist"
    return p

def _cat_params():
    p = CFG.CAT_BASE.copy()
    if GPU_AVAILABLE and CFG.USE_GPU:
        p["task_type"] = "GPU"
    return p

def _lgbm_wrap_params(base_params: dict):
    p = base_params.copy()
    # 固定必要参数
    p.update(dict(objective="binary", random_state=CFG.SEED, n_jobs=-1))
    # 正确的 GPU 開關鍵名是 device_type
    if GPU_AVAILABLE and CFG.USE_GPU:
        p["device_type"] = "gpu"
        # 用 double precision，GPU 分裂更稳定一点
        p["gpu_use_dp"] = True
    else:
        p["device_type"] = "cpu"

    # 提高稳定性的一些保守设置
    p.setdefault("min_data_in_leaf", max(20, p.get("min_child_samples", 20)))
    p.setdefault("feature_pre_filter", False)
    p.setdefault("zero_as_missing", True)
    p.setdefault("max_bin", 255)
    return p


def train_models_for_action(X_df: pd.DataFrame, y_series: pd.Series, meta_df: pd.DataFrame,
                            save_root: Path, action: str):
    save_root.mkdir(parents=True, exist_ok=True)

    # 先做 y 掩码（只在有标签的行上训练/筛特征）
    y_raw = y_series.to_numpy()
    mask  = ~pd.isna(y_raw)
    if mask.sum() == 0:
        print(f"    [SKIP] {action}: no labels")
        return

    # 丢掉常数/全NaN特征（仅在有标签的样本上统计）
    const_mask = (X_df.loc[mask].nunique(dropna=False) <= 1)
    if const_mask.any():
        tqdm.write(f"    [Info] drop {int(const_mask.sum())} constant features for {action}")
        X_df = X_df.loc[:, ~const_mask].copy()

    if X_df.shape[1] == 0:
        print(f"    [SKIP] {action}: no usable features after constant-drop")
        return

    # 保存最终特征列（供推理对齐）
    save_columns(list(X_df.columns), save_root / "feature_columns.txt")

    # 准备 X / y / groups（仅取有标签的行）
    idx = np.flatnonzero(mask)
    y   = y_raw[mask].astype(int)
    X   = X_df.to_numpy(np.float32, copy=False)[idx]
    groups = meta_df.loc[mask, "video_id"].to_numpy()

    # 少正样/单类直接跳过
    if (y.sum() < 5) or (np.unique(y).size == 1):
        print(f"    [SKIP] {action}: insufficient positives ({y.sum()})")
        return

    # ---------------- XGB: CV + full ----------------
    if XGBOOST_AVAILABLE:
        params = _xgb_params()

        # 动态降折：折数 <= 真实 group 数；不足 2 折就跳过 CV
        n_groups = int(pd.Series(groups).nunique())
        n_folds = min(CFG.N_FOLDS_XGB, n_groups)

        if n_folds < 2:
            tqdm.write(f"    [XGB] skip CV for {action}: only {n_groups} group(s)")
        else:
            gkf = GroupKFold(n_splits=n_folds)
            for k, (tr_idx, va_idx) in tqdm(
                enumerate(gkf.split(X, y, groups)),
                total=n_folds,
                desc=f"[{action}] XGB {n_folds}-fold",
                leave=True
            ):
                X_tr, y_tr = X[tr_idx], y[tr_idx]
                X_va, y_va = X[va_idx], y[va_idx]
                X_tr_aug, y_tr_aug = _maybe_augment(
                    X_tr, y_tr, CFG.TRAIN_NOISE_STD, CFG.TRAIN_NOISE_TIMES, CFG.SEED + k
                )
                clf = XGBClassifier(**params)
                clf.fit(X_tr_aug, y_tr_aug, eval_set=[(X_va, y_va)], verbose=False)
                save_model_any(clf, save_root / f"xgb_fold{k}.json", "xgb")
                del clf; gc.collect()

        # full 模型（无论是否做了 CV 都训）
        tqdm.write(f"[{action}] XGB full")
        clf_full = XGBClassifier(**params)
        X_aug, y_aug = _maybe_augment(X, y, CFG.TRAIN_NOISE_STD, CFG.TRAIN_NOISE_TIMES, CFG.SEED + 777)
        clf_full.fit(X_aug, y_aug, verbose=False)
        save_model_any(clf_full, save_root / "xgb_full.json", "xgb")
        del clf_full; gc.collect()
    else:
        print("    [WARN] XGBoost unavailable, skip XGB models.")

    # ---------------- CatBoost: full ----------------
    if CATBOOST_AVAILABLE:
        cat = CatBoostClassifier(**_cat_params())
        X_aug, y_aug = _maybe_augment(X, y, CFG.TRAIN_NOISE_STD, CFG.TRAIN_NOISE_TIMES, CFG.SEED + 888)
        cat.fit(X_aug, y_aug, verbose=False)
        save_model_any(cat, save_root / "cat_full.cbm", "cat")
        del cat; gc.collect()
    else:
        print("    [WARN] CatBoost unavailable, skip CatBoost.")

    # ---------------- LightGBM: 3 组全量 ----------------
    for i, base in tqdm(
        list(enumerate(CFG.LGBM_PARAMSETS, 1)),
        total=len(CFG.LGBM_PARAMSETS),
        desc=f"[{action}] LGBM variants",
        leave=False
    ):
        params = _lgbm_wrap_params(base)
        X_aug, y_aug = _maybe_augment(X, y, CFG.TRAIN_NOISE_STD, CFG.TRAIN_NOISE_TIMES, CFG.SEED + 999 + i)

        def _fit_one(params_):
            model_ = lgb.LGBMClassifier(**params_)
            model_.fit(X_aug, y_aug)  # 不要传 verbose
            return model_

        try:
            lgbm = _fit_one(params_) if False else _fit_one(params)  # 占位，便于阅读
        except Exception as e:
            tqdm.write(f"    [LGBM-{i}] GPU failed: {str(e).splitlines()[0]} -> fallback CPU/row-wise")
            params["device_type"] = "cpu"
            params["force_row_wise"] = True
            lgbm = _fit_one(params)

        save_model_any(lgbm, save_root / f"lgbm_full_{i}.txt", "lgbm")
        del lgbm; gc.collect()

def train_for_section(section_idx: int, body_parts_tracked_str: str):
    bslug = bpt_slug(body_parts_tracked_str)
    try:
        body_parts_tracked = json.loads(body_parts_tracked_str)
        if len(body_parts_tracked) > 5:
            body_parts_tracked = [b for b in body_parts_tracked if b not in drop_body_parts]
    except Exception:
        print(f"[WARN] invalid body_parts json at section {section_idx}")
        return

    train_subset = train[train.body_parts_tracked == body_parts_tracked_str]
    for mode in ("single","pair"):
        print(f"[Section {section_idx}] mode={mode} | #videos={train_subset.video_id.nunique()}")
        X_tr, label_df, meta_df, _ = prepare_train_mats(train_subset, body_parts_tracked, mode)
        if X_tr is None:
            print("  -> No data, skip.")
            continue
        actions = list(label_df.columns)
        print(f"  -> X shape={X_tr.shape}, actions={len(actions)}")
        # save meta for inference
        with open(CFG.META_DIR / f"{bslug}_{mode}.json", "w") as f:
            json.dump({"body_parts_tracked": body_parts_tracked, "mode": mode, "actions": actions}, f)
        for action in tqdm(actions, desc=f"[Section {section_idx}] {mode} actions", leave=True):
            save_root = CFG.MODEL_DIR / bslug / mode / action
            tqdm.write(f"  [Train] action={action}")
            train_models_for_action(X_tr, label_df[action], meta_df, save_root, action)
            gc.collect()
        del X_tr, label_df, meta_df; gc.collect()


In [None]:
# ---------------------------
# 7) EDA (VIS)
# ---------------------------
import matplotlib.pyplot as plt
# ===================== EDA: Visualizer + Plots + Animations =====================
def _safe_display_df_head(df: pd.DataFrame, n: int = 5, name: str = ""):
    print(f"\n--- {name} (head) ---")
    try:
        from IPython.display import display
        display(df.head(n))
    except Exception:
        print(df.head(n).to_string())

class Visualizer:
    """
    Visualize a single frame of a mouse video (skeleton + keypoints), with optional action title.
    Inspired by: https://www.kaggle.com/code/ambrosm/mabe-eda-which-makes-sense
    """
    paws = ['forepaw_left', 'forepaw_right', 'hindpaw_left', 'hindpaw_right']
    head = ['ear_left', 'ear_right', 'nose', 'ear_left']

    def __init__(self, train_df: pd.DataFrame, input_dir: Path):
        self.train = train_df
        self.input_dir = input_dir

    def load_video(self, train_idx: int):
        self.train_idx = train_idx
        row = self.train.iloc[train_idx]
        lab_id, video_id = row.lab_id, row.video_id
        tpath = self.input_dir / "train_tracking" / lab_id / f"{video_id}.parquet"
        self.video_name = f"{lab_id}/{video_id}"
        self.vid = pd.read_parquet(tpath)

        try:
            apath = self.input_dir / "train_annotation" / lab_id / f"{video_id}.parquet"
            self.annot = pd.read_parquet(apath)
        except FileNotFoundError:
            self.annot = None

        self.pvid = self.vid.pivot(columns=['mouse_id','bodypart'],
                                   index='video_frame', values=['x','y'])
        # convenience
        self.n_mouses = len(np.unique(self.pvid.columns.get_level_values('mouse_id')))

    def __len__(self):
        return len(self.pvid)

    def plot_frame(self, frame_idx: int):
        import matplotlib.pyplot as plt
        video_frame = self.pvid.index[frame_idx]
        # empty guard
        if (self.pvid.loc[video_frame] == 0).all().all():
            print(f"{self.train_idx}.{frame_idx} is empty.")
            return

        colors = ['g','b','orange','brown']
        for mouse in range(self.n_mouses):
            color = colors[mouse % len(colors)]
            mouse_id = mouse + 1
            mx = self.pvid.loc[video_frame, ('x', mouse_id)].copy()
            my = self.pvid.loc[video_frame, ('y', mouse_id)].copy()

            # head
            if 'nose' in mx.index and mx['nose'] != 0:
                plt.fill(mx[self.head], my[self.head], color=color, alpha=0.5)
                plt.scatter([mx['nose']], [my['nose']], s=100, color=color)
            else:
                # fall back to ears line
                have_ears = [p for p in ['ear_left','ear_right'] if p in mx.index]
                if len(have_ears) == 2:
                    plt.plot(mx[have_ears], my[have_ears], color=color)

            # synthesize head center if missing
            if 'head' not in mx.index:
                if all(p in mx.index for p in ['ear_left','ear_right']):
                    mx['head'] = mx[['ear_left','ear_right']].mean()
                    my['head'] = my[['ear_left','ear_right']].mean()

            # body + tail polyline
            parts = ['head']
            if 'neck' in mx.index and mx['neck'] != 0: parts.append('neck')
            if 'body_center' in mx.index and mx['body_center'] != 0: parts.append('body_center')
            if 'tail_base' in mx.index and mx['tail_base'] != 0: parts.append('tail_base')
            if 'tail_tip' in mx.index and mx['tail_tip'] != 0: parts.append('tail_tip')
            parts = [p for p in parts if p in mx.index]
            if len(parts) >= 2:
                plt.plot(mx[parts], my[parts], color=color)

            # width (laterals)
            if all(p in mx.index for p in ['lateral_right','lateral_left']):
                plt.plot(mx[['lateral_right','lateral_left']], my[['lateral_right','lateral_left']], color=color)

            # hips
            if all(p in mx.index for p in ['hip_right','hip_left']):
                plt.plot(mx[['hip_right','hip_left']], my[['hip_right','hip_left']], color=color)

            # paws
            if 'forepaw_left' in mx.index:
                have_paws = [p for p in self.paws if p in mx.index]
                if have_paws:
                    plt.scatter(mx[have_paws], my[have_paws], color=color)

        # title with active actions (if any)
        actions = ''
        if self.annot is not None:
            cur = set(self.annot.action[(self.annot.start_frame <= video_frame) & (video_frame <= self.annot.stop_frame)])
            actions = ', '.join(sorted(list(cur)))
        plt.title(f'{self.video_name} | frame={video_frame} | actions=[{actions}]')
        plt.gca().set_aspect('equal')
        plt.gca().invert_yaxis()
        plt.tight_layout()
        plt.show()


def _pivot_wide_xy(df_tracking: pd.DataFrame) -> pd.DataFrame:
    """Long->wide: (frame) x (mouse,part,{x,y}) → flat columns mouse{m}_{part}_{x|y}"""
    pivot_x = df_tracking.pivot(index='video_frame', columns=['mouse_id','bodypart'], values='x')
    pivot_y = df_tracking.pivot(index='video_frame', columns=['mouse_id','bodypart'], values='y')
    pivot_x.columns = [f"mouse{m}_{bp}_x" for m, bp in pivot_x.columns]
    pivot_y.columns = [f"mouse{m}_{bp}_y" for m, bp in pivot_y.columns]
    wide = pd.concat([pivot_x, pivot_y], axis=1).sort_index(axis=1)
    return wide


def _plot_skeleton_frame(frame_row: pd.Series, frame_idx: int, anatomical_parts=None, connections=None):
    import matplotlib.pyplot as plt
    if anatomical_parts is None:
        anatomical_parts = ['nose','ear_left','ear_right','neck','body_center','lateral_left','lateral_right','tail_base']
    if connections is None:
        connections = [
            ('nose','ear_left'), ('nose','ear_right'), ('ear_left','ear_right'),
            ('nose','neck'), ('neck','body_center'),
            ('body_center','lateral_left'), ('body_center','lateral_right'),
            ('body_center','tail_base')
        ]
    mouse_colors = {1:'blue',2:'orange',3:'green',4:'red'}
    plt.figure(figsize=(7.5,7.5))
    for mouse_id in range(1,5):
        # require a basic key to decide if this mouse exists
        kx = f'mouse{mouse_id}_nose_x'
        if kx not in frame_row or pd.isna(frame_row[kx]):
            continue
        for part in anatomical_parts:
            cx, cy = f'mouse{mouse_id}_{part}_x', f'mouse{mouse_id}_{part}_y'
            if cx in frame_row and cy in frame_row and pd.notna(frame_row[cx]) and pd.notna(frame_row[cy]):
                plt.scatter(frame_row[cx], frame_row[cy], color=mouse_colors[mouse_id])
        for a,b in connections:
            a_x, a_y = f'mouse{mouse_id}_{a}_x', f'mouse{mouse_id}_{a}_y'
            b_x, b_y = f'mouse{mouse_id}_{b}_x', f'mouse{mouse_id}_{b}_y'
            if all(k in frame_row for k in [a_x,a_y,b_x,b_y]) and pd.notna(frame_row[a_x]) and pd.notna(frame_row[b_x]):
                plt.plot([frame_row[a_x], frame_row[b_x]], [frame_row[a_y], frame_row[b_y]], color=mouse_colors[mouse_id], alpha=0.75)
    plt.title(f"Skeleton @ frame {frame_idx}")
    plt.gca().invert_yaxis()
    plt.gca().set_aspect('equal')
    plt.tight_layout()
    plt.show()


def _animate_behavior_segment(wide_df: pd.DataFrame, annot_row: pd.Series, padding: int = 20, save_html: bool = False, out_html_path: Path = None):
    """Create a simple skeleton animation around one annotated behavior."""
    import matplotlib.pyplot as plt
    from matplotlib.animation import FuncAnimation
    try:
        from IPython.display import HTML, display
    except Exception:
        HTML = None; display = print

    start, stop = int(annot_row['start_frame']), int(annot_row['stop_frame'])
    act  = annot_row['action']; agent = int(annot_row['agent_id'])
    a0 = max(0, start - padding); a1 = stop + padding
    anim_df = wide_df.loc[a0:a1].copy()

    fig, ax = plt.subplots(figsize=(7.5,7.5))
    x_min, x_max = anim_df.filter(like='_x').min().min(), anim_df.filter(like='_x').max().max()
    y_min, y_max = anim_df.filter(like='_y').min().min(), anim_df.filter(like='_y').max().max()
    pad = 50
    ax.set_xlim(x_min - pad, x_max + pad)
    ax.set_ylim(y_min - pad, y_max + pad)

    anatomical_parts = ['nose','ear_left','ear_right','neck','body_center','lateral_left','lateral_right','tail_base']
    connections = [
        ('nose','ear_left'), ('nose','ear_right'), ('ear_left','ear_right'),
        ('nose','neck'), ('neck','body_center'),
        ('body_center','lateral_left'), ('body_center','lateral_right'),
        ('body_center','tail_base')
    ]
    mouse_colors = {1:'blue',2:'orange',3:'green',4:'red'}

    def update(i):
        ax.clear()
        fr = anim_df.iloc[i]; real_fr = anim_df.index[i]
        for mouse_id in range(1,5):
            base = f'mouse{mouse_id}_nose_x'
            if base not in fr or pd.isna(fr[base]): 
                continue
            for p in anatomical_parts:
                cx,cy = f'mouse{mouse_id}_{p}_x', f'mouse{mouse_id}_{p}_y'
                if cx in fr and cy in fr and pd.notna(fr[cx]) and pd.notna(fr[cy]):
                    ax.scatter(fr[cx], fr[cy], color=mouse_colors[mouse_id])
            for a,b in connections:
                ax1, ay1 = f'mouse{mouse_id}_{a}_x', f'mouse{mouse_id}_{a}_y'
                bx1, by1 = f'mouse{mouse_id}_{b}_x', f'mouse{mouse_id}_{b}_y'
                if all(k in fr for k in [ax1,ay1,bx1,by1]) and pd.notna(fr[ax1]) and pd.notna(fr[bx1]):
                    ax.plot([fr[ax1], fr[bx1]], [fr[ay1], fr[by1]], color=mouse_colors[mouse_id], alpha=0.7)
        ax.set_title(f"Behavior '{act}' by mouse {agent} | frame={real_fr}")
        ax.set_xlim(x_min - pad, x_max + pad)
        ax.set_ylim(y_min - pad, y_max + pad)
        ax.invert_yaxis()
        return ax,

    ani = FuncAnimation(fig, update, frames=len(anim_df), interval=50, blit=False)
    if save_html:
        out_html_path = out_html_path or (CFG.WORK_DIR / "behavior_anim.html")
        try:
            html = ani.to_jshtml()
            with open(out_html_path, "w", encoding="utf-8") as f:
                f.write(html)
            print(f"[EDA] animation saved to: {out_html_path}")
        except Exception as e:
            print(f"[EDA] failed to save animation: {e}")

    if 'display' in globals():
        try:
            from IPython.display import HTML, display
            display(HTML(ani.to_jshtml()))
        except Exception:
            pass


def _build_full_annotations(train_meta: pd.DataFrame, input_dir: Path, max_videos: int = -1) -> pd.DataFrame:
    """Concatenate all train_annotation parquet files. If max_videos>0, truncate for speed."""
    from tqdm.auto import tqdm
    rows = []
    it = train_meta.iterrows()
    if max_videos and max_videos > 0:
        it = list(it)[:max_videos]
    for _, row in tqdm(it, total=(len(train_meta) if (max_videos in [-1, None, 0]) else min(max_videos, len(train_meta)))):
        lab_id = row['lab_id']; video_id = row['video_id']
        apath = input_dir / "train_annotation" / lab_id / f"{video_id}.parquet"
        if apath.exists():
            tmp = pd.read_parquet(apath)
            tmp['video_id'] = video_id
            rows.append(tmp)
    if len(rows) == 0:
        return pd.DataFrame(columns=['start_frame','stop_frame','action','agent_id','target_id','video_id'])
    return pd.concat(rows, ignore_index=True)


def run_eda():
    import seaborn as sns
    import matplotlib.pyplot as plt
    np.random.seed(CFG.EDA_RANDOM_SEED)

    print("=== EDA (rich) ===")

    # ---------- Meta quick look ----------
    print("\n[Meta] train.csv & test.csv")
    print(f"train shape: {train.shape} | test shape: {test.shape}")
    _safe_display_df_head(train, 5, "train.csv")
    _safe_display_df_head(test, 5, "test.csv")

    # FPS distribution
    if 'frames_per_second' in train.columns:
        fps_counts = train['frames_per_second'].value_counts().sort_index()
        print("\n[Meta] FPS distribution:")
        print(fps_counts)
        plt.figure(figsize=(6,3.5)); fps_counts.plot(kind='bar'); plt.title("FPS distribution (train)")
        plt.xlabel("FPS"); plt.ylabel("count"); plt.tight_layout(); plt.show()

    # body_parts_tracked combos
    bp_counts = train['body_parts_tracked'].value_counts()
    print("\n[Meta] top body_parts_tracked combos:")
    print(bp_counts.head(15))
    plt.figure(figsize=(8,3.5)); bp_counts.head(20).plot(kind='bar')
    plt.title("Top body_parts_tracked (top20)"); plt.xticks(rotation=90); plt.tight_layout(); plt.show()

    # labeled presence
    has_label = train['behaviors_labeled'].notna().astype(int)
    print(f"\n[Meta] videos with labeled behaviors: {has_label.sum()} / {len(train)}")
    plt.figure(figsize=(4,3)); has_label.value_counts().sort_index().plot(kind='bar')
    plt.title("Has labeled behaviors? (train)"); plt.xticks([0,1], ['No','Yes']); plt.tight_layout(); plt.show()

    # ---------- Visualize several videos (static frames) ----------
    print("\n[Frames] sampling a few videos and plotting key frames ...")
    vis = Visualizer(train, CFG.INPUT_DIR)
    for idx in CFG.EDA_SAMPLE_TRAIN_IDXS:
        if idx < 0 or idx >= len(train): 
            continue
        try:
            vis.load_video(idx)
            L = len(vis)
            steps = max(1, CFG.EDA_FRAMES_PER_VIDEO)
            # positions: evenly spaced (include first frame)
            frame_indices = np.linspace(0, max(0, L-1), num=steps, dtype=int)
            print(f"  - {vis.video_name}: total_frames={L} | sampled={list(frame_indices)}")
            for fi in frame_indices:
                vis.plot_frame(fi)
        except Exception as e:
            print(f"  [warn] fail {idx}: {e}")

    # ---------- One sample: pivot wide + skeleton frame ----------
    print("\n[Wide view] Pivot one sample video then plot skeleton @ one frame ...")
    # pick first available row that has an annotation file
    sample_row = None
    for i in range(len(train)):
        lab_id, vid = train.iloc[i].lab_id, train.iloc[i].video_id
        ap = CFG.INPUT_DIR / "train_annotation" / lab_id / f"{vid}.parquet"
        tp = CFG.INPUT_DIR / "train_tracking" / lab_id / f"{vid}.parquet"
        if ap.exists() and tp.exists():
            sample_row = train.iloc[i]
            break
    if sample_row is not None:
        lab_id = sample_row.lab_id; video_id = sample_row.video_id
        tpath = CFG.INPUT_DIR / "train_tracking" / lab_id / f"{video_id}.parquet"
        apath = CFG.INPUT_DIR / "train_annotation" / lab_id / f"{video_id}.parquet"
        df_tracking_sample = pd.read_parquet(tpath)
        df_annot_sample   = pd.read_parquet(apath)
        print(f"  sample -> {lab_id}/{video_id} | tracking={df_tracking_sample.shape} annot={df_annot_sample.shape}")
        wide = _pivot_wide_xy(df_tracking_sample)
        # plot a mid frame
        mid_fr = int(np.clip(len(wide)//2, 0, max(0, len(wide)-1)))
        _plot_skeleton_frame(wide.iloc[mid_fr], wide.index[mid_fr])

        # optional animation around first annotation
        if CFG.EDA_DO_ANIM and (len(df_annot_sample) > 0):
            print("  creating short animation around first annotated segment ...")
            _animate_behavior_segment(
                wide_df = wide,
                annot_row = df_annot_sample.iloc[0],
                padding = CFG.EDA_ANIM_PADDING_FR,
                save_html = CFG.EDA_SAVE_ANIM_HTML,
                out_html_path = CFG.WORK_DIR / f"anim_{lab_id}_{video_id}.html"
            )
    else:
        print("  [info] no sample with both tracking & annotation found.")

    # ---------- Build full annotations table ----------
    print("\n[Annotations] building full annotations table (this may take a bit) ...")
    df_annotations_full = _build_full_annotations(
        train_meta=train,
        input_dir=CFG.INPUT_DIR,
        max_videos=CFG.EDA_MAX_ANNOT_VIDEOS
    )
    print(f"  annotations shape: {df_annotations_full.shape}")
    _safe_display_df_head(df_annotations_full, 5, "annotations_full")

    if len(df_annotations_full) > 0:
        # Behavior frequency
        print("\n[Behavior] frequency across all annotations")
        beh_counts = df_annotations_full['action'].value_counts()
        plt.figure(figsize=(10,4)); sns.barplot(x=beh_counts.index, y=beh_counts.values)
        plt.title("Behavior frequency"); plt.xticks(rotation=45, ha='right'); plt.tight_layout(); plt.show()

        # Duration analysis
        print("\n[Behavior] duration statistics (frames)")
        df_annotations_full['duration_frames'] = df_annotations_full['stop_frame'] - df_annotations_full['start_frame']
        zero_duration = int((df_annotations_full['duration_frames'] == 0).sum())
        print(f"  zero-duration events: {zero_duration}")
        print(df_annotations_full['duration_frames'].describe())

        plt.figure(figsize=(10,4))
        vals = (df_annotations_full['duration_frames'] + 1).clip(lower=1)  # avoid log(0)
        plt.hist(vals, bins=100, log=True)
        plt.title("Distribution of behavior durations (log scale, +1)")
        plt.xlabel("duration+1 (frames)"); plt.ylabel("count (log)"); plt.tight_layout(); plt.show()

        order = df_annotations_full.groupby('action')['duration_frames'].median().sort_values(ascending=False).index
        plt.figure(figsize=(11,6))
        sns.boxplot(x=(df_annotations_full['duration_frames']+1), y='action', data=df_annotations_full, order=order)
        plt.xscale('log'); plt.title("Duration per behavior (log scale, +1)")
        plt.xlabel("duration+1 (frames)"); plt.ylabel("behavior"); plt.tight_layout(); plt.show()

        # Lab variability (proportions per lab)
        print("\n[Lab] behavior composition by lab (proportions)")
        df_lab = train[['video_id','lab_id']].copy()
        ann_lab = df_annotations_full.merge(df_lab, on='video_id', how='left')
        crosstab = pd.crosstab(ann_lab['lab_id'], ann_lab['action'], normalize='index')
        plt.figure(figsize=(12,7))
        sns.heatmap(crosstab, cmap='viridis')
        plt.title("Proportion of behaviors by lab")
        plt.xlabel("behavior"); plt.ylabel("lab_id"); plt.tight_layout(); plt.show()
    else:
        print("  [info] no annotations found; skip behavior stats.")




In [None]:
# ---------------------------
# 8) Thresholds & robustify & inference helpers
# ---------------------------
from collections import defaultdict
action_thresholds = {
    "default": 0.27,
    "single_default": 0.27,
    "pair_default": 0.27,
    "single": {"rear": 0.30},
}

def _select_threshold_map(thresholds, mode: str):
    if isinstance(thresholds, dict):
        if any(k in thresholds for k in ("single","pair","single_default","pair_default")):
            base_default = float(thresholds.get("default", 0.27))
            mode_default = float(thresholds.get(f"{mode}_default", base_default))
            mode_overrides = thresholds.get(mode, {}) or {}
            out = defaultdict(lambda: mode_default)
            out.update({str(k): float(v) for k, v in mode_overrides.items()})
            return out
        out = defaultdict(lambda: float(thresholds.get("default", 0.27)))
        out.update({str(k): float(v) for k, v in thresholds.items() if k != "default"})
        return out
    return defaultdict(lambda: 0.27)

def predict_multiclass_adaptive(pred_df: pd.DataFrame, meta_df: pd.DataFrame, thresholds):
    pred_smoothed = pred_df.rolling(window=5, min_periods=1, center=True).mean()
    mode = "single" if ("target_id" in meta_df.columns and meta_df["target_id"].eq("self").all()) else "pair"
    ama = np.argmax(pred_smoothed.values, axis=1)
    th_map = _select_threshold_map(thresholds, mode)
    max_probs = pred_smoothed.max(axis=1).values
    threshold_mask = np.zeros(len(pred_smoothed), dtype=bool)
    for i, action in enumerate(pred_smoothed.columns):
        action_mask = (ama == i)
        thr = th_map[action]
        threshold_mask |= (action_mask & (max_probs >= thr))
    ama = np.where(threshold_mask, ama, -1)
    ama = pd.Series(ama, index=meta_df.video_frame)

    changes_mask = (ama != ama.shift(1)).values
    ama_changes = ama[changes_mask]; meta_changes = meta_df[changes_mask]
    mask = ama_changes.values >= 0
    if len(mask) == 0: 
        return pd.DataFrame(columns=["video_id","agent_id","target_id","action","start_frame","stop_frame"])
    mask[-1] = False

    sub = pd.DataFrame({
        "video_id": meta_changes["video_id"][mask].values,
        "agent_id": meta_changes["agent_id"][mask].values,
        "target_id": meta_changes["target_id"][mask].values,
        "action": pred_smoothed.columns[ama_changes[mask].values],
        "start_frame": ama_changes.index[mask],
        "stop_frame": ama_changes.index[1:][mask[:-1]]
    })

    stop_vid = meta_changes["video_id"][1:][mask[:-1]].values
    stop_agent = meta_changes["agent_id"][1:][mask[:-1]].values
    stop_target = meta_changes["target_id"][1:][mask[:-1]].values
    for i in range(len(sub)):
        vid = sub.video_id.iloc[i]; ag = sub.agent_id.iloc[i]; tg = sub.target_id.iloc[i]
        if i < len(stop_vid):
            if (stop_vid[i] != vid) or (stop_agent[i] != ag) or (stop_target[i] != tg):
                new_stop = meta_df.query("(video_id == @vid)").video_frame.max() + 1
                sub.iat[i, sub.columns.get_loc("stop_frame")] = new_stop
        else:
            new_stop = meta_df.query("(video_id == @vid)").video_frame.max() + 1
            sub.iat[i, sub.columns.get_loc("stop_frame")] = new_stop

    dur = sub.stop_frame - sub.start_frame
    sub = sub[dur >= 3].reset_index(drop=True)
    if len(sub) > 0:
        assert (sub.stop_frame > sub.start_frame).all()
    return sub

def robustify(submission: pd.DataFrame, dataset: pd.DataFrame, traintest: str, traintest_directory=None):
    if traintest_directory is None:
        traintest_directory = f"{CFG.INPUT_DIR}/{traintest}_tracking"
    submission = submission[submission.start_frame < submission.stop_frame]
    group_list = []
    for _, g in submission.groupby(["video_id","agent_id","target_id"]):
        g = g.sort_values("start_frame")
        mask = np.ones(len(g), dtype=bool); last_stop = -1
        for i, (_, row) in enumerate(g.iterrows()):
            if row["start_frame"] < last_stop: mask[i] = False
            else: last_stop = row["stop_frame"]
        group_list.append(g[mask])
    submission = pd.concat(group_list) if group_list else submission

    s_list = []
    for _, row in dataset.iterrows():
        lab_id = row["lab_id"]; video_id = row["video_id"]
        if (submission.video_id == video_id).any(): continue
        path = f"{traintest_directory}/{lab_id}/{video_id}.parquet"
        vid = pd.read_parquet(path)
        vb = json.loads(row["behaviors_labeled"])
        vb = sorted(list({b.replace("'", "") for b in vb}))
        vb = [b.split(",") for b in vb]
        vb = pd.DataFrame(vb, columns=["agent","target","action"])
        start_frame = vid.video_frame.min(); stop_frame  = vid.video_frame.max() + 1
        for (agent, target), actions in vb.groupby(["agent","target"]):
            batch_len = int(np.ceil((stop_frame - start_frame) / len(actions)))
            for i, (_, arow) in enumerate(actions.iterrows()):
                b_start = start_frame + i * batch_len
                b_stop  = min(b_start + batch_len, stop_frame)
                s_list.append((video_id, agent, target, arow["action"], b_start, b_stop))
    if len(s_list) > 0:
        fill_df = pd.DataFrame(s_list, columns=["video_id","agent_id","target_id","action","start_frame","stop_frame"])
        submission = pd.concat([submission, fill_df], ignore_index=True)
    return submission.reset_index(drop=True)



In [None]:
# ----- model I/O for inference -----
def _read_feature_columns(cols_path: Path):
    with open(cols_path, "r") as f:
        cols = [line.strip() for line in f.readlines() if len(line.strip())]
    return cols

def _align_features(X_df: pd.DataFrame, cols_order):
    for c in cols_order:
        if c not in X_df.columns:
            X_df[c] = 0.0
    return X_df[cols_order]

def _load_xgb(path: Path):
    from xgboost import XGBClassifier
    clf = XGBClassifier()
    try:
        clf.load_model(str(path)); return clf
    except Exception:
        return joblib.load(path)

def _load_cat(path: Path):
    try:
        from catboost import CatBoostClassifier
        model = CatBoostClassifier(); model.load_model(str(path)); return model
    except Exception:
        return joblib.load(path)

def _load_lgbm(path: Path):
    try:
        booster = lgb.Booster(model_file=str(path))
        return ("booster", booster)
    except Exception:
        model = joblib.load(path)
        return ("sk", model)

def _predict_proba_any(model_entry, X_np: np.ndarray):
    if isinstance(model_entry, tuple) and model_entry[0] in ("booster","sk"):
        kind, obj = model_entry
        if kind == "booster":
            p = obj.predict(X_np)
            return p[:,1] if (p.ndim == 2 and p.shape[1] == 2) else p
        else:
            try:
                return obj.predict_proba(X_np)[:,1]
            except Exception:
                p = obj.predict(X_np)
                return p[:,1] if (p.ndim == 2 and p.shape[1] == 2) else p
    try:
        return model_entry.predict_proba(X_np)[:,1]
    except Exception:
        p = model_entry.predict(X_np)
        return p[:,1] if (p.ndim == 2 and p.shape[1] == 2) else p

def _tta_predict_models(models, X_np: np.ndarray, n: int, std: float, seed: int):
    if n <= 0 or std <= 0:
        probs = [ _predict_proba_any(m, X_np) for m in models ]
        return np.mean(np.stack(probs, 0), 0)
    rng = np.random.default_rng(seed)
    outs = []
    for _ in range(n):
        Xn = X_np + rng.normal(0.0, std, size=X_np.shape).astype(np.float32)
        probs = [ _predict_proba_any(m, Xn) for m in models ]
        outs.append(np.mean(np.stack(probs, 0), 0))
    return np.mean(np.stack(outs, 0), 0)

def _gather_models(action_dir: Path):
    models = []
    for k in range(CFG.N_FOLDS_XGB):
        p = action_dir / f"xgb_fold{k}.json"
        if p.exists() and XGBOOST_AVAILABLE: models.append(_load_xgb(p))
    p = action_dir / "xgb_full.json"
    if p.exists() and XGBOOST_AVAILABLE: models.append(_load_xgb(p))
    p = action_dir / "cat_full.cbm"
    if p.exists() and CATBOOST_AVAILABLE: models.append(_load_cat(p))
    for i in range(1, 4):
        p = action_dir / f"lgbm_full_{i}.txt"
        if p.exists(): models.append(_load_lgbm(p))
    return models

def inference_for_section_mode(section_idx: int, bpt_str: str, mode: str, submission_list: list):
    bslug = bpt_slug(bpt_str)
    meta_path = CFG.META_DIR / f"{bslug}_{mode}.json"
    if not meta_path.exists():
        print(f"  [WARN] meta not found for {bslug} {mode}, skip.")
        return
    with open(meta_path, "r") as f:
        meta_info = json.load(f)
    body_parts = meta_info["body_parts_tracked"]
    actions     = meta_info["actions"]

    test_subset = test[test.body_parts_tracked == bpt_str]
    if len(test_subset) == 0:
        print("  -> No test rows."); return
    fps_lookup = (
        test_subset[["video_id","frames_per_second"]]
        .drop_duplicates("video_id").set_index("video_id")["frames_per_second"].to_dict()
    )

    # load models & feature columns per action
    action_models, action_cols = {}, {}
    for action in actions:
        action_dir = CFG.MODEL_DIR / bslug / mode / action
        cols_file  = action_dir / "feature_columns.txt"
        if not (action_dir.exists() and cols_file.exists()): continue
        models = _gather_models(action_dir)
        if len(models) == 0: continue
        action_models[action] = models
        action_cols[action]   = _read_feature_columns(cols_file)
    if len(action_models) == 0:
        print("  -> No models found for this section/mode"); return

    gen = generate_mouse_data(test_subset, "test",
                              generate_single=(mode=="single"),
                              generate_pair=(mode=="pair"),
                              verbose=False)
    for switch_te, data_te, meta_te, actions_te in gen:
        if switch_te != mode: continue
        try:
            fps_i = _fps_from_meta(meta_te, fps_lookup, default_fps=30.0)
            if mode == "single":
                X_te = transform_single(data_te, body_parts, fps_i).astype(np.float32)
            else:
                X_te = transform_pair(data_te, body_parts, fps_i).astype(np.float32)

            pred_df = pd.DataFrame(index=meta_te.video_frame)
            X_cache = {}
            for action in actions_te:
                if action not in action_models: continue
                cols = action_cols[action]
                if action not in X_cache:
                    Xa = _align_features(X_te.copy(), cols)
                    X_cache[action] = Xa.to_numpy(np.float32, copy=False)
                models = action_models[action]
                probs = _tta_predict_models(models, X_cache[action], CFG.TTA_N, CFG.TTA_NOISE_STD, CFG.SEED+123)
                pred_df[action] = probs

            if pred_df.shape[1] > 0:
                sub_part = predict_multiclass_adaptive(pred_df, meta_te, action_thresholds)
                if len(sub_part): submission_list.append(sub_part)

            del X_te, data_te; gc.collect()
        except Exception as e:
            print(f"  [ERR] inference error: {str(e)[:120]}")
            try: del data_te
            except: pass
            gc.collect()

In [None]:
# ---------------------------
# 9) MAIN (single)
# ---------------------------
if __name__ == "__main__":
    print("CFG:", {k:getattr(CFG,k) for k in dir(CFG) if k.isupper()})

    if CFG.VIS:
        run_eda()

    if CFG.TRAIN:
        print("\n=== TRAIN START ===")
        for section in range(1, len(body_parts_tracked_list)):
            bpt_str = body_parts_tracked_list[section]
            try:
                body_parts = json.loads(bpt_str)
                print(f">> Section {section}: {len(body_parts)} body parts")
            except Exception:
                print(f">> Section {section}: [invalid body_parts json] -> skip")
                continue
            train_for_section(section, bpt_str); gc.collect()
        print("=== TRAIN DONE ===\n")

    if CFG.SUB:
        print("\n=== INFERENCE / SUB START ===")
        submission_list = []
        for section in range(1, len(body_parts_tracked_list)):
            bpt_str = body_parts_tracked_list[section]
            try:
                _ = json.loads(bpt_str)
            except Exception:
                print(f">> Section {section}: invalid body_parts json -> skip")
                continue
            print(f">> Section {section}: ensemble inference")
            for mode in ("single","pair"):
                print(f"   - mode={mode}")
                inference_for_section_mode(section, bpt_str, mode, submission_list)
                gc.collect()

        if len(submission_list) > 0:
            submission = pd.concat(submission_list, ignore_index=True)
        else:
            submission = pd.DataFrame({
                "video_id":[438887472],
                "agent_id":["mouse1"],
                "target_id":["self"],
                "action":["rear"],
                "start_frame":[278],
                "stop_frame":[500],
            })

        submission_robust = robustify(submission, test, "test")
        submission_robust.index.name = "row_id"
        submission_robust.to_csv(CFG.SUB_PATH)
        print(f"\nSubmission created: {len(submission_robust)} rows -> {CFG.SUB_PATH}")
        print(submission_robust.head())
        print("=== INFERENCE / SUB DONE ===")


