# Set Paths & Select Config (CFG)

In [1]:
# ============================================================
# STAGE 0 — Set Paths & Select Config (CFG) (Kaggle-ready, offline)
# REVISI FULL v2.2 (lebih kuat + bantu performa model lewat pemilihan CFG terbaik)
#
# Upgrade utama v2.2:
# - Cache root multi-sumber: utamakan /kaggle/working/recodai_luc jika ada (hasil run terbaru),
#   fallback ke dataset input (OUT_ROOT) — jadi training pakai artefak paling update.
# - Pemilihan CFG pakai scoring (bukan cuma train rows):
#     * wajib: train feature ada
#     * prefer: test feature ada (match_features_test / pred_features_test)
#     * tie-break: train rows, test rows, latest modified time
# - Auto-detect DINOv2 model dir (large/giant/base) offline
# - Tambah PATHS penting untuk TRAIN/INFER: manifest_pred_*, pred_summary.json, dll.
# - Sanity guard lebih informatif + tidak nge-crash untuk file opsional.
#
# Output globals (tetap, JANGAN diganti namanya):
# - COMP_ROOT, OUT_DS_ROOT, OUT_ROOT
# - PATHS (dict)
# - MATCH_CFG_DIR, PRED_CFG_DIR, DINO_CFG_DIR
# Extra globals (opsional, membantu tahap training):
# - CACHE_ROOTS (list), SELECTED (dict), TRAIN_PLAN (dict)
# ============================================================

import os, re, json, time
from pathlib import Path
import pandas as pd

# ----------------------------
# Helper: fast count rows CSV
# ----------------------------
def _fast_count_rows_csv(path: Path) -> int:
    try:
        with path.open("r", encoding="utf-8", errors="ignore") as f:
            n = sum(1 for _ in f) - 1
        return int(max(n, 0))
    except Exception:
        return -1

def _safe_mtime(p: Path) -> float:
    try:
        return float(p.stat().st_mtime)
    except Exception:
        return 0.0

def _is_nonempty_file(p: Path) -> bool:
    try:
        return p.exists() and p.is_file() and p.stat().st_size > 0
    except Exception:
        return False

# ----------------------------
# Helper: find competition root
# ----------------------------
def find_comp_root(preferred: str = "/kaggle/input/recodai-luc-scientific-image-forgery-detection") -> Path:
    p = Path(preferred)
    if p.exists():
        return p

    base = Path("/kaggle/input")
    if not base.exists():
        raise FileNotFoundError("/kaggle/input tidak ditemukan (pastikan kamu di Kaggle Notebook).")

    cands = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        if (d / "sample_submission.csv").exists() and ((d / "train_images").exists() or (d / "test_images").exists()):
            cands.append(d)

    if not cands:
        for d in base.iterdir():
            if not d.is_dir():
                continue
            for x in d.iterdir():
                if not x.is_dir():
                    continue
                if (x / "sample_submission.csv").exists() and ((x / "train_images").exists() or (x / "test_images").exists()):
                    cands.append(x)

    if not cands:
        raise FileNotFoundError(
            "COMP_ROOT tidak ditemukan. Harus ada folder yang memuat sample_submission.csv dan train_images/test_images."
        )

    cands.sort(key=lambda x: (("recodai" not in x.name.lower()), ("forgery" not in x.name.lower()), x.name))
    return cands[0]

# ----------------------------
# Helper: find output dataset root (hasil PREP)
# ----------------------------
def find_output_dataset_root(preferred_names=(
    "recod-ailuc-dinov2-base",
    "recod-ai-luc-dinov2-base",
    "recodai-luc-dinov2-base",
    "recodai-luc-dinov2",
    "recodai-luc-dinov2-prep",
)) -> Path:
    base = Path("/kaggle/input")

    for nm in preferred_names:
        p = base / nm
        if p.exists():
            return p

    cands = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        if (d / "recodai_luc" / "artifacts").exists():
            cands.append(d)
            continue
        inner = list(d.glob("*/recodai_luc/artifacts"))
        if inner:
            cands.append(d)

    if not cands:
        raise FileNotFoundError("OUT_DS_ROOT tidak ditemukan. Harus ada /kaggle/input/<...>/recodai_luc/artifacts/")

    cands.sort(key=lambda x: (("dinov2" not in x.name.lower()), x.name))
    return cands[0]

# ----------------------------
# Helper: resolve OUT_ROOT = <dataset>/recodai_luc
# ----------------------------
def resolve_out_root(out_ds_root: Path) -> Path:
    direct = out_ds_root / "recodai_luc"
    if direct.exists():
        return direct
    hits = list(out_ds_root.glob("*/recodai_luc"))
    if hits:
        return hits[0]
    raise FileNotFoundError(f"Folder recodai_luc tidak ditemukan di bawah {out_ds_root}")

# ----------------------------
# Helper: pick best cfg directory by multi-criteria scoring
# ----------------------------
def pick_best_cfg(cache_roots, prefix: str, feat_train_filename: str, feat_test_filename: str = None) -> Path:
    """
    cache_roots: list[Path] kandidat cache root
    prefix: 'match_base_cfg_' atau 'pred_base'
    feat_train_filename: file wajib
    feat_test_filename : file opsional yang diprefer (kalau ada)
    """
    cands = []
    for root in cache_roots:
        root = Path(root)
        if not root.exists():
            continue
        for d in root.iterdir():
            if not d.is_dir():
                continue
            if not d.name.startswith(prefix):
                continue

            train_fp = d / feat_train_filename
            if not train_fp.exists():
                continue

            train_n = _fast_count_rows_csv(train_fp)
            test_fp = (d / feat_test_filename) if feat_test_filename else None
            test_ok = _is_nonempty_file(test_fp) if test_fp else False
            test_n = _fast_count_rows_csv(test_fp) if (test_fp and test_fp.exists()) else -1

            # latest time: pakai dir atau file train
            mt = max(_safe_mtime(d), _safe_mtime(train_fp), _safe_mtime(test_fp) if test_fp else 0.0)

            # scoring: prefer punya test file, lalu train rows besar, lalu test rows besar, lalu terbaru
            score = 0.0
            if train_n <= 0:
                score -= 1e6
            score += 1e5 * (1.0 if test_ok else 0.0)
            score += 1.0 * float(max(train_n, 0))
            score += 0.1 * float(max(test_n, 0))
            score += 1e-6 * float(mt)

            cands.append((score, d, root, train_fp, test_fp, train_n, test_n, test_ok, mt))

    if not cands:
        raise FileNotFoundError(
            f"Tidak ada CFG folder untuk prefix='{prefix}' dengan file '{feat_train_filename}' pada cache_roots={cache_roots}"
        )

    cands.sort(key=lambda x: (-x[0], x[1].name))
    best = cands[0]
    return best[1]

# ----------------------------
# Helper: detect best DINO model dir (offline)
# ----------------------------
def detect_dino_dir() -> Path:
    # urutan prefer: large -> giant -> base
    cands = [
        Path("/kaggle/input/dinov2/pytorch/large/1"),
        Path("/kaggle/input/dinov2/pytorch/giant/1"),
        Path("/kaggle/input/dinov2/pytorch/base/1"),
    ]
    for p in cands:
        if p.exists():
            return p
    return Path("/kaggle/input/dinov2/pytorch/large/1")  # default (mungkin missing; hanya warning)

# ============================================================
# 0) Locate roots
# ============================================================
COMP_ROOT = find_comp_root("/kaggle/input/recodai-luc-scientific-image-forgery-detection")
OUT_DS_ROOT = find_output_dataset_root()
OUT_ROOT = resolve_out_root(OUT_DS_ROOT)  # dataset input: .../recodai_luc

# Kandidat cache/artifacts:
# - hasil run terbaru di /kaggle/working/recodai_luc (kalau ada)
# - fallback dataset input OUT_ROOT
WORK_OUT_ROOT = Path("/kaggle/working/recodai_luc")
CACHE_ROOTS = []
ART_ROOTS = []

if (WORK_OUT_ROOT / "cache").exists() and (WORK_OUT_ROOT / "artifacts").exists():
    CACHE_ROOTS.append(WORK_OUT_ROOT / "cache")
    ART_ROOTS.append(WORK_OUT_ROOT / "artifacts")

CACHE_ROOTS.append(Path(OUT_ROOT) / "cache")
ART_ROOTS.append(Path(OUT_ROOT) / "artifacts")

# pilih artifacts root yang valid (prioritas: working)
ART_DIR = None
for a in ART_ROOTS:
    if a.exists():
        ART_DIR = a
        break
if ART_DIR is None:
    raise FileNotFoundError("ART_DIR tidak ditemukan di /kaggle/working maupun dataset input.")

# pilih cache root list yang exist
CACHE_DIRS = [p for p in CACHE_ROOTS if Path(p).exists()]
if not CACHE_DIRS:
    raise FileNotFoundError("CACHE_DIR tidak ditemukan di /kaggle/working maupun dataset input.")

# ============================================================
# 1) Competition paths (raw images/masks)
# ============================================================
PATHS = {}
PATHS["COMP_ROOT"] = str(COMP_ROOT)
PATHS["SAMPLE_SUB"] = str(COMP_ROOT / "sample_submission.csv")

PATHS["TRAIN_IMAGES"] = str(COMP_ROOT / "train_images")
PATHS["TEST_IMAGES"]  = str(COMP_ROOT / "test_images")
PATHS["TRAIN_MASKS"]  = str(COMP_ROOT / "train_masks")
PATHS["SUPP_IMAGES"]  = str(COMP_ROOT / "supplemental_images")
PATHS["SUPP_MASKS"]   = str(COMP_ROOT / "supplemental_masks")

PATHS["TRAIN_AUTH_DIR"] = str(COMP_ROOT / "train_images" / "authentic")
PATHS["TRAIN_FORG_DIR"] = str(COMP_ROOT / "train_images" / "forged")

# ============================================================
# 2) Output dataset paths (clean artifacts + cache)
# ============================================================
# NOTE: OUT_DS_ROOT/OUT_ROOT tetap menunjuk dataset input (untuk kompatibilitas),
# tapi ART_DIR/CACHE_DIRS bisa berasal dari /kaggle/working jika ada (lebih baru).
PATHS["OUT_DS_ROOT"] = str(OUT_DS_ROOT)
PATHS["OUT_ROOT"]    = str(OUT_ROOT)

PATHS["ART_DIR"]     = str(ART_DIR)
PATHS["CACHE_DIRS"]  = [str(x) for x in CACHE_DIRS]  # list

# artifacts utama
PATHS["DF_TRAIN_ALL"] = str(Path(ART_DIR) / "df_train_all.parquet")
PATHS["DF_TRAIN_CLS"] = str(Path(ART_DIR) / "df_train_cls.parquet")
PATHS["DF_TRAIN_SEG"] = str(Path(ART_DIR) / "df_train_seg.parquet")
PATHS["DF_TEST"]      = str(Path(ART_DIR) / "df_test.parquet")

PATHS["CV_CASE_FOLDS"]   = str(Path(ART_DIR) / "cv_case_folds.csv")
PATHS["CV_SAMPLE_FOLDS"] = str(Path(ART_DIR) / "cv_sample_folds.csv")

PATHS["IMG_PROFILE_TRAIN"] = str(Path(ART_DIR) / "image_profile_train.parquet")
PATHS["IMG_PROFILE_TEST"]  = str(Path(ART_DIR) / "image_profile_test.parquet")
PATHS["MASK_PROFILE"]      = str(Path(ART_DIR) / "mask_profile.parquet")
PATHS["CASE_SUMMARY"]      = str(Path(ART_DIR) / "case_summary.parquet")

# ============================================================
# 3) Select best MATCH/PRED CFG dirs automatically (scoring)
# ============================================================
MATCH_CFG_DIR = pick_best_cfg(
    CACHE_DIRS,
    prefix="match_base_cfg_",
    feat_train_filename="match_features_train_all.csv",
    feat_test_filename="match_features_test.csv",
)

PRED_CFG_DIR = pick_best_cfg(
    CACHE_DIRS,
    prefix="pred_base",
    feat_train_filename="pred_features_train_all.csv",
    feat_test_filename="pred_features_test.csv",  # diprefer jika ada
)

# DINO cache cfg (opsional): cache/dino_v2_large/cfg_*/manifest_train_all.csv
DINO_CFG_DIR = None
for root in CACHE_DIRS:
    dino_root = Path(root) / "dino_v2_large"
    if not dino_root.exists():
        continue
    dino_cands = []
    for d in dino_root.iterdir():
        if d.is_dir() and d.name.startswith("cfg_") and (d / "manifest_train_all.csv").exists():
            dino_cands.append(d)
    if dino_cands:
        scored = []
        for d in dino_cands:
            mf = d / "manifest_train_all.csv"
            scored.append((_fast_count_rows_csv(mf), _safe_mtime(d), d))
        scored.sort(key=lambda x: (-x[0], -x[1], x[2].name))
        DINO_CFG_DIR = scored[0][2]
        break

# simpan cfg dir ke PATHS
PATHS["MATCH_CFG_DIR"] = str(MATCH_CFG_DIR)
PATHS["PRED_CFG_DIR"]  = str(PRED_CFG_DIR)
PATHS["DINO_CFG_DIR"]  = str(DINO_CFG_DIR) if DINO_CFG_DIR else ""

# feature paths dari cfg terpilih
PATHS["MATCH_FEAT_TRAIN"] = str(MATCH_CFG_DIR / "match_features_train_all.csv")
PATHS["MATCH_FEAT_TEST"]  = str(MATCH_CFG_DIR / "match_features_test.csv")

PATHS["PRED_FEAT_TRAIN"]  = str(PRED_CFG_DIR / "pred_features_train_all.csv")
PATHS["PRED_FEAT_TEST"]   = str(PRED_CFG_DIR / "pred_features_test.csv")  # bisa missing (warning)

# pred manifests (sering kepakai untuk infer/submission)
PATHS["PRED_MAN_TRAIN"] = str(PRED_CFG_DIR / "manifest_pred_train_all.csv")
PATHS["PRED_MAN_TEST"]  = str(PRED_CFG_DIR / "manifest_pred_test.csv")
PATHS["PRED_SUMMARY"]   = str(PRED_CFG_DIR / "pred_summary.json")

# match manifests (kalau ada)
PATHS["MATCH_MAN_TRAIN"] = str(MATCH_CFG_DIR / "manifest_match_train_all.csv")
PATHS["MATCH_MAN_TEST"]  = str(MATCH_CFG_DIR / "manifest_match_test.csv")

# ============================================================
# 4) DINO model dir (offline)
# ============================================================
DINO_DIR = detect_dino_dir()
PATHS["DINO_DIR"] = str(DINO_DIR)

# ============================================================
# 5) Sanity checks (wajib ada untuk training)
# ============================================================
must_exist = [
    ("sample_submission.csv", PATHS["SAMPLE_SUB"]),
    ("df_train_all.parquet",  PATHS["DF_TRAIN_ALL"]),
    ("cv_case_folds.csv",     PATHS["CV_CASE_FOLDS"]),
    ("match_features_train_all.csv", PATHS["MATCH_FEAT_TRAIN"]),
    ("pred_features_train_all.csv",  PATHS["PRED_FEAT_TRAIN"]),
]
missing = [name for name, p in must_exist if not Path(p).exists()]
if missing:
    raise FileNotFoundError("Missing required files: " + ", ".join(missing))

# opsional tapi penting untuk inference: test feature files
opt_warn = []
if not Path(PATHS["MATCH_FEAT_TEST"]).exists():
    opt_warn.append("match_features_test.csv (MATCH_FEAT_TEST)")
if not Path(PATHS["PRED_FEAT_TEST"]).exists():
    opt_warn.append("pred_features_test.csv (PRED_FEAT_TEST)")
if opt_warn:
    print("WARNING: File opsional untuk inference belum ada:")
    for w in opt_warn:
        print(" -", w)
    print("Catatan: training masih aman (pakai *_train_all). Untuk inference gate ke test, file ini biasanya dibutuhkan.")

# DINO model dir opsional (warning saja)
if not Path(PATHS["DINO_DIR"]).exists():
    print(f"WARNING: DINO dir tidak ditemukan: {PATHS['DINO_DIR']}")

# ============================================================
# 6) Print summary + export helpers
# ============================================================
SELECTED = {
    "ART_DIR": str(ART_DIR),
    "CACHE_DIRS": [str(x) for x in CACHE_DIRS],
    "MATCH_CFG_DIR": str(MATCH_CFG_DIR),
    "PRED_CFG_DIR": str(PRED_CFG_DIR),
    "DINO_CFG_DIR": str(DINO_CFG_DIR) if DINO_CFG_DIR else "",
    "DINO_DIR": str(DINO_DIR),
}

# Training plan (dipakai stage berikutnya supaya model bagus):
# - Group CV by case_id (anti leakage)
TRAIN_PLAN = {
    "seed": 2025,
    "group_col": "case_id",
    "target_col": "y_forged",
    "n_folds": 5,
    "use_calibration": True,
    "calibration": "isotonic",  # biasanya kuat untuk tabular probs
    "tune_threshold_on_oof": True,
}

print("OK — Roots")
print("  COMP_ROOT   :", COMP_ROOT)
print("  OUT_DS_ROOT :", OUT_DS_ROOT)
print("  OUT_ROOT    :", OUT_ROOT)
print("  ART_DIR(use):", ART_DIR)
print("  CACHE_DIRS  :", [str(x) for x in CACHE_DIRS])

print("\nOK — Selected CFG")
print("  MATCH_CFG_DIR:", MATCH_CFG_DIR.name)
print("  PRED_CFG_DIR :", PRED_CFG_DIR.name)
print("  DINO_CFG_DIR :", (DINO_CFG_DIR.name if DINO_CFG_DIR else "(not found / optional)"))

print("\nOK — Key files (train)")
for k in ["DF_TRAIN_ALL", "CV_CASE_FOLDS", "MATCH_FEAT_TRAIN", "PRED_FEAT_TRAIN", "IMG_PROFILE_TRAIN"]:
    p = Path(PATHS[k])
    print(f"  {k:16s}: {p}  {'(exists)' if p.exists() else '(missing/optional)'}")

print("\nOK — Key files (test/infer, optional)")
for k in ["MATCH_FEAT_TEST", "PRED_FEAT_TEST", "PRED_MAN_TEST", "PRED_SUMMARY"]:
    p = Path(PATHS[k])
    print(f"  {k:16s}: {p}  {'(exists)' if p.exists() else '(missing)'}")

print("\nOK — DINO model dir")
print("  DINO_DIR:", DINO_DIR, "(exists)" if DINO_DIR.exists() else "(missing)")

# export globals
globals().update({
    "MATCH_CFG_DIR": MATCH_CFG_DIR,
    "PRED_CFG_DIR": PRED_CFG_DIR,
    "DINO_CFG_DIR": DINO_CFG_DIR,
    "CACHE_ROOTS": CACHE_DIRS,
    "SELECTED": SELECTED,
    "TRAIN_PLAN": TRAIN_PLAN,
})


OK — Roots
  COMP_ROOT   : /kaggle/input/recodai-luc-scientific-image-forgery-detection
  OUT_DS_ROOT : /kaggle/input/recod-ailuc-dinov2-base
  OUT_ROOT    : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc

OK — Selected CFG
  MATCH_CFG_DIR: match_base_cfg_f9f7ea3a65c5
  PRED_CFG_DIR : pred_base_v3_v7_cfg_5dbf0aa165
  DINO_CFG_DIR : cfg_3246fd54aab0

OK — Key files
  DF_TRAIN_ALL    : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/df_train_all.parquet  (exists)
  CV_CASE_FOLDS   : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/cv_case_folds.csv  (exists)
  MATCH_FEAT_TRAIN: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/match_base_cfg_f9f7ea3a65c5/match_features_train_all.csv  (exists)
  PRED_FEAT_TRAIN : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/pred_base_v3_v7_cfg_5dbf0aa165/pred_features_train_all.csv  (exists)
  IMG_PROFILE_TRAIN: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/image_profile_train.parquet  (exists)

OK 

# Build Training Table (X, y, folds)

In [2]:
# ============================================================
# STEP 2 — Build Training Table (X, y, folds) — REVISI FULL v2.1 (Tabular/Gate-ready, robust)
#
# Fokus:
# - Bangun df_train_tabular + FEATURE_COLS dari pred_features (+match_features/+image_profile opsional)
# - Split pakai cv_case_folds.csv (anti leakage by case_id)
#
# Output globals:
# - df_train_tabular, FEATURE_COLS
# - X_train, y_train, folds
#
# Saved:
# - /kaggle/working/recodai_luc_gate_artifacts/feature_cols.json
# - /kaggle/working/recodai_luc_gate_artifacts/feature_schema.json
# - /kaggle/working/recodai_luc_gate_artifacts/df_train_tabular.parquet
# ============================================================

import os, json, gc, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

# ----------------------------
# 0) Require PATHS
# ----------------------------
if "PATHS" not in globals() or not isinstance(PATHS, dict):
    raise RuntimeError("Missing PATHS. Jalankan dulu STAGE 0 — Set Paths & Select Config (CFG).")

# ----------------------------
# 1) Feature Engineering Config
# ----------------------------
FE_CFG = {
    "use_match_features": True,
    "use_image_profile": True,

    # encode variant -> dummies (sering bantu score)
    "encode_variant_onehot": True,
    "variant_min_count": 1,          # keep semua variant; naikkan kalau mau buang yang super-rare

    # transforms
    "add_log_features": True,
    "add_sqrt_features": False,      # opsional
    "add_interactions": True,
    "add_missing_indicators": True,

    # outlier control
    "clip_by_quantile": True,
    "clip_q": 0.999,                 # p99.9 cap
    "clip_max_fallback": 1e9,

    # fill
    "fillna_value": 0.0,

    # prune
    "drop_constant_features": True,

    # dtype
    "cast_float32": True,
}

# ----------------------------
# 2) Prefer WORKING features if exist (kalau kamu regen di /kaggle/working)
# ----------------------------
def _prefer_existing(*paths):
    for p in paths:
        if p is None:
            continue
        p = Path(str(p))
        if p.exists():
            return p
    return Path(str(paths[0])) if paths and paths[0] is not None else Path("")

match_cfg_name = Path(PATHS.get("MATCH_CFG_DIR", "")).name if PATHS.get("MATCH_CFG_DIR") else ""
pred_cfg_name  = Path(PATHS.get("PRED_CFG_DIR", "")).name  if PATHS.get("PRED_CFG_DIR") else ""

WORK_CACHE_ROOT = Path("/kaggle/working/recodai_luc/cache")
match_feat_work = (WORK_CACHE_ROOT / match_cfg_name / "match_features_train_all.csv") if match_cfg_name else None
pred_feat_work  = (WORK_CACHE_ROOT / pred_cfg_name  / "pred_features_train_all.csv")  if pred_cfg_name  else None

PRED_FEAT_TRAIN  = _prefer_existing(pred_feat_work,  PATHS.get("PRED_FEAT_TRAIN", ""))
MATCH_FEAT_TRAIN = _prefer_existing(match_feat_work, PATHS.get("MATCH_FEAT_TRAIN", ""))

DF_TRAIN_ALL      = Path(PATHS["DF_TRAIN_ALL"])
CV_CASE_FOLDS     = Path(PATHS["CV_CASE_FOLDS"])
IMG_PROFILE_TRAIN = Path(PATHS.get("IMG_PROFILE_TRAIN", ""))

for need_name, need_path in [
    ("df_train_all.parquet", DF_TRAIN_ALL),
    ("cv_case_folds.csv", CV_CASE_FOLDS),
    ("pred_features_train_all.csv", PRED_FEAT_TRAIN),
]:
    if not Path(need_path).exists():
        raise FileNotFoundError(f"Missing required file: {need_name} -> {need_path}")

print("Using:")
print("  DF_TRAIN_ALL     :", DF_TRAIN_ALL)
print("  CV_CASE_FOLDS    :", CV_CASE_FOLDS)
print("  PRED_FEAT_TRAIN  :", PRED_FEAT_TRAIN)
print("  MATCH_FEAT_TRAIN :", MATCH_FEAT_TRAIN, "(optional)" if Path(MATCH_FEAT_TRAIN).exists() else "(missing/skip)")
print("  IMG_PROFILE_TRAIN:", IMG_PROFILE_TRAIN, "(optional)" if Path(IMG_PROFILE_TRAIN).exists() else "(missing/skip)")

# ----------------------------
# 3) Load minimal inputs
# ----------------------------
# df_train_all: ambil kolom minimal supaya hemat memori
df_base = pd.read_parquet(
    DF_TRAIN_ALL,
    columns=[c for c in ["sample_id", "uid", "case_id", "variant", "y_forged", "has_mask", "is_forged", "forged"]]
    if True else None
)

df_cv   = pd.read_csv(CV_CASE_FOLDS)
df_pred = pd.read_csv(PRED_FEAT_TRAIN, low_memory=False)

df_match = None
if FE_CFG["use_match_features"] and Path(MATCH_FEAT_TRAIN).exists():
    try:
        df_match = pd.read_csv(MATCH_FEAT_TRAIN, low_memory=False)
    except Exception:
        df_match = None

df_prof = None
if FE_CFG["use_image_profile"] and Path(IMG_PROFILE_TRAIN).exists():
    try:
        df_prof = pd.read_parquet(IMG_PROFILE_TRAIN)
    except Exception:
        df_prof = None

# ----------------------------
# 4) Utilities: normalize ids and ensure columns
# ----------------------------
def _to_str_series(s: pd.Series) -> pd.Series:
    return s.astype(str).replace({"nan": "", "None": ""})

def _ensure_uid(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    if "uid" not in df.columns:
        for alt in ["sample_id", "id", "key"]:
            if alt in df.columns:
                df = df.rename(columns={alt: "uid"})
                break
    if "uid" not in df.columns:
        raise ValueError("Cannot find uid column. Expected 'uid' or 'sample_id'.")
    df["uid"] = _to_str_series(df["uid"])
    return df

def _parse_case_variant_from_uid(uid_s: pd.Series) -> pd.DataFrame:
    uid = _to_str_series(uid_s)
    # pattern: "<case>__<variant>" (utama)
    case1 = uid.str.extract(r"^(\d+)__")[0]
    var1  = uid.str.extract(r"__(.+)$")[0]
    # fallback: "<case>_<variant>"
    case2 = uid.str.extract(r"^(\d+)_")[0]
    var2  = uid.str.extract(r"_(\w+)$")[0]
    case = case1.fillna(case2)
    var  = var1.fillna(var2).fillna("unk")
    out = pd.DataFrame({"case_id": case, "variant": var})
    return out

def _ensure_case_variant(df: pd.DataFrame, df_base_map: pd.DataFrame = None) -> pd.DataFrame:
    df = _ensure_uid(df)

    # kalau sudah ada case_id/variant, cukup rapikan
    if "case_id" in df.columns and "variant" in df.columns:
        df["case_id"] = pd.to_numeric(df["case_id"], errors="coerce")
        if df["case_id"].isna().any():
            # coba fallback parse dari uid
            pv = _parse_case_variant_from_uid(df["uid"])
            df["case_id"] = df["case_id"].fillna(pd.to_numeric(pv["case_id"], errors="coerce"))
            if "variant" in df.columns:
                df["variant"] = df["variant"].astype(str).replace({"nan": "unk", "None": "unk"})
                df["variant"] = df["variant"].where(df["variant"].str.len() > 0, pv["variant"])
        df["case_id"] = df["case_id"].astype("Int64")
        df["variant"] = df["variant"].astype(str).replace({"nan": "unk", "None": "unk"})
        return df

    # prioritas: merge dari df_train_all (lebih akurat daripada parse)
    if df_base_map is not None and {"uid", "case_id", "variant"}.issubset(df_base_map.columns):
        df = df.merge(df_base_map[["uid", "case_id", "variant"]], on="uid", how="left")

    if "case_id" not in df.columns or "variant" not in df.columns:
        pv = _parse_case_variant_from_uid(df["uid"])
        if "case_id" not in df.columns:
            df["case_id"] = pd.to_numeric(pv["case_id"], errors="coerce")
        if "variant" not in df.columns:
            df["variant"] = pv["variant"]

    df["case_id"] = pd.to_numeric(df["case_id"], errors="coerce").astype("Int64")
    df["variant"] = df["variant"].astype(str).replace({"nan": "unk", "None": "unk"})
    return df

def _pick_label_col(df: pd.DataFrame) -> str:
    for cand in ["y_forged", "has_mask", "is_forged", "forged"]:
        if cand in df.columns:
            return cand
    return ""

# ----------------------------
# 5) Prepare base mapping from df_train_all
# ----------------------------
df_base = df_base.copy()
if "uid" not in df_base.columns:
    if "sample_id" in df_base.columns:
        df_base = df_base.rename(columns={"sample_id": "uid"})
    elif ("case_id" in df_base.columns and "variant" in df_base.columns):
        df_base["uid"] = _to_str_series(df_base["case_id"]) + "__" + _to_str_series(df_base["variant"])
df_base = _ensure_uid(df_base)

# normalize base case/variant
if "case_id" in df_base.columns:
    df_base["case_id"] = pd.to_numeric(df_base["case_id"], errors="coerce").astype("Int64")
if "variant" in df_base.columns:
    df_base["variant"] = df_base["variant"].astype(str).replace({"nan": "unk", "None": "unk"})

label_col = _pick_label_col(df_base)
if not label_col:
    raise ValueError("Cannot find label column in df_train_all (y_forged/has_mask/is_forged/forged).")

# drop duplicates: one row per uid
df_base_map = df_base.drop_duplicates(subset=["uid"], keep="first").copy()

# ----------------------------
# 6) Prepare folds
# ----------------------------
if "case_id" not in df_cv.columns or "fold" not in df_cv.columns:
    raise ValueError("cv_case_folds.csv must contain columns: case_id, fold")

df_cv = df_cv[["case_id", "fold"]].copy()
df_cv["case_id"] = pd.to_numeric(df_cv["case_id"], errors="coerce").astype("Int64")
df_cv["fold"]    = pd.to_numeric(df_cv["fold"], errors="coerce").astype("Int64")
df_cv = df_cv.dropna().astype({"case_id": int, "fold": int}).drop_duplicates("case_id")

# ----------------------------
# 7) Start from df_pred (1 row per uid)
# ----------------------------
df_pred = _ensure_case_variant(df_pred, df_base_map=df_base_map)

# de-dup pred (kalau ada duplikat uid, keep first)
if df_pred["uid"].duplicated().any():
    df_pred = df_pred.drop_duplicates(subset=["uid"], keep="first").reset_index(drop=True)

df_train = df_pred.copy()

# attach label from df_base_map (lebih aman daripada percaya pred_features)
if "y" not in df_train.columns:
    df_train = df_train.merge(
        df_base_map[["uid", label_col]].rename(columns={label_col: "y"}),
        on="uid", how="left"
    )

if df_train["y"].isna().any():
    miss = int(df_train["y"].isna().sum())
    raise ValueError(f"Label merge produced NaN in y: {miss} rows. Check df_train_all vs pred_features alignment.")

df_train["y"] = pd.to_numeric(df_train["y"], errors="coerce").astype(int)

# attach folds by case_id
df_train = df_train.drop(columns=["fold"], errors="ignore").merge(df_cv, on="case_id", how="left")
if df_train["fold"].isna().any():
    miss = int(df_train["fold"].isna().sum())
    raise ValueError(f"Missing fold after merging cv_case_folds.csv: {miss} rows.")
df_train["fold"] = df_train["fold"].astype(int)

# ----------------------------
# 8) Optional merge match features (new cols only)
# ----------------------------
if df_match is not None:
    df_match = _ensure_case_variant(df_match, df_base_map=df_base_map)
    if df_match["uid"].duplicated().any():
        df_match = df_match.drop_duplicates(subset=["uid"], keep="first").reset_index(drop=True)

    base_cols = set(df_train.columns)
    new_cols = [c for c in df_match.columns if c not in base_cols and c not in ["case_id", "variant"]]
    if new_cols:
        df_train = df_train.merge(df_match[["uid"] + new_cols], on="uid", how="left")

# ----------------------------
# 9) Optional merge image profile by case_id (numeric only, prefixed)
# ----------------------------
if df_prof is not None and "case_id" in df_prof.columns:
    df_prof2 = df_prof.copy()
    df_prof2["case_id"] = pd.to_numeric(df_prof2["case_id"], errors="coerce").astype("Int64")
    df_prof2 = df_prof2.dropna(subset=["case_id"]).astype({"case_id": int})
    df_prof2 = df_prof2.drop_duplicates("case_id")

    # keep numeric columns only (+case_id)
    prof_num = ["case_id"] + [c for c in df_prof2.columns if c != "case_id" and pd.api.types.is_numeric_dtype(df_prof2[c])]
    df_prof2 = df_prof2[prof_num].copy()

    # prefix (hindari clash)
    ren = {c: f"profile_{c}" for c in df_prof2.columns if c != "case_id"}
    df_prof2 = df_prof2.rename(columns=ren)
    df_train = df_train.merge(df_prof2, on="case_id", how="left")

# ----------------------------
# 10) Feature engineering helpers
# ----------------------------
def _num(s):
    return pd.to_numeric(s, errors="coerce")

def safe_log1p_nonneg(x):
    x = np.asarray(x, dtype=np.float64)
    x = np.where(np.isfinite(x), x, 0.0)
    x = np.clip(x, 0.0, None)
    return np.log1p(x)

def safe_sqrt_nonneg(x):
    x = np.asarray(x, dtype=np.float64)
    x = np.where(np.isfinite(x), x, 0.0)
    x = np.clip(x, 0.0, None)
    return np.sqrt(x)

def get_clip_cap(series: pd.Series, q: float, fallback: float):
    s = _num(series).astype(float).replace([np.inf, -np.inf], np.nan).dropna()
    if len(s) == 0:
        return float(fallback)
    # cap pakai abs untuk robust (kalau ada nilai negatif liar)
    s = s[np.isfinite(s)]
    cap = float(np.quantile(np.abs(s.values), q))
    if (not np.isfinite(cap)) or (cap <= 0):
        return float(fallback)
    return float(cap)

# ----------------------------
# 11) Build candidate numeric feature list (pre-fill)
# ----------------------------
# Jangan masukkan target/split/ID numeric
TARGET_COLS = {"y", "y_forged", "has_mask", "is_forged", "forged"}
SPLIT_COLS  = {"fold"}
ID_DROP_NUM = {"case_id"}  # jangan dipakai sebagai feature

# Replace inf -> NaN for numeric cols
for c in df_train.columns:
    if pd.api.types.is_numeric_dtype(df_train[c]):
        df_train[c] = df_train[c].replace([np.inf, -np.inf], np.nan)

# Missing indicators (sebelum fill)
missing_ind_cols = []
if FE_CFG["add_missing_indicators"]:
    for c in df_train.columns:
        if c in TARGET_COLS or c in SPLIT_COLS or c in ID_DROP_NUM:
            continue
        if pd.api.types.is_numeric_dtype(df_train[c]):
            if df_train[c].isna().any():
                ind = f"isna_{c}"
                df_train[ind] = df_train[c].isna().astype(np.uint8)
                missing_ind_cols.append(ind)

# Heavy-tail columns (dinamis + beberapa umum)
heavy_candidates = set([
    "peak_ratio", "best_weight", "best_count", "best_weight_frac",
    "pair_count", "n_pairs_thr", "n_pairs_mnn", "overmask_tighten_steps",
    "largest_comp", "n_comp", "grid_h", "grid_w", "patch",
    "grid_area_frac", "area_frac", "inlier_ratio",
])
# tambah otomatis: semua kolom yang mengandung kata kunci ini
for c in df_train.columns:
    cl = c.lower()
    if any(k in cl for k in ["count", "pairs", "weight", "ratio", "area", "comp"]):
        if pd.api.types.is_numeric_dtype(df_train[c]):
            heavy_candidates.add(c)

clip_caps = {}
if FE_CFG["clip_by_quantile"]:
    for c in sorted(list(heavy_candidates)):
        if c in df_train.columns and pd.api.types.is_numeric_dtype(df_train[c]):
            clip_caps[c] = get_clip_cap(df_train[c], FE_CFG["clip_q"], FE_CFG["clip_max_fallback"])

# Clipped + log/sqrt transforms
if FE_CFG["add_log_features"] or FE_CFG["add_sqrt_features"]:
    for c, cap in clip_caps.items():
        x = _num(df_train[c]).fillna(0.0).astype(float).values
        x = np.clip(x, -cap, cap)  # symmetric cap, lebih aman
        df_train[f"{c}_cap"] = x.astype(np.float32)

        # log1p untuk magnitude (nonneg)
        if FE_CFG["add_log_features"]:
            df_train[f"logabs_{c}"] = safe_log1p_nonneg(np.abs(x)).astype(np.float32)

        if FE_CFG["add_sqrt_features"]:
            df_train[f"sqrtabs_{c}"] = safe_sqrt_nonneg(np.abs(x)).astype(np.float32)

# Interactions (ringkas tapi sering berguna)
if FE_CFG["add_interactions"]:
    def getf(col, default=0.0):
        if col in df_train.columns:
            return _num(df_train[col]).fillna(default).astype(float).values
        return np.full(len(df_train), default, dtype=np.float64)

    best_mean_sim = getf("best_mean_sim", 0.0)
    best_count    = getf("best_count", 0.0)
    peak_ratio    = getf("peak_ratio", 0.0)
    has_peak      = getf("has_peak", 0.0)
    grid_area     = getf("grid_area_frac", 0.0)
    area_frac     = getf("area_frac", 0.0)
    n_pairs_thr   = getf("n_pairs_thr", 0.0)
    n_pairs_mnn   = getf("n_pairs_mnn", 0.0)
    inlier_ratio  = getf("inlier_ratio", 0.0)
    gh = getf("grid_h", 0.0)
    gw = getf("grid_w", 0.0)

    gridN = np.clip(gh * gw, 0.0, None)

    df_train["sim_x_count"]      = (best_mean_sim * best_count).astype(np.float32)
    df_train["peak_x_sim"]       = (peak_ratio * best_mean_sim).astype(np.float32)
    df_train["haspeak_x_sim"]    = (has_peak * best_mean_sim).astype(np.float32)
    df_train["area_x_sim"]       = (grid_area * best_mean_sim).astype(np.float32)
    df_train["area_x_count"]     = (grid_area * best_count).astype(np.float32)
    df_train["mask_grid_ratio"]  = (area_frac / (1e-6 + grid_area)).astype(np.float32)
    df_train["mnn_ratio"]        = (n_pairs_mnn / (1.0 + n_pairs_thr)).astype(np.float32)
    df_train["pairs_per_cell"]   = (n_pairs_thr / (1.0 + gridN)).astype(np.float32)
    df_train["inlier_x_pairs"]   = (inlier_ratio * n_pairs_thr).astype(np.float32)

# Fill NaN numeric -> 0
num_cols = [c for c in df_train.columns if pd.api.types.is_numeric_dtype(df_train[c])]
df_train[num_cols] = df_train[num_cols].fillna(FE_CFG["fillna_value"])

# ----------------------------
# 12) Variant encoding (optional one-hot)
# ----------------------------
variant_dummy_cols = []
if FE_CFG["encode_variant_onehot"]:
    vc = df_train["variant"].astype(str).fillna("unk")
    # (opsional) buang variant yang terlalu jarang
    counts = vc.value_counts()
    keep = set(counts[counts >= int(FE_CFG["variant_min_count"])].index.tolist())
    vc = vc.where(vc.isin(keep), other="rare")

    dummies = pd.get_dummies(vc, prefix="v", dummy_na=False)
    # pastikan dtype kecil
    dummies = dummies.astype(np.uint8)
    variant_dummy_cols = dummies.columns.tolist()
    df_train = pd.concat([df_train, dummies], axis=1)

# ----------------------------
# 13) Select final feature columns (numeric only)
# ----------------------------
# Exclude target/split/id numeric. Variant dummies sudah numeric -> ikut.
feature_cols = []
for c in df_train.columns:
    if not pd.api.types.is_numeric_dtype(df_train[c]):
        continue
    if c in TARGET_COLS or c in SPLIT_COLS or c in ID_DROP_NUM:
        continue
    feature_cols.append(c)

# Drop constant features
dropped_constant = []
if FE_CFG["drop_constant_features"]:
    nun = df_train[feature_cols].nunique(dropna=False)
    nonconst = nun[nun > 1].index.tolist()
    dropped_constant = sorted(set(feature_cols) - set(nonconst))
    feature_cols = nonconst

# Cast float32 for numeric features (keep uint8 for dummies/indicators ok tapi boleh float32 juga)
if FE_CFG["cast_float32"]:
    # pisahkan uint8 indicators agar tidak meledak memori; cast ke float32 biar konsisten training
    df_train[feature_cols] = df_train[feature_cols].astype(np.float32)

FEATURE_COLS = list(feature_cols)

# ----------------------------
# 14) Final outputs
# ----------------------------
base_out_cols = ["uid", "case_id", "variant", "fold", "y"]
df_train_tabular = df_train[base_out_cols + FEATURE_COLS].copy()

X_train = df_train_tabular[FEATURE_COLS]
y_train = df_train_tabular["y"].astype(int)
folds   = df_train_tabular["fold"].astype(int)

print("\nOK — Training table built")
print("  df_train_tabular:", df_train_tabular.shape)
print("  X_train:", X_train.shape, "| y pos%:", float(y_train.mean()) * 100.0)
print("  folds:", int(folds.nunique()), "unique folds")
print("  feature_cols:", int(len(FEATURE_COLS)))
if dropped_constant:
    print("  dropped_constant_features:", len(dropped_constant))
if variant_dummy_cols:
    print("  variant_dummies:", len(variant_dummy_cols))
if missing_ind_cols:
    print("  missing_indicators:", len(missing_ind_cols))

# hard sanity
if X_train.shape[0] != y_train.shape[0]:
    raise RuntimeError("X_train and y_train row mismatch")
if y_train.isna().any():
    raise RuntimeError("y_train contains NaN")
if folds.isna().any():
    raise RuntimeError("folds contains NaN")

print("\nFeature head:", FEATURE_COLS[:25])
print("Feature tail:", FEATURE_COLS[-15:])

# ----------------------------
# 15) Save reproducible schema
# ----------------------------
OUT_ART = Path("/kaggle/working/recodai_luc_gate_artifacts")
OUT_ART.mkdir(parents=True, exist_ok=True)

with open(OUT_ART / "feature_cols.json", "w") as f:
    json.dump(FEATURE_COLS, f, indent=2)

schema = {
    "fe_cfg": FE_CFG,
    "label_col_source": label_col,
    "clip_caps": clip_caps,
    "dropped_constant_features": dropped_constant,
    "variant_dummy_cols": variant_dummy_cols,
    "missing_indicator_cols": missing_ind_cols,
    "n_features": int(len(FEATURE_COLS)),
    "example_feature_head": FEATURE_COLS[:30],
}
with open(OUT_ART / "feature_schema.json", "w") as f:
    json.dump(schema, f, indent=2)

df_train_tabular.to_parquet(OUT_ART / "df_train_tabular.parquet", index=False)

print(f"\nSaved -> {OUT_ART/'feature_cols.json'}")
print(f"Saved -> {OUT_ART/'feature_schema.json'}")
print(f"Saved -> {OUT_ART/'df_train_tabular.parquet'}")

# export globals
globals().update({
    "df_train_tabular": df_train_tabular,
    "FEATURE_COLS": FEATURE_COLS,
    "X_train": X_train,
    "y_train": y_train,
    "folds": folds,
})


Using:
  DF_TRAIN_ALL      : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/df_train_all.parquet
  CV_CASE_FOLDS     : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/cv_case_folds.csv
  PRED_FEAT_TRAIN   : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/pred_base_v3_v7_cfg_5dbf0aa165/pred_features_train_all.csv
  MATCH_FEAT_TRAIN  : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/match_base_cfg_f9f7ea3a65c5/match_features_train_all.csv (optional)
  IMG_PROFILE_TRAIN : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/image_profile_train.parquet (optional)
  DINO_LARGE_DIR    : /kaggle/input/dinov2/pytorch/large/1

OK — Training table built
  df_train_tabular: (5176, 67)
  X_train: (5176, 62) | y pos%: 54.07650695517774
  folds: 5 unique folds
  feature_cols: 62
  dropped_constant_features: 6

Feature head: ['has_peak', 'peak_ratio', 'best_weight', 'best_count', 'best_mean_sim', 'n_pairs_thr', 'n_pairs_mnn', 'best_inlier_ratio', 'best_

# Train Baseline Model (Leakage-Safe CV)

In [None]:
# ============================================================
# Step 3 — Train Stronger Model (Leakage-Safe CV) — REVISI FULL v2.2
# Upgrade utama:
# - EMA evaluation + save best EMA weights
# - Focal-BCE (optional) + label smoothing (optional)
# - Feature token dropout + input noise (tabular regularization)
# - Fix CFG name print bug + better FULL training epochs (median best epoch)
# - Best threshold search per fold (for reporting)
# ============================================================

import json, gc, math, time, os, warnings
from pathlib import Path

import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

from IPython.display import display
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, log_loss

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# 0) Require outputs from Step 2
# ----------------------------
need_vars = ["df_train_tabular", "FEATURE_COLS"]
for v in need_vars:
    if v not in globals():
        raise RuntimeError(f"Missing `{v}`. Jalankan dulu Step 2 — Build Training Table (X, y, folds).")

df_train_tabular = df_train_tabular.copy()
FEATURE_COLS = list(FEATURE_COLS)

required_cols = {"uid", "case_id", "variant", "fold", "y"}
missing_cols = [c for c in required_cols if c not in df_train_tabular.columns]
if missing_cols:
    raise ValueError(f"df_train_tabular missing columns: {missing_cols}.")

# ----------------------------
# 1) CFG (SAFE vs STRONG)
# ----------------------------
CFG_SAFE = {
    "seed": 2025,

    # mHC (PDF)
    "n_streams": 4,
    "sinkhorn_tmax": 20,
    "alpha_init": 0.01,

    # model capacity (SAFE)
    "d_model": 256,
    "n_layers": 6,
    "n_heads": 8,
    "ffn_mult": 4,
    "dropout": 0.12,
    "attn_dropout": 0.08,

    # regularization for tabular
    "feat_token_drop_p": 0.05,     # drop some feature-tokens (not CLS)
    "input_noise_std": 0.01,       # small gaussian noise AFTER standardize
    "label_smoothing": 0.00,       # optional
    "focal_gamma": 1.5,            # 0.0 -> pure BCE. 1.5 cukup aman buat imbalance

    # training
    "batch_size": 256,
    "accum_steps": 2,
    "epochs": 60,

    "lr": 3e-4,
    "betas": (0.9, 0.95),
    "eps": 1e-8,
    "weight_decay": 5e-2,

    "warmup_frac": 0.08,
    "lr_decay_milestones": (0.80, 0.90),
    "lr_decay_values": (0.316, 0.10),

    "grad_clip": 1.0,

    "early_stop_patience": 12,
    "early_stop_min_delta": 1e-4,

    # EMA
    "use_ema": True,
    "ema_decay": 0.999,

    # reporting
    "report_thr": 0.5,
    "search_best_thr": True,
    "thr_grid": 201,              # number of thresholds to test in [0,1]
}

CFG_STRONG = {
    **CFG_SAFE,
    "d_model": 384,
    "n_layers": 8,
    "dropout": 0.16,
    "epochs": 80,
    "lr": 2e-4,
    "weight_decay": 7e-2,
    "batch_size": 256,
    "accum_steps": 2,
    "feat_token_drop_p": 0.06,
    "input_noise_std": 0.012,
    "focal_gamma": 1.5,
}

# auto select
CFG = dict(CFG_SAFE)
CFG_NAME = "SAFE"
if torch.cuda.is_available():
    try:
        mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        if mem_gb >= 30:
            CFG = dict(CFG_STRONG)
            CFG_NAME = "STRONG"
    except Exception:
        pass

# ----------------------------
# 2) Seed + device
# ----------------------------
def seed_everything(seed: int = 2025):
    import random
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_everything(int(CFG["seed"]))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")
print("Device:", device, "| AMP:", use_amp, "| CFG:", CFG_NAME)

if device.type == "cuda":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# ----------------------------
# 3) Build arrays + guard
# ----------------------------
X = df_train_tabular[FEATURE_COLS].to_numpy(dtype=np.float32, copy=True)
y = df_train_tabular["y"].to_numpy(dtype=np.int64, copy=True)
folds = df_train_tabular["fold"].to_numpy(dtype=np.int64, copy=True)

if not np.isfinite(X).all():
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

n = len(df_train_tabular)
unique_folds = sorted(pd.Series(folds).unique().tolist())
n_folds = len(unique_folds)
n_features = X.shape[1]

print("Setup:")
print("  rows      :", n)
print("  folds     :", n_folds, "|", unique_folds)
print("  pos%      :", float(y.mean()) * 100.0)
print("  n_features:", n_features)

# ----------------------------
# 4) Dataset
# ----------------------------
class TabDataset(Dataset):
    def __init__(self, X, y=None):
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = None if y is None else torch.from_numpy(y.astype(np.float32))

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        if self.y is None:
            return self.X[idx]
        return self.X[idx], self.y[idx]

# ----------------------------
# 5) Normalization (leakage-safe)
# ----------------------------
def fit_standardizer(X_tr: np.ndarray):
    mu = X_tr.mean(axis=0, dtype=np.float64)
    sig = X_tr.std(axis=0, dtype=np.float64)
    sig = np.where(sig < 1e-8, 1.0, sig)
    return mu.astype(np.float32), sig.astype(np.float32)

def apply_standardizer(X_in: np.ndarray, mu: np.ndarray, sig: np.ndarray):
    return ((X_in - mu) / sig).astype(np.float32)

# ----------------------------
# 6) Metrics helpers
# ----------------------------
def safe_auc(y_true, p):
    y_true = np.asarray(y_true)
    if len(np.unique(y_true)) < 2:
        return None
    return float(roc_auc_score(y_true, p))

def safe_logloss(y_true, p):
    p = np.clip(np.asarray(p, dtype=np.float64), 1e-6, 1-1e-6)
    return float(log_loss(y_true, p, labels=[0, 1]))

def find_best_threshold(y_true, p, n_grid=201):
    # maximize F1 (bisa diganti Youden/accuracy kalau mau)
    y_true = np.asarray(y_true).astype(int)
    p = np.asarray(p).astype(np.float64)
    best = {"thr": 0.5, "f1": -1.0, "precision": 0.0, "recall": 0.0}
    for thr in np.linspace(0.0, 1.0, int(n_grid)):
        yh = (p >= thr).astype(int)
        f1 = f1_score(y_true, yh, zero_division=0)
        if f1 > best["f1"]:
            best["thr"] = float(thr)
            best["f1"] = float(f1)
            best["precision"] = float(precision_score(y_true, yh, zero_division=0))
            best["recall"] = float(recall_score(y_true, yh, zero_division=0))
    return best

# ----------------------------
# 7) Sinkhorn projection
# ----------------------------
def sinkhorn_doubly_stochastic(logits: torch.Tensor, tmax: int = 20, eps: float = 1e-6):
    z = logits - logits.max()
    M = torch.exp(z)
    for _ in range(int(tmax)):
        M = M / (M.sum(dim=-1, keepdim=True) + eps)
        M = M / (M.sum(dim=-2, keepdim=True) + eps)
    return M

# ----------------------------
# 8) RMSNorm
# ----------------------------
class RMSNorm(nn.Module):
    def __init__(self, d: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d))

    def forward(self, x):
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
        return x * rms * self.weight

# ----------------------------
# 9) EMA helper
# ----------------------------
class EMA:
    def __init__(self, model: nn.Module, decay: float = 0.999):
        self.decay = float(decay)
        self.shadow = {}
        self.backup = {}

        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.detach().clone()

    @torch.no_grad()
    def update(self, model: nn.Module):
        d = self.decay
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.shadow[name].mul_(d).add_(p.detach(), alpha=(1.0 - d))

    @torch.no_grad()
    def apply_shadow(self, model: nn.Module):
        self.backup = {}
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.backup[name] = p.detach().clone()
            p.copy_(self.shadow[name])

    @torch.no_grad()
    def restore(self, model: nn.Module):
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            p.copy_(self.backup[name])
        self.backup = {}

# ----------------------------
# 10) Model blocks (feature token dropout)
# ----------------------------
class MHCAttnBlock(nn.Module):
    def __init__(self, d_model, n_heads, ffn_mult, dropout, attn_dropout,
                 n_streams=4, sinkhorn_tmax=20, alpha_init=0.01):
        super().__init__()
        self.n_streams = int(n_streams)
        self.sinkhorn_tmax = int(sinkhorn_tmax)

        self.norm1 = RMSNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=attn_dropout, batch_first=True)
        self.drop1 = nn.Dropout(dropout)

        self.norm2 = RMSNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ffn_mult * d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ffn_mult * d_model, d_model),
        )
        self.drop2 = nn.Dropout(dropout)

        self.h_logits = nn.Parameter(torch.zeros(self.n_streams, self.n_streams))
        nn.init.zeros_(self.h_logits)

        a0 = float(alpha_init)
        a0 = min(max(a0, 1e-4), 1 - 1e-4)
        self.alpha_logit = nn.Parameter(torch.log(torch.tensor(a0 / (1 - a0), dtype=torch.float32)))

    def forward(self, streams):
        # streams: (B, n_streams, S, D)
        B, nS, S, D = streams.shape

        # update only stream 0
        x = streams[:, 0]  # (B,S,D)

        x0 = x
        q = self.norm1(x)
        attn_out, _ = self.attn(q, q, q, need_weights=False)
        x = x0 + self.drop1(attn_out)

        x1 = x
        h = self.norm2(x)
        h = self.ffn(h)
        x = x1 + self.drop2(h)

        streams = streams.clone()
        streams[:, 0] = x

        H = sinkhorn_doubly_stochastic(self.h_logits, tmax=self.sinkhorn_tmax)
        alpha = torch.sigmoid(self.alpha_logit).to(dtype=streams.dtype)
        I = torch.eye(self.n_streams, device=streams.device, dtype=streams.dtype)
        Hres = (1.0 - alpha) * I + alpha * H.to(dtype=streams.dtype)

        mixed = torch.einsum("ij,bjtd->bitd", Hres, streams)
        return mixed

class MHCFTTransformer(nn.Module):
    def __init__(self, n_features,
                 d_model=256, n_heads=8, n_layers=6, ffn_mult=4,
                 dropout=0.12, attn_dropout=0.08,
                 n_streams=4, sinkhorn_tmax=20, alpha_init=0.01,
                 feat_token_drop_p=0.0):
        super().__init__()
        self.n_features = int(n_features)
        self.d_model = int(d_model)
        self.feat_token_drop_p = float(feat_token_drop_p)

        self.w = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)
        self.b = nn.Parameter(torch.zeros(self.n_features, self.d_model))
        self.feat_emb = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)

        self.cls = nn.Parameter(torch.randn(1, 1, self.d_model) * 0.02)
        self.in_drop = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            MHCAttnBlock(
                d_model=self.d_model,
                n_heads=n_heads,
                ffn_mult=ffn_mult,
                dropout=dropout,
                attn_dropout=attn_dropout,
                n_streams=n_streams,
                sinkhorn_tmax=sinkhorn_tmax,
                alpha_init=alpha_init
            )
            for _ in range(int(n_layers))
        ])

        self.out_norm = RMSNorm(self.d_model)
        self.head = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(self.d_model, 1),
        )

        self.n_streams = int(n_streams)

    def forward(self, x):
        # x: (B,F)
        tok = x.unsqueeze(-1) * self.w.unsqueeze(0) + self.b.unsqueeze(0)   # (B,F,D)
        tok = tok + self.feat_emb.unsqueeze(0)

        # feature-token dropout (regularizer), do not drop CLS
        if self.training and self.feat_token_drop_p > 0:
            B, F, D = tok.shape
            keep = (torch.rand(B, F, device=tok.device) > self.feat_token_drop_p).to(tok.dtype)
            tok = tok * keep.unsqueeze(-1)

        B = tok.size(0)
        cls = self.cls.expand(B, -1, -1)                                    # (B,1,D)
        seq = torch.cat([cls, tok], dim=1)                                  # (B,1+F,D)
        seq = self.in_drop(seq)

        streams = seq.unsqueeze(1).repeat(1, self.n_streams, 1, 1)          # (B,nS,S,D)
        for blk in self.blocks:
            streams = blk(streams)

        z = streams[:, 0, 0]
        z = self.out_norm(z)
        logit = self.head(z).squeeze(-1)
        return logit

# ----------------------------
# 11) LR Scheduler
# ----------------------------
def make_warmup_step_scheduler(optimizer, total_steps: int, warmup_steps: int,
                              milestones_frac=(0.8, 0.9), decay_values=(0.316, 0.1)):
    m1 = int(float(milestones_frac[0]) * total_steps)
    m2 = int(float(milestones_frac[1]) * total_steps)
    d1 = float(decay_values[0])
    d2 = float(decay_values[1])

    def lr_lambda(step):
        if step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))
        if step < m1:
            return 1.0
        elif step < m2:
            return d1
        else:
            return d2

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# ----------------------------
# 12) Loss: BCE / Focal-BCE with pos_weight + label smoothing
# ----------------------------
def focal_bce_with_logits(logits, targets, pos_weight=None, gamma=0.0):
    # targets in [0,1]
    bce = F.binary_cross_entropy_with_logits(
        logits, targets, reduction="none", pos_weight=pos_weight
    )
    if gamma and gamma > 0:
        p = torch.sigmoid(logits)
        p_t = p * targets + (1 - p) * (1 - targets)
        mod = (1.0 - p_t).clamp_min(0.0).pow(gamma)
        bce = bce * mod
    return bce.mean()

# ----------------------------
# 13) Predict helper (optionally with EMA)
# ----------------------------
@torch.no_grad()
def predict_proba(model, loader, ema: EMA = None):
    model.eval()
    if ema is not None:
        ema.apply_shadow(model)
    probs = []
    for batch in loader:
        xb = batch[0] if isinstance(batch, (list, tuple)) else batch
        xb = xb.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=use_amp):
            logits = model(xb)
            p = torch.sigmoid(logits)
        probs.append(p.detach().cpu().numpy())
    out = np.concatenate(probs, axis=0).astype(np.float32)
    if ema is not None:
        ema.restore(model)
    return out

# ----------------------------
# 14) Train one fold (EMA + focal + noise + best_state)
# ----------------------------
def train_one_fold(X_tr, y_tr, X_va, y_va, cfg):
    mu, sig = fit_standardizer(X_tr)
    X_trn = apply_standardizer(X_tr, mu, sig)
    X_van = apply_standardizer(X_va, mu, sig)

    ds_tr = TabDataset(X_trn, y_tr)
    ds_va = TabDataset(X_van, y_va)

    cpu_cnt = os.cpu_count() or 2
    nw = 2 if cpu_cnt >= 4 else 0
    pin = (device.type == "cuda")
    dl_tr = DataLoader(ds_tr, batch_size=int(cfg["batch_size"]), shuffle=True,
                       num_workers=nw, pin_memory=pin, drop_last=False,
                       persistent_workers=(nw > 0))
    dl_va = DataLoader(ds_va, batch_size=int(cfg["batch_size"]), shuffle=False,
                       num_workers=nw, pin_memory=pin, drop_last=False,
                       persistent_workers=(nw > 0))

    model = MHCFTTransformer(
        n_features=n_features,
        d_model=int(cfg["d_model"]),
        n_heads=int(cfg["n_heads"]),
        n_layers=int(cfg["n_layers"]),
        ffn_mult=int(cfg["ffn_mult"]),
        dropout=float(cfg["dropout"]),
        attn_dropout=float(cfg["attn_dropout"]),
        n_streams=int(cfg["n_streams"]),
        sinkhorn_tmax=int(cfg["sinkhorn_tmax"]),
        alpha_init=float(cfg["alpha_init"]),
        feat_token_drop_p=float(cfg.get("feat_token_drop_p", 0.0)),
    ).to(device)

    # imbalance pos_weight
    pos = int(y_tr.sum())
    neg = int(len(y_tr) - pos)
    pos_weight = torch.tensor([float(neg / max(1, pos))], device=device, dtype=torch.float32)

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        betas=tuple(cfg["betas"]),
        eps=float(cfg["eps"]),
        weight_decay=float(cfg["weight_decay"]),
    )

    accum_steps = max(1, int(cfg.get("accum_steps", 1)))
    optim_steps_per_epoch = int(math.ceil(len(dl_tr) / accum_steps))
    total_optim_steps = int(cfg["epochs"]) * max(1, optim_steps_per_epoch)
    warmup_steps = int(float(cfg["warmup_frac"]) * total_optim_steps)

    sch = make_warmup_step_scheduler(
        opt,
        total_steps=total_optim_steps,
        warmup_steps=warmup_steps,
        milestones_frac=cfg["lr_decay_milestones"],
        decay_values=cfg["lr_decay_values"],
    )

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    ema = EMA(model, decay=float(cfg["ema_decay"])) if bool(cfg.get("use_ema", True)) else None

    best = {"val_logloss": 1e9, "val_auc": None, "epoch": -1}
    best_state = None
    bad = 0

    input_noise_std = float(cfg.get("input_noise_std", 0.0))
    label_smooth = float(cfg.get("label_smoothing", 0.0))
    focal_gamma = float(cfg.get("focal_gamma", 0.0))

    for epoch in range(int(cfg["epochs"])):
        model.train()
        t0 = time.time()
        loss_sum = 0.0
        n_sum = 0

        opt.zero_grad(set_to_none=True)
        micro_step = 0
        optim_step_in_epoch = 0

        for xb, yb in dl_tr:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()

            # input noise (after standardize)
            if input_noise_std and input_noise_std > 0:
                xb = xb + torch.randn_like(xb) * input_noise_std

            # label smoothing for binary
            if label_smooth and label_smooth > 0:
                yb = yb * (1.0 - label_smooth) + 0.5 * label_smooth

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model(xb)
                loss = focal_bce_with_logits(logits, yb, pos_weight=pos_weight, gamma=focal_gamma)
                loss = loss / accum_steps

            scaler.scale(loss).backward()
            micro_step += 1

            loss_sum += float(loss.item()) * xb.size(0) * accum_steps
            n_sum += xb.size(0)

            if (micro_step % accum_steps) == 0:
                if cfg["grad_clip"] and float(cfg["grad_clip"]) > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))

                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                sch.step()
                optim_step_in_epoch += 1

                if ema is not None:
                    ema.update(model)

        # flush last partial
        if (micro_step % accum_steps) != 0:
            if cfg["grad_clip"] and float(cfg["grad_clip"]) > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)
            sch.step()
            optim_step_in_epoch += 1
            if ema is not None:
                ema.update(model)

        # validate (gunakan EMA kalau ada)
        p_va = predict_proba(model, dl_va, ema=ema)
        vll = safe_logloss(y_va, p_va)
        vauc = safe_auc(y_va, p_va)

        tr_loss = loss_sum / max(1, n_sum)
        dt = time.time() - t0
        print(f"  epoch {epoch+1:03d}/{cfg['epochs']} | train_loss={tr_loss:.5f} | val_logloss={vll:.5f} | val_auc={vauc} | opt_steps={optim_step_in_epoch} | dt={dt:.1f}s")

        improved = (best["val_logloss"] - vll) > float(cfg["early_stop_min_delta"])
        if improved:
            best["val_logloss"] = float(vll)
            best["val_auc"] = vauc
            best["epoch"] = int(epoch)
            # simpan state (EMA weight kalau ada, supaya best benar-benar yang dieval)
            if ema is not None:
                ema.apply_shadow(model)
                best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
                ema.restore(model)
            else:
                best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= int(cfg["early_stop_patience"]):
                print(f"  early stop at epoch {epoch+1}, best_epoch={best['epoch']+1}, best_val_logloss={best['val_logloss']:.5f}")
                break

        if device.type == "cuda":
            torch.cuda.empty_cache()
        gc.collect()

    if best_state is not None:
        model.load_state_dict(best_state, strict=True)

    # final val preds (best)
    p_va = predict_proba(model, dl_va, ema=None)  # model sudah di-load best_state

    pack = {
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "mu": mu,
        "sig": sig,
        "cfg": cfg,
        "best": best,
    }
    return pack, p_va, best

# ----------------------------
# 15) CV loop
# ----------------------------
oof_pred = np.zeros(n, dtype=np.float32)
fold_reports = []
best_epochs = []

models_dir = Path("/kaggle/working/recodai_luc_gate_artifacts/baseline_mhc_transformer_folds")
models_dir.mkdir(parents=True, exist_ok=True)

for f in unique_folds:
    print(f"\n[Fold {f}]")
    tr_idx = np.where(folds != f)[0]
    va_idx = np.where(folds == f)[0]

    X_tr, y_tr = X[tr_idx], y[tr_idx]
    X_va, y_va = X[va_idx], y[va_idx]

    pack, p_va, best = train_one_fold(X_tr, y_tr, X_va, y_va, CFG)
    oof_pred[va_idx] = p_va
    best_epochs.append(int(best["epoch"] + 1))

    auc = safe_auc(y_va, p_va)
    vll = safe_logloss(y_va, p_va)

    # best threshold search (untuk report)
    if bool(CFG.get("search_best_thr", True)):
        bt = find_best_threshold(y_va, p_va, n_grid=int(CFG.get("thr_grid", 201)))
        thr_use = float(bt["thr"])
        yhat = (p_va >= thr_use).astype(np.int32)
        rep_thr = {
            "best_thr": thr_use,
            "best_f1": float(bt["f1"]),
            "best_precision": float(bt["precision"]),
            "best_recall": float(bt["recall"]),
        }
    else:
        thr_use = float(CFG["report_thr"])
        yhat = (p_va >= thr_use).astype(np.int32)
        rep_thr = {
            "best_thr": None,
            "best_f1": None,
            "best_precision": None,
            "best_recall": None,
        }

    rep = {
        "fold": int(f),
        "n_val": int(len(va_idx)),
        "pos_val": int(y_va.sum()),
        "auc": auc,
        "logloss": vll,
        f"f1@thr({thr_use:.3f})": float(f1_score(y_va, yhat, zero_division=0)),
        f"precision@thr({thr_use:.3f})": float(precision_score(y_va, yhat, zero_division=0)),
        f"recall@thr({thr_use:.3f})": float(recall_score(y_va, yhat, zero_division=0)),
        "best_val_logloss": float(best["val_logloss"]),
        "best_val_auc": best["val_auc"],
        "best_epoch": int(best["epoch"] + 1),
        **rep_thr,
    }
    fold_reports.append(rep)

    torch.save(
        {"pack": pack, "feature_cols": FEATURE_COLS},
        models_dir / f"baseline_mhc_transformer_fold_{f}.pt"
    )

    if device.type == "cuda":
        torch.cuda.empty_cache()
    gc.collect()

# ----------------------------
# 16) Overall OOF metrics
# ----------------------------
oof_auc = safe_auc(y, oof_pred)
oof_ll  = safe_logloss(y, oof_pred)

# report with fixed thr=0.5 and also best thr on all OOF (optional)
thr_fixed = float(CFG["report_thr"])
oof_yhat_fixed = (oof_pred >= thr_fixed).astype(np.int32)

best_oof_thr = None
best_oof = None
if bool(CFG.get("search_best_thr", True)):
    best_oof = find_best_threshold(y, oof_pred, n_grid=int(CFG.get("thr_grid", 201)))
    best_oof_thr = float(best_oof["thr"])

overall = {
    "rows": int(n),
    "folds": int(n_folds),
    "pos_total": int(y.sum()),
    "pos_rate": float(y.mean()),
    "oof_auc": oof_auc,
    "oof_logloss": oof_ll,
    f"oof_f1@{thr_fixed}": float(f1_score(y, oof_yhat_fixed, zero_division=0)),
    f"oof_precision@{thr_fixed}": float(precision_score(y, oof_yhat_fixed, zero_division=0)),
    f"oof_recall@{thr_fixed}": float(recall_score(y, oof_yhat_fixed, zero_division=0)),
    "oof_best_thr": best_oof_thr,
    "oof_best_f1": (float(best_oof["f1"]) if best_oof else None),
}

df_rep = pd.DataFrame(fold_reports).sort_values("fold").reset_index(drop=True)
print("\nPer-fold report:")
display(df_rep)

print("\nOOF overall:")
print(overall)

# ----------------------------
# 17) Train FULL model (epochs = median(best_epoch) * 1.15)
# ----------------------------
def train_full_fixed(X_full_raw, y_full, cfg, epochs_full: int):
    mu, sig = fit_standardizer(X_full_raw)
    X_full = apply_standardizer(X_full_raw, mu, sig)

    ds_full = TabDataset(X_full, y_full)
    cpu_cnt = os.cpu_count() or 2
    nw = 2 if cpu_cnt >= 4 else 0
    pin = (device.type == "cuda")
    dl_full = DataLoader(ds_full, batch_size=int(cfg["batch_size"]), shuffle=True,
                         num_workers=nw, pin_memory=pin, drop_last=False,
                         persistent_workers=(nw > 0))

    model = MHCFTTransformer(
        n_features=n_features,
        d_model=int(cfg["d_model"]),
        n_heads=int(cfg["n_heads"]),
        n_layers=int(cfg["n_layers"]),
        ffn_mult=int(cfg["ffn_mult"]),
        dropout=float(cfg["dropout"]),
        attn_dropout=float(cfg["attn_dropout"]),
        n_streams=int(cfg["n_streams"]),
        sinkhorn_tmax=int(cfg["sinkhorn_tmax"]),
        alpha_init=float(cfg["alpha_init"]),
        feat_token_drop_p=float(cfg.get("feat_token_drop_p", 0.0)),
    ).to(device)

    pos = int(y_full.sum())
    neg = int(len(y_full) - pos)
    pos_weight = torch.tensor([float(neg / max(1, pos))], device=device, dtype=torch.float32)

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        betas=tuple(cfg["betas"]),
        eps=float(cfg["eps"]),
        weight_decay=float(cfg["weight_decay"]),
    )

    accum_steps = max(1, int(cfg.get("accum_steps", 1)))
    optim_steps_per_epoch = int(math.ceil(len(dl_full) / accum_steps))
    total_optim_steps = int(epochs_full) * max(1, optim_steps_per_epoch)
    warmup_steps = int(float(cfg["warmup_frac"]) * total_optim_steps)

    sch = make_warmup_step_scheduler(
        opt,
        total_steps=total_optim_steps,
        warmup_steps=warmup_steps,
        milestones_frac=cfg["lr_decay_milestones"],
        decay_values=cfg["lr_decay_values"],
    )

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    ema = EMA(model, decay=float(cfg["ema_decay"])) if bool(cfg.get("use_ema", True)) else None

    input_noise_std = float(cfg.get("input_noise_std", 0.0))
    label_smooth = float(cfg.get("label_smoothing", 0.0))
    focal_gamma = float(cfg.get("focal_gamma", 0.0))

    print(f"\nTraining FULL mHC transformer for {epochs_full} epochs (median-best based)...")

    for epoch in range(int(epochs_full)):
        model.train()
        loss_sum = 0.0
        n_sum = 0

        opt.zero_grad(set_to_none=True)
        micro_step = 0

        for xb, yb in dl_full:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()

            if input_noise_std and input_noise_std > 0:
                xb = xb + torch.randn_like(xb) * input_noise_std

            if label_smooth and label_smooth > 0:
                yb = yb * (1.0 - label_smooth) + 0.5 * label_smooth

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model(xb)
                loss = focal_bce_with_logits(logits, yb, pos_weight=pos_weight, gamma=focal_gamma)
                loss = loss / accum_steps

            scaler.scale(loss).backward()
            micro_step += 1

            loss_sum += float(loss.item()) * xb.size(0) * accum_steps
            n_sum += xb.size(0)

            if (micro_step % accum_steps) == 0:
                if cfg["grad_clip"] and float(cfg["grad_clip"]) > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                sch.step()
                if ema is not None:
                    ema.update(model)

        if (micro_step % accum_steps) != 0:
            if cfg["grad_clip"] and float(cfg["grad_clip"]) > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)
            sch.step()
            if ema is not None:
                ema.update(model)

        print(f"  full epoch {epoch+1:03d}/{epochs_full} | loss={loss_sum/max(1,n_sum):.5f}")

        if device.type == "cuda":
            torch.cuda.empty_cache()
        gc.collect()

    # save EMA weights if used
    if ema is not None:
        ema.apply_shadow(model)

    full_pack = {
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "mu": mu,
        "sig": sig,
        "cfg": cfg,
        "epochs_full": int(epochs_full),
        "used_ema": bool(ema is not None),
    }

    if ema is not None:
        ema.restore(model)

    return full_pack

# decide epochs_full from CV best epochs
med_best = int(np.median(np.array(best_epochs, dtype=np.int32))) if len(best_epochs) else int(max(12, CFG["epochs"] * 0.7))
epochs_full = int(max(12, round(med_best * 1.15)))
epochs_full = int(min(epochs_full, int(CFG["epochs"])))  # jangan lebih dari max epochs config

out_dir = Path("/kaggle/working/recodai_luc_gate_artifacts")
out_dir.mkdir(parents=True, exist_ok=True)

full_pack = train_full_fixed(X, y, CFG, epochs_full=epochs_full)
torch.save({"pack": full_pack, "feature_cols": FEATURE_COLS}, out_dir / "baseline_mhc_transformer_model_full.pt")

# ----------------------------
# 18) Save OOF + report
# ----------------------------
df_oof = df_train_tabular[["uid", "case_id", "variant", "fold", "y"]].copy()
df_oof["oof_pred_baseline_mhc_tf"] = oof_pred
df_oof.to_csv(out_dir / "oof_baseline_mhc_transformer.csv", index=False)

report = {
    "model": "mHC-FTTransformer (numeric tabular) — upgraded v2.2 (EMA + focal + reg)",
    "cfg_name": CFG_NAME,
    "cfg": CFG,
    "feature_count": int(len(FEATURE_COLS)),
    "best_epochs": best_epochs,
    "epochs_full": int(epochs_full),
    "fold_reports": fold_reports,
    "overall": overall,
}
with open(out_dir / "baseline_mhc_transformer_cv_report.json", "w") as f:
    json.dump(report, f, indent=2)

print("\nSaved artifacts:")
print("  fold models  ->", models_dir)
print("  full model   ->", out_dir / "baseline_mhc_transformer_model_full.pt")
print("  oof preds    ->", out_dir / "oof_baseline_mhc_transformer.csv")
print("  cv report    ->", out_dir / "baseline_mhc_transformer_cv_report.json")

# Export globals
OOF_PRED_BASELINE_MHC_TF = oof_pred
BASELINE_MHC_TF_OVERALL = overall
BASELINE_MHC_TF_FOLD_REPORTS = fold_reports
BASELINE_MHC_TF_BEST_EPOCHS = best_epochs


Device: cpu | AMP: False | CFG: SAFE
Setup:
  rows      : 5176
  folds     : 5 | [0, 1, 2, 3, 4]
  pos%      : 54.07650695517774
  n_features: 62

[Fold 0]
  epoch 001/50 | train_loss=0.64418 | val_logloss=0.69617 | opt_steps=9 | dt=149.7s
  epoch 002/50 | train_loss=0.64102 | val_logloss=0.70211 | opt_steps=9 | dt=151.1s
  epoch 003/50 | train_loss=0.63783 | val_logloss=0.69920 | opt_steps=9 | dt=148.9s
  epoch 004/50 | train_loss=0.63812 | val_logloss=0.69830 | opt_steps=9 | dt=148.2s
  epoch 005/50 | train_loss=0.63792 | val_logloss=0.69034 | opt_steps=9 | dt=148.4s
  epoch 006/50 | train_loss=0.63585 | val_logloss=0.68873 | opt_steps=9 | dt=149.4s
  epoch 007/50 | train_loss=0.63943 | val_logloss=0.69627 | opt_steps=9 | dt=148.6s
  epoch 008/50 | train_loss=0.63788 | val_logloss=0.70011 | opt_steps=9 | dt=146.9s
  epoch 009/50 | train_loss=0.63812 | val_logloss=0.69200 | opt_steps=9 | dt=152.9s
  epoch 010/50 | train_loss=0.63717 | val_logloss=0.68839 | opt_steps=9 | dt=148.1s
  ep

# Optimize Model & Hyperparameters (Iterative)

In [None]:
# ============================================================
# Step 4 — Optimize Model & Hyperparameters (Iterative) — TRANSFORMER ONLY
# REVISI FULL v3.1 (mHC-lite PDF-inspired, differentiable + EMA + accum + reg)
#
# Fix penting:
# - Sinkhorn-Knopp differentiable (H belajar beneran)
# - EMA eval + save best (stabil)
# - Grad accumulation + scheduler on optimizer-steps (bukan per batch)
# - Feature-token drop + input noise (tabular regularization)
# - 2-stage search: subset folds cepat -> full CV top-M
# - Best model tidak retrain ulang: pakai fold_packs dari stage-2
#
# Primary score: OOF best F-beta (beta=0.5)
#
# Output:
# - /kaggle/working/recodai_luc_gate_artifacts/opt_search/opt_results.csv
# - /kaggle/working/recodai_luc_gate_artifacts/opt_search/opt_results.json
# - /kaggle/working/recodai_luc_gate_artifacts/opt_search/opt_fold_details.csv
# - /kaggle/working/recodai_luc_gate_artifacts/opt_search/oof_preds_<cfg_name>.csv (top configs)
# - /kaggle/working/recodai_luc_gate_artifacts/best_gate_config.json
# - /kaggle/working/recodai_luc_gate_artifacts/best_gate_model.pt
# ============================================================

import os, json, gc, math, time, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

from IPython.display import display
from sklearn.metrics import roc_auc_score, log_loss

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# 0) Require data from Step 2
# ----------------------------
need_vars = ["df_train_tabular", "FEATURE_COLS"]
for v in need_vars:
    if v not in globals():
        raise RuntimeError(f"Missing `{v}`. Jalankan dulu Step 2 — Build Training Table (X, y, folds).")

df_train_tabular = df_train_tabular.copy()
FEATURE_COLS = list(FEATURE_COLS)

X_all = df_train_tabular[FEATURE_COLS].to_numpy(dtype=np.float32, copy=True)
y_all = df_train_tabular["y"].to_numpy(dtype=np.int64, copy=True)
folds_all = df_train_tabular["fold"].to_numpy(dtype=np.int64, copy=True)
uids_all = df_train_tabular["uid"].astype(str).to_numpy()

if not np.isfinite(X_all).all():
    X_all = np.nan_to_num(X_all, nan=0.0, posinf=0.0, neginf=0.0)

unique_folds = sorted(pd.Series(folds_all).unique().tolist())
n = len(y_all)
pos_rate = float(y_all.mean())
n_features = X_all.shape[1]

print("Optimize setup (Transformer-only, mHC-lite):")
print(f"  rows={n} | folds={len(unique_folds)} | pos%={pos_rate*100:.2f} | n_features={n_features}")

# ----------------------------
# 1) Global settings
# ----------------------------
SEED = 2025
BETA = 0.5
THR_GRID = 201

# 2-stage runtime controls
STAGE1_FOLDS = min(3, len(unique_folds))
STAGE1_EPOCH_CAP = 35
STAGE1_PAT_CAP = 6

STAGE2_TOPM = 3
REPORT_TOPK_OOF = 3

# Optional time budget (0 = off)
TIME_BUDGET_SEC = 0  # contoh: 2.5*60*60

def seed_everything(seed=2025):
    import random
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")
print("Device:", device, "| AMP:", use_amp)

if device.type == "cuda":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

def get_mem_gb():
    if not torch.cuda.is_available():
        return 0.0
    try:
        return float(torch.cuda.get_device_properties(0).total_memory / (1024**3))
    except Exception:
        return 0.0

MEM_GB = get_mem_gb()
print("GPU mem GB:", MEM_GB)

# ----------------------------
# 2) Metrics helpers (F-beta primary)
# ----------------------------
def best_fbeta_fast(y_true, p, beta=0.5, grid=201):
    y = (np.asarray(y_true).astype(np.int32) == 1)
    p = np.clip(np.asarray(p, dtype=np.float64), 1e-8, 1.0 - 1e-8)

    thrs = np.linspace(0.01, 0.99, int(grid), dtype=np.float64)
    pred = (p[:, None] >= thrs[None, :])

    y1 = y[:, None]
    tp = (pred & y1).sum(axis=0).astype(np.float64)
    fp = (pred & (~y1)).sum(axis=0).astype(np.float64)
    fn = (y.sum().astype(np.float64) - tp)

    precision = np.divide(tp, tp + fp, out=np.zeros_like(tp), where=(tp + fp) > 0)
    recall    = np.divide(tp, tp + fn, out=np.zeros_like(tp), where=(tp + fn) > 0)

    b2 = beta * beta
    denom = (b2 * precision + recall)
    fbeta = np.divide((1.0 + b2) * precision * recall, denom, out=np.zeros_like(precision), where=denom > 0)

    j = int(np.argmax(fbeta))
    return {
        "fbeta": float(fbeta[j]),
        "thr": float(thrs[j]),
        "precision": float(precision[j]),
        "recall": float(recall[j]),
    }

def safe_auc(y_true, p):
    if len(np.unique(y_true)) < 2:
        return None
    return float(roc_auc_score(y_true, p))

def safe_logloss(y_true, p):
    p = np.clip(np.asarray(p, dtype=np.float64), 1e-8, 1 - 1e-8)
    return float(log_loss(y_true, p, labels=[0, 1]))

# ----------------------------
# 3) Dataset + Standardizer (no leakage)
# ----------------------------
class TabDataset(Dataset):
    def __init__(self, X, y=None):
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = None if y is None else torch.from_numpy(y.astype(np.float32))

    def __len__(self): return self.X.shape[0]

    def __getitem__(self, idx):
        if self.y is None:
            return self.X[idx]
        return self.X[idx], self.y[idx]

def fit_standardizer(X_tr: np.ndarray):
    mu = X_tr.mean(axis=0, dtype=np.float64)
    sig = X_tr.std(axis=0, dtype=np.float64)
    sig = np.where(sig < 1e-8, 1.0, sig)
    return mu.astype(np.float32), sig.astype(np.float32)

def apply_standardizer(X_in: np.ndarray, mu: np.ndarray, sig: np.ndarray):
    return ((X_in - mu) / sig).astype(np.float32)

# ----------------------------
# 4) RMSNorm + Differentiable Sinkhorn + EMA
# ----------------------------
class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-6):
        super().__init__()
        self.eps = float(eps)
        self.weight = nn.Parameter(torch.ones(d))

    def forward(self, x):
        rms = torch.mean(x * x, dim=-1, keepdim=True)
        x = x * torch.rsqrt(rms + self.eps)
        return x * self.weight

def sinkhorn_knopp(P, tmax=20, eps=1e-6):
    """
    Differentiable Sinkhorn-Knopp
    P: (B,n,n) non-negative
    """
    M = P.clamp_min(eps)
    for _ in range(int(tmax)):
        M = M / (M.sum(dim=-1, keepdim=True).clamp_min(eps))  # row
        M = M / (M.sum(dim=-2, keepdim=True).clamp_min(eps))  # col
    return M

class EMA:
    def __init__(self, model: nn.Module, decay: float = 0.999):
        self.decay = float(decay)
        self.shadow = {}
        self.backup = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.detach().clone()

    @torch.no_grad()
    def update(self, model: nn.Module):
        d = self.decay
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.shadow[name].mul_(d).add_(p.detach(), alpha=(1.0 - d))

    @torch.no_grad()
    def apply_shadow(self, model: nn.Module):
        self.backup = {}
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.backup[name] = p.detach().clone()
            p.copy_(self.shadow[name])

    @torch.no_grad()
    def restore(self, model: nn.Module):
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            p.copy_(self.backup[name])
        self.backup = {}

# ----------------------------
# 5) mHC-lite on CLS streams (stabil: residual-to-identity + learnable alpha)
# ----------------------------
class MHCLite(nn.Module):
    """
    Maintain n_streams for CLS only.
    Build per-sample mixing matrix -> Sinkhorn -> residual blend with Identity.
    """
    def __init__(self, d_model, n_streams=4, alpha_init=0.01, tmax=20, dropout=0.0):
        super().__init__()
        self.n = int(n_streams)
        self.tmax = int(tmax)
        self.drop = nn.Dropout(float(dropout))

        self.norm = RMSNorm(d_model, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, self.n * self.n),
        )
        self.softplus = nn.Softplus()

        a0 = float(alpha_init)
        a0 = min(max(a0, 1e-4), 1.0 - 1e-4)
        self.alpha_logit = nn.Parameter(torch.log(torch.tensor(a0 / (1 - a0), dtype=torch.float32)))

    def forward(self, streams, cls_vec):
        # streams: (B,n,D), cls_vec: (B,D)
        B, n, D = streams.shape
        h = self.norm(cls_vec)
        logits = self.mlp(h).view(B, n, n)  # (B,n,n)

        # non-negative
        P = self.softplus(logits)

        # doubly-stochastic
        M = sinkhorn_knopp(P, tmax=self.tmax, eps=1e-6)  # (B,n,n)

        # residual blend with Identity (stability)
        alpha = torch.sigmoid(self.alpha_logit).to(dtype=streams.dtype, device=streams.device)
        I = torch.eye(n, device=streams.device, dtype=streams.dtype).unsqueeze(0).expand(B, -1, -1)
        H = (1.0 - alpha) * I + alpha * M

        mixed = torch.einsum("bij,bjd->bid", H, streams)
        injected = mixed + cls_vec.unsqueeze(1)
        return self.drop(injected)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, ffn_mult=4, dropout=0.2, attn_dropout=0.1):
        super().__init__()
        self.norm1 = RMSNorm(d_model, eps=1e-6)
        self.attn = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=int(n_heads),
            dropout=float(attn_dropout), batch_first=True
        )
        self.drop1 = nn.Dropout(float(dropout))

        self.norm2 = RMSNorm(d_model, eps=1e-6)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, int(ffn_mult) * d_model),
            nn.GELU(),
            nn.Dropout(float(dropout)),
            nn.Linear(int(ffn_mult) * d_model, d_model),
        )
        self.drop2 = nn.Dropout(float(dropout))

    def forward(self, x):
        h = self.norm1(x)
        attn_out, _ = self.attn(h, h, h, need_weights=False)
        x = x + self.drop1(attn_out)

        h = self.norm2(x)
        x = x + self.drop2(self.ffn(h))
        return x

class FTTransformer_MHCLite(nn.Module):
    """
    Numeric FT-Transformer + CLS-stream mHC-lite between blocks.
    Regularizer: feature-token drop (zero some feature tokens, not CLS).
    """
    def __init__(self, n_features, d_model=384, n_heads=8, n_layers=8, ffn_mult=4,
                 dropout=0.2, attn_dropout=0.1,
                 n_streams=4, alpha_init=0.01, sinkhorn_tmax=20, mhc_dropout=0.0,
                 feat_token_drop_p=0.0):
        super().__init__()
        self.n_features = int(n_features)
        self.d_model = int(d_model)
        self.n_layers = int(n_layers)
        self.feat_token_drop_p = float(feat_token_drop_p)

        self.w = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)
        self.b = nn.Parameter(torch.zeros(self.n_features, self.d_model))
        self.feat_emb = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)

        self.cls = nn.Parameter(torch.randn(1, 1, self.d_model) * 0.02)
        self.in_drop = nn.Dropout(float(dropout))

        self.blocks = nn.ModuleList([
            TransformerBlock(
                d_model=self.d_model,
                n_heads=n_heads,
                ffn_mult=ffn_mult,
                dropout=dropout,
                attn_dropout=attn_dropout
            ) for _ in range(self.n_layers)
        ])

        self.mhc = nn.ModuleList([
            MHCLite(
                d_model=self.d_model,
                n_streams=n_streams,
                alpha_init=alpha_init,
                tmax=sinkhorn_tmax,
                dropout=mhc_dropout
            ) for _ in range(self.n_layers)
        ])

        self.out_norm = RMSNorm(self.d_model, eps=1e-6)
        self.head = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.GELU(),
            nn.Dropout(float(dropout)),
            nn.Linear(self.d_model, 1),
        )

    def forward(self, x):
        # x: (B,F)
        tok = x.unsqueeze(-1) * self.w.unsqueeze(0) + self.b.unsqueeze(0)  # (B,F,D)
        tok = tok + self.feat_emb.unsqueeze(0)

        # feature-token drop
        if self.training and self.feat_token_drop_p > 0:
            B, F_, D = tok.shape
            keep = (torch.rand(B, F_, device=tok.device) > self.feat_token_drop_p).to(tok.dtype)
            tok = tok * keep.unsqueeze(-1)

        B = tok.size(0)
        cls = self.cls.expand(B, -1, -1)  # (B,1,D)
        seq = torch.cat([cls, tok], dim=1)  # (B,1+F,D)
        seq = self.in_drop(seq)

        # init CLS streams
        nS = self.mhc[0].n
        streams = seq[:, 0, :].unsqueeze(1).expand(B, nS, self.d_model).contiguous()

        for l, blk in enumerate(self.blocks):
            # avoid inplace on seq[:,0,:]
            cls_in = streams.mean(dim=1).unsqueeze(1)  # (B,1,D)
            seq = torch.cat([cls_in, seq[:, 1:, :]], dim=1)

            seq = blk(seq)
            cls_vec = seq[:, 0, :]

            streams = self.mhc[l](streams, cls_vec)

        out = self.out_norm(streams.mean(dim=1))
        logit = self.head(out).squeeze(-1)
        return logit

# ----------------------------
# 6) Scheduler: warmup + step decay at ratios [0.8,0.9]
# (scheduler on OPTIMIZER STEPS, bukan per batch)
# ----------------------------
def make_warmup_step_scheduler(optimizer, total_steps, warmup_steps,
                              r1=0.8, r2=0.9, d1=0.316, d2=0.1):
    m1 = int(float(r1) * total_steps)
    m2 = int(float(r2) * total_steps)

    def lr_lambda(step):
        if warmup_steps > 0 and step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))
        mult = 1.0
        if step >= m1:
            mult *= float(d1)
        if step >= m2:
            mult *= float(d2)
        return mult

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# ----------------------------
# 7) Predict helper (robust batch)
# ----------------------------
@torch.no_grad()
def predict_proba(model, loader, ema: EMA = None):
    model.eval()
    if ema is not None:
        ema.apply_shadow(model)
    probs = []
    for batch in loader:
        xb = batch[0] if isinstance(batch, (list, tuple)) else batch
        xb = xb.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=use_amp):
            logits = model(xb)
            p = torch.sigmoid(logits)
        probs.append(p.detach().cpu().numpy())
    out = np.concatenate(probs, axis=0).astype(np.float32)
    if ema is not None:
        ema.restore(model)
    return out

# ----------------------------
# 8) Train one fold (AMP + accum + EMA + reg)
# ----------------------------
def train_one_fold_transformer(X_tr, y_tr, X_va, y_va, cfg):
    mu, sig = fit_standardizer(X_tr)
    X_trn = apply_standardizer(X_tr, mu, sig)
    X_van = apply_standardizer(X_va, mu, sig)

    ds_tr = TabDataset(X_trn, y_tr)
    ds_va = TabDataset(X_van, y_va)

    cpu_cnt = os.cpu_count() or 2
    nw = 2 if cpu_cnt >= 4 else 0
    pin = (device.type == "cuda")

    dl_tr = DataLoader(ds_tr, batch_size=int(cfg["batch_size"]), shuffle=True,
                       num_workers=nw, pin_memory=pin, drop_last=False,
                       persistent_workers=(nw > 0))
    dl_va = DataLoader(ds_va, batch_size=int(cfg["batch_size"]), shuffle=False,
                       num_workers=nw, pin_memory=pin, drop_last=False,
                       persistent_workers=(nw > 0))

    model = FTTransformer_MHCLite(
        n_features=n_features,
        d_model=int(cfg["d_model"]),
        n_heads=int(cfg["n_heads"]),
        n_layers=int(cfg["n_layers"]),
        ffn_mult=int(cfg["ffn_mult"]),
        dropout=float(cfg["dropout"]),
        attn_dropout=float(cfg["attn_dropout"]),
        n_streams=int(cfg["n_streams"]),
        alpha_init=float(cfg["alpha_init"]),
        sinkhorn_tmax=int(cfg["sinkhorn_tmax"]),
        mhc_dropout=float(cfg["mhc_dropout"]),
        feat_token_drop_p=float(cfg.get("feat_token_drop_p", 0.0)),
    ).to(device)

    # imbalance: pos_weight
    pos = int(y_tr.sum())
    neg = int(len(y_tr) - pos)
    pos_weight = torch.tensor([float(neg / max(1, pos))], device=device, dtype=torch.float32)

    # optional focal gamma
    focal_gamma = float(cfg.get("focal_gamma", 0.0))
    label_smoothing = float(cfg.get("label_smoothing", 0.0))
    input_noise_std = float(cfg.get("input_noise_std", 0.0))

    def loss_fn(logits, targets):
        # targets float [0,1]
        if label_smoothing and label_smoothing > 0:
            targets = targets * (1.0 - label_smoothing) + 0.5 * label_smoothing
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none", pos_weight=pos_weight)
        if focal_gamma and focal_gamma > 0:
            p = torch.sigmoid(logits)
            p_t = p * targets + (1 - p) * (1 - targets)
            mod = (1.0 - p_t).clamp_min(0.0).pow(focal_gamma)
            bce = bce * mod
        return bce.mean()

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        weight_decay=float(cfg["weight_decay"]),
        betas=(float(cfg["beta1"]), float(cfg["beta2"])),
        eps=float(cfg["adam_eps"]),
    )

    accum_steps = max(1, int(cfg.get("accum_steps", 1)))
    steps_per_epoch = int(math.ceil(len(dl_tr) / accum_steps))
    total_steps = int(cfg["epochs"]) * max(1, steps_per_epoch)
    warmup_steps = int(float(cfg["warmup_frac"]) * total_steps)

    sch = make_warmup_step_scheduler(
        opt,
        total_steps=total_steps,
        warmup_steps=warmup_steps,
        r1=float(cfg["lr_decay_ratio1"]),
        r2=float(cfg["lr_decay_ratio2"]),
        d1=float(cfg["lr_decay_rate1"]),
        d2=float(cfg["lr_decay_rate2"]),
    )

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    ema = EMA(model, decay=float(cfg.get("ema_decay", 0.999))) if bool(cfg.get("use_ema", True)) else None

    best_val = 1e18
    best_state = None
    best_epoch = -1
    bad = 0
    opt_step = 0

    for epoch in range(int(cfg["epochs"])):
        model.train()
        opt.zero_grad(set_to_none=True)

        loss_sum = 0.0
        n_sum = 0
        micro = 0

        for xb, yb in dl_tr:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()

            if input_noise_std and input_noise_std > 0:
                xb = xb + torch.randn_like(xb) * input_noise_std

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model(xb)
                loss = loss_fn(logits, yb)
                loss = loss / accum_steps

            scaler.scale(loss).backward()
            micro += 1

            loss_sum += float(loss.item()) * xb.size(0) * accum_steps
            n_sum += xb.size(0)

            if (micro % accum_steps) == 0:
                if float(cfg["grad_clip"]) and float(cfg["grad_clip"]) > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                sch.step()
                opt_step += 1
                if ema is not None:
                    ema.update(model)

        # flush last partial
        if (micro % accum_steps) != 0:
            if float(cfg["grad_clip"]) and float(cfg["grad_clip"]) > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)
            sch.step()
            opt_step += 1
            if ema is not None:
                ema.update(model)

        # validate with EMA (if enabled)
        p_va = predict_proba(model, dl_va, ema=ema)
        vll = safe_logloss(y_va, p_va)

        improved = (best_val - vll) > float(cfg["min_delta"])
        if improved:
            best_val = float(vll)
            best_epoch = int(epoch)
            if ema is not None:
                ema.apply_shadow(model)
                best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
                ema.restore(model)
            else:
                best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= int(cfg["patience"]):
                break

        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    if best_state is not None:
        model.load_state_dict(best_state, strict=True)

    # final val preds (best)
    p_va = predict_proba(model, dl_va, ema=None)

    pack = {
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "mu": mu,
        "sig": sig,
        "cfg": dict(cfg),
        "best_epoch": int(best_epoch + 1),
        "best_val_logloss": float(best_val),
    }
    return pack, p_va

# ----------------------------
# 9) CV evaluator for a config
# ----------------------------
def run_cv_config(cfg, cfg_name, folds_subset=None, beta=0.5, thr_grid=201):
    oof = np.zeros(n, dtype=np.float32)
    fold_rows = []
    fold_packs = []

    use_folds = unique_folds if folds_subset is None else list(folds_subset)

    for f in use_folds:
        tr = np.where(folds_all != f)[0]
        va = np.where(folds_all == f)[0]

        X_tr, y_tr = X_all[tr], y_all[tr]
        X_va, y_va = X_all[va], y_all[va]

        pack, p_va = train_one_fold_transformer(X_tr, y_tr, X_va, y_va, cfg)
        oof[va] = p_va

        fold_auc = safe_auc(y_va, p_va)
        fold_ll  = safe_logloss(y_va, p_va)
        best_fold = best_fbeta_fast(y_va, p_va, beta=beta, grid=max(81, thr_grid//2))

        fold_rows.append({
            "cfg": cfg_name,
            "fold": int(f),
            "n_val": int(len(va)),
            "pos_val": int(y_va.sum()),
            "auc": fold_auc,
            "logloss": fold_ll,
            "best_fbeta": best_fold["fbeta"],
            "best_thr": best_fold["thr"],
            "best_prec": best_fold["precision"],
            "best_rec": best_fold["recall"],
            "best_val_logloss": float(pack["best_val_logloss"]),
            "best_epoch": int(pack["best_epoch"]),
        })

        pack2 = dict(pack)
        pack2["fold"] = int(f)
        fold_packs.append(pack2)

        del pack
        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    if folds_subset is None:
        idx_eval = np.arange(n)
    else:
        idx_eval = np.where(np.isin(folds_all, np.array(use_folds)))[0]

    oof_eval = oof[idx_eval]
    y_eval = y_all[idx_eval]

    oof_auc = safe_auc(y_eval, oof_eval)
    oof_ll  = safe_logloss(y_eval, oof_eval)
    best_oof = best_fbeta_fast(y_eval, oof_eval, beta=beta, grid=thr_grid)

    summary = {
        "cfg": cfg_name,
        "stage": "full" if folds_subset is None else f"subset{len(use_folds)}",
        "oof_auc": oof_auc,
        "oof_logloss": oof_ll,
        "oof_best_fbeta": best_oof["fbeta"],
        "oof_best_thr": best_oof["thr"],
        "oof_best_prec": best_oof["precision"],
        "oof_best_rec": best_oof["recall"],

        # main arch
        "d_model": cfg["d_model"],
        "n_layers": cfg["n_layers"],
        "n_heads": cfg["n_heads"],
        "ffn_mult": cfg["ffn_mult"],
        "dropout": cfg["dropout"],
        "attn_dropout": cfg["attn_dropout"],

        # mHC-lite
        "n_streams": cfg["n_streams"],
        "alpha_init": cfg["alpha_init"],
        "sinkhorn_tmax": cfg["sinkhorn_tmax"],
        "mhc_dropout": cfg["mhc_dropout"],

        # reg
        "feat_token_drop_p": cfg.get("feat_token_drop_p", 0.0),
        "input_noise_std": cfg.get("input_noise_std", 0.0),
        "focal_gamma": cfg.get("focal_gamma", 0.0),
        "label_smoothing": cfg.get("label_smoothing", 0.0),

        # train
        "batch_size": cfg["batch_size"],
        "accum_steps": cfg.get("accum_steps", 1),
        "epochs": cfg["epochs"],
        "lr": cfg["lr"],
        "weight_decay": cfg["weight_decay"],
        "warmup_frac": cfg["warmup_frac"],
        "beta1": cfg["beta1"],
        "beta2": cfg["beta2"],
        "adam_eps": cfg["adam_eps"],
        "lr_decay_ratio1": cfg["lr_decay_ratio1"],
        "lr_decay_ratio2": cfg["lr_decay_ratio2"],
        "lr_decay_rate1": cfg["lr_decay_rate1"],
        "lr_decay_rate2": cfg["lr_decay_rate2"],
        "patience": cfg["patience"],
        "min_delta": cfg["min_delta"],
        "grad_clip": cfg["grad_clip"],
        "use_ema": cfg.get("use_ema", True),
        "ema_decay": cfg.get("ema_decay", 0.999),
    }
    return summary, fold_rows, oof, fold_packs

# ----------------------------
# 10) Candidate configs (lebih efektif: kecil tapi tajam)
# ----------------------------
def make_base():
    # heuristic batch/accum
    if device.type == "cuda":
        if MEM_GB >= 30:
            bs, acc = 512, 2
        elif MEM_GB >= 16:
            bs, acc = 384, 2
        else:
            bs, acc = 256, 2
    else:
        bs, acc = 256, 1

    return dict(
        batch_size=bs,
        accum_steps=acc,
        epochs=75 if device.type == "cuda" else 40,
        lr=2e-4,
        weight_decay=1.0e-2,
        warmup_frac=0.10,
        grad_clip=1.0,
        patience=10,
        min_delta=1e-4,

        # optimizer (PDF-ish)
        beta1=0.9,
        beta2=0.95,
        adam_eps=1e-8,

        # schedule (PDF)
        lr_decay_ratio1=0.8,
        lr_decay_ratio2=0.9,
        lr_decay_rate1=0.316,
        lr_decay_rate2=0.1,

        # mHC-lite
        n_streams=4,
        alpha_init=0.01,
        sinkhorn_tmax=20,
        mhc_dropout=0.00,

        # regularization
        feat_token_drop_p=0.05,
        input_noise_std=0.01,
        focal_gamma=1.5,
        label_smoothing=0.00,

        # EMA
        use_ema=True,
        ema_decay=0.999,
    )

BASE = make_base()

candidates = []
candidates.append(("mhc_384x8_main", dict(BASE, d_model=384, n_layers=8,  n_heads=8,  ffn_mult=4, dropout=0.18, attn_dropout=0.10)))
candidates.append(("mhc_384x10_reg", dict(BASE, d_model=384, n_layers=10, n_heads=8,  ffn_mult=4, dropout=0.24, attn_dropout=0.12,
                                         lr=1.6e-4, weight_decay=1.5e-2, patience=12,
                                         feat_token_drop_p=0.06, input_noise_std=0.012)))
candidates.append(("mhc_384x8_ffn2", dict(BASE, d_model=384, n_layers=8,  n_heads=8,  ffn_mult=2, dropout=0.16, attn_dropout=0.10,
                                         lr=2.2e-4, weight_decay=8e-3)))
candidates.append(("mhc_256x6_fast", dict(BASE, d_model=256, n_layers=6,  n_heads=8,  ffn_mult=4, dropout=0.16, attn_dropout=0.08,
                                         lr=3e-4, weight_decay=6e-3, epochs=min(BASE["epochs"], 60), patience=9,
                                         feat_token_drop_p=0.04, input_noise_std=0.010)))

if device.type == "cuda" and MEM_GB >= 20:
    candidates.append(("mhc_512x10_big", dict(BASE, d_model=512, n_layers=10, n_heads=16, ffn_mult=4, dropout=0.26, attn_dropout=0.14,
                                             lr=1.2e-4, weight_decay=2.0e-2, epochs=max(BASE["epochs"], 85), patience=12,
                                             mhc_dropout=0.05, feat_token_drop_p=0.06)))

print(f"\nTotal Transformer candidates: {len(candidates)}")
print("Primary score: OOF best F-beta (beta=0.5)")

# ----------------------------
# 11) Run 2-stage search
# ----------------------------
OUT_DIR = Path("/kaggle/working/recodai_luc_gate_artifacts")
OPT_DIR = OUT_DIR / "opt_search"
OPT_DIR.mkdir(parents=True, exist_ok=True)

# stage-1 folds subset (evenly spaced)
if STAGE1_FOLDS >= len(unique_folds):
    folds_subset = unique_folds
else:
    stepk = max(1, len(unique_folds) // STAGE1_FOLDS)
    folds_subset = unique_folds[::stepk][:STAGE1_FOLDS]

print("\nStage-1 folds subset:", folds_subset)

t0 = time.time()
stage1_rows = []
stage1_fold_rows = []

for i, (name, cfg) in enumerate(candidates, 1):
    if TIME_BUDGET_SEC and (time.time() - t0) > TIME_BUDGET_SEC:
        print("Time budget reached. Stop search.")
        break

    cfg1 = dict(cfg)
    cfg1["epochs"] = int(min(int(cfg1["epochs"]), int(STAGE1_EPOCH_CAP)))
    cfg1["patience"] = int(min(int(cfg1["patience"]), int(STAGE1_PAT_CAP)))

    print(f"\n[Stage-1 {i:02d}/{len(candidates)}] CV(subset) -> {name}")
    summ, fold_rows, _, _ = run_cv_config(cfg1, name, folds_subset=folds_subset, beta=BETA, thr_grid=101)

    stage1_rows.append(summ)
    stage1_fold_rows.extend(fold_rows)

    print(f"  stage1 best_fbeta: {summ['oof_best_fbeta']:.6f} | thr: {summ['oof_best_thr']:.3f} | logloss: {summ['oof_logloss']:.6f}")

df_s1 = pd.DataFrame(stage1_rows).sort_values(["oof_best_fbeta","oof_logloss"], ascending=[False, True]).reset_index(drop=True)
print("\nStage-1 ranking (top):")
display(df_s1.head(10))

topM = min(int(STAGE2_TOPM), len(df_s1))
stage2_names = df_s1["cfg"].head(topM).tolist()
print("\nStage-2 will run full CV for:", stage2_names)

all_summaries = []
all_fold_rows = []
oof_store = {}
pack_store = {}

for j, nm in enumerate(stage2_names, 1):
    if TIME_BUDGET_SEC and (time.time() - t0) > TIME_BUDGET_SEC:
        print("Time budget reached. Stop stage-2.")
        break

    cfg = None
    for (nname, ccfg) in candidates:
        if nname == nm:
            cfg = ccfg
            break
    if cfg is None:
        continue

    print(f"\n[Stage-2 {j:02d}/{len(stage2_names)}] CV(full) -> {nm}")
    summ, fold_rows, oof, fold_packs = run_cv_config(cfg, nm, folds_subset=None, beta=BETA, thr_grid=THR_GRID)

    all_summaries.append(summ)
    all_fold_rows.extend(fold_rows)
    oof_store[nm] = oof
    pack_store[nm] = fold_packs

    print(f"  OOF best_fbeta: {summ['oof_best_fbeta']:.6f} | thr: {summ['oof_best_thr']:.3f}"
          f" | auc: {(summ['oof_auc'] if summ['oof_auc'] is not None else float('nan')):.6f}"
          f" | logloss: {summ['oof_logloss']:.6f}")

df_sum = pd.DataFrame(all_summaries)
df_fold = pd.DataFrame(all_fold_rows)

df_sum = df_sum.sort_values(["oof_best_fbeta", "oof_logloss"], ascending=[False, True]).reset_index(drop=True)

print("\nStage-2 top candidates (full CV):")
display(df_sum)

# save search results
df_sum.to_csv(OPT_DIR / "opt_results.csv", index=False)
with open(OPT_DIR / "opt_results.json", "w") as f:
    json.dump(df_sum.to_dict(orient="records"), f, indent=2)
df_fold.to_csv(OPT_DIR / "opt_fold_details.csv", index=False)

# save OOF preds for top configs (debug)
top_names = df_sum["cfg"].head(min(REPORT_TOPK_OOF, len(df_sum))).tolist()
for nm in top_names:
    df_o = pd.DataFrame({
        "uid": uids_all,
        "y": y_all,
        "fold": folds_all,
        f"oof_pred_{nm}": oof_store[nm]
    })
    df_o.to_csv(OPT_DIR / f"oof_preds_{nm}.csv", index=False)

# ----------------------------
# 12) Choose best config + save BEST fold packs (tanpa retrain ulang)
# ----------------------------
if len(df_sum) == 0:
    raise RuntimeError("Stage-2 produced no results. Turunkan kandidat/epochs atau cek device/VRAM.")

best_single = df_sum.iloc[0].to_dict()
best_cfg_name = str(best_single["cfg"])

best_cfg = None
for nm, cfg in candidates:
    if nm == best_cfg_name:
        best_cfg = cfg
        break
if best_cfg is None:
    raise RuntimeError("Best cfg not found in candidates list (unexpected).")

best_fold_packs = pack_store[best_cfg_name]
best_oof = oof_store[best_cfg_name]
best_oof_best = best_fbeta_fast(y_all, best_oof, beta=BETA, grid=THR_GRID)

best_model_path = OUT_DIR / "best_gate_model.pt"
torch.save(
    {
        "type": "mhc_lite_ft_transformer_v3",
        "feature_cols": FEATURE_COLS,
        "fold_packs": best_fold_packs,
        "cfg_name": best_cfg_name,
        "cfg": best_cfg,
        "seed": SEED,
    },
    best_model_path
)

best_bundle = {
    "type": "mhc_lite_ft_transformer_v3",
    "model_name": best_cfg_name,
    "members": [best_cfg_name],
    "random_seed": SEED,
    "beta_for_tuning": BETA,

    "feature_cols": FEATURE_COLS,
    "cfg": best_cfg,

    "oof_best_thr": best_oof_best["thr"],
    "oof_best_fbeta": best_oof_best["fbeta"],
    "oof_best_prec": best_oof_best["precision"],
    "oof_best_rec": best_oof_best["recall"],
    "oof_auc": safe_auc(y_all, best_oof),
    "oof_logloss": safe_logloss(y_all, best_oof),

    "notes": "Best config from Step 4 (Transformer-only, mHC-lite differentiable + EMA + accum + reg). Step 5 should train FULL model or ensemble fold packs for inference.",
}

with open(OUT_DIR / "best_gate_config.json", "w") as f:
    json.dump(best_bundle, f, indent=2)

print("\nSaved best artifacts:")
print("  best model (fold packs) ->", best_model_path)
print("  best config             ->", OUT_DIR / "best_gate_config.json")
print("  opt results             ->", OPT_DIR / "opt_results.csv")
print("  fold detail             ->", OPT_DIR / "opt_fold_details.csv")

# Export globals for Step 5
BEST_GATE_BUNDLE = best_bundle
BEST_TF_CFG_NAME = best_cfg_name
BEST_TF_CFG = best_cfg
OPT_RESULTS_DF = df_sum
BEST_TF_OOF = best_oof
BEST_TF_OOF_METRIC = best_oof_best


# Final Training (Train on Full Data)

In [None]:
# ============================================================
# Step 5 — Final Training (Train on Full Data) — TRANSFORMER ONLY
# REVISI FULL v4 (match Step 4 v3.1: mHC-lite differentiable + EMA + accum)
#
# Upgrade utama vs Step 5 kamu:
# - Arsitektur disamakan dengan Step 4 best: FTTransformer_MHCLite (bukan FTTransformerBig)
# - Sinkhorn differentiable + residual-to-Identity alpha (stabil)
# - EMA eval + best saving (lebih stabil)
# - Grad accumulation + scheduler dihitung per OPTIMIZER STEP (benar)
# - Internal case-level val untuk cari best_epoch -> retrain full data dengan epoch itu
# - OOM fallback otomatis
#
# Output:
#   /kaggle/working/recodai_luc_gate_artifacts/final_gate_model.pt
#   /kaggle/working/recodai_luc_gate_artifacts/final_gate_bundle.json
# ============================================================

import os, json, gc, math, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import log_loss, roc_auc_score

# ----------------------------
# 0) REQUIRE
# ----------------------------
if "df_train_tabular" not in globals():
    raise RuntimeError("Missing `df_train_tabular`. Jalankan Step 2 dulu.")
if "FEATURE_COLS" not in globals():
    raise RuntimeError("Missing `FEATURE_COLS`. Jalankan Step 2 dulu.")

df_train_tabular = df_train_tabular.copy()
FEATURE_COLS = list(FEATURE_COLS)

need_cols = {"uid", "case_id", "y"}
miss = [c for c in need_cols if c not in df_train_tabular.columns]
if miss:
    raise ValueError(f"df_train_tabular missing columns: {miss}")

X_all = df_train_tabular[FEATURE_COLS].to_numpy(dtype=np.float32, copy=True)
y_all = df_train_tabular["y"].to_numpy(dtype=np.int64, copy=True)

if not np.isfinite(X_all).all():
    X_all = np.nan_to_num(X_all, nan=0.0, posinf=0.0, neginf=0.0)

OUT_DIR = Path("/kaggle/working/recodai_luc_gate_artifacts")
OUT_DIR.mkdir(parents=True, exist_ok=True)

print("Final training data:")
print(f"  rows={len(y_all)} | pos%={float(y_all.mean())*100:.2f} | n_features={X_all.shape[1]}")

# ----------------------------
# 1) Load best cfg (Step 4 output)
# ----------------------------
best_bundle = None
source = None

cfg_path = OUT_DIR / "best_gate_config.json"
best_model_candidates = [
    OUT_DIR / "best_gate_model.pt",
    OUT_DIR / "best_gate_model.pth",
]

if "BEST_GATE_BUNDLE" in globals() and isinstance(BEST_GATE_BUNDLE, dict):
    best_bundle = BEST_GATE_BUNDLE
    source = "memory(BEST_GATE_BUNDLE)"
elif cfg_path.exists():
    best_bundle = json.loads(cfg_path.read_text())
    source = str(cfg_path)

if best_bundle is not None and isinstance(best_bundle, dict) and isinstance(best_bundle.get("cfg", None), dict):
    base_cfg = dict(best_bundle["cfg"])
    print("\nLoaded cfg from:", source)
else:
    base_cfg = {}
    print("\nNo best_gate_config found. Using strong default cfg.")

# optional: load fold_packs dari best_gate_model.pt (biar final file bisa bawa ensemble fold juga)
fold_packs_from_step4 = None
best_gate_model_path = None
for p in best_model_candidates:
    if p.exists():
        best_gate_model_path = p
        break

if best_gate_model_path is not None:
    try:
        obj = torch.load(best_gate_model_path, map_location="cpu")
        if isinstance(obj, dict) and isinstance(obj.get("fold_packs", None), list):
            fold_packs_from_step4 = obj["fold_packs"]
            print("Loaded fold_packs from:", str(best_gate_model_path))
    except Exception as e:
        print("Warning: failed to load best_gate_model.pt fold_packs:", repr(e))

# ----------------------------
# 2) Device + seed
# ----------------------------
def seed_everything(seed: int):
    import random
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

FINAL_SEED = 2025
if isinstance(best_bundle, dict):
    FINAL_SEED = int(best_bundle.get("seed", best_bundle.get("random_seed", 2025)))

seed_everything(FINAL_SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")

try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

vram_gb = None
if device.type == "cuda":
    vram_gb = float(torch.cuda.get_device_properties(0).total_memory / (1024**3))

print("\nDevice:", device, "| AMP:", use_amp, "| VRAM_GB:", (f"{vram_gb:.1f}" if vram_gb else "CPU"))

# ----------------------------
# 3) Training policy (stabil + kuat)
# ----------------------------
WANT_BIG_MODEL = True
USE_INTERNAL_VAL = True         # cari best_epoch dari val (case-level)
VAL_FRAC_CASE = 0.08            # 8% case untuk val
EARLY_STOP = True

# runtime: 1 seed default
N_SEEDS = 1

# target effective batch
TARGET_EFF_BATCH = 1024 if device.type == "cuda" else 256

# ----------------------------
# 4) Dataset + Standardizer
# ----------------------------
class TabDataset(Dataset):
    def __init__(self, X, y=None):
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = None if y is None else torch.from_numpy(y.astype(np.float32))
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        if self.y is None:
            return self.X[idx]
        return self.X[idx], self.y[idx]

def fit_standardizer(X_tr: np.ndarray):
    mu = X_tr.mean(axis=0, dtype=np.float64)
    sig = X_tr.std(axis=0, dtype=np.float64)
    sig = np.where(sig < 1e-8, 1.0, sig)
    return mu.astype(np.float32), sig.astype(np.float32)

def apply_standardizer(X_in: np.ndarray, mu: np.ndarray, sig: np.ndarray):
    return ((X_in - mu) / sig).astype(np.float32)

# ----------------------------
# 5) Metrics helpers
# ----------------------------
def safe_logloss(y_true, p):
    p = np.clip(np.asarray(p, dtype=np.float64), 1e-8, 1 - 1e-8)
    return float(log_loss(y_true, p, labels=[0, 1]))

def safe_auc(y_true, p):
    if len(np.unique(y_true)) < 2:
        return None
    return float(roc_auc_score(y_true, p))

# ----------------------------
# 6) Model: FTTransformer_MHCLite (MATCH Step 4 v3.1)
# ----------------------------
class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-6):
        super().__init__()
        self.eps = float(eps)
        self.weight = nn.Parameter(torch.ones(d))
    def forward(self, x):
        rms = torch.mean(x * x, dim=-1, keepdim=True)
        x = x * torch.rsqrt(rms + self.eps)
        return x * self.weight

def sinkhorn_knopp(P, tmax=20, eps=1e-6):
    M = P.clamp_min(eps)
    for _ in range(int(tmax)):
        M = M / (M.sum(dim=-1, keepdim=True).clamp_min(eps))  # row
        M = M / (M.sum(dim=-2, keepdim=True).clamp_min(eps))  # col
    return M

class MHCLite(nn.Module):
    def __init__(self, d_model, n_streams=4, alpha_init=0.01, tmax=20, dropout=0.0):
        super().__init__()
        self.n = int(n_streams)
        self.tmax = int(tmax)
        self.drop = nn.Dropout(float(dropout))

        self.norm = RMSNorm(d_model, eps=1e-6)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, self.n * self.n),
        )
        self.softplus = nn.Softplus()

        a0 = float(alpha_init)
        a0 = min(max(a0, 1e-4), 1.0 - 1e-4)
        self.alpha_logit = nn.Parameter(torch.log(torch.tensor(a0 / (1 - a0), dtype=torch.float32)))

    def forward(self, streams, cls_vec):
        # streams: (B,n,D), cls_vec: (B,D)
        B, n, D = streams.shape
        h = self.norm(cls_vec)
        logits = self.mlp(h).view(B, n, n)
        P = self.softplus(logits)
        M = sinkhorn_knopp(P, tmax=self.tmax, eps=1e-6)

        alpha = torch.sigmoid(self.alpha_logit).to(dtype=streams.dtype, device=streams.device)
        I = torch.eye(n, device=streams.device, dtype=streams.dtype).unsqueeze(0).expand(B, -1, -1)
        H = (1.0 - alpha) * I + alpha * M

        mixed = torch.einsum("bij,bjd->bid", H, streams)
        injected = mixed + cls_vec.unsqueeze(1)
        return self.drop(injected)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, ffn_mult=4, dropout=0.2, attn_dropout=0.1):
        super().__init__()
        self.norm1 = RMSNorm(d_model, eps=1e-6)
        self.attn = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=int(n_heads),
            dropout=float(attn_dropout), batch_first=True
        )
        self.drop1 = nn.Dropout(float(dropout))

        self.norm2 = RMSNorm(d_model, eps=1e-6)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, int(ffn_mult) * d_model),
            nn.GELU(),
            nn.Dropout(float(dropout)),
            nn.Linear(int(ffn_mult) * d_model, d_model),
        )
        self.drop2 = nn.Dropout(float(dropout))

    def forward(self, x):
        h = self.norm1(x)
        attn_out, _ = self.attn(h, h, h, need_weights=False)
        x = x + self.drop1(attn_out)
        h = self.norm2(x)
        x = x + self.drop2(self.ffn(h))
        return x

class FTTransformer_MHCLite(nn.Module):
    def __init__(self, n_features, d_model=384, n_heads=8, n_layers=8, ffn_mult=4,
                 dropout=0.2, attn_dropout=0.1,
                 n_streams=4, alpha_init=0.01, sinkhorn_tmax=20, mhc_dropout=0.0,
                 feat_token_drop_p=0.0):
        super().__init__()
        self.n_features = int(n_features)
        self.d_model = int(d_model)
        self.n_layers = int(n_layers)
        self.feat_token_drop_p = float(feat_token_drop_p)

        self.w = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)
        self.b = nn.Parameter(torch.zeros(self.n_features, self.d_model))
        self.feat_emb = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)

        self.cls = nn.Parameter(torch.randn(1, 1, self.d_model) * 0.02)
        self.in_drop = nn.Dropout(float(dropout))

        self.blocks = nn.ModuleList([
            TransformerBlock(
                d_model=self.d_model,
                n_heads=n_heads,
                ffn_mult=ffn_mult,
                dropout=dropout,
                attn_dropout=attn_dropout
            ) for _ in range(self.n_layers)
        ])

        self.mhc = nn.ModuleList([
            MHCLite(
                d_model=self.d_model,
                n_streams=n_streams,
                alpha_init=alpha_init,
                tmax=sinkhorn_tmax,
                dropout=mhc_dropout
            ) for _ in range(self.n_layers)
        ])

        self.out_norm = RMSNorm(self.d_model, eps=1e-6)
        self.head = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.GELU(),
            nn.Dropout(float(dropout)),
            nn.Linear(self.d_model, 1),
        )

    def forward(self, x):
        # x: (B,F)
        tok = x.unsqueeze(-1) * self.w.unsqueeze(0) + self.b.unsqueeze(0)  # (B,F,D)
        tok = tok + self.feat_emb.unsqueeze(0)

        # feature-token drop
        if self.training and self.feat_token_drop_p > 0:
            B, F_, D = tok.shape
            keep = (torch.rand(B, F_, device=tok.device) > self.feat_token_drop_p).to(tok.dtype)
            tok = tok * keep.unsqueeze(-1)

        B = tok.size(0)
        cls = self.cls.expand(B, -1, -1)  # (B,1,D)
        seq = torch.cat([cls, tok], dim=1)  # (B,1+F,D)
        seq = self.in_drop(seq)

        nS = self.mhc[0].n
        streams = seq[:, 0, :].unsqueeze(1).expand(B, nS, self.d_model).contiguous()

        for l, blk in enumerate(self.blocks):
            cls_in = streams.mean(dim=1).unsqueeze(1)  # (B,1,D)
            seq = torch.cat([cls_in, seq[:, 1:, :]], dim=1)
            seq = blk(seq)
            cls_vec = seq[:, 0, :]
            streams = self.mhc[l](streams, cls_vec)

        out = self.out_norm(streams.mean(dim=1))
        logit = self.head(out).squeeze(-1)
        return logit

# ----------------------------
# 7) EMA + scheduler (optimizer-steps)
# ----------------------------
class EMA:
    def __init__(self, model: nn.Module, decay: float = 0.999):
        self.decay = float(decay)
        self.shadow = {}
        self.backup = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.detach().clone()

    @torch.no_grad()
    def update(self, model: nn.Module):
        d = self.decay
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.shadow[name].mul_(d).add_(p.detach(), alpha=(1.0 - d))

    @torch.no_grad()
    def apply_shadow(self, model: nn.Module):
        self.backup = {}
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.backup[name] = p.detach().clone()
            p.copy_(self.shadow[name])

    @torch.no_grad()
    def restore(self, model: nn.Module):
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            p.copy_(self.backup[name])
        self.backup = {}

def make_warmup_step_scheduler(optimizer, total_steps, warmup_steps,
                              r1=0.8, r2=0.9, d1=0.316, d2=0.1):
    m1 = int(float(r1) * total_steps)
    m2 = int(float(r2) * total_steps)

    def lr_lambda(step):
        if warmup_steps > 0 and step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))
        mult = 1.0
        if step >= m1:
            mult *= float(d1)
        if step >= m2:
            mult *= float(d2)
        return mult

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

@torch.no_grad()
def predict_proba(model, loader, ema: EMA = None):
    model.eval()
    if ema is not None:
        ema.apply_shadow(model)
    ps = []
    for batch in loader:
        xb = batch[0] if isinstance(batch, (list, tuple)) else batch
        xb = xb.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=use_amp):
            logits = model(xb)
            p = torch.sigmoid(logits)
        ps.append(p.detach().cpu().numpy())
    out = np.concatenate(ps, axis=0).astype(np.float32)
    if ema is not None:
        ema.restore(model)
    return out

# ----------------------------
# 8) CFG merge + autoscale
# ----------------------------
# default fallback cfg (kalau best_bundle tidak ada)
CFG = dict(
    # arch
    d_model=384, n_layers=8, n_heads=8, ffn_mult=4,
    dropout=0.20, attn_dropout=0.10,

    # mHC-lite
    n_streams=4, alpha_init=0.01, sinkhorn_tmax=20, mhc_dropout=0.0,

    # regularization
    feat_token_drop_p=0.05,
    input_noise_std=0.01,
    focal_gamma=1.5,
    label_smoothing=0.00,

    # optim
    lr=2e-4,
    weight_decay=1.0e-2,
    beta1=0.9,
    beta2=0.95,
    adam_eps=1e-8,

    # sched
    warmup_frac=0.10,
    lr_decay_ratio1=0.8,
    lr_decay_ratio2=0.9,
    lr_decay_rate1=0.316,
    lr_decay_rate2=0.1,

    # train
    batch_size=512 if device.type == "cuda" else 256,
    epochs=75 if device.type == "cuda" else 35,
    accum_steps=2 if device.type == "cuda" else 1,
    grad_clip=1.0,
    patience=10,
    min_delta=1e-4,

    # EMA
    use_ema=True,
    ema_decay=0.999,
)

# merge best cfg -> CFG
for k, v in base_cfg.items():
    if k in CFG:
        CFG[k] = v

def autoscale_cfg(cfg: dict):
    cfg = dict(cfg)
    if device.type != "cuda":
        cfg["d_model"] = min(int(cfg["d_model"]), 256)
        cfg["n_layers"] = min(int(cfg["n_layers"]), 6)
        cfg["batch_size"] = min(int(cfg["batch_size"]), 256)
        cfg["epochs"] = min(int(cfg["epochs"]), 35)
        cfg["accum_steps"] = 1
        return cfg

    # CUDA scaling “besar tapi aman”
    if not WANT_BIG_MODEL or vram_gb is None:
        return cfg

    if vram_gb >= 24:
        cfg["d_model"] = max(int(cfg["d_model"]), 512)
        cfg["n_layers"] = max(int(cfg["n_layers"]), 10)
        cfg["n_heads"]  = max(int(cfg["n_heads"]), 16)
        cfg["dropout"] = max(float(cfg["dropout"]), 0.24)
        cfg["attn_dropout"] = max(float(cfg["attn_dropout"]), 0.12)
        cfg["batch_size"] = min(int(cfg["batch_size"]), 512)
        cfg["epochs"] = max(int(cfg["epochs"]), 85)
        cfg["weight_decay"] = max(float(cfg["weight_decay"]), 2e-2)
        cfg["lr"] = min(float(cfg["lr"]), 1.6e-4)
        cfg["accum_steps"] = max(int(cfg.get("accum_steps", 2)), 2)
    elif vram_gb >= 16:
        cfg["d_model"] = max(int(cfg["d_model"]), 448)
        cfg["n_layers"] = max(int(cfg["n_layers"]), 9)
        cfg["dropout"] = max(float(cfg["dropout"]), 0.22)
        cfg["attn_dropout"] = max(float(cfg["attn_dropout"]), 0.12)
        cfg["batch_size"] = min(int(cfg["batch_size"]), 512)
        cfg["epochs"] = max(int(cfg["epochs"]), 75)
        cfg["weight_decay"] = max(float(cfg["weight_decay"]), 1.5e-2)
        cfg["lr"] = min(float(cfg["lr"]), 1.8e-4)
        cfg["accum_steps"] = max(int(cfg.get("accum_steps", 2)), 2)
    else:
        cfg["d_model"] = max(int(cfg["d_model"]), 384)
        cfg["n_layers"] = max(int(cfg["n_layers"]), 8)
        cfg["batch_size"] = min(int(cfg["batch_size"]), 384)
        cfg["epochs"] = max(int(cfg["epochs"]), 70)
        cfg["accum_steps"] = max(int(cfg.get("accum_steps", 2)), 2)
    return cfg

CFG = autoscale_cfg(CFG)

# ensure effective batch near target
CFG["accum_steps"] = max(1, int(math.ceil(TARGET_EFF_BATCH / int(CFG["batch_size"]))))

print("\nCFG (final):")
for k in [
    "d_model","n_layers","n_heads","ffn_mult","dropout","attn_dropout",
    "n_streams","alpha_init","sinkhorn_tmax","mhc_dropout",
    "batch_size","accum_steps","epochs","lr","weight_decay","warmup_frac","patience"
]:
    print(f"  {k}: {CFG[k]}")

# ----------------------------
# 9) Internal val split (case_id-safe)
# ----------------------------
def make_case_split(df: pd.DataFrame, val_frac=0.08, seed=2025):
    g = df.groupby("case_id")["y"].max().reset_index().rename(columns={"y": "case_y"})
    pos_cases = g.loc[g["case_y"] == 1, "case_id"].to_numpy()
    neg_cases = g.loc[g["case_y"] == 0, "case_id"].to_numpy()

    rng = np.random.RandomState(int(seed))
    rng.shuffle(pos_cases)
    rng.shuffle(neg_cases)

    n_val_pos = max(1, int(len(pos_cases) * float(val_frac))) if len(pos_cases) else 0
    n_val_neg = max(1, int(len(neg_cases) * float(val_frac))) if len(neg_cases) else 0

    val_cases = np.concatenate([pos_cases[:n_val_pos], neg_cases[:n_val_neg]])
    val_set = set(map(str, val_cases.tolist()))
    is_val = df["case_id"].astype(str).map(lambda x: str(x) in val_set).to_numpy(dtype=bool)
    return is_val

# ----------------------------
# 10) Train with internal val to get best_epoch
# ----------------------------
def train_with_internal_val_get_best_epoch(X_raw, y_raw, cfg, seed=2025):
    seed_everything(int(seed))

    is_val = make_case_split(df_train_tabular, val_frac=float(VAL_FRAC_CASE), seed=int(seed))
    tr_idx = np.where(~is_val)[0]
    va_idx = np.where(is_val)[0]

    X_tr, y_tr = X_raw[tr_idx], y_raw[tr_idx]
    X_va, y_va = X_raw[va_idx], y_raw[va_idx]

    mu, sig = fit_standardizer(X_tr)
    X_trn = apply_standardizer(X_tr, mu, sig)
    X_van = apply_standardizer(X_va, mu, sig)

    ds_tr = TabDataset(X_trn, y_tr)
    ds_va = TabDataset(X_van, y_va)

    dl_tr = DataLoader(ds_tr, batch_size=int(cfg["batch_size"]), shuffle=True,
                       num_workers=2, pin_memory=(device.type=="cuda"), drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=int(cfg["batch_size"]), shuffle=False,
                       num_workers=2, pin_memory=(device.type=="cuda"), drop_last=False)

    model = FTTransformer_MHCLite(
        n_features=X_raw.shape[1],
        d_model=int(cfg["d_model"]),
        n_heads=int(cfg["n_heads"]),
        n_layers=int(cfg["n_layers"]),
        ffn_mult=int(cfg["ffn_mult"]),
        dropout=float(cfg["dropout"]),
        attn_dropout=float(cfg["attn_dropout"]),
        n_streams=int(cfg["n_streams"]),
        alpha_init=float(cfg["alpha_init"]),
        sinkhorn_tmax=int(cfg["sinkhorn_tmax"]),
        mhc_dropout=float(cfg["mhc_dropout"]),
        feat_token_drop_p=float(cfg.get("feat_token_drop_p", 0.0)),
    ).to(device)

    # pos_weight
    pos = float(y_tr.sum())
    neg = float(len(y_tr) - pos)
    pos_weight = torch.tensor([neg / max(1.0, pos)], device=device, dtype=torch.float32)

    focal_gamma = float(cfg.get("focal_gamma", 0.0))
    label_smoothing = float(cfg.get("label_smoothing", 0.0))
    input_noise_std = float(cfg.get("input_noise_std", 0.0))

    def loss_fn(logits, targets):
        if label_smoothing and label_smoothing > 0:
            targets = targets * (1.0 - label_smoothing) + 0.5 * label_smoothing
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none", pos_weight=pos_weight)
        if focal_gamma and focal_gamma > 0:
            p = torch.sigmoid(logits)
            p_t = p * targets + (1 - p) * (1 - targets)
            mod = (1.0 - p_t).clamp_min(0.0).pow(focal_gamma)
            bce = bce * mod
        return bce.mean()

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        weight_decay=float(cfg["weight_decay"]),
        betas=(float(cfg["beta1"]), float(cfg["beta2"])),
        eps=float(cfg["adam_eps"]),
    )

    accum = max(1, int(cfg.get("accum_steps", 1)))
    steps_per_epoch = int(math.ceil(len(dl_tr) / accum))
    total_steps = int(cfg["epochs"]) * max(1, steps_per_epoch)
    warmup_steps = int(float(cfg["warmup_frac"]) * total_steps)

    sch = make_warmup_step_scheduler(
        opt, total_steps=total_steps, warmup_steps=warmup_steps,
        r1=float(cfg["lr_decay_ratio1"]), r2=float(cfg["lr_decay_ratio2"]),
        d1=float(cfg["lr_decay_rate1"]), d2=float(cfg["lr_decay_rate2"])
    )

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    ema = EMA(model, decay=float(cfg.get("ema_decay", 0.999))) if bool(cfg.get("use_ema", True)) else None

    best_val = 1e18
    best_epoch = -1
    bad = 0

    print(f"\nInternal val split: train={len(tr_idx)} | val={len(va_idx)} | val_pos%={float(y_va.mean())*100:.2f}")

    for epoch in range(int(cfg["epochs"])):
        model.train()
        opt.zero_grad(set_to_none=True)
        loss_sum = 0.0
        n_sum = 0
        micro = 0

        for xb, yb in dl_tr:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()

            if input_noise_std and input_noise_std > 0:
                xb = xb + torch.randn_like(xb) * input_noise_std

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model(xb)
                loss = loss_fn(logits, yb)
                loss = loss / float(accum)

            scaler.scale(loss).backward()
            micro += 1

            loss_sum += float(loss.item()) * xb.size(0) * float(accum)
            n_sum += xb.size(0)

            if (micro % accum) == 0:
                if float(cfg.get("grad_clip", 1.0)) > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg.get("grad_clip", 1.0)))
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                sch.step()
                if ema is not None:
                    ema.update(model)

        # flush last partial
        if (micro % accum) != 0:
            if float(cfg.get("grad_clip", 1.0)) > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg.get("grad_clip", 1.0)))
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)
            sch.step()
            if ema is not None:
                ema.update(model)

        p_va = predict_proba(model, dl_va, ema=ema)
        vll = safe_logloss(y_va, p_va)
        vauc = safe_auc(y_va, p_va)

        improved = (best_val - vll) > float(cfg["min_delta"])
        if improved:
            best_val = float(vll)
            best_epoch = int(epoch) + 1
            bad = 0
        else:
            bad += 1

        print(f"  epoch {epoch+1:03d}/{int(cfg['epochs'])} | tr_loss={loss_sum/max(1,n_sum):.5f} | val_ll={vll:.5f} | val_auc={(vauc if vauc is not None else float('nan')):.5f} | bad={bad}")

        if EARLY_STOP and bad >= int(cfg["patience"]):
            break

        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    # best_epoch fallback
    if best_epoch < 1:
        best_epoch = max(12, int(cfg["epochs"] * 0.6))

    return {
        "best_epoch": int(best_epoch),
        "best_val_logloss": float(best_val) if best_val < 1e18 else None,
    }

# ----------------------------
# 11) Train FULL data for fixed epochs (best_epoch)
# ----------------------------
def train_full_fixed_epochs(X_raw, y_raw, cfg, epochs_fixed, seed=2025):
    seed_everything(int(seed))

    mu, sig = fit_standardizer(X_raw)
    Xn = apply_standardizer(X_raw, mu, sig)

    ds = TabDataset(Xn, y_raw)
    dl = DataLoader(ds, batch_size=int(cfg["batch_size"]), shuffle=True,
                    num_workers=2, pin_memory=(device.type=="cuda"), drop_last=False)

    model = FTTransformer_MHCLite(
        n_features=X_raw.shape[1],
        d_model=int(cfg["d_model"]),
        n_heads=int(cfg["n_heads"]),
        n_layers=int(cfg["n_layers"]),
        ffn_mult=int(cfg["ffn_mult"]),
        dropout=float(cfg["dropout"]),
        attn_dropout=float(cfg["attn_dropout"]),
        n_streams=int(cfg["n_streams"]),
        alpha_init=float(cfg["alpha_init"]),
        sinkhorn_tmax=int(cfg["sinkhorn_tmax"]),
        mhc_dropout=float(cfg["mhc_dropout"]),
        feat_token_drop_p=float(cfg.get("feat_token_drop_p", 0.0)),
    ).to(device)

    pos = float(y_raw.sum())
    neg = float(len(y_raw) - pos)
    pos_weight = torch.tensor([neg / max(1.0, pos)], device=device, dtype=torch.float32)

    focal_gamma = float(cfg.get("focal_gamma", 0.0))
    label_smoothing = float(cfg.get("label_smoothing", 0.0))
    input_noise_std = float(cfg.get("input_noise_std", 0.0))

    def loss_fn(logits, targets):
        if label_smoothing and label_smoothing > 0:
            targets = targets * (1.0 - label_smoothing) + 0.5 * label_smoothing
        bce = F.binary_cross_entropy_with_logits(logits, targets, reduction="none", pos_weight=pos_weight)
        if focal_gamma and focal_gamma > 0:
            p = torch.sigmoid(logits)
            p_t = p * targets + (1 - p) * (1 - targets)
            mod = (1.0 - p_t).clamp_min(0.0).pow(focal_gamma)
            bce = bce * mod
        return bce.mean()

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        weight_decay=float(cfg["weight_decay"]),
        betas=(float(cfg["beta1"]), float(cfg["beta2"])),
        eps=float(cfg["adam_eps"]),
    )

    accum = max(1, int(cfg.get("accum_steps", 1)))
    steps_per_epoch = int(math.ceil(len(dl) / accum))
    total_steps = int(epochs_fixed) * max(1, steps_per_epoch)
    warmup_steps = int(float(cfg["warmup_frac"]) * total_steps)

    sch = make_warmup_step_scheduler(
        opt, total_steps=total_steps, warmup_steps=warmup_steps,
        r1=float(cfg["lr_decay_ratio1"]), r2=float(cfg["lr_decay_ratio2"]),
        d1=float(cfg["lr_decay_rate1"]), d2=float(cfg["lr_decay_rate2"])
    )

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)
    ema = EMA(model, decay=float(cfg.get("ema_decay", 0.999))) if bool(cfg.get("use_ema", True)) else None

    t0 = time.time()
    for epoch in range(int(epochs_fixed)):
        model.train()
        opt.zero_grad(set_to_none=True)
        loss_sum = 0.0
        n_sum = 0
        micro = 0

        for xb, yb in dl:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()

            if input_noise_std and input_noise_std > 0:
                xb = xb + torch.randn_like(xb) * input_noise_std

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model(xb)
                loss = loss_fn(logits, yb)
                loss = loss / float(accum)

            scaler.scale(loss).backward()
            micro += 1

            loss_sum += float(loss.item()) * xb.size(0) * float(accum)
            n_sum += xb.size(0)

            if (micro % accum) == 0:
                if float(cfg.get("grad_clip", 1.0)) > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg.get("grad_clip", 1.0)))
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                sch.step()
                if ema is not None:
                    ema.update(model)

        # flush last partial
        if (micro % accum) != 0:
            if float(cfg.get("grad_clip", 1.0)) > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg.get("grad_clip", 1.0)))
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)
            sch.step()
            if ema is not None:
                ema.update(model)

        print(f"  full epoch {epoch+1:03d}/{int(epochs_fixed)} | loss={loss_sum/max(1,n_sum):.5f}")

        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    # save EMA weights if enabled (usually better)
    if ema is not None:
        ema.apply_shadow(model)

    pack = {
        "type": "mhc_lite_ft_transformer_full_v4",
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "mu": mu,
        "sig": sig,
        "cfg": dict(cfg),
        "seed": int(seed),
        "train_rows": int(len(y_raw)),
        "pos_rate": float(y_raw.mean()),
        "epochs_fixed": int(epochs_fixed),
        "accum_steps": int(accum),
        "train_time_s": float(time.time() - t0),
        "used_ema_weights": bool(ema is not None),
    }
    return pack

# ----------------------------
# 12) Train final (OOM-safe fallback)
# ----------------------------
final_full_packs = []

for s in range(int(N_SEEDS)):
    print(f"\n[Final Train v4] seed_offset={s}")

    try:
        # Phase A: get best_epoch from internal val
        if USE_INTERNAL_VAL:
            info = train_with_internal_val_get_best_epoch(X_all, y_all, CFG, seed=FINAL_SEED + s)
            best_epoch = int(info["best_epoch"])
            # retrain on full data (slightly extend by 5% for stability)
            E_FULL = int(min(int(CFG["epochs"]), max(12, round(best_epoch * 1.05))))
            print(f"\nBest_epoch(from internal val)={best_epoch} -> Retrain FULL for E_FULL={E_FULL}")
        else:
            E_FULL = int(CFG["epochs"])

        full_pack = train_full_fixed_epochs(X_all, y_all, CFG, epochs_fixed=E_FULL, seed=FINAL_SEED + s)

    except RuntimeError as e:
        msg = str(e).lower()
        if ("out of memory" in msg) and device.type == "cuda":
            print("  OOM detected. Applying fallback (batch_size/width/depth downshift).")
            torch.cuda.empty_cache()

            CFG["batch_size"] = max(128, int(CFG["batch_size"]) // 2)
            CFG["d_model"] = max(256, int(CFG["d_model"]) - 64)
            CFG["n_layers"] = max(6, int(CFG["n_layers"]) - 2)
            CFG["accum_steps"] = max(1, int(math.ceil(TARGET_EFF_BATCH / int(CFG["batch_size"]))))

            print("  New CFG after fallback:")
            for k in ["d_model","n_layers","n_heads","batch_size","accum_steps","epochs"]:
                print(f"    {k}: {CFG[k]}")

            if USE_INTERNAL_VAL:
                info = train_with_internal_val_get_best_epoch(X_all, y_all, CFG, seed=FINAL_SEED + s)
                best_epoch = int(info["best_epoch"])
                E_FULL = int(min(int(CFG["epochs"]), max(12, round(best_epoch * 1.05))))
                print(f"\nBest_epoch={best_epoch} -> Retrain FULL for E_FULL={E_FULL}")
            else:
                E_FULL = int(CFG["epochs"])

            full_pack = train_full_fixed_epochs(X_all, y_all, CFG, epochs_fixed=E_FULL, seed=FINAL_SEED + s)
        else:
            raise

    final_full_packs.append(full_pack)
    gc.collect()
    if device.type == "cuda":
        torch.cuda.empty_cache()

# ----------------------------
# 13) Save artifacts
# ----------------------------
final_model_path = OUT_DIR / "final_gate_model.pt"

# threshold recommended from Step 4 (OOF)
best_thr = None
if isinstance(best_bundle, dict):
    best_thr = best_bundle.get("oof_best_thr", None)

torch.save(
    {
        "type": "final_gate_v4",
        "feature_cols": FEATURE_COLS,

        # keep fold ensemble from Step 4 if available (strong option for inference)
        "fold_packs": fold_packs_from_step4,

        # full-data trained packs (list; usually 1 seed)
        "full_packs": final_full_packs,

        "recommended_thr": best_thr,
        "bundle_source": source,
    },
    final_model_path
)

final_bundle = {
    "type": "final_gate_v4",
    "feature_cols": FEATURE_COLS,
    "n_seeds": int(N_SEEDS),
    "seeds": [int(p["seed"]) for p in final_full_packs],
    "cfg": dict(CFG),

    "use_internal_val": bool(USE_INTERNAL_VAL),
    "val_frac_case": float(VAL_FRAC_CASE) if USE_INTERNAL_VAL else 0.0,
    "early_stop": bool(EARLY_STOP) if USE_INTERNAL_VAL else False,

    "train_rows": int(len(y_all)),
    "pos_rate": float(y_all.mean()),

    "recommended_thr": best_thr,
    "has_fold_packs_from_step4": bool(fold_packs_from_step4 is not None),
    "notes": "Final model uses Step4-matched mHC-lite FTTransformer with EMA+accum and full-data retrain using best_epoch from internal val.",
}

final_bundle_path = OUT_DIR / "final_gate_bundle.json"
final_bundle_path.write_text(json.dumps(final_bundle, indent=2))

print("\nSaved final training artifacts:")
print("  model  ->", final_model_path)
print("  bundle ->", final_bundle_path)

# Export globals
FINAL_GATE_MODEL_PT = str(final_model_path)
FINAL_GATE_BUNDLE = final_bundle


# Finalize & Save Model Bundle (Reproducible)

In [None]:
# ============================================================
# Step 6 — Finalize & Save Model Bundle (Reproducible) — REVISI FULL v4 (TRANSFORMER COMPAT)
# - Compatible with Step 5 v4: final_gate_model.pt contains {fold_packs, full_packs, recommended_thr}
# - Bundle artifacts + thresholds + manifest + ZIP portable (no submission)
#
# REQUIRE:
# - Step 2: feature_cols.json (atau feature_cols.json alternative path)
# - Step 5: final_gate_model.pt + final_gate_bundle.json
# ============================================================

import os, json, time, platform, warnings, zipfile, hashlib
from pathlib import Path

warnings.filterwarnings("ignore")

OUT_DIR = Path("/kaggle/working/recodai_luc_gate_artifacts")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# Helpers
# ----------------------------
def read_json_safe(p: Path, default=None):
    try:
        return json.loads(Path(p).read_text())
    except Exception:
        return default

def pick_first_existing(paths):
    for p in paths:
        if p is None:
            continue
        p = Path(p)
        if p.exists() and p.is_file():
            return p
    return None

def safe_add(zf: zipfile.ZipFile, p: Path, arcname: str):
    if p is None:
        return
    p = Path(p)
    if p.exists() and p.is_file():
        zf.write(p, arcname=arcname)

def sha256_file(p: Path, chunk=1024 * 1024):
    p = Path(p)
    h = hashlib.sha256()
    with p.open("rb") as f:
        while True:
            b = f.read(chunk)
            if not b:
                break
            h.update(b)
    return h.hexdigest()

def file_meta(p: Path):
    p = Path(p)
    if not p.exists() or not p.is_file():
        return None
    return {
        "path": str(p),
        "name": p.name,
        "bytes": int(p.stat().st_size),
        "sha256": sha256_file(p),
        "mtime_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime(p.stat().st_mtime)),
    }

# ----------------------------
# 0) Locate required artifacts
# ----------------------------
final_model_pt   = OUT_DIR / "final_gate_model.pt"
final_bundle_json = OUT_DIR / "final_gate_bundle.json"

# feature_cols: Step 2 biasanya menyimpan ini
feature_cols_path = pick_first_existing([
    OUT_DIR / "feature_cols.json",
    OUT_DIR / "model_bundle_pack_feature_cols.json",  # (optional legacy)
])

if feature_cols_path is None:
    raise FileNotFoundError("Missing feature_cols.json (jalankan Step 2 dulu).")
if not final_model_pt.exists():
    raise FileNotFoundError(f"Missing final model: {final_model_pt} (jalankan Step 5 dulu).")

model_format = "torch_pt"
final_model_path = final_model_pt

# Optional / extras
baseline_report_path = pick_first_existing([
    OUT_DIR / "baseline_mhc_transformer_cv_report.json",   # Step 3 (yang paling mungkin)
    OUT_DIR / "baseline_mhc_transformer_cv_report_v2.json",
    OUT_DIR / "baseline_transformer_cv_report.json",
    OUT_DIR / "baseline_cv_report.json",
])

best_gate_config_path = pick_first_existing([
    OUT_DIR / "best_gate_config.json",  # Step 4
])

best_gate_model_path = pick_first_existing([
    OUT_DIR / "best_gate_model.pt",     # Step 4
])

opt_results_csv = pick_first_existing([
    OUT_DIR / "opt_search" / "opt_results.csv",
])

opt_fold_csv = pick_first_existing([
    OUT_DIR / "opt_search" / "opt_fold_details.csv",
])

# OOF optional
oof_baseline_csv = pick_first_existing([
    OUT_DIR / "oof_baseline_mhc_transformer.csv",
    OUT_DIR / "oof_baseline_transformer.csv",
    OUT_DIR / "oof_baseline.csv",
])

print("Found artifacts:")
print("  final_model        :", final_model_path, f"(format={model_format})")
print("  final_bundle       :", final_bundle_json if final_bundle_json.exists() else "(missing/skip)")
print("  feature_cols       :", feature_cols_path)
print("  best_gate_config   :", best_gate_config_path if best_gate_config_path else "(missing/skip)")
print("  best_gate_model    :", best_gate_model_path if best_gate_model_path else "(missing/skip)")
print("  baseline_report    :", baseline_report_path if baseline_report_path else "(missing/skip)")
print("  opt_results_csv    :", opt_results_csv if opt_results_csv else "(missing/skip)")
print("  opt_fold_csv       :", opt_fold_csv if opt_fold_csv else "(missing/skip)")
print("  oof_baseline_csv   :", oof_baseline_csv if oof_baseline_csv else "(missing/skip)")

# ----------------------------
# 1) Load metadata
# ----------------------------
feature_cols = read_json_safe(feature_cols_path, default=[])
if not isinstance(feature_cols, list) or len(feature_cols) == 0:
    raise ValueError(f"feature_cols invalid/empty: {feature_cols_path}")

final_bundle = read_json_safe(final_bundle_json, default={}) if final_bundle_json.exists() else {}
baseline_report = read_json_safe(baseline_report_path, default=None) if baseline_report_path else None
best_gate_config = read_json_safe(best_gate_config_path, default=None) if best_gate_config_path else None

# Try read recommended_thr from final_gate_model.pt (Step 5 v4 puts it there)
recommended_thr_from_pt = None
try:
    import torch
    obj = torch.load(final_model_pt, map_location="cpu")
    if isinstance(obj, dict):
        recommended_thr_from_pt = obj.get("recommended_thr", None)
except Exception as e:
    print("Warning: failed to read final_gate_model.pt for recommended_thr:", repr(e))

# ----------------------------
# 2) Thresholds (robust priority)
# Priority:
#   (a) thresholds.json (if exists)
#   (b) final_gate_bundle.json -> recommended_thr
#   (c) final_gate_model.pt -> recommended_thr
#   (d) best_gate_config.json -> oof_best_thr (new) OR selection.oof_best_thr (legacy)
#   (e) fallback 0.5
# ----------------------------
thresholds_path = OUT_DIR / "thresholds.json"

def extract_thr_from_best_gate_config(cfg: dict):
    if not isinstance(cfg, dict):
        return None
    # new style
    if "oof_best_thr" in cfg:
        try:
            return float(cfg["oof_best_thr"])
        except Exception:
            pass
    # legacy: selection.oof_best_thr
    sel = cfg.get("selection", None)
    if isinstance(sel, dict) and ("oof_best_thr" in sel):
        try:
            return float(sel["oof_best_thr"])
        except Exception:
            pass
    return None

if thresholds_path.exists():
    thresholds = read_json_safe(thresholds_path, default={})
    if not isinstance(thresholds, dict) or ("T_gate" not in thresholds):
        thresholds = {}
else:
    thresholds = {}

T_gate = None
# (a) existing thresholds.json
if isinstance(thresholds, dict):
    try:
        if thresholds.get("T_gate", None) is not None:
            T_gate = float(thresholds["T_gate"])
    except Exception:
        T_gate = None

# (b) final_bundle recommended_thr
if T_gate is None and isinstance(final_bundle, dict):
    try:
        if final_bundle.get("recommended_thr", None) is not None:
            T_gate = float(final_bundle["recommended_thr"])
    except Exception:
        T_gate = None

# (c) final_model.pt recommended_thr
if T_gate is None and recommended_thr_from_pt is not None:
    try:
        T_gate = float(recommended_thr_from_pt)
    except Exception:
        T_gate = None

# (d) best_gate_config oof_best_thr
if T_gate is None and isinstance(best_gate_config, dict):
    T_gate = extract_thr_from_best_gate_config(best_gate_config)

# (e) fallback
if T_gate is None:
    T_gate = 0.5

# write thresholds.json in a stable schema
thresholds = {
    "T_gate": float(T_gate),
    "beta_for_tuning": float(best_gate_config.get("beta_for_tuning", 0.5)) if isinstance(best_gate_config, dict) else 0.5,
    "guards": {
        "min_area_frac": None,
        "max_area_frac": None,
        "max_components": None,
    },
    "source_priority": [
        "thresholds.json (existing)",
        "final_gate_bundle.json.recommended_thr",
        "final_gate_model.pt.recommended_thr",
        "best_gate_config.json.oof_best_thr (or selection.oof_best_thr legacy)",
        "fallback 0.5",
    ],
    "notes": "Gate threshold used for binary decision. Update after calibration/OOF tuning if needed.",
}
thresholds_path.write_text(json.dumps(thresholds, indent=2))

# ----------------------------
# 3) Capture dataset/cfg metadata (if available)
# ----------------------------
cfg_meta = {}
if "PATHS" in globals() and isinstance(PATHS, dict):
    cfg_meta = {
        "COMP_ROOT": PATHS.get("COMP_ROOT", None),
        "OUT_DS_ROOT": PATHS.get("OUT_DS_ROOT", None),
        "OUT_ROOT": PATHS.get("OUT_ROOT", None),
        "MATCH_CFG_DIR": PATHS.get("MATCH_CFG_DIR", None),
        "PRED_CFG_DIR": PATHS.get("PRED_CFG_DIR", None),
        "DINO_CFG_DIR": PATHS.get("DINO_CFG_DIR", None),
        "DINO_LARGE_DIR": PATHS.get("DINO_LARGE_DIR", None),
        "PRED_FEAT_TRAIN": PATHS.get("PRED_FEAT_TRAIN", None),
        "MATCH_FEAT_TRAIN": PATHS.get("MATCH_FEAT_TRAIN", None),
        "DF_TRAIN_ALL": PATHS.get("DF_TRAIN_ALL", None),
        "CV_CASE_FOLDS": PATHS.get("CV_CASE_FOLDS", None),
        "IMG_PROFILE_TRAIN": PATHS.get("IMG_PROFILE_TRAIN", None),
    }

# ----------------------------
# 4) Manifest (reproducible)
# ----------------------------
task_str = "Recod.ai/LUC — Gate Model — DINOv2 features + Transformer gate (.pt)"
bundle_version = "v4"

# artifact list for hashing
artifact_paths = [
    final_model_path,
    final_bundle_json if final_bundle_json.exists() else None,
    feature_cols_path,
    thresholds_path,
    baseline_report_path,
    best_gate_config_path,
    best_gate_model_path,
    opt_results_csv,
    opt_fold_csv,
    oof_baseline_csv,
]
artifact_paths = [p for p in artifact_paths if p is not None]

artifacts_meta = {}
for p in artifact_paths:
    m = file_meta(p)
    if m is not None:
        artifacts_meta[p.name] = m

# summaries (robust: handle new/old keys)
opt_summary = None
if isinstance(best_gate_config, dict):
    opt_summary = best_gate_config.get("selection", None)
    if opt_summary is None:
        # new style: keys directly on dict
        opt_summary = {
            "model_name": best_gate_config.get("model_name", None),
            "oof_best_thr": best_gate_config.get("oof_best_thr", None),
            "oof_best_fbeta": best_gate_config.get("oof_best_fbeta", None),
            "oof_auc": best_gate_config.get("oof_auc", None),
            "oof_logloss": best_gate_config.get("oof_logloss", None),
        }

manifest = {
    "created_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
    "python": platform.python_version(),
    "platform": platform.platform(),
    "bundle_version": bundle_version,
    "task": task_str,
    "model_format": model_format,
    "artifacts_index": artifacts_meta,
    "cfg_meta": cfg_meta,
    "model_summary": {
        "type": (final_bundle.get("type") if isinstance(final_bundle, dict) else None),
        "n_seeds": (final_bundle.get("n_seeds") if isinstance(final_bundle, dict) else None),
        "seeds": (final_bundle.get("seeds") if isinstance(final_bundle, dict) else None),
        "train_rows": (final_bundle.get("train_rows") if isinstance(final_bundle, dict) else None),
        "pos_rate": (final_bundle.get("pos_rate") if isinstance(final_bundle, dict) else None),
        "feature_count": int(len(feature_cols)),
        "T_gate": float(thresholds.get("T_gate", 0.5)),
        "recommended_thr_from_pt": recommended_thr_from_pt,
    },
    "baseline_summary": (baseline_report.get("overall") if isinstance(baseline_report, dict) else None),
    "opt_summary": opt_summary,
}

manifest_path = OUT_DIR / "model_bundle_manifest.json"
manifest_path.write_text(json.dumps(manifest, indent=2))

# ----------------------------
# 5) Bundle pack (portable JSON) + optional joblib
# ----------------------------
bundle_pack = {
    "bundle_version": bundle_version,
    "model_format": model_format,
    "final_model_path": str(final_model_path),
    "final_bundle": final_bundle,
    "feature_cols": feature_cols,
    "thresholds": thresholds,
    "cfg_meta": cfg_meta,
    "manifest": manifest,
}

bundle_pack_json = OUT_DIR / "model_bundle_pack.json"
bundle_pack_json.write_text(json.dumps(bundle_pack, indent=2))

bundle_pack_joblib = OUT_DIR / "model_bundle_pack.joblib"
joblib_ok = False
try:
    import joblib
    joblib.dump(bundle_pack, bundle_pack_joblib)
    joblib_ok = True
except Exception:
    joblib_ok = False

# ----------------------------
# 6) Create portable ZIP
# ----------------------------
zip_path = OUT_DIR / "model_bundle_v4.zip"

with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
    # core
    safe_add(zf, final_model_path, final_model_path.name)
    safe_add(zf, final_bundle_json if final_bundle_json.exists() else None, final_bundle_json.name)
    safe_add(zf, feature_cols_path, feature_cols_path.name)
    safe_add(zf, thresholds_path, thresholds_path.name)
    safe_add(zf, manifest_path, manifest_path.name)
    safe_add(zf, bundle_pack_json, bundle_pack_json.name)
    if joblib_ok:
        safe_add(zf, bundle_pack_joblib, bundle_pack_joblib.name)

    # extras
    safe_add(zf, baseline_report_path, baseline_report_path.name if baseline_report_path else "baseline_report.json")
    safe_add(zf, best_gate_config_path, best_gate_config_path.name if best_gate_config_path else "best_gate_config.json")
    safe_add(zf, best_gate_model_path, best_gate_model_path.name if best_gate_model_path else "best_gate_model.pt")

    if opt_results_csv:
        safe_add(zf, opt_results_csv, f"opt_search/{opt_results_csv.name}")
    if opt_fold_csv:
        safe_add(zf, opt_fold_csv, f"opt_search/{opt_fold_csv.name}")

    if oof_baseline_csv:
        safe_add(zf, oof_baseline_csv, f"oof/{oof_baseline_csv.name}")

print("\nOK — Model bundle finalized")
print("  manifest     ->", manifest_path)
print("  pack (json)  ->", bundle_pack_json)
print("  pack (joblib)->", (bundle_pack_joblib if joblib_ok else "(skip; joblib not available)"))
print("  thresholds   ->", thresholds_path)
print("  zip          ->", zip_path)

print("\nBundle summary:")
print("  bundle_version:", bundle_version)
print("  model_format  :", model_format)
print("  feature_cnt   :", len(feature_cols))
print("  T_gate        :", thresholds.get("T_gate"))
print("  task          :", task_str)
