# Set Paths & Select Config (CFG)

In [1]:
# ============================================================
# STAGE 0 — Set Paths & Select Config (CFG) (Kaggle-ready, offline)
# REVISI FULL v3.0 (lebih kuat + siap MULTI-CFG + lebih aman anti-error)
#
# Fokus upgrade v3.0 (sesuai strategi naik score):
# - Multi-CFG support: pilih TOP-K CFG kandidat (bukan cuma 1) untuk MATCH & PRED
#   -> enabling: stacking/selector di stage training (anti overfit satu CFG)
# - Scoring CFG lebih kaya:
#   * wajib: feat_train ada & rows > 0
#   * prefer: feat_test ada
#   * prefer: manifest_test ada, pred_summary.json ada
#   * prefer: ada folder npz test/train_all (untuk submission / mask features)
#   * tie-break: rows train/test + modified time
# - Cache/artifacts root independen:
#   * artifacts bisa dari input, cache bisa dari working (atau sebaliknya)
# - DINO cache cfg autodetect multi-backbone (large/giant/base) dari cache
# - Sanity guard lebih informatif, file opsional tidak bikin crash
#
# Output globals (TETAP, JANGAN diganti):
# - COMP_ROOT, OUT_DS_ROOT, OUT_ROOT
# - PATHS (dict)
# - MATCH_CFG_DIR, PRED_CFG_DIR, DINO_CFG_DIR
#
# Extra globals (aman, membantu training lanjutan):
# - MATCH_CFG_DIRS, PRED_CFG_DIRS (list[Path] TOP-K)
# - MATCH_CFG_INFO, PRED_CFG_INFO (list[dict] detail skor)
# - CACHE_ROOTS (list[Path]), SELECTED (dict), TRAIN_PLAN (dict)
# ============================================================

import os, re, json, time
from pathlib import Path
import pandas as pd

# ----------------------------
# Config knobs (boleh diubah)
# ----------------------------
TOPK_MATCH_CFGS = int(os.environ.get("TOPK_MATCH_CFGS", "5"))
TOPK_PRED_CFGS  = int(os.environ.get("TOPK_PRED_CFGS",  "8"))

# Kalau kamu mau paksa pakai dataset input tertentu:
# export LUC_OUT_DS_ROOT="/kaggle/input/<nama_dataset_output>"
ENV_OUT_DS_ROOT = os.environ.get("LUC_OUT_DS_ROOT", "").strip()

# ----------------------------
# Helper: fast count CSV rows (binary newline count)
# ----------------------------
def _fast_count_rows_csv(path: Path, assume_header: bool = True) -> int:
    """
    Count rows quickly by counting '\n' in binary.
    Returns -1 if error.
    """
    try:
        if not path.exists() or not path.is_file():
            return -1
        # count newlines
        nl = 0
        with path.open("rb") as f:
            while True:
                b = f.read(1024 * 1024)
                if not b:
                    break
                nl += b.count(b"\n")
        # if file ends without newline, nl still ok for line count approximation
        # rows = lines - header
        rows = nl - (1 if assume_header else 0)
        return int(max(rows, 0))
    except Exception:
        return -1

def _safe_mtime(p: Path) -> float:
    try:
        return float(p.stat().st_mtime)
    except Exception:
        return 0.0

def _is_nonempty_file(p: Path) -> bool:
    try:
        return p is not None and p.exists() and p.is_file() and p.stat().st_size > 0
    except Exception:
        return False

def _dir_has_any_npz(d: Path) -> bool:
    try:
        if d is None or (not d.exists()) or (not d.is_dir()):
            return False
        # cepat: cek 1 file npz saja
        for _ in d.glob("*.npz"):
            return True
        return False
    except Exception:
        return False

# ----------------------------
# Helper: find competition root
# ----------------------------
def find_comp_root(preferred: str = "/kaggle/input/recodai-luc-scientific-image-forgery-detection") -> Path:
    p = Path(preferred)
    if p.exists():
        return p

    base = Path("/kaggle/input")
    if not base.exists():
        raise FileNotFoundError("/kaggle/input tidak ditemukan (pastikan kamu di Kaggle Notebook).")

    cands = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        if (d / "sample_submission.csv").exists() and ((d / "train_images").exists() or (d / "test_images").exists()):
            cands.append(d)

    if not cands:
        for d in base.iterdir():
            if not d.is_dir():
                continue
            for x in d.iterdir():
                if not x.is_dir():
                    continue
                if (x / "sample_submission.csv").exists() and ((x / "train_images").exists() or (x / "test_images").exists()):
                    cands.append(x)

    if not cands:
        raise FileNotFoundError(
            "COMP_ROOT tidak ditemukan. Harus ada folder yang memuat sample_submission.csv dan train_images/test_images."
        )

    cands.sort(key=lambda x: (("recodai" not in x.name.lower()),
                              ("forgery" not in x.name.lower()),
                              x.name))
    return cands[0]

# ----------------------------
# Helper: find output dataset root (hasil PREP)
# ----------------------------
def find_output_dataset_root(preferred_names=(
    "recod-ailuc-dinov2-base",
    "recod-ai-luc-dinov2-base",
    "recodai-luc-dinov2-base",
    "recodai-luc-dinov2",
    "recodai-luc-dinov2-prep",
)) -> Path:
    base = Path("/kaggle/input")

    # env override
    if ENV_OUT_DS_ROOT:
        p = Path(ENV_OUT_DS_ROOT)
        if p.exists():
            return p
        else:
            print(f"WARNING: ENV LUC_OUT_DS_ROOT tidak ditemukan: {p}")

    for nm in preferred_names:
        p = base / nm
        if p.exists():
            return p

    cands = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        if (d / "recodai_luc" / "artifacts").exists() or (d / "recodai_luc" / "cache").exists():
            cands.append(d)
            continue
        inner = list(d.glob("*/recodai_luc/artifacts"))
        if inner:
            cands.append(d)

    if not cands:
        raise FileNotFoundError("OUT_DS_ROOT tidak ditemukan. Harus ada /kaggle/input/<...>/recodai_luc/(artifacts|cache)/")

    cands.sort(key=lambda x: (("dinov2" not in x.name.lower()), x.name))
    return cands[0]

# ----------------------------
# Helper: resolve OUT_ROOT = <dataset>/recodai_luc
# ----------------------------
def resolve_out_root(out_ds_root: Path) -> Path:
    direct = out_ds_root / "recodai_luc"
    if direct.exists():
        return direct
    hits = list(out_ds_root.glob("*/recodai_luc"))
    if hits:
        return hits[0]
    raise FileNotFoundError(f"Folder recodai_luc tidak ditemukan di bawah {out_ds_root}")

# ----------------------------
# Helper: pick TOP-K cfg directories by multi-criteria scoring
# ----------------------------
def pick_top_cfgs(
    cache_roots,
    prefixes,
    required_train_file: str,
    prefer_files=(),
    extra_prefer_dirs=(),
    max_return: int = 5,
) -> list:
    """
    Return list of candidate dicts sorted by score desc.
    Each candidate dict: {dir, score, train_rows, prefer_hits, mt, root, debug...}
    """
    if isinstance(prefixes, str):
        prefixes = [prefixes]
    prefixes = list(prefixes)

    prefer_files = list(prefer_files or [])
    extra_prefer_dirs = list(extra_prefer_dirs or [])

    cands = []
    for root in cache_roots:
        root = Path(root)
        if not root.exists():
            continue

        for d in root.iterdir():
            if not d.is_dir():
                continue
            if not any(d.name.startswith(px) for px in prefixes):
                continue

            train_fp = d / required_train_file
            if not train_fp.exists():
                continue

            train_n = _fast_count_rows_csv(train_fp)
            if train_n <= 0:
                continue

            # prefer files
            pref_hit = 0
            pref_rows_sum = 0
            pref_detail = {}
            for fn in prefer_files:
                fp = d / fn
                ok = _is_nonempty_file(fp)
                pref_detail[fn] = bool(ok)
                if ok:
                    pref_hit += 1
                    pref_rows_sum += max(_fast_count_rows_csv(fp), 0)

            # prefer dirs (npz availability or other dirs)
            dir_hit = 0
            dir_detail = {}
            for dn in extra_prefer_dirs:
                dd = d / dn
                ok = dd.exists() and dd.is_dir()
                # special: if expecting npz, check any npz
                if ok and dn in ("test", "train_all"):
                    ok = _dir_has_any_npz(dd)
                dir_detail[dn] = bool(ok)
                if ok:
                    dir_hit += 1

            # mtime: consider dir + train file + any prefer file
            mt = max(_safe_mtime(d), _safe_mtime(train_fp))
            for fn in prefer_files:
                mt = max(mt, _safe_mtime(d / fn))

            # score design:
            # - strong prefer: prefer files hits
            # - prefer dirs hits (mask availability)
            # - then train rows, then prefer rows
            # - then newest mtime
            score = 0.0
            score += 1e6 * float(pref_hit)            # prefer having test/manifest/summary
            score += 2e5 * float(dir_hit)             # prefer having npz dirs
            score += 1.0  * float(train_n)             # size of train features
            score += 0.05 * float(pref_rows_sum)       # size of prefer files
            score += 1e-6 * float(mt)                  # newest

            cands.append({
                "dir": d,
                "root": root,
                "score": score,
                "train_file": str(train_fp),
                "train_rows": int(train_n),
                "prefer_hits": int(pref_hit),
                "prefer_rows_sum": int(pref_rows_sum),
                "prefer_detail": pref_detail,
                "dir_hits": int(dir_hit),
                "dir_detail": dir_detail,
                "mtime": float(mt),
            })

    cands.sort(key=lambda x: (-x["score"], x["dir"].name))
    return cands[:max_return]

# ----------------------------
# Helper: detect DINO model dir (offline)
# ----------------------------
def detect_dino_dir() -> Path:
    # try standard Kaggle dinov2 dataset structure
    base = Path("/kaggle/input/dinov2/pytorch")
    if base.exists():
        # prefer large -> giant -> base
        for name in ["large", "giant", "base"]:
            p = base / name / "1"
            if p.exists():
                return p
    # fallback (might be missing; warning only)
    return Path("/kaggle/input/dinov2/pytorch/large/1")

def detect_dino_cache_cfg(cache_dirs: list) -> Path:
    """
    Find best cfg under cache/dino_v2_*/cfg_*/manifest_train_all.csv
    Prefer large then giant then base, but choose whichever has best manifest size.
    """
    best = None
    best_key = None  # tuple for sorting

    # priority map (lower is better)
    prio = {"dino_v2_large": 0, "dino_v2_giant": 1, "dino_v2_base": 2}

    for root in cache_dirs:
        root = Path(root)
        if not root.exists():
            continue

        for dino_name in ["dino_v2_large", "dino_v2_giant", "dino_v2_base"]:
            dino_root = root / dino_name
            if not dino_root.exists():
                continue

            for cfg in dino_root.iterdir():
                if not (cfg.is_dir() and cfg.name.startswith("cfg_")):
                    continue
                mf = cfg / "manifest_train_all.csv"
                if not mf.exists():
                    continue

                n = _fast_count_rows_csv(mf)
                mt = _safe_mtime(cfg)
                key = (prio.get(dino_name, 9), -n, -mt, cfg.name)
                # choose best: lowest prio, largest n, newest mt
                if best is None or key < best_key:
                    best = cfg
                    best_key = key

    return best  # can be None

# ============================================================
# 0) Locate roots
# ============================================================
COMP_ROOT = find_comp_root("/kaggle/input/recodai-luc-scientific-image-forgery-detection")
OUT_DS_ROOT = find_output_dataset_root()
OUT_ROOT = resolve_out_root(OUT_DS_ROOT)  # dataset input: .../recodai_luc

WORK_OUT_ROOT = Path("/kaggle/working/recodai_luc")

# Cache roots: prefer working if exists, plus input
CACHE_ROOTS = []
if (WORK_OUT_ROOT / "cache").exists():
    CACHE_ROOTS.append(WORK_OUT_ROOT / "cache")
if (OUT_ROOT / "cache").exists():
    CACHE_ROOTS.append(OUT_ROOT / "cache")

# Artifact roots: prefer working if exists, plus input
ART_ROOTS = []
if (WORK_OUT_ROOT / "artifacts").exists():
    ART_ROOTS.append(WORK_OUT_ROOT / "artifacts")
if (OUT_ROOT / "artifacts").exists():
    ART_ROOTS.append(OUT_ROOT / "artifacts")

# choose first existing artifact root
ART_DIR = None
for a in ART_ROOTS:
    if a.exists():
        ART_DIR = a
        break
if ART_DIR is None:
    raise FileNotFoundError("ART_DIR tidak ditemukan di /kaggle/working maupun dataset input.")

# cache dirs that exist
CACHE_DIRS = [p for p in CACHE_ROOTS if Path(p).exists()]
if not CACHE_DIRS:
    raise FileNotFoundError("CACHE_DIR tidak ditemukan di /kaggle/working maupun dataset input.")

# ============================================================
# 1) Competition paths (raw images/masks)
# ============================================================
PATHS = {}
PATHS["COMP_ROOT"] = str(COMP_ROOT)
PATHS["SAMPLE_SUB"] = str(COMP_ROOT / "sample_submission.csv")

PATHS["TRAIN_IMAGES"] = str(COMP_ROOT / "train_images")
PATHS["TEST_IMAGES"]  = str(COMP_ROOT / "test_images")
PATHS["TRAIN_MASKS"]  = str(COMP_ROOT / "train_masks")
PATHS["SUPP_IMAGES"]  = str(COMP_ROOT / "supplemental_images")
PATHS["SUPP_MASKS"]   = str(COMP_ROOT / "supplemental_masks")

PATHS["TRAIN_AUTH_DIR"] = str(COMP_ROOT / "train_images" / "authentic")
PATHS["TRAIN_FORG_DIR"] = str(COMP_ROOT / "train_images" / "forged")

# ============================================================
# 2) Output dataset paths (clean artifacts + cache)
# ============================================================
PATHS["OUT_DS_ROOT"] = str(OUT_DS_ROOT)
PATHS["OUT_ROOT"]    = str(OUT_ROOT)

PATHS["ART_DIR"]     = str(ART_DIR)
PATHS["CACHE_DIRS"]  = [str(x) for x in CACHE_DIRS]  # list

# artifacts utama
PATHS["DF_TRAIN_ALL"] = str(Path(ART_DIR) / "df_train_all.parquet")
PATHS["DF_TRAIN_CLS"] = str(Path(ART_DIR) / "df_train_cls.parquet")
PATHS["DF_TRAIN_SEG"] = str(Path(ART_DIR) / "df_train_seg.parquet")
PATHS["DF_TEST"]      = str(Path(ART_DIR) / "df_test.parquet")

PATHS["CV_CASE_FOLDS"]   = str(Path(ART_DIR) / "cv_case_folds.csv")
PATHS["CV_SAMPLE_FOLDS"] = str(Path(ART_DIR) / "cv_sample_folds.csv")

PATHS["IMG_PROFILE_TRAIN"] = str(Path(ART_DIR) / "image_profile_train.parquet")
PATHS["IMG_PROFILE_TEST"]  = str(Path(ART_DIR) / "image_profile_test.parquet")
PATHS["MASK_PROFILE"]      = str(Path(ART_DIR) / "mask_profile.parquet")
PATHS["CASE_SUMMARY"]      = str(Path(ART_DIR) / "case_summary.parquet")

# ============================================================
# 3) Select best MATCH/PRED CFG dirs automatically (TOP-K + scoring)
# ============================================================
# MATCH candidates: must have match_features_train_all.csv
MATCH_CFG_INFO = pick_top_cfgs(
    CACHE_DIRS,
    prefixes=["match_base_cfg_"],
    required_train_file="match_features_train_all.csv",
    prefer_files=[
        "match_features_test.csv",
        "manifest_match_test.csv",
        "manifest_match_train_all.csv",
    ],
    extra_prefer_dirs=[],  # biasanya match cfg tidak punya npz
    max_return=max(1, TOPK_MATCH_CFGS),
)

if not MATCH_CFG_INFO:
    raise FileNotFoundError("Tidak menemukan match cfg folder yang valid (match_features_train_all.csv).")

MATCH_CFG_DIRS = [x["dir"] for x in MATCH_CFG_INFO]
MATCH_CFG_DIR  = MATCH_CFG_DIRS[0]  # primary (kompatibilitas stage lanjut)

# PRED candidates: must have pred_features_train_all.csv
# Prefer: pred_features_test, manifests, summary, and availability of npz dirs
PRED_CFG_INFO = pick_top_cfgs(
    CACHE_DIRS,
    prefixes=["pred_base"],
    required_train_file="pred_features_train_all.csv",
    prefer_files=[
        "pred_features_test.csv",
        "manifest_pred_test.csv",
        "manifest_pred_train_all.csv",
        "pred_summary.json",
    ],
    extra_prefer_dirs=["test", "train_all"],  # prefer having npz predictions
    max_return=max(1, TOPK_PRED_CFGS),
)

if not PRED_CFG_INFO:
    raise FileNotFoundError("Tidak menemukan pred cfg folder yang valid (pred_features_train_all.csv).")

PRED_CFG_DIRS = [x["dir"] for x in PRED_CFG_INFO]
PRED_CFG_DIR  = PRED_CFG_DIRS[0]  # primary (kompatibilitas stage lanjut)

# DINO cache cfg (opsional)
DINO_CFG_DIR = detect_dino_cache_cfg(CACHE_DIRS)

# simpan cfg dir ke PATHS
PATHS["MATCH_CFG_DIR"]    = str(MATCH_CFG_DIR)
PATHS["PRED_CFG_DIR"]     = str(PRED_CFG_DIR)
PATHS["DINO_CFG_DIR"]     = str(DINO_CFG_DIR) if DINO_CFG_DIR else ""

# multi-cfg (baru, opsional)
PATHS["MATCH_CFG_DIRS"]   = [str(x) for x in MATCH_CFG_DIRS]
PATHS["PRED_CFG_DIRS"]    = [str(x) for x in PRED_CFG_DIRS]

# feature paths dari cfg primary
PATHS["MATCH_FEAT_TRAIN"] = str(MATCH_CFG_DIR / "match_features_train_all.csv")
PATHS["MATCH_FEAT_TEST"]  = str(MATCH_CFG_DIR / "match_features_test.csv")

PATHS["PRED_FEAT_TRAIN"]  = str(PRED_CFG_DIR / "pred_features_train_all.csv")
PATHS["PRED_FEAT_TEST"]   = str(PRED_CFG_DIR / "pred_features_test.csv")  # bisa missing (warning)

# pred manifests (sering kepakai untuk infer/submission)
PATHS["PRED_MAN_TRAIN"] = str(PRED_CFG_DIR / "manifest_pred_train_all.csv")
PATHS["PRED_MAN_TEST"]  = str(PRED_CFG_DIR / "manifest_pred_test.csv")
PATHS["PRED_SUMMARY"]   = str(PRED_CFG_DIR / "pred_summary.json")

# match manifests (kalau ada)
PATHS["MATCH_MAN_TRAIN"] = str(MATCH_CFG_DIR / "manifest_match_train_all.csv")
PATHS["MATCH_MAN_TEST"]  = str(MATCH_CFG_DIR / "manifest_match_test.csv")

# ============================================================
# 4) DINO model dir (offline)
# ============================================================
DINO_DIR = detect_dino_dir()
PATHS["DINO_DIR"] = str(DINO_DIR)

# ============================================================
# 5) Sanity checks (wajib ada untuk training)
# ============================================================
must_exist = [
    ("sample_submission.csv", PATHS["SAMPLE_SUB"]),
    ("df_train_all.parquet",  PATHS["DF_TRAIN_ALL"]),
    ("cv_case_folds.csv",     PATHS["CV_CASE_FOLDS"]),
    ("match_features_train_all.csv", PATHS["MATCH_FEAT_TRAIN"]),
    ("pred_features_train_all.csv",  PATHS["PRED_FEAT_TRAIN"]),
]
missing = [name for name, p in must_exist if not Path(p).exists()]
if missing:
    raise FileNotFoundError("Missing required files: " + ", ".join(missing))

# opsional tapi penting untuk inference: test feature files
opt_warn = []
if not Path(PATHS["MATCH_FEAT_TEST"]).exists():
    opt_warn.append("match_features_test.csv (MATCH_FEAT_TEST)")
if not Path(PATHS["PRED_FEAT_TEST"]).exists():
    opt_warn.append("pred_features_test.csv (PRED_FEAT_TEST)")
if opt_warn:
    print("WARNING: File opsional untuk inference belum ada:")
    for w in opt_warn:
        print(" -", w)
    print("Catatan: training masih aman (pakai *_train_all). Untuk inference gate ke test, file ini biasanya dibutuhkan.")

# DINO model dir opsional (warning saja)
if not Path(PATHS["DINO_DIR"]).exists():
    print(f"WARNING: DINO dir tidak ditemukan: {PATHS['DINO_DIR']}")

# ============================================================
# 6) Print summary + export helpers
# ============================================================
SELECTED = {
    "ART_DIR": str(ART_DIR),
    "CACHE_DIRS": [str(x) for x in CACHE_DIRS],
    "MATCH_CFG_DIR": str(MATCH_CFG_DIR),
    "PRED_CFG_DIR": str(PRED_CFG_DIR),
    "DINO_CFG_DIR": str(DINO_CFG_DIR) if DINO_CFG_DIR else "",
    "DINO_DIR": str(DINO_DIR),
    "TOPK_MATCH_CFGS": TOPK_MATCH_CFGS,
    "TOPK_PRED_CFGS": TOPK_PRED_CFGS,
}

# Training plan: disiapkan untuk upgrade lanjutan (stacking + calibration + threshold tune)
TRAIN_PLAN = {
    "seed": 2025,
    "group_col": "case_id",
    "target_col": "y_forged",
    "n_folds": 5,
    "use_calibration": True,
    "calibration": "isotonic",  # robust untuk tabular probs (OOF)
    "tune_threshold_on_oof": True,
    "multi_cfg": {
        "enabled": True,
        "topk_match_cfgs": TOPK_MATCH_CFGS,
        "topk_pred_cfgs": TOPK_PRED_CFGS,
        "primary_match": MATCH_CFG_DIR.name,
        "primary_pred": PRED_CFG_DIR.name,
    },
}

def _print_top_cfg_table(title: str, info_list: list, max_rows: int = 5):
    print(title)
    if not info_list:
        print("  (none)")
        return
    show = info_list[:max_rows]
    for i, x in enumerate(show, 1):
        d = x["dir"]
        ph = x["prefer_hits"]
        dh = x["dir_hits"]
        tr = x["train_rows"]
        print(f"  #{i:02d} {d.name} | score={x['score']:.1f} | train_rows={tr} | prefer_hits={ph} | dir_hits={dh}")

print("OK — Roots")
print("  COMP_ROOT   :", COMP_ROOT)
print("  OUT_DS_ROOT :", OUT_DS_ROOT)
print("  OUT_ROOT    :", OUT_ROOT)
print("  ART_DIR(use):", ART_DIR)
print("  CACHE_DIRS  :", [str(x) for x in CACHE_DIRS])

print("\nOK — Selected CFG (PRIMARY)")
print("  MATCH_CFG_DIR:", MATCH_CFG_DIR.name)
print("  PRED_CFG_DIR :", PRED_CFG_DIR.name)
print("  DINO_CFG_DIR :", (DINO_CFG_DIR.name if DINO_CFG_DIR else "(not found / optional)"))

print("\nOK — Top candidates (for MULTI-CFG training)")
_print_top_cfg_table("MATCH CFG TOP:", MATCH_CFG_INFO, max_rows=min(5, len(MATCH_CFG_INFO)))
_print_top_cfg_table("PRED  CFG TOP:", PRED_CFG_INFO,  max_rows=min(5, len(PRED_CFG_INFO)))

print("\nOK — Key files (train)")
for k in ["DF_TRAIN_ALL", "CV_CASE_FOLDS", "MATCH_FEAT_TRAIN", "PRED_FEAT_TRAIN", "IMG_PROFILE_TRAIN"]:
    p = Path(PATHS[k])
    print(f"  {k:16s}: {p}  {'(exists)' if p.exists() else '(missing/optional)'}")

print("\nOK — Key files (test/infer, optional)")
for k in ["MATCH_FEAT_TEST", "PRED_FEAT_TEST", "PRED_MAN_TEST", "PRED_SUMMARY"]:
    p = Path(PATHS[k])
    print(f"  {k:16s}: {p}  {'(exists)' if p.exists() else '(missing)'}")

print("\nOK — DINO model dir")
print("  DINO_DIR:", DINO_DIR, "(exists)" if DINO_DIR.exists() else "(missing)")

# export globals (kompatibilitas)
globals().update({
    "MATCH_CFG_DIR": MATCH_CFG_DIR,
    "PRED_CFG_DIR": PRED_CFG_DIR,
    "DINO_CFG_DIR": DINO_CFG_DIR,
    "CACHE_ROOTS": [Path(x) for x in CACHE_DIRS],
    "SELECTED": SELECTED,
    "TRAIN_PLAN": TRAIN_PLAN,

    # extra (multi-cfg)
    "MATCH_CFG_DIRS": MATCH_CFG_DIRS,
    "PRED_CFG_DIRS": PRED_CFG_DIRS,
    "MATCH_CFG_INFO": MATCH_CFG_INFO,
    "PRED_CFG_INFO": PRED_CFG_INFO,
})


OK — Roots
  COMP_ROOT   : /kaggle/input/recodai-luc-scientific-image-forgery-detection
  OUT_DS_ROOT : /kaggle/input/recod-ailuc-dinov2-base
  OUT_ROOT    : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc
  ART_DIR(use): /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts
  CACHE_DIRS  : ['/kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache']

OK — Selected CFG (PRIMARY)
  MATCH_CFG_DIR: match_base_cfg_f9f7ea3a65c5
  PRED_CFG_DIR : pred_base_v3_v7_cfg_5dbf0aa165
  DINO_CFG_DIR : cfg_3246fd54aab0

OK — Top candidates (for MULTI-CFG training)
MATCH CFG TOP:
  #01 match_base_cfg_f9f7ea3a65c5 | score=3007202.8 | train_rows=5176 | prefer_hits=3 | dir_hits=0
PRED  CFG TOP:
  #01 pred_base_v3_v7_cfg_5dbf0aa165 | score=4407206.0 | train_rows=5176 | prefer_hits=4 | dir_hits=2

OK — Key files (train)
  DF_TRAIN_ALL    : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/df_train_all.parquet  (exists)
  CV_CASE_FOLDS   : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/ar

# Build Training Table (X, y, folds)

In [2]:
# ============================================================
# STEP 2 — Build Training Table (X, y, folds) — REVISI FULL v3.0 (MAX-UPGRADE, robust, anti-error)
#
# Upgrade utama v3.0 (sesuai strategi naik score):
# - MULTI-CFG support (dari STAGE 0 v3.0):
#     * Primary CFG: load FULL pred_features + (opsional) FULL match_features
#     * Extra CFGs: load CORE columns saja + buat agregasi row-wise (min/max/mean) -> stabil & kuat
# - Label hygiene:
#     * Drop unlabeled (y not in {0,1}) -> menghindari noise supplemental y=-1
# - Robust parquet read (safe columns)
# - Feature engineering lebih kuat + aman:
#     * missing indicators
#     * clipping caps
#     * logabs transforms
#     * interactions (lebih lengkap)
# - Output tetap kompatibel:
#   globals: df_train_tabular, FEATURE_COLS, X_train, y_train, folds
#   save: feature_cols.json, feature_schema.json, df_train_tabular.parquet
# ============================================================

import os, json, gc, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

# ----------------------------
# 0) Require PATHS
# ----------------------------
if "PATHS" not in globals() or not isinstance(PATHS, dict):
    raise RuntimeError("Missing PATHS. Jalankan dulu STAGE 0 — Set Paths & Select Config (CFG).")

# ----------------------------
# 1) Feature Engineering Config
# ----------------------------
FE_CFG = {
    # sources
    "use_match_features": True,
    "use_image_profile": True,

    # Multi-CFG (requires STAGE0 v3 which sets PATHS['PRED_CFG_DIRS'], PATHS['MATCH_CFG_DIRS'])
    "multi_cfg_enabled": True,
    "multi_cfg_max_pred": int(os.environ.get("MULTI_CFG_MAX_PRED", "6")),   # pakai top-k pred cfg (primary + extra)
    "multi_cfg_max_match": int(os.environ.get("MULTI_CFG_MAX_MATCH", "3")), # pakai top-k match cfg (primary + extra)
    "multi_cfg_extra_mode": "core+agg",  # "core" | "core+agg" (recommended)

    # variant encoding
    "encode_variant_onehot": True,
    "variant_min_count": 1,

    # transforms
    "add_log_features": True,
    "add_sqrt_features": False,
    "add_interactions": True,
    "add_missing_indicators": True,

    # outlier control
    "clip_by_quantile": True,
    "clip_q": 0.999,
    "clip_max_fallback": 1e9,

    # fill
    "fillna_value": 0.0,

    # prune
    "drop_constant_features": True,

    # dtype
    "cast_float32": True,

    # label handling
    "drop_unlabeled": True,          # drop y not in {0,1}
    "positive_value": 1,             # keep standard
}

# ----------------------------
# 2) Prefer WORKING features if exist (kalau regen di /kaggle/working)
# ----------------------------
def _prefer_existing(*paths):
    for p in paths:
        if p is None:
            continue
        p = Path(str(p))
        if p.exists():
            return p
    return Path(str(paths[0])) if paths and paths[0] is not None else Path("")

def _to_str_series(s: pd.Series) -> pd.Series:
    return s.astype(str).replace({"nan": "", "None": ""})

def _ensure_uid(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    if "uid" not in df.columns:
        for alt in ["sample_id", "id", "key"]:
            if alt in df.columns:
                df = df.rename(columns={alt: "uid"})
                break
    if "uid" not in df.columns:
        raise ValueError("Cannot find uid column. Expected 'uid' or 'sample_id'.")
    df["uid"] = _to_str_series(df["uid"])
    return df

def _parse_case_variant_from_uid(uid_s: pd.Series) -> pd.DataFrame:
    uid = _to_str_series(uid_s)
    case1 = uid.str.extract(r"^(\d+)__")[0]
    var1  = uid.str.extract(r"__(.+)$")[0]
    case2 = uid.str.extract(r"^(\d+)_")[0]
    var2  = uid.str.extract(r"_(\w+)$")[0]
    case = case1.fillna(case2)
    var  = var1.fillna(var2).fillna("unk")
    return pd.DataFrame({"case_id": case, "variant": var})

def _ensure_case_variant(df: pd.DataFrame, df_base_map: pd.DataFrame = None) -> pd.DataFrame:
    df = _ensure_uid(df)

    if "case_id" in df.columns:
        df["case_id"] = pd.to_numeric(df["case_id"], errors="coerce")
    if "variant" in df.columns:
        df["variant"] = df["variant"].astype(str).replace({"nan": "unk", "None": "unk"})

    # merge from base map preferred
    if df_base_map is not None and {"uid", "case_id", "variant"}.issubset(df_base_map.columns):
        need_merge = ("case_id" not in df.columns) or ("variant" not in df.columns) or df["case_id"].isna().any()
        if need_merge:
            df = df.merge(df_base_map[["uid", "case_id", "variant"]], on="uid", how="left", suffixes=("", "_base"))
            if "case_id_base" in df.columns:
                df["case_id"] = df["case_id"].fillna(df["case_id_base"])
                df = df.drop(columns=["case_id_base"])
            if "variant_base" in df.columns:
                df["variant"] = df["variant"].where(df["variant"].astype(str).str.len() > 0, df["variant_base"])
                df = df.drop(columns=["variant_base"])

    if ("case_id" not in df.columns) or ("variant" not in df.columns) or df["case_id"].isna().any():
        pv = _parse_case_variant_from_uid(df["uid"])
        if "case_id" not in df.columns:
            df["case_id"] = pd.to_numeric(pv["case_id"], errors="coerce")
        else:
            df["case_id"] = df["case_id"].fillna(pd.to_numeric(pv["case_id"], errors="coerce"))
        if "variant" not in df.columns:
            df["variant"] = pv["variant"]
        else:
            v = df["variant"].astype(str).replace({"nan": "unk", "None": "unk"})
            df["variant"] = v.where(v.str.len() > 0, pv["variant"])

    df["case_id"] = pd.to_numeric(df["case_id"], errors="coerce").astype("Int64")
    df["variant"] = df["variant"].astype(str).replace({"nan": "unk", "None": "unk"})
    return df

def _pick_label_col(df: pd.DataFrame) -> str:
    for cand in ["y_forged", "has_mask", "is_forged", "forged"]:
        if cand in df.columns:
            return cand
    return ""

# parquet safe columns
def _read_parquet_cols_safe(path: Path, desired_cols: list) -> pd.DataFrame:
    path = Path(path)
    try:
        import pyarrow.parquet as pq
        cols = pq.ParquetFile(path).schema.names
        use = [c for c in desired_cols if c in cols]
        if not use:
            return pd.read_parquet(path)
        return pd.read_parquet(path, columns=use)
    except Exception:
        # fallback: read full then subset if possible
        df = pd.read_parquet(path)
        use = [c for c in desired_cols if c in df.columns]
        return df[use].copy() if use else df

# ----------------------------
# Resolve primary paths
# ----------------------------
match_cfg_name = Path(PATHS.get("MATCH_CFG_DIR", "")).name if PATHS.get("MATCH_CFG_DIR") else ""
pred_cfg_name  = Path(PATHS.get("PRED_CFG_DIR", "")).name  if PATHS.get("PRED_CFG_DIR") else ""

WORK_CACHE_ROOT = Path("/kaggle/working/recodai_luc/cache")
match_feat_work = (WORK_CACHE_ROOT / match_cfg_name / "match_features_train_all.csv") if match_cfg_name else None
pred_feat_work  = (WORK_CACHE_ROOT / pred_cfg_name  / "pred_features_train_all.csv")  if pred_cfg_name  else None

PRED_FEAT_TRAIN  = _prefer_existing(pred_feat_work,  PATHS.get("PRED_FEAT_TRAIN", ""))
MATCH_FEAT_TRAIN = _prefer_existing(match_feat_work, PATHS.get("MATCH_FEAT_TRAIN", ""))

DF_TRAIN_ALL      = Path(PATHS["DF_TRAIN_ALL"])
CV_CASE_FOLDS     = Path(PATHS["CV_CASE_FOLDS"])
IMG_PROFILE_TRAIN = Path(PATHS.get("IMG_PROFILE_TRAIN", ""))

for need_name, need_path in [
    ("df_train_all.parquet", DF_TRAIN_ALL),
    ("cv_case_folds.csv", CV_CASE_FOLDS),
    ("pred_features_train_all.csv", PRED_FEAT_TRAIN),
]:
    if not Path(need_path).exists():
        raise FileNotFoundError(f"Missing required file: {need_name} -> {need_path}")

print("Using (PRIMARY):")
print("  DF_TRAIN_ALL     :", DF_TRAIN_ALL)
print("  CV_CASE_FOLDS    :", CV_CASE_FOLDS)
print("  PRED_FEAT_TRAIN  :", PRED_FEAT_TRAIN)
print("  MATCH_FEAT_TRAIN :", MATCH_FEAT_TRAIN, "(optional)" if Path(MATCH_FEAT_TRAIN).exists() else "(missing/skip)")
print("  IMG_PROFILE_TRAIN:", IMG_PROFILE_TRAIN, "(optional)" if Path(IMG_PROFILE_TRAIN).exists() else "(missing/skip)")

# ----------------------------
# 3) Load minimal inputs
# ----------------------------
# df_train_all: ambil kolom minimal (safe)
base_cols_want = ["sample_id", "uid", "case_id", "variant", "y_forged", "has_mask", "is_forged", "forged"]
df_base = _read_parquet_cols_safe(DF_TRAIN_ALL, base_cols_want)

df_cv   = pd.read_csv(CV_CASE_FOLDS)

# primary pred features (full)
df_pred_primary = pd.read_csv(PRED_FEAT_TRAIN, low_memory=False)

# primary match features (optional, full)
df_match_primary = None
if FE_CFG["use_match_features"] and Path(MATCH_FEAT_TRAIN).exists():
    try:
        df_match_primary = pd.read_csv(MATCH_FEAT_TRAIN, low_memory=False)
    except Exception:
        df_match_primary = None

# image profile (optional)
df_prof = None
if FE_CFG["use_image_profile"] and Path(IMG_PROFILE_TRAIN).exists():
    try:
        df_prof = pd.read_parquet(IMG_PROFILE_TRAIN)
    except Exception:
        df_prof = None

# ----------------------------
# 4) Prepare base mapping from df_train_all
# ----------------------------
df_base = df_base.copy()
if "uid" not in df_base.columns:
    if "sample_id" in df_base.columns:
        df_base = df_base.rename(columns={"sample_id": "uid"})
    elif ("case_id" in df_base.columns and "variant" in df_base.columns):
        df_base["uid"] = _to_str_series(df_base["case_id"]) + "__" + _to_str_series(df_base["variant"])

df_base = _ensure_uid(df_base)

if "case_id" in df_base.columns:
    df_base["case_id"] = pd.to_numeric(df_base["case_id"], errors="coerce").astype("Int64")
if "variant" in df_base.columns:
    df_base["variant"] = df_base["variant"].astype(str).replace({"nan": "unk", "None": "unk"})

label_col = _pick_label_col(df_base)
if not label_col:
    raise ValueError("Cannot find label column in df_train_all (y_forged/has_mask/is_forged/forged).")

df_base_map = df_base.drop_duplicates(subset=["uid"], keep="first").copy()

# ----------------------------
# 5) Prepare folds
# ----------------------------
if "case_id" not in df_cv.columns or "fold" not in df_cv.columns:
    raise ValueError("cv_case_folds.csv must contain columns: case_id, fold")

df_cv = df_cv[["case_id", "fold"]].copy()
df_cv["case_id"] = pd.to_numeric(df_cv["case_id"], errors="coerce").astype("Int64")
df_cv["fold"]    = pd.to_numeric(df_cv["fold"], errors="coerce").astype("Int64")
df_cv = df_cv.dropna().astype({"case_id": int, "fold": int}).drop_duplicates("case_id")

# ----------------------------
# 6) Start from PRIMARY pred features (1 row per uid)
# ----------------------------
df_pred_primary = _ensure_case_variant(df_pred_primary, df_base_map=df_base_map)
if df_pred_primary["uid"].duplicated().any():
    df_pred_primary = df_pred_primary.drop_duplicates(subset=["uid"], keep="first").reset_index(drop=True)

df_train = df_pred_primary.copy()

# attach label from base map (sumber paling aman)
df_train = df_train.merge(
    df_base_map[["uid", label_col]].rename(columns={label_col: "y"}),
    on="uid", how="left"
)

if df_train["y"].isna().any():
    miss = int(df_train["y"].isna().sum())
    raise ValueError(f"Label merge produced NaN in y: {miss} rows. Check df_train_all vs pred_features alignment.")

df_train["y"] = pd.to_numeric(df_train["y"], errors="coerce")

# drop unlabeled y not in {0,1} (recommended)
if FE_CFG["drop_unlabeled"]:
    before = len(df_train)
    df_train = df_train[df_train["y"].isin([0, 1])].copy()
    after = len(df_train)
    if before != after:
        print(f"NOTE: Dropped unlabeled rows (y not in {{0,1}}): {before-after} rows")

df_train["y"] = df_train["y"].astype(int)

# attach folds by case_id
df_train = df_train.drop(columns=["fold"], errors="ignore").merge(df_cv, on="case_id", how="left")
if df_train["fold"].isna().any():
    miss = int(df_train["fold"].isna().sum())
    raise ValueError(f"Missing fold after merging cv_case_folds.csv: {miss} rows.")
df_train["fold"] = df_train["fold"].astype(int)

# ----------------------------
# 7) Optional merge PRIMARY match features (new cols only)
# ----------------------------
if df_match_primary is not None:
    dfm = _ensure_case_variant(df_match_primary, df_base_map=df_base_map)
    if dfm["uid"].duplicated().any():
        dfm = dfm.drop_duplicates(subset=["uid"], keep="first").reset_index(drop=True)

    base_cols = set(df_train.columns)
    new_cols = [c for c in dfm.columns if c not in base_cols and c not in ["case_id", "variant"]]
    if new_cols:
        df_train = df_train.merge(dfm[["uid"] + new_cols], on="uid", how="left")

# ----------------------------
# 8) MULTI-CFG: load EXTRA core features from additional CFG dirs (PRED + MATCH)
# ----------------------------
def _short_cfg_tag(cfg_dir: Path, idx: int) -> str:
    # tag pendek agar kolom tidak kepanjangan
    nm = cfg_dir.name
    # ambil hash kalau ada
    m = re.search(r"(cfg_[0-9a-f]{6,})", nm)
    tag = m.group(1) if m else nm
    tag = tag.replace("pred_base_", "p_").replace("match_base_cfg_", "m_")
    tag = re.sub(r"[^0-9a-zA-Z_]+", "_", tag)
    return f"{idx:02d}_{tag[:24]}"

# core cols yang paling berpengaruh untuk gate (stabil across cfg)
CORE_PRED_COLS = [
    "area_frac", "grid_area_frac", "log_pred_area", "pred_area_frac",
    "best_count", "best_mean_sim", "peak_ratio", "best_weight", "best_weight_frac",
    "inlier_ratio", "n_pairs_thr", "n_pairs_mnn", "overmask_tighten_steps",
    "grid_h", "grid_w", "has_peak",
]
CORE_MATCH_COLS = [
    "best_count", "best_mean_sim", "peak_ratio", "best_weight", "best_weight_frac",
    "inlier_ratio", "n_pairs_thr", "n_pairs_mnn", "grid_h", "grid_w", "has_peak",
]

def _load_csv_core(csv_path: Path, core_cols: list) -> pd.DataFrame:
    df = pd.read_csv(csv_path, low_memory=False)
    df = _ensure_uid(df)
    keep = ["uid"] + [c for c in core_cols if c in df.columns]
    df = df[keep].copy()
    return df

# get cfg dirs list from PATHS (stage0 v3)
pred_cfg_dirs = [Path(x) for x in PATHS.get("PRED_CFG_DIRS", [])] if FE_CFG["multi_cfg_enabled"] else []
match_cfg_dirs = [Path(x) for x in PATHS.get("MATCH_CFG_DIRS", [])] if FE_CFG["multi_cfg_enabled"] else []

# ensure primary is first (and exists)
if pred_cfg_dirs:
    # keep only existing dirs
    pred_cfg_dirs = [d for d in pred_cfg_dirs if d.exists()]
    # cap to max
    pred_cfg_dirs = pred_cfg_dirs[:max(1, FE_CFG["multi_cfg_max_pred"])]
else:
    pred_cfg_dirs = [Path(PATHS["PRED_CFG_DIR"])]

if match_cfg_dirs:
    match_cfg_dirs = [d for d in match_cfg_dirs if d.exists()]
    match_cfg_dirs = match_cfg_dirs[:max(1, FE_CFG["multi_cfg_max_match"])]
else:
    match_cfg_dirs = [Path(PATHS["MATCH_CFG_DIR"])] if FE_CFG["use_match_features"] else []

# EXTRA PRED cfgs (exclude primary already loaded)
extra_pred_cfgs = [d for d in pred_cfg_dirs if str(d) != str(Path(PATHS["PRED_CFG_DIR"]))]

pred_core_cols_added = []
pred_core_matrix_cols = {}  # base_name -> list of suffixed cols

for i, cfg_dir in enumerate(extra_pred_cfgs, 1):
    fp = cfg_dir / "pred_features_train_all.csv"
    if not fp.exists():
        continue
    try:
        tag = _short_cfg_tag(cfg_dir, i)
        df_extra = _load_csv_core(fp, CORE_PRED_COLS)
        # ensure joinable to df_train uids
        df_extra = _ensure_case_variant(df_extra, df_base_map=df_base_map)
        # rename core columns with suffix tag
        ren = {}
        for c in df_extra.columns:
            if c == "uid":
                continue
            ren[c] = f"{c}__{tag}"
            pred_core_matrix_cols.setdefault(c, []).append(ren[c])
        df_extra = df_extra.rename(columns=ren)
        # merge
        df_train = df_train.merge(df_extra[["uid"] + list(ren.values())], on="uid", how="left")
        pred_core_cols_added.extend(list(ren.values()))
    except Exception as e:
        print(f"WARNING: skip extra pred cfg {cfg_dir.name} due to error: {e}")

# Aggregate across CFGs for core pred metrics (recommended)
# include primary columns if exist
if FE_CFG["multi_cfg_extra_mode"] == "core+agg":
    for base_name, cols_suff in pred_core_matrix_cols.items():
        # build list: [primary_col(if exists)] + extra cols
        cols_all = []
        if base_name in df_train.columns and pd.api.types.is_numeric_dtype(df_train[base_name]):
            cols_all.append(base_name)
        cols_all.extend([c for c in cols_suff if c in df_train.columns])
        cols_all = [c for c in cols_all if c in df_train.columns]
        if len(cols_all) < 2:
            continue

        # rowwise stats (NaN-safe)
        mat = df_train[cols_all].apply(pd.to_numeric, errors="coerce")
        df_train[f"cfg_mean_{base_name}"] = mat.mean(axis=1, skipna=True)
        df_train[f"cfg_max_{base_name}"]  = mat.max(axis=1, skipna=True)
        df_train[f"cfg_min_{base_name}"]  = mat.min(axis=1, skipna=True)
        df_train[f"cfg_std_{base_name}"]  = mat.std(axis=1, skipna=True)

# EXTRA MATCH cfgs (optional)
match_core_cols_added = []
match_core_matrix_cols = {}

if FE_CFG["use_match_features"]:
    extra_match_cfgs = [d for d in match_cfg_dirs if str(d) != str(Path(PATHS["MATCH_CFG_DIR"]))]
    for i, cfg_dir in enumerate(extra_match_cfgs, 1):
        fp = cfg_dir / "match_features_train_all.csv"
        if not fp.exists():
            continue
        try:
            tag = _short_cfg_tag(cfg_dir, i)
            df_extra = _load_csv_core(fp, CORE_MATCH_COLS)
            df_extra = _ensure_case_variant(df_extra, df_base_map=df_base_map)
            ren = {}
            for c in df_extra.columns:
                if c == "uid":
                    continue
                ren[c] = f"{c}__{tag}"
                match_core_matrix_cols.setdefault(c, []).append(ren[c])
            df_extra = df_extra.rename(columns=ren)
            df_train = df_train.merge(df_extra[["uid"] + list(ren.values())], on="uid", how="left")
            match_core_cols_added.extend(list(ren.values()))
        except Exception as e:
            print(f"WARNING: skip extra match cfg {cfg_dir.name} due to error: {e}")

    if FE_CFG["multi_cfg_extra_mode"] == "core+agg":
        for base_name, cols_suff in match_core_matrix_cols.items():
            cols_all = []
            if base_name in df_train.columns and pd.api.types.is_numeric_dtype(df_train[base_name]):
                cols_all.append(base_name)
            cols_all.extend([c for c in cols_suff if c in df_train.columns])
            cols_all = [c for c in cols_all if c in df_train.columns]
            if len(cols_all) < 2:
                continue
            mat = df_train[cols_all].apply(pd.to_numeric, errors="coerce")
            df_train[f"cfg_mean_match_{base_name}"] = mat.mean(axis=1, skipna=True)
            df_train[f"cfg_max_match_{base_name}"]  = mat.max(axis=1, skipna=True)
            df_train[f"cfg_min_match_{base_name}"]  = mat.min(axis=1, skipna=True)
            df_train[f"cfg_std_match_{base_name}"]  = mat.std(axis=1, skipna=True)

# ----------------------------
# 9) Optional merge image profile by case_id (numeric only, prefixed)
# ----------------------------
if df_prof is not None and "case_id" in df_prof.columns:
    df_prof2 = df_prof.copy()
    df_prof2["case_id"] = pd.to_numeric(df_prof2["case_id"], errors="coerce").astype("Int64")
    df_prof2 = df_prof2.dropna(subset=["case_id"]).astype({"case_id": int})
    df_prof2 = df_prof2.drop_duplicates("case_id")

    prof_num = ["case_id"] + [
        c for c in df_prof2.columns
        if c != "case_id" and pd.api.types.is_numeric_dtype(df_prof2[c])
    ]
    df_prof2 = df_prof2[prof_num].copy()

    ren = {c: f"profile_{c}" for c in df_prof2.columns if c != "case_id"}
    df_prof2 = df_prof2.rename(columns=ren)
    df_train = df_train.merge(df_prof2, on="case_id", how="left")

# ----------------------------
# 10) Feature engineering helpers
# ----------------------------
def _num(s):
    return pd.to_numeric(s, errors="coerce")

def safe_log1p_nonneg(x):
    x = np.asarray(x, dtype=np.float64)
    x = np.where(np.isfinite(x), x, 0.0)
    x = np.clip(x, 0.0, None)
    return np.log1p(x)

def safe_sqrt_nonneg(x):
    x = np.asarray(x, dtype=np.float64)
    x = np.where(np.isfinite(x), x, 0.0)
    x = np.clip(x, 0.0, None)
    return np.sqrt(x)

def get_clip_cap(series: pd.Series, q: float, fallback: float):
    s = _num(series).astype(float).replace([np.inf, -np.inf], np.nan).dropna()
    if len(s) == 0:
        return float(fallback)
    s = s[np.isfinite(s)]
    cap = float(np.quantile(np.abs(s.values), q))
    if (not np.isfinite(cap)) or (cap <= 0):
        return float(fallback)
    return float(cap)

# ----------------------------
# 11) Candidate numeric feature list (pre-fill)
# ----------------------------
TARGET_COLS = {"y", "y_forged", "has_mask", "is_forged", "forged"}
SPLIT_COLS  = {"fold"}
ID_DROP_NUM = {"case_id"}  # jangan jadi feature

# Replace inf -> NaN for numeric cols
for c in df_train.columns:
    if pd.api.types.is_numeric_dtype(df_train[c]):
        df_train[c] = df_train[c].replace([np.inf, -np.inf], np.nan)

# Missing indicators (sebelum fill)
missing_ind_cols = []
if FE_CFG["add_missing_indicators"]:
    for c in df_train.columns:
        if c in TARGET_COLS or c in SPLIT_COLS or c in ID_DROP_NUM:
            continue
        if pd.api.types.is_numeric_dtype(df_train[c]) and df_train[c].isna().any():
            ind = f"isna_{c}"
            df_train[ind] = df_train[c].isna().astype(np.uint8)
            missing_ind_cols.append(ind)

# Heavy-tail candidates
heavy_candidates = set([
    "peak_ratio", "best_weight", "best_count", "best_weight_frac",
    "pair_count", "n_pairs_thr", "n_pairs_mnn", "overmask_tighten_steps",
    "largest_comp", "n_comp", "grid_h", "grid_w",
    "grid_area_frac", "area_frac", "inlier_ratio",
])
for c in df_train.columns:
    cl = c.lower()
    if any(k in cl for k in ["count", "pairs", "weight", "ratio", "area", "comp", "std_"]):
        if pd.api.types.is_numeric_dtype(df_train[c]):
            heavy_candidates.add(c)

clip_caps = {}
if FE_CFG["clip_by_quantile"]:
    for c in sorted(list(heavy_candidates)):
        if c in df_train.columns and pd.api.types.is_numeric_dtype(df_train[c]):
            clip_caps[c] = get_clip_cap(df_train[c], FE_CFG["clip_q"], FE_CFG["clip_max_fallback"])

# Clipped + log/sqrt transforms
if FE_CFG["add_log_features"] or FE_CFG["add_sqrt_features"]:
    for c, cap in clip_caps.items():
        x = _num(df_train[c]).fillna(0.0).astype(float).values
        x = np.clip(x, -cap, cap)
        df_train[f"{c}_cap"] = x.astype(np.float32)

        if FE_CFG["add_log_features"]:
            df_train[f"logabs_{c}"] = safe_log1p_nonneg(np.abs(x)).astype(np.float32)
        if FE_CFG["add_sqrt_features"]:
            df_train[f"sqrtabs_{c}"] = safe_sqrt_nonneg(np.abs(x)).astype(np.float32)

# Interactions (lebih kaya + aman)
if FE_CFG["add_interactions"]:
    def getf(col, default=0.0):
        if col in df_train.columns:
            return _num(df_train[col]).fillna(default).astype(float).values
        return np.full(len(df_train), default, dtype=np.float64)

    best_mean_sim = getf("best_mean_sim", 0.0)
    best_count    = getf("best_count", 0.0)
    peak_ratio    = getf("peak_ratio", 0.0)
    has_peak      = getf("has_peak", 0.0)
    grid_area     = getf("grid_area_frac", 0.0)
    area_frac     = getf("area_frac", 0.0)
    n_pairs_thr   = getf("n_pairs_thr", 0.0)
    n_pairs_mnn   = getf("n_pairs_mnn", 0.0)
    inlier_ratio  = getf("inlier_ratio", 0.0)
    gh = getf("grid_h", 0.0)
    gw = getf("grid_w", 0.0)

    gridN = np.clip(gh * gw, 0.0, None)

    df_train["sim_x_count"]      = (best_mean_sim * best_count).astype(np.float32)
    df_train["peak_x_sim"]       = (peak_ratio * best_mean_sim).astype(np.float32)
    df_train["haspeak_x_sim"]    = (has_peak * best_mean_sim).astype(np.float32)
    df_train["area_x_sim"]       = (grid_area * best_mean_sim).astype(np.float32)
    df_train["area_x_count"]     = (grid_area * best_count).astype(np.float32)
    df_train["mask_grid_ratio"]  = (area_frac / (1e-6 + grid_area)).astype(np.float32)
    df_train["mnn_ratio"]        = (n_pairs_mnn / (1.0 + n_pairs_thr)).astype(np.float32)
    df_train["pairs_per_cell"]   = (n_pairs_thr / (1.0 + gridN)).astype(np.float32)
    df_train["inlier_x_pairs"]   = (inlier_ratio * n_pairs_thr).astype(np.float32)

    # tambahan yang sering bantu stabil:
    df_train["log1p_pairs_thr"]  = safe_log1p_nonneg(n_pairs_thr).astype(np.float32)
    df_train["log1p_best_count"] = safe_log1p_nonneg(best_count).astype(np.float32)
    df_train["log1p_area_frac"]  = safe_log1p_nonneg(np.clip(area_frac, 0, None)).astype(np.float32)

# Fill NaN numeric -> 0
num_cols = [c for c in df_train.columns if pd.api.types.is_numeric_dtype(df_train[c])]
df_train[num_cols] = df_train[num_cols].fillna(FE_CFG["fillna_value"])

# ----------------------------
# 12) Variant encoding (optional one-hot)
# ----------------------------
variant_dummy_cols = []
if FE_CFG["encode_variant_onehot"]:
    vc = df_train["variant"].astype(str).fillna("unk")
    counts = vc.value_counts()
    keep = set(counts[counts >= int(FE_CFG["variant_min_count"])].index.tolist())
    vc = vc.where(vc.isin(keep), other="rare")

    dummies = pd.get_dummies(vc, prefix="v", dummy_na=False).astype(np.uint8)
    variant_dummy_cols = dummies.columns.tolist()
    df_train = pd.concat([df_train, dummies], axis=1)

# ----------------------------
# 13) Select final feature columns (numeric only)
# ----------------------------
feature_cols = []
for c in df_train.columns:
    if not pd.api.types.is_numeric_dtype(df_train[c]):
        continue
    if c in TARGET_COLS or c in SPLIT_COLS or c in ID_DROP_NUM:
        continue
    feature_cols.append(c)

# Drop constant features
dropped_constant = []
if FE_CFG["drop_constant_features"]:
    nun = df_train[feature_cols].nunique(dropna=False)
    nonconst = nun[nun > 1].index.tolist()
    dropped_constant = sorted(set(feature_cols) - set(nonconst))
    feature_cols = nonconst

# Cast float32 for numeric features
if FE_CFG["cast_float32"]:
    df_train[feature_cols] = df_train[feature_cols].astype(np.float32)

FEATURE_COLS = list(feature_cols)

# ----------------------------
# 14) Final outputs
# ----------------------------
base_out_cols = ["uid", "case_id", "variant", "fold", "y"]
df_train_tabular = df_train[base_out_cols + FEATURE_COLS].copy()

X_train = df_train_tabular[FEATURE_COLS]
y_train = df_train_tabular["y"].astype(int)
folds   = df_train_tabular["fold"].astype(int)

print("\nOK — Training table built")
print("  df_train_tabular:", df_train_tabular.shape)
print("  X_train:", X_train.shape, "| y pos%:", float(y_train.mean()) * 100.0)
print("  folds:", int(folds.nunique()), "unique folds")
print("  feature_cols:", int(len(FEATURE_COLS)))
if dropped_constant:
    print("  dropped_constant_features:", len(dropped_constant))
if variant_dummy_cols:
    print("  variant_dummies:", len(variant_dummy_cols))
if missing_ind_cols:
    print("  missing_indicators:", len(missing_ind_cols))
if pred_core_cols_added:
    print("  extra_pred_core_cols:", len(pred_core_cols_added))
if match_core_cols_added:
    print("  extra_match_core_cols:", len(match_core_cols_added))

# hard sanity
if X_train.shape[0] != y_train.shape[0]:
    raise RuntimeError("X_train and y_train row mismatch")
if y_train.isna().any():
    raise RuntimeError("y_train contains NaN")
if folds.isna().any():
    raise RuntimeError("folds contains NaN")

print("\nFeature head:", FEATURE_COLS[:25])
print("Feature tail:", FEATURE_COLS[-15:])

# ----------------------------
# 15) Save reproducible schema
# ----------------------------
OUT_ART = Path("/kaggle/working/recodai_luc_gate_artifacts")
OUT_ART.mkdir(parents=True, exist_ok=True)

with open(OUT_ART / "feature_cols.json", "w") as f:
    json.dump(FEATURE_COLS, f, indent=2)

schema = {
    "fe_cfg": FE_CFG,
    "label_col_source": label_col,
    "clip_caps": clip_caps,
    "dropped_constant_features": dropped_constant,
    "variant_dummy_cols": variant_dummy_cols,
    "missing_indicator_cols": missing_ind_cols,
    "multi_cfg": {
        "pred_cfg_dirs_used": [str(Path(PATHS["PRED_CFG_DIR"]))] + [str(d) for d in extra_pred_cfgs],
        "match_cfg_dirs_used": [str(Path(PATHS["MATCH_CFG_DIR"]))] + ([str(d) for d in match_cfg_dirs if str(d) != str(Path(PATHS["MATCH_CFG_DIR"]))] if FE_CFG["use_match_features"] else []),
        "extra_pred_core_cols_added": pred_core_cols_added[:200],
        "extra_match_core_cols_added": match_core_cols_added[:200],
    },
    "n_features": int(len(FEATURE_COLS)),
    "example_feature_head": FEATURE_COLS[:30],
}
with open(OUT_ART / "feature_schema.json", "w") as f:
    json.dump(schema, f, indent=2)

df_train_tabular.to_parquet(OUT_ART / "df_train_tabular.parquet", index=False)

print(f"\nSaved -> {OUT_ART/'feature_cols.json'}")
print(f"Saved -> {OUT_ART/'feature_schema.json'}")
print(f"Saved -> {OUT_ART/'df_train_tabular.parquet'}")

# export globals
globals().update({
    "df_train_tabular": df_train_tabular,
    "FEATURE_COLS": FEATURE_COLS,
    "X_train": X_train,
    "y_train": y_train,
    "folds": folds,
})


Using (PRIMARY):
  DF_TRAIN_ALL     : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/df_train_all.parquet
  CV_CASE_FOLDS    : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/cv_case_folds.csv
  PRED_FEAT_TRAIN  : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/pred_base_v3_v7_cfg_5dbf0aa165/pred_features_train_all.csv
  MATCH_FEAT_TRAIN : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/match_base_cfg_f9f7ea3a65c5/match_features_train_all.csv (optional)
  IMG_PROFILE_TRAIN: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/image_profile_train.parquet (optional)

OK — Training table built
  df_train_tabular: (5176, 89)
  X_train: (5176, 84) | y pos%: 54.07650695517774
  folds: 5 unique folds
  feature_cols: 84
  dropped_constant_features: 6
  variant_dummies: 3

Feature head: ['has_peak', 'peak_ratio', 'best_weight', 'best_count', 'best_mean_sim', 'n_pairs_thr', 'n_pairs_mnn', 'best_inlier_ratio', 'best_weight_frac', 'inlier_ratio', 'pa

# Build & Export Test Feature Table

In [3]:
# ============================================================
# Step 2.5 — Build & Export Test Feature Table (pred_features_test*)
# ONE CELL (Kaggle-ready) — REVISI FULL v1.0
#
# Tujuan:
# - Bangun tabel fitur TEST dengan skema yang sama seperti training (FEATURE_COLS)
# - Output:
#   /kaggle/working/recodai_luc_gate_artifacts/pred_features_test.csv
#   /kaggle/working/recodai_luc_gate_artifacts/pred_features_test_cfg_<hash>.csv
#
# REQUIRE (minimal):
# - FEATURE_COLS (list)
# - Salah satu dari: df_test_tabular / df_test / PATHS["DF_TEST"] / PATHS["DF_TEST_ALL"]
#
# Catatan:
# - Script ini akan mencoba "mengisi" FEATURE_COLS dari:
#   (1) df_test_tabular jika sudah ada
#   (2) df_test (atau CSV meta test) jika ada
#   (3) scan file fitur test (*.csv/*.parquet) di /kaggle/working dan /kaggle/input
# - Kalau ada kolom fitur yang tetap tidak ketemu, akan diisi 0.0 (dan diwarn)
# ============================================================

import os, json, hashlib, gc
from pathlib import Path
import numpy as np
import pandas as pd

OUT_DIR = Path("/kaggle/working/recodai_luc_gate_artifacts")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# 0) Require FEATURE_COLS
# ----------------------------
if "FEATURE_COLS" not in globals():
    raise RuntimeError("Missing `FEATURE_COLS`. Jalankan Step 2 (Build Training Table) dulu.")
FEATURE_COLS = list(FEATURE_COLS)
if len(FEATURE_COLS) == 0:
    raise RuntimeError("FEATURE_COLS kosong. Cek Step 2.")

# ----------------------------
# 1) Helper: load base test dataframe
# ----------------------------
def _read_table_any(p: Path):
    p = Path(p)
    if not p.exists():
        return None
    if p.suffix.lower() in [".parquet"]:
        return pd.read_parquet(p)
    return pd.read_csv(p)

def _pick_first_existing(paths):
    for p in paths:
        if p is None:
            continue
        p = Path(p)
        if p.exists() and p.is_file():
            return p
    return None

def _get_base_test_df():
    # Priority:
    # 1) df_test_tabular (already has features)
    # 2) df_test (meta)
    # 3) PATHS DF_TEST / DF_TEST_ALL (csv/parquet)
    if "df_test_tabular" in globals() and isinstance(globals()["df_test_tabular"], pd.DataFrame):
        return globals()["df_test_tabular"].copy(), "globals(df_test_tabular)"
    if "df_test" in globals() and isinstance(globals()["df_test"], pd.DataFrame):
        return globals()["df_test"].copy(), "globals(df_test)"
    if "PATHS" in globals() and isinstance(globals()["PATHS"], dict):
        cand = _pick_first_existing([
            globals()["PATHS"].get("DF_TEST", None),
            globals()["PATHS"].get("DF_TEST_ALL", None),
            globals()["PATHS"].get("TEST_META", None),
        ])
        if cand is not None:
            df = _read_table_any(cand)
            if df is not None:
                return df.copy(), f"PATHS({cand})"
    # fallback scan common places
    cand = _pick_first_existing([
        Path("/kaggle/working/df_test.csv"),
        Path("/kaggle/working/test.csv"),
        Path("/kaggle/working/df_test.parquet"),
    ])
    if cand is not None:
        df = _read_table_any(cand)
        if df is not None:
            return df.copy(), f"fallback({cand})"
    raise FileNotFoundError(
        "Tidak menemukan df_test_tabular / df_test / PATHS[DF_TEST]. "
        "Pastikan Step 1/2 sudah membentuk df_test atau PATHS mengarah ke test meta."
    )

df_base, base_src = _get_base_test_df()
print("Base TEST source:", base_src, "| shape:", df_base.shape)

# ----------------------------
# 2) Normalize id columns (uid/case_id/variant)
# ----------------------------
# Try to ensure: uid exists
def _ensure_uid(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    cols = set(df.columns)
    # standardize common id column names
    if "case_id" not in cols:
        for alt in ["case", "caseid", "image_id", "img_id", "id", "sample_id"]:
            if alt in cols:
                df = df.rename(columns={alt: "case_id"})
                cols = set(df.columns)
                break
    if "variant" not in cols:
        for alt in ["var", "aug", "view", "split_variant"]:
            if alt in cols:
                df = df.rename(columns={alt: "variant"})
                cols = set(df.columns)
                break

    if "uid" not in cols:
        if "case_id" in cols and "variant" in cols:
            df["uid"] = df["case_id"].astype(str) + "_" + df["variant"].astype(str)
        elif "case_id" in cols:
            df["uid"] = df["case_id"].astype(str)
        elif "id" in cols:
            df["uid"] = df["id"].astype(str)
        else:
            # last resort: index-based uid
            df["uid"] = np.arange(len(df)).astype(str)

    # ensure string uid
    df["uid"] = df["uid"].astype(str)
    return df

df_base = _ensure_uid(df_base)

id_cols = []
for c in ["uid", "case_id", "variant"]:
    if c in df_base.columns:
        id_cols.append(c)

print("ID columns used:", id_cols)

# ----------------------------
# 3) Build output frame skeleton
# ----------------------------
df_out = df_base[id_cols].copy()

# Fill any features already present in base
present = [c for c in FEATURE_COLS if c in df_base.columns]
if len(present) > 0:
    df_out = df_out.merge(df_base[["uid"] + present].drop_duplicates("uid"), on="uid", how="left")

missing = [c for c in FEATURE_COLS if c not in df_out.columns]
print(f"Features present from base: {len(present)}/{len(FEATURE_COLS)} | missing: {len(missing)}")

# ----------------------------
# 4) Optional: scan and merge external feature tables (CSV/Parquet)
# ----------------------------
def _walk_with_depth(root: Path, max_depth=4):
    root = Path(root)
    root_depth = len(root.parts)
    for cur, dirs, files in os.walk(root):
        cur_p = Path(cur)
        depth = len(cur_p.parts) - root_depth
        if depth > max_depth:
            dirs[:] = []
            continue
        yield cur_p, files

def _find_feature_files(max_files=80):
    patterns = [
        "pred_features_test", "test_features", "features_test", "pred_feat_test", "gate_features_test",
        "match_features_test", "match_feat_test",
    ]
    exts = {".csv", ".parquet"}
    found = []

    # Search working first (usually fastest)
    for root, depth in [(Path("/kaggle/working"), 6), (Path("/kaggle/input"), 4)]:
        if not root.exists():
            continue
        for cur_p, files in _walk_with_depth(root, max_depth=depth):
            for fn in files:
                p = cur_p / fn
                if p.suffix.lower() not in exts:
                    continue
                low = p.name.lower()
                if any(k in low for k in patterns):
                    found.append(p)
                    if len(found) >= max_files:
                        return found
    return found

def _score_feat_file(p: Path, need_cols: set):
    # quick header read for csv; for parquet we must read schema (still ok)
    try:
        if p.suffix.lower() == ".csv":
            dfh = pd.read_csv(p, nrows=5)
        else:
            dfh = pd.read_parquet(p)
            dfh = dfh.head(5)
    except Exception:
        return -1, None

    cols = set(dfh.columns)
    # join keys compatibility score
    key_score = 0
    if "uid" in cols: key_score += 5
    if ("case_id" in cols) and ("variant" in cols): key_score += 3
    if "case_id" in cols: key_score += 1

    overlap = len(cols & need_cols)
    score = key_score * 1000 + overlap  # prioritize joinability heavily
    return score, list(cols)

def _infer_join_keys(dfA: pd.DataFrame, dfB: pd.DataFrame):
    A = set(dfA.columns); B = set(dfB.columns)
    for keys in (["uid"], ["case_id", "variant"], ["case_id"], ["id"]):
        if all(k in A for k in keys) and all(k in B for k in keys):
            return keys
    return None

if len(missing) > 0:
    need_cols = set(["uid", "case_id", "variant"] + FEATURE_COLS)
    files = _find_feature_files(max_files=80)

    scored = []
    for p in files:
        sc, cols = _score_feat_file(p, need_cols)
        if sc > 0:
            scored.append((sc, p, cols))
    scored.sort(reverse=True, key=lambda x: x[0])

    print(f"External feature files found: {len(scored)} (after scoring)")

    # merge in descending quality until missing small or list exhausted
    for rank, (sc, p, cols) in enumerate(scored[:12], 1):
        if len(missing) == 0:
            break
        try:
            df_feat = _read_table_any(p)
            if df_feat is None or len(df_feat) == 0:
                continue
            df_feat = _ensure_uid(df_feat)

            join_keys = _infer_join_keys(df_out, df_feat)
            if join_keys is None:
                continue

            # keep only useful columns
            use_cols = list(dict.fromkeys(join_keys + [c for c in FEATURE_COLS if c in df_feat.columns]))
            if len(use_cols) <= len(join_keys):
                continue

            df_feat = df_feat[use_cols].copy()

            # drop dup on join keys
            if len(join_keys) == 1:
                df_feat = df_feat.drop_duplicates(join_keys[0])
            else:
                df_feat = df_feat.drop_duplicates(join_keys)

            before_missing = len(missing)
            df_out = df_out.merge(df_feat, on=join_keys, how="left", suffixes=("", "_dup"))

            # if any *_dup created, prefer existing non-null then drop dup
            dup_cols = [c for c in df_out.columns if c.endswith("_dup")]
            for dc in dup_cols:
                basec = dc[:-4]
                if basec in df_out.columns:
                    df_out[basec] = df_out[basec].fillna(df_out[dc])
                df_out = df_out.drop(columns=[dc])

            missing = [c for c in FEATURE_COLS if c not in df_out.columns]
            gained = before_missing - len(missing)
            print(f"  [{rank}] merged: {p} | join_keys={join_keys} | gained_cols={gained} | missing_now={len(missing)}")

            del df_feat
            gc.collect()
        except Exception as e:
            print(f"  [{rank}] skip (read/merge error): {p} | err={repr(e)}")
            continue

# ----------------------------
# 5) Finalize: ensure all FEATURE_COLS exist, numeric float32, fillna(0)
# ----------------------------
still_missing = [c for c in FEATURE_COLS if c not in df_out.columns]
if len(still_missing) > 0:
    # add missing columns as 0.0
    for c in still_missing:
        df_out[c] = 0.0

# coerce feature cols to numeric float32
for c in FEATURE_COLS:
    if c in df_out.columns:
        df_out[c] = pd.to_numeric(df_out[c], errors="coerce").fillna(0.0).astype(np.float32)

# order columns: ids then features
df_out = df_out[id_cols + FEATURE_COLS].copy()

# sanity
assert df_out.shape[0] == df_base.shape[0], "Row count changed unexpectedly after merges."
assert df_out["uid"].astype(str).nunique() <= df_out.shape[0], "uid issue."
print("\nFinal TEST feature table:", df_out.shape)

# Report missing originally (if any)
if len(still_missing) > 0:
    print(f"WARNING: {len(still_missing)} feature cols were not found anywhere and set to 0.0")
    print("  examples:", still_missing[:25])

# ----------------------------
# 6) Save pred_features_test*
# ----------------------------
cfg_hash = hashlib.sha1(json.dumps(FEATURE_COLS, ensure_ascii=False).encode("utf-8")).hexdigest()[:12]
p_main = OUT_DIR / "pred_features_test.csv"
p_cfg  = OUT_DIR / f"pred_features_test_cfg_{cfg_hash}.csv"

df_out.to_csv(p_main, index=False)
df_out.to_csv(p_cfg, index=False)

print("\nSaved:")
print("  ->", p_main)
print("  ->", p_cfg)

# Export global for next steps
df_test_tabular = df_out
PRED_FEATURES_TEST_CSV = str(p_main)
PRED_FEATURES_TEST_CFG_CSV = str(p_cfg)
PRED_FEATURES_TEST_CFG_HASH = cfg_hash


Base TEST source: PATHS(/kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/df_test.parquet) | shape: (1, 3)
ID columns used: ['uid', 'case_id']
Features present from base: 0/84 | missing: 84
External feature files found: 3 (after scoring)
  [1] merged: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/pred_base_v3_v7_cfg_5dbf0aa165/pred_features_test_cfg_5dbf0aa165.csv | join_keys=['uid'] | gained_cols=24 | missing_now=60
  [2] merged: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/pred_base_v3_v7_cfg_5dbf0aa165/pred_features_test.csv | join_keys=['uid'] | gained_cols=0 | missing_now=60
  [3] merged: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/match_base_cfg_f9f7ea3a65c5/match_features_test.csv | join_keys=['uid'] | gained_cols=4 | missing_now=56

Final TEST feature table: (1, 86)
  examples: ['profile_img_H', 'profile_img_W', 'profile_aspect', 'profile_is_gray', 'profile_bg_white_frac', 'profile_roi_x0', 'profile_roi_y0', 'profile_roi_x1', 'profile_

# Train Baseline Model (Leakage-Safe CV)

In [4]:
# ============================================================
# Step 3 — Train Stronger Model (Leakage-Safe CV) — REVISI FULL v3.0 (MAX-UPGRADE, robust)
#
# Upgrade v3.0 (sesuai strategi naik score & anti error):
# - Multi-SEED bagging (auto: 1 seed CPU/small GPU, 2 seed untuk GPU besar; bisa override ENV N_SEEDS)
# - Leakage-safe PER-FOLD calibration (isotonic/sigmoid/none) -> oof_cal lebih stabil untuk thresholding
# - Simpan OOF RAW + OOF CAL (csv + npy) + fold packs lengkap (mu/sig + calibrator + cfg + metrics)
# - Threshold search lebih kaya: F1 + F0.5 + F2 (default pilih F0.5 utk menekan FP)
# - Better early stopping: monitor val_logloss (EMA-eval) + simpan best_state (EMA weights)
# - Guard kuat: NaN/Inf, fold sanity, amp safe, workers safe
#
# Output globals:
# - OOF_PRED_MHC_RAW, OOF_PRED_MHC_CAL (np.ndarray)
# - BASELINE_MHC_TF_OVERALL, BASELINE_MHC_TF_FOLD_REPORTS, BASELINE_MHC_TF_BEST_EPOCHS
# - FULL_PACK_PATH (str)
# ============================================================

import json, gc, math, time, os, warnings
from pathlib import Path

import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

from IPython.display import display
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, log_loss

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# 0) Require outputs from Step 2
# ----------------------------
need_vars = ["df_train_tabular", "FEATURE_COLS"]
for v in need_vars:
    if v not in globals():
        raise RuntimeError(f"Missing `{v}`. Jalankan dulu Step 2 — Build Training Table (X, y, folds).")

df_train_tabular = df_train_tabular.copy()
FEATURE_COLS = list(FEATURE_COLS)

required_cols = {"uid", "case_id", "variant", "fold", "y"}
missing_cols = [c for c in required_cols if c not in df_train_tabular.columns]
if missing_cols:
    raise ValueError(f"df_train_tabular missing columns: {missing_cols}.")

# ----------------------------
# 1) CFG (AUTO: CPU/SAFE/STRONG)
# ----------------------------
CFG_CPU = {
    "seed": 2025,

    "n_streams": 2,
    "sinkhorn_tmax": 10,
    "alpha_init": 0.01,

    "d_model": 192,
    "n_layers": 4,
    "n_heads": 6,
    "ffn_mult": 4,
    "dropout": 0.12,
    "attn_dropout": 0.08,

    "feat_token_drop_p": 0.05,
    "input_noise_std": 0.008,
    "label_smoothing": 0.00,
    "focal_gamma": 1.2,

    "batch_size": 256,
    "accum_steps": 2,
    "epochs": 45,

    "lr": 3e-4,
    "betas": (0.9, 0.95),
    "eps": 1e-8,
    "weight_decay": 5e-2,

    "warmup_frac": 0.08,
    "lr_decay_milestones": (0.80, 0.90),
    "lr_decay_values": (0.316, 0.10),

    "grad_clip": 1.0,

    "early_stop_patience": 10,
    "early_stop_min_delta": 1e-4,

    "use_ema": True,
    "ema_decay": 0.999,

    # calibration (fold-wise; leakage-safe)
    "use_calibration": True,
    "calibration": None,  # auto from TRAIN_PLAN if available else "isotonic"
    "calib_min_samples": 200,  # isotonic butuh cukup sample

    # reporting/thresholding
    "search_best_thr": True,
    "thr_grid": 801,          # lebih rapat, tetap aman
    "thr_objective": "f0.5",  # "f1" | "f0.5" | "f2"  (f0.5 menekan FP)
    "report_thr": 0.5,
}

CFG_SAFE = {
    **CFG_CPU,
    "n_streams": 4,
    "sinkhorn_tmax": 20,

    "d_model": 256,
    "n_layers": 6,
    "n_heads": 8,

    "epochs": 60,
    "lr": 3e-4,
    "weight_decay": 5e-2,

    "feat_token_drop_p": 0.05,
    "input_noise_std": 0.01,
    "focal_gamma": 1.5,

    "batch_size": 256,
    "accum_steps": 2,
}

CFG_STRONG = {
    **CFG_SAFE,
    "d_model": 384,
    "n_layers": 8,
    "dropout": 0.16,
    "epochs": 80,
    "lr": 2e-4,
    "weight_decay": 7e-2,
    "feat_token_drop_p": 0.06,
    "input_noise_std": 0.012,
    "focal_gamma": 1.5,
}

# auto select by device/memory
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")

CFG = dict(CFG_SAFE)
CFG_NAME = "SAFE"
if device.type != "cuda":
    CFG = dict(CFG_CPU)
    CFG_NAME = "CPU"
else:
    try:
        mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        if mem_gb >= 30:
            CFG = dict(CFG_STRONG)
            CFG_NAME = "STRONG"
    except Exception:
        pass

# calibration preference from TRAIN_PLAN if available
if "TRAIN_PLAN" in globals() and isinstance(TRAIN_PLAN, dict):
    if CFG.get("use_calibration", True):
        CFG["use_calibration"] = bool(TRAIN_PLAN.get("use_calibration", True))
        CFG["calibration"] = TRAIN_PLAN.get("calibration", CFG.get("calibration", None))
# default calibration
if CFG.get("use_calibration", True) and (CFG.get("calibration") is None):
    CFG["calibration"] = "isotonic"  # default terbaik untuk tabular probs

# ----------------------------
# 2) Seed + device optim
# ----------------------------
def seed_everything(seed: int = 2025):
    import random
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_everything(int(CFG["seed"]))

print("Device:", device, "| AMP:", use_amp, "| CFG:", CFG_NAME)
if device.type == "cuda":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# ----------------------------
# 3) Build arrays + guard
# ----------------------------
X = df_train_tabular[FEATURE_COLS].to_numpy(dtype=np.float32, copy=True)
y = df_train_tabular["y"].to_numpy(dtype=np.int64, copy=True)
folds = df_train_tabular["fold"].to_numpy(dtype=np.int64, copy=True)

if not np.isfinite(X).all():
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

n = len(df_train_tabular)
unique_folds = sorted(pd.Series(folds).unique().tolist())
n_folds = len(unique_folds)
n_features = X.shape[1]

print("Setup:")
print("  rows      :", n)
print("  folds     :", n_folds, "|", unique_folds)
print("  pos%      :", float(y.mean()) * 100.0)
print("  n_features:", n_features)
if n_folds < 2:
    raise ValueError("Need >=2 folds. Check cv_case_folds.")

# ----------------------------
# 4) Dataset
# ----------------------------
class TabDataset(Dataset):
    def __init__(self, X, y=None):
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = None if y is None else torch.from_numpy(y.astype(np.float32))

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        if self.y is None:
            return self.X[idx]
        return self.X[idx], self.y[idx]

# ----------------------------
# 5) Normalization (leakage-safe)
# ----------------------------
def fit_standardizer(X_tr: np.ndarray):
    mu = X_tr.mean(axis=0, dtype=np.float64)
    sig = X_tr.std(axis=0, dtype=np.float64)
    sig = np.where(sig < 1e-8, 1.0, sig)
    return mu.astype(np.float32), sig.astype(np.float32)

def apply_standardizer(X_in: np.ndarray, mu: np.ndarray, sig: np.ndarray):
    return ((X_in - mu) / sig).astype(np.float32)

# ----------------------------
# 6) Metrics helpers
# ----------------------------
def safe_auc(y_true, p):
    y_true = np.asarray(y_true)
    if len(np.unique(y_true)) < 2:
        return None
    return float(roc_auc_score(y_true, p))

def safe_logloss(y_true, p):
    p = np.clip(np.asarray(p, dtype=np.float64), 1e-6, 1-1e-6)
    return float(log_loss(y_true, p, labels=[0, 1]))

def fbeta_np(y_true, yhat, beta=1.0):
    # y_true,yhat in {0,1}
    y_true = np.asarray(y_true).astype(int)
    yhat = np.asarray(yhat).astype(int)
    tp = int(((y_true == 1) & (yhat == 1)).sum())
    fp = int(((y_true == 0) & (yhat == 1)).sum())
    fn = int(((y_true == 1) & (yhat == 0)).sum())
    if tp == 0:
        return 0.0
    b2 = beta * beta
    return float((1 + b2) * tp / max(1e-12, (1 + b2) * tp + b2 * fn + fp))

def find_best_threshold(y_true, p, n_grid=801, objective="f0.5"):
    """
    objective: f1 / f0.5 / f2
    """
    y_true = np.asarray(y_true).astype(int)
    p = np.asarray(p).astype(np.float64)
    if objective == "f2":
        beta = 2.0
    elif objective == "f1":
        beta = 1.0
    else:
        beta = 0.5

    best = {"thr": 0.5, "score": -1.0, "f1": 0.0, "f05": 0.0, "f2": 0.0, "precision": 0.0, "recall": 0.0}
    for thr in np.linspace(0.0, 1.0, int(n_grid)):
        yh = (p >= thr).astype(int)
        sc = fbeta_np(y_true, yh, beta=beta)
        if sc > best["score"]:
            best["thr"] = float(thr)
            best["score"] = float(sc)
            best["f1"] = float(f1_score(y_true, yh, zero_division=0))
            best["f05"] = float(fbeta_np(y_true, yh, beta=0.5))
            best["f2"] = float(fbeta_np(y_true, yh, beta=2.0))
            best["precision"] = float(precision_score(y_true, yh, zero_division=0))
            best["recall"] = float(recall_score(y_true, yh, zero_division=0))
    best["objective"] = str(objective)
    return best

# ----------------------------
# 7) Sinkhorn projection
# ----------------------------
def sinkhorn_doubly_stochastic(logits: torch.Tensor, tmax: int = 20, eps: float = 1e-6):
    z = logits - logits.max()
    M = torch.exp(z)
    for _ in range(int(tmax)):
        M = M / (M.sum(dim=-1, keepdim=True) + eps)
        M = M / (M.sum(dim=-2, keepdim=True) + eps)
    return M

# ----------------------------
# 8) RMSNorm
# ----------------------------
class RMSNorm(nn.Module):
    def __init__(self, d: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d))

    def forward(self, x):
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
        return x * rms * self.weight

# ----------------------------
# 9) EMA helper
# ----------------------------
class EMA:
    def __init__(self, model: nn.Module, decay: float = 0.999):
        self.decay = float(decay)
        self.shadow = {}
        self.backup = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.detach().clone()

    @torch.no_grad()
    def update(self, model: nn.Module):
        d = self.decay
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.shadow[name].mul_(d).add_(p.detach(), alpha=(1.0 - d))

    @torch.no_grad()
    def apply_shadow(self, model: nn.Module):
        self.backup = {}
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.backup[name] = p.detach().clone()
            p.copy_(self.shadow[name])

    @torch.no_grad()
    def restore(self, model: nn.Module):
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            p.copy_(self.backup[name])
        self.backup = {}

# ----------------------------
# 10) Model blocks (feature token dropout)
# ----------------------------
class MHCAttnBlock(nn.Module):
    def __init__(self, d_model, n_heads, ffn_mult, dropout, attn_dropout,
                 n_streams=4, sinkhorn_tmax=20, alpha_init=0.01):
        super().__init__()
        self.n_streams = int(n_streams)
        self.sinkhorn_tmax = int(sinkhorn_tmax)

        self.norm1 = RMSNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=attn_dropout, batch_first=True)
        self.drop1 = nn.Dropout(dropout)

        self.norm2 = RMSNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ffn_mult * d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ffn_mult * d_model, d_model),
        )
        self.drop2 = nn.Dropout(dropout)

        self.h_logits = nn.Parameter(torch.zeros(self.n_streams, self.n_streams))
        nn.init.zeros_(self.h_logits)

        a0 = float(alpha_init)
        a0 = min(max(a0, 1e-4), 1 - 1e-4)
        self.alpha_logit = nn.Parameter(torch.log(torch.tensor(a0 / (1 - a0), dtype=torch.float32)))

    def forward(self, streams):
        # streams: (B, n_streams, S, D)
        B, nS, S, D = streams.shape

        x = streams[:, 0]  # (B,S,D)

        x0 = x
        q = self.norm1(x)
        attn_out, _ = self.attn(q, q, q, need_weights=False)
        x = x0 + self.drop1(attn_out)

        x1 = x
        h = self.norm2(x)
        h = self.ffn(h)
        x = x1 + self.drop2(h)

        # minimal clone (avoid in-place on view)
        streams = streams.clone()
        streams[:, 0] = x

        H = sinkhorn_doubly_stochastic(self.h_logits, tmax=self.sinkhorn_tmax)
        alpha = torch.sigmoid(self.alpha_logit).to(dtype=streams.dtype)
        I = torch.eye(self.n_streams, device=streams.device, dtype=streams.dtype)
        Hres = (1.0 - alpha) * I + alpha * H.to(dtype=streams.dtype)

        mixed = torch.einsum("ij,bjtd->bitd", Hres, streams)
        return mixed

class MHCFTTransformer(nn.Module):
    def __init__(self, n_features,
                 d_model=256, n_heads=8, n_layers=6, ffn_mult=4,
                 dropout=0.12, attn_dropout=0.08,
                 n_streams=4, sinkhorn_tmax=20, alpha_init=0.01,
                 feat_token_drop_p=0.0):
        super().__init__()
        self.n_features = int(n_features)
        self.d_model = int(d_model)
        self.feat_token_drop_p = float(feat_token_drop_p)

        self.w = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)
        self.b = nn.Parameter(torch.zeros(self.n_features, self.d_model))
        self.feat_emb = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)

        self.cls = nn.Parameter(torch.randn(1, 1, self.d_model) * 0.02)
        self.in_drop = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            MHCAttnBlock(
                d_model=self.d_model,
                n_heads=n_heads,
                ffn_mult=ffn_mult,
                dropout=dropout,
                attn_dropout=attn_dropout,
                n_streams=n_streams,
                sinkhorn_tmax=sinkhorn_tmax,
                alpha_init=alpha_init
            )
            for _ in range(int(n_layers))
        ])

        self.out_norm = RMSNorm(self.d_model)
        self.head = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(self.d_model, 1),
        )

        self.n_streams = int(n_streams)

    def forward(self, x):
        # x: (B,F)
        tok = x.unsqueeze(-1) * self.w.unsqueeze(0) + self.b.unsqueeze(0)   # (B,F,D)
        tok = tok + self.feat_emb.unsqueeze(0)

        # feature-token dropout (do not drop CLS)
        if self.training and self.feat_token_drop_p > 0:
            B, F, D = tok.shape
            keep = (torch.rand(B, F, device=tok.device) > self.feat_token_drop_p).to(tok.dtype)
            tok = tok * keep.unsqueeze(-1)

        B = tok.size(0)
        cls = self.cls.expand(B, -1, -1)                                    # (B,1,D)
        seq = torch.cat([cls, tok], dim=1)                                  # (B,1+F,D)
        seq = self.in_drop(seq)

        streams = seq.unsqueeze(1).repeat(1, self.n_streams, 1, 1)          # (B,nS,S,D)
        for blk in self.blocks:
            streams = blk(streams)

        z = streams[:, 0, 0]
        z = self.out_norm(z)
        logit = self.head(z).squeeze(-1)
        return logit

# ----------------------------
# 11) LR Scheduler
# ----------------------------
def make_warmup_step_scheduler(optimizer, total_steps: int, warmup_steps: int,
                              milestones_frac=(0.8, 0.9), decay_values=(0.316, 0.1)):
    m1 = int(float(milestones_frac[0]) * total_steps)
    m2 = int(float(milestones_frac[1]) * total_steps)
    d1 = float(decay_values[0])
    d2 = float(decay_values[1])

    def lr_lambda(step):
        if step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))
        if step < m1:
            return 1.0
        elif step < m2:
            return d1
        else:
            return d2

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# ----------------------------
# 12) Loss: BCE / Focal-BCE
# ----------------------------
def focal_bce_with_logits(logits, targets, pos_weight=None, gamma=0.0):
    bce = F.binary_cross_entropy_with_logits(
        logits, targets, reduction="none", pos_weight=pos_weight
    )
    if gamma and gamma > 0:
        p = torch.sigmoid(logits)
        p_t = p * targets + (1 - p) * (1 - targets)
        mod = (1.0 - p_t).clamp_min(0.0).pow(gamma)
        bce = bce * mod
    return bce.mean()

# ----------------------------
# 13) Predict helper (optionally with EMA)
# ----------------------------
@torch.no_grad()
def predict_proba(model, loader, ema: EMA = None):
    model.eval()
    if ema is not None:
        ema.apply_shadow(model)
    probs = []
    for batch in loader:
        xb = batch[0] if isinstance(batch, (list, tuple)) else batch
        xb = xb.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=use_amp):
            logits = model(xb)
            p = torch.sigmoid(logits)
        probs.append(p.detach().cpu().numpy())
    out = np.concatenate(probs, axis=0).astype(np.float32)
    if ema is not None:
        ema.restore(model)
    return out

# ----------------------------
# 14) Calibration (fold-wise; leakage-safe)
#   - isotonic: store (x,y) knots -> apply via np.interp
#   - sigmoid: store (a,b) for calibrated_logit = a*logit(p)+b
# ----------------------------
def _fit_isotonic(p, y, min_samples=200):
    from sklearn.isotonic import IsotonicRegression
    p = np.asarray(p, dtype=np.float64)
    y = np.asarray(y, dtype=np.int32)
    ok = np.isfinite(p)
    p = p[ok]; y = y[ok]
    if len(p) < int(min_samples) or len(np.unique(y)) < 2:
        return None
    iso = IsotonicRegression(y_min=0.0, y_max=1.0, increasing=True, out_of_bounds="clip")
    iso.fit(p, y)
    # store knots for portable inference
    return {"type": "isotonic", "x": iso.X_thresholds_.astype(np.float32).tolist(),
            "y": iso.y_thresholds_.astype(np.float32).tolist()}

def _apply_isotonic(cal, p):
    x = np.asarray(cal["x"], dtype=np.float32)
    yk = np.asarray(cal["y"], dtype=np.float32)
    p = np.asarray(p, dtype=np.float32)
    p = np.clip(p, 0.0, 1.0)
    return np.interp(p, x, yk).astype(np.float32)

def _fit_sigmoid(p, y, min_samples=200):
    from sklearn.linear_model import LogisticRegression
    p = np.asarray(p, dtype=np.float64)
    y = np.asarray(y, dtype=np.int32)
    ok = np.isfinite(p)
    p = p[ok]; y = y[ok]
    if len(p) < int(min_samples) or len(np.unique(y)) < 2:
        return None
    p = np.clip(p, 1e-6, 1-1e-6)
    logit = np.log(p / (1 - p)).reshape(-1, 1)
    lr = LogisticRegression(C=1000.0, solver="lbfgs", max_iter=200)
    lr.fit(logit, y)
    a = float(lr.coef_.ravel()[0])
    b = float(lr.intercept_.ravel()[0])
    return {"type": "sigmoid", "a": a, "b": b}

def _apply_sigmoid(cal, p):
    p = np.asarray(p, dtype=np.float64)
    p = np.clip(p, 1e-6, 1-1e-6)
    logit = np.log(p / (1 - p))
    z = cal["a"] * logit + cal["b"]
    out = 1.0 / (1.0 + np.exp(-z))
    return out.astype(np.float32)

def fit_calibrator(calib_type, p_tr, y_tr, min_samples=200):
    if calib_type is None or str(calib_type).lower() in ["none", "off", "false"]:
        return None
    calib_type = str(calib_type).lower()
    if calib_type == "isotonic":
        return _fit_isotonic(p_tr, y_tr, min_samples=min_samples)
    elif calib_type in ["sigmoid", "platt"]:
        return _fit_sigmoid(p_tr, y_tr, min_samples=min_samples)
    else:
        return None

def apply_calibrator(cal, p):
    if cal is None:
        return np.asarray(p, dtype=np.float32)
    if cal["type"] == "isotonic":
        return _apply_isotonic(cal, p)
    if cal["type"] == "sigmoid":
        return _apply_sigmoid(cal, p)
    return np.asarray(p, dtype=np.float32)

# ----------------------------
# 15) Train one fold (EMA + focal + noise + best_state + fold calibration)
# ----------------------------
def train_one_fold(X_tr, y_tr, X_va, y_va, cfg):
    mu, sig = fit_standardizer(X_tr)
    X_trn = apply_standardizer(X_tr, mu, sig)
    X_van = apply_standardizer(X_va, mu, sig)

    ds_tr = TabDataset(X_trn, y_tr)
    ds_va = TabDataset(X_van, y_va)

    # deterministic-ish loader config
    cpu_cnt = os.cpu_count() or 2
    nw = 2 if cpu_cnt >= 4 else 0
    pin = (device.type == "cuda")

    dl_tr = DataLoader(ds_tr, batch_size=int(cfg["batch_size"]), shuffle=True,
                       num_workers=nw, pin_memory=pin, drop_last=False,
                       persistent_workers=(nw > 0))
    dl_va = DataLoader(ds_va, batch_size=int(cfg["batch_size"]), shuffle=False,
                       num_workers=nw, pin_memory=pin, drop_last=False,
                       persistent_workers=(nw > 0))
    # for calibration: need train preds with shuffle False
    dl_tr_eval = DataLoader(ds_tr, batch_size=int(cfg["batch_size"]), shuffle=False,
                            num_workers=nw, pin_memory=pin, drop_last=False,
                            persistent_workers=(nw > 0))

    model = MHCFTTransformer(
        n_features=n_features,
        d_model=int(cfg["d_model"]),
        n_heads=int(cfg["n_heads"]),
        n_layers=int(cfg["n_layers"]),
        ffn_mult=int(cfg["ffn_mult"]),
        dropout=float(cfg["dropout"]),
        attn_dropout=float(cfg["attn_dropout"]),
        n_streams=int(cfg["n_streams"]),
        sinkhorn_tmax=int(cfg["sinkhorn_tmax"]),
        alpha_init=float(cfg["alpha_init"]),
        feat_token_drop_p=float(cfg.get("feat_token_drop_p", 0.0)),
    ).to(device)

    # imbalance pos_weight
    pos = int(y_tr.sum())
    neg = int(len(y_tr) - pos)
    pos_weight = torch.tensor([float(neg / max(1, pos))], device=device, dtype=torch.float32)

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        betas=tuple(cfg["betas"]),
        eps=float(cfg["eps"]),
        weight_decay=float(cfg["weight_decay"]),
    )

    accum_steps = max(1, int(cfg.get("accum_steps", 1)))
    optim_steps_per_epoch = int(math.ceil(len(dl_tr) / accum_steps))
    total_optim_steps = int(cfg["epochs"]) * max(1, optim_steps_per_epoch)
    warmup_steps = int(float(cfg["warmup_frac"]) * total_optim_steps)

    sch = make_warmup_step_scheduler(
        opt,
        total_steps=total_optim_steps,
        warmup_steps=warmup_steps,
        milestones_frac=cfg["lr_decay_milestones"],
        decay_values=cfg["lr_decay_values"],
    )

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    ema = EMA(model, decay=float(cfg["ema_decay"])) if bool(cfg.get("use_ema", True)) else None

    best = {"val_logloss": 1e9, "val_auc": None, "epoch": -1}
    best_state = None
    bad = 0

    input_noise_std = float(cfg.get("input_noise_std", 0.0))
    label_smooth = float(cfg.get("label_smoothing", 0.0))
    focal_gamma = float(cfg.get("focal_gamma", 0.0))

    for epoch in range(int(cfg["epochs"])):
        model.train()
        t0 = time.time()
        loss_sum = 0.0
        n_sum = 0

        opt.zero_grad(set_to_none=True)
        micro_step = 0
        optim_step_in_epoch = 0

        for xb, yb in dl_tr:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()

            if input_noise_std and input_noise_std > 0:
                xb = xb + torch.randn_like(xb) * input_noise_std

            if label_smooth and label_smooth > 0:
                yb = yb * (1.0 - label_smooth) + 0.5 * label_smooth

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model(xb)
                loss = focal_bce_with_logits(logits, yb, pos_weight=pos_weight, gamma=focal_gamma)
                loss = loss / accum_steps

            scaler.scale(loss).backward()
            micro_step += 1

            loss_sum += float(loss.item()) * xb.size(0) * accum_steps
            n_sum += xb.size(0)

            if (micro_step % accum_steps) == 0:
                if cfg.get("grad_clip", 0) and float(cfg["grad_clip"]) > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))

                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                sch.step()
                optim_step_in_epoch += 1

                if ema is not None:
                    ema.update(model)

        # flush last partial
        if (micro_step % accum_steps) != 0:
            if cfg.get("grad_clip", 0) and float(cfg["grad_clip"]) > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)
            sch.step()
            optim_step_in_epoch += 1
            if ema is not None:
                ema.update(model)

        # validate (EMA-eval)
        p_va_raw = predict_proba(model, dl_va, ema=ema)
        vll = safe_logloss(y_va, p_va_raw)
        vauc = safe_auc(y_va, p_va_raw)

        tr_loss = loss_sum / max(1, n_sum)
        dt = time.time() - t0
        print(f"  epoch {epoch+1:03d}/{cfg['epochs']} | train_loss={tr_loss:.5f} | val_logloss={vll:.5f} | val_auc={vauc} | opt_steps={optim_step_in_epoch} | dt={dt:.1f}s")

        improved = (best["val_logloss"] - vll) > float(cfg["early_stop_min_delta"])
        if improved:
            best["val_logloss"] = float(vll)
            best["val_auc"] = vauc
            best["epoch"] = int(epoch)

            # save best EMA-weight (since eval used EMA)
            if ema is not None:
                ema.apply_shadow(model)
                best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
                ema.restore(model)
            else:
                best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= int(cfg["early_stop_patience"]):
                print(f"  early stop at epoch {epoch+1}, best_epoch={best['epoch']+1}, best_val_logloss={best['val_logloss']:.5f}")
                break

        if device.type == "cuda":
            torch.cuda.empty_cache()
        gc.collect()

    # load best_state (already EMA weights if used)
    if best_state is not None:
        model.load_state_dict(best_state, strict=True)

    # predict RAW on train & val (best weights)
    p_tr_raw = predict_proba(model, dl_tr_eval, ema=None)
    p_va_raw = predict_proba(model, dl_va, ema=None)

    # fit fold calibrator on TRAIN preds only (leakage-safe)
    cal = None
    p_va_cal = p_va_raw
    if bool(cfg.get("use_calibration", True)):
        cal_type = cfg.get("calibration", None)
        cal = fit_calibrator(cal_type, p_tr_raw, y_tr, min_samples=int(cfg.get("calib_min_samples", 200)))
        if cal is not None:
            p_va_cal = apply_calibrator(cal, p_va_raw)

    pack = {
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "mu": mu,
        "sig": sig,
        "cfg": cfg,
        "best": best,
        "calibrator": cal,  # portable dict or None
    }
    return pack, p_va_raw, p_va_cal, best

# ----------------------------
# 16) Multi-seed CV loop
# ----------------------------
# auto N_SEEDS
try:
    mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3) if device.type == "cuda" else 0
except Exception:
    mem_gb = 0

N_SEEDS_ENV = os.environ.get("N_SEEDS", "").strip()
if N_SEEDS_ENV:
    N_SEEDS = int(N_SEEDS_ENV)
else:
    N_SEEDS = 2 if (device.type == "cuda" and mem_gb >= 30) else 1

SEEDS = [int(CFG["seed"]) + i * 17 for i in range(max(1, N_SEEDS))]
print("\nSEED plan:", SEEDS)

out_dir = Path("/kaggle/working/recodai_luc_gate_artifacts")
out_dir.mkdir(parents=True, exist_ok=True)

all_seed_reports = []
all_seed_oof_raw = []
all_seed_oof_cal = []
best_epochs_all = []

for seed_i, seed in enumerate(SEEDS):
    print(f"\n==============================")
    print(f"== SEED {seed} ({seed_i+1}/{len(SEEDS)})")
    print(f"==============================")
    seed_everything(seed)

    oof_raw = np.zeros(n, dtype=np.float32)
    oof_cal = np.zeros(n, dtype=np.float32)

    fold_reports = []
    best_epochs = []

    models_dir = out_dir / f"mhc_transformer_folds_seed_{seed}"
    models_dir.mkdir(parents=True, exist_ok=True)

    for f in unique_folds:
        print(f"\n[Seed {seed} | Fold {f}]")
        tr_idx = np.where(folds != f)[0]
        va_idx = np.where(folds == f)[0]

        X_tr, y_tr = X[tr_idx], y[tr_idx]
        X_va, y_va = X[va_idx], y[va_idx]

        pack, p_va_raw, p_va_cal, best = train_one_fold(X_tr, y_tr, X_va, y_va, CFG)

        oof_raw[va_idx] = p_va_raw
        oof_cal[va_idx] = p_va_cal
        best_epochs.append(int(best["epoch"] + 1))

        # metrics raw
        auc_raw = safe_auc(y_va, p_va_raw)
        ll_raw  = safe_logloss(y_va, p_va_raw)

        # metrics calibrated
        auc_cal = safe_auc(y_va, p_va_cal)
        ll_cal  = safe_logloss(y_va, p_va_cal)

        # best threshold search for calibrated (recommended)
        if bool(CFG.get("search_best_thr", True)):
            bt = find_best_threshold(y_va, p_va_cal, n_grid=int(CFG.get("thr_grid", 801)),
                                     objective=str(CFG.get("thr_objective", "f0.5")).lower())
            thr_use = float(bt["thr"])
        else:
            thr_use = float(CFG.get("report_thr", 0.5))
            bt = None

        yhat = (p_va_cal >= thr_use).astype(np.int32)

        rep = {
            "seed": int(seed),
            "fold": int(f),
            "n_val": int(len(va_idx)),
            "pos_val": int(y_va.sum()),
            "auc_raw": auc_raw,
            "logloss_raw": ll_raw,
            "auc_cal": auc_cal,
            "logloss_cal": ll_cal,
            "thr_used": thr_use,
            "f1@thr": float(f1_score(y_va, yhat, zero_division=0)),
            "f0.5@thr": float(fbeta_np(y_va, yhat, beta=0.5)),
            "f2@thr": float(fbeta_np(y_va, yhat, beta=2.0)),
            "precision@thr": float(precision_score(y_va, yhat, zero_division=0)),
            "recall@thr": float(recall_score(y_va, yhat, zero_division=0)),
            "best_val_logloss": float(best["val_logloss"]),
            "best_val_auc": best["val_auc"],
            "best_epoch": int(best["epoch"] + 1),
            "thr_search": bt,
            "used_calibration": bool(pack.get("calibrator") is not None),
            "calibration_type": (pack["calibrator"]["type"] if pack.get("calibrator") else None),
        }
        fold_reports.append(rep)

        torch.save(
            {"pack": pack, "feature_cols": FEATURE_COLS, "seed": int(seed), "fold": int(f)},
            models_dir / f"mhc_transformer_fold_{f}.pt"
        )

        if device.type == "cuda":
            torch.cuda.empty_cache()
        gc.collect()

    # overall metrics for seed
    oof_auc_raw = safe_auc(y, oof_raw)
    oof_ll_raw  = safe_logloss(y, oof_raw)
    oof_auc_cal = safe_auc(y, oof_cal)
    oof_ll_cal  = safe_logloss(y, oof_cal)

    bt_oof = None
    best_thr = None
    if bool(CFG.get("search_best_thr", True)):
        bt_oof = find_best_threshold(y, oof_cal, n_grid=int(CFG.get("thr_grid", 801)),
                                     objective=str(CFG.get("thr_objective", "f0.5")).lower())
        best_thr = float(bt_oof["thr"])

    # fixed thr baseline
    thr_fixed = float(CFG.get("report_thr", 0.5))
    yhat_fixed = (oof_cal >= thr_fixed).astype(np.int32)

    overall = {
        "seed": int(seed),
        "rows": int(n),
        "folds": int(n_folds),
        "pos_total": int(y.sum()),
        "pos_rate": float(y.mean()),
        "oof_auc_raw": oof_auc_raw,
        "oof_logloss_raw": oof_ll_raw,
        "oof_auc_cal": oof_auc_cal,
        "oof_logloss_cal": oof_ll_cal,
        f"oof_f1@{thr_fixed}": float(f1_score(y, yhat_fixed, zero_division=0)),
        f"oof_f0.5@{thr_fixed}": float(fbeta_np(y, yhat_fixed, beta=0.5)),
        f"oof_precision@{thr_fixed}": float(precision_score(y, yhat_fixed, zero_division=0)),
        f"oof_recall@{thr_fixed}": float(recall_score(y, yhat_fixed, zero_division=0)),
        "oof_best_thr_cal": best_thr,
        "oof_best_thr_detail": bt_oof,
        "best_epochs": best_epochs,
    }

    df_rep = pd.DataFrame(fold_reports).sort_values(["seed", "fold"]).reset_index(drop=True)
    print("\nPer-fold report:")
    display(df_rep)
    print("\nOOF overall (seed):")
    print(overall)

    # save per-seed report
    with open(out_dir / f"mhc_transformer_cv_report_seed_{seed}.json", "w") as f:
        json.dump({"cfg_name": CFG_NAME, "cfg": CFG, "fold_reports": fold_reports, "overall": overall}, f, indent=2)

    all_seed_reports.append({"seed": seed, "fold_reports": fold_reports, "overall": overall})
    all_seed_oof_raw.append(oof_raw)
    all_seed_oof_cal.append(oof_cal)
    best_epochs_all.append(best_epochs)

# ----------------------------
# 17) Seed-ensemble OOF (avg)
# ----------------------------
OOF_PRED_MHC_RAW = np.mean(np.stack(all_seed_oof_raw, axis=0), axis=0).astype(np.float32)
OOF_PRED_MHC_CAL = np.mean(np.stack(all_seed_oof_cal, axis=0), axis=0).astype(np.float32)

ens_auc_raw = safe_auc(y, OOF_PRED_MHC_RAW)
ens_ll_raw  = safe_logloss(y, OOF_PRED_MHC_RAW)
ens_auc_cal = safe_auc(y, OOF_PRED_MHC_CAL)
ens_ll_cal  = safe_logloss(y, OOF_PRED_MHC_CAL)

thr_fixed = float(CFG.get("report_thr", 0.5))
yhat_fixed = (OOF_PRED_MHC_CAL >= thr_fixed).astype(np.int32)

bt_ens = None
ens_best_thr = None
if bool(CFG.get("search_best_thr", True)):
    bt_ens = find_best_threshold(y, OOF_PRED_MHC_CAL, n_grid=int(CFG.get("thr_grid", 801)),
                                 objective=str(CFG.get("thr_objective", "f0.5")).lower())
    ens_best_thr = float(bt_ens["thr"])

BASELINE_MHC_TF_OVERALL = {
    "model": "mHC-FTTransformer (tabular gate) v3.0",
    "cfg_name": CFG_NAME,
    "seeds": SEEDS,
    "feature_count": int(len(FEATURE_COLS)),
    "oof_auc_raw": ens_auc_raw,
    "oof_logloss_raw": ens_ll_raw,
    "oof_auc_cal": ens_auc_cal,
    "oof_logloss_cal": ens_ll_cal,
    f"oof_f1@{thr_fixed}": float(f1_score(y, yhat_fixed, zero_division=0)),
    f"oof_f0.5@{thr_fixed}": float(fbeta_np(y, yhat_fixed, beta=0.5)),
    f"oof_precision@{thr_fixed}": float(precision_score(y, yhat_fixed, zero_division=0)),
    f"oof_recall@{thr_fixed}": float(recall_score(y, yhat_fixed, zero_division=0)),
    "oof_best_thr_cal": ens_best_thr,
    "oof_best_thr_detail": bt_ens,
}

print("\n==============================")
print("OOF ENSEMBLE overall:")
print(BASELINE_MHC_TF_OVERALL)
print("==============================\n")

# flatten fold reports for export
BASELINE_MHC_TF_FOLD_REPORTS = []
for sr in all_seed_reports:
    BASELINE_MHC_TF_FOLD_REPORTS.extend(sr["fold_reports"])

BASELINE_MHC_TF_BEST_EPOCHS = [int(x) for xs in best_epochs_all for x in xs]

# ----------------------------
# 18) Train FULL model (epochs = median(best_epoch) * 1.15)
#    (pakai seed pertama; fold model sudah disimpan per seed)
# ----------------------------
def train_full_fixed(X_full_raw, y_full, cfg, epochs_full: int, seed: int):
    seed_everything(seed)
    mu, sig = fit_standardizer(X_full_raw)
    X_full = apply_standardizer(X_full_raw, mu, sig)

    ds_full = TabDataset(X_full, y_full)
    cpu_cnt = os.cpu_count() or 2
    nw = 2 if cpu_cnt >= 4 else 0
    pin = (device.type == "cuda")

    dl_full = DataLoader(ds_full, batch_size=int(cfg["batch_size"]), shuffle=True,
                         num_workers=nw, pin_memory=pin, drop_last=False,
                         persistent_workers=(nw > 0))

    model = MHCFTTransformer(
        n_features=n_features,
        d_model=int(cfg["d_model"]),
        n_heads=int(cfg["n_heads"]),
        n_layers=int(cfg["n_layers"]),
        ffn_mult=int(cfg["ffn_mult"]),
        dropout=float(cfg["dropout"]),
        attn_dropout=float(cfg["attn_dropout"]),
        n_streams=int(cfg["n_streams"]),
        sinkhorn_tmax=int(cfg["sinkhorn_tmax"]),
        alpha_init=float(cfg["alpha_init"]),
        feat_token_drop_p=float(cfg.get("feat_token_drop_p", 0.0)),
    ).to(device)

    pos = int(y_full.sum())
    neg = int(len(y_full) - pos)
    pos_weight = torch.tensor([float(neg / max(1, pos))], device=device, dtype=torch.float32)

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        betas=tuple(cfg["betas"]),
        eps=float(cfg["eps"]),
        weight_decay=float(cfg["weight_decay"]),
    )

    accum_steps = max(1, int(cfg.get("accum_steps", 1)))
    optim_steps_per_epoch = int(math.ceil(len(dl_full) / accum_steps))
    total_optim_steps = int(epochs_full) * max(1, optim_steps_per_epoch)
    warmup_steps = int(float(cfg["warmup_frac"]) * total_optim_steps)

    sch = make_warmup_step_scheduler(
        opt,
        total_steps=total_optim_steps,
        warmup_steps=warmup_steps,
        milestones_frac=cfg["lr_decay_milestones"],
        decay_values=cfg["lr_decay_values"],
    )

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    ema = EMA(model, decay=float(cfg["ema_decay"])) if bool(cfg.get("use_ema", True)) else None

    input_noise_std = float(cfg.get("input_noise_std", 0.0))
    label_smooth = float(cfg.get("label_smoothing", 0.0))
    focal_gamma = float(cfg.get("focal_gamma", 0.0))

    print(f"\nTraining FULL mHC transformer for {epochs_full} epochs | seed={seed} ...")

    for epoch in range(int(epochs_full)):
        model.train()
        loss_sum = 0.0
        n_sum = 0

        opt.zero_grad(set_to_none=True)
        micro_step = 0

        for xb, yb in dl_full:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()

            if input_noise_std and input_noise_std > 0:
                xb = xb + torch.randn_like(xb) * input_noise_std

            if label_smooth and label_smooth > 0:
                yb = yb * (1.0 - label_smooth) + 0.5 * label_smooth

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model(xb)
                loss = focal_bce_with_logits(logits, yb, pos_weight=pos_weight, gamma=focal_gamma)
                loss = loss / accum_steps

            scaler.scale(loss).backward()
            micro_step += 1

            loss_sum += float(loss.item()) * xb.size(0) * accum_steps
            n_sum += xb.size(0)

            if (micro_step % accum_steps) == 0:
                if cfg.get("grad_clip", 0) and float(cfg["grad_clip"]) > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                sch.step()
                if ema is not None:
                    ema.update(model)

        if (micro_step % accum_steps) != 0:
            if cfg.get("grad_clip", 0) and float(cfg["grad_clip"]) > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)
            sch.step()
            if ema is not None:
                ema.update(model)

        print(f"  full epoch {epoch+1:03d}/{epochs_full} | loss={loss_sum/max(1,n_sum):.5f}")

        if device.type == "cuda":
            torch.cuda.empty_cache()
        gc.collect()

    # save EMA weights if used
    used_ema = bool(ema is not None)
    if ema is not None:
        ema.apply_shadow(model)

    full_pack = {
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "mu": mu,
        "sig": sig,
        "cfg": cfg,
        "epochs_full": int(epochs_full),
        "used_ema": used_ema,
        "seed": int(seed),
        "recommended_thr": BASELINE_MHC_TF_OVERALL.get("oof_best_thr_cal", None),
        "recommended_thr_detail": BASELINE_MHC_TF_OVERALL.get("oof_best_thr_detail", None),
    }

    if ema is not None:
        ema.restore(model)

    return full_pack

# decide epochs_full from all folds, all seeds (median)
flat_best_epochs = np.array(BASELINE_MHC_TF_BEST_EPOCHS, dtype=np.int32)
med_best = int(np.median(flat_best_epochs)) if len(flat_best_epochs) else int(max(12, CFG["epochs"] * 0.7))
epochs_full = int(max(12, round(med_best * 1.15)))
epochs_full = int(min(epochs_full, int(CFG["epochs"])))  # safety

full_pack = train_full_fixed(X, y, CFG, epochs_full=epochs_full, seed=int(SEEDS[0]))

FULL_PACK_PATH = str(out_dir / "mhc_transformer_model_full.pt")
torch.save({"pack": full_pack, "feature_cols": FEATURE_COLS}, FULL_PACK_PATH)

# ----------------------------
# 19) Save OOF + report
# ----------------------------
df_oof = df_train_tabular[["uid", "case_id", "variant", "fold", "y"]].copy()
df_oof["oof_pred_mhc_raw"] = OOF_PRED_MHC_RAW
df_oof["oof_pred_mhc_cal"] = OOF_PRED_MHC_CAL
df_oof.to_csv(out_dir / "oof_mhc_transformer.csv", index=False)

np.save(out_dir / "oof_pred_mhc_raw.npy", OOF_PRED_MHC_RAW)
np.save(out_dir / "oof_pred_mhc_cal.npy", OOF_PRED_MHC_CAL)

report = {
    "model": "mHC-FTTransformer (numeric tabular gate) v3.0",
    "cfg_name": CFG_NAME,
    "cfg": CFG,
    "feature_count": int(len(FEATURE_COLS)),
    "seeds": SEEDS,
    "seed_reports": all_seed_reports,
    "ensemble_overall": BASELINE_MHC_TF_OVERALL,
    "epochs_full": int(epochs_full),
    "full_pack_path": FULL_PACK_PATH,
}
with open(out_dir / "mhc_transformer_cv_report.json", "w") as f:
    json.dump(report, f, indent=2)

print("\nSaved artifacts:")
print("  fold models  ->", out_dir, "(mhc_transformer_folds_seed_*/)")
print("  full model   ->", FULL_PACK_PATH)
print("  oof preds    ->", out_dir / "oof_mhc_transformer.csv")
print("  oof npy      ->", out_dir / "oof_pred_mhc_raw.npy", "and", out_dir / "oof_pred_mhc_cal.npy")
print("  cv report    ->", out_dir / "mhc_transformer_cv_report.json")

# Export globals
globals().update({
    "OOF_PRED_MHC_RAW": OOF_PRED_MHC_RAW,
    "OOF_PRED_MHC_CAL": OOF_PRED_MHC_CAL,
    "BASELINE_MHC_TF_OVERALL": BASELINE_MHC_TF_OVERALL,
    "BASELINE_MHC_TF_FOLD_REPORTS": BASELINE_MHC_TF_FOLD_REPORTS,
    "BASELINE_MHC_TF_BEST_EPOCHS": BASELINE_MHC_TF_BEST_EPOCHS,
    "FULL_PACK_PATH": FULL_PACK_PATH,
})


Device: cuda | AMP: True | CFG: SAFE
Setup:
  rows      : 5176
  folds     : 5 | [0, 1, 2, 3, 4]
  pos%      : 54.07650695517774
  n_features: 84

SEED plan: [2025]

== SEED 2025 (1/1)

[Seed 2025 | Fold 0]
  epoch 001/60 | train_loss=0.21723 | val_logloss=0.70783 | val_auc=0.06474154975990962 | opt_steps=9 | dt=6.1s
  epoch 002/60 | train_loss=0.09139 | val_logloss=0.70457 | val_auc=0.09654458148950193 | opt_steps=9 | dt=4.5s
  epoch 003/60 | train_loss=0.03457 | val_logloss=0.69920 | val_auc=0.200711797382544 | opt_steps=9 | dt=4.5s
  epoch 004/60 | train_loss=0.01954 | val_logloss=0.69166 | val_auc=0.44811411354863007 | opt_steps=9 | dt=4.5s
  epoch 005/60 | train_loss=0.02331 | val_logloss=0.68211 | val_auc=0.7642142924395066 | opt_steps=9 | dt=4.5s
  epoch 006/60 | train_loss=0.00354 | val_logloss=0.67043 | val_auc=0.9357499293851803 | opt_steps=9 | dt=4.5s
  epoch 007/60 | train_loss=0.00282 | val_logloss=0.65666 | val_auc=0.974949628095283 | opt_steps=9 | dt=4.5s
  epoch 008/60 

Unnamed: 0,seed,fold,n_val,pos_val,auc_raw,logloss_raw,auc_cal,logloss_cal,thr_used,f1@thr,f0.5@thr,f2@thr,precision@thr,recall@thr,best_val_logloss,best_val_auc,best_epoch,thr_search,used_calibration,calibration_type
0,2025,0,1034,559,1.0,0.414488,1.0,1e-06,0.00125,1.0,1.0,1.0,1.0,1.0,0.414488,1.0,25,"{'thr': 0.00125, 'score': 1.0, 'f1': 1.0, 'f05...",True,isotonic
1,2025,1,1041,561,1.0,0.320638,1.0,0.000222,0.20625,1.0,1.0,1.0,1.0,1.0,0.320638,1.0,24,"{'thr': 0.20625000000000002, 'score': 1.0, 'f1...",True,isotonic
2,2025,2,1032,559,1.0,0.308105,1.0,1e-06,0.00125,1.0,1.0,1.0,1.0,1.0,0.308105,1.0,23,"{'thr': 0.00125, 'score': 1.0, 'f1': 1.0, 'f05...",True,isotonic
3,2025,3,1035,560,1.0,0.31403,1.0,1e-06,0.00125,1.0,1.0,1.0,1.0,1.0,0.31403,1.0,22,"{'thr': 0.00125, 'score': 1.0, 'f1': 1.0, 'f05...",True,isotonic
4,2025,4,1034,560,1.0,0.346357,1.0,9.7e-05,0.005,1.0,1.0,1.0,1.0,1.0,0.346357,1.0,23,"{'thr': 0.005, 'score': 1.0, 'f1': 1.0, 'f05':...",True,isotonic



OOF overall (seed):
{'seed': 2025, 'rows': 5176, 'folds': 5, 'pos_total': 2799, 'pos_rate': 0.5407650695517774, 'oof_auc_raw': 0.9999998496969063, 'oof_logloss_raw': 0.3407038727618474, 'oof_auc_cal': 1.0, 'oof_logloss_cal': 6.458565854921662e-05, 'oof_f1@0.5': 1.0, 'oof_f0.5@0.5': 1.0, 'oof_precision@0.5': 1.0, 'oof_recall@0.5': 1.0, 'oof_best_thr_cal': 0.20625000000000002, 'oof_best_thr_detail': {'thr': 0.20625000000000002, 'score': 1.0, 'f1': 1.0, 'f05': 1.0, 'f2': 1.0, 'precision': 1.0, 'recall': 1.0, 'objective': 'f0.5'}, 'best_epochs': [25, 24, 23, 22, 23]}

OOF ENSEMBLE overall:
{'model': 'mHC-FTTransformer (tabular gate) v3.0', 'cfg_name': 'SAFE', 'seeds': [2025], 'feature_count': 84, 'oof_auc_raw': 0.9999998496969063, 'oof_logloss_raw': 0.3407038727618474, 'oof_auc_cal': 1.0, 'oof_logloss_cal': 6.458565854921662e-05, 'oof_f1@0.5': 1.0, 'oof_f0.5@0.5': 1.0, 'oof_precision@0.5': 1.0, 'oof_recall@0.5': 1.0, 'oof_best_thr_cal': 0.20625000000000002, 'oof_best_thr_detail': {'thr': 

# Optimize Model & Hyperparameters (Iterative)

In [5]:
# ============================================================
# Step 4 — Optimize Model & Hyperparameters (Iterative) — TRANSFORMER ONLY
# REVISI FULL v3.2 (mHC-lite PDF-inspired, differentiable + EMA + accum + reg)
#
# Perbaikan v3.2 (dibanding draft kamu):
# - Stage-1 folds subset dipilih benar-benar “evenly spaced” (bukan slicing raw)
# - AMP/GradScaler aman: otomatis OFF di CPU (hindari warning/bug)
# - Scheduler benar-benar on optimizer-steps (sudah), + total_steps guard
# - Early-stop logging lebih jelas + cleanup lebih agresif
# - Save best_gate_model.pt sekarang juga menyimpan recommended_thr + best_oof_score
# - Save stage1_results.csv + bisa skip kandidat yang sudah ada (resume-safe)
#
# Primary score: OOF best F-beta (beta=0.5)
#
# Output:
# - /kaggle/working/recodai_luc_gate_artifacts/opt_search/stage1_results.csv
# - /kaggle/working/recodai_luc_gate_artifacts/opt_search/opt_results.csv
# - /kaggle/working/recodai_luc_gate_artifacts/opt_search/opt_results.json
# - /kaggle/working/recodai_luc_gate_artifacts/opt_search/opt_fold_details.csv
# - /kaggle/working/recodai_luc_gate_artifacts/opt_search/oof_preds_<cfg_name>.csv (top configs)
# - /kaggle/working/recodai_luc_gate_artifacts/best_gate_config.json
# - /kaggle/working/recodai_luc_gate_artifacts/best_gate_model.pt
# ============================================================

import os, json, gc, math, time, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

from IPython.display import display
from sklearn.metrics import roc_auc_score, log_loss

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# 0) Require data from Step 2
# ----------------------------
need_vars = ["df_train_tabular", "FEATURE_COLS"]
for v in need_vars:
    if v not in globals():
        raise RuntimeError(f"Missing `{v}`. Jalankan dulu Step 2 — Build Training Table (X, y, folds).")

df_train_tabular = df_train_tabular.copy()
FEATURE_COLS = list(FEATURE_COLS)

X_all = df_train_tabular[FEATURE_COLS].to_numpy(dtype=np.float32, copy=True)
y_all = df_train_tabular["y"].to_numpy(dtype=np.int64, copy=True)
folds_all = df_train_tabular["fold"].to_numpy(dtype=np.int64, copy=True)
uids_all = df_train_tabular["uid"].astype(str).to_numpy()

if not np.isfinite(X_all).all():
    X_all = np.nan_to_num(X_all, nan=0.0, posinf=0.0, neginf=0.0)

unique_folds = sorted(pd.Series(folds_all).unique().tolist())
n = len(y_all)
pos_rate = float(y_all.mean())
n_features = X_all.shape[1]

print("Optimize setup (Transformer-only, mHC-lite):")
print(f"  rows={n} | folds={len(unique_folds)} | pos%={pos_rate*100:.2f} | n_features={n_features}")

# ----------------------------
# 1) Global settings
# ----------------------------
SEED = 2025
BETA = 0.5
THR_GRID = 201

# 2-stage runtime controls
STAGE1_FOLDS = min(3, len(unique_folds))
STAGE1_EPOCH_CAP = 35
STAGE1_PAT_CAP = 6

STAGE2_TOPM = 3
REPORT_TOPK_OOF = 3

# Optional time budget (0 = off)
TIME_BUDGET_SEC = 0  # contoh: 2.5*60*60

def seed_everything(seed=2025):
    import random
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")
print("Device:", device, "| AMP:", use_amp)

if device.type == "cuda":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

def get_mem_gb():
    if not torch.cuda.is_available():
        return 0.0
    try:
        return float(torch.cuda.get_device_properties(0).total_memory / (1024**3))
    except Exception:
        return 0.0

MEM_GB = get_mem_gb()
print("GPU mem GB:", MEM_GB)

# ----------------------------
# 2) Metrics helpers (F-beta primary)
# ----------------------------
def best_fbeta_fast(y_true, p, beta=0.5, grid=201):
    y = (np.asarray(y_true).astype(np.int32) == 1)
    p = np.clip(np.asarray(p, dtype=np.float64), 1e-8, 1.0 - 1e-8)

    thrs = np.linspace(0.01, 0.99, int(grid), dtype=np.float64)
    pred = (p[:, None] >= thrs[None, :])

    y1 = y[:, None]
    tp = (pred & y1).sum(axis=0).astype(np.float64)
    fp = (pred & (~y1)).sum(axis=0).astype(np.float64)
    fn = (y.sum().astype(np.float64) - tp)

    precision = np.divide(tp, tp + fp, out=np.zeros_like(tp), where=(tp + fp) > 0)
    recall    = np.divide(tp, tp + fn, out=np.zeros_like(tp), where=(tp + fn) > 0)

    b2 = beta * beta
    denom = (b2 * precision + recall)
    fbeta = np.divide((1.0 + b2) * precision * recall, denom, out=np.zeros_like(precision), where=denom > 0)

    j = int(np.argmax(fbeta))
    return {
        "fbeta": float(fbeta[j]),
        "thr": float(thrs[j]),
        "precision": float(precision[j]),
        "recall": float(recall[j]),
    }

def safe_auc(y_true, p):
    if len(np.unique(y_true)) < 2:
        return None
    return float(roc_auc_score(y_true, p))

def safe_logloss(y_true, p):
    p = np.clip(np.asarray(p, dtype=np.float64), 1e-8, 1 - 1e-8)
    return float(log_loss(y_true, p, labels=[0, 1]))

# ----------------------------
# 3) Dataset + Standardizer (no leakage)
# ----------------------------
class TabDataset(Dataset):
    def __init__(self, X, y=None):
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = None if y is None else torch.from_numpy(y.astype(np.float32))

    def __len__(self): return self.X.shape[0]

    def __getitem__(self, idx):
        if self.y is None:
            return self.X[idx]
        return self.X[idx], self.y[idx]

def fit_standardizer(X_tr: np.ndarray):
    mu = X_tr.mean(axis=0, dtype=np.float64)
    sig = X_tr.std(axis=0, dtype=np.float64)
    sig = np.where(sig < 1e-8, 1.0, sig)
    return mu.astype(np.float32), sig.astype(np.float32)

def apply_standardizer(X_in: np.ndarray, mu: np.ndarray, sig: np.ndarray):
    return ((X_in - mu) / sig).astype(np.float32)

# ----------------------------
# 4) RMSNorm + Differentiable Sinkhorn + EMA
# ----------------------------
class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-6):
        super().__init__()
        self.eps = float(eps)
        self.weight = nn.Parameter(torch.ones(d))

    def forward(self, x):
        rms = torch.mean(x * x, dim=-1, keepdim=True)
        x = x * torch.rsqrt(rms + self.eps)
        return x * self.weight

def sinkhorn_knopp(P, tmax=20, eps=1e-6):
    """
    Differentiable Sinkhorn-Knopp
    P: (B,n,n) non-negative
    """
    M = P.clamp_min(eps)
    for _ in range(int(tmax)):
        M = M / (M.sum(dim=-1, keepdim=True).clamp_min(eps))  # row
        M = M / (M.sum(dim=-2, keepdim=True).clamp_min(eps))  # col
    return M

class EMA:
    def __init__(self, model: nn.Module, decay: float = 0.999):
        self.decay = float(decay)
        self.shadow = {}
        self.backup = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.detach().clone()

    @torch.no_grad()
    def update(self, model: nn.Module):
        d = self.decay
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.shadow[name].mul_(d).add_(p.detach(), alpha=(1.0 - d))

    @torch.no_grad()
    def apply_shadow(self, model: nn.Module):
        self.backup = {}
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.backup[name] = p.detach().clone()
            p.copy_(self.shadow[name])

    @torch.no_grad()
    def restore(self, model: nn.Module):
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            p.copy_(self.backup[name])
        self.backup = {}

# ----------------------------
# 5) mHC-lite on CLS streams
# ----------------------------
class MHCLite(nn.Module):
    """
    Maintain n_streams for CLS only.
    Build per-sample mixing matrix -> Sinkhorn -> residual blend with Identity.
    """
    def __init__(self, d_model, n_streams=4, alpha_init=0.01, tmax=20, dropout=0.0):
        super().__init__()
        self.n = int(n_streams)
        self.tmax = int(tmax)
        self.drop = nn.Dropout(float(dropout))

        self.norm = RMSNorm(d_model, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, self.n * self.n),
        )
        self.softplus = nn.Softplus()

        a0 = float(alpha_init)
        a0 = min(max(a0, 1e-4), 1.0 - 1e-4)
        self.alpha_logit = nn.Parameter(torch.log(torch.tensor(a0 / (1 - a0), dtype=torch.float32)))

    def forward(self, streams, cls_vec):
        # streams: (B,n,D), cls_vec: (B,D)
        B, n, D = streams.shape
        h = self.norm(cls_vec)
        logits = self.mlp(h).view(B, n, n)  # (B,n,n)

        P = self.softplus(logits)           # non-negative
        M = sinkhorn_knopp(P, tmax=self.tmax, eps=1e-6)

        alpha = torch.sigmoid(self.alpha_logit).to(dtype=streams.dtype, device=streams.device)
        I = torch.eye(n, device=streams.device, dtype=streams.dtype).unsqueeze(0).expand(B, -1, -1)
        H = (1.0 - alpha) * I + alpha * M

        mixed = torch.einsum("bij,bjd->bid", H, streams)
        injected = mixed + cls_vec.unsqueeze(1)
        return self.drop(injected)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, ffn_mult=4, dropout=0.2, attn_dropout=0.1):
        super().__init__()
        self.norm1 = RMSNorm(d_model, eps=1e-6)
        self.attn = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=int(n_heads),
            dropout=float(attn_dropout), batch_first=True
        )
        self.drop1 = nn.Dropout(float(dropout))

        self.norm2 = RMSNorm(d_model, eps=1e-6)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, int(ffn_mult) * d_model),
            nn.GELU(),
            nn.Dropout(float(dropout)),
            nn.Linear(int(ffn_mult) * d_model, d_model),
        )
        self.drop2 = nn.Dropout(float(dropout))

    def forward(self, x):
        h = self.norm1(x)
        attn_out, _ = self.attn(h, h, h, need_weights=False)
        x = x + self.drop1(attn_out)

        h = self.norm2(x)
        x = x + self.drop2(self.ffn(h))
        return x

class FTTransformer_MHCLite(nn.Module):
    """
    Numeric FT-Transformer + CLS-stream mHC-lite between blocks.
    Regularizer: feature-token drop (zero some feature tokens, not CLS).
    """
    def __init__(self, n_features, d_model=384, n_heads=8, n_layers=8, ffn_mult=4,
                 dropout=0.2, attn_dropout=0.1,
                 n_streams=4, alpha_init=0.01, sinkhorn_tmax=20, mhc_dropout=0.0,
                 feat_token_drop_p=0.0):
        super().__init__()
        self.n_features = int(n_features)
        self.d_model = int(d_model)
        self.n_layers = int(n_layers)
        self.feat_token_drop_p = float(feat_token_drop_p)

        self.w = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)
        self.b = nn.Parameter(torch.zeros(self.n_features, self.d_model))
        self.feat_emb = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)

        self.cls = nn.Parameter(torch.randn(1, 1, self.d_model) * 0.02)
        self.in_drop = nn.Dropout(float(dropout))

        self.blocks = nn.ModuleList([
            TransformerBlock(
                d_model=self.d_model,
                n_heads=n_heads,
                ffn_mult=ffn_mult,
                dropout=dropout,
                attn_dropout=attn_dropout
            ) for _ in range(self.n_layers)
        ])

        self.mhc = nn.ModuleList([
            MHCLite(
                d_model=self.d_model,
                n_streams=n_streams,
                alpha_init=alpha_init,
                tmax=sinkhorn_tmax,
                dropout=mhc_dropout
            ) for _ in range(self.n_layers)
        ])

        self.out_norm = RMSNorm(self.d_model, eps=1e-6)
        self.head = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.GELU(),
            nn.Dropout(float(dropout)),
            nn.Linear(self.d_model, 1),
        )

    def forward(self, x):
        tok = x.unsqueeze(-1) * self.w.unsqueeze(0) + self.b.unsqueeze(0)
        tok = tok + self.feat_emb.unsqueeze(0)

        if self.training and self.feat_token_drop_p > 0:
            B, F_, D = tok.shape
            keep = (torch.rand(B, F_, device=tok.device) > self.feat_token_drop_p).to(tok.dtype)
            tok = tok * keep.unsqueeze(-1)

        B = tok.size(0)
        cls = self.cls.expand(B, -1, -1)
        seq = torch.cat([cls, tok], dim=1)
        seq = self.in_drop(seq)

        nS = self.mhc[0].n
        streams = seq[:, 0, :].unsqueeze(1).expand(B, nS, self.d_model).contiguous()

        for l, blk in enumerate(self.blocks):
            cls_in = streams.mean(dim=1).unsqueeze(1)
            seq = torch.cat([cls_in, seq[:, 1:, :]], dim=1)

            seq = blk(seq)
            cls_vec = seq[:, 0, :]

            streams = self.mhc[l](streams, cls_vec)

        out = self.out_norm(streams.mean(dim=1))
        logit = self.head(out).squeeze(-1)
        return logit

# ----------------------------
# 6) Scheduler (optimizer-steps)
# ----------------------------
def make_warmup_step_scheduler(optimizer, total_steps, warmup_steps,
                              r1=0.8, r2=0.9, d1=0.316, d2=0.1):
    total_steps = int(max(1, total_steps))
    warmup_steps = int(max(0, min(warmup_steps, total_steps)))

    m1 = int(float(r1) * total_steps)
    m2 = int(float(r2) * total_steps)

    def lr_lambda(step):
        if warmup_steps > 0 and step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))
        mult = 1.0
        if step >= m1:
            mult *= float(d1)
        if step >= m2:
            mult *= float(d2)
        return mult

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# ----------------------------
# 7) Predict helper
# ----------------------------
@torch.no_grad()
def predict_proba(model, loader, ema: EMA = None):
    model.eval()
    if ema is not None:
        ema.apply_shadow(model)

    probs = []
    for batch in loader:
        xb = batch[0] if isinstance(batch, (list, tuple)) else batch
        xb = xb.to(device, non_blocking=True)
        if use_amp:
            with torch.cuda.amp.autocast(enabled=True):
                logits = model(xb)
                p = torch.sigmoid(logits)
        else:
            logits = model(xb)
            p = torch.sigmoid(logits)
        probs.append(p.detach().cpu().numpy())

    out = np.concatenate(probs, axis=0).astype(np.float32)
    if ema is not None:
        ema.restore(model)
    return out

# ----------------------------
# 8) Train one fold
# ----------------------------
def train_one_fold_transformer(X_tr, y_tr, X_va, y_va, cfg):
    mu, sig = fit_standardizer(X_tr)
    X_trn = apply_standardizer(X_tr, mu, sig)
    X_van = apply_standardizer(X_va, mu, sig)

    ds_tr = TabDataset(X_trn, y_tr)
    ds_va = TabDataset(X_van, y_va)

    cpu_cnt = os.cpu_count() or 2
    nw = 2 if cpu_cnt >= 4 else 0
    pin = (device.type == "cuda")

    dl_tr = DataLoader(ds_tr, batch_size=int(cfg["batch_size"]), shuffle=True,
                       num_workers=nw, pin_memory=pin, drop_last=False,
                       persistent_workers=(nw > 0))
    dl_va = DataLoader(ds_va, batch_size=int(cfg["batch_size"]), shuffle=False,
                       num_workers=nw, pin_memory=pin, drop_last=False,
                       persistent_workers=(nw > 0))

    model = FTTransformer_MHCLite(
        n_features=n_features,
        d_model=int(cfg["d_model"]),
        n_heads=int(cfg["n_heads"]),
        n_layers=int(cfg["n_layers"]),
        ffn_mult=int(cfg["ffn_mult"]),
        dropout=float(cfg["dropout"]),
        attn_dropout=float(cfg["attn_dropout"]),
        n_streams=int(cfg["n_streams"]),
        alpha_init=float(cfg["alpha_init"]),
        sinkhorn_tmax=int(cfg["sinkhorn_tmax"]),
        mhc_dropout=float(cfg["mhc_dropout"]),
        feat_token_drop_p=float(cfg.get("feat_token_drop_p", 0.0)),
    ).to(device)

    pos = int(y_tr.sum())
    neg = int(len(y_tr) - pos)
    pos_weight = torch.tensor([float(neg / max(1, pos))], device=device, dtype=torch.float32)

    focal_gamma = float(cfg.get("focal_gamma", 0.0))
    label_smoothing = float(cfg.get("label_smoothing", 0.0))
    input_noise_std = float(cfg.get("input_noise_std", 0.0))

    def loss_fn(logits, targets):
        if label_smoothing and label_smoothing > 0:
            targets = targets * (1.0 - label_smoothing) + 0.5 * label_smoothing
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none", pos_weight=pos_weight)
        if focal_gamma and focal_gamma > 0:
            p = torch.sigmoid(logits)
            p_t = p * targets + (1 - p) * (1 - targets)
            mod = (1.0 - p_t).clamp_min(0.0).pow(focal_gamma)
            bce = bce * mod
        return bce.mean()

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        weight_decay=float(cfg["weight_decay"]),
        betas=(float(cfg["beta1"]), float(cfg["beta2"])),
        eps=float(cfg["adam_eps"]),
    )

    accum_steps = max(1, int(cfg.get("accum_steps", 1)))
    steps_per_epoch = int(math.ceil(len(dl_tr) / accum_steps))
    total_steps = int(cfg["epochs"]) * max(1, steps_per_epoch)
    warmup_steps = int(float(cfg["warmup_frac"]) * total_steps)

    sch = make_warmup_step_scheduler(
        opt,
        total_steps=total_steps,
        warmup_steps=warmup_steps,
        r1=float(cfg["lr_decay_ratio1"]),
        r2=float(cfg["lr_decay_ratio2"]),
        d1=float(cfg["lr_decay_rate1"]),
        d2=float(cfg["lr_decay_rate2"]),
    )

    if use_amp:
        scaler = torch.cuda.amp.GradScaler(enabled=True)
    else:
        scaler = None

    ema = EMA(model, decay=float(cfg.get("ema_decay", 0.999))) if bool(cfg.get("use_ema", True)) else None

    best_val = 1e18
    best_state = None
    best_epoch = -1
    bad = 0
    opt_step = 0

    for epoch in range(int(cfg["epochs"])):
        model.train()
        opt.zero_grad(set_to_none=True)

        loss_sum = 0.0
        n_sum = 0
        micro = 0

        for xb, yb in dl_tr:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()

            if input_noise_std and input_noise_std > 0:
                xb = xb + torch.randn_like(xb) * input_noise_std

            if use_amp:
                with torch.cuda.amp.autocast(enabled=True):
                    logits = model(xb)
                    loss = loss_fn(logits, yb) / accum_steps
                scaler.scale(loss).backward()
            else:
                logits = model(xb)
                loss = loss_fn(logits, yb) / accum_steps
                loss.backward()

            micro += 1
            loss_sum += float(loss.item()) * xb.size(0) * accum_steps
            n_sum += xb.size(0)

            if (micro % accum_steps) == 0:
                if float(cfg["grad_clip"]) and float(cfg["grad_clip"]) > 0:
                    if use_amp:
                        scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))

                if use_amp:
                    scaler.step(opt)
                    scaler.update()
                else:
                    opt.step()

                opt.zero_grad(set_to_none=True)
                sch.step()
                opt_step += 1
                if ema is not None:
                    ema.update(model)

        # flush last partial
        if (micro % accum_steps) != 0:
            if float(cfg["grad_clip"]) and float(cfg["grad_clip"]) > 0:
                if use_amp:
                    scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))

            if use_amp:
                scaler.step(opt)
                scaler.update()
            else:
                opt.step()

            opt.zero_grad(set_to_none=True)
            sch.step()
            opt_step += 1
            if ema is not None:
                ema.update(model)

        # validate with EMA (if enabled)
        p_va = predict_proba(model, dl_va, ema=ema)
        vll = safe_logloss(y_va, p_va)

        improved = (best_val - vll) > float(cfg["min_delta"])
        if improved:
            best_val = float(vll)
            best_epoch = int(epoch)

            if ema is not None:
                ema.apply_shadow(model)
                best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
                ema.restore(model)
            else:
                best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}

            bad = 0
        else:
            bad += 1
            if bad >= int(cfg["patience"]):
                break

        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    if best_state is not None:
        model.load_state_dict(best_state, strict=True)

    p_va = predict_proba(model, dl_va, ema=None)

    pack = {
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "mu": mu,
        "sig": sig,
        "cfg": dict(cfg),
        "best_epoch": int(best_epoch + 1),
        "best_val_logloss": float(best_val),
    }
    return pack, p_va

# ----------------------------
# 9) CV evaluator for a config
# ----------------------------
def run_cv_config(cfg, cfg_name, folds_subset=None, beta=0.5, thr_grid=201):
    oof = np.zeros(n, dtype=np.float32)
    fold_rows = []
    fold_packs = []

    use_folds = unique_folds if folds_subset is None else list(folds_subset)

    for f in use_folds:
        tr = np.where(folds_all != f)[0]
        va = np.where(folds_all == f)[0]

        X_tr, y_tr = X_all[tr], y_all[tr]
        X_va, y_va = X_all[va], y_all[va]

        pack, p_va = train_one_fold_transformer(X_tr, y_tr, X_va, y_va, cfg)
        oof[va] = p_va

        fold_auc = safe_auc(y_va, p_va)
        fold_ll  = safe_logloss(y_va, p_va)
        best_fold = best_fbeta_fast(y_va, p_va, beta=beta, grid=max(81, thr_grid//2))

        fold_rows.append({
            "cfg": cfg_name,
            "fold": int(f),
            "n_val": int(len(va)),
            "pos_val": int(y_va.sum()),
            "auc": fold_auc,
            "logloss": fold_ll,
            "best_fbeta": best_fold["fbeta"],
            "best_thr": best_fold["thr"],
            "best_prec": best_fold["precision"],
            "best_rec": best_fold["recall"],
            "best_val_logloss": float(pack["best_val_logloss"]),
            "best_epoch": int(pack["best_epoch"]),
        })

        pack2 = dict(pack)
        pack2["fold"] = int(f)
        fold_packs.append(pack2)

        del pack
        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    if folds_subset is None:
        idx_eval = np.arange(n)
    else:
        idx_eval = np.where(np.isin(folds_all, np.array(use_folds)))[0]

    oof_eval = oof[idx_eval]
    y_eval = y_all[idx_eval]

    oof_auc = safe_auc(y_eval, oof_eval)
    oof_ll  = safe_logloss(y_eval, oof_eval)
    best_oof = best_fbeta_fast(y_eval, oof_eval, beta=beta, grid=thr_grid)

    summary = {
        "cfg": cfg_name,
        "stage": "full" if folds_subset is None else f"subset{len(use_folds)}",
        "oof_auc": oof_auc,
        "oof_logloss": oof_ll,
        "oof_best_fbeta": best_oof["fbeta"],
        "oof_best_thr": best_oof["thr"],
        "oof_best_prec": best_oof["precision"],
        "oof_best_rec": best_oof["recall"],

        "d_model": cfg["d_model"],
        "n_layers": cfg["n_layers"],
        "n_heads": cfg["n_heads"],
        "ffn_mult": cfg["ffn_mult"],
        "dropout": cfg["dropout"],
        "attn_dropout": cfg["attn_dropout"],

        "n_streams": cfg["n_streams"],
        "alpha_init": cfg["alpha_init"],
        "sinkhorn_tmax": cfg["sinkhorn_tmax"],
        "mhc_dropout": cfg["mhc_dropout"],

        "feat_token_drop_p": cfg.get("feat_token_drop_p", 0.0),
        "input_noise_std": cfg.get("input_noise_std", 0.0),
        "focal_gamma": cfg.get("focal_gamma", 0.0),
        "label_smoothing": cfg.get("label_smoothing", 0.0),

        "batch_size": cfg["batch_size"],
        "accum_steps": cfg.get("accum_steps", 1),
        "epochs": cfg["epochs"],
        "lr": cfg["lr"],
        "weight_decay": cfg["weight_decay"],
        "warmup_frac": cfg["warmup_frac"],
        "beta1": cfg["beta1"],
        "beta2": cfg["beta2"],
        "adam_eps": cfg["adam_eps"],
        "lr_decay_ratio1": cfg["lr_decay_ratio1"],
        "lr_decay_ratio2": cfg["lr_decay_ratio2"],
        "lr_decay_rate1": cfg["lr_decay_rate1"],
        "lr_decay_rate2": cfg["lr_decay_rate2"],
        "patience": cfg["patience"],
        "min_delta": cfg["min_delta"],
        "grad_clip": cfg["grad_clip"],
        "use_ema": cfg.get("use_ema", True),
        "ema_decay": cfg.get("ema_decay", 0.999),
    }
    return summary, fold_rows, oof, fold_packs

# ----------------------------
# 10) Candidate configs
# ----------------------------
def make_base():
    if device.type == "cuda":
        if MEM_GB >= 30:
            bs, acc = 512, 2
        elif MEM_GB >= 16:
            bs, acc = 384, 2
        else:
            bs, acc = 256, 2
    else:
        bs, acc = 256, 1

    return dict(
        batch_size=bs,
        accum_steps=acc,
        epochs=75 if device.type == "cuda" else 40,
        lr=2e-4,
        weight_decay=1.0e-2,
        warmup_frac=0.10,
        grad_clip=1.0,
        patience=10,
        min_delta=1e-4,

        beta1=0.9,
        beta2=0.95,
        adam_eps=1e-8,

        lr_decay_ratio1=0.8,
        lr_decay_ratio2=0.9,
        lr_decay_rate1=0.316,
        lr_decay_rate2=0.1,

        n_streams=4,
        alpha_init=0.01,
        sinkhorn_tmax=20,
        mhc_dropout=0.00,

        feat_token_drop_p=0.05,
        input_noise_std=0.01,
        focal_gamma=1.5,
        label_smoothing=0.00,

        use_ema=True,
        ema_decay=0.999,
    )

BASE = make_base()

candidates = []
candidates.append(("mhc_384x8_main", dict(BASE, d_model=384, n_layers=8,  n_heads=8,  ffn_mult=4, dropout=0.18, attn_dropout=0.10)))
candidates.append(("mhc_384x10_reg", dict(BASE, d_model=384, n_layers=10, n_heads=8,  ffn_mult=4, dropout=0.24, attn_dropout=0.12,
                                         lr=1.6e-4, weight_decay=1.5e-2, patience=12,
                                         feat_token_drop_p=0.06, input_noise_std=0.012)))
candidates.append(("mhc_384x8_ffn2", dict(BASE, d_model=384, n_layers=8,  n_heads=8,  ffn_mult=2, dropout=0.16, attn_dropout=0.10,
                                         lr=2.2e-4, weight_decay=8e-3)))
candidates.append(("mhc_256x6_fast", dict(BASE, d_model=256, n_layers=6,  n_heads=8,  ffn_mult=4, dropout=0.16, attn_dropout=0.08,
                                         lr=3e-4, weight_decay=6e-3, epochs=min(int(BASE["epochs"]), 60), patience=9,
                                         feat_token_drop_p=0.04, input_noise_std=0.010)))

if device.type == "cuda" and MEM_GB >= 20:
    candidates.append(("mhc_512x10_big", dict(BASE, d_model=512, n_layers=10, n_heads=16, ffn_mult=4, dropout=0.26, attn_dropout=0.14,
                                             lr=1.2e-4, weight_decay=2.0e-2, epochs=max(int(BASE["epochs"]), 85), patience=12,
                                             mhc_dropout=0.05, feat_token_drop_p=0.06)))

print(f"\nTotal Transformer candidates: {len(candidates)}")
print("Primary score: OOF best F-beta (beta=0.5)")

# ----------------------------
# 11) Run 2-stage search
# ----------------------------
OUT_DIR = Path("/kaggle/working/recodai_luc_gate_artifacts")
OPT_DIR = OUT_DIR / "opt_search"
OPT_DIR.mkdir(parents=True, exist_ok=True)

STAGE1_PATH = OPT_DIR / "stage1_results.csv"

# stage-1 folds subset: evenly spaced indices
if STAGE1_FOLDS >= len(unique_folds):
    folds_subset = unique_folds
else:
    idxs = np.linspace(0, len(unique_folds) - 1, STAGE1_FOLDS)
    idxs = np.round(idxs).astype(int)
    idxs = np.unique(idxs).tolist()
    folds_subset = [unique_folds[i] for i in idxs]
    if len(folds_subset) < STAGE1_FOLDS:
        # pad if collision happened
        for f in unique_folds:
            if f not in folds_subset:
                folds_subset.append(f)
            if len(folds_subset) >= STAGE1_FOLDS:
                break

print("\nStage-1 folds subset:", folds_subset)

# resume-safe: load already evaluated stage1 cfg names
done_stage1 = set()
if STAGE1_PATH.exists():
    try:
        df_prev = pd.read_csv(STAGE1_PATH)
        if "cfg" in df_prev.columns:
            done_stage1 = set(df_prev["cfg"].astype(str).tolist())
            print(f"Resume: found {len(done_stage1)} configs already in stage1_results.csv")
    except Exception:
        pass

t0 = time.time()
stage1_rows = []
stage1_fold_rows = []

for i, (name, cfg) in enumerate(candidates, 1):
    if TIME_BUDGET_SEC and (time.time() - t0) > TIME_BUDGET_SEC:
        print("Time budget reached. Stop search.")
        break

    if name in done_stage1:
        print(f"\n[Stage-1 {i:02d}/{len(candidates)}] SKIP (already done) -> {name}")
        continue

    cfg1 = dict(cfg)
    cfg1["epochs"] = int(min(int(cfg1["epochs"]), int(STAGE1_EPOCH_CAP)))
    cfg1["patience"] = int(min(int(cfg1["patience"]), int(STAGE1_PAT_CAP)))

    print(f"\n[Stage-1 {i:02d}/{len(candidates)}] CV(subset) -> {name}")
    summ, fold_rows, _, _ = run_cv_config(cfg1, name, folds_subset=folds_subset, beta=BETA, thr_grid=101)

    stage1_rows.append(summ)
    stage1_fold_rows.extend(fold_rows)

    print(f"  stage1 best_fbeta: {summ['oof_best_fbeta']:.6f} | thr: {summ['oof_best_thr']:.3f} | logloss: {summ['oof_logloss']:.6f}")

    # append-progress to disk (resume-safe)
    try:
        df_append = pd.DataFrame([summ])
        if STAGE1_PATH.exists():
            df_old = pd.read_csv(STAGE1_PATH)
            df_new = pd.concat([df_old, df_append], axis=0, ignore_index=True)
        else:
            df_new = df_append
        df_new.to_csv(STAGE1_PATH, index=False)
    except Exception:
        pass

# build stage1 ranking from disk (preferred) + current in-memory
if STAGE1_PATH.exists():
    df_s1 = pd.read_csv(STAGE1_PATH)
else:
    df_s1 = pd.DataFrame(stage1_rows)

if len(df_s1) == 0:
    raise RuntimeError("Stage-1 menghasilkan 0 hasil. Cek runtime/VRAM atau kecilkan kandidat/epochs.")

df_s1 = df_s1.sort_values(["oof_best_fbeta","oof_logloss"], ascending=[False, True]).reset_index(drop=True)
print("\nStage-1 ranking (top):")
display(df_s1.head(10))

topM = min(int(STAGE2_TOPM), len(df_s1))
stage2_names = df_s1["cfg"].head(topM).astype(str).tolist()
print("\nStage-2 will run full CV for:", stage2_names)

all_summaries = []
all_fold_rows = []
oof_store = {}
pack_store = {}

for j, nm in enumerate(stage2_names, 1):
    if TIME_BUDGET_SEC and (time.time() - t0) > TIME_BUDGET_SEC:
        print("Time budget reached. Stop stage-2.")
        break

    cfg = None
    for (nname, ccfg) in candidates:
        if nname == nm:
            cfg = ccfg
            break
    if cfg is None:
        continue

    print(f"\n[Stage-2 {j:02d}/{len(stage2_names)}] CV(full) -> {nm}")
    summ, fold_rows, oof, fold_packs = run_cv_config(cfg, nm, folds_subset=None, beta=BETA, thr_grid=THR_GRID)

    all_summaries.append(summ)
    all_fold_rows.extend(fold_rows)
    oof_store[nm] = oof
    pack_store[nm] = fold_packs

    print(f"  OOF best_fbeta: {summ['oof_best_fbeta']:.6f} | thr: {summ['oof_best_thr']:.3f}"
          f" | auc: {(summ['oof_auc'] if summ['oof_auc'] is not None else float('nan')):.6f}"
          f" | logloss: {summ['oof_logloss']:.6f}")

df_sum = pd.DataFrame(all_summaries)
df_fold = pd.DataFrame(all_fold_rows)

if len(df_sum) == 0:
    raise RuntimeError("Stage-2 produced no results. Turunkan kandidat/epochs atau cek device/VRAM.")

df_sum = df_sum.sort_values(["oof_best_fbeta", "oof_logloss"], ascending=[False, True]).reset_index(drop=True)

print("\nStage-2 top candidates (full CV):")
display(df_sum)

# save search results
df_sum.to_csv(OPT_DIR / "opt_results.csv", index=False)
with open(OPT_DIR / "opt_results.json", "w") as f:
    json.dump(df_sum.to_dict(orient="records"), f, indent=2)
df_fold.to_csv(OPT_DIR / "opt_fold_details.csv", index=False)

# save OOF preds for top configs (debug)
top_names = df_sum["cfg"].head(min(REPORT_TOPK_OOF, len(df_sum))).astype(str).tolist()
for nm in top_names:
    df_o = pd.DataFrame({
        "uid": uids_all,
        "y": y_all,
        "fold": folds_all,
        f"oof_pred_{nm}": oof_store[nm]
    })
    df_o.to_csv(OPT_DIR / f"oof_preds_{nm}.csv", index=False)

# ----------------------------
# 12) Choose best config + save BEST fold packs (tanpa retrain ulang)
# ----------------------------
best_single = df_sum.iloc[0].to_dict()
best_cfg_name = str(best_single["cfg"])

best_cfg = None
for nm, cfg in candidates:
    if nm == best_cfg_name:
        best_cfg = cfg
        break
if best_cfg is None:
    raise RuntimeError("Best cfg not found in candidates list (unexpected).")

best_fold_packs = pack_store[best_cfg_name]
best_oof = oof_store[best_cfg_name]
best_oof_best = best_fbeta_fast(y_all, best_oof, beta=BETA, grid=THR_GRID)

best_model_path = OUT_DIR / "best_gate_model.pt"
torch.save(
    {
        "type": "mhc_lite_ft_transformer_v3",
        "feature_cols": FEATURE_COLS,
        "fold_packs": best_fold_packs,
        "cfg_name": best_cfg_name,
        "cfg": best_cfg,
        "seed": SEED,
        "beta_for_tuning": BETA,
        "recommended_thr": best_oof_best["thr"],
        "recommended_score": best_oof_best["fbeta"],
    },
    best_model_path
)

best_bundle = {
    "type": "mhc_lite_ft_transformer_v3",
    "model_name": best_cfg_name,
    "members": [best_cfg_name],
    "random_seed": SEED,
    "beta_for_tuning": BETA,

    "feature_cols": FEATURE_COLS,
    "cfg": best_cfg,

    "oof_best_thr": best_oof_best["thr"],
    "oof_best_fbeta": best_oof_best["fbeta"],
    "oof_best_prec": best_oof_best["precision"],
    "oof_best_rec": best_oof_best["recall"],
    "oof_auc": safe_auc(y_all, best_oof),
    "oof_logloss": safe_logloss(y_all, best_oof),

    "notes": "Best config from Step 4 (Transformer-only, mHC-lite differentiable + EMA + accum + reg). Best model saved as fold_packs (no retrain).",
}

with open(OUT_DIR / "best_gate_config.json", "w") as f:
    json.dump(best_bundle, f, indent=2)

print("\nSaved best artifacts:")
print("  best model (fold packs) ->", best_model_path)
print("  best config             ->", OUT_DIR / "best_gate_config.json")
print("  opt results             ->", OPT_DIR / "opt_results.csv")
print("  fold detail             ->", OPT_DIR / "opt_fold_details.csv")
print("  stage1 cache            ->", STAGE1_PATH)

# Export globals for Step 5
BEST_GATE_BUNDLE = best_bundle
BEST_TF_CFG_NAME = best_cfg_name
BEST_TF_CFG = best_cfg
OPT_RESULTS_DF = df_sum
BEST_TF_OOF = best_oof
BEST_TF_OOF_METRIC = best_oof_best


Optimize setup (Transformer-only, mHC-lite):
  rows=5176 | folds=5 | pos%=54.08 | n_features=84
Device: cuda | AMP: True
GPU mem GB: 15.887939453125

Total Transformer candidates: 4
Primary score: OOF best F-beta (beta=0.5)

Stage-1 folds subset: [0, 2, 4]

[Stage-1 01/4] CV(subset) -> mhc_384x8_main
  stage1 best_fbeta: 1.000000 | thr: 0.412 | logloss: 0.151695

[Stage-1 02/4] CV(subset) -> mhc_384x10_reg
  stage1 best_fbeta: 1.000000 | thr: 0.382 | logloss: 0.126817

[Stage-1 03/4] CV(subset) -> mhc_384x8_ffn2
  stage1 best_fbeta: 1.000000 | thr: 0.343 | logloss: 0.175593

[Stage-1 04/4] CV(subset) -> mhc_256x6_fast
  stage1 best_fbeta: 1.000000 | thr: 0.235 | logloss: 0.149304

Stage-1 ranking (top):


Unnamed: 0,cfg,stage,oof_auc,oof_logloss,oof_best_fbeta,oof_best_thr,oof_best_prec,oof_best_rec,d_model,n_layers,...,adam_eps,lr_decay_ratio1,lr_decay_ratio2,lr_decay_rate1,lr_decay_rate2,patience,min_delta,grad_clip,use_ema,ema_decay
0,mhc_384x10_reg,subset3,1.0,0.126817,1.0,0.3824,1.0,1.0,384,10,...,1e-08,0.8,0.9,0.316,0.1,6,0.0001,1.0,True,0.999
1,mhc_256x6_fast,subset3,1.0,0.149304,1.0,0.2354,1.0,1.0,256,6,...,1e-08,0.8,0.9,0.316,0.1,6,0.0001,1.0,True,0.999
2,mhc_384x8_main,subset3,1.0,0.151695,1.0,0.4118,1.0,1.0,384,8,...,1e-08,0.8,0.9,0.316,0.1,6,0.0001,1.0,True,0.999
3,mhc_384x8_ffn2,subset3,1.0,0.175593,1.0,0.3432,1.0,1.0,384,8,...,1e-08,0.8,0.9,0.316,0.1,6,0.0001,1.0,True,0.999



Stage-2 will run full CV for: ['mhc_384x10_reg', 'mhc_256x6_fast', 'mhc_384x8_main']

[Stage-2 01/3] CV(full) -> mhc_384x10_reg
  OOF best_fbeta: 1.000000 | thr: 0.074 | auc: 1.000000 | logloss: 0.030143

[Stage-2 02/3] CV(full) -> mhc_256x6_fast
  OOF best_fbeta: 1.000000 | thr: 0.201 | auc: 1.000000 | logloss: 0.096908

[Stage-2 03/3] CV(full) -> mhc_384x8_main
  OOF best_fbeta: 1.000000 | thr: 0.294 | auc: 1.000000 | logloss: 0.064275

Stage-2 top candidates (full CV):


Unnamed: 0,cfg,stage,oof_auc,oof_logloss,oof_best_fbeta,oof_best_thr,oof_best_prec,oof_best_rec,d_model,n_layers,...,adam_eps,lr_decay_ratio1,lr_decay_ratio2,lr_decay_rate1,lr_decay_rate2,patience,min_delta,grad_clip,use_ema,ema_decay
0,mhc_384x10_reg,full,1.0,0.030143,1.0,0.0737,1.0,1.0,384,10,...,1e-08,0.8,0.9,0.316,0.1,12,0.0001,1.0,True,0.999
1,mhc_384x8_main,full,1.0,0.064275,1.0,0.2942,1.0,1.0,384,8,...,1e-08,0.8,0.9,0.316,0.1,10,0.0001,1.0,True,0.999
2,mhc_256x6_fast,full,1.0,0.096908,1.0,0.2011,1.0,1.0,256,6,...,1e-08,0.8,0.9,0.316,0.1,9,0.0001,1.0,True,0.999



Saved best artifacts:
  best model (fold packs) -> /kaggle/working/recodai_luc_gate_artifacts/best_gate_model.pt
  best config             -> /kaggle/working/recodai_luc_gate_artifacts/best_gate_config.json
  opt results             -> /kaggle/working/recodai_luc_gate_artifacts/opt_search/opt_results.csv
  fold detail             -> /kaggle/working/recodai_luc_gate_artifacts/opt_search/opt_fold_details.csv
  stage1 cache            -> /kaggle/working/recodai_luc_gate_artifacts/opt_search/stage1_results.csv


# Final Training (Train on Full Data)

In [6]:
# ============================================================
# Step 5 — Final Training (Train on Full Data) — TRANSFORMER ONLY
# REVISI FULL v4.1 (match Step 4 v3.2: mHC-lite differentiable + EMA + accum)
#
# Fix v4.1:
# - STRICT match best cfg dari Step 4 (tanpa auto “bigger model” kecuali kamu ON-kan)
# - AMP/GradScaler aman (CPU = no-amp, no-scaler)
# - CFG guard: d_model harus divisible oleh n_heads (auto-fix saat fallback/OOM)
# - Internal val case-level: robust fallback kalau val kosong / 1-class
# - OOM fallback lebih aman: turunkan batch -> d_model -> n_layers -> n_heads
# - Scheduler total_steps guard (>=1) + warmup clamp
#
# Output:
#   /kaggle/working/recodai_luc_gate_artifacts/final_gate_model.pt
#   /kaggle/working/recodai_luc_gate_artifacts/final_gate_bundle.json
# ============================================================

import os, json, gc, math, time, warnings
from pathlib import Path
from contextlib import nullcontext

import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import log_loss, roc_auc_score

# ----------------------------
# 0) REQUIRE
# ----------------------------
if "df_train_tabular" not in globals():
    raise RuntimeError("Missing `df_train_tabular`. Jalankan Step 2 dulu.")
if "FEATURE_COLS" not in globals():
    raise RuntimeError("Missing `FEATURE_COLS`. Jalankan Step 2 dulu.")

df_train_tabular = df_train_tabular.copy()
FEATURE_COLS = list(FEATURE_COLS)

need_cols = {"uid", "case_id", "y"}
miss = [c for c in need_cols if c not in df_train_tabular.columns]
if miss:
    raise ValueError(f"df_train_tabular missing columns: {miss}")

X_all = df_train_tabular[FEATURE_COLS].to_numpy(dtype=np.float32, copy=True)
y_all = df_train_tabular["y"].to_numpy(dtype=np.int64, copy=True)

if not np.isfinite(X_all).all():
    X_all = np.nan_to_num(X_all, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)

OUT_DIR = Path("/kaggle/working/recodai_luc_gate_artifacts")
OUT_DIR.mkdir(parents=True, exist_ok=True)

print("Final training data:")
print(f"  rows={len(y_all)} | pos%={float(y_all.mean())*100:.2f} | n_features={X_all.shape[1]}")

# ----------------------------
# 1) Load best cfg (Step 4 output)
# ----------------------------
best_bundle = None
source = None

cfg_path = OUT_DIR / "best_gate_config.json"
best_model_candidates = [
    OUT_DIR / "best_gate_model.pt",
    OUT_DIR / "best_gate_model.pth",
]

if "BEST_GATE_BUNDLE" in globals() and isinstance(BEST_GATE_BUNDLE, dict):
    best_bundle = BEST_GATE_BUNDLE
    source = "memory(BEST_GATE_BUNDLE)"
elif cfg_path.exists():
    best_bundle = json.loads(cfg_path.read_text())
    source = str(cfg_path)

if best_bundle is not None and isinstance(best_bundle, dict) and isinstance(best_bundle.get("cfg", None), dict):
    base_cfg = dict(best_bundle["cfg"])
    print("\nLoaded cfg from:", source)
else:
    base_cfg = {}
    print("\nNo best_gate_config found. Using strong default cfg.")

# optional: load fold_packs dari best_gate_model.pt (biar final file bisa bawa ensemble fold juga)
fold_packs_from_step4 = None
best_gate_model_path = None
for p in best_model_candidates:
    if p.exists():
        best_gate_model_path = p
        break

if best_gate_model_path is not None:
    try:
        obj = torch.load(best_gate_model_path, map_location="cpu")
        if isinstance(obj, dict) and isinstance(obj.get("fold_packs", None), list):
            fold_packs_from_step4 = obj["fold_packs"]
            print("Loaded fold_packs from:", str(best_gate_model_path))
    except Exception as e:
        print("Warning: failed to load best_gate_model.* fold_packs:", repr(e))

# ----------------------------
# 2) Device + seed
# ----------------------------
def seed_everything(seed: int):
    import random
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

FINAL_SEED = 2025
if isinstance(best_bundle, dict):
    FINAL_SEED = int(best_bundle.get("seed", best_bundle.get("random_seed", 2025)))

seed_everything(FINAL_SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")

try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

vram_gb = None
if device.type == "cuda":
    vram_gb = float(torch.cuda.get_device_properties(0).total_memory / (1024**3))

print("\nDevice:", device, "| AMP:", use_amp, "| VRAM_GB:", (f"{vram_gb:.1f}" if vram_gb else "CPU"))

# ----------------------------
# 3) Training policy
# ----------------------------
# IMPORTANT: defaultnya STRICT match Step 4 best cfg (tidak dibesarkan otomatis)
ALLOW_UPSCALE = False           # set True kalau kamu mau auto-besar sesuai VRAM
USE_INTERNAL_VAL = True         # cari best_epoch dari val (case-level)
VAL_FRAC_CASE = 0.08            # 8% case untuk val
EARLY_STOP = True

# runtime: 1 seed default
N_SEEDS = 1

# target effective batch
TARGET_EFF_BATCH = 1024 if device.type == "cuda" else 256

# ----------------------------
# 4) Dataset + Standardizer
# ----------------------------
class TabDataset(Dataset):
    def __init__(self, X, y=None):
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = None if y is None else torch.from_numpy(y.astype(np.float32))
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        if self.y is None:
            return self.X[idx]
        return self.X[idx], self.y[idx]

def fit_standardizer(X_tr: np.ndarray):
    mu = X_tr.mean(axis=0, dtype=np.float64)
    sig = X_tr.std(axis=0, dtype=np.float64)
    sig = np.where(sig < 1e-8, 1.0, sig)
    return mu.astype(np.float32), sig.astype(np.float32)

def apply_standardizer(X_in: np.ndarray, mu: np.ndarray, sig: np.ndarray):
    return ((X_in - mu) / sig).astype(np.float32)

# ----------------------------
# 5) Metrics helpers
# ----------------------------
def safe_logloss(y_true, p):
    p = np.clip(np.asarray(p, dtype=np.float64), 1e-8, 1 - 1e-8)
    return float(log_loss(y_true, p, labels=[0, 1]))

def safe_auc(y_true, p):
    if len(np.unique(y_true)) < 2:
        return None
    return float(roc_auc_score(y_true, p))

# ----------------------------
# 6) Model: FTTransformer_MHCLite (MATCH Step 4)
# ----------------------------
class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-6):
        super().__init__()
        self.eps = float(eps)
        self.weight = nn.Parameter(torch.ones(d))
    def forward(self, x):
        rms = torch.mean(x * x, dim=-1, keepdim=True)
        x = x * torch.rsqrt(rms + self.eps)
        return x * self.weight

def sinkhorn_knopp(P, tmax=20, eps=1e-6):
    M = P.clamp_min(eps)
    for _ in range(int(tmax)):
        M = M / (M.sum(dim=-1, keepdim=True).clamp_min(eps))  # row
        M = M / (M.sum(dim=-2, keepdim=True).clamp_min(eps))  # col
    return M

class MHCLite(nn.Module):
    def __init__(self, d_model, n_streams=4, alpha_init=0.01, tmax=20, dropout=0.0):
        super().__init__()
        self.n = int(n_streams)
        self.tmax = int(tmax)
        self.drop = nn.Dropout(float(dropout))

        self.norm = RMSNorm(d_model, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, self.n * self.n),
        )
        self.softplus = nn.Softplus()

        a0 = float(alpha_init)
        a0 = min(max(a0, 1e-4), 1.0 - 1e-4)
        self.alpha_logit = nn.Parameter(torch.log(torch.tensor(a0 / (1 - a0), dtype=torch.float32)))

    def forward(self, streams, cls_vec):
        B, n, D = streams.shape
        h = self.norm(cls_vec)
        logits = self.mlp(h).view(B, n, n)
        P = self.softplus(logits)
        M = sinkhorn_knopp(P, tmax=self.tmax, eps=1e-6)

        alpha = torch.sigmoid(self.alpha_logit).to(dtype=streams.dtype, device=streams.device)
        I = torch.eye(n, device=streams.device, dtype=streams.dtype).unsqueeze(0).expand(B, -1, -1)
        H = (1.0 - alpha) * I + alpha * M

        mixed = torch.einsum("bij,bjd->bid", H, streams)
        injected = mixed + cls_vec.unsqueeze(1)
        return self.drop(injected)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, ffn_mult=4, dropout=0.2, attn_dropout=0.1):
        super().__init__()
        self.norm1 = RMSNorm(d_model, eps=1e-6)
        self.attn = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=int(n_heads),
            dropout=float(attn_dropout), batch_first=True
        )
        self.drop1 = nn.Dropout(float(dropout))

        self.norm2 = RMSNorm(d_model, eps=1e-6)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, int(ffn_mult) * d_model),
            nn.GELU(),
            nn.Dropout(float(dropout)),
            nn.Linear(int(ffn_mult) * d_model, d_model),
        )
        self.drop2 = nn.Dropout(float(dropout))

    def forward(self, x):
        h = self.norm1(x)
        attn_out, _ = self.attn(h, h, h, need_weights=False)
        x = x + self.drop1(attn_out)
        h = self.norm2(x)
        x = x + self.drop2(self.ffn(h))
        return x

class FTTransformer_MHCLite(nn.Module):
    def __init__(self, n_features, d_model=384, n_heads=8, n_layers=8, ffn_mult=4,
                 dropout=0.2, attn_dropout=0.1,
                 n_streams=4, alpha_init=0.01, sinkhorn_tmax=20, mhc_dropout=0.0,
                 feat_token_drop_p=0.0):
        super().__init__()
        self.n_features = int(n_features)
        self.d_model = int(d_model)
        self.n_layers = int(n_layers)
        self.feat_token_drop_p = float(feat_token_drop_p)

        self.w = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)
        self.b = nn.Parameter(torch.zeros(self.n_features, self.d_model))
        self.feat_emb = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)

        self.cls = nn.Parameter(torch.randn(1, 1, self.d_model) * 0.02)
        self.in_drop = nn.Dropout(float(dropout))

        self.blocks = nn.ModuleList([
            TransformerBlock(
                d_model=self.d_model,
                n_heads=n_heads,
                ffn_mult=ffn_mult,
                dropout=dropout,
                attn_dropout=attn_dropout
            ) for _ in range(self.n_layers)
        ])

        self.mhc = nn.ModuleList([
            MHCLite(
                d_model=self.d_model,
                n_streams=n_streams,
                alpha_init=alpha_init,
                tmax=sinkhorn_tmax,
                dropout=mhc_dropout
            ) for _ in range(self.n_layers)
        ])

        self.out_norm = RMSNorm(self.d_model, eps=1e-6)
        self.head = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.GELU(),
            nn.Dropout(float(dropout)),
            nn.Linear(self.d_model, 1),
        )

    def forward(self, x):
        tok = x.unsqueeze(-1) * self.w.unsqueeze(0) + self.b.unsqueeze(0)
        tok = tok + self.feat_emb.unsqueeze(0)

        if self.training and self.feat_token_drop_p > 0:
            B, F_, D = tok.shape
            keep = (torch.rand(B, F_, device=tok.device) > self.feat_token_drop_p).to(tok.dtype)
            tok = tok * keep.unsqueeze(-1)

        B = tok.size(0)
        cls = self.cls.expand(B, -1, -1)
        seq = torch.cat([cls, tok], dim=1)
        seq = self.in_drop(seq)

        nS = self.mhc[0].n
        streams = seq[:, 0, :].unsqueeze(1).expand(B, nS, self.d_model).contiguous()

        for l, blk in enumerate(self.blocks):
            cls_in = streams.mean(dim=1).unsqueeze(1)
            seq = torch.cat([cls_in, seq[:, 1:, :]], dim=1)
            seq = blk(seq)
            cls_vec = seq[:, 0, :]
            streams = self.mhc[l](streams, cls_vec)

        out = self.out_norm(streams.mean(dim=1))
        logit = self.head(out).squeeze(-1)
        return logit

# ----------------------------
# 7) EMA + scheduler (optimizer-steps)
# ----------------------------
class EMA:
    def __init__(self, model: nn.Module, decay: float = 0.999):
        self.decay = float(decay)
        self.shadow = {}
        self.backup = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.detach().clone()

    @torch.no_grad()
    def update(self, model: nn.Module):
        d = self.decay
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.shadow[name].mul_(d).add_(p.detach(), alpha=(1.0 - d))

    @torch.no_grad()
    def apply_shadow(self, model: nn.Module):
        self.backup = {}
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.backup[name] = p.detach().clone()
            p.copy_(self.shadow[name])

    @torch.no_grad()
    def restore(self, model: nn.Module):
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            p.copy_(self.backup[name])
        self.backup = {}

def make_warmup_step_scheduler(optimizer, total_steps, warmup_steps,
                              r1=0.8, r2=0.9, d1=0.316, d2=0.1):
    total_steps = int(max(1, total_steps))
    warmup_steps = int(max(0, min(int(warmup_steps), total_steps)))

    m1 = int(float(r1) * total_steps)
    m2 = int(float(r2) * total_steps)

    def lr_lambda(step):
        if warmup_steps > 0 and step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))
        mult = 1.0
        if step >= m1:
            mult *= float(d1)
        if step >= m2:
            mult *= float(d2)
        return mult

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

@torch.no_grad()
def predict_proba(model, loader, ema: EMA = None):
    model.eval()
    if ema is not None:
        ema.apply_shadow(model)

    ctx = torch.cuda.amp.autocast(enabled=True) if use_amp else nullcontext()
    ps = []
    for batch in loader:
        xb = batch[0] if isinstance(batch, (list, tuple)) else batch
        xb = xb.to(device, non_blocking=True)
        with ctx:
            logits = model(xb)
            p = torch.sigmoid(logits)
        ps.append(p.detach().cpu().numpy())

    out = np.concatenate(ps, axis=0).astype(np.float32)
    if ema is not None:
        ema.restore(model)
    return out

# ----------------------------
# 8) CFG merge + guards
# ----------------------------
CFG = dict(
    # arch
    d_model=384, n_layers=8, n_heads=8, ffn_mult=4,
    dropout=0.20, attn_dropout=0.10,

    # mHC-lite
    n_streams=4, alpha_init=0.01, sinkhorn_tmax=20, mhc_dropout=0.0,

    # regularization
    feat_token_drop_p=0.05,
    input_noise_std=0.01,
    focal_gamma=1.5,
    label_smoothing=0.00,

    # optim
    lr=2e-4,
    weight_decay=1.0e-2,
    beta1=0.9,
    beta2=0.95,
    adam_eps=1e-8,

    # sched
    warmup_frac=0.10,
    lr_decay_ratio1=0.8,
    lr_decay_ratio2=0.9,
    lr_decay_rate1=0.316,
    lr_decay_rate2=0.1,

    # train
    batch_size=512 if device.type == "cuda" else 256,
    epochs=75 if device.type == "cuda" else 35,
    accum_steps=2 if device.type == "cuda" else 1,
    grad_clip=1.0,
    patience=10,
    min_delta=1e-4,

    # EMA
    use_ema=True,
    ema_decay=0.999,
)

# merge best cfg -> CFG (strict match Step 4)
for k, v in base_cfg.items():
    if k in CFG:
        CFG[k] = v

def cpu_safe_cfg(cfg: dict):
    cfg = dict(cfg)
    if device.type != "cuda":
        # keep stable + fast on CPU
        cfg["d_model"] = min(int(cfg["d_model"]), 256)
        cfg["n_layers"] = min(int(cfg["n_layers"]), 6)
        cfg["n_heads"]  = min(int(cfg["n_heads"]), 8)
        cfg["batch_size"] = min(int(cfg["batch_size"]), 256)
        cfg["epochs"] = min(int(cfg["epochs"]), 35)
        cfg["accum_steps"] = 1
    return cfg

def maybe_upscale_cfg(cfg: dict):
    if not (device.type == "cuda" and ALLOW_UPSCALE and vram_gb is not None):
        return dict(cfg)
    cfg = dict(cfg)
    # mild upscale only (optional)
    if vram_gb >= 24:
        cfg["batch_size"] = min(int(cfg["batch_size"]), 512)
        cfg["accum_steps"] = max(int(cfg.get("accum_steps", 2)), 2)
    elif vram_gb >= 16:
        cfg["batch_size"] = min(int(cfg["batch_size"]), 384)
        cfg["accum_steps"] = max(int(cfg.get("accum_steps", 2)), 2)
    return cfg

def fix_heads_divisibility(cfg: dict):
    cfg = dict(cfg)
    d = int(cfg["d_model"])
    h = int(cfg["n_heads"])
    if h < 1:
        h = 1
    if d % h == 0:
        cfg["n_heads"] = h
        return cfg
    # find largest divisor <= h
    cand = []
    for k in range(h, 0, -1):
        if d % k == 0:
            cand.append(k)
            break
    cfg["n_heads"] = cand[0] if cand else 1
    return cfg

CFG = cpu_safe_cfg(CFG)
CFG = maybe_upscale_cfg(CFG)

# effective batch close to target
CFG["accum_steps"] = max(1, int(math.ceil(TARGET_EFF_BATCH / int(CFG["batch_size"]))))

# final divisibility guard
CFG = fix_heads_divisibility(CFG)

print("\nCFG (final):")
for k in [
    "d_model","n_layers","n_heads","ffn_mult","dropout","attn_dropout",
    "n_streams","alpha_init","sinkhorn_tmax","mhc_dropout",
    "batch_size","accum_steps","epochs","lr","weight_decay","warmup_frac","patience"
]:
    print(f"  {k}: {CFG[k]}")

# ----------------------------
# 9) Internal val split (case_id-safe) + robust fallback
# ----------------------------
def make_case_split(df: pd.DataFrame, val_frac=0.08, seed=2025):
    g = df.groupby("case_id")["y"].max().reset_index().rename(columns={"y": "case_y"})
    pos_cases = g.loc[g["case_y"] == 1, "case_id"].astype(str).to_numpy()
    neg_cases = g.loc[g["case_y"] == 0, "case_id"].astype(str).to_numpy()

    rng = np.random.RandomState(int(seed))
    rng.shuffle(pos_cases)
    rng.shuffle(neg_cases)

    if len(pos_cases) == 0 or len(neg_cases) == 0:
        # extreme case: cannot stratify by case; fallback to row stratified
        idx = np.arange(len(df))
        rng.shuffle(idx)
        n_val = max(1, int(len(df) * float(val_frac)))
        val_idx = idx[:n_val]
        is_val = np.zeros(len(df), dtype=bool)
        is_val[val_idx] = True
        return is_val

    n_val_pos = max(1, int(len(pos_cases) * float(val_frac)))
    n_val_neg = max(1, int(len(neg_cases) * float(val_frac)))

    val_cases = np.concatenate([pos_cases[:n_val_pos], neg_cases[:n_val_neg]])
    val_set = set(val_cases.tolist())
    is_val = df["case_id"].astype(str).isin(val_set).to_numpy(dtype=bool)

    # if too small / degenerate, fallback to row stratified
    va_idx = np.where(is_val)[0]
    if len(va_idx) < 32 or len(np.unique(df.loc[is_val, "y"].values)) < 2:
        idx_pos = np.where(df["y"].values == 1)[0]
        idx_neg = np.where(df["y"].values == 0)[0]
        rng.shuffle(idx_pos); rng.shuffle(idx_neg)
        n_val = max(32, int(len(df) * float(val_frac)))
        n_val_pos = max(1, int(n_val * float(df["y"].mean())))
        n_val_pos = min(n_val_pos, len(idx_pos))
        n_val_neg = min(n_val - n_val_pos, len(idx_neg))
        val_idx = np.concatenate([idx_pos[:n_val_pos], idx_neg[:n_val_neg]])
        is_val = np.zeros(len(df), dtype=bool)
        is_val[val_idx] = True

    return is_val

# ----------------------------
# 10) Training helpers (workers/amp/scaler)
# ----------------------------
def make_loader(ds, batch_size, shuffle):
    cpu_cnt = os.cpu_count() or 2
    nw = 2 if cpu_cnt >= 4 else 0
    pin = (device.type == "cuda")
    return DataLoader(
        ds,
        batch_size=int(batch_size),
        shuffle=bool(shuffle),
        num_workers=nw,
        pin_memory=pin,
        drop_last=False,
        persistent_workers=(nw > 0),
    )

def build_loss_fn(y_for_pos_weight, cfg):
    pos = float(np.sum(y_for_pos_weight))
    neg = float(len(y_for_pos_weight) - pos)
    pos_weight = torch.tensor([neg / max(1.0, pos)], device=device, dtype=torch.float32)

    focal_gamma = float(cfg.get("focal_gamma", 0.0))
    label_smoothing = float(cfg.get("label_smoothing", 0.0))

    def loss_fn(logits, targets):
        if label_smoothing and label_smoothing > 0:
            targets = targets * (1.0 - label_smoothing) + 0.5 * label_smoothing
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none", pos_weight=pos_weight)
        if focal_gamma and focal_gamma > 0:
            p = torch.sigmoid(logits)
            p_t = p * targets + (1 - p) * (1 - targets)
            mod = (1.0 - p_t).clamp_min(0.0).pow(focal_gamma)
            bce = bce * mod
        return bce.mean()

    return loss_fn

def build_model(cfg, n_features):
    cfg = fix_heads_divisibility(cfg)
    return FTTransformer_MHCLite(
        n_features=int(n_features),
        d_model=int(cfg["d_model"]),
        n_heads=int(cfg["n_heads"]),
        n_layers=int(cfg["n_layers"]),
        ffn_mult=int(cfg["ffn_mult"]),
        dropout=float(cfg["dropout"]),
        attn_dropout=float(cfg["attn_dropout"]),
        n_streams=int(cfg["n_streams"]),
        alpha_init=float(cfg["alpha_init"]),
        sinkhorn_tmax=int(cfg["sinkhorn_tmax"]),
        mhc_dropout=float(cfg["mhc_dropout"]),
        feat_token_drop_p=float(cfg.get("feat_token_drop_p", 0.0)),
    ).to(device)

def build_opt_and_sch(model, cfg, steps_per_epoch, epochs):
    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        weight_decay=float(cfg["weight_decay"]),
        betas=(float(cfg["beta1"]), float(cfg["beta2"])),
        eps=float(cfg["adam_eps"]),
    )
    accum = max(1, int(cfg.get("accum_steps", 1)))
    total_steps = int(max(1, int(epochs) * int(max(1, steps_per_epoch))))
    warmup_steps = int(float(cfg["warmup_frac"]) * total_steps)
    sch = make_warmup_step_scheduler(
        opt,
        total_steps=total_steps,
        warmup_steps=warmup_steps,
        r1=float(cfg["lr_decay_ratio1"]),
        r2=float(cfg["lr_decay_ratio2"]),
        d1=float(cfg["lr_decay_rate1"]),
        d2=float(cfg["lr_decay_rate2"]),
    )
    return opt, sch

# ----------------------------
# 11) Train with internal val to get best_epoch
# ----------------------------
def train_with_internal_val_get_best_epoch(X_raw, y_raw, cfg, seed=2025):
    seed_everything(int(seed))

    is_val = make_case_split(df_train_tabular, val_frac=float(VAL_FRAC_CASE), seed=int(seed))
    tr_idx = np.where(~is_val)[0]
    va_idx = np.where(is_val)[0]

    X_tr, y_tr = X_raw[tr_idx], y_raw[tr_idx]
    X_va, y_va = X_raw[va_idx], y_raw[va_idx]

    mu, sig = fit_standardizer(X_tr)
    X_trn = apply_standardizer(X_tr, mu, sig)
    X_van = apply_standardizer(X_va, mu, sig)

    ds_tr = TabDataset(X_trn, y_tr)
    ds_va = TabDataset(X_van, y_va)

    dl_tr = make_loader(ds_tr, cfg["batch_size"], shuffle=True)
    dl_va = make_loader(ds_va, cfg["batch_size"], shuffle=False)

    model = build_model(cfg, n_features=X_raw.shape[1])

    loss_fn = build_loss_fn(y_tr, cfg)

    accum = max(1, int(cfg.get("accum_steps", 1)))
    steps_per_epoch = int(math.ceil(len(dl_tr) / accum))
    opt, sch = build_opt_and_sch(model, cfg, steps_per_epoch=steps_per_epoch, epochs=int(cfg["epochs"]))

    scaler = torch.cuda.amp.GradScaler(enabled=True) if use_amp else None
    ema = EMA(model, decay=float(cfg.get("ema_decay", 0.999))) if bool(cfg.get("use_ema", True)) else None

    input_noise_std = float(cfg.get("input_noise_std", 0.0))

    best_val = 1e18
    best_epoch = -1
    bad = 0

    print(f"\nInternal val split: train={len(tr_idx)} | val={len(va_idx)} | val_pos%={float(y_va.mean())*100:.2f}")

    ctx = torch.cuda.amp.autocast(enabled=True) if use_amp else nullcontext()

    for epoch in range(int(cfg["epochs"])):
        model.train()
        opt.zero_grad(set_to_none=True)

        loss_sum = 0.0
        n_sum = 0
        micro = 0

        for xb, yb in dl_tr:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()

            if input_noise_std and input_noise_std > 0:
                xb = xb + torch.randn_like(xb) * input_noise_std

            with ctx:
                logits = model(xb)
                loss = loss_fn(logits, yb) / float(accum)

            if use_amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            micro += 1
            loss_sum += float(loss.item()) * xb.size(0) * float(accum)
            n_sum += xb.size(0)

            if (micro % accum) == 0:
                if float(cfg.get("grad_clip", 1.0)) > 0:
                    if use_amp:
                        scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg.get("grad_clip", 1.0)))

                if use_amp:
                    scaler.step(opt)
                    scaler.update()
                else:
                    opt.step()

                opt.zero_grad(set_to_none=True)
                sch.step()
                if ema is not None:
                    ema.update(model)

        # flush last partial
        if (micro % accum) != 0:
            if float(cfg.get("grad_clip", 1.0)) > 0:
                if use_amp:
                    scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg.get("grad_clip", 1.0)))

            if use_amp:
                scaler.step(opt)
                scaler.update()
            else:
                opt.step()

            opt.zero_grad(set_to_none=True)
            sch.step()
            if ema is not None:
                ema.update(model)

        p_va = predict_proba(model, dl_va, ema=ema)
        vll = safe_logloss(y_va, p_va)
        vauc = safe_auc(y_va, p_va)

        improved = (best_val - vll) > float(cfg["min_delta"])
        if improved:
            best_val = float(vll)
            best_epoch = int(epoch) + 1
            bad = 0
        else:
            bad += 1

        print(f"  epoch {epoch+1:03d}/{int(cfg['epochs'])} | tr_loss={loss_sum/max(1,n_sum):.5f} | val_ll={vll:.5f} | val_auc={(vauc if vauc is not None else float('nan')):.5f} | bad={bad}")

        if EARLY_STOP and bad >= int(cfg["patience"]):
            break

        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    if best_epoch < 1:
        best_epoch = max(12, int(cfg["epochs"] * 0.6))

    return {
        "best_epoch": int(best_epoch),
        "best_val_logloss": float(best_val) if best_val < 1e18 else None,
    }

# ----------------------------
# 12) Train FULL data for fixed epochs (best_epoch)
# ----------------------------
def train_full_fixed_epochs(X_raw, y_raw, cfg, epochs_fixed, seed=2025):
    seed_everything(int(seed))

    mu, sig = fit_standardizer(X_raw)
    Xn = apply_standardizer(X_raw, mu, sig)

    ds = TabDataset(Xn, y_raw)
    dl = make_loader(ds, cfg["batch_size"], shuffle=True)

    model = build_model(cfg, n_features=X_raw.shape[1])
    loss_fn = build_loss_fn(y_raw, cfg)

    accum = max(1, int(cfg.get("accum_steps", 1)))
    steps_per_epoch = int(math.ceil(len(dl) / accum))
    opt, sch = build_opt_and_sch(model, cfg, steps_per_epoch=steps_per_epoch, epochs=int(epochs_fixed))

    scaler = torch.cuda.amp.GradScaler(enabled=True) if use_amp else None
    ema = EMA(model, decay=float(cfg.get("ema_decay", 0.999))) if bool(cfg.get("use_ema", True)) else None
    input_noise_std = float(cfg.get("input_noise_std", 0.0))

    t0 = time.time()
    ctx = torch.cuda.amp.autocast(enabled=True) if use_amp else nullcontext()

    for epoch in range(int(epochs_fixed)):
        model.train()
        opt.zero_grad(set_to_none=True)

        loss_sum = 0.0
        n_sum = 0
        micro = 0

        for xb, yb in dl:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()

            if input_noise_std and input_noise_std > 0:
                xb = xb + torch.randn_like(xb) * input_noise_std

            with ctx:
                logits = model(xb)
                loss = loss_fn(logits, yb) / float(accum)

            if use_amp:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            micro += 1
            loss_sum += float(loss.item()) * xb.size(0) * float(accum)
            n_sum += xb.size(0)

            if (micro % accum) == 0:
                if float(cfg.get("grad_clip", 1.0)) > 0:
                    if use_amp:
                        scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg.get("grad_clip", 1.0)))

                if use_amp:
                    scaler.step(opt)
                    scaler.update()
                else:
                    opt.step()

                opt.zero_grad(set_to_none=True)
                sch.step()
                if ema is not None:
                    ema.update(model)

        # flush last partial
        if (micro % accum) != 0:
            if float(cfg.get("grad_clip", 1.0)) > 0:
                if use_amp:
                    scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg.get("grad_clip", 1.0)))

            if use_amp:
                scaler.step(opt)
                scaler.update()
            else:
                opt.step()

            opt.zero_grad(set_to_none=True)
            sch.step()
            if ema is not None:
                ema.update(model)

        print(f"  full epoch {epoch+1:03d}/{int(epochs_fixed)} | loss={loss_sum/max(1,n_sum):.5f}")

        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    # save EMA weights if enabled
    if ema is not None:
        ema.apply_shadow(model)

    pack = {
        "type": "mhc_lite_ft_transformer_full_v4",
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "mu": mu,
        "sig": sig,
        "cfg": dict(cfg),
        "seed": int(seed),
        "train_rows": int(len(y_raw)),
        "pos_rate": float(np.mean(y_raw)),
        "epochs_fixed": int(epochs_fixed),
        "accum_steps": int(accum),
        "train_time_s": float(time.time() - t0),
        "used_ema_weights": bool(ema is not None),
    }
    return pack

# ----------------------------
# 13) OOM fallback policy
# ----------------------------
def apply_oom_fallback(cfg: dict):
    cfg = dict(cfg)

    # 1) reduce batch first
    cfg["batch_size"] = max(64, int(cfg["batch_size"]) // 2)

    # recompute accum to keep eff batch
    cfg["accum_steps"] = max(1, int(math.ceil(TARGET_EFF_BATCH / int(cfg["batch_size"]))))

    # 2) reduce width/depth gradually
    cfg["d_model"] = max(256, int(cfg["d_model"]) - 64)
    cfg["n_layers"] = max(6, int(cfg["n_layers"]) - 2)

    # 3) fix heads divisibility (also may reduce heads)
    cfg = fix_heads_divisibility(cfg)

    return cfg

# ----------------------------
# 14) Train final (OOM-safe)
# ----------------------------
final_full_packs = []
internal_val_infos = []

for s in range(int(N_SEEDS)):
    seed_i = FINAL_SEED + s
    print(f"\n[Final Train v4.1] seed={seed_i}")

    cfg_run = dict(CFG)

    # retry loop for OOM
    for attempt in range(6):
        try:
            # Phase A: get best_epoch from internal val
            if USE_INTERNAL_VAL:
                info = train_with_internal_val_get_best_epoch(X_all, y_all, cfg_run, seed=seed_i)
                best_epoch = int(info["best_epoch"])
                internal_val_infos.append({"seed": seed_i, **info, "cfg_used": dict(cfg_run)})

                # retrain full data (slight +5% for stability, capped by cfg_run["epochs"])
                E_FULL = int(min(int(cfg_run["epochs"]), max(12, round(best_epoch * 1.05))))
                print(f"\nBest_epoch(from internal val)={best_epoch} -> Retrain FULL for E_FULL={E_FULL}")
            else:
                E_FULL = int(cfg_run["epochs"])

            full_pack = train_full_fixed_epochs(X_all, y_all, cfg_run, epochs_fixed=E_FULL, seed=seed_i)
            final_full_packs.append(full_pack)
            break

        except RuntimeError as e:
            msg = str(e).lower()
            if ("out of memory" in msg) and device.type == "cuda":
                print(f"  OOM detected (attempt {attempt+1}). Applying fallback.")
                torch.cuda.empty_cache()
                cfg_run = apply_oom_fallback(cfg_run)
                print("  New CFG after fallback:")
                for k in ["d_model","n_layers","n_heads","batch_size","accum_steps","epochs"]:
                    print(f"    {k}: {cfg_run[k]}")
                continue
            raise

    gc.collect()
    if device.type == "cuda":
        torch.cuda.empty_cache()

if len(final_full_packs) == 0:
    raise RuntimeError("Final training failed: no full_packs produced.")

# ----------------------------
# 15) Save artifacts
# ----------------------------
final_model_path = OUT_DIR / "final_gate_model.pt"

# threshold recommended from Step 4 (OOF)
best_thr = None
if isinstance(best_bundle, dict):
    best_thr = best_bundle.get("oof_best_thr", None)

torch.save(
    {
        "type": "final_gate_v4_1",
        "feature_cols": FEATURE_COLS,

        # keep fold ensemble from Step 4 if available
        "fold_packs": fold_packs_from_step4,

        # full-data trained packs (list; usually 1 seed)
        "full_packs": final_full_packs,

        # training meta
        "internal_val_infos": internal_val_infos,
        "recommended_thr": best_thr,
        "bundle_source": source,
        "seed_base": int(FINAL_SEED),
    },
    final_model_path
)

final_bundle = {
    "type": "final_gate_v4_1",
    "feature_cols": FEATURE_COLS,
    "n_seeds": int(len(final_full_packs)),
    "seeds": [int(p["seed"]) for p in final_full_packs],
    "cfg_final": dict(CFG),

    "use_internal_val": bool(USE_INTERNAL_VAL),
    "val_frac_case": float(VAL_FRAC_CASE) if USE_INTERNAL_VAL else 0.0,
    "early_stop": bool(EARLY_STOP) if USE_INTERNAL_VAL else False,

    "train_rows": int(len(y_all)),
    "pos_rate": float(np.mean(y_all)),

    "recommended_thr": best_thr,
    "has_fold_packs_from_step4": bool(fold_packs_from_step4 is not None),
    "best_cfg_source": source,
    "notes": "Final model uses Step4-matched mHC-lite FTTransformer with EMA+accum; best_epoch estimated via internal case-level val then retrain full.",
}

final_bundle_path = OUT_DIR / "final_gate_bundle.json"
final_bundle_path.write_text(json.dumps(final_bundle, indent=2))

print("\nSaved final training artifacts:")
print("  model  ->", final_model_path)
print("  bundle ->", final_bundle_path)

# Export globals
FINAL_GATE_MODEL_PT = str(final_model_path)
FINAL_GATE_BUNDLE = final_bundle


Final training data:
  rows=5176 | pos%=54.08 | n_features=84

Loaded cfg from: memory(BEST_GATE_BUNDLE)

Device: cuda | AMP: True | VRAM_GB: 15.9

CFG (final):
  d_model: 384
  n_layers: 10
  n_heads: 8
  ffn_mult: 4
  dropout: 0.24
  attn_dropout: 0.12
  n_streams: 4
  alpha_init: 0.01
  sinkhorn_tmax: 20
  mhc_dropout: 0.0
  batch_size: 256
  accum_steps: 4
  epochs: 75
  lr: 0.00016
  weight_decay: 0.015
  warmup_frac: 0.1
  patience: 12

[Final Train v4.1] seed=2025

Internal val split: train=4762 | val=414 | val_pos%=51.93
  epoch 001/75 | tr_loss=0.22686 | val_ll=0.69988 | val_auc=0.39817 | bad=0
  epoch 002/75 | tr_loss=0.19510 | val_ll=0.69912 | val_auc=0.40799 | bad=0
  epoch 003/75 | tr_loss=0.11404 | val_ll=0.69741 | val_auc=0.43477 | bad=0
  epoch 004/75 | tr_loss=0.03136 | val_ll=0.69472 | val_auc=0.48263 | bad=0
  epoch 005/75 | tr_loss=0.01792 | val_ll=0.69098 | val_auc=0.55512 | bad=0
  epoch 006/75 | tr_loss=0.00890 | val_ll=0.68635 | val_auc=0.62365 | bad=0
  epoch 0

# Finalize & Save Model Bundle (Reproducible)

In [7]:
# ============================================================
# Step 6 — Finalize & Save Model Bundle (Notebook-3 / Inference Notebook)
# REVISI FULL v5.1 (portable + robust source discovery + copies core files)
#
# Goals:
# - READ artifacts from /kaggle/input/... (read-only) OR /kaggle/working/... (if exists)
# - AUTO-find the *actual* bundle folder (supports nested model_bundle_v*/ etc.)
# - COPY core files into writable bundle folder:
#     /kaggle/working/recodai_luc_gate_artifacts/model_bundle_v5_notebook3/
# - WRITE thresholds/manifest/pack + ZIP to that bundle folder
# - Compatible with Step 5 v4+:
#     final_gate_model.pt may contain {fold_packs, full_packs, recommended_thr}
# ============================================================

import os, json, time, platform, warnings, zipfile, hashlib, shutil
from pathlib import Path

warnings.filterwarnings("ignore")

# ----------------------------
# 0) IO roots
# ----------------------------
OUT_ROOT = Path("/kaggle/working/recodai_luc_gate_artifacts")
OUT_ROOT.mkdir(parents=True, exist_ok=True)

BUNDLE_VERSION = "v5_notebook3"
OUT_DIR = OUT_ROOT / f"model_bundle_{BUNDLE_VERSION}"
OUT_DIR.mkdir(parents=True, exist_ok=True)

print("OUTPUT (write):", OUT_DIR)

# ----------------------------
# Helpers
# ----------------------------
def read_json_safe(p: Path, default=None):
    try:
        return json.loads(Path(p).read_text())
    except Exception:
        return default

def sha256_file(p: Path, chunk=1024 * 1024):
    p = Path(p)
    h = hashlib.sha256()
    with p.open("rb") as f:
        while True:
            b = f.read(chunk)
            if not b:
                break
            h.update(b)
    return h.hexdigest()

def file_meta(p: Path):
    p = Path(p)
    if not p.exists() or not p.is_file():
        return None
    st = p.stat()
    return {
        "path": str(p),
        "name": p.name,
        "bytes": int(st.st_size),
        "sha256": sha256_file(p),
        "mtime_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(st.st_mtime)),
    }

def safe_add(zf: zipfile.ZipFile, p: Path, arcname: str):
    if p is None:
        return
    p = Path(p)
    if p.exists() and p.is_file():
        zf.write(p, arcname=arcname)

def pick_first_existing(paths):
    for p in paths:
        if p is None:
            continue
        p = Path(p)
        if p.exists() and p.is_file():
            return p
    return None

def find_file_near(root: Path, filename: str, max_depth: int = 3):
    """
    Search root, then 1..max_depth levels deep for a file named filename.
    Cheap bounded search (no full rglob).
    """
    root = Path(root)
    if (root / filename).exists():
        return root / filename

    # depth 1
    for p1 in root.glob("*"):
        if p1.is_dir() and (p1 / filename).exists():
            return p1 / filename

    if max_depth >= 2:
        for p2 in root.glob("*/*"):
            if p2.is_dir() and (p2 / filename).exists():
                return p2 / filename

    if max_depth >= 3:
        for p3 in root.glob("*/*/*"):
            if p3.is_dir() and (p3 / filename).exists():
                return p3 / filename

    return None

def copy_if_needed(src: Path, dst: Path, verbose=True):
    src, dst = Path(src), Path(dst)
    dst.parent.mkdir(parents=True, exist_ok=True)
    if not src.exists():
        raise FileNotFoundError(f"Missing src: {src}")

    if dst.exists() and dst.is_file():
        try:
            if src.stat().st_size == dst.stat().st_size:
                # optional fast check: same size; verify by sha if you want strict
                # strict sha check (safe):
                if sha256_file(src) == sha256_file(dst):
                    if verbose:
                        print("  [skip copy] already identical:", dst.name)
                    return dst
        except Exception:
            pass

    shutil.copy2(src, dst)
    if verbose:
        print("  [copied] ->", dst)
    return dst

def _score_dir(p: Path):
    """
    Score candidate directory based on presence of key artifacts directly inside.
    """
    p = Path(p)
    score = 0
    if (p / "final_gate_model.pt").exists(): score += 100
    if (p / "final_gate_bundle.json").exists(): score += 30
    if (p / "feature_cols.json").exists(): score += 20
    if (p / "thresholds.json").exists(): score += 5
    if (p / "best_gate_config.json").exists(): score += 3
    if (p / "best_gate_model.pt").exists(): score += 3
    if (p / "opt_search" / "opt_results.csv").exists(): score += 1
    return score

def _gather_candidate_dirs():
    """
    Collect candidate directories that *might* contain final_gate_model.pt.
    Supports nested structures like:
      /kaggle/input/<ds>/recodai_luc_gate_artifacts/model_bundle_v3/final_gate_model.pt
    """
    cands = []

    # Working roots (if you ran training in same notebook)
    for base in [
        Path("/kaggle/working/recodai_luc_gate_artifacts"),
        OUT_ROOT,  # current bundle root too
    ]:
        if base.exists():
            cands.append(base)
            # common nested
            for sub in base.glob("model_bundle_*"):
                if sub.is_dir():
                    cands.append(sub)

    # Inputs (read-only datasets)
    inp = Path("/kaggle/input")
    if inp.exists():
        for ds in inp.iterdir():
            if not ds.is_dir():
                continue

            # bounded-depth directory walk: ds, ds/*, ds/*/*, ds/*/*/*
            level0 = [ds]
            level1 = [p for p in ds.glob("*") if p.is_dir()]
            level2 = [p for p in ds.glob("*/*") if p.is_dir()]
            level3 = [p for p in ds.glob("*/*/*") if p.is_dir()]
            for d in (level0 + level1 + level2 + level3):
                # if directory name hints artifacts, include it
                if d.name == "recodai_luc_gate_artifacts" or "recodai" in d.name.lower() or "bundle" in d.name.lower():
                    cands.append(d)
                # if it directly contains final model, include it
                if (d / "final_gate_model.pt").exists():
                    cands.append(d)
                # if it's the artifacts root, also include common nested bundles
                if d.name == "recodai_luc_gate_artifacts":
                    for sub in d.glob("model_bundle_*"):
                        if sub.is_dir():
                            cands.append(sub)

    # de-dup preserve order
    seen = set()
    out = []
    for p in cands:
        p = Path(p)
        sp = str(p)
        if sp not in seen and p.exists() and p.is_dir():
            seen.add(sp)
            out.append(p)
    return out

# ----------------------------
# 1) Pick best SRC_DIR
# ----------------------------
SRC_CANDS = _gather_candidate_dirs()
if len(SRC_CANDS) == 0:
    raise FileNotFoundError(
        "Tidak menemukan kandidat folder artifacts di /kaggle/working maupun /kaggle/input.\n"
        "Pastikan kamu sudah add dataset output training ke notebook ini."
    )

# If none has high score, we still pick best and do a deeper file search below.
SRC_DIR = max(SRC_CANDS, key=_score_dir)

print("\nSOURCE candidate picked:", SRC_DIR)
print("Score:", _score_dir(SRC_DIR))

# ----------------------------
# 2) Locate required artifacts (robust)
# ----------------------------
final_model_pt = find_file_near(SRC_DIR, "final_gate_model.pt", max_depth=3)
if final_model_pt is None:
    raise FileNotFoundError(
        "Missing final_gate_model.pt in selected SRC_DIR (even after nested search).\n"
        f"SRC_DIR={SRC_DIR}\n"
        "Cek struktur dataset input kamu (mungkin bundle ada lebih dalam)."
    )

# For bundle json and feature cols: search near final_model folder first (more reliable)
MODEL_DIR = final_model_pt.parent

final_bundle_json = find_file_near(MODEL_DIR, "final_gate_bundle.json", max_depth=2)
feature_cols_path = find_file_near(MODEL_DIR, "feature_cols.json", max_depth=2)

# Fallback: search near SRC_DIR if still missing
if final_bundle_json is None:
    final_bundle_json = find_file_near(SRC_DIR, "final_gate_bundle.json", max_depth=3)
if feature_cols_path is None:
    feature_cols_path = find_file_near(SRC_DIR, "feature_cols.json", max_depth=3)

if feature_cols_path is None:
    raise FileNotFoundError(
        "Missing feature_cols.json (jalankan Step 2 dulu di training notebook / pastikan ikut dibundle)."
    )

# Optional / extras
baseline_report_path = pick_first_existing([
    find_file_near(MODEL_DIR, "baseline_mhc_transformer_cv_report.json", max_depth=2),
    find_file_near(MODEL_DIR, "baseline_transformer_cv_report.json", max_depth=2),
    find_file_near(MODEL_DIR, "baseline_cv_report.json", max_depth=2),
    find_file_near(SRC_DIR,   "baseline_mhc_transformer_cv_report.json", max_depth=3),
    find_file_near(SRC_DIR,   "baseline_transformer_cv_report.json", max_depth=3),
    find_file_near(SRC_DIR,   "baseline_cv_report.json", max_depth=3),
])

best_gate_config_path = pick_first_existing([
    find_file_near(MODEL_DIR, "best_gate_config.json", max_depth=2),
    find_file_near(SRC_DIR,   "best_gate_config.json", max_depth=3),
])

best_gate_model_path = pick_first_existing([
    find_file_near(MODEL_DIR, "best_gate_model.pt", max_depth=2),
    find_file_near(SRC_DIR,   "best_gate_model.pt", max_depth=3),
])

opt_results_csv = pick_first_existing([
    find_file_near(MODEL_DIR, "opt_results.csv", max_depth=3),
    find_file_near(SRC_DIR,   "opt_results.csv", max_depth=3),
    (MODEL_DIR / "opt_search" / "opt_results.csv") if (MODEL_DIR / "opt_search" / "opt_results.csv").exists() else None,
    (SRC_DIR   / "opt_search" / "opt_results.csv") if (SRC_DIR   / "opt_search" / "opt_results.csv").exists() else None,
])

opt_fold_csv = pick_first_existing([
    find_file_near(MODEL_DIR, "opt_fold_details.csv", max_depth=3),
    find_file_near(SRC_DIR,   "opt_fold_details.csv", max_depth=3),
    (MODEL_DIR / "opt_search" / "opt_fold_details.csv") if (MODEL_DIR / "opt_search" / "opt_fold_details.csv").exists() else None,
    (SRC_DIR   / "opt_search" / "opt_fold_details.csv") if (SRC_DIR   / "opt_search" / "opt_fold_details.csv").exists() else None,
])

oof_baseline_csv = pick_first_existing([
    find_file_near(MODEL_DIR, "oof_baseline_mhc_transformer.csv", max_depth=2),
    find_file_near(MODEL_DIR, "oof_baseline_transformer.csv", max_depth=2),
    find_file_near(MODEL_DIR, "oof_baseline.csv", max_depth=2),
    find_file_near(SRC_DIR,   "oof_baseline_mhc_transformer.csv", max_depth=3),
    find_file_near(SRC_DIR,   "oof_baseline_transformer.csv", max_depth=3),
    find_file_near(SRC_DIR,   "oof_baseline.csv", max_depth=3),
])

print("\nFound artifacts (read):")
print("  final_model        :", final_model_pt)
print("  final_bundle       :", final_bundle_json if (final_bundle_json and final_bundle_json.exists()) else "(missing/skip)")
print("  feature_cols       :", feature_cols_path)
print("  best_gate_config   :", best_gate_config_path if best_gate_config_path else "(missing/skip)")
print("  best_gate_model    :", best_gate_model_path if best_gate_model_path else "(missing/skip)")
print("  baseline_report    :", baseline_report_path if baseline_report_path else "(missing/skip)")
print("  opt_results_csv    :", opt_results_csv if opt_results_csv else "(missing/skip)")
print("  opt_fold_csv       :", opt_fold_csv if opt_fold_csv else "(missing/skip)")
print("  oof_baseline_csv   :", oof_baseline_csv if oof_baseline_csv else "(missing/skip)")

# ----------------------------
# 3) COPY core artifacts into OUT_DIR (portable bundle)
# ----------------------------
print("\nCopying core files into OUT_DIR (portable):")
final_model_dst = copy_if_needed(final_model_pt, OUT_DIR / "final_gate_model.pt")
final_bundle_dst = None
if final_bundle_json is not None and final_bundle_json.exists():
    final_bundle_dst = copy_if_needed(final_bundle_json, OUT_DIR / "final_gate_bundle.json")
feature_cols_dst = copy_if_needed(feature_cols_path, OUT_DIR / "feature_cols.json")

# Copy extras (optional, but useful)
extras_dir = OUT_DIR / "extras"
extras_dir.mkdir(parents=True, exist_ok=True)

baseline_report_dst = None
if baseline_report_path:
    baseline_report_dst = copy_if_needed(baseline_report_path, extras_dir / Path(baseline_report_path).name, verbose=False)

best_gate_config_dst = None
if best_gate_config_path:
    best_gate_config_dst = copy_if_needed(best_gate_config_path, extras_dir / Path(best_gate_config_path).name, verbose=False)

best_gate_model_dst = None
if best_gate_model_path:
    best_gate_model_dst = copy_if_needed(best_gate_model_path, extras_dir / Path(best_gate_model_path).name, verbose=False)

opt_dir = OUT_DIR / "opt_search"
opt_dir.mkdir(parents=True, exist_ok=True)
opt_results_dst = None
if opt_results_csv:
    opt_results_dst = copy_if_needed(opt_results_csv, opt_dir / Path(opt_results_csv).name, verbose=False)

opt_fold_dst = None
if opt_fold_csv:
    opt_fold_dst = copy_if_needed(opt_fold_csv, opt_dir / Path(opt_fold_csv).name, verbose=False)

oof_dir = OUT_DIR / "oof"
oof_dir.mkdir(parents=True, exist_ok=True)
oof_baseline_dst = None
if oof_baseline_csv:
    oof_baseline_dst = copy_if_needed(oof_baseline_csv, oof_dir / Path(oof_baseline_csv).name, verbose=False)

# ----------------------------
# 4) Load metadata (from copied files)
# ----------------------------
feature_cols = read_json_safe(feature_cols_dst, default=[])
if not isinstance(feature_cols, list) or len(feature_cols) == 0:
    raise ValueError(f"feature_cols invalid/empty: {feature_cols_dst}")

final_bundle = read_json_safe(final_bundle_dst, default={}) if final_bundle_dst else {}
baseline_report = read_json_safe(baseline_report_dst, default=None) if baseline_report_dst else None
best_gate_config = read_json_safe(best_gate_config_dst, default=None) if best_gate_config_dst else None

# recommended_thr from final_gate_model.pt
recommended_thr_from_pt = None
try:
    import torch
    obj = torch.load(final_model_dst, map_location="cpu")
    if isinstance(obj, dict):
        recommended_thr_from_pt = obj.get("recommended_thr", None)
except Exception as e:
    print("Warning: failed to read final_gate_model.pt for recommended_thr:", repr(e))

# ----------------------------
# 5) Thresholds resolve (robust priority)
# Priority:
#   (a) SRC thresholds.json near original model dir
#   (b) existing OUT thresholds.json (if rerun)
#   (c) final_gate_bundle.json -> recommended_thr
#   (d) final_gate_model.pt -> recommended_thr
#   (e) best_gate_config.json -> oof_best_thr OR selection.oof_best_thr (legacy)
#   (f) fallback 0.5
# ----------------------------
def extract_thr_from_best_gate_config(cfg: dict):
    if not isinstance(cfg, dict):
        return None
    if "oof_best_thr" in cfg:
        try:
            return float(cfg["oof_best_thr"])
        except Exception:
            pass
    sel = cfg.get("selection", None)
    if isinstance(sel, dict) and ("oof_best_thr" in sel):
        try:
            return float(sel["oof_best_thr"])
        except Exception:
            pass
    return None

T_gate = None
src_thresh = find_file_near(MODEL_DIR, "thresholds.json", max_depth=2) or find_file_near(SRC_DIR, "thresholds.json", max_depth=3)
out_thresh = OUT_DIR / "thresholds.json"

# (a) SRC thresholds.json
if src_thresh and src_thresh.exists():
    tj = read_json_safe(src_thresh, default={})
    if isinstance(tj, dict) and tj.get("T_gate", None) is not None:
        try:
            T_gate = float(tj["T_gate"])
        except Exception:
            T_gate = None

# (b) OUT thresholds.json (existing)
if T_gate is None and out_thresh.exists():
    tj = read_json_safe(out_thresh, default={})
    if isinstance(tj, dict) and tj.get("T_gate", None) is not None:
        try:
            T_gate = float(tj["T_gate"])
        except Exception:
            T_gate = None

# (c) final_bundle recommended_thr
if T_gate is None and isinstance(final_bundle, dict):
    try:
        if final_bundle.get("recommended_thr", None) is not None:
            T_gate = float(final_bundle["recommended_thr"])
    except Exception:
        T_gate = None

# (d) final_model.pt recommended_thr
if T_gate is None and recommended_thr_from_pt is not None:
    try:
        T_gate = float(recommended_thr_from_pt)
    except Exception:
        T_gate = None

# (e) best_gate_config oof_best_thr
if T_gate is None and isinstance(best_gate_config, dict):
    T_gate = extract_thr_from_best_gate_config(best_gate_config)

# (f) fallback
if T_gate is None:
    T_gate = 0.5

thresholds = {
    "T_gate": float(T_gate),
    "beta_for_tuning": float(best_gate_config.get("beta_for_tuning", 0.5)) if isinstance(best_gate_config, dict) else 0.5,
    "guards": {"min_area_frac": None, "max_area_frac": None, "max_components": None},
    "source_priority": [
        "SRC thresholds.json (near model)",
        "OUT_DIR/thresholds.json (existing)",
        "final_gate_bundle.json.recommended_thr",
        "final_gate_model.pt.recommended_thr",
        "best_gate_config.json.oof_best_thr (or selection.oof_best_thr legacy)",
        "fallback 0.5",
    ],
    "notes": "Gate threshold used for binary decision. Update after calibration/OOF tuning if needed.",
}

thresholds_path = OUT_DIR / "thresholds.json"
thresholds_path.write_text(json.dumps(thresholds, indent=2))

print("\nThreshold resolved:")
print("  T_gate:", thresholds["T_gate"])
print("  thresholds.json ->", thresholds_path)

# ----------------------------
# 6) cfg_meta (optional)
# ----------------------------
cfg_meta = {}
if "PATHS" in globals() and isinstance(PATHS, dict):
    cfg_meta = {k: PATHS.get(k, None) for k in [
        "COMP_ROOT","OUT_DS_ROOT","OUT_ROOT","MATCH_CFG_DIR","PRED_CFG_DIR","DINO_CFG_DIR","DINO_LARGE_DIR",
        "PRED_FEAT_TRAIN","MATCH_FEAT_TRAIN","DF_TRAIN_ALL","CV_CASE_FOLDS","IMG_PROFILE_TRAIN"
    ]}

# ----------------------------
# 7) Manifest (reproducible) + sha256 index (for OUT_DIR files)
# ----------------------------
task_str = "Recod.ai/LUC — Gate Model — DINOv2 features + Transformer gate (.pt)"
model_format = "torch_pt"

artifact_paths = [
    final_model_dst,
    final_bundle_dst if final_bundle_dst else None,
    feature_cols_dst,
    thresholds_path,
    baseline_report_dst,
    best_gate_config_dst,
    best_gate_model_dst,
    opt_results_dst,
    opt_fold_dst,
    oof_baseline_dst,
]
artifact_paths = [p for p in artifact_paths if p is not None]

artifacts_meta = {}
for p in artifact_paths:
    m = file_meta(p)
    if m is not None:
        artifacts_meta[m["name"]] = m

opt_summary = None
if isinstance(best_gate_config, dict):
    opt_summary = best_gate_config.get("selection", None)
    if opt_summary is None:
        opt_summary = {
            "model_name": best_gate_config.get("model_name", None),
            "oof_best_thr": best_gate_config.get("oof_best_thr", None),
            "oof_best_fbeta": best_gate_config.get("oof_best_fbeta", None),
            "oof_auc": best_gate_config.get("oof_auc", None),
            "oof_logloss": best_gate_config.get("oof_logloss", None),
        }

manifest = {
    "created_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
    "python": platform.python_version(),
    "platform": platform.platform(),
    "bundle_version": BUNDLE_VERSION,
    "task": task_str,
    "model_format": model_format,
    "source_dir_selected": str(SRC_DIR),
    "source_model_dir": str(MODEL_DIR),
    "output_dir": str(OUT_DIR),
    "artifacts_index": artifacts_meta,
    "cfg_meta": cfg_meta,
    "model_summary": {
        "type": (final_bundle.get("type") if isinstance(final_bundle, dict) else None),
        "n_seeds": (final_bundle.get("n_seeds") if isinstance(final_bundle, dict) else None),
        "seeds": (final_bundle.get("seeds") if isinstance(final_bundle, dict) else None),
        "train_rows": (final_bundle.get("train_rows") if isinstance(final_bundle, dict) else None),
        "pos_rate": (final_bundle.get("pos_rate") if isinstance(final_bundle, dict) else None),
        "feature_count": int(len(feature_cols)),
        "T_gate": float(thresholds.get("T_gate", 0.5)),
        "recommended_thr_from_pt": recommended_thr_from_pt,
    },
    "baseline_summary": (baseline_report.get("overall") if isinstance(baseline_report, dict) else None),
    "opt_summary": opt_summary,
}

manifest_path = OUT_DIR / "model_bundle_manifest.json"
manifest_path.write_text(json.dumps(manifest, indent=2))

# ----------------------------
# 8) Bundle pack (portable JSON) + optional joblib
# ----------------------------
bundle_pack = {
    "bundle_version": BUNDLE_VERSION,
    "model_format": model_format,
    "bundle_files": {
        "final_gate_model.pt": "final_gate_model.pt",
        "final_gate_bundle.json": "final_gate_bundle.json" if final_bundle_dst else None,
        "feature_cols.json": "feature_cols.json",
        "thresholds.json": "thresholds.json",
        "model_bundle_manifest.json": "model_bundle_manifest.json",
        "model_bundle_pack.json": "model_bundle_pack.json",
        "model_bundle_pack.joblib": "model_bundle_pack.joblib",
        "model_bundle_v5_notebook3.zip": "model_bundle_v5_notebook3.zip",
    },
    "thresholds": thresholds,
    "feature_cols": feature_cols,
    "cfg_meta": cfg_meta,
    "manifest": manifest,
}

bundle_pack_json = OUT_DIR / "model_bundle_pack.json"
bundle_pack_json.write_text(json.dumps(bundle_pack, indent=2))

bundle_pack_joblib = OUT_DIR / "model_bundle_pack.joblib"
joblib_ok = False
try:
    import joblib
    joblib.dump(bundle_pack, bundle_pack_joblib)
    joblib_ok = True
except Exception:
    joblib_ok = False

# ----------------------------
# 9) Create portable ZIP (writes to OUT_DIR)
# ----------------------------
zip_path = OUT_DIR / "model_bundle_v5_notebook3.zip"

with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
    # core files (from OUT_DIR)
    safe_add(zf, final_model_dst, "final_gate_model.pt")
    safe_add(zf, final_bundle_dst, "final_gate_bundle.json")
    safe_add(zf, feature_cols_dst, "feature_cols.json")
    safe_add(zf, thresholds_path, "thresholds.json")
    safe_add(zf, manifest_path, "model_bundle_manifest.json")
    safe_add(zf, bundle_pack_json, "model_bundle_pack.json")
    if joblib_ok:
        safe_add(zf, bundle_pack_joblib, "model_bundle_pack.joblib")

    # extras (already copied into OUT_DIR)
    for p in (extras_dir.glob("*") if extras_dir.exists() else []):
        safe_add(zf, p, f"extras/{p.name}")

    # opt_search
    if opt_dir.exists():
        for p in opt_dir.glob("*"):
            safe_add(zf, p, f"opt_search/{p.name}")

    # oof
    if oof_dir.exists():
        for p in oof_dir.glob("*"):
            safe_add(zf, p, f"oof/{p.name}")

print("\nOK — Model bundle finalized (Notebook-3 compatible)")
print("  SRC_DIR selected ->", SRC_DIR)
print("  MODEL_DIR        ->", MODEL_DIR)
print("  OUT_DIR          ->", OUT_DIR)
print("  manifest         ->", manifest_path)
print("  pack (json)      ->", bundle_pack_json)
print("  pack (joblib)    ->", (bundle_pack_joblib if joblib_ok else "(skip; joblib not available)"))
print("  thresholds       ->", thresholds_path)
print("  zip              ->", zip_path)

print("\nBundle summary:")
print("  bundle_version:", BUNDLE_VERSION)
print("  model_format  :", model_format)
print("  feature_cnt   :", len(feature_cols))
print("  T_gate        :", thresholds.get("T_gate"))
print("  task          :", task_str)


OUTPUT (write): /kaggle/working/recodai_luc_gate_artifacts/model_bundle_v5_notebook3

SOURCE candidate picked: /kaggle/working/recodai_luc_gate_artifacts
Score: 157

Found artifacts (read):
  final_model        : /kaggle/working/recodai_luc_gate_artifacts/final_gate_model.pt
  final_bundle       : /kaggle/working/recodai_luc_gate_artifacts/final_gate_bundle.json
  feature_cols       : /kaggle/working/recodai_luc_gate_artifacts/feature_cols.json
  best_gate_config   : /kaggle/working/recodai_luc_gate_artifacts/best_gate_config.json
  best_gate_model    : /kaggle/working/recodai_luc_gate_artifacts/best_gate_model.pt
  baseline_report    : (missing/skip)
  opt_results_csv    : /kaggle/working/recodai_luc_gate_artifacts/opt_search/opt_results.csv
  opt_fold_csv       : /kaggle/working/recodai_luc_gate_artifacts/opt_search/opt_fold_details.csv
  oof_baseline_csv   : (missing/skip)

Copying core files into OUT_DIR (portable):
  [copied] -> /kaggle/working/recodai_luc_gate_artifacts/model_bun