# Set Paths & Select Config (CFG)

In [3]:
# ============================================================
# STAGE 0 — Set Paths & Select Config (CFG) (Kaggle-ready, offline)
# REVISI FULL v3.1 (lebih kuat + siap MULTI-CFG + lebih aman anti-error)
#
# Upgrade v3.1 (sesuai strategi UNet+ASPP + pipeline naik score):
# - PRED prefix priority: utamakan hasil training baru (UNet+ASPP / Fusion) jika ada
#   -> pred_unet_aspp_cfg_* dipilih dulu sebelum pred_base_*
# - MATCH juga prefer ada folder npz (test/train_all) + match_summary.json
# - Optional FORCE select via env (tanpa merusak multi-cfg list):
#     LUC_FORCE_PRED_CFG  = nama folder cfg (mis: pred_unet_aspp_cfg_xxx)
#     LUC_FORCE_MATCH_CFG = nama folder cfg (mis: match_base_cfg_xxx)
# - Sanity guard lebih jelas + PATHS tambah shortcut folder npz (aman untuk stage lanjut)
#
# Output globals (TETAP, JANGAN diganti):
# - COMP_ROOT, OUT_DS_ROOT, OUT_ROOT
# - PATHS (dict)
# - MATCH_CFG_DIR, PRED_CFG_DIR, DINO_CFG_DIR
#
# Extra globals (aman, membantu training lanjutan):
# - MATCH_CFG_DIRS, PRED_CFG_DIRS (list[Path] TOP-K)
# - MATCH_CFG_INFO, PRED_CFG_INFO (list[dict] detail skor)
# - CACHE_ROOTS (list[Path]), SELECTED (dict), TRAIN_PLAN (dict)
# ============================================================

import os, re, json, time
from pathlib import Path
import pandas as pd

# ----------------------------
# Config knobs (boleh diubah)
# ----------------------------
TOPK_MATCH_CFGS = int(os.environ.get("TOPK_MATCH_CFGS", "5"))
TOPK_PRED_CFGS  = int(os.environ.get("TOPK_PRED_CFGS",  "8"))

# Kalau kamu mau paksa pakai dataset input tertentu:
# export LUC_OUT_DS_ROOT="/kaggle/input/<nama_dataset_output>"
ENV_OUT_DS_ROOT = os.environ.get("LUC_OUT_DS_ROOT", "").strip()

# Optional: paksa pilih cfg primary (nama folder) tanpa mematikan multi-cfg list
FORCE_PRED_CFG_NAME  = os.environ.get("LUC_FORCE_PRED_CFG", "").strip()
FORCE_MATCH_CFG_NAME = os.environ.get("LUC_FORCE_MATCH_CFG", "").strip()

# ----------------------------
# Helper: fast count CSV rows (binary newline count)
# ----------------------------
def _fast_count_rows_csv(path: Path, assume_header: bool = True) -> int:
    """
    Count rows quickly by counting '\n' in binary.
    Returns -1 if error.
    """
    try:
        if not path.exists() or not path.is_file():
            return -1
        nl = 0
        with path.open("rb") as f:
            while True:
                b = f.read(1024 * 1024)
                if not b:
                    break
                nl += b.count(b"\n")
        rows = nl - (1 if assume_header else 0)
        return int(max(rows, 0))
    except Exception:
        return -1

def _safe_mtime(p: Path) -> float:
    try:
        return float(p.stat().st_mtime)
    except Exception:
        return 0.0

def _is_nonempty_file(p: Path) -> bool:
    try:
        return p is not None and p.exists() and p.is_file() and p.stat().st_size > 0
    except Exception:
        return False

def _dir_has_any_npz(d: Path) -> bool:
    try:
        if d is None or (not d.exists()) or (not d.is_dir()):
            return False
        for _ in d.glob("*.npz"):
            return True
        return False
    except Exception:
        return False

def _resolve_cfg_by_name(cache_roots: list, cfg_name: str) -> Path:
    """
    Find cfg directory by exact folder name under any cache root (one level deep).
    Returns Path or None.
    """
    if not cfg_name:
        return None
    for root in cache_roots:
        root = Path(root)
        if not root.exists():
            continue
        cand = root / cfg_name
        if cand.exists() and cand.is_dir():
            return cand
        # fallback: search one-level deep (avoid heavy recursion)
        try:
            for d in root.iterdir():
                if d.is_dir() and d.name == cfg_name:
                    return d
        except Exception:
            pass
    return None

# ----------------------------
# Helper: find competition root
# ----------------------------
def find_comp_root(preferred: str = "/kaggle/input/recodai-luc-scientific-image-forgery-detection") -> Path:
    p = Path(preferred)
    if p.exists():
        return p

    base = Path("/kaggle/input")
    if not base.exists():
        raise FileNotFoundError("/kaggle/input tidak ditemukan (pastikan kamu di Kaggle Notebook).")

    cands = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        if (d / "sample_submission.csv").exists() and ((d / "train_images").exists() or (d / "test_images").exists()):
            cands.append(d)

    if not cands:
        for d in base.iterdir():
            if not d.is_dir():
                continue
            for x in d.iterdir():
                if not x.is_dir():
                    continue
                if (x / "sample_submission.csv").exists() and ((x / "train_images").exists() or (x / "test_images").exists()):
                    cands.append(x)

    if not cands:
        raise FileNotFoundError(
            "COMP_ROOT tidak ditemukan. Harus ada folder yang memuat sample_submission.csv dan train_images/test_images."
        )

    cands.sort(key=lambda x: (("recodai" not in x.name.lower()),
                              ("forgery" not in x.name.lower()),
                              x.name))
    return cands[0]

# ----------------------------
# Helper: find output dataset root (hasil PREP)
# ----------------------------
def find_output_dataset_root(preferred_names=(
    "recod-ailuc-dinov2-base",
    "recod-ai-luc-dinov2-base",
    "recodai-luc-dinov2-base",
    "recodai-luc-dinov2",
    "recodai-luc-dinov2-prep",
)) -> Path:
    base = Path("/kaggle/input")

    if ENV_OUT_DS_ROOT:
        p = Path(ENV_OUT_DS_ROOT)
        if p.exists():
            return p
        else:
            print(f"WARNING: ENV LUC_OUT_DS_ROOT tidak ditemukan: {p}")

    for nm in preferred_names:
        p = base / nm
        if p.exists():
            return p

    cands = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        if (d / "recodai_luc" / "artifacts").exists() or (d / "recodai_luc" / "cache").exists():
            cands.append(d)
            continue
        inner = list(d.glob("*/recodai_luc/artifacts"))
        if inner:
            cands.append(d)

    if not cands:
        raise FileNotFoundError("OUT_DS_ROOT tidak ditemukan. Harus ada /kaggle/input/<...>/recodai_luc/(artifacts|cache)/")

    cands.sort(key=lambda x: (("dinov2" not in x.name.lower()), x.name))
    return cands[0]

# ----------------------------
# Helper: resolve OUT_ROOT = <dataset>/recodai_luc
# ----------------------------
def resolve_out_root(out_ds_root: Path) -> Path:
    direct = out_ds_root / "recodai_luc"
    if direct.exists():
        return direct
    hits = list(out_ds_root.glob("*/recodai_luc"))
    if hits:
        return hits[0]
    raise FileNotFoundError(f"Folder recodai_luc tidak ditemukan di bawah {out_ds_root}")

# ----------------------------
# Helper: pick TOP-K cfg directories by multi-criteria scoring
# ----------------------------
def pick_top_cfgs(
    cache_roots,
    prefixes,
    required_train_file: str,
    prefer_files=(),
    extra_prefer_dirs=(),
    max_return: int = 5,
) -> list:
    """
    Return list of candidate dicts sorted by score desc.
    Each candidate dict: {dir, score, train_rows, prefer_hits, mt, root, debug...}

    IMPORTANT:
    - prefixes order is treated as PRIORITY (earlier = more preferred).
      We add prefix bonus so pred_unet_aspp_cfg_* will win over pred_base_* if both valid.
    """
    if isinstance(prefixes, str):
        prefixes = [prefixes]
    prefixes = list(prefixes)

    prefer_files = list(prefer_files or [])
    extra_prefer_dirs = list(extra_prefer_dirs or [])

    cands = []
    for root in cache_roots:
        root = Path(root)
        if not root.exists():
            continue

        try:
            it = list(root.iterdir())
        except Exception:
            continue

        for d in it:
            if not d.is_dir():
                continue
            name = d.name

            # match prefix priority
            px_idx = None
            for i, px in enumerate(prefixes):
                if name.startswith(px):
                    px_idx = i
                    break
            if px_idx is None:
                continue

            train_fp = d / required_train_file
            if not train_fp.exists():
                continue

            train_n = _fast_count_rows_csv(train_fp)
            if train_n <= 0:
                continue

            # prefer files
            pref_hit = 0
            pref_rows_sum = 0
            pref_detail = {}
            for fn in prefer_files:
                fp = d / fn
                ok = _is_nonempty_file(fp)
                pref_detail[fn] = bool(ok)
                if ok:
                    pref_hit += 1
                    pref_rows_sum += max(_fast_count_rows_csv(fp), 0)

            # prefer dirs (npz availability or other dirs)
            dir_hit = 0
            dir_detail = {}
            for dn in extra_prefer_dirs:
                dd = d / dn
                ok = dd.exists() and dd.is_dir()
                if ok and dn in ("test", "train_all"):
                    ok = _dir_has_any_npz(dd)
                dir_detail[dn] = bool(ok)
                if ok:
                    dir_hit += 1

            # mtime: consider dir + train file + any prefer file
            mt = max(_safe_mtime(d), _safe_mtime(train_fp))
            for fn in prefer_files:
                mt = max(mt, _safe_mtime(d / fn))

            # prefix bonus (earlier prefix => bigger bonus)
            # keep big enough so model family preference wins unless clearly invalid
            # example: pred_unet_aspp_cfg_* should win over pred_base_* by default
            prefix_bonus = 0.0
            if len(prefixes) > 1:
                prefix_bonus = 3e6 * float(len(prefixes) - px_idx)

            # score design:
            # - prefix bonus (family priority)
            # - prefer files hits
            # - prefer dirs hits (npz availability)
            # - then train rows, then prefer rows
            # - then newest mtime
            score = 0.0
            score += prefix_bonus
            score += 1e6 * float(pref_hit)
            score += 2e5 * float(dir_hit)
            score += 1.0  * float(train_n)
            score += 0.05 * float(pref_rows_sum)
            score += 1e-6 * float(mt)

            cands.append({
                "dir": d,
                "root": root,
                "score": score,
                "prefix_idx": int(px_idx),
                "prefix": prefixes[px_idx],
                "train_file": str(train_fp),
                "train_rows": int(train_n),
                "prefer_hits": int(pref_hit),
                "prefer_rows_sum": int(pref_rows_sum),
                "prefer_detail": pref_detail,
                "dir_hits": int(dir_hit),
                "dir_detail": dir_detail,
                "mtime": float(mt),
            })

    cands.sort(key=lambda x: (-x["score"], x["dir"].name))
    return cands[:max_return]

# ----------------------------
# Helper: detect DINO model dir (offline)
# ----------------------------
def detect_dino_dir() -> Path:
    base = Path("/kaggle/input/dinov2/pytorch")
    if base.exists():
        for name in ["large", "giant", "base"]:
            p = base / name / "1"
            if p.exists():
                return p
    return Path("/kaggle/input/dinov2/pytorch/large/1")

def detect_dino_cache_cfg(cache_dirs: list) -> Path:
    """
    Find best cfg under cache/dino_v2_*/cfg_*/manifest_train_all.csv
    Prefer large then giant then base, but choose whichever has best manifest size.
    """
    best = None
    best_key = None  # tuple for sorting
    prio = {"dino_v2_large": 0, "dino_v2_giant": 1, "dino_v2_base": 2}

    for root in cache_dirs:
        root = Path(root)
        if not root.exists():
            continue

        for dino_name in ["dino_v2_large", "dino_v2_giant", "dino_v2_base"]:
            dino_root = root / dino_name
            if not dino_root.exists():
                continue

            try:
                cfgs = list(dino_root.iterdir())
            except Exception:
                continue

            for cfg in cfgs:
                if not (cfg.is_dir() and cfg.name.startswith("cfg_")):
                    continue
                mf = cfg / "manifest_train_all.csv"
                if not mf.exists():
                    continue

                n = _fast_count_rows_csv(mf)
                mt = _safe_mtime(cfg)
                key = (prio.get(dino_name, 9), -n, -mt, cfg.name)
                if best is None or key < best_key:
                    best = cfg
                    best_key = key

    return best  # can be None

# ============================================================
# 0) Locate roots
# ============================================================
COMP_ROOT = find_comp_root("/kaggle/input/recodai-luc-scientific-image-forgery-detection")
OUT_DS_ROOT = find_output_dataset_root()
OUT_ROOT = resolve_out_root(OUT_DS_ROOT)  # dataset input: .../recodai_luc

WORK_OUT_ROOT = Path("/kaggle/working/recodai_luc")

# Cache roots: prefer working if exists, plus input
CACHE_ROOTS = []
if (WORK_OUT_ROOT / "cache").exists():
    CACHE_ROOTS.append(WORK_OUT_ROOT / "cache")
if (OUT_ROOT / "cache").exists():
    CACHE_ROOTS.append(OUT_ROOT / "cache")

# Artifact roots: prefer working if exists, plus input
ART_ROOTS = []
if (WORK_OUT_ROOT / "artifacts").exists():
    ART_ROOTS.append(WORK_OUT_ROOT / "artifacts")
if (OUT_ROOT / "artifacts").exists():
    ART_ROOTS.append(OUT_ROOT / "artifacts")

# choose first existing artifact root
ART_DIR = None
for a in ART_ROOTS:
    if a.exists():
        ART_DIR = a
        break
if ART_DIR is None:
    raise FileNotFoundError("ART_DIR tidak ditemukan di /kaggle/working maupun dataset input.")

# cache dirs that exist
CACHE_DIRS = [p for p in CACHE_ROOTS if Path(p).exists()]
if not CACHE_DIRS:
    raise FileNotFoundError("CACHE_DIR tidak ditemukan di /kaggle/working maupun dataset input.")

# ============================================================
# 1) Competition paths (raw images/masks)
# ============================================================
PATHS = {}
PATHS["COMP_ROOT"] = str(COMP_ROOT)
PATHS["SAMPLE_SUB"] = str(COMP_ROOT / "sample_submission.csv")

PATHS["TRAIN_IMAGES"] = str(COMP_ROOT / "train_images")
PATHS["TEST_IMAGES"]  = str(COMP_ROOT / "test_images")
PATHS["TRAIN_MASKS"]  = str(COMP_ROOT / "train_masks")
PATHS["SUPP_IMAGES"]  = str(COMP_ROOT / "supplemental_images")
PATHS["SUPP_MASKS"]   = str(COMP_ROOT / "supplemental_masks")

PATHS["TRAIN_AUTH_DIR"] = str(COMP_ROOT / "train_images" / "authentic")
PATHS["TRAIN_FORG_DIR"] = str(COMP_ROOT / "train_images" / "forged")

# ============================================================
# 2) Output dataset paths (clean artifacts + cache)
# ============================================================
PATHS["OUT_DS_ROOT"] = str(OUT_DS_ROOT)
PATHS["OUT_ROOT"]    = str(OUT_ROOT)

PATHS["ART_DIR"]     = str(ART_DIR)
PATHS["CACHE_DIRS"]  = [str(x) for x in CACHE_DIRS]  # list

# artifacts utama
PATHS["DF_TRAIN_ALL"] = str(Path(ART_DIR) / "df_train_all.parquet")
PATHS["DF_TRAIN_CLS"] = str(Path(ART_DIR) / "df_train_cls.parquet")
PATHS["DF_TRAIN_SEG"] = str(Path(ART_DIR) / "df_train_seg.parquet")
PATHS["DF_TEST"]      = str(Path(ART_DIR) / "df_test.parquet")

PATHS["CV_CASE_FOLDS"]   = str(Path(ART_DIR) / "cv_case_folds.csv")
PATHS["CV_SAMPLE_FOLDS"] = str(Path(ART_DIR) / "cv_sample_folds.csv")

PATHS["IMG_PROFILE_TRAIN"] = str(Path(ART_DIR) / "image_profile_train.parquet")
PATHS["IMG_PROFILE_TEST"]  = str(Path(ART_DIR) / "image_profile_test.parquet")
PATHS["MASK_PROFILE"]      = str(Path(ART_DIR) / "mask_profile.parquet")
PATHS["CASE_SUMMARY"]      = str(Path(ART_DIR) / "case_summary.parquet")

# ============================================================
# 3) Select best MATCH/PRED CFG dirs automatically (TOP-K + scoring)
# ============================================================
# MATCH candidates: must have match_features_train_all.csv
# Prefer having match_features_test + manifests + match_summary + npz folders
MATCH_PREFIXES = ["match_base_cfg_"]  # (kalau nanti ada match v2, tambah di depan list ini)
MATCH_CFG_INFO = pick_top_cfgs(
    CACHE_DIRS,
    prefixes=MATCH_PREFIXES,
    required_train_file="match_features_train_all.csv",
    prefer_files=[
        "match_features_test.csv",
        "manifest_match_test.csv",
        "manifest_match_train_all.csv",
        "match_summary.json",
    ],
    extra_prefer_dirs=["test", "train_all"],  # match kamu memang punya npz, jadi prefer
    max_return=max(1, TOPK_MATCH_CFGS),
)

if not MATCH_CFG_INFO:
    raise FileNotFoundError("Tidak menemukan match cfg folder yang valid (match_features_train_all.csv).")

MATCH_CFG_DIRS = [x["dir"] for x in MATCH_CFG_INFO]
MATCH_CFG_DIR  = MATCH_CFG_DIRS[0]  # default primary

# PRED candidates: MUST include UNet+ASPP prefix priority (kalau sudah kamu train nanti)
# Urutan prefix = prioritas: lebih depan lebih dipilih
PRED_PREFIXES = [
    "pred_unet_aspp_cfg_",   # hasil training UNet+ASPP (yang kamu mau)
    "pred_fusion_cfg_",      # jika nanti ada fusion/staking
    "pred_base",             # fallback lama
]
PRED_CFG_INFO = pick_top_cfgs(
    CACHE_DIRS,
    prefixes=PRED_PREFIXES,
    required_train_file="pred_features_train_all.csv",
    prefer_files=[
        "pred_features_test.csv",
        "manifest_pred_test.csv",
        "manifest_pred_train_all.csv",
        "pred_summary.json",
    ],
    extra_prefer_dirs=["test", "train_all"],  # prefer having npz predictions
    max_return=max(1, TOPK_PRED_CFGS),
)

if not PRED_CFG_INFO:
    raise FileNotFoundError("Tidak menemukan pred cfg folder yang valid (pred_features_train_all.csv).")

PRED_CFG_DIRS = [x["dir"] for x in PRED_CFG_INFO]
PRED_CFG_DIR  = PRED_CFG_DIRS[0]  # default primary

# Optional FORCE override (primary only)
_force_match = _resolve_cfg_by_name(CACHE_DIRS, FORCE_MATCH_CFG_NAME) if FORCE_MATCH_CFG_NAME else None
_force_pred  = _resolve_cfg_by_name(CACHE_DIRS, FORCE_PRED_CFG_NAME) if FORCE_PRED_CFG_NAME else None

if _force_match is not None:
    print(f"FORCE: MATCH_CFG_DIR -> {_force_match.name}")
    MATCH_CFG_DIR = _force_match
if _force_pred is not None:
    print(f"FORCE: PRED_CFG_DIR  -> {_force_pred.name}")
    PRED_CFG_DIR = _force_pred

# DINO cache cfg (opsional)
DINO_CFG_DIR = detect_dino_cache_cfg(CACHE_DIRS)

# simpan cfg dir ke PATHS
PATHS["MATCH_CFG_DIR"]    = str(MATCH_CFG_DIR)
PATHS["PRED_CFG_DIR"]     = str(PRED_CFG_DIR)
PATHS["DINO_CFG_DIR"]     = str(DINO_CFG_DIR) if DINO_CFG_DIR else ""

# multi-cfg (baru, opsional)
PATHS["MATCH_CFG_DIRS"]   = [str(x) for x in MATCH_CFG_DIRS]
PATHS["PRED_CFG_DIRS"]    = [str(x) for x in PRED_CFG_DIRS]

# feature paths dari cfg primary
PATHS["MATCH_FEAT_TRAIN"] = str(MATCH_CFG_DIR / "match_features_train_all.csv")
PATHS["MATCH_FEAT_TEST"]  = str(MATCH_CFG_DIR / "match_features_test.csv")

PATHS["PRED_FEAT_TRAIN"]  = str(PRED_CFG_DIR / "pred_features_train_all.csv")
PATHS["PRED_FEAT_TEST"]   = str(PRED_CFG_DIR / "pred_features_test.csv")  # bisa missing (warning)

# manifests/summary (sering kepakai untuk infer/submission)
PATHS["PRED_MAN_TRAIN"] = str(PRED_CFG_DIR / "manifest_pred_train_all.csv")
PATHS["PRED_MAN_TEST"]  = str(PRED_CFG_DIR / "manifest_pred_test.csv")
PATHS["PRED_SUMMARY"]   = str(PRED_CFG_DIR / "pred_summary.json")

PATHS["MATCH_MAN_TRAIN"] = str(MATCH_CFG_DIR / "manifest_match_train_all.csv")
PATHS["MATCH_MAN_TEST"]  = str(MATCH_CFG_DIR / "manifest_match_test.csv")
PATHS["MATCH_SUMMARY"]   = str(MATCH_CFG_DIR / "match_summary.json")

# shortcut ke folder npz (aman untuk stage lanjut, walau mungkin kosong)
PATHS["PRED_NPZ_TRAIN_DIR"]  = str(PRED_CFG_DIR / "train_all")
PATHS["PRED_NPZ_TEST_DIR"]   = str(PRED_CFG_DIR / "test")
PATHS["MATCH_NPZ_TRAIN_DIR"] = str(MATCH_CFG_DIR / "train_all")
PATHS["MATCH_NPZ_TEST_DIR"]  = str(MATCH_CFG_DIR / "test")

# ============================================================
# 4) DINO model dir (offline)
# ============================================================
DINO_DIR = detect_dino_dir()
PATHS["DINO_DIR"] = str(DINO_DIR)

# ============================================================
# 5) Sanity checks (wajib ada untuk training)
# ============================================================
must_exist = [
    ("sample_submission.csv", PATHS["SAMPLE_SUB"]),
    ("df_train_all.parquet",  PATHS["DF_TRAIN_ALL"]),
    ("cv_case_folds.csv",     PATHS["CV_CASE_FOLDS"]),
    ("match_features_train_all.csv", PATHS["MATCH_FEAT_TRAIN"]),
    ("pred_features_train_all.csv",  PATHS["PRED_FEAT_TRAIN"]),
]
missing = [name for name, p in must_exist if not Path(p).exists()]
if missing:
    raise FileNotFoundError("Missing required files: " + ", ".join(missing))

# opsional tapi penting untuk inference: test feature files
opt_warn = []
if not Path(PATHS["MATCH_FEAT_TEST"]).exists():
    opt_warn.append("match_features_test.csv (MATCH_FEAT_TEST)")
if not Path(PATHS["PRED_FEAT_TEST"]).exists():
    opt_warn.append("pred_features_test.csv (PRED_FEAT_TEST)")
if opt_warn:
    print("WARNING: File opsional untuk inference belum ada:")
    for w in opt_warn:
        print(" -", w)
    print("Catatan: training masih aman (pakai *_train_all). Untuk inference gate ke test, file ini biasanya dibutuhkan.")

# npz dirs warning (tidak crash)
if not _dir_has_any_npz(Path(PATHS["PRED_NPZ_TRAIN_DIR"])):
    print(f"WARNING: PRED_NPZ_TRAIN_DIR tidak ada/empty: {PATHS['PRED_NPZ_TRAIN_DIR']}")
if not _dir_has_any_npz(Path(PATHS["PRED_NPZ_TEST_DIR"])):
    print(f"WARNING: PRED_NPZ_TEST_DIR tidak ada/empty: {PATHS['PRED_NPZ_TEST_DIR']}")
if not _dir_has_any_npz(Path(PATHS["MATCH_NPZ_TRAIN_DIR"])):
    print(f"WARNING: MATCH_NPZ_TRAIN_DIR tidak ada/empty: {PATHS['MATCH_NPZ_TRAIN_DIR']}")
if not _dir_has_any_npz(Path(PATHS["MATCH_NPZ_TEST_DIR"])):
    print(f"WARNING: MATCH_NPZ_TEST_DIR tidak ada/empty: {PATHS['MATCH_NPZ_TEST_DIR']}")

# DINO model dir opsional (warning saja)
if not Path(PATHS["DINO_DIR"]).exists():
    print(f"WARNING: DINO dir tidak ditemukan: {PATHS['DINO_DIR']}")

# ============================================================
# 6) Print summary + export helpers
# ============================================================
SELECTED = {
    "ART_DIR": str(ART_DIR),
    "CACHE_DIRS": [str(x) for x in CACHE_DIRS],
    "MATCH_CFG_DIR": str(MATCH_CFG_DIR),
    "PRED_CFG_DIR": str(PRED_CFG_DIR),
    "DINO_CFG_DIR": str(DINO_CFG_DIR) if DINO_CFG_DIR else "",
    "DINO_DIR": str(DINO_DIR),
    "TOPK_MATCH_CFGS": TOPK_MATCH_CFGS,
    "TOPK_PRED_CFGS": TOPK_PRED_CFGS,
    "FORCE_MATCH_CFG_NAME": FORCE_MATCH_CFG_NAME,
    "FORCE_PRED_CFG_NAME": FORCE_PRED_CFG_NAME,
    "PRED_PREFIXES": PRED_PREFIXES,
    "MATCH_PREFIXES": MATCH_PREFIXES,
}

# Training plan: disiapkan untuk upgrade lanjutan (stacking + calibration + threshold tune)
TRAIN_PLAN = {
    "seed": 2025,
    "group_col": "case_id",
    "target_col": "y_forged",
    "n_folds": 5,
    "use_calibration": True,
    "calibration": "isotonic",
    "tune_threshold_on_oof": True,
    "multi_cfg": {
        "enabled": True,
        "topk_match_cfgs": TOPK_MATCH_CFGS,
        "topk_pred_cfgs": TOPK_PRED_CFGS,
        "primary_match": MATCH_CFG_DIR.name,
        "primary_pred": PRED_CFG_DIR.name,
        "pred_prefix_priority": PRED_PREFIXES,
        "match_prefix_priority": MATCH_PREFIXES,
    },
}

def _print_top_cfg_table(title: str, info_list: list, max_rows: int = 5):
    print(title)
    if not info_list:
        print("  (none)")
        return
    show = info_list[:max_rows]
    for i, x in enumerate(show, 1):
        d = x["dir"]
        tr = x["train_rows"]
        ph = x["prefer_hits"]
        dh = x["dir_hits"]
        px = x.get("prefix", "")
        print(f"  #{i:02d} {d.name} | px={px} | score={x['score']:.1f} | train_rows={tr} | prefer_hits={ph} | dir_hits={dh}")

print("OK — Roots")
print("  COMP_ROOT   :", COMP_ROOT)
print("  OUT_DS_ROOT :", OUT_DS_ROOT)
print("  OUT_ROOT    :", OUT_ROOT)
print("  ART_DIR(use):", ART_DIR)
print("  CACHE_DIRS  :", [str(x) for x in CACHE_DIRS])

print("\nOK — Selected CFG (PRIMARY)")
print("  MATCH_CFG_DIR:", MATCH_CFG_DIR.name)
print("  PRED_CFG_DIR :", PRED_CFG_DIR.name)
print("  DINO_CFG_DIR :", (DINO_CFG_DIR.name if DINO_CFG_DIR else "(not found / optional)"))

print("\nOK — Top candidates (for MULTI-CFG training)")
_print_top_cfg_table("MATCH CFG TOP:", MATCH_CFG_INFO, max_rows=min(5, len(MATCH_CFG_INFO)))
_print_top_cfg_table("PRED  CFG TOP:", PRED_CFG_INFO,  max_rows=min(5, len(PRED_CFG_INFO)))

print("\nOK — Key files (train)")
for k in ["DF_TRAIN_ALL", "CV_CASE_FOLDS", "MATCH_FEAT_TRAIN", "PRED_FEAT_TRAIN", "IMG_PROFILE_TRAIN"]:
    p = Path(PATHS[k])
    print(f"  {k:16s}: {p}  {'(exists)' if p.exists() else '(missing/optional)'}")

print("\nOK — Key files (test/infer, optional)")
for k in ["MATCH_FEAT_TEST", "PRED_FEAT_TEST", "PRED_MAN_TEST", "PRED_SUMMARY", "MATCH_MAN_TEST", "MATCH_SUMMARY"]:
    p = Path(PATHS.get(k, ""))
    if str(p) == ".":
        print(f"  {k:16s}: (unset)")
    else:
        print(f"  {k:16s}: {p}  {'(exists)' if p.exists() else '(missing)'}")

print("\nOK — NPZ dirs (optional, for mask-based features / submission)")
for k in ["PRED_NPZ_TRAIN_DIR", "PRED_NPZ_TEST_DIR", "MATCH_NPZ_TRAIN_DIR", "MATCH_NPZ_TEST_DIR"]:
    p = Path(PATHS[k])
    ok = _dir_has_any_npz(p)
    print(f"  {k:16s}: {p}  {'(npz found)' if ok else '(empty/missing)'}")

print("\nOK — DINO model dir")
print("  DINO_DIR:", DINO_DIR, "(exists)" if DINO_DIR.exists() else "(missing)")

# export globals (kompatibilitas)
globals().update({
    "MATCH_CFG_DIR": MATCH_CFG_DIR,
    "PRED_CFG_DIR": PRED_CFG_DIR,
    "DINO_CFG_DIR": DINO_CFG_DIR,
    "CACHE_ROOTS": [Path(x) for x in CACHE_DIRS],
    "SELECTED": SELECTED,
    "TRAIN_PLAN": TRAIN_PLAN,

    # extra (multi-cfg)
    "MATCH_CFG_DIRS": MATCH_CFG_DIRS,
    "PRED_CFG_DIRS": PRED_CFG_DIRS,
    "MATCH_CFG_INFO": MATCH_CFG_INFO,
    "PRED_CFG_INFO": PRED_CFG_INFO,
})


OK — Roots
  COMP_ROOT   : /kaggle/input/recodai-luc-scientific-image-forgery-detection
  OUT_DS_ROOT : /kaggle/input/recod-ailuc-dinov2-base
  OUT_ROOT    : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc
  ART_DIR(use): /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts
  CACHE_DIRS  : ['/kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache']

OK — Selected CFG (PRIMARY)
  MATCH_CFG_DIR: match_base_cfg_f9f7ea3a65c5
  PRED_CFG_DIR : pred_base_v3_v7_cfg_5dbf0aa165
  DINO_CFG_DIR : cfg_3246fd54aab0

OK — Top candidates (for MULTI-CFG training)
MATCH CFG TOP:
  #01 match_base_cfg_f9f7ea3a65c5 | px=match_base_cfg_ | score=4407205.8 | train_rows=5176 | prefer_hits=4 | dir_hits=2
PRED  CFG TOP:
  #01 pred_base_v3_v7_cfg_5dbf0aa165 | px=pred_base | score=7407206.3 | train_rows=5176 | prefer_hits=4 | dir_hits=2

OK — Key files (train)
  DF_TRAIN_ALL    : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/df_train_all.parquet  (exists)
  CV_CASE_FOLDS   : /kaggle/input/re

# Build Training Table (X, y, folds)

In [None]:
# ============================================================
# STEP 2 — Build Training Table (X, y, folds) — REVISI FULL v3.2
# FIX utama v3.2:
# - NPZ overlap (pred vs match) dibuat SAFE: handle shape mismatch / mask mini / dim aneh
#   -> tidak lagi crash "operands could not be broadcast together"
# Output & flow lain TIDAK diubah.
# ============================================================

import os, re, json, gc, warnings, hashlib
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

# ----------------------------
# 0) Require PATHS
# ----------------------------
if "PATHS" not in globals() or not isinstance(PATHS, dict):
    raise RuntimeError("Missing PATHS. Jalankan dulu STAGE 0 — Set Paths & Select Config (CFG).")

# ----------------------------
# 1) Feature Engineering Config
# ----------------------------
FE_CFG = {
    "use_match_features": True,
    "use_image_profile": True,

    "use_npz_pair_features": True,
    "npz_downsample_for_cc": 256,
    "npz_bin_thr": 0.5,
    "npz_max_rows": int(os.environ.get("NPZ_MAX_ROWS", "0")),
    "npz_cache_enabled": True,

    "multi_cfg_enabled": True,
    "multi_cfg_max_pred": int(os.environ.get("MULTI_CFG_MAX_PRED", "6")),
    "multi_cfg_max_match": int(os.environ.get("MULTI_CFG_MAX_MATCH", "3")),
    "multi_cfg_extra_mode": "core+agg",

    "encode_variant_onehot": True,
    "variant_min_count": 1,

    "add_log_features": True,
    "add_sqrt_features": False,
    "add_interactions": True,
    "add_missing_indicators": True,

    "clip_by_quantile": True,
    "clip_q": 0.999,
    "clip_max_fallback": 1e9,

    "fillna_value": 0.0,

    "drop_constant_features": True,
    "cast_float32": True,

    "drop_unlabeled": True,
    "positive_value": 1,
}

# ----------------------------
# 2) Prefer WORKING features if exist
# ----------------------------
def _prefer_existing(*paths):
    for p in paths:
        if p is None:
            continue
        p = Path(str(p))
        if str(p).strip() == "":
            continue
        if p.exists():
            return p
    for p in paths:
        if p is None:
            continue
        p = Path(str(p))
        if str(p).strip() != "":
            return p
    return Path("")

def _to_str_series(s: pd.Series) -> pd.Series:
    return s.astype(str).replace({"nan": "", "None": ""})

def _ensure_uid(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    if "uid" not in df.columns:
        for alt in ["sample_id", "id", "key"]:
            if alt in df.columns:
                df = df.rename(columns={alt: "uid"})
                break
    if "uid" not in df.columns:
        raise ValueError("Cannot find uid column. Expected 'uid' or 'sample_id'.")
    df["uid"] = _to_str_series(df["uid"])
    return df

def _parse_case_variant_from_uid(uid_s: pd.Series) -> pd.DataFrame:
    uid = _to_str_series(uid_s)
    case1 = uid.str.extract(r"^(\d+)__")[0]
    var1  = uid.str.extract(r"__(.+)$")[0]
    case2 = uid.str.extract(r"^(\d+)_")[0]
    var2  = uid.str.extract(r"_(\w+)$")[0]
    case = case1.fillna(case2)
    var  = var1.fillna(var2).fillna("unk")
    return pd.DataFrame({"case_id": case, "variant": var})

def _ensure_case_variant(df: pd.DataFrame, df_base_map: pd.DataFrame = None) -> pd.DataFrame:
    df = _ensure_uid(df)

    if "case_id" in df.columns:
        df["case_id"] = pd.to_numeric(df["case_id"], errors="coerce")
    if "variant" in df.columns:
        df["variant"] = df["variant"].astype(str).replace({"nan": "unk", "None": "unk"})

    if df_base_map is not None and {"uid", "case_id", "variant"}.issubset(df_base_map.columns):
        need_merge = ("case_id" not in df.columns) or ("variant" not in df.columns) or df["case_id"].isna().any()
        if need_merge:
            df = df.merge(df_base_map[["uid", "case_id", "variant"]], on="uid", how="left", suffixes=("", "_base"))
            if "case_id_base" in df.columns:
                df["case_id"] = df["case_id"].fillna(df["case_id_base"])
                df = df.drop(columns=["case_id_base"])
            if "variant_base" in df.columns:
                df["variant"] = df["variant"].where(df["variant"].astype(str).str.len() > 0, df["variant_base"])
                df = df.drop(columns=["variant_base"])

    if ("case_id" not in df.columns) or ("variant" not in df.columns) or df["case_id"].isna().any():
        pv = _parse_case_variant_from_uid(df["uid"])
        if "case_id" not in df.columns:
            df["case_id"] = pd.to_numeric(pv["case_id"], errors="coerce")
        else:
            df["case_id"] = df["case_id"].fillna(pd.to_numeric(pv["case_id"], errors="coerce"))
        if "variant" not in df.columns:
            df["variant"] = pv["variant"]
        else:
            v = df["variant"].astype(str).replace({"nan": "unk", "None": "unk"})
            df["variant"] = v.where(v.str.len() > 0, pv["variant"])

    df["case_id"] = pd.to_numeric(df["case_id"], errors="coerce").astype("Int64")
    df["variant"] = df["variant"].astype(str).replace({"nan": "unk", "None": "unk"})
    return df

def _pick_label_col(df: pd.DataFrame) -> str:
    for cand in ["y_forged", "has_mask", "is_forged", "forged"]:
        if cand in df.columns:
            return cand
    return ""

def _slug(s: str) -> str:
    s = str(s)
    s = re.sub(r"[^0-9a-zA-Z_]+", "_", s)
    s = re.sub(r"_+", "_", s).strip("_")
    return s[:80] if len(s) > 80 else s

def _read_parquet_cols_safe(path: Path, desired_cols: list) -> pd.DataFrame:
    path = Path(path)
    try:
        import pyarrow.parquet as pq
        cols = pq.ParquetFile(path).schema.names
        use = [c for c in desired_cols if c in cols]
        if not use:
            return pd.read_parquet(path)
        return pd.read_parquet(path, columns=use)
    except Exception:
        df = pd.read_parquet(path)
        use = [c for c in desired_cols if c in df.columns]
        return df[use].copy() if use else df

# ----------------------------
# Resolve primary paths
# ----------------------------
match_cfg_name = Path(PATHS.get("MATCH_CFG_DIR", "")).name if PATHS.get("MATCH_CFG_DIR") else ""
pred_cfg_name  = Path(PATHS.get("PRED_CFG_DIR", "")).name  if PATHS.get("PRED_CFG_DIR") else ""

WORK_CACHE_ROOT = Path("/kaggle/working/recodai_luc/cache")
match_feat_work = (WORK_CACHE_ROOT / match_cfg_name / "match_features_train_all.csv") if match_cfg_name else None
pred_feat_work  = (WORK_CACHE_ROOT / pred_cfg_name  / "pred_features_train_all.csv")  if pred_cfg_name  else None

PRED_FEAT_TRAIN  = _prefer_existing(pred_feat_work,  PATHS.get("PRED_FEAT_TRAIN", ""))
MATCH_FEAT_TRAIN = _prefer_existing(match_feat_work, PATHS.get("MATCH_FEAT_TRAIN", ""))

DF_TRAIN_ALL      = Path(PATHS["DF_TRAIN_ALL"])
CV_CASE_FOLDS     = Path(PATHS["CV_CASE_FOLDS"])
IMG_PROFILE_TRAIN = Path(PATHS.get("IMG_PROFILE_TRAIN", ""))

for need_name, need_path in [
    ("df_train_all.parquet", DF_TRAIN_ALL),
    ("cv_case_folds.csv", CV_CASE_FOLDS),
    ("pred_features_train_all.csv", PRED_FEAT_TRAIN),
]:
    if not Path(need_path).exists():
        raise FileNotFoundError(f"Missing required file: {need_name} -> {need_path}")

print("Using (PRIMARY):")
print("  DF_TRAIN_ALL     :", DF_TRAIN_ALL)
print("  CV_CASE_FOLDS    :", CV_CASE_FOLDS)
print("  PRED_FEAT_TRAIN  :", PRED_FEAT_TRAIN)
print("  MATCH_FEAT_TRAIN :", MATCH_FEAT_TRAIN, "(optional)" if Path(MATCH_FEAT_TRAIN).exists() else "(missing/skip)")
print("  IMG_PROFILE_TRAIN:", IMG_PROFILE_TRAIN, "(optional)" if Path(IMG_PROFILE_TRAIN).exists() else "(missing/skip)")

PRED_NPZ_TRAIN_DIR  = Path(PATHS.get("PRED_NPZ_TRAIN_DIR", str(Path(PATHS.get("PRED_CFG_DIR", "")) / "train_all")))
MATCH_NPZ_TRAIN_DIR = Path(PATHS.get("MATCH_NPZ_TRAIN_DIR", str(Path(PATHS.get("MATCH_CFG_DIR", "")) / "train_all")))

# ----------------------------
# 3) Load minimal inputs
# ----------------------------
base_cols_want = ["sample_id", "uid", "case_id", "variant", "y_forged", "has_mask", "is_forged", "forged"]
df_base = _read_parquet_cols_safe(DF_TRAIN_ALL, base_cols_want)
df_cv   = pd.read_csv(CV_CASE_FOLDS)

df_pred_primary = pd.read_csv(PRED_FEAT_TRAIN, low_memory=False)

df_match_primary = None
if FE_CFG["use_match_features"] and Path(MATCH_FEAT_TRAIN).exists():
    try:
        df_match_primary = pd.read_csv(MATCH_FEAT_TRAIN, low_memory=False)
    except Exception:
        df_match_primary = None

df_prof = None
if FE_CFG["use_image_profile"] and Path(IMG_PROFILE_TRAIN).exists():
    try:
        df_prof = pd.read_parquet(IMG_PROFILE_TRAIN)
    except Exception:
        df_prof = None

# ----------------------------
# 4) Prepare base mapping from df_train_all
# ----------------------------
df_base = df_base.copy()
if "uid" not in df_base.columns:
    if "sample_id" in df_base.columns:
        df_base = df_base.rename(columns={"sample_id": "uid"})
    elif ("case_id" in df_base.columns and "variant" in df_base.columns):
        df_base["uid"] = _to_str_series(df_base["case_id"]) + "__" + _to_str_series(df_base["variant"])

df_base = _ensure_uid(df_base)

if "case_id" in df_base.columns:
    df_base["case_id"] = pd.to_numeric(df_base["case_id"], errors="coerce").astype("Int64")
if "variant" in df_base.columns:
    df_base["variant"] = df_base["variant"].astype(str).replace({"nan": "unk", "None": "unk"})

label_col = _pick_label_col(df_base)
if not label_col:
    raise ValueError("Cannot find label column in df_train_all (y_forged/has_mask/is_forged/forged).")

df_base_map = df_base.drop_duplicates(subset=["uid"], keep="first").copy()

# ----------------------------
# 5) Prepare folds
# ----------------------------
if "case_id" not in df_cv.columns or "fold" not in df_cv.columns:
    raise ValueError("cv_case_folds.csv must contain columns: case_id, fold")

df_cv = df_cv[["case_id", "fold"]].copy()
df_cv["case_id"] = pd.to_numeric(df_cv["case_id"], errors="coerce").astype("Int64")
df_cv["fold"]    = pd.to_numeric(df_cv["fold"], errors="coerce").astype("Int64")
df_cv = df_cv.dropna().astype({"case_id": int, "fold": int}).drop_duplicates("case_id")

# ----------------------------
# 6) Start from PRIMARY pred features
# ----------------------------
df_pred_primary = _ensure_case_variant(df_pred_primary, df_base_map=df_base_map)
if df_pred_primary["uid"].duplicated().any():
    df_pred_primary = df_pred_primary.drop_duplicates(subset=["uid"], keep="first").reset_index(drop=True)

df_train = df_pred_primary.copy()

df_train = df_train.merge(
    df_base_map[["uid", label_col]].rename(columns={label_col: "y"}),
    on="uid", how="left"
)

if df_train["y"].isna().any():
    miss = int(df_train["y"].isna().sum())
    raise ValueError(f"Label merge produced NaN in y: {miss} rows. Check df_train_all vs pred_features alignment.")

df_train["y"] = pd.to_numeric(df_train["y"], errors="coerce")

if FE_CFG["drop_unlabeled"]:
    before = len(df_train)
    df_train = df_train[df_train["y"].isin([0, 1])].copy()
    after = len(df_train)
    if before != after:
        print(f"NOTE: Dropped unlabeled rows (y not in {{0,1}}): {before-after} rows")

df_train["y"] = df_train["y"].astype(int)

df_train = df_train.drop(columns=["fold"], errors="ignore").merge(df_cv, on="case_id", how="left")
if df_train["fold"].isna().any():
    miss = int(df_train["fold"].isna().sum())
    raise ValueError(f"Missing fold after merging cv_case_folds.csv: {miss} rows.")
df_train["fold"] = df_train["fold"].astype(int)

# ----------------------------
# 7) Optional merge PRIMARY match features
# ----------------------------
if df_match_primary is not None:
    dfm = _ensure_case_variant(df_match_primary, df_base_map=df_base_map)
    if dfm["uid"].duplicated().any():
        dfm = dfm.drop_duplicates(subset=["uid"], keep="first").reset_index(drop=True)

    base_cols = set(df_train.columns)
    new_cols = [c for c in dfm.columns if c not in base_cols and c not in ["case_id", "variant"]]
    if new_cols:
        df_train = df_train.merge(dfm[["uid"] + new_cols], on="uid", how="left")

# ----------------------------
# 7.5) Aliasing
# ----------------------------
def _coalesce_numeric(df: pd.DataFrame, out_col: str, candidates: list):
    if out_col in df.columns:
        return
    for c in candidates:
        if c in df.columns:
            df[out_col] = pd.to_numeric(df[c], errors="coerce")
            return

_coalesce_numeric(df_train, "area_frac", ["area_frac", "pred_area_frac", "mask_area_frac", "area_frac_pred", "area_frac_unet", "unet_area_frac"])
_coalesce_numeric(df_train, "grid_area_frac", ["grid_area_frac", "grid_mask_area_frac", "grid_area"])
_coalesce_numeric(df_train, "best_count", ["best_count", "n_pairs", "pairs", "pair_count"])
_coalesce_numeric(df_train, "best_mean_sim", ["best_mean_sim", "mean_sim", "sim_mean"])
_coalesce_numeric(df_train, "peak_ratio", ["peak_ratio", "peak_to_mean", "peak_over_mean"])
_coalesce_numeric(df_train, "inlier_ratio", ["inlier_ratio", "ransac_inlier_ratio", "inliers_ratio"])
_coalesce_numeric(df_train, "has_peak", ["has_peak", "peak_found", "has_mode"])

_coalesce_numeric(df_train, "n_cc_pred", ["n_cc_pred", "n_cc", "n_components", "num_components", "n_comp", "pred_n_cc"])
_coalesce_numeric(df_train, "largest_cc_frac_pred", ["largest_cc_frac_pred", "largest_cc_frac", "largest_comp_frac"])
_coalesce_numeric(df_train, "mean_prob_inside_pred", ["mean_prob_inside_pred", "mean_prob_inside", "mean_inside_prob"])
_coalesce_numeric(df_train, "p90_prob_inside_pred", ["p90_prob_inside_pred", "p90_prob_inside", "p90_inside_prob"])
_coalesce_numeric(df_train, "max_prob_pred", ["max_prob_pred", "max_prob", "pred_max_prob"])

# ----------------------------
# 8) MULTI-CFG extras
# ----------------------------
def _short_cfg_tag(cfg_dir: Path, idx: int) -> str:
    nm = cfg_dir.name
    m = re.search(r"(cfg_[0-9a-f]{6,})", nm)
    tag = m.group(1) if m else nm
    tag = tag.replace("pred_unet_aspp_cfg_", "pu_").replace("pred_fusion_cfg_", "pf_").replace("pred_base_", "pb_")
    tag = tag.replace("match_base_cfg_", "m_")
    tag = re.sub(r"[^0-9a-zA-Z_]+", "_", tag)
    return f"{idx:02d}_{tag[:24]}"

CORE_PRED_COLS = [
    "best_count","best_mean_sim","peak_ratio","best_weight","best_weight_frac","inlier_ratio","n_pairs_thr","n_pairs_mnn","has_peak",
    "grid_h","grid_w","grid_area_frac",
    "area_frac","pred_area_frac","mask_area_frac","n_cc","n_comp","n_components","largest_cc_frac","mean_prob_inside","p90_prob_inside","max_prob",
]
CORE_MATCH_COLS = [
    "best_count","best_mean_sim","peak_ratio","best_weight","best_weight_frac","inlier_ratio","n_pairs_thr","n_pairs_mnn","has_peak",
    "grid_h","grid_w","grid_area_frac",
    "area_frac","mask_area_frac","n_cc","n_comp","largest_cc_frac",
]

def _load_csv_core(csv_path: Path, core_cols: list) -> pd.DataFrame:
    df = pd.read_csv(csv_path, low_memory=False)
    df = _ensure_uid(df)
    keep = ["uid"] + [c for c in core_cols if c in df.columns]
    return df[keep].copy()

pred_cfg_dirs = [Path(x) for x in PATHS.get("PRED_CFG_DIRS", [])] if FE_CFG["multi_cfg_enabled"] else []
match_cfg_dirs = [Path(x) for x in PATHS.get("MATCH_CFG_DIRS", [])] if FE_CFG["multi_cfg_enabled"] else []

if pred_cfg_dirs:
    pred_cfg_dirs = [d for d in pred_cfg_dirs if d.exists()]
    pred_cfg_dirs = pred_cfg_dirs[:max(1, FE_CFG["multi_cfg_max_pred"])]
else:
    pred_cfg_dirs = [Path(PATHS["PRED_CFG_DIR"])]

if match_cfg_dirs:
    match_cfg_dirs = [d for d in match_cfg_dirs if d.exists()]
    match_cfg_dirs = match_cfg_dirs[:max(1, FE_CFG["multi_cfg_max_match"])]
else:
    match_cfg_dirs = [Path(PATHS["MATCH_CFG_DIR"])] if FE_CFG["use_match_features"] else []

extra_pred_cfgs = [d for d in pred_cfg_dirs if str(d) != str(Path(PATHS["PRED_CFG_DIR"]))]
extra_match_cfgs = [d for d in match_cfg_dirs if str(d) != str(Path(PATHS["MATCH_CFG_DIR"]))]

pred_core_cols_added = []
pred_core_matrix_cols = {}

for i, cfg_dir in enumerate(extra_pred_cfgs, 1):
    fp = cfg_dir / "pred_features_train_all.csv"
    if not fp.exists():
        continue
    try:
        tag = _short_cfg_tag(cfg_dir, i)
        df_extra = _load_csv_core(fp, CORE_PRED_COLS)
        df_extra = _ensure_case_variant(df_extra, df_base_map=df_base_map)
        ren = {}
        for c in df_extra.columns:
            if c == "uid":
                continue
            ren[c] = f"{c}__{tag}"
            pred_core_matrix_cols.setdefault(c, []).append(ren[c])
        df_extra = df_extra.rename(columns=ren)
        df_train = df_train.merge(df_extra[["uid"] + list(ren.values())], on="uid", how="left")
        pred_core_cols_added.extend(list(ren.values()))
    except Exception as e:
        print(f"WARNING: skip extra pred cfg {cfg_dir.name} due to error: {e}")

if FE_CFG["multi_cfg_extra_mode"] == "core+agg":
    for base_name, cols_suff in pred_core_matrix_cols.items():
        cols_all = []
        if base_name in df_train.columns:
            cols_all.append(base_name)
        cols_all.extend([c for c in cols_suff if c in df_train.columns])
        cols_all = [c for c in cols_all if c in df_train.columns]
        if len(cols_all) < 2:
            continue
        mat = df_train[cols_all].apply(pd.to_numeric, errors="coerce")
        df_train[f"cfg_mean_{base_name}"] = mat.mean(axis=1, skipna=True)
        df_train[f"cfg_max_{base_name}"]  = mat.max(axis=1, skipna=True)
        df_train[f"cfg_min_{base_name}"]  = mat.min(axis=1, skipna=True)
        df_train[f"cfg_std_{base_name}"]  = mat.std(axis=1, skipna=True)

match_core_cols_added = []
match_core_matrix_cols = {}

if FE_CFG["use_match_features"]:
    for i, cfg_dir in enumerate(extra_match_cfgs, 1):
        fp = cfg_dir / "match_features_train_all.csv"
        if not fp.exists():
            continue
        try:
            tag = _short_cfg_tag(cfg_dir, i)
            df_extra = _load_csv_core(fp, CORE_MATCH_COLS)
            df_extra = _ensure_case_variant(df_extra, df_base_map=df_base_map)
            ren = {}
            for c in df_extra.columns:
                if c == "uid":
                    continue
                ren[c] = f"{c}__{tag}"
                match_core_matrix_cols.setdefault(c, []).append(ren[c])
            df_extra = df_extra.rename(columns=ren)
            df_train = df_train.merge(df_extra[["uid"] + list(ren.values())], on="uid", how="left")
            match_core_cols_added.extend(list(ren.values()))
        except Exception as e:
            print(f"WARNING: skip extra match cfg {cfg_dir.name} due to error: {e}")

    if FE_CFG["multi_cfg_extra_mode"] == "core+agg":
        for base_name, cols_suff in match_core_matrix_cols.items():
            cols_all = []
            if base_name in df_train.columns:
                cols_all.append(base_name)
            cols_all.extend([c for c in cols_suff if c in df_train.columns])
            cols_all = [c for c in cols_all if c in df_train.columns]
            if len(cols_all) < 2:
                continue
            mat = df_train[cols_all].apply(pd.to_numeric, errors="coerce")
            df_train[f"cfg_mean_match_{base_name}"] = mat.mean(axis=1, skipna=True)
            df_train[f"cfg_max_match_{base_name}"]  = mat.max(axis=1, skipna=True)
            df_train[f"cfg_min_match_{base_name}"]  = mat.min(axis=1, skipna=True)
            df_train[f"cfg_std_match_{base_name}"]  = mat.std(axis=1, skipna=True)

# ----------------------------
# 9) Optional merge image profile
# ----------------------------
if df_prof is not None and "case_id" in df_prof.columns:
    df_prof2 = df_prof.copy()
    df_prof2["case_id"] = pd.to_numeric(df_prof2["case_id"], errors="coerce").astype("Int64")
    df_prof2 = df_prof2.dropna(subset=["case_id"]).astype({"case_id": int})
    df_prof2 = df_prof2.drop_duplicates("case_id")

    prof_num = ["case_id"] + [
        c for c in df_prof2.columns
        if c != "case_id" and pd.api.types.is_numeric_dtype(df_prof2[c])
    ]
    df_prof2 = df_prof2[prof_num].copy()

    ren = {c: f"profile_{c}" for c in df_prof2.columns if c != "case_id"}
    df_prof2 = df_prof2.rename(columns=ren)
    df_train = df_train.merge(df_prof2, on="case_id", how="left")

# ----------------------------
# 9.5) NPZ pair features (SAFE overlap)
# ----------------------------
def _dir_has_any_npz(d: Path) -> bool:
    try:
        if d is None or (not d.exists()) or (not d.is_dir()):
            return False
        for _ in d.glob("*.npz"):
            return True
        return False
    except Exception:
        return False

def _uid_to_npz_path(npz_dir: Path, uid: str) -> Path:
    uid = str(uid)
    if uid.endswith(".npz"):
        return npz_dir / uid
    cands = [
        npz_dir / f"{uid}.npz",
        npz_dir / f"{uid.replace('__','_')}.npz",
        npz_dir / f"{uid.replace('_','__')}.npz",
    ]
    if "__authentic" in uid:
        cands.append(npz_dir / f"{uid.replace('__authentic','__auth')}.npz")
    if "__forged" in uid:
        cands.append(npz_dir / f"{uid.replace('__forged','__forg')}.npz")
    for p in cands:
        if p.exists():
            return p
    return cands[0]

try:
    import scipy.ndimage as ndi
    _HAS_NDI = True
except Exception:
    _HAS_NDI = False
try:
    import cv2
    _HAS_CV2 = True
except Exception:
    _HAS_CV2 = False

def _maybe_downsample_mask(mask: np.ndarray, ds: int):
    if ds is None or ds <= 0:
        return mask
    H, W = mask.shape[:2]
    if H <= ds and W <= ds:
        return mask
    if _HAS_CV2:
        return cv2.resize(mask.astype(np.uint8), (ds, ds), interpolation=cv2.INTER_NEAREST).astype(bool)
    ys = (np.linspace(0, H-1, ds)).astype(int)
    xs = (np.linspace(0, W-1, ds)).astype(int)
    return mask[np.ix_(ys, xs)]

def _count_cc_and_largest(mask_bool: np.ndarray, ds: int):
    m = _maybe_downsample_mask(mask_bool, ds)
    m = (m > 0).astype(np.uint8)
    if m.sum() == 0:
        return 0, 0
    if _HAS_NDI:
        lab, n = ndi.label(m)
        if n <= 0:
            return 0, 0
        sizes = np.bincount(lab.ravel())
        sizes[0] = 0
        largest = int(sizes.max()) if sizes.size else 0
        return int(n), int(largest)
    if _HAS_CV2:
        n, lab = cv2.connectedComponents(m, connectivity=8)
        n_cc = int(max(n - 1, 0))
        if n_cc <= 0:
            return 0, 0
        sizes = np.bincount(lab.ravel())
        sizes[0] = 0
        largest = int(sizes.max()) if sizes.size else 0
        return int(n_cc), int(largest)
    return -1, -1

def _extract_mask_prob_from_npz(npz_path: Path):
    try:
        z = np.load(npz_path, allow_pickle=True)
        keys = list(z.files)
        arrs = {k: z[k] for k in keys}
    except Exception:
        return None, None

    def _pick_first(keys_like):
        for k in keys_like:
            if k in arrs:
                return k
        return None

    k_prob = _pick_first(["prob", "probs", "pred_prob", "mask_prob", "p", "pred"])
    k_mask = _pick_first(["mask", "masks", "bin_mask", "pred_mask", "mask_bin"])

    prob = None
    mask = None

    if k_prob is not None:
        a = arrs[k_prob]
        if isinstance(a, np.ndarray) and a.ndim >= 2:
            prob = a.astype(np.float32)

    if k_mask is not None:
        a = arrs[k_mask]
        if isinstance(a, np.ndarray):
            if a.ndim == 3:
                a = np.max(a, axis=0)
            mask = (a > 0).astype(bool)

    if mask is None and prob is None:
        best_k, best_n = None, -1
        for k, a in arrs.items():
            if not isinstance(a, np.ndarray) or a.ndim < 2:
                continue
            n = int(np.prod(a.shape))
            if n > best_n:
                best_k, best_n = k, n
        if best_k is not None:
            a = arrs[best_k]
            a2 = np.max(a, axis=0) if a.ndim == 3 else a
            if np.issubdtype(a2.dtype, np.floating):
                prob = a2.astype(np.float32)
            else:
                mask = (a2 > 0).astype(bool)

    if mask is None and prob is not None:
        mask = (prob >= float(FE_CFG["npz_bin_thr"]))

    if prob is not None and prob.ndim == 3:
        prob = np.max(prob, axis=0)
    if mask is not None and mask.ndim == 3:
        mask = np.max(mask, axis=0).astype(bool)

    # squeeze final (safety)
    if prob is not None:
        prob = np.asarray(prob).squeeze()
        if prob.ndim != 2:
            prob = None
    if mask is not None:
        mask = np.asarray(mask).squeeze()
        if mask.ndim != 2:
            mask = None

    return mask, prob

def _mask_stats(mask: np.ndarray, prob: np.ndarray, ds_cc: int):
    out = {}
    if mask is None:
        out.update({
            "npz_area_frac": np.nan,
            "npz_n_cc": np.nan,
            "npz_largest_cc_frac": np.nan,
            "npz_mean_prob_inside": np.nan,
            "npz_p90_prob_inside": np.nan,
            "npz_max_prob": np.nan,
        })
        return out

    m = (mask > 0)
    H, W = m.shape[:2]
    tot = float(H * W) if H and W else 0.0
    area = float(m.sum())
    out["npz_area_frac"] = (area / tot) if tot > 0 else 0.0

    n_cc, largest = _count_cc_and_largest(m, ds_cc)
    out["npz_n_cc"] = float(n_cc) if n_cc >= 0 else np.nan
    out["npz_largest_cc_frac"] = (float(largest) / area) if area > 0 and largest >= 0 else np.nan

    if prob is not None and prob.shape[:2] == m.shape[:2]:
        p = prob.astype(np.float32)
        inside = p[m]
        if inside.size > 0:
            out["npz_mean_prob_inside"] = float(np.mean(inside))
            out["npz_p90_prob_inside"] = float(np.quantile(inside, 0.90))
            out["npz_max_prob"] = float(np.max(inside))
        else:
            out["npz_mean_prob_inside"] = 0.0
            out["npz_p90_prob_inside"] = 0.0
            out["npz_max_prob"] = 0.0
    else:
        out["npz_mean_prob_inside"] = np.nan
        out["npz_p90_prob_inside"] = np.nan
        out["npz_max_prob"] = np.nan

    return out

# -------- FIX v3.2: SAFE overlap helpers --------
def _resize_bool_to_hw(mask_bool: np.ndarray, out_hw):
    H, W = int(out_hw[0]), int(out_hw[1])
    m = np.asarray(mask_bool).astype(np.uint8).squeeze()
    if m.ndim != 2:
        return None
    if m.shape[0] == H and m.shape[1] == W:
        return (m > 0)
    if _HAS_CV2:
        m2 = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST)
        return (m2 > 0)
    h, w = m.shape
    if h <= 0 or w <= 0 or H <= 0 or W <= 0:
        return None
    ys = (np.linspace(0, h - 1, H)).astype(int)
    xs = (np.linspace(0, w - 1, W)).astype(int)
    m2 = m[np.ix_(ys, xs)]
    return (m2 > 0)

def _safe_overlap_features(mask_a, mask_b, min_pixels=16):
    out = {
        "pm_inter_frac": np.nan,
        "pm_union_frac": np.nan,
        "pm_iou": np.nan,
        "pm_pred_minus_match_frac": np.nan,
        "pm_match_minus_pred_frac": np.nan,
    }
    if mask_a is None or mask_b is None:
        return out
    a = np.asarray(mask_a).squeeze()
    b = np.asarray(mask_b).squeeze()
    if a.ndim != 2 or b.ndim != 2:
        return out
    if a.size < int(min_pixels) or b.size < int(min_pixels):
        return out

    a = (a > 0)
    b = (b > 0)

    if a.shape != b.shape:
        b2 = _resize_bool_to_hw(b, a.shape)
        if b2 is None:
            return out
        b = b2

    inter = float(np.logical_and(a, b).sum())
    union = float(np.logical_or(a, b).sum())
    tot = float(a.size) if a.size else 0.0

    out["pm_inter_frac"] = (inter / tot) if tot > 0 else 0.0
    out["pm_union_frac"] = (union / tot) if tot > 0 else 0.0
    out["pm_iou"] = (inter / union) if union > 0 else 0.0
    out["pm_pred_minus_match_frac"] = ((float(a.sum()) - inter) / tot) if tot > 0 else 0.0
    out["pm_match_minus_pred_frac"] = ((float(b.sum()) - inter) / tot) if tot > 0 else 0.0
    return out
# -----------------------------------------------

def _build_npz_pair_features(uids: list, pred_dir: Path, match_dir: Path, cache_path: Path):
    if FE_CFG["npz_cache_enabled"] and cache_path.exists():
        try:
            dfc = pd.read_parquet(cache_path)
            dfc = _ensure_uid(dfc)
            return dfc
        except Exception:
            pass

    rows = []
    ds_cc = int(FE_CFG["npz_downsample_for_cc"] or 0)

    max_rows = int(FE_CFG["npz_max_rows"] or 0)
    if max_rows > 0:
        uids = uids[:max_rows]

    for i, uid in enumerate(uids, 1):
        r = {"uid": str(uid)}

        pred_p = _uid_to_npz_path(pred_dir, uid) if pred_dir is not None else None
        match_p = _uid_to_npz_path(match_dir, uid) if match_dir is not None else None

        mp = None; pp = None
        if pred_p is not None and pred_p.exists():
            mp, pp = _extract_mask_prob_from_npz(pred_p)
        pred_stats = _mask_stats(mp, pp, ds_cc)
        for k, v in pred_stats.items():
            r[f"pred_{k}"] = v

        mm = None; pm = None
        if match_p is not None and match_p.exists():
            mm, pm = _extract_mask_prob_from_npz(match_p)
        match_stats = _mask_stats(mm, pm, ds_cc)
        for k, v in match_stats.items():
            r[f"match_{k}"] = v

        # FIX v3.2: SAFE overlap (no broadcast error)
        r.update(_safe_overlap_features(mp, mm, min_pixels=16))

        rows.append(r)

        if (i % 500) == 0:
            print(f"  NPZ features progress: {i}/{len(uids)}")

    dfc = pd.DataFrame(rows)
    if FE_CFG["npz_cache_enabled"]:
        try:
            cache_path.parent.mkdir(parents=True, exist_ok=True)
            dfc.to_parquet(cache_path, index=False)
        except Exception:
            pass
    return dfc

npz_ready = FE_CFG["use_npz_pair_features"] and _dir_has_any_npz(PRED_NPZ_TRAIN_DIR) and _dir_has_any_npz(MATCH_NPZ_TRAIN_DIR)
if npz_ready:
    OUT_ART = Path("/kaggle/working/recodai_luc_gate_artifacts")
    OUT_ART.mkdir(parents=True, exist_ok=True)

    key = f"npz_pair_{_slug(pred_cfg_name)}__{_slug(match_cfg_name)}__thr{FE_CFG['npz_bin_thr']}"
    h = hashlib.md5(key.encode()).hexdigest()[:12]
    cache_path = OUT_ART / f"npz_pair_features_train_all_{h}.parquet"

    print("\nBuilding NPZ pair features (pred+match) ...")
    uids = df_train["uid"].astype(str).tolist()
    df_npz = _build_npz_pair_features(uids, PRED_NPZ_TRAIN_DIR, MATCH_NPZ_TRAIN_DIR, cache_path)
    df_npz = _ensure_uid(df_npz)

    keep_cols = [c for c in df_npz.columns if c != "uid"]
    df_train = df_train.merge(df_npz[["uid"] + keep_cols], on="uid", how="left")

    _coalesce_numeric(df_train, "area_frac", ["area_frac", "pred_npz_area_frac", "pred_npz_area_frac"])
    _coalesce_numeric(df_train, "n_cc_pred", ["n_cc_pred", "pred_npz_n_cc"])
    _coalesce_numeric(df_train, "largest_cc_frac_pred", ["largest_cc_frac_pred", "pred_npz_largest_cc_frac"])
    _coalesce_numeric(df_train, "mean_prob_inside_pred", ["mean_prob_inside_pred", "pred_npz_mean_prob_inside"])
    _coalesce_numeric(df_train, "p90_prob_inside_pred", ["p90_prob_inside_pred", "pred_npz_p90_prob_inside"])
    _coalesce_numeric(df_train, "max_prob_pred", ["max_prob_pred", "pred_npz_max_prob"])

else:
    print("\nNOTE: NPZ pair features skipped (missing/empty pred/match npz dirs).")
    print("  PRED_NPZ_TRAIN_DIR :", PRED_NPZ_TRAIN_DIR, "(npz found)" if _dir_has_any_npz(PRED_NPZ_TRAIN_DIR) else "(empty/missing)")
    print("  MATCH_NPZ_TRAIN_DIR:", MATCH_NPZ_TRAIN_DIR, "(npz found)" if _dir_has_any_npz(MATCH_NPZ_TRAIN_DIR) else "(empty/missing)")

# ----------------------------
# 10) Feature engineering helpers
# ----------------------------
def _num(s):
    return pd.to_numeric(s, errors="coerce")

def safe_log1p_nonneg(x):
    x = np.asarray(x, dtype=np.float64)
    x = np.where(np.isfinite(x), x, 0.0)
    x = np.clip(x, 0.0, None)
    return np.log1p(x)

def safe_sqrt_nonneg(x):
    x = np.asarray(x, dtype=np.float64)
    x = np.where(np.isfinite(x), x, 0.0)
    x = np.clip(x, 0.0, None)
    return np.sqrt(x)

def get_clip_cap(series: pd.Series, q: float, fallback: float):
    s = _num(series).astype(float).replace([np.inf, -np.inf], np.nan).dropna()
    if len(s) == 0:
        return float(fallback)
    s = s[np.isfinite(s)]
    cap = float(np.quantile(np.abs(s.values), q))
    if (not np.isfinite(cap)) or (cap <= 0):
        return float(fallback)
    return float(cap)

# ----------------------------
# 11) Candidate numeric feature list (pre-fill)
# ----------------------------
TARGET_COLS = {"y", "y_forged", "has_mask", "is_forged", "forged"}
SPLIT_COLS  = {"fold"}
ID_DROP_NUM = {"case_id"}

for c in df_train.columns:
    if pd.api.types.is_numeric_dtype(df_train[c]):
        df_train[c] = df_train[c].replace([np.inf, -np.inf], np.nan)

missing_ind_cols = []
if FE_CFG["add_missing_indicators"]:
    for c in df_train.columns:
        if c in TARGET_COLS or c in SPLIT_COLS or c in ID_DROP_NUM:
            continue
        if pd.api.types.is_numeric_dtype(df_train[c]) and df_train[c].isna().any():
            ind = f"isna_{c}"
            df_train[ind] = df_train[c].isna().astype(np.uint8)
            missing_ind_cols.append(ind)

heavy_candidates = set([
    "peak_ratio", "best_weight", "best_count", "best_weight_frac",
    "pair_count", "n_pairs_thr", "n_pairs_mnn", "overmask_tighten_steps",
    "largest_comp", "n_comp", "grid_h", "grid_w",
    "grid_area_frac", "area_frac", "inlier_ratio",
    "n_cc_pred", "largest_cc_frac_pred", "mean_prob_inside_pred", "p90_prob_inside_pred", "max_prob_pred",
    "pm_iou", "pm_union_frac", "pm_inter_frac",
])
for c in df_train.columns:
    cl = c.lower()
    if any(k in cl for k in ["count", "pairs", "weight", "ratio", "area", "comp", "std_", "iou", "prob"]):
        if pd.api.types.is_numeric_dtype(df_train[c]):
            heavy_candidates.add(c)

clip_caps = {}
if FE_CFG["clip_by_quantile"]:
    for c in sorted(list(heavy_candidates)):
        if c in df_train.columns and pd.api.types.is_numeric_dtype(df_train[c]):
            clip_caps[c] = get_clip_cap(df_train[c], FE_CFG["clip_q"], FE_CFG["clip_max_fallback"])

if FE_CFG["add_log_features"] or FE_CFG["add_sqrt_features"]:
    for c, cap in clip_caps.items():
        x = _num(df_train[c]).fillna(0.0).astype(float).values
        x = np.clip(x, -cap, cap)
        df_train[f"{c}_cap"] = x.astype(np.float32)

        if FE_CFG["add_log_features"]:
            df_train[f"logabs_{c}"] = safe_log1p_nonneg(np.abs(x)).astype(np.float32)
        if FE_CFG["add_sqrt_features"]:
            df_train[f"sqrtabs_{c}"] = safe_sqrt_nonneg(np.abs(x)).astype(np.float32)

if FE_CFG["add_interactions"]:
    def getf(col, default=0.0):
        if col in df_train.columns:
            return _num(df_train[col]).fillna(default).astype(float).values
        return np.full(len(df_train), default, dtype=np.float64)

    best_mean_sim = getf("best_mean_sim", 0.0)
    best_count    = getf("best_count", 0.0)
    peak_ratio    = getf("peak_ratio", 0.0)
    has_peak      = getf("has_peak", 0.0)
    grid_area     = getf("grid_area_frac", 0.0)
    area_frac     = getf("area_frac", 0.0)
    n_pairs_thr   = getf("n_pairs_thr", 0.0)
    n_pairs_mnn   = getf("n_pairs_mnn", 0.0)
    inlier_ratio  = getf("inlier_ratio", 0.0)
    gh = getf("grid_h", 0.0)
    gw = getf("grid_w", 0.0)
    gridN = np.clip(gh * gw, 0.0, None)

    df_train["sim_x_count"]      = (best_mean_sim * best_count).astype(np.float32)
    df_train["peak_x_sim"]       = (peak_ratio * best_mean_sim).astype(np.float32)
    df_train["haspeak_x_sim"]    = (has_peak * best_mean_sim).astype(np.float32)
    df_train["area_x_sim"]       = (grid_area * best_mean_sim).astype(np.float32)
    df_train["area_x_count"]     = (grid_area * best_count).astype(np.float32)
    df_train["mask_grid_ratio"]  = (area_frac / (1e-6 + grid_area)).astype(np.float32)
    df_train["mnn_ratio"]        = (n_pairs_mnn / (1.0 + n_pairs_thr)).astype(np.float32)
    df_train["pairs_per_cell"]   = (n_pairs_thr / (1.0 + gridN)).astype(np.float32)
    df_train["inlier_x_pairs"]   = (inlier_ratio * n_pairs_thr).astype(np.float32)

    df_train["log1p_pairs_thr"]  = safe_log1p_nonneg(n_pairs_thr).astype(np.float32)
    df_train["log1p_best_count"] = safe_log1p_nonneg(best_count).astype(np.float32)
    df_train["log1p_area_frac"]  = safe_log1p_nonneg(np.clip(area_frac, 0, None)).astype(np.float32)

    ncc   = getf("n_cc_pred", 0.0)
    lccf  = getf("largest_cc_frac_pred", 0.0)
    mpin  = getf("mean_prob_inside_pred", 0.0)
    p90in = getf("p90_prob_inside_pred", 0.0)
    mxp   = getf("max_prob_pred", 0.0)

    df_train["prob_mean_x_area"] = (mpin * area_frac).astype(np.float32)
    df_train["prob_p90_x_area"]  = (p90in * area_frac).astype(np.float32)
    df_train["prob_max_x_area"]  = (mxp * area_frac).astype(np.float32)
    df_train["comp_x_area"]      = (ncc * area_frac).astype(np.float32)
    df_train["largestcc_x_area"] = (lccf * area_frac).astype(np.float32)
    df_train["prob_contrast"]    = (p90in - mpin).astype(np.float32)

    iou = getf("pm_iou", 0.0)
    inter = getf("pm_inter_frac", 0.0)
    union = getf("pm_union_frac", 0.0)
    df_train["iou_x_area"]       = (iou * area_frac).astype(np.float32)
    df_train["inter_over_union"] = (inter / (1e-6 + union)).astype(np.float32)

num_cols = [c for c in df_train.columns if pd.api.types.is_numeric_dtype(df_train[c])]
df_train[num_cols] = df_train[num_cols].fillna(FE_CFG["fillna_value"])

# ----------------------------
# 12) Variant one-hot
# ----------------------------
variant_dummy_cols = []
if FE_CFG["encode_variant_onehot"]:
    vc = df_train["variant"].astype(str).fillna("unk")
    counts = vc.value_counts()
    keep = set(counts[counts >= int(FE_CFG["variant_min_count"])].index.tolist())
    vc = vc.where(vc.isin(keep), other="rare")
    dummies = pd.get_dummies(vc, prefix="v", dummy_na=False).astype(np.uint8)
    variant_dummy_cols = dummies.columns.tolist()
    df_train = pd.concat([df_train, dummies], axis=1)

# ----------------------------
# 13) Select final feature columns
# ----------------------------
feature_cols = []
for c in df_train.columns:
    if not pd.api.types.is_numeric_dtype(df_train[c]):
        continue
    if c in TARGET_COLS or c in SPLIT_COLS or c in ID_DROP_NUM:
        continue
    feature_cols.append(c)

dropped_constant = []
if FE_CFG["drop_constant_features"]:
    nun = df_train[feature_cols].nunique(dropna=False)
    nonconst = nun[nun > 1].index.tolist()
    dropped_constant = sorted(set(feature_cols) - set(nonconst))
    feature_cols = nonconst

if FE_CFG["cast_float32"]:
    df_train[feature_cols] = df_train[feature_cols].astype(np.float32)

FEATURE_COLS = list(feature_cols)

# ----------------------------
# 14) Final outputs
# ----------------------------
base_out_cols = ["uid", "case_id", "variant", "fold", "y"]
df_train_tabular = df_train[base_out_cols + FEATURE_COLS].copy()

X_train = df_train_tabular[FEATURE_COLS]
y_train = df_train_tabular["y"].astype(int)
folds   = df_train_tabular["fold"].astype(int)

print("\nOK — Training table built")
print("  df_train_tabular:", df_train_tabular.shape)
print("  X_train:", X_train.shape, "| y pos%:", float(y_train.mean()) * 100.0)
print("  folds:", int(folds.nunique()), "unique folds")
print("  feature_cols:", int(len(FEATURE_COLS)))
if dropped_constant:
    print("  dropped_constant_features:", len(dropped_constant))
if variant_dummy_cols:
    print("  variant_dummies:", len(variant_dummy_cols))
if missing_ind_cols:
    print("  missing_indicators:", len(missing_ind_cols))
if pred_core_cols_added:
    print("  extra_pred_core_cols:", len(pred_core_cols_added))
if match_core_cols_added:
    print("  extra_match_core_cols:", len(match_core_cols_added))

if X_train.shape[0] != y_train.shape[0]:
    raise RuntimeError("X_train and y_train row mismatch")
if y_train.isna().any():
    raise RuntimeError("y_train contains NaN")
if folds.isna().any():
    raise RuntimeError("folds contains NaN")

print("\nFeature head:", FEATURE_COLS[:25])
print("Feature tail:", FEATURE_COLS[-15:])

# ----------------------------
# 15) Save schema
# ----------------------------
OUT_ART = Path("/kaggle/working/recodai_luc_gate_artifacts")
OUT_ART.mkdir(parents=True, exist_ok=True)

with open(OUT_ART / "feature_cols.json", "w") as f:
    json.dump(FEATURE_COLS, f, indent=2)

schema = {
    "fe_cfg": FE_CFG,
    "label_col_source": label_col,
    "clip_caps": {k: float(v) for k, v in clip_caps.items()},
    "dropped_constant_features": dropped_constant,
    "variant_dummy_cols": variant_dummy_cols,
    "missing_indicator_cols": missing_ind_cols,
    "primary_cfg": {
        "pred_cfg_dir": str(Path(PATHS["PRED_CFG_DIR"])),
        "match_cfg_dir": str(Path(PATHS["MATCH_CFG_DIR"])) if PATHS.get("MATCH_CFG_DIR") else "",
        "pred_npz_train_dir": str(PRED_NPZ_TRAIN_DIR),
        "match_npz_train_dir": str(MATCH_NPZ_TRAIN_DIR),
    },
    "multi_cfg": {
        "pred_cfg_dirs_used": [str(Path(PATHS["PRED_CFG_DIR"]))] + [str(d) for d in extra_pred_cfgs],
        "match_cfg_dirs_used": [str(Path(PATHS["MATCH_CFG_DIR"]))] + ([str(d) for d in extra_match_cfgs] if FE_CFG["use_match_features"] else []),
        "extra_pred_core_cols_added": pred_core_cols_added[:200],
        "extra_match_core_cols_added": match_core_cols_added[:200],
    },
    "n_features": int(len(FEATURE_COLS)),
    "example_feature_head": FEATURE_COLS[:30],
}
with open(OUT_ART / "feature_schema.json", "w") as f:
    json.dump(schema, f, indent=2)

df_train_tabular.to_parquet(OUT_ART / "df_train_tabular.parquet", index=False)

print(f"\nSaved -> {OUT_ART/'feature_cols.json'}")
print(f"Saved -> {OUT_ART/'feature_schema.json'}")
print(f"Saved -> {OUT_ART/'df_train_tabular.parquet'}")

globals().update({
    "df_train_tabular": df_train_tabular,
    "FEATURE_COLS": FEATURE_COLS,
    "X_train": X_train,
    "y_train": y_train,
    "folds": folds,
})


Using (PRIMARY):
  DF_TRAIN_ALL     : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/df_train_all.parquet
  CV_CASE_FOLDS    : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/cv_case_folds.csv
  PRED_FEAT_TRAIN  : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/pred_base_v3_v7_cfg_5dbf0aa165/pred_features_train_all.csv
  MATCH_FEAT_TRAIN : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/match_base_cfg_f9f7ea3a65c5/match_features_train_all.csv (optional)
  IMG_PROFILE_TRAIN: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/image_profile_train.parquet (optional)

Building NPZ pair features (pred+match) ...
  NPZ features progress: 500/5176


# Build & Export Test Feature Table

In [None]:
# ============================================================
# Step 2.5 — Build & Export Test Feature Table (pred_features_test*)
# ONE CELL (Kaggle-ready) — REVISI FULL v2.0 (SELARAS Step 2 v3.1)
#
# Tujuan (FIX utama v2.0):
# - Bukan sekadar "merge kolom yang ketemu".
# - v2.0 MEMBANGUN fitur TEST dengan *feature engineering yang sama* seperti Step 2:
#   * missing indicators (sesuai schema)
#   * clipping + *_cap + logabs_*
#   * interactions
#   * variant one-hot (kolom sama persis)
#   * (opsional) multi-CFG core+agg (pred+match)
#   * (opsional) NPZ pair features (pred+match) + overlap (cache parquet)
#
# Output:
#   /kaggle/working/recodai_luc_gate_artifacts/pred_features_test.csv
#   /kaggle/working/recodai_luc_gate_artifacts/pred_features_test_cfg_<hash>.csv
#
# REQUIRE (minimal):
# - FEATURE_COLS (list) dari Step 2
# - feature_schema.json (disarankan; dibuat oleh Step 2 v3.1)
# - PATHS dari STAGE 0 (untuk lokasi pred/match cfg + test meta)
# ============================================================

import os, re, json, hashlib, gc, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

OUT_DIR = Path("/kaggle/working/recodai_luc_gate_artifacts")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# 0) Require FEATURE_COLS
# ----------------------------
if "FEATURE_COLS" not in globals():
    raise RuntimeError("Missing `FEATURE_COLS`. Jalankan Step 2 (Build Training Table) dulu.")
FEATURE_COLS = list(FEATURE_COLS)
if len(FEATURE_COLS) == 0:
    raise RuntimeError("FEATURE_COLS kosong. Cek Step 2.")

# ----------------------------
# 1) Load schema (from Step 2)
# ----------------------------
SCHEMA_PATH = OUT_DIR / "feature_schema.json"
if not SCHEMA_PATH.exists():
    print("WARNING: feature_schema.json tidak ditemukan. Akan jalan mode fallback (lebih lemah).")
    schema = {}
else:
    schema = json.loads(SCHEMA_PATH.read_text())

FE_CFG = schema.get("fe_cfg", {}) or {}
clip_caps = schema.get("clip_caps", {}) or {}
variant_dummy_cols = schema.get("variant_dummy_cols", []) or []
missing_indicator_cols = schema.get("missing_indicator_cols", []) or []

# fallback defaults (jaga aman)
FE_CFG.setdefault("fillna_value", 0.0)
FE_CFG.setdefault("npz_bin_thr", 0.5)
FE_CFG.setdefault("npz_downsample_for_cc", 256)
FE_CFG.setdefault("use_npz_pair_features", True)
FE_CFG.setdefault("use_match_features", True)
FE_CFG.setdefault("multi_cfg_enabled", True)
FE_CFG.setdefault("multi_cfg_extra_mode", "core+agg")
FE_CFG.setdefault("multi_cfg_max_pred", int(os.environ.get("MULTI_CFG_MAX_PRED", "6")))
FE_CFG.setdefault("multi_cfg_max_match", int(os.environ.get("MULTI_CFG_MAX_MATCH", "3")))

# ----------------------------
# 2) Helpers
# ----------------------------
def _read_table_any(p: Path):
    p = Path(p)
    if not p.exists():
        return None
    if p.suffix.lower() == ".parquet":
        return pd.read_parquet(p)
    return pd.read_csv(p)

def _pick_first_existing(paths):
    for p in paths:
        if p is None:
            continue
        p = Path(str(p))
        if p.exists() and p.is_file():
            return p
    return None

def _to_str_series(s: pd.Series) -> pd.Series:
    return s.astype(str).replace({"nan": "", "None": ""})

def _ensure_uid_case_variant(df: pd.DataFrame) -> pd.DataFrame:
    """
    Pastikan ada: uid, case_id, variant (kalau bisa).
    PRIORITAS uid format: case_id__variant (double underscore) karena NPZ kamu pakai itu.
    """
    df = df.copy()
    cols = set(df.columns)

    # normalize case_id
    if "case_id" not in cols:
        for alt in ["case", "caseid", "image_id", "img_id", "id"]:
            if alt in cols:
                df = df.rename(columns={alt: "case_id"})
                cols = set(df.columns)
                break

    if "variant" not in cols:
        for alt in ["var", "aug", "view", "split_variant"]:
            if alt in cols:
                df = df.rename(columns={alt: "variant"})
                cols = set(df.columns)
                break

    if "case_id" in df.columns:
        df["case_id"] = pd.to_numeric(df["case_id"], errors="coerce").astype("Int64")

    if "variant" not in df.columns:
        df["variant"] = "test"
    df["variant"] = df["variant"].astype(str).replace({"nan": "test", "None": "test", "": "test"})

    # uid
    if "uid" not in df.columns:
        if "case_id" in df.columns:
            # best: case_id__variant
            df["uid"] = _to_str_series(df["case_id"]) + "__" + _to_str_series(df["variant"])
        else:
            df["uid"] = np.arange(len(df)).astype(str)

    df["uid"] = df["uid"].astype(str)
    return df

def _infer_join_keys(dfA: pd.DataFrame, dfB: pd.DataFrame):
    A = set(dfA.columns); B = set(dfB.columns)
    for keys in (["uid"], ["case_id","variant"], ["case_id"]):
        if all(k in A for k in keys) and all(k in B for k in keys):
            return keys
    return None

def _slug(s: str) -> str:
    s = str(s)
    s = re.sub(r"[^0-9a-zA-Z_]+", "_", s)
    s = re.sub(r"_+", "_", s).strip("_")
    return s[:80] if len(s) > 80 else s

def _coalesce_numeric(df: pd.DataFrame, out_col: str, candidates: list):
    if out_col in df.columns:
        return
    for c in candidates:
        if c in df.columns:
            df[out_col] = pd.to_numeric(df[c], errors="coerce")
            return

# ----------------------------
# 3) Load base TEST meta
# ----------------------------
def _get_base_test_df():
    if "df_test_tabular" in globals() and isinstance(globals()["df_test_tabular"], pd.DataFrame):
        return globals()["df_test_tabular"].copy(), "globals(df_test_tabular)"
    if "df_test" in globals() and isinstance(globals()["df_test"], pd.DataFrame):
        return globals()["df_test"].copy(), "globals(df_test)"
    if "PATHS" in globals() and isinstance(globals()["PATHS"], dict):
        cand = _pick_first_existing([
            globals()["PATHS"].get("DF_TEST", None),
            globals()["PATHS"].get("DF_TEST_ALL", None),
            globals()["PATHS"].get("TEST_META", None),
        ])
        if cand is not None:
            df = _read_table_any(cand)
            if df is not None:
                return df.copy(), f"PATHS({cand})"
    cand = _pick_first_existing([
        Path("/kaggle/working/df_test.csv"),
        Path("/kaggle/working/test.csv"),
        Path("/kaggle/working/df_test.parquet"),
    ])
    if cand is not None:
        df = _read_table_any(cand)
        if df is not None:
            return df.copy(), f"fallback({cand})"
    raise FileNotFoundError("Tidak menemukan df_test/meta. Pastikan PATHS[DF_TEST] ada atau df_test sudah dibuat.")

df_base, base_src = _get_base_test_df()
df_base = _ensure_uid_case_variant(df_base)
print("Base TEST source:", base_src, "| shape:", df_base.shape)

# satu row per uid
if df_base["uid"].duplicated().any():
    df_base = df_base.drop_duplicates("uid").reset_index(drop=True)

id_cols = [c for c in ["uid","case_id","variant"] if c in df_base.columns]
df_out = df_base[id_cols].copy()

# ----------------------------
# 4) Merge PRIMARY pred/match test tables (langsung dari PATHS jika ada)
# ----------------------------
def _merge_feat(path, name):
    nonlocal df_out
    if path is None:
        return 0
    p = Path(str(path))
    if (not p.exists()) or (not p.is_file()):
        return 0
    try:
        df = _read_table_any(p)
        if df is None or len(df) == 0:
            return 0
        df = _ensure_uid_case_variant(df)
        join_keys = _infer_join_keys(df_out, df)
        if join_keys is None:
            return 0
        # keep only what helps (raw columns, not limited to FEATURE_COLS yet)
        # (kita butuh base cols untuk bikin derived features)
        df = df.drop_duplicates(join_keys if len(join_keys)>1 else join_keys[0])
        before_cols = set(df_out.columns)
        df_out = df_out.merge(df, on=join_keys, how="left", suffixes=("", "_dup"))

        # resolve dup
        dup_cols = [c for c in df_out.columns if c.endswith("_dup")]
        for dc in dup_cols:
            basec = dc[:-4]
            if basec in df_out.columns:
                df_out[basec] = df_out[basec].fillna(df_out[dc])
            df_out = df_out.drop(columns=[dc])

        gained = len(set(df_out.columns) - before_cols)
        print(f"  merged {name}: {p.name} | join_keys={join_keys} | gained_cols={gained}")
        return gained
    except Exception as e:
        print(f"  skip {name} (merge error): {p} | err={repr(e)}")
        return 0

PRED_FEAT_TEST  = PATHS.get("PRED_FEAT_TEST", None) if "PATHS" in globals() else None
MATCH_FEAT_TEST = PATHS.get("MATCH_FEAT_TEST", None) if "PATHS" in globals() else None

print("\nMerging primary test feature CSVs (if exist):")
g1 = _merge_feat(PRED_FEAT_TEST, "PRED_FEAT_TEST")
g2 = _merge_feat(MATCH_FEAT_TEST, "MATCH_FEAT_TEST")

# ----------------------------
# 5) MULTI-CFG extra core+agg (test) — harus konsisten dg Step 2
# ----------------------------
def _short_cfg_tag(cfg_dir: Path, idx: int) -> str:
    nm = cfg_dir.name
    m = re.search(r"(cfg_[0-9a-f]{6,})", nm)
    tag = m.group(1) if m else nm
    tag = tag.replace("pred_unet_aspp_cfg_", "pu_").replace("pred_fusion_cfg_", "pf_").replace("pred_base_", "pb_")
    tag = tag.replace("match_base_cfg_", "m_")
    tag = re.sub(r"[^0-9a-zA-Z_]+", "_", tag)
    return f"{idx:02d}_{tag[:24]}"

CORE_PRED_COLS = [
    "best_count","best_mean_sim","peak_ratio","best_weight","best_weight_frac","inlier_ratio","n_pairs_thr","n_pairs_mnn","has_peak",
    "grid_h","grid_w","grid_area_frac",
    "area_frac","pred_area_frac","mask_area_frac","n_cc","n_comp","n_components","largest_cc_frac","mean_prob_inside","p90_prob_inside","max_prob",
]
CORE_MATCH_COLS = [
    "best_count","best_mean_sim","peak_ratio","best_weight","best_weight_frac","inlier_ratio","n_pairs_thr","n_pairs_mnn","has_peak",
    "grid_h","grid_w","grid_area_frac",
    "area_frac","mask_area_frac","n_cc","n_comp","largest_cc_frac",
]

def _load_csv_core(csv_path: Path, core_cols: list) -> pd.DataFrame:
    df = pd.read_csv(csv_path, low_memory=False)
    df = _ensure_uid_case_variant(df)
    keep = ["uid"] + [c for c in core_cols if c in df.columns]
    return df[keep].copy()

if FE_CFG.get("multi_cfg_enabled", True) and ("PATHS" in globals()) and isinstance(PATHS, dict):
    pred_cfg_dirs = [Path(x) for x in PATHS.get("PRED_CFG_DIRS", [])] if PATHS.get("PRED_CFG_DIRS") else []
    match_cfg_dirs = [Path(x) for x in PATHS.get("MATCH_CFG_DIRS", [])] if PATHS.get("MATCH_CFG_DIRS") else []

    if pred_cfg_dirs:
        pred_cfg_dirs = [d for d in pred_cfg_dirs if d.exists()]
        pred_cfg_dirs = pred_cfg_dirs[:max(1, int(FE_CFG.get("multi_cfg_max_pred", 6)))]
    else:
        if PATHS.get("PRED_CFG_DIR"):
            pred_cfg_dirs = [Path(PATHS["PRED_CFG_DIR"])]

    if match_cfg_dirs:
        match_cfg_dirs = [d for d in match_cfg_dirs if d.exists()]
        match_cfg_dirs = match_cfg_dirs[:max(1, int(FE_CFG.get("multi_cfg_max_match", 3)))]
    else:
        if PATHS.get("MATCH_CFG_DIR"):
            match_cfg_dirs = [Path(PATHS["MATCH_CFG_DIR"])]

    primary_pred = str(Path(PATHS.get("PRED_CFG_DIR",""))) if PATHS.get("PRED_CFG_DIR") else ""
    primary_match = str(Path(PATHS.get("MATCH_CFG_DIR",""))) if PATHS.get("MATCH_CFG_DIR") else ""

    extra_pred_cfgs = [d for d in pred_cfg_dirs if str(d) != primary_pred]
    extra_match_cfgs = [d for d in match_cfg_dirs if str(d) != primary_match]

    pred_core_matrix_cols = {}
    for i, cfg_dir in enumerate(extra_pred_cfgs, 1):
        fp = cfg_dir / "pred_features_test.csv"
        if not fp.exists():
            continue
        try:
            tag = _short_cfg_tag(cfg_dir, i)
            df_extra = _load_csv_core(fp, CORE_PRED_COLS)
            ren = {}
            for c in df_extra.columns:
                if c == "uid": continue
                ren[c] = f"{c}__{tag}"
                pred_core_matrix_cols.setdefault(c, []).append(ren[c])
            df_extra = df_extra.rename(columns=ren)
            df_out = df_out.merge(df_extra[["uid"] + list(ren.values())], on="uid", how="left")
            print(f"  extra pred cfg merged: {cfg_dir.name}")
        except Exception as e:
            print(f"  WARNING: skip extra pred cfg {cfg_dir.name}: {e}")

    if FE_CFG.get("multi_cfg_extra_mode","core+agg") == "core+agg":
        for base_name, cols_suff in pred_core_matrix_cols.items():
            cols_all = []
            if base_name in df_out.columns:
                cols_all.append(base_name)
            cols_all.extend([c for c in cols_suff if c in df_out.columns])
            cols_all = [c for c in cols_all if c in df_out.columns]
            if len(cols_all) < 2:
                continue
            mat = df_out[cols_all].apply(pd.to_numeric, errors="coerce")
            df_out[f"cfg_mean_{base_name}"] = mat.mean(axis=1, skipna=True)
            df_out[f"cfg_max_{base_name}"]  = mat.max(axis=1, skipna=True)
            df_out[f"cfg_min_{base_name}"]  = mat.min(axis=1, skipna=True)
            df_out[f"cfg_std_{base_name}"]  = mat.std(axis=1, skipna=True)

    match_core_matrix_cols = {}
    if FE_CFG.get("use_match_features", True):
        for i, cfg_dir in enumerate(extra_match_cfgs, 1):
            fp = cfg_dir / "match_features_test.csv"
            if not fp.exists():
                continue
            try:
                tag = _short_cfg_tag(cfg_dir, i)
                df_extra = _load_csv_core(fp, CORE_MATCH_COLS)
                ren = {}
                for c in df_extra.columns:
                    if c == "uid": continue
                    ren[c] = f"{c}__{tag}"
                    match_core_matrix_cols.setdefault(c, []).append(ren[c])
                df_extra = df_extra.rename(columns=ren)
                df_out = df_out.merge(df_extra[["uid"] + list(ren.values())], on="uid", how="left")
                print(f"  extra match cfg merged: {cfg_dir.name}")
            except Exception as e:
                print(f"  WARNING: skip extra match cfg {cfg_dir.name}: {e}")

        if FE_CFG.get("multi_cfg_extra_mode","core+agg") == "core+agg":
            for base_name, cols_suff in match_core_matrix_cols.items():
                cols_all = []
                if base_name in df_out.columns:
                    cols_all.append(base_name)
                cols_all.extend([c for c in cols_suff if c in df_out.columns])
                cols_all = [c for c in cols_all if c in df_out.columns]
                if len(cols_all) < 2:
                    continue
                mat = df_out[cols_all].apply(pd.to_numeric, errors="coerce")
                df_out[f"cfg_mean_match_{base_name}"] = mat.mean(axis=1, skipna=True)
                df_out[f"cfg_max_match_{base_name}"]  = mat.max(axis=1, skipna=True)
                df_out[f"cfg_min_match_{base_name}"]  = mat.min(axis=1, skipna=True)
                df_out[f"cfg_std_match_{base_name}"]  = mat.std(axis=1, skipna=True)

# ----------------------------
# 6) Optional: NPZ pair features for TEST (pred+match) + overlap (cache)
# ----------------------------
def _dir_has_any_npz(d: Path) -> bool:
    try:
        if d is None or (not d.exists()) or (not d.is_dir()):
            return False
        for _ in d.glob("*.npz"):
            return True
        return False
    except Exception:
        return False

def _uid_to_npz_path(npz_dir: Path, uid: str) -> Path:
    uid = str(uid)
    if uid.endswith(".npz"):
        p = npz_dir / uid
        return p
    cands = [
        npz_dir / f"{uid}.npz",
        npz_dir / f"{uid.replace('__','_')}.npz",
        npz_dir / f"{uid.replace('_','__')}.npz",
    ]
    for p in cands:
        if p.exists():
            return p
    return cands[0]

try:
    import scipy.ndimage as ndi
    _HAS_NDI = True
except Exception:
    _HAS_NDI = False
try:
    import cv2
    _HAS_CV2 = True
except Exception:
    _HAS_CV2 = False

def _maybe_downsample_mask(mask: np.ndarray, ds: int):
    if ds is None or ds <= 0:
        return mask
    H, W = mask.shape[:2]
    if H <= ds and W <= ds:
        return mask
    if _HAS_CV2:
        return cv2.resize(mask.astype(np.uint8), (ds, ds), interpolation=cv2.INTER_NEAREST).astype(bool)
    ys = (np.linspace(0, H-1, ds)).astype(int)
    xs = (np.linspace(0, W-1, ds)).astype(int)
    return mask[np.ix_(ys, xs)]

def _count_cc_and_largest(mask_bool: np.ndarray, ds: int):
    m = _maybe_downsample_mask(mask_bool, ds)
    m = (m > 0).astype(np.uint8)
    if m.sum() == 0:
        return 0, 0
    if _HAS_NDI:
        lab, n = ndi.label(m)
        if n <= 0:
            return 0, 0
        sizes = np.bincount(lab.ravel())
        sizes[0] = 0
        largest = int(sizes.max()) if sizes.size else 0
        return int(n), int(largest)
    if _HAS_CV2:
        n, lab = cv2.connectedComponents(m, connectivity=8)
        n_cc = int(max(n - 1, 0))
        if n_cc <= 0:
            return 0, 0
        sizes = np.bincount(lab.ravel())
        sizes[0] = 0
        largest = int(sizes.max()) if sizes.size else 0
        return int(n_cc), int(largest)
    return -1, -1

def _extract_mask_prob_from_npz(npz_path: Path, thr: float):
    try:
        z = np.load(npz_path, allow_pickle=True)
        keys = list(z.files)
        arrs = {k: z[k] for k in keys}
    except Exception:
        return None, None

    def _pick_first(keys_like):
        for k in keys_like:
            if k in arrs:
                return k
        return None

    k_prob = _pick_first(["prob","probs","pred_prob","mask_prob","p","pred"])
    k_mask = _pick_first(["mask","masks","bin_mask","pred_mask","mask_bin"])

    prob = None
    mask = None
    if k_prob is not None:
        a = arrs[k_prob]
        if isinstance(a, np.ndarray) and a.ndim >= 2:
            prob = a.astype(np.float32)
    if k_mask is not None:
        a = arrs[k_mask]
        if isinstance(a, np.ndarray):
            if a.ndim == 3:
                a = np.max(a, axis=0)
            mask = (a > 0).astype(bool)

    if mask is None and prob is None:
        best_k, best_n = None, -1
        for k, a in arrs.items():
            if not isinstance(a, np.ndarray): continue
            if a.ndim < 2: continue
            n = int(np.prod(a.shape))
            if n > best_n:
                best_k, best_n = k, n
        if best_k is not None:
            a = arrs[best_k]
            a2 = np.max(a, axis=0) if a.ndim == 3 else a
            if np.issubdtype(a2.dtype, np.floating):
                prob = a2.astype(np.float32)
            else:
                mask = (a2 > 0).astype(bool)

    if mask is None and prob is not None:
        m = prob
        if m.ndim == 3:
            m = np.max(m, axis=0)
        mask = (m >= float(thr))

    if prob is not None and prob.ndim == 3:
        prob = np.max(prob, axis=0)
    if mask is not None and mask.ndim == 3:
        mask = np.max(mask, axis=0).astype(bool)

    return mask, prob

def _mask_stats(mask: np.ndarray, prob: np.ndarray, ds_cc: int):
    out = {}
    if mask is None:
        out.update({
            "npz_area_frac": np.nan,
            "npz_n_cc": np.nan,
            "npz_largest_cc_frac": np.nan,
            "npz_mean_prob_inside": np.nan,
            "npz_p90_prob_inside": np.nan,
            "npz_max_prob": np.nan,
        })
        return out

    m = (mask > 0)
    H, W = m.shape[:2]
    tot = float(H * W) if H and W else 0.0
    area = float(m.sum())
    out["npz_area_frac"] = (area / tot) if tot > 0 else 0.0

    n_cc, largest = _count_cc_and_largest(m, ds_cc)
    out["npz_n_cc"] = float(n_cc) if n_cc >= 0 else np.nan
    out["npz_largest_cc_frac"] = (float(largest) / area) if area > 0 and largest >= 0 else np.nan

    if prob is not None and prob.shape[:2] == m.shape[:2]:
        p = prob.astype(np.float32)
        inside = p[m]
        if inside.size > 0:
            out["npz_mean_prob_inside"] = float(np.mean(inside))
            out["npz_p90_prob_inside"] = float(np.quantile(inside, 0.90))
            out["npz_max_prob"] = float(np.max(inside))
        else:
            out["npz_mean_prob_inside"] = 0.0
            out["npz_p90_prob_inside"] = 0.0
            out["npz_max_prob"] = 0.0
    else:
        out["npz_mean_prob_inside"] = np.nan
        out["npz_p90_prob_inside"] = np.nan
        out["npz_max_prob"] = np.nan
    return out

def _build_npz_pair_features(uids: list, pred_dir: Path, match_dir: Path, cache_path: Path, thr: float, ds_cc: int):
    if cache_path.exists():
        try:
            dfc = pd.read_parquet(cache_path)
            dfc = _ensure_uid_case_variant(dfc)
            return dfc
        except Exception:
            pass

    rows = []
    for i, uid in enumerate(uids, 1):
        r = {"uid": str(uid)}

        pred_p = _uid_to_npz_path(pred_dir, uid) if pred_dir is not None else None
        match_p = _uid_to_npz_path(match_dir, uid) if match_dir is not None else None

        mp = None; pp = None
        if pred_p is not None and pred_p.exists():
            mp, pp = _extract_mask_prob_from_npz(pred_p, thr)
        pred_stats = _mask_stats(mp, pp, ds_cc)
        for k, v in pred_stats.items():
            r[f"pred_{k}"] = v

        mm = None; pm = None
        if match_p is not None and match_p.exists():
            mm, pm = _extract_mask_prob_from_npz(match_p, thr)
        match_stats = _mask_stats(mm, pm, ds_cc)
        for k, v in match_stats.items():
            r[f"match_{k}"] = v

        if (mp is not None) and (mm is not None):
            a = (mp > 0); b = (mm > 0)
            inter = float(np.logical_and(a, b).sum())
            union = float(np.logical_or(a, b).sum())
            tot = float(a.size) if a.size else 0.0
            r["pm_inter_frac"] = (inter / tot) if tot > 0 else 0.0
            r["pm_union_frac"] = (union / tot) if tot > 0 else 0.0
            r["pm_iou"] = (inter / union) if union > 0 else 0.0
            r["pm_pred_minus_match_frac"] = ((float(a.sum()) - inter) / tot) if tot > 0 else 0.0
            r["pm_match_minus_pred_frac"] = ((float(b.sum()) - inter) / tot) if tot > 0 else 0.0
        else:
            r["pm_inter_frac"] = np.nan
            r["pm_union_frac"] = np.nan
            r["pm_iou"] = np.nan
            r["pm_pred_minus_match_frac"] = np.nan
            r["pm_match_minus_pred_frac"] = np.nan

        rows.append(r)

        if (i % 500) == 0:
            print(f"  NPZ test progress: {i}/{len(uids)}")

    dfc = pd.DataFrame(rows)
    try:
        cache_path.parent.mkdir(parents=True, exist_ok=True)
        dfc.to_parquet(cache_path, index=False)
    except Exception:
        pass
    return dfc

# locate npz dirs (primary cfg)
PRED_CFG_DIR = Path(PATHS.get("PRED_CFG_DIR","")) if ("PATHS" in globals()) else Path("")
MATCH_CFG_DIR = Path(PATHS.get("MATCH_CFG_DIR","")) if ("PATHS" in globals()) else Path("")
PRED_NPZ_TEST_DIR = Path(PATHS.get("PRED_NPZ_TEST_DIR", str(PRED_CFG_DIR / "test")))
MATCH_NPZ_TEST_DIR = Path(PATHS.get("MATCH_NPZ_TEST_DIR", str(MATCH_CFG_DIR / "test")))

npz_ready = bool(FE_CFG.get("use_npz_pair_features", True)) and _dir_has_any_npz(PRED_NPZ_TEST_DIR) and _dir_has_any_npz(MATCH_NPZ_TEST_DIR)
if npz_ready:
    thr = float(FE_CFG.get("npz_bin_thr", 0.5))
    ds_cc = int(FE_CFG.get("npz_downsample_for_cc", 256) or 0)

    key = f"npz_pair_test_{_slug(PRED_CFG_DIR.name)}__{_slug(MATCH_CFG_DIR.name)}__thr{thr}"
    h = hashlib.md5(key.encode()).hexdigest()[:12]
    cache_path = OUT_DIR / f"npz_pair_features_test_{h}.parquet"

    print("\nBuilding NPZ pair features (TEST) ...")
    uids = df_out["uid"].astype(str).tolist()
    df_npz = _build_npz_pair_features(uids, PRED_NPZ_TEST_DIR, MATCH_NPZ_TEST_DIR, cache_path, thr, ds_cc)
    df_npz = _ensure_uid_case_variant(df_npz)
    keep_cols = [c for c in df_npz.columns if c != "uid"]
    df_out = df_out.merge(df_npz[["uid"] + keep_cols], on="uid", how="left")
else:
    print("\nNOTE: NPZ pair features TEST skipped (missing/empty dirs).")
    print("  PRED_NPZ_TEST_DIR :", PRED_NPZ_TEST_DIR, "(npz found)" if _dir_has_any_npz(PRED_NPZ_TEST_DIR) else "(empty/missing)")
    print("  MATCH_NPZ_TEST_DIR:", MATCH_NPZ_TEST_DIR, "(npz found)" if _dir_has_any_npz(MATCH_NPZ_TEST_DIR) else "(empty/missing)")

# ----------------------------
# 7) Canonical aliases (supaya interactions stabil)
# ----------------------------
_coalesce_numeric(df_out, "area_frac", ["area_frac","pred_area_frac","mask_area_frac","pred_npz_area_frac","pred_npz_area_frac","pred_npz_area_frac"])
_coalesce_numeric(df_out, "grid_area_frac", ["grid_area_frac","grid_area"])
_coalesce_numeric(df_out, "best_count", ["best_count","pair_count","n_pairs","pairs"])
_coalesce_numeric(df_out, "best_mean_sim", ["best_mean_sim","mean_sim","sim_mean"])
_coalesce_numeric(df_out, "peak_ratio", ["peak_ratio","peak_to_mean","peak_over_mean"])
_coalesce_numeric(df_out, "inlier_ratio", ["inlier_ratio","ransac_inlier_ratio","inliers_ratio"])
_coalesce_numeric(df_out, "has_peak", ["has_peak","peak_found","has_mode"])

# NPZ-derived (pred)
_coalesce_numeric(df_out, "n_cc_pred", ["n_cc_pred","pred_npz_n_cc","pred_npz_n_cc","pred_npz_n_cc","pred_npz_n_cc", "pred_npz_n_cc", "pred_npz_n_cc", "pred_npz_n_cc"])
_coalesce_numeric(df_out, "largest_cc_frac_pred", ["largest_cc_frac_pred","pred_npz_largest_cc_frac"])
_coalesce_numeric(df_out, "mean_prob_inside_pred", ["mean_prob_inside_pred","pred_npz_mean_prob_inside"])
_coalesce_numeric(df_out, "p90_prob_inside_pred", ["p90_prob_inside_pred","pred_npz_p90_prob_inside"])
_coalesce_numeric(df_out, "max_prob_pred", ["max_prob_pred","pred_npz_max_prob"])

# ----------------------------
# 8) Recompute engineered columns to match FEATURE_COLS
#    (hanya buat kolom-kolom yang dibutuhkan)
# ----------------------------
def _num(s):
    return pd.to_numeric(s, errors="coerce")

def safe_log1p_nonneg(x):
    x = np.asarray(x, dtype=np.float64)
    x = np.where(np.isfinite(x), x, 0.0)
    x = np.clip(x, 0.0, None)
    return np.log1p(x)

def safe_sqrt_nonneg(x):
    x = np.asarray(x, dtype=np.float64)
    x = np.where(np.isfinite(x), x, 0.0)
    x = np.clip(x, 0.0, None)
    return np.sqrt(x)

# replace inf -> NaN (numeric)
for c in df_out.columns:
    if pd.api.types.is_numeric_dtype(df_out[c]):
        df_out[c] = df_out[c].replace([np.inf,-np.inf], np.nan)

# missing indicators: pakai daftar dari schema agar kolom sama
if missing_indicator_cols:
    for ind in missing_indicator_cols:
        # expected format: isna_<col>
        if ind in df_out.columns:
            continue
        if ind.startswith("isna_"):
            base = ind[5:]
            if base in df_out.columns:
                if pd.api.types.is_numeric_dtype(df_out[base]):
                    df_out[ind] = df_out[base].isna().astype(np.uint8)
                else:
                    df_out[ind] = 0
            else:
                df_out[ind] = 1  # benar-benar missing
else:
    # fallback: create isna_* hanya jika dibutuhkan oleh FEATURE_COLS
    for c in FEATURE_COLS:
        if c.startswith("isna_") and c not in df_out.columns:
            base = c[5:]
            if base in df_out.columns and pd.api.types.is_numeric_dtype(df_out[base]):
                df_out[c] = df_out[base].isna().astype(np.uint8)
            else:
                df_out[c] = 1

# clipped + cap + logabs/sqrtabs: gunakan clip_caps dari schema
# hanya buat kalau kolom target ada di FEATURE_COLS
for c, cap in (clip_caps or {}).items():
    try:
        cap = float(cap)
    except Exception:
        continue

    need_cap = (f"{c}_cap" in FEATURE_COLS)
    need_log = (f"logabs_{c}" in FEATURE_COLS)
    need_sqrt = (f"sqrtabs_{c}" in FEATURE_COLS)

    if (not need_cap) and (not need_log) and (not need_sqrt):
        continue
    if c not in df_out.columns:
        # kalau base tidak ada, isi 0
        if need_cap:  df_out[f"{c}_cap"] = 0.0
        if need_log:  df_out[f"logabs_{c}"] = 0.0
        if need_sqrt: df_out[f"sqrtabs_{c}"] = 0.0
        continue

    x = _num(df_out[c]).fillna(0.0).astype(float).values
    x = np.clip(x, -cap, cap)
    if need_cap:
        df_out[f"{c}_cap"] = x.astype(np.float32)
    if need_log:
        df_out[f"logabs_{c}"] = safe_log1p_nonneg(np.abs(x)).astype(np.float32)
    if need_sqrt:
        df_out[f"sqrtabs_{c}"] = safe_sqrt_nonneg(np.abs(x)).astype(np.float32)

# interactions: hitung hanya kalau kolomnya diminta FEATURE_COLS
def _getf(col, default=0.0):
    if col in df_out.columns:
        return _num(df_out[col]).fillna(default).astype(float).values
    return np.full(len(df_out), default, dtype=np.float64)

best_mean_sim = _getf("best_mean_sim", 0.0)
best_count    = _getf("best_count", 0.0)
peak_ratio    = _getf("peak_ratio", 0.0)
has_peak      = _getf("has_peak", 0.0)
grid_area     = _getf("grid_area_frac", 0.0)
area_frac     = _getf("area_frac", 0.0)
n_pairs_thr   = _getf("n_pairs_thr", 0.0)
n_pairs_mnn   = _getf("n_pairs_mnn", 0.0)
inlier_ratio  = _getf("inlier_ratio", 0.0)
gh = _getf("grid_h", 0.0)
gw = _getf("grid_w", 0.0)
gridN = np.clip(gh * gw, 0.0, None)

def _set_if_needed(name, arr):
    if name in FEATURE_COLS and name not in df_out.columns:
        df_out[name] = arr.astype(np.float32)

_set_if_needed("sim_x_count",      best_mean_sim * best_count)
_set_if_needed("peak_x_sim",       peak_ratio * best_mean_sim)
_set_if_needed("haspeak_x_sim",    has_peak * best_mean_sim)
_set_if_needed("area_x_sim",       grid_area * best_mean_sim)
_set_if_needed("area_x_count",     grid_area * best_count)
_set_if_needed("mask_grid_ratio",  area_frac / (1e-6 + grid_area))
_set_if_needed("mnn_ratio",        n_pairs_mnn / (1.0 + n_pairs_thr))
_set_if_needed("pairs_per_cell",   n_pairs_thr / (1.0 + gridN))
_set_if_needed("inlier_x_pairs",   inlier_ratio * n_pairs_thr)
_set_if_needed("log1p_pairs_thr",  safe_log1p_nonneg(n_pairs_thr))
_set_if_needed("log1p_best_count", safe_log1p_nonneg(best_count))
_set_if_needed("log1p_area_frac",  safe_log1p_nonneg(np.clip(area_frac, 0, None)))

# seg/gate stability extras (jika dibutuhkan)
ncc   = _getf("n_cc_pred", 0.0)
lccf  = _getf("largest_cc_frac_pred", 0.0)
mpin  = _getf("mean_prob_inside_pred", 0.0)
p90in = _getf("p90_prob_inside_pred", 0.0)
mxp   = _getf("max_prob_pred", 0.0)

_set_if_needed("prob_mean_x_area", mpin * area_frac)
_set_if_needed("prob_p90_x_area",  p90in * area_frac)
_set_if_needed("prob_max_x_area",  mxp * area_frac)
_set_if_needed("comp_x_area",      ncc * area_frac)
_set_if_needed("largestcc_x_area", lccf * area_frac)
_set_if_needed("prob_contrast",    p90in - mpin)

# overlap extras
iou   = _getf("pm_iou", 0.0)
inter = _getf("pm_inter_frac", 0.0)
union = _getf("pm_union_frac", 0.0)
_set_if_needed("iou_x_area",       iou * area_frac)
_set_if_needed("inter_over_union", inter / (1e-6 + union))

# variant one-hot: buat kolom sama persis seperti train
if variant_dummy_cols:
    cats = [c[len("v_"):] for c in variant_dummy_cols if c.startswith("v_")]
    cats_set = set(cats)
    v = df_out["variant"].astype(str).fillna("unk")
    # map unseen -> rare jika ada
    if "rare" in cats_set:
        v2 = v.where(v.isin(cats_set), other="rare")
    else:
        v2 = v.where(v.isin(cats_set), other=v)  # unseen => stay (kolomnya nanti 0 semua)
    for col in variant_dummy_cols:
        if col in df_out.columns:
            continue
        if col.startswith("v_"):
            cat = col[len("v_"):]
            df_out[col] = (v2 == cat).astype(np.uint8)

# fill numeric NaN -> 0
num_cols = [c for c in df_out.columns if pd.api.types.is_numeric_dtype(df_out[c])]
df_out[num_cols] = df_out[num_cols].fillna(float(FE_CFG.get("fillna_value", 0.0)))

# ----------------------------
# 9) Finalize: ensure all FEATURE_COLS exist + numeric float32
# ----------------------------
still_missing = [c for c in FEATURE_COLS if c not in df_out.columns]
if still_missing:
    for c in still_missing:
        df_out[c] = 0.0

for c in FEATURE_COLS:
    df_out[c] = pd.to_numeric(df_out[c], errors="coerce").fillna(0.0).astype(np.float32)

# order final
df_final = df_out[id_cols + FEATURE_COLS].copy()
print("\nFinal TEST feature table:", df_final.shape)

# ----------------------------
# 10) Save pred_features_test*
# ----------------------------
cfg_hash = hashlib.sha1(json.dumps(FEATURE_COLS, ensure_ascii=False).encode("utf-8")).hexdigest()[:12]
p_main = OUT_DIR / "pred_features_test.csv"
p_cfg  = OUT_DIR / f"pred_features_test_cfg_{cfg_hash}.csv"

df_final.to_csv(p_main, index=False)
df_final.to_csv(p_cfg, index=False)

print("\nSaved:")
print("  ->", p_main)
print("  ->", p_cfg)

# Export global for next steps
df_test_tabular = df_final
PRED_FEATURES_TEST_CSV = str(p_main)
PRED_FEATURES_TEST_CFG_CSV = str(p_cfg)
PRED_FEATURES_TEST_CFG_HASH = cfg_hash


# Train Baseline Model (Leakage-Safe CV)

In [None]:
# ============================================================
# Step 3 — Train Mask Model (UNet + ASPP) on DINOv2 Token-Grid Embeddings (Decoder-only)
# REVISI FULL v3.1-mask (Leakage-Safe CV by fold/case_id, AMP+EMA+accum+earlystop)
#
# Target: UNet+ASPP untuk prediksi mask, decoder di atas token-grid embedding DINOv2.
# - Input utama: token-grid (Htok, Wtok, D) per uid (DINOv2 patch tokens/grid).
# - GT: mask full-res (opsional) -> downsample ke (Htok, Wtok) untuk training.
#
# REQUIRE (minimal):
# - df_train_tabular: kolom wajib: uid, case_id, fold, y  (variant opsional)
#   * y=0 -> mask dianggap kosong
# - TOKEN CACHE: file embedding per uid (.npz/.npy) yang berisi array token grid
#   * Auto-search beberapa direktori umum. Kalau tidak ketemu -> fallback zeros (warn sekali).
# - TRAIN MASK DIR (opsional): kalau mask file ada (nama sama dengan uid.*)
#
# OUTPUT:
# - /kaggle/working/recodai_luc_mask_artifacts/mask_folds_seed_<seed>/mask_fold_<fold>.pt
# - /kaggle/working/recodai_luc_mask_artifacts/oof_mask_metrics.csv
# - /kaggle/working/recodai_luc_mask_artifacts/final_mask_model.pt   (bundle: fold_packs + full_packs + recommended_thr)
#
# Export globals:
# - FINAL_MASK_MODEL_PT (str)
# - MASK_OOF_REPORT (dict)
# ============================================================

import os, gc, json, math, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import roc_auc_score

# ----------------------------
# 0) Require df_train_tabular
# ----------------------------
need_vars = ["df_train_tabular"]
for v in need_vars:
    if v not in globals():
        raise RuntimeError(f"Missing `{v}`. Butuh df_train_tabular (uid/case_id/fold/y).")

df_train_tabular = df_train_tabular.copy()

required_cols = {"uid", "case_id", "fold", "y"}
missing_cols = [c for c in required_cols if c not in df_train_tabular.columns]
if missing_cols:
    raise ValueError(f"df_train_tabular missing columns: {missing_cols}")

df_train_tabular["uid"] = df_train_tabular["uid"].astype(str)
df_train_tabular["case_id"] = df_train_tabular["case_id"].astype(str)
df_train_tabular["fold"] = df_train_tabular["fold"].astype(int)
df_train_tabular["y"] = df_train_tabular["y"].astype(int)

uids = df_train_tabular["uid"].to_numpy()
y_img = df_train_tabular["y"].to_numpy(np.int64)
folds = df_train_tabular["fold"].to_numpy(np.int64)
unique_folds = sorted(df_train_tabular["fold"].unique().tolist())

print("Step3-mask setup:")
print("  rows :", len(df_train_tabular))
print("  folds:", len(unique_folds), unique_folds)
print("  forged%:", float(y_img.mean())*100.0)

# ----------------------------
# 1) Device + CFG (AUTO)
# ----------------------------
SEED0 = 2025

def seed_everything(seed=2025):
    import random
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")

def get_mem_gb():
    if device.type != "cuda":
        return 0.0
    try:
        return float(torch.cuda.get_device_properties(0).total_memory / (1024**3))
    except Exception:
        return 0.0

MEM_GB = get_mem_gb()
seed_everything(SEED0)

print("Device:", device, "| AMP:", use_amp, "| VRAM(GB):", MEM_GB)

# decoder-only; token dim D akan di-infer dari cache
CFG_CPU = dict(
    seed=SEED0,
    epochs=35,
    batch_size=16,
    accum_steps=1,
    lr=3e-4,
    weight_decay=1e-2,
    warmup_frac=0.05,
    grad_clip=1.0,
    early_patience=8,
    early_min_delta=1e-4,
    use_ema=True,
    ema_decay=0.999,

    # model
    base_ch=128,
    drop=0.10,
    aspp_rates=(1, 2, 4, 6),

    # loss
    bce_weight=1.0,
    dice_weight=1.0,
    focal_gamma=0.0,           # 0=off

    # aug on token-grid (flip only, safe)
    aug_hflip=0.5,
    aug_vflip=0.2,

    # threshold search (dice on OOF)
    thr_grid=81,
)

CFG_GPU = dict(CFG_CPU)
CFG_GPU.update(
    epochs=50,
    batch_size=32 if MEM_GB >= 16 else 24,
    accum_steps=1 if MEM_GB >= 16 else 2,
    lr=3e-4,
    base_ch=160 if MEM_GB >= 16 else 128,
    drop=0.10,
)

CFG_STRONG = dict(CFG_GPU)
CFG_STRONG.update(
    epochs=65,
    batch_size=48 if MEM_GB >= 30 else CFG_GPU["batch_size"],
    base_ch=192 if MEM_GB >= 30 else CFG_GPU["base_ch"],
    lr=2.5e-4,
)

CFG = dict(CFG_GPU if device.type == "cuda" else CFG_CPU)
CFG_NAME = "GPU" if device.type == "cuda" else "CPU"
if device.type == "cuda" and MEM_GB >= 30:
    CFG = dict(CFG_STRONG)
    CFG_NAME = "STRONG"

# ENV overrides (opsional)
if os.environ.get("MASK_EPOCHS","").strip():
    CFG["epochs"] = int(os.environ["MASK_EPOCHS"])
if os.environ.get("MASK_BS","").strip():
    CFG["batch_size"] = int(os.environ["MASK_BS"])
if os.environ.get("MASK_LR","").strip():
    CFG["lr"] = float(os.environ["MASK_LR"])
if os.environ.get("MASK_ACCUM","").strip():
    CFG["accum_steps"] = int(os.environ["MASK_ACCUM"])

print("CFG:", CFG_NAME, json.dumps({k: CFG[k] for k in ["epochs","batch_size","accum_steps","lr","base_ch"]}, indent=2))

# ----------------------------
# 2) Directories: token cache + mask dir (auto-detect)
# ----------------------------
def _first_existing(paths):
    for p in paths:
        if p is None: 
            continue
        p = Path(p)
        if p.exists():
            return p
    return None

# mask dir candidates
MASK_DIR = None
if "PATHS" in globals() and isinstance(PATHS, dict):
    MASK_DIR = _first_existing([PATHS.get("TRAIN_MASK_DIR"), PATHS.get("MASK_DIR"), PATHS.get("TRAIN_MASKS")])

if MASK_DIR is None:
    # common competition layouts
    MASK_DIR = _first_existing([
        "/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_masks",
        "/kaggle/input/recodai-luc-scientific-image-forgery-detection/masks_train",
        "/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_mask",
        "/kaggle/input/recodai-luc-scientific-image-forgery-detection/masks",
        "/kaggle/working/recodai_luc/train_masks",
    ])

print("MASK_DIR:", str(MASK_DIR) if MASK_DIR else "(None / will assume empty mask for y=0 and missing files)")

# token cache candidates
TOKEN_CACHE_DIR = None
if "CACHE_ROOT" in globals():
    try:
        cr = Path(CACHE_ROOT)
        if cr.exists():
            TOKEN_CACHE_DIR = cr
    except Exception:
        pass

if TOKEN_CACHE_DIR is None:
    # try typical cache roots
    base_candidates = [
        Path("/kaggle/working/recodai_luc/cache"),
        Path("/kaggle/working/recodai_luc/cache/dino_v2"),
        Path("/kaggle/input/recod-ailuc-dinov2-train/recodai_luc/cache"),
        Path("/kaggle/input/recod-ailuc-dinov2-train/recodai_luc/cache/dino_v2"),
    ]
    for b in base_candidates:
        if not b.exists():
            continue
        # pick latest cfg_* if exists
        cfg_dirs = sorted([p for p in b.glob("cfg_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
        if cfg_dirs:
            TOKEN_CACHE_DIR = cfg_dirs[0]
            break
        # otherwise use b directly
        TOKEN_CACHE_DIR = b
        break

print("TOKEN_CACHE_DIR:", str(TOKEN_CACHE_DIR) if TOKEN_CACHE_DIR else "(None)")

# ----------------------------
# 3) Token loader (robust) + mask loader
# ----------------------------
_WARN_TOKEN_MISS = False
_WARN_MASK_MISS = False

def _load_np_any(p: Path):
    p = Path(p)
    if p.suffix.lower() == ".npy":
        return np.load(p, allow_pickle=False)
    if p.suffix.lower() == ".npz":
        z = np.load(p, allow_pickle=False)
        # pick array with ndim 2 or 3
        keys = list(z.keys())
        best = None
        for k in keys:
            a = z[k]
            if isinstance(a, np.ndarray) and a.ndim in (2,3):
                best = a
                break
        if best is None and len(keys):
            best = z[keys[0]]
        return best
    raise ValueError(f"Unsupported token file: {p}")

def _reshape_to_grid(a: np.ndarray):
    # expected grid: (H,W,D) or (D,H,W) or (N,D)
    a = np.asarray(a)
    if a.ndim == 3:
        # maybe (D,H,W)
        if a.shape[0] in [384, 512, 768, 1024, 1536] and (a.shape[1] * a.shape[2] > 16):
            # assume (D,H,W)
            D,H,W = a.shape
            return np.transpose(a, (1,2,0))  # -> (H,W,D)
        # assume (H,W,D)
        return a
    if a.ndim == 2:
        N,D = a.shape
        s = int(round(math.sqrt(N)))
        if s*s == N:
            return a.reshape(s, s, D)
        # fallback: try factorization
        for h in range(1, int(math.sqrt(N))+1):
            if N % h == 0:
                w = N // h
                if h >= 4 and w >= 4:
                    return a.reshape(h, w, D)
        # last resort: treat as 1xN
        return a.reshape(1, N, D)
    raise ValueError(f"Token array has unsupported shape: {a.shape}")

def find_token_file(uid: str):
    if TOKEN_CACHE_DIR is None:
        return None
    uid = str(uid)
    # common patterns
    cand = [
        TOKEN_CACHE_DIR / f"{uid}.npz",
        TOKEN_CACHE_DIR / f"{uid}.npy",
        TOKEN_CACHE_DIR / "tokens_train" / f"{uid}.npz",
        TOKEN_CACHE_DIR / "tokens_train" / f"{uid}.npy",
        TOKEN_CACHE_DIR / "train" / f"{uid}.npz",
        TOKEN_CACHE_DIR / "train" / f"{uid}.npy",
        TOKEN_CACHE_DIR / "feat_train" / f"{uid}.npz",
        TOKEN_CACHE_DIR / "feat_train" / f"{uid}.npy",
    ]
    for p in cand:
        if p.exists():
            return p
    # limited glob fallback (avoid super expensive recursive)
    try:
        hits = list(TOKEN_CACHE_DIR.glob(f"**/*{uid}*.npz"))
        if hits:
            return hits[0]
        hits = list(TOKEN_CACHE_DIR.glob(f"**/*{uid}*.npy"))
        if hits:
            return hits[0]
    except Exception:
        pass
    return None

def load_token_grid(uid: str):
    global _WARN_TOKEN_MISS
    p = find_token_file(uid)
    if p is None:
        if not _WARN_TOKEN_MISS:
            print("[WARN] Token file tidak ketemu untuk sebagian uid. Fallback -> zeros (sekali warn).")
            _WARN_TOKEN_MISS = True
        return None
    try:
        a = _load_np_any(p)
        g = _reshape_to_grid(a).astype(np.float32, copy=False)  # (H,W,D)
        # clean nan/inf
        if not np.isfinite(g).all():
            g = np.nan_to_num(g, nan=0.0, posinf=0.0, neginf=0.0)
        return g
    except Exception as e:
        if not _WARN_TOKEN_MISS:
            print("[WARN] Gagal baca token file, fallback zeros. err:", repr(e))
            _WARN_TOKEN_MISS = True
        return None

def find_mask_file(uid: str):
    if MASK_DIR is None:
        return None
    uid = str(uid)
    # try common extensions
    exts = [".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"]
    for ex in exts:
        p = Path(MASK_DIR) / f"{uid}{ex}"
        if p.exists():
            return p
    # glob fallback
    try:
        hits = list(Path(MASK_DIR).glob(f"{uid}.*"))
        if hits:
            return hits[0]
    except Exception:
        pass
    return None

def load_mask_full(uid: str):
    global _WARN_MASK_MISS
    p = find_mask_file(uid)
    if p is None:
        if not _WARN_MASK_MISS and MASK_DIR is not None:
            print("[WARN] Sebagian mask file tidak ditemukan. Untuk missing -> mask kosong (sekali warn).")
            _WARN_MASK_MISS = True
        return None
    try:
        im = Image.open(p)
        im = im.convert("L")
        m = np.array(im, dtype=np.uint8)
        # binarize (umum: 0/255)
        m = (m > 127).astype(np.uint8)
        return m
    except Exception as e:
        if not _WARN_MASK_MISS:
            print("[WARN] Gagal baca mask file, fallback kosong. err:", repr(e))
            _WARN_MASK_MISS = True
        return None

def downsample_mask_to_grid(mask_full: np.ndarray, H: int, W: int):
    # area-like downsample (PIL BILINEAR then threshold) -> cukup stabil untuk token grid
    if mask_full is None:
        return np.zeros((H, W), dtype=np.float32)
    im = Image.fromarray((mask_full.astype(np.uint8) * 255))
    im = im.resize((W, H), resample=Image.BILINEAR)
    m = np.array(im, dtype=np.float32) / 255.0
    # keep as soft mask (0..1) for loss
    return m.astype(np.float32)

# infer token_dim from first available sample
TOKEN_DIM = None
TOKEN_HW_EX = None
for uid in uids[:min(80, len(uids))]:
    g = load_token_grid(uid)
    if g is not None and g.ndim == 3 and g.shape[0] >= 4 and g.shape[1] >= 4:
        TOKEN_DIM = int(g.shape[2])
        TOKEN_HW_EX = (int(g.shape[0]), int(g.shape[1]))
        break

if TOKEN_DIM is None:
    raise RuntimeError(
        "Tidak menemukan token-grid embedding di TOKEN_CACHE_DIR. "
        "Pastikan cache DINOv2 token-grid per uid tersedia (npz/npy)."
    )

print("Token example grid:", TOKEN_HW_EX, "| token_dim:", TOKEN_DIM)

# ----------------------------
# 4) Dataset (token-grid -> mask-grid)
# ----------------------------
class TokenMaskDataset(Dataset):
    def __init__(self, df: pd.DataFrame, is_train: bool, cfg: dict):
        self.df = df.reset_index(drop=True)
        self.is_train = bool(is_train)
        self.cfg = cfg

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        uid = str(row["uid"])
        yflag = int(row["y"])

        g = load_token_grid(uid)  # (H,W,D) or None
        if g is None:
            # fallback zeros with token example H,W
            H, W = TOKEN_HW_EX
            g = np.zeros((H, W, TOKEN_DIM), dtype=np.float32)
        H, W, D = g.shape

        # mask: if y=0 -> empty; else load if available (if missing -> empty)
        mask_full = None
        if yflag == 1:
            mask_full = load_mask_full(uid)
        m = downsample_mask_to_grid(mask_full, H, W)  # (H,W) float 0..1

        # aug (token-grid safe flips)
        if self.is_train:
            if self.cfg.get("aug_hflip", 0) > 0 and np.random.rand() < float(self.cfg["aug_hflip"]):
                g = g[:, ::-1, :].copy()
                m = m[:, ::-1].copy()
            if self.cfg.get("aug_vflip", 0) > 0 and np.random.rand() < float(self.cfg["aug_vflip"]):
                g = g[::-1, :, :].copy()
                m = m[::-1, :].copy()

        # to tensor
        x = torch.from_numpy(np.transpose(g, (2, 0, 1)).astype(np.float32))  # (D,H,W)
        y = torch.from_numpy(m[None, ...].astype(np.float32))                # (1,H,W)
        return x, y, uid

# ----------------------------
# 5) Model: UNet + ASPP (decoder-only on token-grid)
# ----------------------------
class ConvGNAct(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1, drop=0.0, groups=8):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=False)
        g = min(int(groups), int(out_ch))
        g = max(1, g)
        self.gn = nn.GroupNorm(g, out_ch)
        self.act = nn.SiLU(inplace=True)
        self.drop = nn.Dropout2d(float(drop)) if float(drop) > 0 else nn.Identity()

    def forward(self, x):
        x = self.conv(x)
        x = self.gn(x)
        x = self.act(x)
        x = self.drop(x)
        return x

class ASPP(nn.Module):
    def __init__(self, ch, rates=(1,2,4,6), drop=0.0):
        super().__init__()
        rs = list(rates)
        self.branches = nn.ModuleList()
        for r in rs:
            self.branches.append(
                nn.Sequential(
                    nn.Conv2d(ch, ch, 3, padding=int(r), dilation=int(r), bias=False),
                    nn.GroupNorm(min(8, ch), ch),
                    nn.SiLU(inplace=True),
                )
            )
        self.proj = nn.Sequential(
            nn.Conv2d(ch * len(rs), ch, 1, bias=False),
            nn.GroupNorm(min(8, ch), ch),
            nn.SiLU(inplace=True),
            nn.Dropout2d(float(drop)) if float(drop) > 0 else nn.Identity()
        )

    def forward(self, x):
        xs = [b(x) for b in self.branches]
        x = torch.cat(xs, dim=1)
        return self.proj(x)

class UNetASPP(nn.Module):
    def __init__(self, in_dim, base_ch=128, drop=0.10, aspp_rates=(1,2,4,6)):
        super().__init__()
        C = int(base_ch)

        # light encoder on token-grid
        self.stem = nn.Sequential(
            ConvGNAct(in_dim, C, k=1, s=1, p=0, drop=drop),
            ConvGNAct(C, C, k=3, s=1, p=1, drop=drop),
        )
        self.enc1 = nn.Sequential(
            ConvGNAct(C, C,   k=3, s=1, p=1, drop=drop),
            ConvGNAct(C, C,   k=3, s=1, p=1, drop=drop),
        )
        self.down1 = ConvGNAct(C, 2*C, k=3, s=2, p=1, drop=drop)
        self.enc2 = nn.Sequential(
            ConvGNAct(2*C, 2*C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(2*C, 2*C, k=3, s=1, p=1, drop=drop),
        )
        self.down2 = ConvGNAct(2*C, 4*C, k=3, s=2, p=1, drop=drop)
        self.enc3 = nn.Sequential(
            ConvGNAct(4*C, 4*C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(4*C, 4*C, k=3, s=1, p=1, drop=drop),
        )
        self.down3 = ConvGNAct(4*C, 6*C, k=3, s=2, p=1, drop=drop)

        self.bottleneck = nn.Sequential(
            ConvGNAct(6*C, 6*C, k=3, s=1, p=1, drop=drop),
            ASPP(6*C, rates=aspp_rates, drop=drop),
            ConvGNAct(6*C, 6*C, k=3, s=1, p=1, drop=drop),
        )

        # decoder
        self.dec3 = nn.Sequential(
            ConvGNAct(6*C + 4*C, 4*C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(4*C, 4*C, k=3, s=1, p=1, drop=drop),
        )
        self.dec2 = nn.Sequential(
            ConvGNAct(4*C + 2*C, 2*C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(2*C, 2*C, k=3, s=1, p=1, drop=drop),
        )
        self.dec1 = nn.Sequential(
            ConvGNAct(2*C + C, C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(C, C, k=3, s=1, p=1, drop=drop),
        )

        self.head = nn.Conv2d(C, 1, kernel_size=1)

    def forward(self, x):
        # x: (B, D, H, W)
        x0 = self.stem(x)     # (B,C,H,W)
        s1 = self.enc1(x0)    # (B,C,H,W)

        x1 = self.down1(s1)   # (B,2C,H/2,W/2)
        s2 = self.enc2(x1)

        x2 = self.down2(s2)   # (B,4C,H/4,W/4)
        s3 = self.enc3(x2)

        x3 = self.down3(s3)   # (B,6C,H/8,W/8)
        b  = self.bottleneck(x3)

        # up to s3
        u3 = F.interpolate(b, size=s3.shape[-2:], mode="bilinear", align_corners=False)
        d3 = self.dec3(torch.cat([u3, s3], dim=1))

        # up to s2
        u2 = F.interpolate(d3, size=s2.shape[-2:], mode="bilinear", align_corners=False)
        d2 = self.dec2(torch.cat([u2, s2], dim=1))

        # up to s1
        u1 = F.interpolate(d2, size=s1.shape[-2:], mode="bilinear", align_corners=False)
        d1 = self.dec1(torch.cat([u1, s1], dim=1))

        logit = self.head(d1)  # (B,1,H,W) token-grid resolution
        return logit

# ----------------------------
# 6) EMA
# ----------------------------
class EMA:
    def __init__(self, model: nn.Module, decay: float = 0.999):
        self.decay = float(decay)
        self.shadow = {}
        self.backup = {}
        for n, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[n] = p.detach().clone()

    @torch.no_grad()
    def update(self, model: nn.Module):
        d = self.decay
        for n, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.shadow[n].mul_(d).add_(p.detach(), alpha=(1.0 - d))

    @torch.no_grad()
    def apply_shadow(self, model: nn.Module):
        self.backup = {}
        for n, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.backup[n] = p.detach().clone()
            p.copy_(self.shadow[n])

    @torch.no_grad()
    def restore(self, model: nn.Module):
        for n, p in model.named_parameters():
            if not p.requires_grad:
                continue
            p.copy_(self.backup[n])
        self.backup = {}

# ----------------------------
# 7) Loss + metrics
# ----------------------------
def dice_score(prob, target, eps=1e-6):
    # prob/target: (B,1,H,W) float
    prob = prob.float()
    target = target.float()
    num = (prob * target).sum(dim=(2,3)) * 2.0
    den = (prob + target).sum(dim=(2,3)).clamp_min(eps)
    d = (num / den).mean()
    return d

def dice_loss_from_logits(logits, target, eps=1e-6):
    prob = torch.sigmoid(logits)
    return 1.0 - dice_score(prob, target, eps=eps)

def focal_bce_with_logits(logits, targets, gamma=0.0):
    bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
    if gamma and gamma > 0:
        p = torch.sigmoid(logits)
        p_t = p * targets + (1.0 - p) * (1.0 - targets)
        mod = (1.0 - p_t).clamp_min(0.0).pow(float(gamma))
        bce = bce * mod
    return bce.mean()

@torch.no_grad()
def eval_loader(model, loader, ema=None, thr_list=None):
    model.eval()
    if ema is not None:
        ema.apply_shadow(model)

    tot_loss = 0.0
    tot_dice = 0.0
    n_batches = 0

    # threshold sweep (dice@thr)
    thr_list = thr_list if thr_list is not None else [0.5]
    thr_list = [float(t) for t in thr_list]
    dice_thr = np.zeros(len(thr_list), dtype=np.float64)
    cnt_thr = 0

    for xb, yb, _ in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=use_amp):
            logits = model(xb)
            loss = focal_bce_with_logits(logits, yb, gamma=float(CFG.get("focal_gamma", 0.0))) * float(CFG["bce_weight"]) \
                 + dice_loss_from_logits(logits, yb) * float(CFG["dice_weight"])

        tot_loss += float(loss.item())
        prob = torch.sigmoid(logits)
        tot_dice += float(dice_score(prob, yb).item())
        n_batches += 1

        # sweep on CPU-friendly boolean
        prob_np = prob.detach().float().cpu().numpy()
        y_np = yb.detach().float().cpu().numpy()
        for j, t in enumerate(thr_list):
            pr = (prob_np >= t).astype(np.float32)
            # dice (batch mean)
            inter = (pr * y_np).sum(axis=(2,3)) * 2.0
            den = (pr + y_np).sum(axis=(2,3)) + 1e-6
            dice_thr[j] += float((inter / den).mean())
        cnt_thr += 1

    if ema is not None:
        ema.restore(model)

    out = {
        "val_loss": tot_loss / max(1, n_batches),
        "val_dice_prob": tot_dice / max(1, n_batches),
        "thr_list": thr_list,
        "val_dice_thr": (dice_thr / max(1, cnt_thr)).tolist(),
    }
    return out

# ----------------------------
# 8) Train one fold
# ----------------------------
def train_one_fold(df_tr, df_va, cfg, seed):
    seed_everything(seed)

    ds_tr = TokenMaskDataset(df_tr, is_train=True, cfg=cfg)
    ds_va = TokenMaskDataset(df_va, is_train=False, cfg=cfg)

    cpu_cnt = os.cpu_count() or 2
    nw = 2 if cpu_cnt >= 4 else 0
    pin = (device.type == "cuda")

    dl_tr = DataLoader(ds_tr, batch_size=int(cfg["batch_size"]), shuffle=True,
                       num_workers=nw, pin_memory=pin, drop_last=False,
                       persistent_workers=(nw > 0))
    dl_va = DataLoader(ds_va, batch_size=int(cfg["batch_size"]), shuffle=False,
                       num_workers=nw, pin_memory=pin, drop_last=False,
                       persistent_workers=(nw > 0))

    model = UNetASPP(in_dim=TOKEN_DIM, base_ch=int(cfg["base_ch"]), drop=float(cfg["drop"]), aspp_rates=tuple(cfg["aspp_rates"])).to(device)

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        weight_decay=float(cfg["weight_decay"]),
        betas=(0.9, 0.99),
        eps=1e-8,
    )

    accum_steps = max(1, int(cfg.get("accum_steps", 1)))
    steps_per_epoch = int(math.ceil(len(dl_tr) / accum_steps))
    total_steps = int(cfg["epochs"]) * max(1, steps_per_epoch)
    warmup_steps = int(float(cfg["warmup_frac"]) * total_steps)

    def lr_lambda(step):
        if warmup_steps > 0 and step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))
        # cosine decay after warmup
        t = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        t = min(max(t, 0.0), 1.0)
        return 0.5 * (1.0 + math.cos(math.pi * t))

    sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda)

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    ema = EMA(model, decay=float(cfg["ema_decay"])) if bool(cfg.get("use_ema", True)) else None

    best = {"val_loss": 1e18, "val_dice_prob": -1.0, "epoch": -1}
    best_state = None
    bad = 0
    opt_step = 0

    thr_list = np.linspace(0.05, 0.95, int(cfg["thr_grid"])).tolist()

    for epoch in range(int(cfg["epochs"])):
        model.train()
        opt.zero_grad(set_to_none=True)

        loss_sum = 0.0
        n_sum = 0
        micro = 0
        t0 = time.time()

        for xb, yb, _ in dl_tr:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model(xb)
                loss = focal_bce_with_logits(logits, yb, gamma=float(cfg.get("focal_gamma", 0.0))) * float(cfg["bce_weight"]) \
                     + dice_loss_from_logits(logits, yb) * float(cfg["dice_weight"])
                loss = loss / accum_steps

            scaler.scale(loss).backward()
            micro += 1
            loss_sum += float(loss.item()) * xb.size(0) * accum_steps
            n_sum += xb.size(0)

            if (micro % accum_steps) == 0:
                if float(cfg.get("grad_clip", 0.0)) > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))

                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                sch.step()
                opt_step += 1
                if ema is not None:
                    ema.update(model)

        # flush last
        if (micro % accum_steps) != 0:
            if float(cfg.get("grad_clip", 0.0)) > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)
            sch.step()
            opt_step += 1
            if ema is not None:
                ema.update(model)

        # validate (EMA)
        ev = eval_loader(model, dl_va, ema=ema, thr_list=thr_list)
        vloss = float(ev["val_loss"])
        vdice = float(ev["val_dice_prob"])

        dt = time.time() - t0
        print(f"  epoch {epoch+1:03d}/{cfg['epochs']} | tr_loss={loss_sum/max(1,n_sum):.5f} | val_loss={vloss:.5f} | val_dice(sigmoid)={vdice:.5f} | opt_step={opt_step} | dt={dt:.1f}s")

        improved = (best["val_loss"] - vloss) > float(cfg["early_min_delta"])
        if improved:
            best["val_loss"] = vloss
            best["val_dice_prob"] = vdice
            best["epoch"] = int(epoch)

            # save EMA-weighted state (since eval uses EMA)
            if ema is not None:
                ema.apply_shadow(model)
                best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
                ema.restore(model)
            else:
                best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}

            bad = 0
        else:
            bad += 1
            if bad >= int(cfg["early_patience"]):
                print(f"  early stop at epoch {epoch+1}, best_epoch={best['epoch']+1}, best_val_loss={best['val_loss']:.5f}")
                break

        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    # load best
    if best_state is not None:
        model.load_state_dict(best_state, strict=True)

    # final val sweep for best thr (dice on hard mask)
    ev = eval_loader(model, dl_va, ema=None, thr_list=thr_list)
    dice_thr = np.array(ev["val_dice_thr"], dtype=np.float64)
    j = int(np.argmax(dice_thr))
    best_thr = float(ev["thr_list"][j])
    best_dice_hard = float(dice_thr[j])

    pack = {
        "arch": "UNetASPP_on_DINOv2TokenGrid_v3.1",
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "cfg": dict(cfg),
        "token_dim": int(TOKEN_DIM),
        "best_epoch": int(best["epoch"] + 1),
        "best_val_loss": float(best["val_loss"]),
        "best_val_dice_prob": float(best["val_dice_prob"]),
        "best_thr": float(best_thr),
        "best_val_dice_hard": float(best_dice_hard),
    }
    return pack

# ----------------------------
# 9) Multi-seed CV + save fold models + bundle
# ----------------------------
OUT_DIR = Path("/kaggle/working/recodai_luc_mask_artifacts")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# seed plan (mirip pola kamu)
N_SEEDS = 2 if (device.type == "cuda" and MEM_GB >= 30) else 1
if os.environ.get("MASK_NSEEDS","").strip():
    N_SEEDS = int(os.environ["MASK_NSEEDS"])
SEEDS = [int(CFG["seed"]) + i*17 for i in range(max(1, N_SEEDS))]
print("SEED plan:", SEEDS)

fold_packs_all = []
fold_rows = []
best_epochs_all = []

for si, seed in enumerate(SEEDS, 1):
    print("\n==============================")
    print(f"== SEED {seed} ({si}/{len(SEEDS)})")
    print("==============================")

    models_dir = OUT_DIR / f"mask_folds_seed_{seed}"
    models_dir.mkdir(parents=True, exist_ok=True)

    for f in unique_folds:
        print(f"\n[Seed {seed} | Fold {f}]")

        df_tr = df_train_tabular[df_train_tabular["fold"] != f].reset_index(drop=True)
        df_va = df_train_tabular[df_train_tabular["fold"] == f].reset_index(drop=True)

        pack = train_one_fold(df_tr, df_va, CFG, seed=int(seed) + int(f)*101)
        pack["seed"] = int(seed)
        pack["fold"] = int(f)

        # save fold pt
        pt_path = models_dir / f"mask_fold_{f}.pt"
        torch.save({"pack": pack}, pt_path)

        fold_packs_all.append(pack)
        best_epochs_all.append(int(pack["best_epoch"]))

        fold_rows.append({
            "seed": int(seed),
            "fold": int(f),
            "best_epoch": int(pack["best_epoch"]),
            "best_val_loss": float(pack["best_val_loss"]),
            "best_val_dice_prob": float(pack["best_val_dice_prob"]),
            "best_thr": float(pack["best_thr"]),
            "best_val_dice_hard": float(pack["best_val_dice_hard"]),
            "pt_path": str(pt_path),
        })

        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

df_fold = pd.DataFrame(fold_rows).sort_values(["seed","fold"]).reset_index(drop=True)
display(df_fold)
df_fold.to_csv(OUT_DIR / "oof_mask_metrics.csv", index=False)

# recommended_thr: median best_thr across all (seed,fold)
recommended_thr = float(np.median(df_fold["best_thr"].to_numpy(dtype=np.float64)))
print("\nRecommended thr (median of fold best_thr):", recommended_thr)

# ----------------------------
# 10) Train FULL model(s) per seed (epochs = median best_epoch * 1.15, capped)
# ----------------------------
def train_full(df_full, cfg, seed, epochs_full):
    seed_everything(seed)

    ds = TokenMaskDataset(df_full, is_train=True, cfg=cfg)
    cpu_cnt = os.cpu_count() or 2
    nw = 2 if cpu_cnt >= 4 else 0
    pin = (device.type == "cuda")

    dl = DataLoader(ds, batch_size=int(cfg["batch_size"]), shuffle=True,
                    num_workers=nw, pin_memory=pin, drop_last=False,
                    persistent_workers=(nw > 0))

    model = UNetASPP(in_dim=TOKEN_DIM, base_ch=int(cfg["base_ch"]), drop=float(cfg["drop"]), aspp_rates=tuple(cfg["aspp_rates"])).to(device)

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        weight_decay=float(cfg["weight_decay"]),
        betas=(0.9, 0.99),
        eps=1e-8,
    )

    accum_steps = max(1, int(cfg.get("accum_steps", 1)))
    steps_per_epoch = int(math.ceil(len(dl) / accum_steps))
    total_steps = int(epochs_full) * max(1, steps_per_epoch)
    warmup_steps = int(float(cfg["warmup_frac"]) * total_steps)

    def lr_lambda(step):
        if warmup_steps > 0 and step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))
        t = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        t = min(max(t, 0.0), 1.0)
        return 0.5 * (1.0 + math.cos(math.pi * t))

    sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda)

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    ema = EMA(model, decay=float(cfg["ema_decay"])) if bool(cfg.get("use_ema", True)) else None

    print(f"\nTraining FULL mask model | seed={seed} | epochs={epochs_full}")

    for ep in range(int(epochs_full)):
        model.train()
        opt.zero_grad(set_to_none=True)
        loss_sum, n_sum, micro = 0.0, 0, 0

        for xb, yb, _ in dl:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model(xb)
                loss = focal_bce_with_logits(logits, yb, gamma=float(cfg.get("focal_gamma", 0.0))) * float(cfg["bce_weight"]) \
                     + dice_loss_from_logits(logits, yb) * float(cfg["dice_weight"])
                loss = loss / accum_steps

            scaler.scale(loss).backward()
            micro += 1
            loss_sum += float(loss.item()) * xb.size(0) * accum_steps
            n_sum += xb.size(0)

            if (micro % accum_steps) == 0:
                if float(cfg.get("grad_clip", 0.0)) > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                sch.step()
                if ema is not None:
                    ema.update(model)

        if (micro % accum_steps) != 0:
            if float(cfg.get("grad_clip", 0.0)) > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)
            sch.step()
            if ema is not None:
                ema.update(model)

        print(f"  full epoch {ep+1:03d}/{epochs_full} | loss={loss_sum/max(1,n_sum):.5f}")

        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    used_ema = bool(ema is not None)
    if ema is not None:
        ema.apply_shadow(model)

    full_pack = {
        "arch": "UNetASPP_on_DINOv2TokenGrid_v3.1",
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "cfg": dict(cfg),
        "token_dim": int(TOKEN_DIM),
        "epochs_full": int(epochs_full),
        "seed": int(seed),
        "used_ema": bool(used_ema),
    }

    if ema is not None:
        ema.restore(model)

    return full_pack

flat_best = np.array(best_epochs_all, dtype=np.int32) if len(best_epochs_all) else np.array([max(10, CFG["epochs"]//2)], dtype=np.int32)
med_best = int(np.median(flat_best))
epochs_full = int(max(10, round(med_best * 1.15)))
epochs_full = int(min(epochs_full, int(CFG["epochs"])))  # cap

full_packs = []
for seed in SEEDS:
    full_packs.append(train_full(df_train_tabular, CFG, seed=int(seed), epochs_full=int(epochs_full)))

# ----------------------------
# 11) Save final bundle (compatible style)
# ----------------------------
final_bundle = {
    "type": "mask_unet_aspp_decoder_on_dinov2_tokengrid_v3.1",
    "cfg_name": CFG_NAME,
    "cfg": CFG,
    "seed_plan": SEEDS,
    "token_dim": int(TOKEN_DIM),
    "token_grid_example": TOKEN_HW_EX,
    "fold_packs": fold_packs_all,   # list: seed x fold
    "full_packs": full_packs,       # list: per seed
    "recommended_thr": float(recommended_thr),
    "fold_metrics_csv": str(OUT_DIR / "oof_mask_metrics.csv"),
    "notes": "Decoder-only UNet+ASPP trained on DINOv2 token-grid embeddings; GT mask downsampled to token grid.",
}

FINAL_MASK_MODEL_PT = str(OUT_DIR / "final_mask_model.pt")
torch.save(final_bundle, FINAL_MASK_MODEL_PT)

with open(OUT_DIR / "final_mask_bundle.json", "w") as f:
    json.dump({
        "final_mask_model_pt": FINAL_MASK_MODEL_PT,
        "cfg_name": CFG_NAME,
        "seeds": SEEDS,
        "n_fold_packs": int(len(fold_packs_all)),
        "n_full_packs": int(len(full_packs)),
        "recommended_thr": float(recommended_thr),
        "token_dim": int(TOKEN_DIM),
        "epochs_full": int(epochs_full),
    }, f, indent=2)

MASK_OOF_REPORT = {
    "cfg_name": CFG_NAME,
    "seeds": SEEDS,
    "recommended_thr": float(recommended_thr),
    "fold_metrics_path": str(OUT_DIR / "oof_mask_metrics.csv"),
    "epochs_full": int(epochs_full),
}

print("\nSaved artifacts:")
print("  fold models   ->", OUT_DIR, "(mask_folds_seed_*/mask_fold_*.pt)")
print("  fold metrics  ->", OUT_DIR / "oof_mask_metrics.csv")
print("  final bundle  ->", FINAL_MASK_MODEL_PT)
print("  meta json     ->", OUT_DIR / "final_mask_bundle.json")

globals().update({
    "FINAL_MASK_MODEL_PT": FINAL_MASK_MODEL_PT,
    "MASK_OOF_REPORT": MASK_OOF_REPORT,
})


# Optimize Model & Hyperparameters (Iterative)

In [None]:
# ============================================================
# Step 4 — Optimize Mask Model & Hyperparameters (Iterative) — UNet+ASPP on DINOv2 Token-Grid
# REVISI FULL v4.0 (2-stage search, resume-safe, AMP+EMA+accum, robust token-grid resize)
#
# Primary score: OOF best Dice (hard mask) over threshold grid
#
# Output:
# - /kaggle/working/recodai_luc_mask_artifacts/opt_search/stage1_results.csv
# - /kaggle/working/recodai_luc_mask_artifacts/opt_search/opt_results.csv
# - /kaggle/working/recodai_luc_mask_artifacts/opt_search/opt_results.json
# - /kaggle/working/recodai_luc_mask_artifacts/opt_search/opt_fold_details.csv
# - /kaggle/working/recodai_luc_mask_artifacts/opt_search/oof_scalars_<cfg_name>.csv (top configs)
# - /kaggle/working/recodai_luc_mask_artifacts/best_mask_config.json
# - /kaggle/working/recodai_luc_mask_artifacts/best_mask_model.pt
#
# REQUIRE:
# - df_train_tabular with columns: uid, case_id, fold, y   (variant optional)
# - TOKEN_CACHE_DIR contains per-uid token-grid npy/npz (H,W,D) or (N,D) reshapeable
# - MASK_DIR (optional but recommended) contains per-uid mask images (png/jpg/tif)
# ============================================================

import os, gc, json, math, time, warnings
from pathlib import Path
from collections import Counter

import numpy as np
import pandas as pd
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from IPython.display import display

# ----------------------------
# 0) Require data
# ----------------------------
need_vars = ["df_train_tabular"]
for v in need_vars:
    if v not in globals():
        raise RuntimeError(f"Missing `{v}`. Butuh df_train_tabular (uid/case_id/fold/y).")

df_train_tabular = df_train_tabular.copy()
req = {"uid","case_id","fold","y"}
miss = [c for c in req if c not in df_train_tabular.columns]
if miss:
    raise ValueError(f"df_train_tabular missing columns: {miss}")

df_train_tabular["uid"] = df_train_tabular["uid"].astype(str)
df_train_tabular["case_id"] = df_train_tabular["case_id"].astype(str)
df_train_tabular["fold"] = df_train_tabular["fold"].astype(int)
df_train_tabular["y"] = df_train_tabular["y"].astype(int)

uids_all  = df_train_tabular["uid"].to_numpy()
y_img_all = df_train_tabular["y"].to_numpy(np.int64)
folds_all = df_train_tabular["fold"].to_numpy(np.int64)
unique_folds = sorted(df_train_tabular["fold"].unique().tolist())

print("Optimize setup (MASK UNet+ASPP on token-grid):")
print(f"  rows={len(df_train_tabular)} | folds={len(unique_folds)} | forged%={float(y_img_all.mean())*100:.2f}")

# ----------------------------
# 1) Global settings
# ----------------------------
SEED = 2025
THR_GRID = 81

# 2-stage runtime controls
STAGE1_FOLDS = min(3, len(unique_folds))
STAGE1_EPOCH_CAP = 30
STAGE1_PAT_CAP = 6

STAGE2_TOPM = 3
REPORT_TOPK_OOF = 3

TIME_BUDGET_SEC = 0  # 0=off

def seed_everything(seed=2025):
    import random
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")
print("Device:", device, "| AMP:", use_amp)

def get_mem_gb():
    if device.type != "cuda":
        return 0.0
    try:
        return float(torch.cuda.get_device_properties(0).total_memory / (1024**3))
    except Exception:
        return 0.0

MEM_GB = get_mem_gb()
print("GPU mem GB:", MEM_GB)

if device.type == "cuda":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# ----------------------------
# 2) Auto-detect dirs (TOKEN_CACHE_DIR + MASK_DIR)
# ----------------------------
def _first_existing(paths):
    for p in paths:
        if p is None:
            continue
        p = Path(p)
        if p.exists():
            return p
    return None

# MASK_DIR
MASK_DIR = None
if "PATHS" in globals() and isinstance(PATHS, dict):
    MASK_DIR = _first_existing([PATHS.get("TRAIN_MASK_DIR"), PATHS.get("MASK_DIR"), PATHS.get("TRAIN_MASKS")])
if MASK_DIR is None:
    MASK_DIR = _first_existing([
        "/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_masks",
        "/kaggle/input/recodai-luc-scientific-image-forgery-detection/masks_train",
        "/kaggle/input/recodai-luc-scientific-image-forgery-detection/masks",
        "/kaggle/working/recodai_luc/train_masks",
    ])
print("MASK_DIR:", str(MASK_DIR) if MASK_DIR else "(None)")

# TOKEN_CACHE_DIR
TOKEN_CACHE_DIR = None
if "CACHE_ROOT" in globals():
    try:
        cr = Path(CACHE_ROOT)
        if cr.exists():
            TOKEN_CACHE_DIR = cr
    except Exception:
        pass

if TOKEN_CACHE_DIR is None:
    base_candidates = [
        Path("/kaggle/working/recodai_luc/cache/dino_v2"),
        Path("/kaggle/working/recodai_luc/cache"),
        Path("/kaggle/input/recod-ailuc-dinov2-train/recodai_luc/cache/dino_v2"),
        Path("/kaggle/input/recod-ailuc-dinov2-train/recodai_luc/cache"),
    ]
    for b in base_candidates:
        if not b.exists():
            continue
        cfg_dirs = sorted([p for p in b.glob("cfg_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
        if cfg_dirs:
            TOKEN_CACHE_DIR = cfg_dirs[0]
            break
        TOKEN_CACHE_DIR = b
        break

print("TOKEN_CACHE_DIR:", str(TOKEN_CACHE_DIR) if TOKEN_CACHE_DIR else "(None)")
if TOKEN_CACHE_DIR is None:
    raise RuntimeError("TOKEN_CACHE_DIR tidak ditemukan. Pastikan cache token-grid DINOv2 per uid tersedia.")

# ----------------------------
# 3) Fast path maps (token/mask) + token shape inference
# ----------------------------
def _load_np_any(p: Path):
    p = Path(p)
    if p.suffix.lower() == ".npy":
        return np.load(p, allow_pickle=False)
    if p.suffix.lower() == ".npz":
        z = np.load(p, allow_pickle=False)
        keys = list(z.keys())
        for k in keys:
            a = z[k]
            if isinstance(a, np.ndarray) and a.ndim in (2,3):
                return a
        return z[keys[0]] if keys else None
    return None

def _reshape_to_grid(a: np.ndarray):
    a = np.asarray(a)
    if a.ndim == 3:
        # (D,H,W) -> (H,W,D)
        if a.shape[0] in [384, 512, 768, 1024, 1536] and (a.shape[1] * a.shape[2] > 16):
            D,H,W = a.shape
            return np.transpose(a, (1,2,0))
        return a  # assume (H,W,D)
    if a.ndim == 2:
        N,D = a.shape
        s = int(round(math.sqrt(N)))
        if s*s == N:
            return a.reshape(s,s,D)
        for h in range(1, int(math.sqrt(N))+1):
            if N % h == 0:
                w = N // h
                if h >= 4 and w >= 4:
                    return a.reshape(h,w,D)
        return a.reshape(1,N,D)
    return None

def find_token_file_fast(uid: str):
    uid = str(uid)
    cand = [
        TOKEN_CACHE_DIR / f"{uid}.npz",
        TOKEN_CACHE_DIR / f"{uid}.npy",
        TOKEN_CACHE_DIR / "tokens_train" / f"{uid}.npz",
        TOKEN_CACHE_DIR / "tokens_train" / f"{uid}.npy",
        TOKEN_CACHE_DIR / "train" / f"{uid}.npz",
        TOKEN_CACHE_DIR / "train" / f"{uid}.npy",
        TOKEN_CACHE_DIR / "feat_train" / f"{uid}.npz",
        TOKEN_CACHE_DIR / "feat_train" / f"{uid}.npy",
    ]
    for p in cand:
        if p.exists():
            return p
    return None

def find_mask_file_fast(uid: str):
    if MASK_DIR is None:
        return None
    uid = str(uid)
    exts = [".png",".jpg",".jpeg",".tif",".tiff",".bmp"]
    for ex in exts:
        p = Path(MASK_DIR) / f"{uid}{ex}"
        if p.exists():
            return p
    # fallback single glob (not recursive)
    hits = list(Path(MASK_DIR).glob(f"{uid}.*"))
    return hits[0] if hits else None

# Build maps (fast, no recursive glob)
TOKEN_PATH = {}
MASK_PATH = {}

token_hw_counter = Counter()
token_dim_counter = Counter()

max_probe = min(300, len(uids_all))
for uid in uids_all[:max_probe]:
    tp = find_token_file_fast(uid)
    TOKEN_PATH[uid] = tp
    if tp is None:
        continue
    a = _load_np_any(tp)
    if a is None:
        continue
    g = _reshape_to_grid(a)
    if g is None or g.ndim != 3:
        continue
    H,W,D = int(g.shape[0]), int(g.shape[1]), int(g.shape[2])
    if H >= 4 and W >= 4 and D >= 16:
        token_hw_counter[(H,W)] += 1
        token_dim_counter[D] += 1

# fill token path map for all uids (fast)
for uid in uids_all[max_probe:]:
    TOKEN_PATH[uid] = find_token_file_fast(uid)

if MASK_DIR is not None:
    for uid in uids_all[:max_probe]:
        MASK_PATH[uid] = find_mask_file_fast(uid)
    for uid in uids_all[max_probe:]:
        MASK_PATH[uid] = find_mask_file_fast(uid)

if not token_hw_counter or not token_dim_counter:
    raise RuntimeError("Tidak bisa infer token grid shape/dim dari token cache. Cek format file token-grid per uid.")

TOKEN_HW = token_hw_counter.most_common(1)[0][0]
TOKEN_DIM = token_dim_counter.most_common(1)[0][0]

print("Token grid mode:", TOKEN_HW, "| token_dim mode:", TOKEN_DIM)

# ----------------------------
# 4) Utilities: load/resize tokens + load/resize mask
# ----------------------------
def load_token_grid_from_path(p: Path):
    a = _load_np_any(p)
    g = _reshape_to_grid(a)
    if g is None:
        return None
    g = g.astype(np.float32, copy=False)
    if not np.isfinite(g).all():
        g = np.nan_to_num(g, nan=0.0, posinf=0.0, neginf=0.0)
    return g  # (H,W,D)

def resize_grid_hw(g: np.ndarray, out_hw):
    # g: (H,W,D) -> resize H,W with bilinear over spatial
    Ht, Wt = int(out_hw[0]), int(out_hw[1])
    if g.shape[0] == Ht and g.shape[1] == Wt:
        return g
    x = torch.from_numpy(np.transpose(g, (2,0,1))).unsqueeze(0)  # (1,D,H,W)
    x = F.interpolate(x, size=(Ht,Wt), mode="bilinear", align_corners=False)
    out = x.squeeze(0).permute(1,2,0).contiguous().cpu().numpy()
    return out.astype(np.float32, copy=False)

def load_mask_bin(path: Path):
    im = Image.open(path).convert("L")
    m = (np.array(im, dtype=np.uint8) > 127).astype(np.uint8)
    return m

def resize_mask_soft(mask_bin: np.ndarray, out_hw):
    # (Hfull,Wfull)->(Ht,Wt) soft 0..1
    Ht,Wt = int(out_hw[0]), int(out_hw[1])
    im = Image.fromarray(mask_bin.astype(np.uint8)*255)
    im = im.resize((Wt,Ht), resample=Image.BILINEAR)
    m = (np.array(im, dtype=np.float32) / 255.0).astype(np.float32)
    return m

# ----------------------------
# 5) Dataset (with optional dropping missing pos masks)
# ----------------------------
DROP_MISSING_POS_MASKS = True  # recommended

def build_df_filtered(df: pd.DataFrame):
    df = df.copy()
    tok_exists = df["uid"].map(lambda u: TOKEN_PATH.get(str(u), None) is not None).astype(int)
    df["token_exists"] = tok_exists

    if MASK_DIR is not None:
        m_exists = df["uid"].map(lambda u: MASK_PATH.get(str(u), None) is not None).astype(int)
    else:
        m_exists = 0
    df["mask_exists"] = m_exists

    # drop missing tokens (training useless)
    df = df[df["token_exists"] == 1].reset_index(drop=True)

    if DROP_MISSING_POS_MASKS and MASK_DIR is not None:
        # keep all negatives; keep positives only if mask exists
        df = df[(df["y"] == 0) | (df["mask_exists"] == 1)].reset_index(drop=True)

    return df

df_train_filtered = build_df_filtered(df_train_tabular)
print("After filtering:")
print("  rows:", len(df_train_filtered),
      "| forged%:", float(df_train_filtered["y"].mean())*100.0,
      "| token missing dropped:", int((df_train_tabular.shape[0] - df_train_filtered.shape[0])))

class TokenMaskDataset(Dataset):
    def __init__(self, df: pd.DataFrame, is_train: bool, cfg: dict):
        self.df = df.reset_index(drop=True)
        self.is_train = bool(is_train)
        self.cfg = cfg

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        uid = str(row["uid"])
        yflag = int(row["y"])

        tp = TOKEN_PATH.get(uid, None)
        g = load_token_grid_from_path(tp) if tp is not None else None
        if g is None:
            g = np.zeros((TOKEN_HW[0], TOKEN_HW[1], TOKEN_DIM), dtype=np.float32)
        if g.shape[2] != TOKEN_DIM:
            # if dim mismatch, pad/crop to TOKEN_DIM
            D = g.shape[2]
            if D > TOKEN_DIM:
                g = g[:, :, :TOKEN_DIM]
            else:
                pad = np.zeros((g.shape[0], g.shape[1], TOKEN_DIM - D), dtype=np.float32)
                g = np.concatenate([g, pad], axis=2)
        g = resize_grid_hw(g, TOKEN_HW)

        # mask
        m_soft = np.zeros((TOKEN_HW[0], TOKEN_HW[1]), dtype=np.float32)
        if yflag == 1 and MASK_DIR is not None:
            mp = MASK_PATH.get(uid, None)
            if mp is not None:
                m_bin = load_mask_bin(mp)
                m_soft = resize_mask_soft(m_bin, TOKEN_HW)

        # aug: flips + token noise
        if self.is_train:
            if float(self.cfg.get("aug_hflip", 0.0)) > 0 and np.random.rand() < float(self.cfg["aug_hflip"]):
                g = g[:, ::-1, :].copy()
                m_soft = m_soft[:, ::-1].copy()
            if float(self.cfg.get("aug_vflip", 0.0)) > 0 and np.random.rand() < float(self.cfg["aug_vflip"]):
                g = g[::-1, :, :].copy()
                m_soft = m_soft[::-1, :].copy()

            ns = float(self.cfg.get("input_noise_std", 0.0))
            if ns > 0:
                g = g + np.random.randn(*g.shape).astype(np.float32) * ns

        x = torch.from_numpy(np.transpose(g, (2,0,1)).astype(np.float32))   # (D,H,W)
        y = torch.from_numpy(m_soft[None, ...].astype(np.float32))          # (1,H,W)
        return x, y, uid

# ----------------------------
# 6) Model: UNet + ASPP
# ----------------------------
class ConvGNAct(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1, drop=0.0, groups=8):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=False)
        g = max(1, min(int(groups), int(out_ch)))
        self.gn = nn.GroupNorm(g, out_ch)
        self.act = nn.SiLU(inplace=True)
        self.drop = nn.Dropout2d(float(drop)) if float(drop) > 0 else nn.Identity()

    def forward(self, x):
        return self.drop(self.act(self.gn(self.conv(x))))

class ASPP(nn.Module):
    def __init__(self, ch, rates=(1,2,4,6), drop=0.0):
        super().__init__()
        rs = list(rates)
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(ch, ch, 3, padding=int(r), dilation=int(r), bias=False),
                nn.GroupNorm(max(1, min(8, ch)), ch),
                nn.SiLU(inplace=True),
            ) for r in rs
        ])
        self.proj = nn.Sequential(
            nn.Conv2d(ch * len(rs), ch, 1, bias=False),
            nn.GroupNorm(max(1, min(8, ch)), ch),
            nn.SiLU(inplace=True),
            nn.Dropout2d(float(drop)) if float(drop) > 0 else nn.Identity()
        )

    def forward(self, x):
        xs = [b(x) for b in self.branches]
        return self.proj(torch.cat(xs, dim=1))

class UNetASPP(nn.Module):
    def __init__(self, in_dim, base_ch=160, drop=0.10, aspp_rates=(1,2,4,6)):
        super().__init__()
        C = int(base_ch)

        self.stem = nn.Sequential(
            ConvGNAct(in_dim, C, k=1, s=1, p=0, drop=drop),
            ConvGNAct(C, C, k=3, s=1, p=1, drop=drop),
        )
        self.enc1 = nn.Sequential(
            ConvGNAct(C, C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(C, C, k=3, s=1, p=1, drop=drop),
        )
        self.down1 = ConvGNAct(C, 2*C, k=3, s=2, p=1, drop=drop)
        self.enc2 = nn.Sequential(
            ConvGNAct(2*C, 2*C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(2*C, 2*C, k=3, s=1, p=1, drop=drop),
        )
        self.down2 = ConvGNAct(2*C, 4*C, k=3, s=2, p=1, drop=drop)
        self.enc3 = nn.Sequential(
            ConvGNAct(4*C, 4*C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(4*C, 4*C, k=3, s=1, p=1, drop=drop),
        )
        self.down3 = ConvGNAct(4*C, 6*C, k=3, s=2, p=1, drop=drop)

        self.bottleneck = nn.Sequential(
            ConvGNAct(6*C, 6*C, k=3, s=1, p=1, drop=drop),
            ASPP(6*C, rates=aspp_rates, drop=drop),
            ConvGNAct(6*C, 6*C, k=3, s=1, p=1, drop=drop),
        )

        self.dec3 = nn.Sequential(
            ConvGNAct(6*C + 4*C, 4*C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(4*C, 4*C, k=3, s=1, p=1, drop=drop),
        )
        self.dec2 = nn.Sequential(
            ConvGNAct(4*C + 2*C, 2*C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(2*C, 2*C, k=3, s=1, p=1, drop=drop),
        )
        self.dec1 = nn.Sequential(
            ConvGNAct(2*C + C, C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(C, C, k=3, s=1, p=1, drop=drop),
        )

        self.head = nn.Conv2d(C, 1, 1)

    def forward(self, x):
        x0 = self.stem(x)
        s1 = self.enc1(x0)

        x1 = self.down1(s1)
        s2 = self.enc2(x1)

        x2 = self.down2(s2)
        s3 = self.enc3(x2)

        x3 = self.down3(s3)
        b  = self.bottleneck(x3)

        u3 = F.interpolate(b,  size=s3.shape[-2:], mode="bilinear", align_corners=False)
        d3 = self.dec3(torch.cat([u3, s3], dim=1))

        u2 = F.interpolate(d3, size=s2.shape[-2:], mode="bilinear", align_corners=False)
        d2 = self.dec2(torch.cat([u2, s2], dim=1))

        u1 = F.interpolate(d2, size=s1.shape[-2:], mode="bilinear", align_corners=False)
        d1 = self.dec1(torch.cat([u1, s1], dim=1))

        return self.head(d1)  # (B,1,H,W)

# ----------------------------
# 7) EMA + loss/metrics
# ----------------------------
class EMA:
    def __init__(self, model: nn.Module, decay: float = 0.999):
        self.decay = float(decay)
        self.shadow = {n: p.detach().clone() for n,p in model.named_parameters() if p.requires_grad}
        self.backup = {}

    @torch.no_grad()
    def update(self, model: nn.Module):
        d = self.decay
        for n,p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.shadow[n].mul_(d).add_(p.detach(), alpha=(1.0 - d))

    @torch.no_grad()
    def apply_shadow(self, model: nn.Module):
        self.backup = {}
        for n,p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.backup[n] = p.detach().clone()
            p.copy_(self.shadow[n])

    @torch.no_grad()
    def restore(self, model: nn.Module):
        for n,p in model.named_parameters():
            if not p.requires_grad:
                continue
            p.copy_(self.backup[n])
        self.backup = {}

def dice_from_probs(prob, target, eps=1e-6):
    # prob/target: (B,1,H,W)
    num = (prob * target).sum(dim=(2,3)) * 2.0
    den = (prob + target).sum(dim=(2,3)).clamp_min(eps)
    return (num / den).mean()

def dice_hard_np(prob_np, target_np, thr):
    pr = (prob_np >= thr).astype(np.float32)
    inter = (pr * target_np).sum(axis=(2,3)) * 2.0
    den = (pr + target_np).sum(axis=(2,3)) + 1e-6
    return float((inter / den).mean())

def loss_bce_dice(logits, target, bce_w=1.0, dice_w=1.0, focal_gamma=0.0):
    bce = F.binary_cross_entropy_with_logits(logits, target, reduction="none")
    if focal_gamma and focal_gamma > 0:
        p = torch.sigmoid(logits)
        p_t = p * target + (1.0 - p) * (1.0 - target)
        bce = bce * (1.0 - p_t).clamp_min(0.0).pow(float(focal_gamma))
    bce = bce.mean()

    prob = torch.sigmoid(logits)
    dsc = dice_from_probs(prob, target)
    dloss = 1.0 - dsc
    return float(bce_w)*bce + float(dice_w)*dloss

@torch.no_grad()
def eval_loader(model, loader, ema=None, thr_grid=81):
    model.eval()
    if ema is not None:
        ema.apply_shadow(model)

    tot_loss = 0.0
    tot_dice_soft = 0.0
    nb = 0

    thrs = np.linspace(0.05, 0.95, int(thr_grid), dtype=np.float64)
    dice_sweep = np.zeros_like(thrs, dtype=np.float64)
    cnt = 0

    for xb, yb, _ in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)

        if use_amp:
            with torch.cuda.amp.autocast(enabled=True):
                logits = model(xb)
                loss = loss_bce_dice(
                    logits, yb,
                    bce_w=float(CFG_EVAL["bce_weight"]),
                    dice_w=float(CFG_EVAL["dice_weight"]),
                    focal_gamma=float(CFG_EVAL.get("focal_gamma", 0.0)),
                )
        else:
            logits = model(xb)
            loss = loss_bce_dice(
                logits, yb,
                bce_w=float(CFG_EVAL["bce_weight"]),
                dice_w=float(CFG_EVAL["dice_weight"]),
                focal_gamma=float(CFG_EVAL.get("focal_gamma", 0.0)),
            )

        prob = torch.sigmoid(logits)
        dsoft = dice_from_probs(prob, yb)

        tot_loss += float(loss.item() if hasattr(loss, "item") else loss)
        tot_dice_soft += float(dsoft.item())
        nb += 1

        prob_np = prob.detach().float().cpu().numpy()
        y_np = yb.detach().float().cpu().numpy()
        for i, t in enumerate(thrs):
            dice_sweep[i] += dice_hard_np(prob_np, y_np, thr=float(t))
        cnt += 1

    if ema is not None:
        ema.restore(model)

    dice_sweep = dice_sweep / max(1, cnt)
    j = int(np.argmax(dice_sweep))
    return {
        "val_loss": tot_loss / max(1, nb),
        "val_dice_soft": tot_dice_soft / max(1, nb),
        "best_thr": float(thrs[j]),
        "best_dice_hard": float(dice_sweep[j]),
    }

# ----------------------------
# 8) Train one fold (mask)
# ----------------------------
def train_one_fold_mask(df_tr, df_va, cfg):
    seed_everything(int(cfg["seed"]))

    ds_tr = TokenMaskDataset(df_tr, is_train=True, cfg=cfg)
    ds_va = TokenMaskDataset(df_va, is_train=False, cfg=cfg)

    cpu_cnt = os.cpu_count() or 2
    nw = 2 if cpu_cnt >= 4 else 0
    pin = (device.type == "cuda")

    dl_tr = DataLoader(ds_tr, batch_size=int(cfg["batch_size"]), shuffle=True,
                       num_workers=nw, pin_memory=pin, drop_last=False,
                       persistent_workers=(nw > 0))
    dl_va = DataLoader(ds_va, batch_size=int(cfg["batch_size"]), shuffle=False,
                       num_workers=nw, pin_memory=pin, drop_last=False,
                       persistent_workers=(nw > 0))

    model = UNetASPP(
        in_dim=int(TOKEN_DIM),
        base_ch=int(cfg["base_ch"]),
        drop=float(cfg["drop"]),
        aspp_rates=tuple(cfg["aspp_rates"]),
    ).to(device)

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        weight_decay=float(cfg["weight_decay"]),
        betas=(0.9, 0.99),
        eps=1e-8,
    )

    accum_steps = max(1, int(cfg.get("accum_steps", 1)))
    steps_per_epoch = int(math.ceil(len(dl_tr) / accum_steps))
    total_steps = int(cfg["epochs"]) * max(1, steps_per_epoch)
    warmup_steps = int(float(cfg["warmup_frac"]) * total_steps)

    def lr_lambda(step):
        if warmup_steps > 0 and step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))
        t = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        t = min(max(t, 0.0), 1.0)
        return 0.5 * (1.0 + math.cos(math.pi * t))

    sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda)

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp) if use_amp else None
    ema = EMA(model, decay=float(cfg["ema_decay"])) if bool(cfg.get("use_ema", True)) else None

    best = {"val_loss": 1e18, "epoch": -1, "best_thr": 0.5, "best_dice_hard": -1.0, "val_dice_soft": -1.0}
    best_state = None
    bad = 0
    opt_step = 0

    global CFG_EVAL
    CFG_EVAL = cfg  # for eval_loader access

    for epoch in range(int(cfg["epochs"])):
        model.train()
        opt.zero_grad(set_to_none=True)

        loss_sum, n_sum, micro = 0.0, 0, 0
        t0 = time.time()

        for xb, yb, _ in dl_tr:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            if use_amp:
                with torch.cuda.amp.autocast(enabled=True):
                    logits = model(xb)
                    loss = loss_bce_dice(
                        logits, yb,
                        bce_w=float(cfg["bce_weight"]),
                        dice_w=float(cfg["dice_weight"]),
                        focal_gamma=float(cfg.get("focal_gamma", 0.0)),
                    ) / accum_steps
                scaler.scale(loss).backward()
            else:
                logits = model(xb)
                loss = loss_bce_dice(
                    logits, yb,
                    bce_w=float(cfg["bce_weight"]),
                    dice_w=float(cfg["dice_weight"]),
                    focal_gamma=float(cfg.get("focal_gamma", 0.0)),
                ) / accum_steps
                loss.backward()

            micro += 1
            loss_sum += float(loss.item()) * xb.size(0) * accum_steps
            n_sum += xb.size(0)

            if (micro % accum_steps) == 0:
                if float(cfg["grad_clip"]) > 0:
                    if use_amp:
                        scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))

                if use_amp:
                    scaler.step(opt); scaler.update()
                else:
                    opt.step()

                opt.zero_grad(set_to_none=True)
                sch.step()
                opt_step += 1
                if ema is not None:
                    ema.update(model)

        if (micro % accum_steps) != 0:
            if float(cfg["grad_clip"]) > 0:
                if use_amp:
                    scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))

            if use_amp:
                scaler.step(opt); scaler.update()
            else:
                opt.step()

            opt.zero_grad(set_to_none=True)
            sch.step()
            opt_step += 1
            if ema is not None:
                ema.update(model)

        ev = eval_loader(model, dl_va, ema=ema, thr_grid=int(cfg["thr_grid"]))
        dt = time.time() - t0
        print(f"  ep {epoch+1:03d}/{cfg['epochs']} | tr_loss={loss_sum/max(1,n_sum):.5f} | "
              f"val_loss={ev['val_loss']:.5f} | val_dice_soft={ev['val_dice_soft']:.5f} | "
              f"best_dice_hard={ev['best_dice_hard']:.5f}@{ev['best_thr']:.2f} | opt_step={opt_step} | dt={dt:.1f}s")

        improved = (best["val_loss"] - float(ev["val_loss"])) > float(cfg["min_delta"])
        if improved:
            best["val_loss"] = float(ev["val_loss"])
            best["val_dice_soft"] = float(ev["val_dice_soft"])
            best["best_thr"] = float(ev["best_thr"])
            best["best_dice_hard"] = float(ev["best_dice_hard"])
            best["epoch"] = int(epoch)

            if ema is not None:
                ema.apply_shadow(model)
                best_state = {k: v.detach().cpu() for k,v in model.state_dict().items()}
                ema.restore(model)
            else:
                best_state = {k: v.detach().cpu() for k,v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= int(cfg["patience"]):
                print(f"  early stop: best_epoch={best['epoch']+1}, best_val_loss={best['val_loss']:.5f}")
                break

        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    if best_state is not None:
        model.load_state_dict(best_state, strict=True)

    pack = {
        "arch": "UNetASPP_on_DINOv2TokenGrid_opt_v4.0",
        "state_dict": {k: v.detach().cpu() for k,v in model.state_dict().items()},
        "cfg": dict(cfg),
        "token_dim": int(TOKEN_DIM),
        "token_hw": tuple(map(int, TOKEN_HW)),
        "best_epoch": int(best["epoch"] + 1),
        "best_val_loss": float(best["val_loss"]),
        "best_val_dice_soft": float(best["val_dice_soft"]),
        "best_thr": float(best["best_thr"]),
        "best_val_dice_hard": float(best["best_dice_hard"]),
    }
    return pack

# ----------------------------
# 9) CV evaluator for a config (store only scalars OOF)
# ----------------------------
@torch.no_grad()
def infer_val_scalars(model, loader, thr=0.5):
    model.eval()
    area_frac = []
    mean_prob = []
    uids = []
    for xb, _, uidb in loader:
        xb = xb.to(device, non_blocking=True)
        if use_amp:
            with torch.cuda.amp.autocast(enabled=True):
                logits = model(xb)
                prob = torch.sigmoid(logits)
        else:
            logits = model(xb)
            prob = torch.sigmoid(logits)

        pr = (prob >= float(thr)).float()
        af = pr.mean(dim=(1,2,3)).detach().cpu().numpy()  # (B,)
        mp = prob.mean(dim=(1,2,3)).detach().cpu().numpy()
        area_frac.append(af)
        mean_prob.append(mp)
        uids.extend(list(uidb))
    area_frac = np.concatenate(area_frac, axis=0).astype(np.float32)
    mean_prob = np.concatenate(mean_prob, axis=0).astype(np.float32)
    return uids, area_frac, mean_prob

def run_cv_config(cfg, cfg_name, folds_subset=None):
    use_folds = unique_folds if folds_subset is None else list(folds_subset)

    fold_rows = []
    fold_packs = []
    # store per-uid scalars (not full masks)
    oof_area = {uid: np.nan for uid in df_train_filtered["uid"].tolist()}
    oof_mean = {uid: np.nan for uid in df_train_filtered["uid"].tolist()}

    for f in use_folds:
        df_tr = df_train_filtered[df_train_filtered["fold"] != f].reset_index(drop=True)
        df_va = df_train_filtered[df_train_filtered["fold"] == f].reset_index(drop=True)

        # unique seed per fold for stability
        cfg_fold = dict(cfg)
        cfg_fold["seed"] = int(cfg.get("seed", SEED)) + int(f)*101

        print(f"    -> fold {f} | tr={len(df_tr)} va={len(df_va)} pos_va={int(df_va['y'].sum())}")

        pack = train_one_fold_mask(df_tr, df_va, cfg_fold)

        # val scalar inference (use model from pack)
        cpu_cnt = os.cpu_count() or 2
        nw = 2 if cpu_cnt >= 4 else 0
        pin = (device.type == "cuda")
        ds_va = TokenMaskDataset(df_va, is_train=False, cfg=cfg_fold)
        dl_va = DataLoader(ds_va, batch_size=int(cfg_fold["batch_size"]), shuffle=False,
                           num_workers=nw, pin_memory=pin, drop_last=False,
                           persistent_workers=(nw > 0))

        model = UNetASPP(in_dim=int(TOKEN_DIM), base_ch=int(cfg_fold["base_ch"]),
                         drop=float(cfg_fold["drop"]), aspp_rates=tuple(cfg_fold["aspp_rates"])).to(device)
        model.load_state_dict(pack["state_dict"], strict=True)

        ulist, area_frac, mean_prob = infer_val_scalars(model, dl_va, thr=float(pack["best_thr"]))

        for u, af, mp in zip(ulist, area_frac, mean_prob):
            oof_area[str(u)] = float(af)
            oof_mean[str(u)] = float(mp)

        fold_rows.append({
            "cfg": cfg_name,
            "fold": int(f),
            "best_epoch": int(pack["best_epoch"]),
            "best_val_loss": float(pack["best_val_loss"]),
            "best_val_dice_soft": float(pack["best_val_dice_soft"]),
            "best_thr": float(pack["best_thr"]),
            "best_val_dice_hard": float(pack["best_val_dice_hard"]),
        })

        pack2 = dict(pack)
        pack2["fold"] = int(f)
        fold_packs.append(pack2)

        del model
        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    df_f = pd.DataFrame(fold_rows)
    # aggregate score: mean of best_val_dice_hard across folds (works for stage1 subset too)
    score = float(df_f["best_val_dice_hard"].mean()) if len(df_f) else 0.0
    thr_med = float(np.median(df_f["best_thr"].to_numpy(dtype=np.float64))) if len(df_f) else 0.5

    summary = {
        "cfg": cfg_name,
        "stage": "full" if folds_subset is None else f"subset{len(use_folds)}",
        "score_mean_best_dice_hard": score,
        "thr_median": thr_med,

        "base_ch": int(cfg["base_ch"]),
        "drop": float(cfg["drop"]),
        "aspp_rates": str(tuple(cfg["aspp_rates"])),

        "epochs": int(cfg["epochs"]),
        "batch_size": int(cfg["batch_size"]),
        "accum_steps": int(cfg.get("accum_steps", 1)),
        "lr": float(cfg["lr"]),
        "weight_decay": float(cfg["weight_decay"]),
        "warmup_frac": float(cfg["warmup_frac"]),
        "grad_clip": float(cfg["grad_clip"]),
        "patience": int(cfg["patience"]),
        "min_delta": float(cfg["min_delta"]),
        "use_ema": bool(cfg.get("use_ema", True)),
        "ema_decay": float(cfg.get("ema_decay", 0.999)),

        "bce_weight": float(cfg["bce_weight"]),
        "dice_weight": float(cfg["dice_weight"]),
        "focal_gamma": float(cfg.get("focal_gamma", 0.0)),
        "input_noise_std": float(cfg.get("input_noise_std", 0.0)),
        "aug_hflip": float(cfg.get("aug_hflip", 0.0)),
        "aug_vflip": float(cfg.get("aug_vflip", 0.0)),
        "thr_grid": int(cfg["thr_grid"]),
    }

    # scalars oof (dict) only for folds evaluated
    return summary, fold_rows, oof_area, oof_mean, fold_packs

# ----------------------------
# 10) Candidate configs (mask)
# ----------------------------
def make_base_cfg():
    if device.type == "cuda":
        if MEM_GB >= 30:
            bs, acc = 48, 1
        elif MEM_GB >= 16:
            bs, acc = 32, 1
        else:
            bs, acc = 24, 2
    else:
        bs, acc = 12, 1

    return dict(
        seed=SEED,
        epochs=55 if device.type == "cuda" else 30,
        batch_size=bs,
        accum_steps=acc,
        lr=3e-4,
        weight_decay=1e-2,
        warmup_frac=0.05,
        grad_clip=1.0,
        patience=10,
        min_delta=1e-4,
        use_ema=True,
        ema_decay=0.999,

        base_ch=160 if device.type == "cuda" else 128,
        drop=0.10,
        aspp_rates=(1,2,4,6),

        bce_weight=1.0,
        dice_weight=1.0,
        focal_gamma=0.0,

        input_noise_std=0.008,
        aug_hflip=0.5,
        aug_vflip=0.2,

        thr_grid=THR_GRID,
    )

BASE = make_base_cfg()

candidates = []
candidates.append(("mask_160_r1246", dict(BASE, base_ch=160, drop=0.10, aspp_rates=(1,2,4,6), lr=3e-4, weight_decay=1e-2)))
candidates.append(("mask_192_reg",   dict(BASE, base_ch=192, drop=0.14, aspp_rates=(1,2,4,6), lr=2.5e-4, weight_decay=1.5e-2,
                                         input_noise_std=0.010, dice_weight=1.2, bce_weight=0.9)))
candidates.append(("mask_128_fast",  dict(BASE, base_ch=128, drop=0.10, aspp_rates=(1,2,3,4), lr=3.5e-4, weight_decay=8e-3,
                                         epochs=min(int(BASE["epochs"]), 45), patience=9)))
candidates.append(("mask_focal",     dict(BASE, base_ch=160, drop=0.12, aspp_rates=(1,2,4,6), lr=3e-4, weight_decay=1.2e-2,
                                         focal_gamma=1.5, bce_weight=1.0, dice_weight=1.1)))

if device.type == "cuda" and MEM_GB >= 20:
    candidates.append(("mask_big_224", dict(BASE, base_ch=224, drop=0.16, aspp_rates=(1,2,4,8), lr=2.0e-4, weight_decay=2.0e-2,
                                           epochs=max(int(BASE["epochs"]), 65), patience=12,
                                           dice_weight=1.25, bce_weight=0.85, input_noise_std=0.012)))

print(f"\nTotal mask candidates: {len(candidates)}")
print("Primary score: mean(best_val_dice_hard) across folds (subset for stage1, full for stage2)")

# ----------------------------
# 11) Run 2-stage search (resume-safe)
# ----------------------------
OUT_DIR = Path("/kaggle/working/recodai_luc_mask_artifacts")
OPT_DIR = OUT_DIR / "opt_search"
OPT_DIR.mkdir(parents=True, exist_ok=True)

STAGE1_PATH = OPT_DIR / "stage1_results.csv"

# evenly spaced folds subset
if STAGE1_FOLDS >= len(unique_folds):
    folds_subset = unique_folds
else:
    idxs = np.linspace(0, len(unique_folds)-1, STAGE1_FOLDS)
    idxs = np.unique(np.round(idxs).astype(int)).tolist()
    folds_subset = [unique_folds[i] for i in idxs]
    if len(folds_subset) < STAGE1_FOLDS:
        for f in unique_folds:
            if f not in folds_subset:
                folds_subset.append(f)
            if len(folds_subset) >= STAGE1_FOLDS:
                break

print("\nStage-1 folds subset:", folds_subset)

done_stage1 = set()
if STAGE1_PATH.exists():
    try:
        df_prev = pd.read_csv(STAGE1_PATH)
        if "cfg" in df_prev.columns:
            done_stage1 = set(df_prev["cfg"].astype(str).tolist())
            print(f"Resume: found {len(done_stage1)} configs already in stage1_results.csv")
    except Exception:
        pass

t0 = time.time()

for i, (name, cfg) in enumerate(candidates, 1):
    if TIME_BUDGET_SEC and (time.time() - t0) > TIME_BUDGET_SEC:
        print("Time budget reached. Stop search.")
        break

    if name in done_stage1:
        print(f"\n[Stage-1 {i:02d}/{len(candidates)}] SKIP -> {name}")
        continue

    cfg1 = dict(cfg)
    cfg1["epochs"] = int(min(int(cfg1["epochs"]), int(STAGE1_EPOCH_CAP)))
    cfg1["patience"] = int(min(int(cfg1["patience"]), int(STAGE1_PAT_CAP)))

    print(f"\n[Stage-1 {i:02d}/{len(candidates)}] CV(subset) -> {name}")
    summ, fold_rows, _, _, _ = run_cv_config(cfg1, name, folds_subset=folds_subset)

    print(f"  stage1 score(mean dice hard): {summ['score_mean_best_dice_hard']:.6f} | thr_median: {summ['thr_median']:.3f}")

    # append to CSV (resume-safe)
    try:
        df_append = pd.DataFrame([summ])
        if STAGE1_PATH.exists():
            df_old = pd.read_csv(STAGE1_PATH)
            df_new = pd.concat([df_old, df_append], axis=0, ignore_index=True)
        else:
            df_new = df_append
        df_new.to_csv(STAGE1_PATH, index=False)
    except Exception:
        pass

# load stage1 ranking
if STAGE1_PATH.exists():
    df_s1 = pd.read_csv(STAGE1_PATH)
else:
    df_s1 = pd.DataFrame([])

if len(df_s1) == 0:
    raise RuntimeError("Stage-1 menghasilkan 0 hasil. Cek token/mask cache atau turunkan kandidat/epochs.")

df_s1 = df_s1.sort_values(["score_mean_best_dice_hard", "thr_median"], ascending=[False, True]).reset_index(drop=True)
print("\nStage-1 ranking (top):")
display(df_s1.head(10))

topM = min(int(STAGE2_TOPM), len(df_s1))
stage2_names = df_s1["cfg"].head(topM).astype(str).tolist()
print("\nStage-2 will run full CV for:", stage2_names)

all_summaries = []
all_fold_rows = []
oof_area_store = {}
oof_mean_store = {}
pack_store = {}

for j, nm in enumerate(stage2_names, 1):
    if TIME_BUDGET_SEC and (time.time() - t0) > TIME_BUDGET_SEC:
        print("Time budget reached. Stop stage-2.")
        break

    cfg = None
    for (nname, ccfg) in candidates:
        if nname == nm:
            cfg = ccfg
            break
    if cfg is None:
        continue

    print(f"\n[Stage-2 {j:02d}/{len(stage2_names)}] CV(full) -> {nm}")
    summ, fold_rows, oof_area, oof_mean, fold_packs = run_cv_config(cfg, nm, folds_subset=None)

    all_summaries.append(summ)
    all_fold_rows.extend(fold_rows)
    oof_area_store[nm] = oof_area
    oof_mean_store[nm] = oof_mean
    pack_store[nm] = fold_packs

    print(f"  OOF score(mean dice hard): {summ['score_mean_best_dice_hard']:.6f} | thr_median: {summ['thr_median']:.3f}")

df_sum = pd.DataFrame(all_summaries)
df_fold = pd.DataFrame(all_fold_rows)

if len(df_sum) == 0:
    raise RuntimeError("Stage-2 produced no results. Turunkan kandidat/epochs atau cek device/VRAM.")

df_sum = df_sum.sort_values(["score_mean_best_dice_hard"], ascending=[False]).reset_index(drop=True)

print("\nStage-2 top candidates (full CV):")
display(df_sum)

# save search results
df_sum.to_csv(OPT_DIR / "opt_results.csv", index=False)
with open(OPT_DIR / "opt_results.json", "w") as f:
    json.dump(df_sum.to_dict(orient="records"), f, indent=2)
df_fold.to_csv(OPT_DIR / "opt_fold_details.csv", index=False)

# save OOF scalars for top configs
top_names = df_sum["cfg"].head(min(REPORT_TOPK_OOF, len(df_sum))).astype(str).tolist()
for nm in top_names:
    # align to df_train_filtered order
    ulist = df_train_filtered["uid"].astype(str).tolist()
    area = np.array([oof_area_store[nm].get(u, np.nan) for u in ulist], dtype=np.float32)
    meanp = np.array([oof_mean_store[nm].get(u, np.nan) for u in ulist], dtype=np.float32)

    df_o = df_train_filtered[["uid","y","fold"]].copy()
    df_o[f"oof_area_frac_{nm}"] = area
    df_o[f"oof_mean_prob_{nm}"] = meanp
    df_o.to_csv(OPT_DIR / f"oof_scalars_{nm}.csv", index=False)

# ----------------------------
# 12) Choose best config + save BEST fold packs
# ----------------------------
best_single = df_sum.iloc[0].to_dict()
best_cfg_name = str(best_single["cfg"])

best_cfg = None
for nm, cfg in candidates:
    if nm == best_cfg_name:
        best_cfg = cfg
        break
if best_cfg is None:
    raise RuntimeError("Best cfg not found in candidates list (unexpected).")

best_fold_packs = pack_store[best_cfg_name]
recommended_thr = float(best_single["thr_median"])
best_score = float(best_single["score_mean_best_dice_hard"])

best_model_path = OUT_DIR / "best_mask_model.pt"
torch.save(
    {
        "type": "unet_aspp_decoder_on_dinov2_tokengrid_opt_v4",
        "cfg_name": best_cfg_name,
        "cfg": best_cfg,
        "seed": SEED,
        "token_dim": int(TOKEN_DIM),
        "token_hw": tuple(map(int, TOKEN_HW)),
        "fold_packs": best_fold_packs,
        "recommended_thr": recommended_thr,
        "best_oof_score_mean_dice_hard": best_score,
        "notes": "Best config from Step 4 (mask optimization). fold_packs only (no full retrain here).",
    },
    best_model_path
)

best_bundle = {
    "type": "unet_aspp_decoder_on_dinov2_tokengrid_opt_v4",
    "model_name": best_cfg_name,
    "random_seed": SEED,
    "token_dim": int(TOKEN_DIM),
    "token_hw": tuple(map(int, TOKEN_HW)),

    "cfg": best_cfg,
    "recommended_thr": recommended_thr,
    "best_oof_score_mean_dice_hard": best_score,

    "paths": {
        "opt_results_csv": str(OPT_DIR / "opt_results.csv"),
        "opt_fold_details_csv": str(OPT_DIR / "opt_fold_details.csv"),
        "stage1_results_csv": str(STAGE1_PATH),
        "best_model_pt": str(best_model_path),
    },
}

with open(OUT_DIR / "best_mask_config.json", "w") as f:
    json.dump(best_bundle, f, indent=2)

print("\nSaved best artifacts:")
print("  best model (fold packs) ->", best_model_path)
print("  best config             ->", OUT_DIR / "best_mask_config.json")
print("  opt results             ->", OPT_DIR / "opt_results.csv")
print("  fold detail             ->", OPT_DIR / "opt_fold_details.csv")
print("  stage1 cache            ->", STAGE1_PATH)

# Export globals for next steps
BEST_MASK_BUNDLE = best_bundle
BEST_MASK_CFG_NAME = best_cfg_name
BEST_MASK_CFG = best_cfg
OPT_RESULTS_DF_MASK = df_sum


# Final Training (Train on Full Data)

In [None]:
# ============================================================
# Step 5 — Final Training (Train on Full Data) — MASK ONLY
# REVISI FULL v1.0 (match Step 4 MASK optimizer: UNet+ASPP on DINOv2 token-grid)
#
# Sesuai saran sebelumnya:
# - Kalau Step 3 kamu sudah pakai UNet+ASPP mask-decoder di atas token-grid DINOv2,
#   maka Step 5 final training juga harus train MASK model (bukan transformer gate).
#
# Fix/fitur:
# - STRICT load cfg terbaik dari Step 4: best_mask_config.json (tanpa upscale kecuali ON)
# - AMP/GradScaler aman: CPU => no-amp, no-scaler
# - Internal val case-level (stratified by case y) untuk cari best_epoch + recommended_thr
# - Retrain FULL data pakai epochs_fixed (best_epoch * 1.05) + EMA final-weights
# - OOM fallback aman: turun batch -> base_ch -> accum
# - Token-grid size guard: auto-detect mode (H,W,D) lalu resize semua token+mask ke itu
#
# Output:
#   /kaggle/working/recodai_luc_mask_artifacts/final_mask_model.pt
#   /kaggle/working/recodai_luc_mask_artifacts/final_mask_bundle.json
# ============================================================

import os, json, gc, math, time, warnings
from pathlib import Path
from collections import Counter
from contextlib import nullcontext

import numpy as np
import pandas as pd
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# 0) REQUIRE
# ----------------------------
if "df_train_tabular" not in globals():
    raise RuntimeError("Missing `df_train_tabular`. Pastikan Step 2 (build table) sudah jalan.")
df_train_tabular = df_train_tabular.copy()

need_cols = {"uid","case_id","y"}
miss = [c for c in need_cols if c not in df_train_tabular.columns]
if miss:
    raise ValueError(f"df_train_tabular missing columns: {miss}")

df_train_tabular["uid"] = df_train_tabular["uid"].astype(str)
df_train_tabular["case_id"] = df_train_tabular["case_id"].astype(str)
df_train_tabular["y"] = df_train_tabular["y"].astype(int)

print("Final MASK training data:")
print(f"  rows={len(df_train_tabular)} | forged%={float(df_train_tabular['y'].mean())*100:.2f}")

OUT_DIR = Path("/kaggle/working/recodai_luc_mask_artifacts")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# 1) Load best cfg (Step 4 output)
# ----------------------------
best_bundle = None
source = None

cfg_path = OUT_DIR / "best_mask_config.json"
best_model_candidates = [
    OUT_DIR / "best_mask_model.pt",
    OUT_DIR / "best_mask_model.pth",
]

if "BEST_MASK_BUNDLE" in globals() and isinstance(BEST_MASK_BUNDLE, dict):
    best_bundle = BEST_MASK_BUNDLE
    source = "memory(BEST_MASK_BUNDLE)"
elif cfg_path.exists():
    best_bundle = json.loads(cfg_path.read_text())
    source = str(cfg_path)

if best_bundle is not None and isinstance(best_bundle, dict) and isinstance(best_bundle.get("cfg", None), dict):
    base_cfg = dict(best_bundle["cfg"])
    print("\nLoaded cfg from:", source)
else:
    base_cfg = {}
    print("\nNo best_mask_config found. Using strong default cfg.")

# optional: load fold_packs dari best_mask_model.pt
fold_packs_from_step4 = None
best_mask_model_path = None
for p in best_model_candidates:
    if p.exists():
        best_mask_model_path = p
        break

if best_mask_model_path is not None:
    try:
        obj = torch.load(best_mask_model_path, map_location="cpu")
        if isinstance(obj, dict) and isinstance(obj.get("fold_packs", None), list):
            fold_packs_from_step4 = obj["fold_packs"]
            print("Loaded fold_packs from:", str(best_mask_model_path))
    except Exception as e:
        print("Warning: failed to load best_mask_model.* fold_packs:", repr(e))

# ----------------------------
# 2) Device + seed
# ----------------------------
def seed_everything(seed: int):
    import random
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

FINAL_SEED = 2025
if isinstance(best_bundle, dict):
    FINAL_SEED = int(best_bundle.get("seed", best_bundle.get("random_seed", 2025)))

seed_everything(FINAL_SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")

try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

vram_gb = None
if device.type == "cuda":
    vram_gb = float(torch.cuda.get_device_properties(0).total_memory / (1024**3))

print("\nDevice:", device, "| AMP:", use_amp, "| VRAM_GB:", (f"{vram_gb:.1f}" if vram_gb else "CPU"))

# ----------------------------
# 3) Training policy
# ----------------------------
ALLOW_UPSCALE = False          # True kalau mau auto-besar sesuai VRAM
USE_INTERNAL_VAL = True        # cari best_epoch + recommended_thr dari internal val (case-level)
VAL_FRAC_CASE = 0.08           # 8% case untuk val
EARLY_STOP = True

N_SEEDS = 1

DROP_MISSING_POS_MASKS = True  # disarankan (pos tanpa mask => bikin training noisy)
TIME_BUDGET_SEC = 0            # 0=off

# ----------------------------
# 4) Auto-detect dirs (TOKEN_CACHE_DIR + MASK_DIR)
# ----------------------------
def _first_existing(paths):
    for p in paths:
        if p is None:
            continue
        p = Path(p)
        if p.exists():
            return p
    return None

# MASK_DIR
MASK_DIR = None
if "PATHS" in globals() and isinstance(PATHS, dict):
    MASK_DIR = _first_existing([PATHS.get("TRAIN_MASK_DIR"), PATHS.get("MASK_DIR"), PATHS.get("TRAIN_MASKS")])
if MASK_DIR is None:
    MASK_DIR = _first_existing([
        "/kaggle/input/recodai-luc-scientific-image-forgery-detection/train_masks",
        "/kaggle/input/recodai-luc-scientific-image-forgery-detection/masks_train",
        "/kaggle/input/recodai-luc-scientific-image-forgery-detection/masks",
        "/kaggle/working/recodai_luc/train_masks",
    ])
print("MASK_DIR:", str(MASK_DIR) if MASK_DIR else "(None)")

# TOKEN_CACHE_DIR
TOKEN_CACHE_DIR = None
if "CACHE_ROOT" in globals():
    try:
        cr = Path(CACHE_ROOT)
        if cr.exists():
            TOKEN_CACHE_DIR = cr
    except Exception:
        pass

if TOKEN_CACHE_DIR is None:
    base_candidates = [
        Path("/kaggle/working/recodai_luc/cache/dino_v2"),
        Path("/kaggle/working/recodai_luc/cache"),
        Path("/kaggle/input/recod-ailuc-dinov2-train/recodai_luc/cache/dino_v2"),
        Path("/kaggle/input/recod-ailuc-dinov2-train/recodai_luc/cache"),
    ]
    for b in base_candidates:
        if not b.exists():
            continue
        cfg_dirs = sorted([p for p in b.glob("cfg_*") if p.is_dir()], key=lambda p: p.stat().st_mtime, reverse=True)
        if cfg_dirs:
            TOKEN_CACHE_DIR = cfg_dirs[0]
            break
        TOKEN_CACHE_DIR = b
        break

print("TOKEN_CACHE_DIR:", str(TOKEN_CACHE_DIR) if TOKEN_CACHE_DIR else "(None)")
if TOKEN_CACHE_DIR is None:
    raise RuntimeError("TOKEN_CACHE_DIR tidak ditemukan. Pastikan cache token-grid DINOv2 per uid tersedia.")

# ----------------------------
# 5) Token/mask path map + infer token grid mode (H,W,D)
# ----------------------------
def _load_np_any(p: Path):
    p = Path(p)
    if p.suffix.lower() == ".npy":
        return np.load(p, allow_pickle=False)
    if p.suffix.lower() == ".npz":
        z = np.load(p, allow_pickle=False)
        keys = list(z.keys())
        for k in keys:
            a = z[k]
            if isinstance(a, np.ndarray) and a.ndim in (2,3):
                return a
        return z[keys[0]] if keys else None
    return None

def _reshape_to_grid(a: np.ndarray):
    a = np.asarray(a)
    if a.ndim == 3:
        # (D,H,W) -> (H,W,D)
        if a.shape[0] in [256, 384, 512, 768, 1024, 1536] and (a.shape[1] * a.shape[2] > 16):
            D,H,W = a.shape
            return np.transpose(a, (1,2,0))
        return a  # assume (H,W,D)
    if a.ndim == 2:
        N,D = a.shape
        s = int(round(math.sqrt(N)))
        if s*s == N:
            return a.reshape(s,s,D)
        for h in range(1, int(math.sqrt(N))+1):
            if N % h == 0:
                w = N // h
                if h >= 4 and w >= 4:
                    return a.reshape(h,w,D)
        return a.reshape(1,N,D)
    return None

def find_token_file_fast(uid: str):
    uid = str(uid)
    cand = [
        TOKEN_CACHE_DIR / f"{uid}.npz",
        TOKEN_CACHE_DIR / f"{uid}.npy",
        TOKEN_CACHE_DIR / "tokens_train" / f"{uid}.npz",
        TOKEN_CACHE_DIR / "tokens_train" / f"{uid}.npy",
        TOKEN_CACHE_DIR / "train" / f"{uid}.npz",
        TOKEN_CACHE_DIR / "train" / f"{uid}.npy",
        TOKEN_CACHE_DIR / "feat_train" / f"{uid}.npz",
        TOKEN_CACHE_DIR / "feat_train" / f"{uid}.npy",
    ]
    for p in cand:
        if p.exists():
            return p
    return None

def find_mask_file_fast(uid: str):
    if MASK_DIR is None:
        return None
    uid = str(uid)
    exts = [".png",".jpg",".jpeg",".tif",".tiff",".bmp"]
    for ex in exts:
        p = Path(MASK_DIR) / f"{uid}{ex}"
        if p.exists():
            return p
    hits = list(Path(MASK_DIR).glob(f"{uid}.*"))
    return hits[0] if hits else None

uids_all = df_train_tabular["uid"].astype(str).tolist()

TOKEN_PATH = {}
MASK_PATH = {}

token_hw_counter = Counter()
token_dim_counter = Counter()

max_probe = min(300, len(uids_all))
for uid in uids_all[:max_probe]:
    tp = find_token_file_fast(uid)
    TOKEN_PATH[uid] = tp
    if tp is None:
        continue
    a = _load_np_any(tp)
    if a is None:
        continue
    g = _reshape_to_grid(a)
    if g is None or g.ndim != 3:
        continue
    H,W,D = int(g.shape[0]), int(g.shape[1]), int(g.shape[2])
    if H >= 4 and W >= 4 and D >= 16:
        token_hw_counter[(H,W)] += 1
        token_dim_counter[D] += 1

for uid in uids_all[max_probe:]:
    TOKEN_PATH[uid] = find_token_file_fast(uid)

if MASK_DIR is not None:
    for uid in uids_all[:max_probe]:
        MASK_PATH[uid] = find_mask_file_fast(uid)
    for uid in uids_all[max_probe:]:
        MASK_PATH[uid] = find_mask_file_fast(uid)

if not token_hw_counter or not token_dim_counter:
    raise RuntimeError("Tidak bisa infer token grid shape/dim dari token cache. Cek format token-grid per uid.")

TOKEN_HW = token_hw_counter.most_common(1)[0][0]
TOKEN_DIM = token_dim_counter.most_common(1)[0][0]

print("Token grid mode:", TOKEN_HW, "| token_dim mode:", TOKEN_DIM)

# ----------------------------
# 6) Filter df (drop missing token, optionally drop missing positive masks)
# ----------------------------
def build_df_filtered(df: pd.DataFrame):
    df = df.copy()
    df["uid"] = df["uid"].astype(str)

    df["token_exists"] = df["uid"].map(lambda u: TOKEN_PATH.get(str(u), None) is not None).astype(int)
    if MASK_DIR is not None:
        df["mask_exists"] = df["uid"].map(lambda u: MASK_PATH.get(str(u), None) is not None).astype(int)
    else:
        df["mask_exists"] = 0

    df = df[df["token_exists"] == 1].reset_index(drop=True)

    if DROP_MISSING_POS_MASKS and MASK_DIR is not None:
        df = df[(df["y"] == 0) | (df["mask_exists"] == 1)].reset_index(drop=True)

    return df

df_train_filtered = build_df_filtered(df_train_tabular)
print("After filtering:")
print("  rows:", len(df_train_filtered),
      "| forged%:", float(df_train_filtered["y"].mean())*100.0,
      "| token_missing_dropped:", int((df_train_tabular.shape[0] - df_train_filtered.shape[0])))

if len(df_train_filtered) < 64:
    raise RuntimeError("Data terlalu sedikit setelah filtering. Cek token cache & mask dir.")

# ----------------------------
# 7) IO utils: load/resize tokens + masks
# ----------------------------
def load_token_grid_from_path(p: Path):
    a = _load_np_any(p)
    g = _reshape_to_grid(a)
    if g is None:
        return None
    g = g.astype(np.float32, copy=False)
    if not np.isfinite(g).all():
        g = np.nan_to_num(g, nan=0.0, posinf=0.0, neginf=0.0)
    return g  # (H,W,D)

def resize_grid_hw(g: np.ndarray, out_hw):
    Ht, Wt = int(out_hw[0]), int(out_hw[1])
    if g.shape[0] == Ht and g.shape[1] == Wt:
        return g
    x = torch.from_numpy(np.transpose(g, (2,0,1))).unsqueeze(0)  # (1,D,H,W)
    x = F.interpolate(x, size=(Ht,Wt), mode="bilinear", align_corners=False)
    out = x.squeeze(0).permute(1,2,0).contiguous().cpu().numpy()
    return out.astype(np.float32, copy=False)

def load_mask_bin(path: Path):
    im = Image.open(path).convert("L")
    m = (np.array(im, dtype=np.uint8) > 127).astype(np.uint8)
    return m

def resize_mask_soft(mask_bin: np.ndarray, out_hw):
    Ht,Wt = int(out_hw[0]), int(out_hw[1])
    im = Image.fromarray(mask_bin.astype(np.uint8)*255)
    im = im.resize((Wt,Ht), resample=Image.BILINEAR)
    m = (np.array(im, dtype=np.float32) / 255.0).astype(np.float32)
    return m

# ----------------------------
# 8) Dataset
# ----------------------------
class TokenMaskDataset(Dataset):
    def __init__(self, df: pd.DataFrame, is_train: bool, cfg: dict):
        self.df = df.reset_index(drop=True)
        self.is_train = bool(is_train)
        self.cfg = cfg

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]
        uid = str(row["uid"])
        yflag = int(row["y"])

        tp = TOKEN_PATH.get(uid, None)
        g = load_token_grid_from_path(tp) if tp is not None else None
        if g is None:
            g = np.zeros((TOKEN_HW[0], TOKEN_HW[1], TOKEN_DIM), dtype=np.float32)

        # dim fix
        if g.shape[2] != TOKEN_DIM:
            D = g.shape[2]
            if D > TOKEN_DIM:
                g = g[:, :, :TOKEN_DIM]
            else:
                pad = np.zeros((g.shape[0], g.shape[1], TOKEN_DIM - D), dtype=np.float32)
                g = np.concatenate([g, pad], axis=2)

        g = resize_grid_hw(g, TOKEN_HW)

        # mask
        m_soft = np.zeros((TOKEN_HW[0], TOKEN_HW[1]), dtype=np.float32)
        if yflag == 1 and MASK_DIR is not None:
            mp = MASK_PATH.get(uid, None)
            if mp is not None:
                m_bin = load_mask_bin(mp)
                m_soft = resize_mask_soft(m_bin, TOKEN_HW)

        # aug
        if self.is_train:
            if float(self.cfg.get("aug_hflip", 0.0)) > 0 and np.random.rand() < float(self.cfg["aug_hflip"]):
                g = g[:, ::-1, :].copy()
                m_soft = m_soft[:, ::-1].copy()
            if float(self.cfg.get("aug_vflip", 0.0)) > 0 and np.random.rand() < float(self.cfg["aug_vflip"]):
                g = g[::-1, :, :].copy()
                m_soft = m_soft[::-1, :].copy()

            ns = float(self.cfg.get("input_noise_std", 0.0))
            if ns > 0:
                g = g + np.random.randn(*g.shape).astype(np.float32) * ns

        x = torch.from_numpy(np.transpose(g, (2,0,1)).astype(np.float32))   # (D,H,W)
        y = torch.from_numpy(m_soft[None, ...].astype(np.float32))          # (1,H,W)
        return x, y, uid

def make_loader(ds, batch_size, shuffle):
    cpu_cnt = os.cpu_count() or 2
    nw = 2 if cpu_cnt >= 4 else 0
    pin = (device.type == "cuda")
    return DataLoader(
        ds,
        batch_size=int(batch_size),
        shuffle=bool(shuffle),
        num_workers=nw,
        pin_memory=pin,
        drop_last=False,
        persistent_workers=(nw > 0),
    )

# ----------------------------
# 9) Model: UNet + ASPP
# ----------------------------
class ConvGNAct(nn.Module):
    def __init__(self, in_ch, out_ch, k=3, s=1, p=1, drop=0.0, groups=8):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=k, stride=s, padding=p, bias=False)
        g = max(1, min(int(groups), int(out_ch)))
        self.gn = nn.GroupNorm(g, out_ch)
        self.act = nn.SiLU(inplace=True)
        self.drop = nn.Dropout2d(float(drop)) if float(drop) > 0 else nn.Identity()

    def forward(self, x):
        return self.drop(self.act(self.gn(self.conv(x))))

class ASPP(nn.Module):
    def __init__(self, ch, rates=(1,2,4,6), drop=0.0):
        super().__init__()
        rs = list(rates)
        self.branches = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(ch, ch, 3, padding=int(r), dilation=int(r), bias=False),
                nn.GroupNorm(max(1, min(8, ch)), ch),
                nn.SiLU(inplace=True),
            ) for r in rs
        ])
        self.proj = nn.Sequential(
            nn.Conv2d(ch * len(rs), ch, 1, bias=False),
            nn.GroupNorm(max(1, min(8, ch)), ch),
            nn.SiLU(inplace=True),
            nn.Dropout2d(float(drop)) if float(drop) > 0 else nn.Identity()
        )

    def forward(self, x):
        xs = [b(x) for b in self.branches]
        return self.proj(torch.cat(xs, dim=1))

class UNetASPP(nn.Module):
    def __init__(self, in_dim, base_ch=160, drop=0.10, aspp_rates=(1,2,4,6)):
        super().__init__()
        C = int(base_ch)

        self.stem = nn.Sequential(
            ConvGNAct(in_dim, C, k=1, s=1, p=0, drop=drop),
            ConvGNAct(C, C, k=3, s=1, p=1, drop=drop),
        )
        self.enc1 = nn.Sequential(
            ConvGNAct(C, C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(C, C, k=3, s=1, p=1, drop=drop),
        )
        self.down1 = ConvGNAct(C, 2*C, k=3, s=2, p=1, drop=drop)
        self.enc2 = nn.Sequential(
            ConvGNAct(2*C, 2*C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(2*C, 2*C, k=3, s=1, p=1, drop=drop),
        )
        self.down2 = ConvGNAct(2*C, 4*C, k=3, s=2, p=1, drop=drop)
        self.enc3 = nn.Sequential(
            ConvGNAct(4*C, 4*C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(4*C, 4*C, k=3, s=1, p=1, drop=drop),
        )
        self.down3 = ConvGNAct(4*C, 6*C, k=3, s=2, p=1, drop=drop)

        self.bottleneck = nn.Sequential(
            ConvGNAct(6*C, 6*C, k=3, s=1, p=1, drop=drop),
            ASPP(6*C, rates=aspp_rates, drop=drop),
            ConvGNAct(6*C, 6*C, k=3, s=1, p=1, drop=drop),
        )

        self.dec3 = nn.Sequential(
            ConvGNAct(6*C + 4*C, 4*C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(4*C, 4*C, k=3, s=1, p=1, drop=drop),
        )
        self.dec2 = nn.Sequential(
            ConvGNAct(4*C + 2*C, 2*C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(2*C, 2*C, k=3, s=1, p=1, drop=drop),
        )
        self.dec1 = nn.Sequential(
            ConvGNAct(2*C + C, C, k=3, s=1, p=1, drop=drop),
            ConvGNAct(C, C, k=3, s=1, p=1, drop=drop),
        )

        self.head = nn.Conv2d(C, 1, 1)

    def forward(self, x):
        x0 = self.stem(x)
        s1 = self.enc1(x0)

        x1 = self.down1(s1)
        s2 = self.enc2(x1)

        x2 = self.down2(s2)
        s3 = self.enc3(x2)

        x3 = self.down3(s3)
        b  = self.bottleneck(x3)

        u3 = F.interpolate(b,  size=s3.shape[-2:], mode="bilinear", align_corners=False)
        d3 = self.dec3(torch.cat([u3, s3], dim=1))

        u2 = F.interpolate(d3, size=s2.shape[-2:], mode="bilinear", align_corners=False)
        d2 = self.dec2(torch.cat([u2, s2], dim=1))

        u1 = F.interpolate(d2, size=s1.shape[-2:], mode="bilinear", align_corners=False)
        d1 = self.dec1(torch.cat([u1, s1], dim=1))

        return self.head(d1)  # (B,1,H,W)

# ----------------------------
# 10) EMA + loss + eval
# ----------------------------
class EMA:
    def __init__(self, model: nn.Module, decay: float = 0.999):
        self.decay = float(decay)
        self.shadow = {n: p.detach().clone() for n,p in model.named_parameters() if p.requires_grad}
        self.backup = {}

    @torch.no_grad()
    def update(self, model: nn.Module):
        d = self.decay
        for n,p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.shadow[n].mul_(d).add_(p.detach(), alpha=(1.0 - d))

    @torch.no_grad()
    def apply_shadow(self, model: nn.Module):
        self.backup = {}
        for n,p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.backup[n] = p.detach().clone()
            p.copy_(self.shadow[n])

    @torch.no_grad()
    def restore(self, model: nn.Module):
        for n,p in model.named_parameters():
            if not p.requires_grad:
                continue
            p.copy_(self.backup[n])
        self.backup = {}

def dice_from_probs(prob, target, eps=1e-6):
    num = (prob * target).sum(dim=(2,3)) * 2.0
    den = (prob + target).sum(dim=(2,3)).clamp_min(eps)
    return (num / den).mean()

def dice_hard_np(prob_np, target_np, thr):
    pr = (prob_np >= thr).astype(np.float32)
    inter = (pr * target_np).sum(axis=(2,3)) * 2.0
    den = (pr + target_np).sum(axis=(2,3)) + 1e-6
    return float((inter / den).mean())

def loss_bce_dice(logits, target, bce_w=1.0, dice_w=1.0, focal_gamma=0.0):
    bce = F.binary_cross_entropy_with_logits(logits, target, reduction="none")
    if focal_gamma and focal_gamma > 0:
        p = torch.sigmoid(logits)
        p_t = p * target + (1.0 - p) * (1.0 - target)
        bce = bce * (1.0 - p_t).clamp_min(0.0).pow(float(focal_gamma))
    bce = bce.mean()
    prob = torch.sigmoid(logits)
    dsc = dice_from_probs(prob, target)
    return float(bce_w)*bce + float(dice_w)*(1.0 - dsc)

@torch.no_grad()
def eval_loader(model, loader, cfg, ema=None):
    model.eval()
    if ema is not None:
        ema.apply_shadow(model)

    ctx = torch.cuda.amp.autocast(enabled=True) if use_amp else nullcontext()

    thrs = np.linspace(0.05, 0.95, int(cfg["thr_grid"]), dtype=np.float64)
    dice_sweep = np.zeros_like(thrs, dtype=np.float64)

    tot_loss, tot_dice_soft, nb, cnt = 0.0, 0.0, 0, 0

    for xb, yb, _ in loader:
        xb = xb.to(device, non_blocking=True)
        yb = yb.to(device, non_blocking=True)

        with ctx:
            logits = model(xb)
            loss = loss_bce_dice(
                logits, yb,
                bce_w=float(cfg["bce_weight"]),
                dice_w=float(cfg["dice_weight"]),
                focal_gamma=float(cfg.get("focal_gamma", 0.0)),
            )
            prob = torch.sigmoid(logits)
            dsoft = dice_from_probs(prob, yb)

        tot_loss += float(loss.item() if hasattr(loss, "item") else loss)
        tot_dice_soft += float(dsoft.item())
        nb += 1

        prob_np = prob.detach().float().cpu().numpy()
        y_np = yb.detach().float().cpu().numpy()
        for i, t in enumerate(thrs):
            dice_sweep[i] += dice_hard_np(prob_np, y_np, thr=float(t))
        cnt += 1

    dice_sweep = dice_sweep / max(1, cnt)
    j = int(np.argmax(dice_sweep))

    if ema is not None:
        ema.restore(model)

    return {
        "val_loss": tot_loss / max(1, nb),
        "val_dice_soft": tot_dice_soft / max(1, nb),
        "best_thr": float(thrs[j]),
        "best_dice_hard": float(dice_sweep[j]),
    }

# ----------------------------
# 11) CFG merge + guards (strict match Step 4)
# ----------------------------
CFG = dict(
    seed=FINAL_SEED,
    epochs=55 if device.type == "cuda" else 30,
    batch_size=32 if device.type == "cuda" else 12,
    accum_steps=1,
    lr=3e-4,
    weight_decay=1e-2,
    warmup_frac=0.05,
    grad_clip=1.0,
    patience=10,
    min_delta=1e-4,
    use_ema=True,
    ema_decay=0.999,

    base_ch=160 if device.type == "cuda" else 128,
    drop=0.10,
    aspp_rates=(1,2,4,6),

    bce_weight=1.0,
    dice_weight=1.0,
    focal_gamma=0.0,

    input_noise_std=0.008,
    aug_hflip=0.5,
    aug_vflip=0.2,

    thr_grid=81,
)

for k, v in base_cfg.items():
    if k in CFG:
        CFG[k] = v

def cpu_safe_cfg(cfg: dict):
    cfg = dict(cfg)
    if device.type != "cuda":
        cfg["base_ch"] = min(int(cfg["base_ch"]), 128)
        cfg["batch_size"] = min(int(cfg["batch_size"]), 12)
        cfg["epochs"] = min(int(cfg["epochs"]), 35)
        cfg["accum_steps"] = 1
    return cfg

def maybe_upscale_cfg(cfg: dict):
    if not (device.type == "cuda" and ALLOW_UPSCALE and vram_gb is not None):
        return dict(cfg)
    cfg = dict(cfg)
    if vram_gb >= 24:
        cfg["batch_size"] = min(int(cfg["batch_size"]), 48)
    elif vram_gb >= 16:
        cfg["batch_size"] = min(int(cfg["batch_size"]), 32)
    return cfg

CFG = cpu_safe_cfg(CFG)
CFG = maybe_upscale_cfg(CFG)

# effective batch heuristic (mask model is heavier; keep conservative)
TARGET_EFF_BATCH = 48 if device.type == "cuda" else 12
CFG["accum_steps"] = max(1, int(math.ceil(TARGET_EFF_BATCH / int(CFG["batch_size"]))))

print("\nCFG (final MASK):")
for k in ["base_ch","drop","aspp_rates","batch_size","accum_steps","epochs","lr","weight_decay","warmup_frac","patience","thr_grid"]:
    print(f"  {k}: {CFG[k]}")

# ----------------------------
# 12) Internal val split (case-level stratified)
# ----------------------------
def make_case_split(df: pd.DataFrame, val_frac=0.08, seed=2025):
    g = df.groupby("case_id")["y"].max().reset_index().rename(columns={"y": "case_y"})
    pos_cases = g.loc[g["case_y"] == 1, "case_id"].astype(str).to_numpy()
    neg_cases = g.loc[g["case_y"] == 0, "case_id"].astype(str).to_numpy()

    rng = np.random.RandomState(int(seed))
    rng.shuffle(pos_cases); rng.shuffle(neg_cases)

    if len(pos_cases) == 0 or len(neg_cases) == 0:
        idx = np.arange(len(df))
        rng.shuffle(idx)
        n_val = max(1, int(len(df) * float(val_frac)))
        is_val = np.zeros(len(df), dtype=bool)
        is_val[idx[:n_val]] = True
        return is_val

    n_val_pos = max(1, int(len(pos_cases) * float(val_frac)))
    n_val_neg = max(1, int(len(neg_cases) * float(val_frac)))
    val_cases = np.concatenate([pos_cases[:n_val_pos], neg_cases[:n_val_neg]])
    val_set = set(val_cases.tolist())

    is_val = df["case_id"].astype(str).isin(val_set).to_numpy(dtype=bool)

    # fallback if degenerate
    va = df.loc[is_val, "y"].to_numpy()
    if is_val.sum() < 32 or len(np.unique(va)) < 2:
        idx_pos = np.where(df["y"].values == 1)[0]
        idx_neg = np.where(df["y"].values == 0)[0]
        rng.shuffle(idx_pos); rng.shuffle(idx_neg)
        n_val = max(32, int(len(df) * float(val_frac)))
        n_val_pos = max(1, int(n_val * float(df["y"].mean())))
        n_val_pos = min(n_val_pos, len(idx_pos))
        n_val_neg = min(n_val - n_val_pos, len(idx_neg))
        val_idx = np.concatenate([idx_pos[:n_val_pos], idx_neg[:n_val_neg]])
        is_val = np.zeros(len(df), dtype=bool)
        is_val[val_idx] = True

    return is_val

# ----------------------------
# 13) Train with internal val -> best_epoch + recommended_thr
# ----------------------------
def build_model(cfg: dict):
    return UNetASPP(
        in_dim=int(TOKEN_DIM),
        base_ch=int(cfg["base_ch"]),
        drop=float(cfg["drop"]),
        aspp_rates=tuple(cfg["aspp_rates"]),
    ).to(device)

def build_opt_and_sch(model, cfg, steps_per_epoch, epochs):
    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        weight_decay=float(cfg["weight_decay"]),
        betas=(0.9, 0.99),
        eps=1e-8,
    )
    accum = max(1, int(cfg.get("accum_steps", 1)))
    total_steps = int(max(1, int(epochs) * int(max(1, steps_per_epoch))))
    warmup_steps = int(float(cfg["warmup_frac"]) * total_steps)

    def lr_lambda(step):
        if warmup_steps > 0 and step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))
        t = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
        t = min(max(t, 0.0), 1.0)
        return 0.5 * (1.0 + math.cos(math.pi * t))

    sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda)
    return opt, sch

def train_with_internal_val_get_best(df_all: pd.DataFrame, cfg: dict, seed=2025):
    seed_everything(int(seed))

    is_val = make_case_split(df_all, val_frac=float(VAL_FRAC_CASE), seed=int(seed))
    df_tr = df_all.loc[~is_val].reset_index(drop=True)
    df_va = df_all.loc[is_val].reset_index(drop=True)

    ds_tr = TokenMaskDataset(df_tr, is_train=True, cfg=cfg)
    ds_va = TokenMaskDataset(df_va, is_train=False, cfg=cfg)
    dl_tr = make_loader(ds_tr, cfg["batch_size"], shuffle=True)
    dl_va = make_loader(ds_va, cfg["batch_size"], shuffle=False)

    model = build_model(cfg)
    opt, sch = build_opt_and_sch(
        model, cfg,
        steps_per_epoch=int(math.ceil(len(dl_tr) / max(1, int(cfg.get("accum_steps", 1))))),
        epochs=int(cfg["epochs"])
    )

    scaler = torch.cuda.amp.GradScaler(enabled=True) if use_amp else None
    ema = EMA(model, decay=float(cfg.get("ema_decay", 0.999))) if bool(cfg.get("use_ema", True)) else None
    ctx = torch.cuda.amp.autocast(enabled=True) if use_amp else nullcontext()

    accum = max(1, int(cfg.get("accum_steps", 1)))
    input_noise_std = float(cfg.get("input_noise_std", 0.0))

    best = {"dice_hard": -1.0, "thr": 0.5, "epoch": -1, "val_loss": 1e18, "val_dice_soft": -1.0}
    best_state = None
    bad = 0

    print(f"\nInternal val split (MASK): train={len(df_tr)} | val={len(df_va)} | val_pos%={float(df_va['y'].mean())*100:.2f}")

    for epoch in range(int(cfg["epochs"])):
        model.train()
        opt.zero_grad(set_to_none=True)

        loss_sum, n_sum, micro = 0.0, 0, 0
        t0 = time.time()

        for xb, yb, _ in dl_tr:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            if input_noise_std and input_noise_std > 0:
                xb = xb + torch.randn_like(xb) * input_noise_std

            with ctx:
                logits = model(xb)
                loss = loss_bce_dice(
                    logits, yb,
                    bce_w=float(cfg["bce_weight"]),
                    dice_w=float(cfg["dice_weight"]),
                    focal_gamma=float(cfg.get("focal_gamma", 0.0)),
                ) / float(accum)

            if use_amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            micro += 1
            loss_sum += float(loss.item()) * xb.size(0) * float(accum)
            n_sum += xb.size(0)

            if (micro % accum) == 0:
                if float(cfg["grad_clip"]) > 0:
                    if use_amp:
                        scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))

                if use_amp:
                    scaler.step(opt); scaler.update()
                else:
                    opt.step()

                opt.zero_grad(set_to_none=True)
                sch.step()
                if ema is not None:
                    ema.update(model)

        if (micro % accum) != 0:
            if float(cfg["grad_clip"]) > 0:
                if use_amp:
                    scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))

            if use_amp:
                scaler.step(opt); scaler.update()
            else:
                opt.step()

            opt.zero_grad(set_to_none=True)
            sch.step()
            if ema is not None:
                ema.update(model)

        ev = eval_loader(model, dl_va, cfg, ema=ema)
        dt = time.time() - t0
        print(f"  ep {epoch+1:03d}/{cfg['epochs']} | tr_loss={loss_sum/max(1,n_sum):.5f} | "
              f"val_loss={ev['val_loss']:.5f} | val_dice_soft={ev['val_dice_soft']:.5f} | "
              f"best_dice_hard={ev['best_dice_hard']:.5f}@{ev['best_thr']:.2f} | bad={bad} | dt={dt:.1f}s")

        improved = (ev["best_dice_hard"] - best["dice_hard"]) > float(cfg["min_delta"])
        if improved:
            best["dice_hard"] = float(ev["best_dice_hard"])
            best["thr"] = float(ev["best_thr"])
            best["epoch"] = int(epoch) + 1
            best["val_loss"] = float(ev["val_loss"])
            best["val_dice_soft"] = float(ev["val_dice_soft"])

            # store EMA weights
            if ema is not None:
                ema.apply_shadow(model)
                best_state = {k: v.detach().cpu() for k,v in model.state_dict().items()}
                ema.restore(model)
            else:
                best_state = {k: v.detach().cpu() for k,v in model.state_dict().items()}

            bad = 0
        else:
            bad += 1
            if EARLY_STOP and bad >= int(cfg["patience"]):
                break

        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    if best_state is None:
        best["epoch"] = max(12, int(cfg["epochs"] * 0.6))
        best["thr"] = float(best_bundle.get("recommended_thr", 0.5)) if isinstance(best_bundle, dict) else 0.5

    return best

# ----------------------------
# 14) Train FULL fixed epochs (use EMA at end)
# ----------------------------
def train_full_fixed_epochs(df_all: pd.DataFrame, cfg: dict, epochs_fixed: int, seed=2025):
    seed_everything(int(seed))

    ds = TokenMaskDataset(df_all, is_train=True, cfg=cfg)
    dl = make_loader(ds, cfg["batch_size"], shuffle=True)

    model = build_model(cfg)
    opt, sch = build_opt_and_sch(
        model, cfg,
        steps_per_epoch=int(math.ceil(len(dl) / max(1, int(cfg.get("accum_steps", 1))))),
        epochs=int(epochs_fixed)
    )

    scaler = torch.cuda.amp.GradScaler(enabled=True) if use_amp else None
    ema = EMA(model, decay=float(cfg.get("ema_decay", 0.999))) if bool(cfg.get("use_ema", True)) else None
    ctx = torch.cuda.amp.autocast(enabled=True) if use_amp else nullcontext()

    accum = max(1, int(cfg.get("accum_steps", 1)))
    input_noise_std = float(cfg.get("input_noise_std", 0.0))

    t0 = time.time()

    for epoch in range(int(epochs_fixed)):
        model.train()
        opt.zero_grad(set_to_none=True)

        loss_sum, n_sum, micro = 0.0, 0, 0

        for xb, yb, _ in dl:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True)

            if input_noise_std and input_noise_std > 0:
                xb = xb + torch.randn_like(xb) * input_noise_std

            with ctx:
                logits = model(xb)
                loss = loss_bce_dice(
                    logits, yb,
                    bce_w=float(cfg["bce_weight"]),
                    dice_w=float(cfg["dice_weight"]),
                    focal_gamma=float(cfg.get("focal_gamma", 0.0)),
                ) / float(accum)

            if use_amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            micro += 1
            loss_sum += float(loss.item()) * xb.size(0) * float(accum)
            n_sum += xb.size(0)

            if (micro % accum) == 0:
                if float(cfg["grad_clip"]) > 0:
                    if use_amp:
                        scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))

                if use_amp:
                    scaler.step(opt); scaler.update()
                else:
                    opt.step()

                opt.zero_grad(set_to_none=True)
                sch.step()
                if ema is not None:
                    ema.update(model)

        if (micro % accum) != 0:
            if float(cfg["grad_clip"]) > 0:
                if use_amp:
                    scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))

            if use_amp:
                scaler.step(opt); scaler.update()
            else:
                opt.step()

            opt.zero_grad(set_to_none=True)
            sch.step()
            if ema is not None:
                ema.update(model)

        print(f"  full epoch {epoch+1:03d}/{int(epochs_fixed)} | loss={loss_sum/max(1,n_sum):.5f}")

        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    if ema is not None:
        ema.apply_shadow(model)

    pack = {
        "type": "unet_aspp_token_grid_full_v1",
        "arch": "UNetASPP",
        "state_dict": {k: v.detach().cpu() for k,v in model.state_dict().items()},
        "cfg": dict(cfg),
        "seed": int(seed),
        "token_dim": int(TOKEN_DIM),
        "token_hw": tuple(map(int, TOKEN_HW)),
        "train_rows": int(len(df_all)),
        "pos_rate": float(df_all["y"].mean()),
        "epochs_fixed": int(epochs_fixed),
        "accum_steps": int(accum),
        "train_time_s": float(time.time() - t0),
        "used_ema_weights": bool(ema is not None),
    }
    return pack

# ----------------------------
# 15) OOM fallback
# ----------------------------
def apply_oom_fallback(cfg: dict):
    cfg = dict(cfg)
    cfg["batch_size"] = max(4, int(cfg["batch_size"]) // 2)
    cfg["accum_steps"] = max(1, int(math.ceil(TARGET_EFF_BATCH / int(cfg["batch_size"]))))
    cfg["base_ch"] = max(96, int(cfg["base_ch"]) - 32)
    cfg["drop"] = min(0.25, float(cfg["drop"]) + 0.02)
    return cfg

# ----------------------------
# 16) Train final (OOM-safe)
# ----------------------------
final_full_packs = []
internal_val_infos = []

t_global = time.time()

for s in range(int(N_SEEDS)):
    seed_i = FINAL_SEED + s
    print(f"\n[Final MASK Train] seed={seed_i}")

    cfg_run = dict(CFG)
    cfg_run["seed"] = int(seed_i)

    for attempt in range(6):
        try:
            # Phase A: best epoch + thr
            if USE_INTERNAL_VAL:
                info = train_with_internal_val_get_best(df_train_filtered, cfg_run, seed=seed_i)
                best_epoch = int(info["epoch"])
                best_thr = float(info["thr"])
                internal_val_infos.append({"seed": seed_i, "best_epoch": best_epoch, "best_thr": best_thr, **info, "cfg_used": dict(cfg_run)})

                E_FULL = int(min(int(cfg_run["epochs"]), max(8, round(best_epoch * 1.05))))
                print(f"\nBest_epoch={best_epoch} | best_thr={best_thr:.3f} -> Retrain FULL for E_FULL={E_FULL}")
            else:
                best_thr = float(best_bundle.get("recommended_thr", 0.5)) if isinstance(best_bundle, dict) else 0.5
                E_FULL = int(cfg_run["epochs"])
                internal_val_infos.append({"seed": seed_i, "best_epoch": None, "best_thr": best_thr, "cfg_used": dict(cfg_run)})

            full_pack = train_full_fixed_epochs(df_train_filtered, cfg_run, epochs_fixed=E_FULL, seed=seed_i)
            final_full_packs.append(full_pack)
            break

        except RuntimeError as e:
            msg = str(e).lower()
            if ("out of memory" in msg) and device.type == "cuda":
                print(f"  OOM detected (attempt {attempt+1}). Applying fallback.")
                torch.cuda.empty_cache()
                cfg_run = apply_oom_fallback(cfg_run)
                print("  New CFG after fallback:")
                for k in ["base_ch","batch_size","accum_steps","epochs","drop"]:
                    print(f"    {k}: {cfg_run[k]}")
                continue
            raise

    gc.collect()
    if device.type == "cuda":
        torch.cuda.empty_cache()

if len(final_full_packs) == 0:
    raise RuntimeError("Final training failed: no full_packs produced.")

# recommended_thr: prefer internal val median if available, else from Step4 bundle, else 0.5
recommended_thr = None
if len(internal_val_infos) > 0 and all(("best_thr" in x and x["best_thr"] is not None) for x in internal_val_infos):
    recommended_thr = float(np.median([float(x["best_thr"]) for x in internal_val_infos]))
elif isinstance(best_bundle, dict) and best_bundle.get("recommended_thr", None) is not None:
    recommended_thr = float(best_bundle["recommended_thr"])
else:
    recommended_thr = 0.5

# ----------------------------
# 17) Save artifacts
# ----------------------------
final_model_path = OUT_DIR / "final_mask_model.pt"

torch.save(
    {
        "type": "final_mask_v1",
        "arch": "UNetASPP_on_DINOv2TokenGrid",
        "token_dim": int(TOKEN_DIM),
        "token_hw": tuple(map(int, TOKEN_HW)),

        # keep fold ensemble from Step 4 if available
        "fold_packs": fold_packs_from_step4,

        # full-data trained packs (list; usually 1 seed)
        "full_packs": final_full_packs,

        # training meta
        "internal_val_infos": internal_val_infos,
        "recommended_thr": float(recommended_thr),
        "bundle_source": source,
        "seed_base": int(FINAL_SEED),
        "train_time_total_s": float(time.time() - t_global),
    },
    final_model_path
)

final_bundle = {
    "type": "final_mask_v1",
    "arch": "UNetASPP_on_DINOv2TokenGrid",
    "token_dim": int(TOKEN_DIM),
    "token_hw": tuple(map(int, TOKEN_HW)),

    "n_seeds": int(len(final_full_packs)),
    "seeds": [int(p["seed"]) for p in final_full_packs],
    "cfg_final": dict(CFG),

    "use_internal_val": bool(USE_INTERNAL_VAL),
    "val_frac_case": float(VAL_FRAC_CASE) if USE_INTERNAL_VAL else 0.0,
    "early_stop": bool(EARLY_STOP) if USE_INTERNAL_VAL else False,

    "train_rows": int(len(df_train_filtered)),
    "pos_rate": float(df_train_filtered["y"].mean()),

    "recommended_thr": float(recommended_thr),
    "has_fold_packs_from_step4": bool(fold_packs_from_step4 is not None),
    "best_cfg_source": source,

    "paths": {
        "final_model_pt": str(final_model_path),
        "best_cfg_json": str(cfg_path) if cfg_path.exists() else None,
        "best_model_pt": str(best_mask_model_path) if best_mask_model_path is not None else None,
    },

    "notes": "Final MASK model: internal case-level val -> pick best_epoch & thr -> retrain full; saves EMA weights.",
}

final_bundle_path = OUT_DIR / "final_mask_bundle.json"
final_bundle_path.write_text(json.dumps(final_bundle, indent=2))

print("\nSaved final MASK training artifacts:")
print("  model  ->", final_model_path)
print("  bundle ->", final_bundle_path)
print("  recommended_thr ->", float(recommended_thr))

# Export globals
FINAL_MASK_MODEL_PT = str(final_model_path)
FINAL_MASK_BUNDLE = final_bundle


# Finalize & Save Model Bundle (Reproducible)

In [None]:
# ============================================================
# Step 6 — Finalize & Save Model Bundle (Notebook-3 / Inference Notebook)
# REVISI FULL v5.2 (portable + robust discovery + SUPPORT Gate + Mask in one bundle)
#
# Upgrade utama vs v5.1:
# - Auto-discover *both* final_gate_model.pt dan final_mask_model.pt (kalau ada)
# - Output bundle tunggal (agar inference notebook bisa load Gate saja / Mask saja / Dual)
# - Thresholds robust: resolve T_gate dan T_mask dengan priority bertingkat
# - Manifest + pack menyimpan meta keduanya (sha256, cfg summary, token_hw/dim, feature_cols)
#
# Output (write):
#   /kaggle/working/recodai_luc_gate_artifacts/model_bundle_v5_notebook3/
#     - final_gate_model.pt (optional)
#     - final_gate_bundle.json (optional)
#     - feature_cols.json (required IF gate exists)
#     - final_mask_model.pt (optional)
#     - final_mask_bundle.json (optional)
#     - thresholds.json (T_gate, T_mask)
#     - model_bundle_manifest.json
#     - model_bundle_pack.json (+ optional joblib)
#     - model_bundle_v5_notebook3.zip
# ============================================================

import os, json, time, platform, warnings, zipfile, hashlib, shutil
from pathlib import Path

warnings.filterwarnings("ignore")

# ----------------------------
# 0) IO roots
# ----------------------------
OUT_ROOT = Path("/kaggle/working/recodai_luc_gate_artifacts")   # keep as requested
OUT_ROOT.mkdir(parents=True, exist_ok=True)

BUNDLE_VERSION = "v5_notebook3"
OUT_DIR = OUT_ROOT / f"model_bundle_{BUNDLE_VERSION}"
OUT_DIR.mkdir(parents=True, exist_ok=True)

print("OUTPUT (write):", OUT_DIR)

# ----------------------------
# Helpers
# ----------------------------
def read_json_safe(p: Path, default=None):
    try:
        return json.loads(Path(p).read_text())
    except Exception:
        return default

def sha256_file(p: Path, chunk=1024 * 1024):
    p = Path(p)
    h = hashlib.sha256()
    with p.open("rb") as f:
        while True:
            b = f.read(chunk)
            if not b:
                break
            h.update(b)
    return h.hexdigest()

def file_meta(p: Path):
    p = Path(p)
    if not p.exists() or not p.is_file():
        return None
    st = p.stat()
    return {
        "path": str(p),
        "name": p.name,
        "bytes": int(st.st_size),
        "sha256": sha256_file(p),
        "mtime_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(st.st_mtime)),
    }

def safe_add(zf: zipfile.ZipFile, p: Path, arcname: str):
    if p is None:
        return
    p = Path(p)
    if p.exists() and p.is_file():
        zf.write(p, arcname=arcname)

def pick_first_existing(paths):
    for p in paths:
        if p is None:
            continue
        p = Path(p)
        if p.exists() and p.is_file():
            return p
    return None

def find_file_near(root: Path, filename: str, max_depth: int = 3):
    root = Path(root)
    if (root / filename).exists():
        return root / filename

    # depth 1
    for p1 in root.glob("*"):
        if p1.is_dir() and (p1 / filename).exists():
            return p1 / filename

    if max_depth >= 2:
        for p2 in root.glob("*/*"):
            if p2.is_dir() and (p2 / filename).exists():
                return p2 / filename

    if max_depth >= 3:
        for p3 in root.glob("*/*/*"):
            if p3.is_dir() and (p3 / filename).exists():
                return p3 / filename

    return None

def copy_if_needed(src: Path, dst: Path, verbose=True):
    src, dst = Path(src), Path(dst)
    dst.parent.mkdir(parents=True, exist_ok=True)
    if not src.exists():
        raise FileNotFoundError(f"Missing src: {src}")

    if dst.exists() and dst.is_file():
        try:
            if src.stat().st_size == dst.stat().st_size:
                if sha256_file(src) == sha256_file(dst):
                    if verbose:
                        print("  [skip copy] already identical:", dst.name)
                    return dst
        except Exception:
            pass

    shutil.copy2(src, dst)
    if verbose:
        print("  [copied] ->", dst)
    return dst

def _score_dir_gate(p: Path):
    p = Path(p)
    score = 0
    if (p / "final_gate_model.pt").exists(): score += 100
    if (p / "final_gate_bundle.json").exists(): score += 30
    if (p / "feature_cols.json").exists(): score += 20
    if (p / "thresholds.json").exists(): score += 5
    if (p / "best_gate_config.json").exists(): score += 3
    if (p / "best_gate_model.pt").exists(): score += 3
    if (p / "opt_search" / "opt_results.csv").exists(): score += 1
    return score

def _score_dir_mask(p: Path):
    p = Path(p)
    score = 0
    if (p / "final_mask_model.pt").exists(): score += 100
    if (p / "final_mask_bundle.json").exists(): score += 30
    if (p / "thresholds_mask.json").exists(): score += 5
    if (p / "thresholds.json").exists(): score += 2
    if (p / "best_mask_config.json").exists(): score += 3
    if (p / "best_mask_model.pt").exists(): score += 3
    return score

def _bounded_dirs(ds: Path):
    ds = Path(ds)
    out = []
    if not ds.exists() or not ds.is_dir():
        return out
    level0 = [ds]
    level1 = [p for p in ds.glob("*") if p.is_dir()]
    level2 = [p for p in ds.glob("*/*") if p.is_dir()]
    level3 = [p for p in ds.glob("*/*/*") if p.is_dir()]
    return level0 + level1 + level2 + level3

def _gather_candidate_dirs(kind: str):
    assert kind in ["gate", "mask"]
    cands = []

    # working roots
    work_roots = [
        Path("/kaggle/working/recodai_luc_gate_artifacts"),
        Path("/kaggle/working/recodai_luc_mask_artifacts"),
        Path("/kaggle/working"),
    ]
    for base in work_roots:
        if not base.exists():
            continue
        cands.append(base)
        for sub in base.glob("model_bundle_*"):
            if sub.is_dir():
                cands.append(sub)

    # inputs (read-only datasets)
    inp = Path("/kaggle/input")
    if inp.exists():
        for ds in inp.iterdir():
            if not ds.is_dir():
                continue
            for d in _bounded_dirs(ds):
                name = d.name.lower()
                if "recodai" in name or "luc" in name or "bundle" in name or "artifacts" in name:
                    cands.append(d)
                if kind == "gate" and (d / "final_gate_model.pt").exists():
                    cands.append(d)
                if kind == "mask" and (d / "final_mask_model.pt").exists():
                    cands.append(d)
                if d.name in ["recodai_luc_gate_artifacts", "recodai_luc_mask_artifacts"]:
                    for sub in d.glob("model_bundle_*"):
                        if sub.is_dir():
                            cands.append(sub)

    # de-dup
    seen = set()
    out = []
    for p in cands:
        p = Path(p)
        sp = str(p)
        if sp not in seen and p.exists() and p.is_dir():
            seen.add(sp)
            out.append(p)
    return out

def _pick_best_dir(kind: str):
    cands = _gather_candidate_dirs(kind)
    if not cands:
        return None
    if kind == "gate":
        return max(cands, key=_score_dir_gate)
    else:
        return max(cands, key=_score_dir_mask)

# ----------------------------
# 1) Pick best SRC dirs (gate + mask)
# ----------------------------
SRC_DIR_GATE = _pick_best_dir("gate")
SRC_DIR_MASK = _pick_best_dir("mask")

print("\nSOURCE dirs picked:")
print("  gate:", SRC_DIR_GATE, "| score:", (_score_dir_gate(SRC_DIR_GATE) if SRC_DIR_GATE else None))
print("  mask:", SRC_DIR_MASK, "| score:", (_score_dir_mask(SRC_DIR_MASK) if SRC_DIR_MASK else None))

# ----------------------------
# 2) Locate artifacts (robust) — Gate
# ----------------------------
final_gate_pt = None
final_gate_bundle_json = None
feature_cols_path = None
best_gate_config_path = None
best_gate_model_path = None
baseline_report_path = None
opt_results_csv = None
opt_fold_csv = None
oof_baseline_csv = None
MODEL_DIR_GATE = None

if SRC_DIR_GATE is not None:
    final_gate_pt = find_file_near(SRC_DIR_GATE, "final_gate_model.pt", max_depth=3)
    if final_gate_pt is not None:
        MODEL_DIR_GATE = final_gate_pt.parent
        final_gate_bundle_json = find_file_near(MODEL_DIR_GATE, "final_gate_bundle.json", max_depth=2) \
                                 or find_file_near(SRC_DIR_GATE, "final_gate_bundle.json", max_depth=3)
        feature_cols_path = find_file_near(MODEL_DIR_GATE, "feature_cols.json", max_depth=2) \
                            or find_file_near(SRC_DIR_GATE, "feature_cols.json", max_depth=3)

        baseline_report_path = pick_first_existing([
            find_file_near(MODEL_DIR_GATE, "baseline_mhc_transformer_cv_report.json", max_depth=2),
            find_file_near(MODEL_DIR_GATE, "baseline_transformer_cv_report.json", max_depth=2),
            find_file_near(MODEL_DIR_GATE, "baseline_cv_report.json", max_depth=2),
            find_file_near(SRC_DIR_GATE,   "baseline_mhc_transformer_cv_report.json", max_depth=3),
            find_file_near(SRC_DIR_GATE,   "baseline_transformer_cv_report.json", max_depth=3),
            find_file_near(SRC_DIR_GATE,   "baseline_cv_report.json", max_depth=3),
        ])

        best_gate_config_path = pick_first_existing([
            find_file_near(MODEL_DIR_GATE, "best_gate_config.json", max_depth=2),
            find_file_near(SRC_DIR_GATE,   "best_gate_config.json", max_depth=3),
        ])

        best_gate_model_path = pick_first_existing([
            find_file_near(MODEL_DIR_GATE, "best_gate_model.pt", max_depth=2),
            find_file_near(SRC_DIR_GATE,   "best_gate_model.pt", max_depth=3),
        ])

        opt_results_csv = pick_first_existing([
            find_file_near(MODEL_DIR_GATE, "opt_results.csv", max_depth=3),
            find_file_near(SRC_DIR_GATE,   "opt_results.csv", max_depth=3),
            (MODEL_DIR_GATE / "opt_search" / "opt_results.csv") if (MODEL_DIR_GATE / "opt_search" / "opt_results.csv").exists() else None,
            (SRC_DIR_GATE   / "opt_search" / "opt_results.csv") if (SRC_DIR_GATE   / "opt_search" / "opt_results.csv").exists() else None,
        ])

        opt_fold_csv = pick_first_existing([
            find_file_near(MODEL_DIR_GATE, "opt_fold_details.csv", max_depth=3),
            find_file_near(SRC_DIR_GATE,   "opt_fold_details.csv", max_depth=3),
            (MODEL_DIR_GATE / "opt_search" / "opt_fold_details.csv") if (MODEL_DIR_GATE / "opt_search" / "opt_fold_details.csv").exists() else None,
            (SRC_DIR_GATE   / "opt_search" / "opt_fold_details.csv") if (SRC_DIR_GATE   / "opt_search" / "opt_fold_details.csv").exists() else None,
        ])

        oof_baseline_csv = pick_first_existing([
            find_file_near(MODEL_DIR_GATE, "oof_baseline_mhc_transformer.csv", max_depth=2),
            find_file_near(MODEL_DIR_GATE, "oof_baseline_transformer.csv", max_depth=2),
            find_file_near(MODEL_DIR_GATE, "oof_baseline.csv", max_depth=2),
            find_file_near(SRC_DIR_GATE,   "oof_baseline_mhc_transformer.csv", max_depth=3),
            find_file_near(SRC_DIR_GATE,   "oof_baseline_transformer.csv", max_depth=3),
            find_file_near(SRC_DIR_GATE,   "oof_baseline.csv", max_depth=3),
        ])

# ----------------------------
# 2b) Locate artifacts (robust) — Mask
# ----------------------------
final_mask_pt = None
final_mask_bundle_json = None
best_mask_config_path = None
best_mask_model_path = None
MODEL_DIR_MASK = None
src_thresh_mask = None

if SRC_DIR_MASK is not None:
    final_mask_pt = find_file_near(SRC_DIR_MASK, "final_mask_model.pt", max_depth=3)
    if final_mask_pt is not None:
        MODEL_DIR_MASK = final_mask_pt.parent
        final_mask_bundle_json = find_file_near(MODEL_DIR_MASK, "final_mask_bundle.json", max_depth=2) \
                                 or find_file_near(SRC_DIR_MASK, "final_mask_bundle.json", max_depth=3)

        best_mask_config_path = pick_first_existing([
            find_file_near(MODEL_DIR_MASK, "best_mask_config.json", max_depth=2),
            find_file_near(SRC_DIR_MASK,   "best_mask_config.json", max_depth=3),
        ])

        best_mask_model_path = pick_first_existing([
            find_file_near(MODEL_DIR_MASK, "best_mask_model.pt", max_depth=2),
            find_file_near(SRC_DIR_MASK,   "best_mask_model.pt", max_depth=3),
        ])

        src_thresh_mask = find_file_near(MODEL_DIR_MASK, "thresholds_mask.json", max_depth=2) \
                          or find_file_near(SRC_DIR_MASK, "thresholds_mask.json", max_depth=3)

# ----------------------------
# 2c) Hard requirement check
# ----------------------------
if final_gate_pt is None and final_mask_pt is None:
    raise FileNotFoundError(
        "Tidak menemukan final_gate_model.pt maupun final_mask_model.pt di /kaggle/working atau /kaggle/input.\n"
        "Pastikan kamu sudah add dataset output training ke notebook ini."
    )

print("\nFound artifacts (read):")
print("  [GATE] final_model      :", final_gate_pt if final_gate_pt else "(missing)")
print("  [GATE] final_bundle     :", final_gate_bundle_json if final_gate_bundle_json else "(missing/skip)")
print("  [GATE] feature_cols     :", feature_cols_path if feature_cols_path else "(missing)")
print("  [MASK] final_model      :", final_mask_pt if final_mask_pt else "(missing)")
print("  [MASK] final_bundle     :", final_mask_bundle_json if final_mask_bundle_json else "(missing/skip)")
print("  [MASK] thresholds_mask  :", src_thresh_mask if src_thresh_mask else "(missing/skip)")

# ----------------------------
# 3) COPY core artifacts into OUT_DIR (portable bundle)
# ----------------------------
print("\nCopying core files into OUT_DIR (portable):")
extras_dir = OUT_DIR / "extras"
extras_dir.mkdir(parents=True, exist_ok=True)
opt_dir = OUT_DIR / "opt_search"
opt_dir.mkdir(parents=True, exist_ok=True)
oof_dir = OUT_DIR / "oof"
oof_dir.mkdir(parents=True, exist_ok=True)

final_gate_dst = None
final_gate_bundle_dst = None
feature_cols_dst = None

if final_gate_pt is not None:
    final_gate_dst = copy_if_needed(final_gate_pt, OUT_DIR / "final_gate_model.pt")
    if final_gate_bundle_json is not None and final_gate_bundle_json.exists():
        final_gate_bundle_dst = copy_if_needed(final_gate_bundle_json, OUT_DIR / "final_gate_bundle.json")
    if feature_cols_path is None or not feature_cols_path.exists():
        raise FileNotFoundError("Gate ditemukan, tapi feature_cols.json tidak ditemukan. Pastikan Step 2 ikut dibundle.")
    feature_cols_dst = copy_if_needed(feature_cols_path, OUT_DIR / "feature_cols.json")

    # optional extras
    if baseline_report_path:
        copy_if_needed(baseline_report_path, extras_dir / Path(baseline_report_path).name, verbose=False)
    if best_gate_config_path:
        copy_if_needed(best_gate_config_path, extras_dir / Path(best_gate_config_path).name, verbose=False)
    if best_gate_model_path:
        copy_if_needed(best_gate_model_path, extras_dir / Path(best_gate_model_path).name, verbose=False)
    if opt_results_csv:
        copy_if_needed(opt_results_csv, opt_dir / Path(opt_results_csv).name, verbose=False)
    if opt_fold_csv:
        copy_if_needed(opt_fold_csv, opt_dir / Path(opt_fold_csv).name, verbose=False)
    if oof_baseline_csv:
        copy_if_needed(oof_baseline_csv, oof_dir / Path(oof_baseline_csv).name, verbose=False)

final_mask_dst = None
final_mask_bundle_dst = None
thresholds_mask_dst = None

if final_mask_pt is not None:
    final_mask_dst = copy_if_needed(final_mask_pt, OUT_DIR / "final_mask_model.pt")
    if final_mask_bundle_json is not None and final_mask_bundle_json.exists():
        final_mask_bundle_dst = copy_if_needed(final_mask_bundle_json, OUT_DIR / "final_mask_bundle.json")
    if best_mask_config_path:
        copy_if_needed(best_mask_config_path, extras_dir / Path(best_mask_config_path).name, verbose=False)
    if best_mask_model_path:
        copy_if_needed(best_mask_model_path, extras_dir / Path(best_mask_model_path).name, verbose=False)
    if src_thresh_mask and src_thresh_mask.exists():
        thresholds_mask_dst = copy_if_needed(src_thresh_mask, OUT_DIR / "thresholds_mask.json", verbose=False)

# ----------------------------
# 4) Load metadata (from copied files)
# ----------------------------
gate_feature_cols = []
gate_final_bundle = {}
mask_final_bundle = {}
best_gate_config = None
best_mask_config = None

if feature_cols_dst is not None:
    gate_feature_cols = read_json_safe(feature_cols_dst, default=[])
    if not isinstance(gate_feature_cols, list) or len(gate_feature_cols) == 0:
        raise ValueError(f"feature_cols invalid/empty: {feature_cols_dst}")

if final_gate_bundle_dst is not None:
    gate_final_bundle = read_json_safe(final_gate_bundle_dst, default={}) or {}

if final_mask_bundle_dst is not None:
    mask_final_bundle = read_json_safe(final_mask_bundle_dst, default={}) or {}

if best_gate_config_path:
    best_gate_config = read_json_safe(best_gate_config_path, default=None)
if best_mask_config_path:
    best_mask_config = read_json_safe(best_mask_config_path, default=None)

# recommended thr from pt(s)
recommended_thr_gate_from_pt = None
recommended_thr_mask_from_pt = None
token_hw_from_mask_pt = None
token_dim_from_mask_pt = None

try:
    import torch
    if final_gate_dst is not None:
        objg = torch.load(final_gate_dst, map_location="cpu")
        if isinstance(objg, dict):
            recommended_thr_gate_from_pt = objg.get("recommended_thr", None)
    if final_mask_dst is not None:
        objm = torch.load(final_mask_dst, map_location="cpu")
        if isinstance(objm, dict):
            recommended_thr_mask_from_pt = objm.get("recommended_thr", None)
            token_hw_from_mask_pt = objm.get("token_hw", None)
            token_dim_from_mask_pt = objm.get("token_dim", None)
except Exception as e:
    print("Warning: failed to read *.pt metadata:", repr(e))

# ----------------------------
# 5) Threshold resolve (T_gate + T_mask)
# ----------------------------
def extract_thr_from_best_cfg(cfg: dict):
    if not isinstance(cfg, dict):
        return None
    if cfg.get("oof_best_thr", None) is not None:
        try:
            return float(cfg["oof_best_thr"])
        except Exception:
            pass
    sel = cfg.get("selection", None)
    if isinstance(sel, dict) and sel.get("oof_best_thr", None) is not None:
        try:
            return float(sel["oof_best_thr"])
        except Exception:
            pass
    # some configs store "recommended_thr"
    if cfg.get("recommended_thr", None) is not None:
        try:
            return float(cfg["recommended_thr"])
        except Exception:
            pass
    return None

# Gate threshold priority:
# (a) SRC thresholds.json near gate model dir
# (b) OUT_DIR thresholds.json if rerun
# (c) final_gate_bundle.json recommended_thr
# (d) final_gate_model.pt recommended_thr
# (e) best_gate_config.json oof_best_thr
# (f) 0.5
T_gate = None
src_thresh_gate = None
if MODEL_DIR_GATE is not None:
    src_thresh_gate = find_file_near(MODEL_DIR_GATE, "thresholds.json", max_depth=2) or \
                      (find_file_near(SRC_DIR_GATE, "thresholds.json", max_depth=3) if SRC_DIR_GATE else None)

out_thresh_path = OUT_DIR / "thresholds.json"
existing_out_thresh = read_json_safe(out_thresh_path, default=None) if out_thresh_path.exists() else None

if src_thresh_gate and src_thresh_gate.exists():
    tj = read_json_safe(src_thresh_gate, default={})
    if isinstance(tj, dict) and tj.get("T_gate", None) is not None:
        try:
            T_gate = float(tj["T_gate"])
        except Exception:
            T_gate = None

if T_gate is None and isinstance(existing_out_thresh, dict) and existing_out_thresh.get("T_gate", None) is not None:
    try:
        T_gate = float(existing_out_thresh["T_gate"])
    except Exception:
        T_gate = None

if T_gate is None and isinstance(gate_final_bundle, dict) and gate_final_bundle.get("recommended_thr", None) is not None:
    try:
        T_gate = float(gate_final_bundle["recommended_thr"])
    except Exception:
        T_gate = None

if T_gate is None and recommended_thr_gate_from_pt is not None:
    try:
        T_gate = float(recommended_thr_gate_from_pt)
    except Exception:
        T_gate = None

if T_gate is None and isinstance(best_gate_config, dict):
    T_gate = extract_thr_from_best_cfg(best_gate_config)

if T_gate is None:
    T_gate = 0.5

# Mask threshold priority:
# (a) thresholds_mask.json near mask model dir
# (b) thresholds.json existing OUT_DIR (T_mask)
# (c) final_mask_bundle.json recommended_thr
# (d) final_mask_model.pt recommended_thr
# (e) best_mask_config.json oof_best_thr/recommended_thr
# (f) 0.5
T_mask = None
if thresholds_mask_dst is not None and thresholds_mask_dst.exists():
    tj = read_json_safe(thresholds_mask_dst, default={})
    if isinstance(tj, dict) and tj.get("T_mask", None) is not None:
        try:
            T_mask = float(tj["T_mask"])
        except Exception:
            T_mask = None
    elif isinstance(tj, dict) and tj.get("recommended_thr", None) is not None:
        try:
            T_mask = float(tj["recommended_thr"])
        except Exception:
            T_mask = None

if T_mask is None and isinstance(existing_out_thresh, dict) and existing_out_thresh.get("T_mask", None) is not None:
    try:
        T_mask = float(existing_out_thresh["T_mask"])
    except Exception:
        T_mask = None

if T_mask is None and isinstance(mask_final_bundle, dict) and mask_final_bundle.get("recommended_thr", None) is not None:
    try:
        T_mask = float(mask_final_bundle["recommended_thr"])
    except Exception:
        T_mask = None

if T_mask is None and recommended_thr_mask_from_pt is not None:
    try:
        T_mask = float(recommended_thr_mask_from_pt)
    except Exception:
        T_mask = None

if T_mask is None and isinstance(best_mask_config, dict):
    T_mask = extract_thr_from_best_cfg(best_mask_config)

if T_mask is None:
    T_mask = 0.5

thresholds = {
    "T_gate": float(T_gate) if final_gate_dst is not None else None,
    "T_mask": float(T_mask) if final_mask_dst is not None else None,
    "beta_for_tuning": float(best_gate_config.get("beta_for_tuning", 0.5)) if isinstance(best_gate_config, dict) else 0.5,
    "guards": {
        "min_area_frac": None,
        "max_area_frac": None,
        "max_components": None,
    },
    "source_priority": {
        "T_gate": [
            "SRC thresholds.json (near gate model)",
            "OUT_DIR/thresholds.json (existing)",
            "final_gate_bundle.json.recommended_thr",
            "final_gate_model.pt.recommended_thr",
            "best_gate_config.json.oof_best_thr (or selection.oof_best_thr legacy)",
            "fallback 0.5",
        ],
        "T_mask": [
            "SRC thresholds_mask.json (near mask model)",
            "OUT_DIR/thresholds.json (existing)",
            "final_mask_bundle.json.recommended_thr",
            "final_mask_model.pt.recommended_thr",
            "best_mask_config.json oof_best_thr/recommended_thr",
            "fallback 0.5",
        ],
    },
    "notes": "Thresholds used in inference. Update after calibration/OOF tuning if needed.",
}

thresholds_path = OUT_DIR / "thresholds.json"
thresholds_path.write_text(json.dumps(thresholds, indent=2))

print("\nThreshold resolved:")
print("  T_gate:", thresholds["T_gate"])
print("  T_mask:", thresholds["T_mask"])
print("  thresholds.json ->", thresholds_path)

# ----------------------------
# 6) cfg_meta (optional)
# ----------------------------
cfg_meta = {}
if "PATHS" in globals() and isinstance(PATHS, dict):
    cfg_meta = {k: PATHS.get(k, None) for k in [
        "COMP_ROOT","OUT_DS_ROOT","OUT_ROOT","MATCH_CFG_DIR","PRED_CFG_DIR","DINO_CFG_DIR","DINO_LARGE_DIR",
        "PRED_FEAT_TRAIN","MATCH_FEAT_TRAIN","DF_TRAIN_ALL","CV_CASE_FOLDS","IMG_PROFILE_TRAIN"
    ]}

# ----------------------------
# 7) Manifest (reproducible) + sha256 index
# ----------------------------
task_str = "Recod.ai/LUC — Portable Bundle — (optional) Gate + (optional) Mask"
model_format = "torch_pt"

artifact_paths = [
    final_gate_dst,
    final_gate_bundle_dst,
    feature_cols_dst,
    final_mask_dst,
    final_mask_bundle_dst,
    thresholds_mask_dst,
    thresholds_path,
]
# extras
for p in extras_dir.glob("*"):
    artifact_paths.append(p)
for p in opt_dir.glob("*"):
    artifact_paths.append(p)
for p in oof_dir.glob("*"):
    artifact_paths.append(p)

artifact_paths = [p for p in artifact_paths if p is not None]

artifacts_meta = {}
for p in artifact_paths:
    m = file_meta(p)
    if m is not None:
        artifacts_meta[m["name"]] = m

manifest = {
    "created_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
    "python": platform.python_version(),
    "platform": platform.platform(),
    "bundle_version": BUNDLE_VERSION,
    "task": task_str,
    "model_format": model_format,

    "source_dirs": {
        "gate": str(SRC_DIR_GATE) if SRC_DIR_GATE else None,
        "mask": str(SRC_DIR_MASK) if SRC_DIR_MASK else None,
    },
    "source_model_dirs": {
        "gate": str(MODEL_DIR_GATE) if MODEL_DIR_GATE else None,
        "mask": str(MODEL_DIR_MASK) if MODEL_DIR_MASK else None,
    },
    "output_dir": str(OUT_DIR),

    "artifacts_index": artifacts_meta,
    "cfg_meta": cfg_meta,

    "summary_gate": {
        "present": bool(final_gate_dst is not None),
        "type": (gate_final_bundle.get("type") if isinstance(gate_final_bundle, dict) else None),
        "n_seeds": (gate_final_bundle.get("n_seeds") if isinstance(gate_final_bundle, dict) else None),
        "seeds": (gate_final_bundle.get("seeds") if isinstance(gate_final_bundle, dict) else None),
        "train_rows": (gate_final_bundle.get("train_rows") if isinstance(gate_final_bundle, dict) else None),
        "pos_rate": (gate_final_bundle.get("pos_rate") if isinstance(gate_final_bundle, dict) else None),
        "feature_count": int(len(gate_feature_cols)) if gate_feature_cols else None,
        "T_gate": thresholds.get("T_gate", None),
        "recommended_thr_from_pt": recommended_thr_gate_from_pt,
    },
    "summary_mask": {
        "present": bool(final_mask_dst is not None),
        "type": (mask_final_bundle.get("type") if isinstance(mask_final_bundle, dict) else None),
        "train_rows": (mask_final_bundle.get("train_rows") if isinstance(mask_final_bundle, dict) else None),
        "pos_rate": (mask_final_bundle.get("pos_rate") if isinstance(mask_final_bundle, dict) else None),
        "token_hw_from_pt": token_hw_from_mask_pt,
        "token_dim_from_pt": token_dim_from_mask_pt,
        "T_mask": thresholds.get("T_mask", None),
        "recommended_thr_from_pt": recommended_thr_mask_from_pt,
    },
}

manifest_path = OUT_DIR / "model_bundle_manifest.json"
manifest_path.write_text(json.dumps(manifest, indent=2))

# ----------------------------
# 8) Bundle pack (portable JSON) + optional joblib
# ----------------------------
bundle_pack = {
    "bundle_version": BUNDLE_VERSION,
    "model_format": model_format,
    "bundle_files": {
        "final_gate_model.pt": "final_gate_model.pt" if final_gate_dst is not None else None,
        "final_gate_bundle.json": "final_gate_bundle.json" if final_gate_bundle_dst is not None else None,
        "feature_cols.json": "feature_cols.json" if feature_cols_dst is not None else None,

        "final_mask_model.pt": "final_mask_model.pt" if final_mask_dst is not None else None,
        "final_mask_bundle.json": "final_mask_bundle.json" if final_mask_bundle_dst is not None else None,
        "thresholds_mask.json": "thresholds_mask.json" if thresholds_mask_dst is not None else None,

        "thresholds.json": "thresholds.json",
        "model_bundle_manifest.json": "model_bundle_manifest.json",
        "model_bundle_pack.json": "model_bundle_pack.json",
        "model_bundle_pack.joblib": "model_bundle_pack.joblib",
        "model_bundle_v5_notebook3.zip": "model_bundle_v5_notebook3.zip",
    },
    "thresholds": thresholds,
    "feature_cols": gate_feature_cols if gate_feature_cols else None,
    "token_meta": {
        "token_hw": token_hw_from_mask_pt,
        "token_dim": token_dim_from_mask_pt,
    } if final_mask_dst is not None else None,
    "cfg_meta": cfg_meta,
    "manifest": manifest,
}

bundle_pack_json = OUT_DIR / "model_bundle_pack.json"
bundle_pack_json.write_text(json.dumps(bundle_pack, indent=2))

bundle_pack_joblib = OUT_DIR / "model_bundle_pack.joblib"
joblib_ok = False
try:
    import joblib
    joblib.dump(bundle_pack, bundle_pack_joblib)
    joblib_ok = True
except Exception:
    joblib_ok = False

# ----------------------------
# 9) Create portable ZIP (writes to OUT_DIR)
# ----------------------------
zip_path = OUT_DIR / "model_bundle_v5_notebook3.zip"

with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
    # core
    safe_add(zf, final_gate_dst, "final_gate_model.pt")
    safe_add(zf, final_gate_bundle_dst, "final_gate_bundle.json")
    safe_add(zf, feature_cols_dst, "feature_cols.json")

    safe_add(zf, final_mask_dst, "final_mask_model.pt")
    safe_add(zf, final_mask_bundle_dst, "final_mask_bundle.json")
    safe_add(zf, thresholds_mask_dst, "thresholds_mask.json")

    safe_add(zf, thresholds_path, "thresholds.json")
    safe_add(zf, manifest_path, "model_bundle_manifest.json")
    safe_add(zf, bundle_pack_json, "model_bundle_pack.json")
    if joblib_ok:
        safe_add(zf, bundle_pack_joblib, "model_bundle_pack.joblib")

    # extras
    if extras_dir.exists():
        for p in extras_dir.glob("*"):
            safe_add(zf, p, f"extras/{p.name}")

    # opt_search
    if opt_dir.exists():
        for p in opt_dir.glob("*"):
            safe_add(zf, p, f"opt_search/{p.name}")

    # oof
    if oof_dir.exists():
        for p in oof_dir.glob("*"):
            safe_add(zf, p, f"oof/{p.name}")

print("\nOK — Model bundle finalized (Notebook-3 compatible)")
print("  OUT_DIR          ->", OUT_DIR)
print("  manifest         ->", manifest_path)
print("  pack (json)      ->", bundle_pack_json)
print("  pack (joblib)    ->", (bundle_pack_joblib if joblib_ok else "(skip; joblib not available)"))
print("  thresholds       ->", thresholds_path)
print("  zip              ->", zip_path)

print("\nBundle summary:")
print("  bundle_version:", BUNDLE_VERSION)
print("  model_format  :", model_format)
print("  has_gate      :", final_gate_dst is not None)
print("  has_mask      :", final_mask_dst is not None)
print("  feature_cnt   :", (len(gate_feature_cols) if gate_feature_cols else 0))
print("  T_gate        :", thresholds.get("T_gate"))
print("  T_mask        :", thresholds.get("T_mask"))
print("  task          :", task_str)
