In [None]:
%%writefile baseline.py

from __future__ import annotations

import argparse
import json
import logging
import os
import warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Tuple, Any
from collections import defaultdict
import joblib
import numpy as np
import polars as pl
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from xgboost import XGBClassifier
from tqdm.auto import tqdm

# ========================
# Config
# ========================
ALLOWED_AGENT = {f"mouse{i}" for i in range(1, 5)}
ALLOWED_TARGET = {f"mouse{i}" for i in range(1, 5)} | {"self"}

@dataclass(frozen=True)
class Config:
    data_root: Path = Path(os.getenv("MABE_DATA_ROOT", "/kaggle/input/MABe-mouse-behavior-detection"))
    submission_file: str = os.getenv("MABe_SUBMISSION", "submission.csv")
    row_id_col: str = os.getenv("MABe_ROW_ID_COL", "row_id")

    @property
    def train_csv(self) -> Path: return self.data_root / "train.csv"
    @property
    def test_csv(self) -> Path: return self.data_root / "test.csv"
    @property
    def train_annot_dir(self) -> Path: return self.data_root / "train_annotation"
    @property
    def train_track_dir(self) -> Path: return self.data_root / "train_tracking"
    @property
    def test_track_dir(self) -> Path: return self.data_root / "test_tracking"

    @property
    def submission_schema(self) -> Dict[str, pl.DataType]:
        return {
            "video_id": pl.Int64, "agent_id": pl.Utf8, "target_id": pl.Utf8,
            "action": pl.Utf8, "start_frame": pl.Int64, "stop_frame": pl.Int64,
        }

    @property
    def solution_schema(self) -> Dict[str, pl.DataType]:
        return {
            "video_id": pl.Int64, "agent_id": pl.Utf8, "target_id": pl.Utf8,
            "action": pl.Utf8, "start_frame": pl.Int64, "stop_frame": pl.Int64,
            "lab_id": pl.Utf8, "behaviors_labeled": pl.Utf8,
        }

logger = logging.getLogger(__name__)

class HostVisibleError(Exception): pass

def setup_logging(verbosity: int = 1) -> None:
    level = logging.WARNING if verbosity <= 0 else logging.INFO if verbosity == 1 else logging.DEBUG
    logging.basicConfig(level=level, format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", force=True)

# ========================
# Utils & Validators
# ========================

def safe_json_loads(s: Optional[str]) -> List[str]:
    if s is None: return []
    if isinstance(s, list): return [str(x) for x in s]
    if not isinstance(s, str): return []
    s = s.strip()
    if not s: return []
    try:
        return json.loads(s)
    except Exception:
        try: return json.loads(s.replace("'", '"'))
        except Exception: return []

def validate_schema(df: pl.DataFrame, schema: Dict[str, pl.DataType], name: str) -> pl.DataFrame:
    missing = set(schema.keys()) - set(df.columns)
    if missing: raise ValueError(f"{name} is missing columns: {missing}")
    casts = [pl.col(col).cast(dtype) for col, dtype in schema.items() if df[col].dtype != dtype]
    return df.with_columns(casts) if casts else df

def validate_frame_ranges(df: pl.DataFrame, name: str) -> None:
    if not (df["start_frame"] <= df["stop_frame"]).all():
        raise ValueError(f"{name}: start_frame > stop_frame detected")

def _norm_mouse_id(x: str | int) -> str:
    s = str(x)
    return s if s.startswith("mouse") else f"mouse{s}"

def _norm_triplet(agent: str | int, target: str | int, action: str) -> str:
    return f"{_norm_mouse_id(agent)},{_norm_mouse_id(target)},{action}"

def _range_frames(start: int, stop: int) -> Iterable[int]:
    return range(start, stop)  # [start, stop)

def merge_intervals(intervals: List[Tuple[int,int]]) -> List[Tuple[int,int]]:
    if not intervals: return []
    intervals = sorted(intervals)
    merged = [intervals[0]]
    for s,e in intervals[1:]:
        ps,pe = merged[-1]
        if s <= pe: merged[-1] = (ps, max(pe, e))
        else: merged.append((s,e))
    return merged

def split_interval(s: int, e: int, parts: int) -> List[Tuple[int,int]]:
    if parts <= 1: return [(s,e)]
    L = e - s
    step = L // parts
    rem = L % parts
    out = []
    cur = s
    for i in range(parts):
        extra = 1 if i < rem else 0
        nxt = cur + step + extra
        out.append((cur, min(nxt, e)))
        cur = nxt
    return out

def largest_remainder_allocation(total: int, weights: List[float]) -> List[int]:
    if total <= 0 or not weights: return [0]*len(weights)
    s = sum(weights) or 1.0
    w = [x/s for x in weights]
    raw = [total*x for x in w]
    base = [int(v) for v in raw]
    remainder = total - sum(base)
    if remainder > 0:
        fr = sorted([(i, raw[i]-base[i]) for i in range(len(w))], key=lambda x: x[1], reverse=True)
        for i in range(remainder):
            base[fr[i % len(w)][0]] += 1
    return base

# ========================
# Metrics (F-beta)
# ========================

def single_lab_f1(lab_solution: pl.DataFrame, lab_submission: pl.DataFrame, beta: float = 1.0) -> float:
    label_frames: Dict[str, Set[int]] = defaultdict(set)
    for row in lab_solution.to_dicts():
        label_frames[row["label_key"]].update(_range_frames(row["start_frame"], row["stop_frame"]))

    active_by_video: Dict[int, Set[str]] = {}
    for row in lab_solution.select(["video_id", "behaviors_labeled"]).unique().to_dicts():
        s: Set[str] = set()
        for item in safe_json_loads(row["behaviors_labeled"]):
            parts = [p.strip() for p in str(item).replace("'", "").split(",")]
            if len(parts) == 3:
                a, t, act = parts
                s.add(_norm_triplet(a, t, act))
        active_by_video[int(row["video_id"])] = s

    prediction_frames: Dict[str, Set[int]] = defaultdict(set)
    for video_id in lab_solution["video_id"].unique():
        active = active_by_video.get(int(video_id), set())
        predicted_mouse_pairs: Dict[str, Set[int]] = defaultdict(set)
        for row in lab_submission.filter(pl.col("video_id") == video_id).to_dicts():
            triple_norm = _norm_triplet(row["agent_id"], row["target_id"], row["action"])
            if triple_norm not in active:
                continue
            pred_key = row["prediction_key"]
            agent_target = f"{row['agent_id']},{row['target_id']}"
            new_frames = set(_range_frames(row["start_frame"], row["stop_frame"]))
            new_frames -= prediction_frames[pred_key]
            if predicted_mouse_pairs[agent_target] & new_frames:
                raise HostVisibleError("Multiple predictions for the same frame from one agent/target pair")
            prediction_frames[pred_key].update(new_frames)
            predicted_mouse_pairs[agent_target].update(new_frames)

    tps: Dict[str, int] = defaultdict(int)
    fns: Dict[str, int] = defaultdict(int)
    fps: Dict[str, int] = defaultdict(int)
    distinct_actions: Set[str] = set()

    for key, pred_frames in prediction_frames.items():
        action = key.split("_")[-1]
        distinct_actions.add(action)
        gt_frames = label_frames.get(key, set())
        tps[action] += len(pred_frames & gt_frames)
        fns[action] += len(gt_frames - pred_frames)
        fps[action] += len(pred_frames - gt_frames)

    for key, gt_frames in label_frames.items():
        action = key.split("_")[-1]
        distinct_actions.add(action)
        if key not in prediction_frames:
            fns[action] += len(gt_frames)

    if not distinct_actions:
        return 0.0

    beta2 = beta * beta
    f_scores: List[float] = []
    for action in distinct_actions:
        tp, fn, fp = tps[action], fns[action], fps[action]
        denom = (1 + beta2) * tp + beta2 * fn + fp
        f_scores.append(0.0 if denom == 0 else (1 + beta2) * tp / denom)
    return sum(f_scores) / len(f_scores)

def mouse_fbeta(solution: pl.DataFrame, submission: pl.DataFrame, beta: float = 1.0, cfg: Optional[Config] = None) -> float:
    cfg = cfg or Config()
    solution = validate_schema(solution, cfg.solution_schema, "Solution")
    submission = validate_schema(submission, cfg.submission_schema, "Submission")
    validate_frame_ranges(solution, "Solution")
    validate_frame_ranges(submission, "Submission")

    solution_videos = solution["video_id"].unique()
    submission = submission.filter(pl.col("video_id").is_in(solution_videos))

    def add_key(df: pl.DataFrame, col_name: str) -> pl.DataFrame:
        return df.with_columns(
            pl.concat_str(
                [
                    pl.col("video_id").cast(pl.Utf8),
                    pl.col("agent_id").cast(pl.Utf8),
                    pl.col("target_id").cast(pl.Utf8),
                    pl.col("action"),
                ],
                separator="_",
            ).alias(col_name)
        )

    solution = add_key(solution, "label_key")
    submission = add_key(submission, "prediction_key")

    lab_scores: List[float] = []
    for lab_id in solution["lab_id"].unique():
        lab_solution = solution.filter(pl.col("lab_id") == lab_id)
        lab_videos = lab_solution["video_id"].unique()
        lab_submission = submission.filter(pl.col("video_id").is_in(lab_videos))
        lab_scores.append(single_lab_f1(lab_solution, lab_submission, beta=beta))

    return sum(lab_scores) / len(lab_scores) if lab_scores else 0.0

def score(solution: pl.DataFrame, submission: pl.DataFrame, row_id_column_name: str = "", beta: float = 1.0, cfg: Optional[Config] = None) -> float:
    if row_id_column_name:
        solution = solution.drop(row_id_column_name, strict=False)
        submission = submission.drop(row_id_column_name, strict=False)
    return mouse_fbeta(solution, submission, beta=beta, cfg=cfg)

# ========================
# Build solution + video spans
# ========================

def create_solution_df(dataset: pl.DataFrame, cfg: Optional[Config] = None) -> pl.DataFrame:
    cfg = cfg or Config()
    records: List[pl.DataFrame] = []
    for row in tqdm(dataset.to_dicts(), total=len(dataset), desc="Building solution"):
        lab_id: str = row["lab_id"]
        if lab_id.startswith("MABe22"): continue
        video_id: int = row["video_id"]
        annot_path = cfg.train_annot_dir / lab_id / f"{video_id}.parquet"
        if not annot_path.exists():
            logger.warning("No annotations for %s", annot_path)
            continue
        try:
            annot = pl.read_parquet(annot_path).with_columns(
                [
                    pl.lit(lab_id).alias("lab_id"),
                    pl.lit(video_id).alias("video_id"),
                    pl.lit(row["behaviors_labeled"]).alias("behaviors_labeled"),
                    pl.concat_str([pl.lit("mouse"), pl.col("agent_id").cast(pl.Utf8)]).alias("agent_id"),
                    pl.concat_str([pl.lit("mouse"), pl.col("target_id").cast(pl.Utf8)]).alias("target_id"),
                ]
            )
            for col, dtype in (cfg.solution_schema).items():
                if col in annot.columns and annot[col].dtype != dtype:
                    annot = annot.with_columns(pl.col(col).cast(dtype))
            annot = annot.select([c for c in cfg.solution_schema.keys() if c in annot.columns])
            records.append(annot)
        except Exception as e:
            logger.error("Failed to load %s: %s", annot_path, e)
            continue
    if not records: raise ValueError("No annotation files loaded.")
    solution = pl.concat(records, how="vertical")
    solution = validate_schema(solution, cfg.solution_schema, "Solution")
    return solution

def build_video_spans(dataset: pl.DataFrame, split: str, cfg: Optional[Config] = None) -> Dict[int, Tuple[int,int]]:
    """
    Map video_id -> (min_frame, max_frame+1).
    """
    cfg = cfg or Config()
    track_dir = cfg.train_track_dir if split == "train" else cfg.test_track_dir
    spans: Dict[int, Tuple[int,int]] = {}
    for row in tqdm(dataset.to_dicts(), total=len(dataset), desc="Scanning spans"):
        lab_id = row["lab_id"]
        if lab_id.startswith("MABe22"): continue
        vid = row["video_id"]
        path = track_dir / lab_id / f"{vid}.parquet"
        if not path.exists(): continue
        try:
            df = pl.read_parquet(path).select(["video_frame"])
            s = int(df["video_frame"].min())
            e = int(df["video_frame"].max()) + 1
            spans[int(vid)] = (s,e)
        except Exception as e:
            logger.warning("Span read failed for %s: %s", path, e)
    return spans

# ========================
# Priors (duration & timing)
# ========================

def compute_action_priors(solution: pl.DataFrame, eps: float = 1.0) -> Tuple[Dict[str, Dict[str, float]], Dict[str, float], Dict[str, Dict[str, int]], Dict[str, int]]:
    """
    Returns:
      per_lab_weight: {lab: {action: weight_share}}
      global_weight: {action: weight_share}
      per_lab_med_dur: {lab: {action: median_duration_frames}}
      global_med_dur: {action: median_duration_frames}
    """
    sol = solution.with_columns((pl.col("stop_frame") - pl.col("start_frame")).alias("dur"))
    # shares
    by_lab = sol.group_by(["lab_id", "action"]).agg(pl.col("dur").sum().alias("dur_sum"))
    global_ = sol.group_by(["action"]).agg(pl.col("dur").sum().alias("dur_sum"))
    actions = set(global_["action"].to_list())

    per_lab_weight: Dict[str, Dict[str, float]] = defaultdict(dict)
    for lab in by_lab["lab_id"].unique():
        sub = by_lab.filter(pl.col("lab_id") == lab)
        dmap = {r["action"]: float(r["dur_sum"]) for r in sub.to_dicts()}
        for a in actions: dmap[a] = dmap.get(a, 0.0) + eps
        total = sum(dmap.values()) or 1.0
        per_lab_weight[str(lab)] = {a: dmap[a]/total for a in actions}

    gmap = {r["action"]: float(r["dur_sum"]) for r in global_.to_dicts()}
    for a in actions: gmap[a] = gmap.get(a, 0.0) + eps
    gtotal = sum(gmap.values()) or 1.0
    global_weight = {a: gmap[a]/gtotal for a in actions}

    # median durations
    med_by_lab = sol.group_by(["lab_id", "action"]).median().select(["lab_id","action","dur"])
    per_lab_med_dur: Dict[str, Dict[str, int]] = defaultdict(dict)
    for r in med_by_lab.to_dicts():
        per_lab_med_dur[str(r["lab_id"])][str(r["action"])] = int(r["dur"])
    med_global = sol.group_by(["action"]).median().select(["action","dur"])
    global_med_dur: Dict[str, int] = {r["action"]: int(r["dur"]) for r in med_global.to_dicts()}

    return per_lab_weight, global_weight, per_lab_med_dur, global_med_dur

def compute_timing_priors(solution: pl.DataFrame, video_spans: Dict[int, Tuple[int,int]]) -> Tuple[Dict[str, Dict[str, float]], Dict[str, float]]:
    """
    Median start percentile per (lab, action) and global.
    start_pct = (start_frame - video_start)/(video_stop - video_start)
    """
    # attach start_pct
    def start_pct_func(row) -> float:
        vid = int(row["video_id"])
        if vid not in video_spans: return 0.5
        s,e = video_spans[vid]
        denom = max(1, e - s)
        return float(max(0, min(1, (int(row["start_frame"]) - s) / denom)))

    rows = []
    for r in solution.select(["lab_id","action","video_id","start_frame"]).to_dicts():
        rows.append({"lab_id": r["lab_id"], "action": r["action"], "start_pct": start_pct_func(r)})
    df = pl.DataFrame(rows)
    by_lab = df.group_by(["lab_id","action"]).median().select(["lab_id","action","start_pct"])
    per_lab: Dict[str, Dict[str, float]] = defaultdict(dict)
    for r in by_lab.to_dicts():
        per_lab[str(r["lab_id"])][str(r["action"])] = float(r["start_pct"])
    g = df.group_by(["action"]).median().select(["action","start_pct"])
    global_: Dict[str, float] = {r["action"]: float(r["start_pct"]) for r in g.to_dicts()}
    return per_lab, global_

# ========================
# Tracking features → windows (TỐI ƯU HIỆU NĂNG - KHÔNG DÙNG PANDAS)
# ========================

def _strip_mouse_prefix(s: str | int) -> str:
    s = str(s)
    return s[5:] if s.startswith("mouse") else s

def _pair_features(df: pl.DataFrame, agent_raw: str, target_raw: str, downsample: int = 1) -> Optional[pl.DataFrame]:
    """
    Extract pairwise features between agent and target.
    Automatically handles different tracking schemas (with/without bodypart).
    """
    # Candidate column names — prioritize centroid-based names
    frame_candidates = ["video_frame", "frame", "frame_idx"]
    id_candidates    = ["mouse_id", "id", "track_id", "agent_id"]
    x_candidates     = ["centroid_x", "cx", "x", "x_pos", "x_position", "x_mm"]
    y_candidates     = ["centroid_y", "cy", "y", "y_pos", "y_position", "y_mm"]

    cols = set(df.columns)
    frame_col = next((c for c in frame_candidates if c in cols), None)
    id_col    = next((c for c in id_candidates    if c in cols), None)
    x_col     = next((c for c in x_candidates     if c in cols), None)
    y_col     = next((c for c in y_candidates     if c in cols), None)

    if not all([frame_col, id_col, x_col, y_col]):
        logger.warning(f"Schema detection failed. Available: {list(cols)} | Required: frame, id, x, y")
        return None

    a_id = _strip_mouse_prefix(agent_raw)
    t_id = _strip_mouse_prefix(target_raw)

    # Xử lý bodypart — nếu có, lấy body_center hoặc trung bình
    if "bodypart" in cols:
        # Ưu tiên bodypart trung tâm
        preferred_parts = ["body_center", "centroid", "center", "torso", "hip_center", "mid"]
        bodypart_col = "bodypart"

        # Lọc các bodypart mong muốn
        available_parts = df.select(bodypart_col).unique().to_series().to_list()
        chosen_parts = [bp for bp in preferred_parts if bp in available_parts]

        if not chosen_parts and available_parts:
            # Nếu không có preferred → lấy tất cả
            chosen_parts = available_parts[:3]  # giới hạn để tránh noise

        if chosen_parts:
            df = df.filter(pl.col(bodypart_col).is_in(chosen_parts))
            # Group by frame, mouse_id → lấy trung bình x,y
            df = df.group_by([frame_col, id_col]).agg([
                pl.col(x_col).mean().alias(x_col),
                pl.col(y_col).mean().alias(y_col)
            ])
        else:
            logger.warning("bodypart exists but no valid parts found. Using all.")
            df = df.group_by([frame_col, id_col]).agg([
                pl.col(x_col).mean().alias(x_col),
                pl.col(y_col).mean().alias(y_col)
            ])

    # Chuẩn hoá ID → ép về string, loại bỏ prefix nếu cần
    try:
        df = df.with_columns([
            pl.col(id_col).cast(pl.Utf8).str.replace("mouse", "").str.strip_chars().alias(id_col)
        ])
    except Exception as e:
        logger.warning(f"ID normalization failed: {e}. Proceeding with raw IDs.")

    # Lọc dữ liệu agent và target
    a_df = df.filter(pl.col(id_col) == a_id).sort(frame_col).unique(subset=[frame_col], keep="first")
    b_df = df.filter(pl.col(id_col) == t_id).sort(frame_col).unique(subset=[frame_col], keep="first")

    if a_df.is_empty() or b_df.is_empty():
        logger.warning(f"Agent {a_id} or target {t_id} not found in tracking data.")
        return None

    # Gộp theo frame
    merged = a_df.join(b_df, on=frame_col, how="inner", suffix="_b")

    # Downsample
    if downsample > 1:
        merged = merged.filter(pl.int_range(0, pl.count()).over(frame_col) % downsample == 0)

    if merged.is_empty():
        return None

    try:
        # Tính toán features — fill null để tránh lỗi
        feat = merged.with_columns([
            ((pl.col(x_col) - pl.col(f"{x_col}_b"))**2 + (pl.col(y_col) - pl.col(f"{y_col}_b"))**2).sqrt().alias("dist"),

            # Vận tốc — dùng diff, fill null
            ((pl.col(x_col).diff().fill_null(0))**2 + (pl.col(y_col).diff().fill_null(0))**2).sqrt().alias("speed_a"),
            ((pl.col(f"{x_col}_b").diff().fill_null(0))**2 + (pl.col(f"{y_col}_b").diff().fill_null(0))**2).sqrt().alias("speed_b"),
        ]).with_columns([
            (pl.col("speed_a") - pl.col("speed_b")).alias("rel_speed"),
            (pl.col("speed_a").diff().fill_null(0) - pl.col("speed_b").diff().fill_null(0)).alias("rel_acc"),
            (pl.col("dist").diff().fill_null(0)).alias("ddist"),
            pl.arctan2(pl.col(y_col) - pl.col(f"{y_col}_b"), pl.col(x_col) - pl.col(f"{x_col}_b")).alias("angle"),
            pl.arctan2(pl.col(y_col) - pl.col(f"{y_col}_b"), pl.col(x_col) - pl.col(f"{x_col}_b")).diff().fill_null(0).alias("dangle"),

            # Radial velocity — an toàn
            (
                (pl.col(x_col) - pl.col(f"{x_col}_b")) * pl.col(x_col).diff().fill_null(0) +
                (pl.col(y_col) - pl.col(f"{y_col}_b")) * pl.col(y_col).diff().fill_null(0)
            ).alias("radial_vel_numerator"),
            (pl.col("dist") + 1e-8).alias("dist_safe"),
        ]).with_columns([
            (pl.col("radial_vel_numerator") / pl.col("dist_safe")).alias("radial_vel")
        ]).select([
            frame_col,
            "dist", "rel_speed", "rel_acc", "ddist", "angle", "dangle", "radial_vel"
        ]).rename({frame_col: "frame"}).sort("frame")

        return feat

    except Exception as e:
        logger.error(f"Feature computation failed: {type(e).__name__}: {e}")
        return None

def _make_windows(
    feat: pl.DataFrame,
    min_len: int,
    q_dist: float = 0.40,
    q_rel: float = 0.60,
    q_ddist: float = 0.40
) -> List[Tuple[int, int]]:
    if len(feat) == 0:
        return []

    # nếu self-like (dist==0 toàn đoạn) và có speed_a -> gate theo speed_a
    is_self_like = ("speed_a" in feat.columns) and (float(feat["dist"].max()) == 0.0)
    if is_self_like:
        qs = float(feat["speed_a"].quantile(0.70))  # hoạt động mạnh của chính agent
        cond = (pl.col("speed_a") >= qs)
    else:
        qd  = float(feat["dist"].quantile(q_dist))
        qr  = float(feat["rel_speed"].quantile(q_rel))
        qdd = float(feat["ddist"].quantile(q_ddist))
        cond = (pl.col("dist") <= qd) | ((pl.col("rel_speed") >= qr) & (pl.col("ddist") <= qdd))

    mask = feat.select(cond.alias("m")).to_series().to_list()
    frames = feat["frame"].to_list()

    windows: List[Tuple[int,int]] = []
    run: Optional[List[int]] = None
    for i, flag in enumerate(mask):
        if flag and run is None:
            run = [frames[i], frames[i]]
        elif flag and run is not None:
            run[1] = frames[i]
        elif (not flag) and run is not None:
            s, e = run[0], run[1] + 1
            if e - s >= min_len:
                windows.append((s, e))
            run = None
    if run is not None:
        s, e = run[0], run[1] + 1
        if e - s >= min_len:
            windows.append((s, e))

    return merge_intervals(windows)

def _intervals_by_action_for_pair(annot_pair_df: pl.DataFrame) -> Dict[str, List[Tuple[int,int]]]:
    """
    Tạo dict: action -> list[(start, stop)) cho ĐÚNG cặp (agent,target).
    """
    out = defaultdict(list)
    for r in annot_pair_df.select(["action","start_frame","stop_frame"]).to_dicts():
        s = int(r["start_frame"]); e = int(r["stop_frame"])
        if s < e:
            out[str(r["action"])].append((s, e))
    # sắp theo start để overlop check nhanh hơn
    for a in out.keys():
        out[a].sort(key=lambda x: x[0])
    return out

def _actions_in_window(intervals_by_action: Dict[str, List[Tuple[int,int]]], ws: int, we: int) -> Set[str]:
    """
    Trả về tập action có overlap với cửa sổ [ws,we) cho cặp hiện tại.
    """
    hits = set()
    for a, ivs in intervals_by_action.items():
        # nhánh nhanh: vì đã sort theo start, break sớm
        for s, e in ivs:
            if s >= we:  # các interval sau càng về sau, không overlap
                break
            if e > ws and s < we:  # overlap
                hits.add(a); break
    return hits

def extract_features_and_labels(
    dataset: pl.DataFrame,
    cfg: Config,
    window_size: int = 15,
    step_size: int = 15,
    downsample: int = 10,
) -> Tuple[Dict[str, List[np.ndarray]], List[Tuple[int, str, str, int, int]]]:
    """
    PAIR-AWARE + SELF-AWARE:
      - Với target='self' trong annotation (hoặc agent_id == target_id), coi là cặp self.
      - Tạo features cho self bằng agent vs agent (dist=0, dùng speed_a,...).
      - Gắn nhãn cửa sổ theo overlap giữa [ws,we) và các interval của ĐÚNG cặp đó.
    """
    X_action = defaultdict(list)
    metadata = []

    for row in tqdm(dataset.to_dicts(), total=len(dataset), desc="Extracting features & labels (pair-aware)"):
        lab_id: str = row["lab_id"]
        if str(lab_id).startswith("MABe22"):
            continue
        video_id: int = row["video_id"]
        annot_path = cfg.train_annot_dir / lab_id / f"{video_id}.parquet"
        track_path = cfg.train_track_dir / lab_id / f"{video_id}.parquet"
        if not annot_path.exists() or not track_path.exists():
            continue

        try:
            trk_agg = _prepare_tracking_minimal(track_path)
            if trk_agg is None or trk_agg.is_empty():
                continue

            # Đọc annotation và chuẩn hoá 'self'
            annot_raw = pl.read_parquet(annot_path)
            annot = annot_raw.with_columns(
                pl.concat_str([pl.lit("mouse"), pl.col("agent_id").cast(pl.Utf8)]).alias("agent_id"),
                pl.when(pl.col("agent_id") == pl.col("target_id"))
                  .then(pl.lit("self"))
                  .otherwise(pl.concat_str([pl.lit("mouse"), pl.col("target_id").cast(pl.Utf8)]))
                  .alias("target_id"),
            ).select(["agent_id","target_id","action","start_frame","stop_frame"])

            if annot.is_empty():
                continue

            # Duyệt theo từng CẶP trong annotation (đã bao gồm self nếu có)
            for (agent, target), ann_pair in annot.group_by(["agent_id","target_id"]):
                agent = str(agent); target = str(target)
                # target dùng cho feature: nếu 'self' thì dùng chính agent
                target_for_feat = agent if target.lower() == "self" else target

                feat_df = _pair_features_from_agg(trk_agg, agent, target_for_feat, downsample=downsample)
                if feat_df is None or feat_df.height < window_size:
                    continue

                # intervals cho đúng cặp
                intervals_by_action = _intervals_by_action_for_pair(ann_pair)

                frames = feat_df["frame"].to_numpy()
                if len(frames) < window_size:
                    continue

                cols = ["dist","rel_speed","rel_acc","ddist","angle","dangle","radial_vel"]
                arrs = {c: feat_df[c].to_numpy() for c in cols}

                for start_idx in range(0, len(frames) - window_size + 1, step_size):
                    ws = int(frames[start_idx])
                    we = int(frames[start_idx + window_size - 1]) + 1
                    feats = []
                    for c in cols:
                        arr = arrs[c][start_idx:start_idx+window_size]
                        feats.extend([
                            float(arr.mean()), float(arr.std()),
                            float(arr.min()), float(arr.max()),
                            float(np.median(arr)), float(arr.max() - arr.min()),
                            float(arr[-1] - arr[0]),
                        ])

                    # NHÃN: action xảy ra trong cửa sổ của CHÍNH cặp này
                    acts = _actions_in_window(intervals_by_action, ws, we)
                    if not acts:
                        continue
                    for a in acts:
                        X_action[a].append(np.asarray(feats, dtype=np.float32))
                    metadata.append((video_id, agent, target, ws, we))

        except Exception as e:
            logger.error(f"Error processing video {video_id} (lab {lab_id}): {type(e).__name__}: {e}")
            continue

    logger.info(f"Extracted pair-aware windows for {len(X_action)} actions.")
    return X_action, metadata

# ========================
# Advanced baseline prediction
# ========================

def _order_actions_by_timing(actions: List[str], lab_id: str,
                             timing_lab: Dict[str, Dict[str, float]],
                             timing_global: Dict[str, float],
                             canonical: Dict[str,int]) -> List[str]:
    def score(a: str) -> float:
        if lab_id in timing_lab and a in timing_lab[lab_id]:
            return timing_lab[lab_id][a]
        return timing_global.get(a, 0.5)
    return sorted(actions, key=lambda a: (score(a), canonical.get(a, 99)))

def _clip_rare_actions(weights_map: Dict[str,float], actions: List[str], p_min: float, cap: float) -> Dict[str,float]:
    w = {a: max(0.0, float(weights_map.get(a, 0.0))) for a in actions}
    for a in actions:
        if w[a] < p_min:
            w[a] = min(w[a], cap)
    s = sum(w.values()) or 1.0
    return {a: w[a]/s for a in actions}

def _allocate_segments_in_windows(windows: List[Tuple[int,int]],
                                  ordered_actions: List[str],
                                  weights: Dict[str,float],
                                  med_dur: Dict[str,int],
                                  total_frames: int) -> List[Tuple[str,int,int]]:
    win_idx = 0
    cur_s, cur_e = (windows[0] if windows else (0,0))
    remain = sum(e-s for s,e in windows)
    out: List[Tuple[str,int,int]] = []

    for a in ordered_actions:
        if remain <= 0: break
        want = int(weights.get(a, 0.0) * total_frames)
        want = max(want, int(med_dur.get(a, 0) or 0))
        want = min(want, remain)
        got = 0
        while got < want and win_idx < len(windows):
            s,e = cur_s, cur_e
            if s >= e:
                win_idx += 1
                if win_idx >= len(windows): break
                cur_s, cur_e = windows[win_idx]
                continue
            take = min(want - got, e - s)
            out.append((a, s, s+take))
            got += take
            remain -= take
            cur_s = s + take
            if cur_s >= e and win_idx < len(windows):
                win_idx += 1
                if win_idx < len(windows):
                    cur_s, cur_e = windows[win_idx]
    return out

def _smooth_segments(segments: List[Tuple[str,int,int]], min_len: int, gap_close: int) -> List[Tuple[str,int,int]]:
    if not segments: return []
    segments = sorted(segments, key=lambda x: (x[1], x[2], x[0]))
    segments = [seg for seg in segments if seg[2] - seg[1] >= min_len]
    if not segments: return []
    out = [segments[0]]
    for a,s,e in segments[1:]:
        pa,ps,pe = out[-1]
        if a == pa and s - pe <= gap_close:
            out[-1] = (a, ps, e)
        else:
            out.append((a,s,e))
    return out

def predict_without_ml(dataset: pl.DataFrame, data_split: str, cfg: Optional[Config] = None,
                       priors_per_lab: Optional[Dict[str, Dict[str, float]]] = None,
                       priors_global: Optional[Dict[str, float]] = None,
                       meddur_per_lab: Optional[Dict[str, Dict[str, int]]] = None,
                       meddur_global: Optional[Dict[str, int]] = None,
                       timing_lab: Optional[Dict[str, Dict[str, float]]] = None,
                       timing_global: Optional[Dict[str, float]] = None,
                       prior_scope: str = "mixed",
                       use_windows: bool = True,
                       min_len: int = 10,
                       gap_close: int = 5,
                       p_min: float = 0.03,
                       cap: float = 0.02) -> pl.DataFrame:
    cfg = cfg or Config()
    track_dir = cfg.test_track_dir if data_split == "test" else cfg.train_track_dir
    records: List[Tuple[int, str, str, str, int, int]] = []
    canonical = {"approach": 0, "avoid": 1, "chase": 2, "chaseattack": 3, "attack": 4, "mount": 5, "submit": 6}

    for row in tqdm(dataset.to_dicts(), total=len(dataset), desc=f"Predicting ({data_split})"):
        lab_id: str = row["lab_id"]
        if lab_id.startswith("MABe22"): continue
        video_id: int = row["video_id"]
        path = track_dir / lab_id / f"{video_id}.parquet"
        if not path.exists():
            logger.warning("Tracking file not found: %s", path)
            continue

        try:
            trk = pl.read_parquet(path)
            start_frame = int(trk["video_frame"].min())
            stop_frame = int(trk["video_frame"].max()) + 1
            video_frames = stop_frame - start_frame
            if video_frames <= 0: continue

            raw_list = safe_json_loads(row["behaviors_labeled"])
            triples: List[List[str]] = []
            for b in raw_list:
                parts = [p.strip() for p in str(b).replace("'", "").split(",")]
                if len(parts) == 3:
                    triples.append(parts)
            if not triples:
                continue

            beh_df = pl.DataFrame(triples, schema=["agent","target","action"], orient="row").with_columns(
                [pl.col("agent").cast(pl.Utf8), pl.col("target").cast(pl.Utf8), pl.col("action").cast(pl.Utf8)]
            )

            for (agent, target), group in beh_df.group_by(["agent","target"]):
                actions = sorted(list(set(group["action"].to_list())), key=lambda a: canonical.get(a, 99))
                if not actions: continue

                if prior_scope == "lab" and priors_per_lab is not None:
                    w_map = priors_per_lab.get(str(lab_id), {})
                    md_map = meddur_per_lab.get(str(lab_id), {}) if meddur_per_lab else {}
                elif prior_scope == "global" and priors_global is not None:
                    w_map = priors_global
                    md_map = meddur_global or {}
                else:
                    w_map = (priors_per_lab or {}).get(str(lab_id), {}) or (priors_global or {})
                    md_map = (meddur_per_lab or {}).get(str(lab_id), {}) or (meddur_global or {})

                weights = _clip_rare_actions(w_map, actions, p_min=p_min, cap=cap)
                ordered_actions = _order_actions_by_timing(
                    actions, str(lab_id), timing_lab or {}, timing_global or {}, canonical
                )

                windows: List[Tuple[int,int]]
                if use_windows:
                    feat = _pair_features(trk, _norm_mouse_id(agent), _norm_mouse_id(target))
                    if feat is None:
                        windows = [(start_frame, stop_frame)]
                    else:
                        windows = _make_windows(feat, min_len=min_len)
                        if not windows:
                            windows = [(start_frame, stop_frame)]
                else:
                    windows = [(start_frame, stop_frame)]

                windows = merge_intervals(windows)
                allowed_total = sum(e - s for s,e in windows)
                if allowed_total <= 0:
                    continue

                segs = _allocate_segments_in_windows(
                    windows=windows,
                    ordered_actions=ordered_actions,
                    weights=weights,
                    med_dur=md_map,
                    total_frames=allowed_total
                )

                segs = _smooth_segments(segs, min_len=min_len, gap_close=gap_close)

                for a, s, e in segs:
                    if e > s:
                        records.append((
                            video_id,
                            _norm_mouse_id(agent), _norm_mouse_id(target),
                            a, int(s), int(e)
                        ))

        except Exception as e:
            logger.error("Error processing %s: %s", path, e)
            continue

    if not records:
        raise ValueError("No predictions generated.")

    df = pl.DataFrame(
        records,
        schema={
            "video_id": pl.Int64, "agent_id": pl.Utf8, "target_id": pl.Utf8,
            "action": pl.Utf8, "start_frame": pl.Int64, "stop_frame": pl.Int64,
        },
        orient="row",
    )
    df = validate_schema(df, cfg.submission_schema, "Submission")
    validate_frame_ranges(df, "Submission")
    return df
# ========================
# Helpers (NEW)
# ========================

def parse_behaviors_labeled(raw: Any) -> pl.DataFrame:
    """
    Chuẩn hóa behaviors_labeled thành DataFrame (agent,target,action),
    cho phép 'self' ở target. Không đổi 'self' khi xuất submission.
    """
    triples = []
    for b in safe_json_loads(raw):
        parts = [p.strip() for p in str(b).replace("'", "").split(",")]
        if len(parts) == 3:
            a, t, act = parts
            if t.lower() == "self":
                # target='self' giữ nguyên 'self' (không thêm prefix)
                pass
            else:
                if not str(t).startswith("mouse"):
                    t = f"mouse{t}"
            if not str(a).startswith("mouse"):
                a = f"mouse{a}"
            triples.append([a, t, act])
    if not triples:
        return pl.DataFrame(schema={"agent": pl.Utf8, "target": pl.Utf8, "action": pl.Utf8})
    return pl.DataFrame(triples, schema=["agent", "target", "action"], orient="row")


def _prepare_tracking_minimal(path: Path) -> Optional[pl.DataFrame]:
    """
    Đọc parquet tracking, chỉ giữ (frame,id,x,y), gộp bodyparts nếu có.
    Thực hiện MỘT LẦN cho mỗi video để tái sử dụng cho mọi cặp.
    """
    if not path.exists():
        return None
    try:
        df = pl.read_parquet(
            path,
            columns=["video_frame", "mouse_id", "centroid_x", "centroid_y", "x", "y", "bodypart"],
            ignore_errors=True,
        )
    except Exception:
        df = pl.read_parquet(path)

    cols = set(df.columns)
    frame_col = "video_frame" if "video_frame" in cols else ("frame" if "frame" in cols else "frame_idx")
    id_col    = "mouse_id"    if "mouse_id"    in cols else ("id"    if "id"    in cols else "track_id")
    x_col     = "centroid_x"  if "centroid_x"  in cols else "x"
    y_col     = "centroid_y"  if "centroid_y"  in cols else "y"

    req = [frame_col, id_col, x_col, y_col]
    if any(c not in df.columns for c in req):
        logger.warning(f"Tracking schema thiếu cột tối thiểu: {df.columns}")
        return None

    if "bodypart" in cols:
        preferred = ["body_center", "centroid", "center", "torso", "hip_center", "mid"]
        parts = df.select("bodypart").unique().to_series().to_list()
        chosen = [p for p in preferred if p in parts] or (parts[:3] if parts else [])
        if chosen:
            df = df.filter(pl.col("bodypart").is_in(chosen))
        df = df.group_by([frame_col, id_col]).agg([
            pl.col(x_col).mean().alias(x_col),
            pl.col(y_col).mean().alias(y_col),
        ])

    # chuẩn hoá id (bỏ prefix 'mouse' để join)
    df = df.with_columns([
        pl.col(id_col).cast(pl.Utf8).str.replace("mouse", "").str.strip_chars().alias(id_col)
    ])

    out = df.select([
        pl.col(frame_col).alias("frame"),
        pl.col(id_col).alias("id"),
        pl.col(x_col).alias("x"),
        pl.col(y_col).alias("y"),
    ]).sort(["frame", "id"]).unique(subset=["frame", "id"], keep="first")

    return out


def _pair_features_from_agg(agg_df: pl.DataFrame, agent_raw: str, target_raw: str, downsample: int = 1) -> Optional[pl.DataFrame]:
    """
    Tạo feature cặp từ DF đã _prepare_tracking_minimal.
    Giữ thêm 'speed_a' để gating tựa-self.
    """
    a_id = _strip_mouse_prefix(agent_raw)
    t_id = _strip_mouse_prefix(target_raw)
    a_df = agg_df.filter(pl.col("id") == a_id)
    b_df = agg_df.filter(pl.col("id") == t_id)
    if a_df.is_empty() or b_df.is_empty():
        return None

    merged = a_df.join(b_df.rename({"x": "x_b", "y": "y_b"}), on="frame", how="inner")

    if downsample > 1:
        merged = merged.with_row_index("_i").filter(pl.col("_i") % downsample == 0).drop("_i")

    if merged.is_empty():
        return None

    feat = merged.with_columns([
        ((pl.col("x") - pl.col("x_b"))**2 + (pl.col("y") - pl.col("y_b"))**2).sqrt().alias("dist"),
        ((pl.col("x").diff().fill_null(0))**2 + (pl.col("y").diff().fill_null(0))**2).sqrt().alias("speed_a"),
        ((pl.col("x_b").diff().fill_null(0))**2 + (pl.col("y_b").diff().fill_null(0))**2).sqrt().alias("speed_b"),
    ]).with_columns([
        (pl.col("speed_a") - pl.col("speed_b")).alias("rel_speed"),
        (pl.col("speed_a").diff().fill_null(0) - pl.col("speed_b").diff().fill_null(0)).alias("rel_acc"),
        (pl.col("dist").diff().fill_null(0)).alias("ddist"),
        pl.arctan2(pl.col("y") - pl.col("y_b"), pl.col("x") - pl.col("x_b")).alias("angle"),
        pl.arctan2(pl.col("y") - pl.col("y_b"), pl.col("x") - pl.col("x_b")).diff().fill_null(0).alias("dangle"),
        (
            (pl.col("x") - pl.col("x_b")) * pl.col("x").diff().fill_null(0) +
            (pl.col("y") - pl.col("y_b")) * pl.col("y").diff().fill_null(0)
        ).alias("rv_num"),
    ]).with_columns([
        (pl.col("rv_num") / (pl.col("dist") + 1e-8)).alias("radial_vel")
    ]).select([
        "frame", "dist", "rel_speed", "rel_acc", "ddist", "angle", "dangle", "radial_vel", "speed_a"  # giữ speed_a
    ]).sort("frame")

    return feat

# ========================
# MACHINE LEARNING TRAINING 
# ========================
# ========================
# Submission hardening (NEW)
# ========================

ALLOWED_AGENT  = {f"mouse{i}" for i in range(1, 5)}
ALLOWED_TARGET = ALLOWED_AGENT | {"self"}

def _resolve_non_overlap(group: pl.DataFrame) -> pl.DataFrame:
    """
    Biến các segment chồng lấp trong *cùng một cặp* thành không chồng lấp,
    ưu tiên giữ đoạn xuất hiện trước (theo start_frame), gộp nếu cùng action liền kề.
    """
    rows = group.select(["action", "start_frame", "stop_frame"]).to_dicts()
    rows.sort(key=lambda r: (int(r["start_frame"]), int(r["stop_frame"]), str(r["action"])))
    out: List[Tuple[str, int, int]] = []
    cur_end = -10**12

    for r in rows:
        a = str(r["action"])
        s = int(r["start_frame"])
        e = int(r["stop_frame"])
        if s >= e:
            continue
        # cắt phần chồng lấp với đoạn đã nhận
        if s < cur_end:
            s = cur_end
        if s >= e:
            continue
        if out and out[-1][0] == a and s <= out[-1][2]:
            out[-1] = (a, out[-1][1], max(out[-1][2], e))
        else:
            out.append((a, s, e))
        cur_end = out[-1][2]

    if not out:
        return group.head(0)

    return pl.DataFrame(out, schema={"action": pl.Utf8, "start_frame": pl.Int64, "stop_frame": pl.Int64})


ALLOWED_AGENT  = {f"mouse{i}" for i in range(1, 5)}
ALLOWED_TARGET = ALLOWED_AGENT | {"self"}

def _resolve_non_overlap(group: pl.DataFrame) -> pl.DataFrame:
    """
    Biến các segment chồng lấp trong *cùng một cặp* thành không chồng lấp,
    ưu tiên giữ đoạn xuất hiện trước (theo start_frame), gộp nếu cùng action liền kề.
    """
    rows = group.select(["action", "start_frame", "stop_frame"]).to_dicts()
    rows.sort(key=lambda r: (int(r["start_frame"]), int(r["stop_frame"]), str(r["action"])))
    out: List[Tuple[str, int, int]] = []
    cur_end = -10**12

    for r in rows:
        a = str(r["action"])
        s = int(r["start_frame"])
        e = int(r["stop_frame"])
        if s >= e:
            continue
        # cắt phần chồng lấp với đoạn đã nhận
        if s < cur_end:
            s = cur_end
        if s >= e:
            continue
        if out and out[-1][0] == a and s <= out[-1][2]:
            out[-1] = (a, out[-1][1], max(out[-1][2], e))
        else:
            out.append((a, s, e))
        cur_end = out[-1][2]

    if not out:
        return group.head(0)

    return pl.DataFrame(out, schema={"action": pl.Utf8, "start_frame": pl.Int64, "stop_frame": pl.Int64})


def _finalize_and_validate_submission(
    sub: pl.DataFrame,
    meta_df: pl.DataFrame,
    cfg: Config,
    allowed_actions: Optional[Set[str]] = None,
    split: str = "test",  # "train" | "test"
) -> pl.DataFrame:
    """
    - Cast/strip; filter id/action (target cho phép 'self')
    - Clip theo span video, sau đó SHIFT về 0-based theo từng video
    - Khử overlap per (video_id, agent_id, target_id)
    - Sort theo (video, agent, target, start, stop), thêm row_id ở cột đầu
    """
    # 0) Cast & strip
    sub = validate_schema(sub, cfg.submission_schema, "Submission")
    sub = sub.with_columns([
        pl.col("agent_id").cast(pl.Utf8).str.strip_chars(),
        pl.col("target_id").cast(pl.Utf8).str.strip_chars(),
        pl.col("action").cast(pl.Utf8).str.strip_chars(),
    ])

    # 1) Filter ids/actions
    sub = sub.filter(
        pl.col("agent_id").is_in(ALLOWED_AGENT) &
        pl.col("target_id").is_in(ALLOWED_TARGET)
    )
    if allowed_actions:
        sub = sub.filter(pl.col("action").is_in(sorted(list(allowed_actions))))

    # 2) Build spans & clip
    spans = build_video_spans(meta_df, split, cfg)  # {vid: (min_frame, max_frame+1)}
    if not spans:
        raise HostVisibleError("Cannot build spans for clipping/validation.")
    spdf = pl.DataFrame({
        "video_id": list(spans.keys()),
        "vmin":    [spans[v][0] for v in spans.keys()],
        "vmaxp1":  [spans[v][1] for v in spans.keys()],
    })
    sub = sub.join(spdf, on="video_id", how="inner").with_columns([
        # clip vào [vmin, vmaxp1]
        pl.when(pl.col("start_frame") < pl.col("vmin")).then(pl.col("vmin")).otherwise(pl.col("start_frame")).alias("start_frame"),
        pl.when(pl.col("stop_frame")  > pl.col("vmaxp1")).then(pl.col("vmaxp1")).otherwise(pl.col("stop_frame")).alias("stop_frame"),
    ])

    # 3) SHIFT về 0-based theo từng video
    sub = sub.with_columns([
        (pl.col("start_frame") - pl.col("vmin")).cast(pl.Int64).alias("start_frame"),
        (pl.col("stop_frame")  - pl.col("vmin")).cast(pl.Int64).alias("stop_frame"),
        (pl.col("vmaxp1") - pl.col("vmin")).cast(pl.Int64).alias("_vlen"),
    ]).drop(["vmin", "vmaxp1"])

    # 4) Drop invalid/empty, dedup exact
    sub = sub.filter((pl.col("start_frame") >= 0) & (pl.col("stop_frame") <= pl.col("_vlen")))
    sub = sub.filter(pl.col("start_frame") < pl.col("stop_frame")).unique()

    # 5) Enforce NON-OVERLAP per pair
    groups = []
    for (vid, ag, tg), g in sub.group_by(["video_id", "agent_id", "target_id"]):
        fixed = _resolve_non_overlap(g)
        if not fixed.is_empty():
            fixed = fixed.with_columns([
                pl.lit(int(vid)).alias("video_id"),
                pl.lit(str(ag)).alias("agent_id"),
                pl.lit(str(tg)).alias("target_id"),
            ]).select(["video_id","agent_id","target_id","action","start_frame","stop_frame"])
            groups.append(fixed)
    sub = pl.concat(groups, how="vertical") if groups else sub.head(0)

    # 6) Final order + sort + row_id
    sub = validate_schema(sub, cfg.submission_schema, "Final Submission")
    sub = sub.sort(["video_id","agent_id","target_id","start_frame","stop_frame"]) \
             .select(["video_id","agent_id","target_id","action","start_frame","stop_frame"]) \
             .with_row_index(cfg.row_id_col)

    # 7) Hard checks
    if set(sub.columns) != {cfg.row_id_col, *cfg.submission_schema.keys()}:
        raise HostVisibleError(f"Columns mismatch: {sub.columns}")
    if sub.height == 0:
        raise HostVisibleError("Empty submission after hardening.")
    if sub.select([(pl.col("start_frame") >= pl.col("stop_frame")).any().alias("bad")])["bad"][0]:
        raise HostVisibleError("Found start_frame >= stop_frame after hardening.")
    if sub.select([pl.any_horizontal([pl.col(c).is_null() for c in sub.columns]).alias("has_null")])["has_null"][0]:
        raise HostVisibleError("Found NULL values.")
    bad_ids = sub.filter(~pl.col("agent_id").is_in(ALLOWED_AGENT) | ~pl.col("target_id").is_in(ALLOWED_TARGET))
    if bad_ids.height > 0:
        raise HostVisibleError(f"Found invalid ids, e.g.: {bad_ids.head(3).to_dicts()}")
    if allowed_actions:
        bad_act = sub.filter(~pl.col("action").is_in(sorted(list(allowed_actions))))
        if bad_act.height > 0:
            raise HostVisibleError(f"Found invalid actions, e.g.: {bad_act.head(3).to_dicts()}")

    return sub

    
def _zero_base_solution(solution: pl.DataFrame, meta_df: pl.DataFrame, cfg: Config) -> pl.DataFrame:
    """
    Dịch solution về 0-based theo từng video (dùng spans của TRAIN).
    """
    spans = build_video_spans(meta_df, "train", cfg)
    if not spans:
        raise HostVisibleError("Cannot build train spans for solution zero-basing.")
    spdf = pl.DataFrame({
        "video_id": list(spans.keys()),
        "vmin":    [spans[v][0] for v in spans.keys()],
        "vmaxp1":  [spans[v][1] for v in spans.keys()],
    })
    sol = solution.join(spdf, on="video_id", how="inner").with_columns([
        (pl.col("start_frame") - pl.col("vmin")).cast(pl.Int64).alias("start_frame"),
        (pl.col("stop_frame")  - pl.col("vmin")).cast(pl.Int64).alias("stop_frame"),
    ]).drop(["vmin","vmaxp1"])
    validate_frame_ranges(sol, "Zero-based Solution")
    return sol


def build_behavioral_graph(solution: pl.DataFrame, max_gap: int = 1, min_count: int = 2, min_freq: float = 0.005) -> Dict[str, Set[str]]:
    from collections import defaultdict
    transitions_raw = defaultdict(set)
    transition_counts = defaultdict(int)

    for key, group in solution.group_by(["video_id", "agent_id", "target_id"]):
        sorted_group = group.sort("start_frame")
        rows = sorted_group.to_dicts()
        for row in rows:
            action = row["action"]
            transitions_raw[action].add(action)
            transition_counts[(action, action)] += 1
        for i in range(len(rows) - 1):
            curr = rows[i]
            next_ = rows[i + 1]
            gap = next_["start_frame"] - curr["stop_frame"]
            if gap <= max_gap:
                a, b = curr["action"], next_["action"]
                transitions_raw[a].add(b)
                transition_counts[(a, b)] += 1

    total_trans = sum(transition_counts.values())
    min_threshold = max(min_count, int(min_freq * total_trans))
    final_transitions = defaultdict(set)
    for (a, b), cnt in transition_counts.items():
        if cnt >= min_threshold:
            final_transitions[a].add(b)

    all_actions = set(solution["action"].unique())
    for a in all_actions:
        if a not in final_transitions:
            final_transitions[a] = {a}
        elif a not in final_transitions[a]:
            final_transitions[a].add(a)

    return {k: v for k, v in final_transitions.items()}
    
def extract_features_and_labels(
    dataset: pl.DataFrame,
    cfg: Config,
    window_size: int = 15,
    step_size: int = 15,
    downsample: int = 10,
) -> Tuple[Dict[str, List[np.ndarray]], List[Tuple[int, str, str, int, int]]]:

    X_action = defaultdict(list)
    metadata = []
    all_actions = set()

    for row in tqdm(dataset.to_dicts(), total=len(dataset), desc="Extracting features & labels (WSTAL)"):
        lab_id: str = row["lab_id"]
        if lab_id.startswith("MABe22"): continue
        video_id: int = row["video_id"]
        annot_path = cfg.train_annot_dir / lab_id / f"{video_id}.parquet"
        track_path = cfg.train_track_dir / lab_id / f"{video_id}.parquet"
        if not annot_path.exists() or not track_path.exists():
            continue
        try:
            trk = pl.read_parquet(track_path)
            annot = pl.read_parquet(annot_path).with_columns(
                pl.concat_str([pl.lit("mouse"), pl.col("agent_id").cast(pl.Utf8)]).alias("agent_id"),
                pl.concat_str([pl.lit("mouse"), pl.col("target_id").cast(pl.Utf8)]).alias("target_id"),
            )
            behavior_pairs = annot.select(["agent_id", "target_id"]).unique().to_dicts()
            # Build frame_actions: frame -> set of actions
            frame_actions = defaultdict(set)
            for ann_row in annot.to_dicts():
                action = ann_row["action"]
                all_actions.add(action)
                for f in range(ann_row["start_frame"], ann_row["stop_frame"]):
                    frame_actions[f].add(action)

            for bp in behavior_pairs:
                agent = bp["agent_id"]
                target = bp["target_id"]
                feat_df = _pair_features(trk, agent, target, downsample=downsample)
                if feat_df is None or len(feat_df) == 0:
                    continue
                frames = feat_df["frame"].to_numpy()
                if len(frames) < window_size:
                    continue
                for start_idx in range(0, len(frames) - window_size + 1, step_size):
                    win_start_frame = int(frames[start_idx])
                    win_end_frame = int(frames[start_idx + window_size - 1]) + 1
                    win_feat = feat_df.slice(start_idx, window_size)
                    if len(win_feat) < window_size:
                        continue
                    features = []
                    for col in ["dist", "rel_speed", "rel_acc", "ddist", "angle", "dangle", "radial_vel"]:
                        arr = win_feat[col].to_numpy()
                        features.extend([
                            np.mean(arr), np.std(arr), np.min(arr), np.max(arr),
                            np.median(arr), np.max(arr) - np.min(arr), arr[-1] - arr[0],
                        ])
                    # Weak labeling: window is positive for every action that appears in it
                    window_actions = set()
                    for f in range(win_start_frame, win_end_frame):
                        window_actions.update(frame_actions[f])
                    if window_actions:
                        for action in window_actions:
                            X_action[action].append(features)
                        metadata.append((video_id, agent, target, win_start_frame, win_end_frame))
        except Exception as e:
            logger.error(f"Error processing video {video_id} for lab {lab_id}: {type(e).__name__}: {e}")
            continue

    logger.info(f"Extracted data for {len(X_action)} actions.")
    return X_action, metadata

import numpy as np
import joblib

PROCESSED_DATA_PATH = "processed_mouse_data.npz"
# MODEL_PATH = "mouse_behavior_model.joblib" # train
MODEL_PATH = "/kaggle/input/mabexa/mouse_behavior_model (1).joblib" # infer

def train_ml_model(cfg: Config, window_size: int = 15, step_size: int = 15, downsample: int = 10) -> None:
    logger.info("Loading train data for ML training...")
    train = pl.read_csv(cfg.train_csv)
    train_subset = train.filter(~pl.col("lab_id").str.starts_with("MABe22"))

    if os.path.exists(PROCESSED_DATA_PATH):
        logger.info(f"Loading preprocessed data from {PROCESSED_DATA_PATH}")
        data = np.load(PROCESSED_DATA_PATH, allow_pickle=True)
        X_action = {k: v.tolist() for k, v in data.items() if k != 'metadata'}
        metadata = data['metadata']
    else:
        logger.info("Extracting features and labels (PAIR-AWARE + SELF-AWARE)...")
        X_action, metadata = extract_features_and_labels(
            train_subset, cfg,
            window_size=window_size,
            step_size=step_size,
            downsample=downsample,
        )
        save_dict = {action: np.array(feats) for action, feats in X_action.items()}
        save_dict['metadata'] = np.array(metadata, dtype=object)
        np.savez_compressed(PROCESSED_DATA_PATH, **save_dict)
        logger.info("Preprocessed data saved.")

    # Train per-action binary models
    models: Dict[str, Any] = {}
    rng = np.random.default_rng(42)
    for action, X_pos in X_action.items():
        X_pos = np.array(X_pos, dtype=np.float32)
        if len(X_pos) == 0:
            logger.warning(f"Skip action {action}: no positive samples.")
            continue

        # Negative sampling từ các action khác
        neg_pool = []
        for other_act, feats in X_action.items():
            if other_act != action:
                neg_pool.extend(feats)
        if not neg_pool:
            neg_pool = X_pos  # fallback an toàn
        neg_pool = np.array(neg_pool, dtype=np.float32)
        n_neg = int(min(len(neg_pool), 5 * len(X_pos)))
        neg_indices = rng.choice(len(neg_pool), size=n_neg, replace=False)
        X_neg = neg_pool[neg_indices]

        X = np.vstack([X_pos, X_neg])
        y = np.hstack([np.ones(len(X_pos), dtype=np.int32), np.zeros(len(X_neg), dtype=np.int32)])

        model = XGBClassifier(
            n_estimators=300,
            max_depth=7,
            learning_rate=0.1,
            subsample=0.8,
            colsample_bytree=0.8,
            scale_pos_weight=(len(X_neg) / max(1, len(X_pos))),
            random_state=42,
            n_jobs=-1,
            verbosity=0,
        )
        model.fit(X, y)
        models[action] = model
        logger.info(f"Trained action [{action}]: pos={len(X_pos)}, neg={len(X_neg)}")

    # Behavioral graph từ solution train (không đổi)
    solution = create_solution_df(train_subset, cfg)
    behavioral_graph = build_behavioral_graph(solution)

    joblib.dump({
        "models": models,
        "window_size": window_size,
        "actions": list(models.keys()),  # có thể gồm 'rear', 'attack', ...
        "behavioral_graph": behavioral_graph
    }, MODEL_PATH)
    logger.info(f"Per-action models saved to {MODEL_PATH}")


def merge_consecutive_segments(segments: List[Tuple[int,int,str]], max_gap: int = 5, min_duration: int = 15) -> List[Tuple[int,int,str]]:
    if not segments:
        return []
    segments = sorted(segments, key=lambda x: x[0])
    merged = [segments[0]]
    for s, e, lbl in segments[1:]:
        ps, pe, plbl = merged[-1]
        if lbl == plbl and s - pe <= max_gap:
            merged[-1] = (ps, max(pe, e), lbl)
        else:
            merged.append((s, e, lbl))
    return [(s, e, lbl) for s, e, lbl in merged if e - s >= min_duration]

def predict_with_ml(
    dataset: pl.DataFrame,
    cfg: Config,
    model_path: str = MODEL_PATH,
    min_len: int = 15,
    gap_close: int = 5,
) -> pl.DataFrame:
    logger.info(f"Loading ML model from {model_path}")
    saved = joblib.load(model_path)
    models: Dict[str, Any] = saved["models"]
    window_size: int = int(saved["window_size"])
    behavioral_graph: Dict[str, Set[str]] = saved.get("behavioral_graph", {})
    actions_sorted = sorted(models.keys())

    records: List[Tuple[int, str, str, str, int, int]] = []

    for row in tqdm(dataset.to_dicts(), total=len(dataset), desc="Predicting with ML (per-action, fast)"):
        lab_id: str = row["lab_id"]
        if str(lab_id).startswith("MABe22"):
            continue
        video_id: int = row["video_id"]
        track_path = cfg.test_track_dir / lab_id / f"{video_id}.parquet"
        if not track_path.exists():
            logger.warning("Tracking file not found: %s", track_path)
            continue

        beh_df = parse_behaviors_labeled(row.get("behaviors_labeled"))
        if beh_df.is_empty():
            continue

        trk_agg = _prepare_tracking_minimal(track_path)
        if trk_agg is None or trk_agg.is_empty():
            logger.warning("Empty tracking after prepare for %s", track_path)
            continue

        for (agent, target), _grp in beh_df.group_by(["agent", "target"]):
            # target 'self' -> dùng chính agent cho feature, nhưng giữ 'self' khi xuất
            orig_target = str(target)
            pair_target_for_feat = agent if orig_target.lower() == "self" else target

            feat_df = _pair_features_from_agg(trk_agg, agent, pair_target_for_feat, downsample=10)
            if feat_df is None or feat_df.height < window_size:
                continue

            # Gating nhanh
            gate_windows = _make_windows(feat_df, min_len=min_len)
            if not gate_windows:
                fmin = int(feat_df["frame"].min())
                fmax = int(feat_df["frame"].max()) + 1
                gate_windows = [(fmin, fmax)]

            cols = ["dist", "rel_speed", "rel_acc", "ddist", "angle", "dangle", "radial_vel", "speed_a"]
            arrays = {c: (feat_df[c].to_numpy() if c in feat_df.columns else None) for c in cols}
            frames = feat_df["frame"].to_numpy()
            f2i = {int(f): i for i, f in enumerate(frames)}

            step = max(1, window_size // 2)
            X_batch: List[np.ndarray] = []
            W_bounds: List[Tuple[int, int]] = []

            for (ws, we) in gate_windows:
                # map gần đúng từ frame -> index
                start_idx = next((i for i, f in enumerate(frames) if f >= ws), None)
                end_idx   = next((i for i, f in enumerate(frames[::-1]) if f < we), None)
                if end_idx is not None:
                    end_idx = len(frames) - 1 - end_idx
                if start_idx is None or end_idx is None or end_idx - start_idx + 1 < window_size:
                    continue

                for si in range(start_idx, end_idx - window_size + 2, step):
                    ei = si + window_size
                    feats = []
                    for c in ["dist", "rel_speed", "rel_acc", "ddist", "angle", "dangle", "radial_vel"]:
                        arr = arrays[c][si:ei]
                        feats.extend([
                            float(arr.mean()),
                            float(arr.std()),
                            float(arr.min()),
                            float(arr.max()),
                            float(np.median(arr)),
                            float(arr.max() - arr.min()),
                            float(arr[-1] - arr[0]),
                        ])
                    X_batch.append(np.asarray(feats, dtype=np.float32))
                    W_bounds.append((int(frames[si]), int(frames[ei-1]) + 1))

            if not X_batch:
                continue

            X_batch = np.stack(X_batch, axis=0)

            # Batch predict cho tất cả actions
            action_probs: Dict[str, np.ndarray] = {}
            for act in actions_sorted:
                model = models[act]
                p = model.predict_proba(X_batch)[:, 1]
                action_probs[act] = p

            # Greedy với behavioral graph
            segments: List[Tuple[int, int, str]] = []
            prev_action = "background"
            last_end = -10**9

            for idx, (s, e) in enumerate(W_bounds):
                best_act, best_p = None, 0.0
                for act in actions_sorted:
                    p = float(action_probs[act][idx])
                    allowed = behavioral_graph.get(prev_action, {act})
                    if act not in allowed:
                        p *= 0.1
                    if p > best_p:
                        best_p, best_act = p, act
                if best_act is not None and best_p >= 0.25:
                    if s - last_end > 30:
                        prev_action = "background"
                    segments.append((s, e, best_act))
                    prev_action = best_act
                    last_end = e

            if not segments:
                continue

            merged = merge_consecutive_segments(segments, max_gap=gap_close, min_duration=min_len)
            for s, e, lbl in merged:
                records.append((
                    video_id,
                    _norm_mouse_id(agent),
                    ("self" if orig_target.lower() == "self" else _norm_mouse_id(target)),
                    lbl, int(s), int(e)
                ))

    if not records:
        logger.warning("No ML predictions → falling back to heuristic")
        return predict_without_ml(dataset, "test", cfg, use_windows=True, min_len=min_len, gap_close=gap_close)

    df = pl.DataFrame(
        records,
        schema={
            "video_id": pl.Int64, "agent_id": pl.Utf8, "target_id": pl.Utf8,
            "action": pl.Utf8, "start_frame": pl.Int64, "stop_frame": pl.Int64,
        },
        orient="row",
    )
    df = validate_schema(df, cfg.submission_schema, "ML Submission")
    validate_frame_ranges(df, "ML Submission")
    return df


def run_train_ml(cfg: Config, window_size: int, step_size: int, downsample: int) -> None:
    train_ml_model(cfg, window_size=window_size, step_size=step_size, downsample=downsample)

def run_submit_ml(cfg: Config, min_len: int, gap_close: int) -> None:
    test = pl.read_csv(cfg.test_csv)
    raw_sub = predict_with_ml(test, cfg, min_len=min_len, gap_close=gap_close)

    saved = joblib.load(MODEL_PATH)
    allowed_actions = set(saved.get("actions", list(saved["models"].keys())))

    sub = _finalize_and_validate_submission(raw_sub, test, cfg, allowed_actions, split="test")

    ordered = [cfg.row_id_col, "video_id", "agent_id", "target_id", "action", "start_frame", "stop_frame"]
    sub = sub.select(ordered)

    sub.write_csv(cfg.submission_file)
    logger.info(f"ML submission saved to {cfg.submission_file}")


# ========================
# CLI / Main
# ========================

def run_validate(cfg: Config, beta: float, prior_scope: str, eps: float,
                 use_windows: bool, min_len: int, gap_close: int, p_min: float, cap: float) -> float:
    logger.info("Loading train data for validation: %s", cfg.train_csv)
    train = pl.read_csv(cfg.train_csv)
    train_subset = train.filter(~pl.col("lab_id").str.starts_with("MABe22"))

    logger.info("Building solution dataframe & spans...")
    solution = create_solution_df(train_subset, cfg)

    # Zero-base solution theo spans train
    solution = _zero_base_solution(solution, train_subset, cfg)

    logger.info("Computing priors (eps=%.2f) & timing...", eps)
    spans = build_video_spans(train_subset, "train", cfg)
    per_lab, global_w, med_lab, med_glob = compute_action_priors(solution, eps=eps)
    timing_lab, timing_glob = compute_timing_priors(solution, spans)

    logger.info("Generating predictions (advanced)...")
    submission_train = predict_without_ml(
        train_subset, "train", cfg,
        priors_per_lab=per_lab, priors_global=global_w,
        meddur_per_lab=med_lab, meddur_global=med_glob,
        timing_lab=timing_lab, timing_global=timing_glob,
        prior_scope=prior_scope,
        use_windows=use_windows, min_len=min_len, gap_close=gap_close,
        p_min=p_min, cap=cap
    )

    # Harden + zero-base submission_train theo spans train
    allowed_actions = set(solution["action"].unique())
    submission_train = _finalize_and_validate_submission(
        submission_train, train_subset, cfg, allowed_actions, split="train"
    )

    logger.info("Scoring (beta=%.3f)...", beta)
    val_score = score(solution, submission_train, cfg.row_id_col, beta=beta, cfg=cfg)
    print(f"[RESULT] Validation F1: {val_score:.6f}")
    return val_score


def run_submit_ml(cfg: Config, min_len: int, gap_close: int) -> None:
    test = pl.read_csv(cfg.test_csv)
    raw_sub = predict_with_ml(test, cfg, min_len=min_len, gap_close=gap_close)

    saved = joblib.load(MODEL_PATH)
    allowed_actions = set(saved.get("actions", list(saved["models"].keys())))

    sub = _finalize_and_validate_submission(raw_sub, test, cfg, allowed_actions, split="test")

    ordered = [cfg.row_id_col, "video_id", "agent_id", "target_id", "action", "start_frame", "stop_frame"]
    sub = sub.select(ordered)

    sub.write_csv(cfg.submission_file)
    logger.info(f"ML submission saved to {cfg.submission_file}")


def main() -> None:
    parser = argparse.ArgumentParser(description="MABe Mouse Behavior: ML Baseline (F1-optimized + FAST)")
    parser.add_argument("--data-root", type=str, default=None, help="Dataset root directory")
    parser.add_argument("--beta", type=float, default=1.0, help="F-beta value")
    parser.add_argument("--mode", choices=["validate", "submit", "all", "train-ml", "submit-ml"], default="all", help="Run mode")
    parser.add_argument("--submission", type=str, default=None, help="Submission CSV output path")
    parser.add_argument("--prior-scope", choices=["lab", "global", "mixed"], default="mixed",
                        help="Use lab-level priors, global priors, or per-lab with global fallback (default)")
    parser.add_argument("--eps", type=float, default=1.0, help="Laplace smoothing (frame units) for priors")
    parser.add_argument("--no-windows", action="store_true", help="Disable proximity windows (use full video)")
    parser.add_argument("--min-len", type=int, default=15, help="Minimum segment/window length (frames)")
    parser.add_argument("--gap-close", type=int, default=3, help="Merge same-action gaps up to this many frames")
    parser.add_argument("--p-min", type=float, default=0.05, help="Rare-action min prior threshold")
    parser.add_argument("--cap", type=float, default=0.05, help="Rare-action maximum share if below p-min")
    parser.add_argument("-v", "--verbose", action="count", default=1, help="Increase verbosity (-v, -vv)")
    args = parser.parse_args()

    setup_logging(args.verbose)
    warnings.filterwarnings("ignore")

    cfg = Config(
        data_root=Path(args.data_root) if args.data_root else Config().data_root,
        submission_file=args.submission if args.submission else Config().submission_file,
        row_id_col=Config().row_id_col,
    )

    if args.mode == "train-ml":
        run_train_ml(cfg, window_size=15, step_size=15, downsample=10)
    elif args.mode == "submit-ml":
        run_submit_ml(cfg, min_len=args.min_len, gap_close=args.gap_close)
    else:
        val = None
        if args.mode in ("validate", "all"):
            val = run_validate(cfg, beta=args.beta, prior_scope=args.prior_scope, eps=args.eps,
                               use_windows=not args.no_windows, min_len=args.min_len, gap_close=args.gap_close,
                               p_min=args.p_min, cap=args.cap)
            logger.info("Validation F1: %.6f", val)
        if args.mode in ("submit", "all"):
            run_submit(cfg, prior_scope=args.prior_scope, eps=args.eps,
                       use_windows=not args.no_windows, min_len=args.min_len, gap_close=args.gap_close,
                       p_min=args.p_min, cap=args.cap)

if __name__ == "__main__":
    main()

In [None]:
# !python baseline.py --mode train-ml -vv

In [None]:
!python baseline.py --mode submit-ml --min-len 15 --gap-close 5 -vv

In [None]:
import joblib
saved = joblib.load('/kaggle/input/mabexa/mouse_behavior_model (1).joblib')
print(sorted(saved.get("actions", list(saved["models"].keys()))))


In [None]:
import pandas as pd 

df = pd.read_csv('submission.csv')
print(df.head(10))
print(df.shape)