# Set Paths & Select Config (CFG)

In [1]:
# ============================================================
# STAGE 0 — Set Paths & Select Config (CFG) (Kaggle-ready, offline) — REVISI FULL (anti-error)
# - English step name, Indonesian explanations.
#
# Tujuan:
# - Deteksi root kompetisi (COMP_ROOT)
# - Deteksi root output dataset hasil PREP (OUT_DS_ROOT) + OUT_ROOT (= .../recodai_luc)
# - Auto-pilih CFG terbaik untuk MATCH + PRED (berdasarkan coverage rows features train)
# - Deteksi CFG DINO cache (opsional) + simpan path model DINOv2-LARGE offline
#
# Output globals (dipakai step berikutnya, JANGAN diganti namanya):
# - COMP_ROOT, OUT_DS_ROOT, OUT_ROOT
# - PATHS (dict jalur penting)
# - MATCH_CFG_DIR, PRED_CFG_DIR, DINO_CFG_DIR
# ============================================================

import os, re, json
from pathlib import Path
import pandas as pd

# ----------------------------
# Helper: find competition root
# ----------------------------
def find_comp_root(preferred: str = "/kaggle/input/recodai-luc-scientific-image-forgery-detection") -> Path:
    p = Path(preferred)
    if p.exists():
        return p

    base = Path("/kaggle/input")
    if not base.exists():
        raise FileNotFoundError("/kaggle/input tidak ditemukan (pastikan kamu di Kaggle Notebook).")

    cands = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        # Heuristic kompetisi: ada sample_submission.csv dan folder train/test images
        if (d / "sample_submission.csv").exists() and ((d / "train_images").exists() or (d / "test_images").exists()):
            cands.append(d)

    if not cands:
        # fallback: cari yang nested satu-level
        for d in base.iterdir():
            if not d.is_dir():
                continue
            inner = [x for x in d.iterdir() if x.is_dir()]
            for x in inner:
                if (x / "sample_submission.csv").exists() and ((x / "train_images").exists() or (x / "test_images").exists()):
                    cands.append(x)

    if not cands:
        raise FileNotFoundError(
            "COMP_ROOT tidak ditemukan. Harus ada folder yang memuat sample_submission.csv dan train_images/test_images."
        )

    # prefer yang mengandung kata kunci
    cands.sort(key=lambda x: (("recodai" not in x.name.lower()), ("forgery" not in x.name.lower()), x.name))
    return cands[0]

# ----------------------------
# Helper: find output dataset root (hasil PREP)
# ----------------------------
def find_output_dataset_root(preferred_names=(
    "recod-ailuc-dinov2-base",
    "recod-ai-luc-dinov2-base",
    "recodai-luc-dinov2-base",
    "recodai-luc-dinov2",
    "recodai-luc-dinov2-prep",
)) -> Path:
    base = Path("/kaggle/input")

    # direct hit
    for nm in preferred_names:
        p = base / nm
        if p.exists():
            return p

    # scan: cari yang punya recodai_luc/artifacts
    cands = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        if (d / "recodai_luc" / "artifacts").exists():
            cands.append(d)
            continue
        inner = list(d.glob("*/recodai_luc/artifacts"))
        if inner:
            cands.append(d)

    if not cands:
        raise FileNotFoundError(
            "OUT_DS_ROOT tidak ditemukan. Harus ada /kaggle/input/<...>/recodai_luc/artifacts/"
        )

    # prefer yang mengandung 'dinov2'
    cands.sort(key=lambda x: (("dinov2" not in x.name.lower()), x.name))
    return cands[0]

# ----------------------------
# Helper: resolve OUT_ROOT = <dataset>/recodai_luc
# ----------------------------
def resolve_out_root(out_ds_root: Path) -> Path:
    direct = out_ds_root / "recodai_luc"
    if direct.exists():
        return direct
    hits = list(out_ds_root.glob("*/recodai_luc"))
    if hits:
        return hits[0]
    raise FileNotFoundError(f"Folder recodai_luc tidak ditemukan di bawah {out_ds_root}")

# ----------------------------
# Helper: pick best cfg directory by train feature coverage (row count)
# ----------------------------
def _fast_count_rows_csv(path: Path) -> int:
    # hitung rows CSV (tanpa pandas) supaya cepat & hemat memori
    try:
        with path.open("r", encoding="utf-8", errors="ignore") as f:
            # -1 untuk header
            n = sum(1 for _ in f) - 1
        return int(max(n, 0))
    except Exception:
        return -1

def pick_best_cfg(cache_root: Path, prefix: str, feat_train_filename: str) -> Path:
    """
    cache_root: .../recodai_luc/cache
    prefix: contoh 'match_base_cfg_' atau 'pred_base'
    feat_train_filename: contoh 'match_features_train_all.csv'
    """
    if not cache_root.exists():
        raise FileNotFoundError(f"cache_root tidak ditemukan: {cache_root}")

    cands = []
    for d in cache_root.iterdir():
        if not d.is_dir():
            continue
        if not d.name.startswith(prefix):
            continue
        feat_path = d / feat_train_filename
        if not feat_path.exists():
            continue
        n = _fast_count_rows_csv(feat_path)
        cands.append((n, d, feat_path))

    if not cands:
        raise FileNotFoundError(
            f"Tidak ada CFG folder di {cache_root} dengan prefix='{prefix}' dan file '{feat_train_filename}'."
        )

    # pilih coverage terbesar, tie-break nama
    cands.sort(key=lambda x: (-x[0], x[1].name))
    best_n, best_dir, best_feat = cands[0]
    return best_dir

# ----------------------------
# 0) Locate roots
# ----------------------------
COMP_ROOT = find_comp_root("/kaggle/input/recodai-luc-scientific-image-forgery-detection")
OUT_DS_ROOT = find_output_dataset_root()
OUT_ROOT = resolve_out_root(OUT_DS_ROOT)  # .../recodai_luc

ART_DIR = OUT_ROOT / "artifacts"
CACHE_DIR = OUT_ROOT / "cache"

if not ART_DIR.exists():
    raise FileNotFoundError(f"ART_DIR tidak ditemukan: {ART_DIR}")
if not CACHE_DIR.exists():
    raise FileNotFoundError(f"CACHE_DIR tidak ditemukan: {CACHE_DIR}")

# ----------------------------
# 1) Competition paths (raw images/masks)
# ----------------------------
PATHS = {}
PATHS["COMP_ROOT"] = str(COMP_ROOT)
PATHS["SAMPLE_SUB"] = str(COMP_ROOT / "sample_submission.csv")

PATHS["TRAIN_IMAGES"] = str(COMP_ROOT / "train_images")
PATHS["TEST_IMAGES"]  = str(COMP_ROOT / "test_images")
PATHS["TRAIN_MASKS"]  = str(COMP_ROOT / "train_masks")
PATHS["SUPP_IMAGES"]  = str(COMP_ROOT / "supplemental_images")
PATHS["SUPP_MASKS"]   = str(COMP_ROOT / "supplemental_masks")

# opsional: jika train_images dibagi authentic/forged
PATHS["TRAIN_AUTH_DIR"] = str(COMP_ROOT / "train_images" / "authentic")
PATHS["TRAIN_FORG_DIR"] = str(COMP_ROOT / "train_images" / "forged")

# ----------------------------
# 2) Output dataset paths (clean artifacts + cache)
# ----------------------------
PATHS["OUT_DS_ROOT"] = str(OUT_DS_ROOT)
PATHS["OUT_ROOT"]    = str(OUT_ROOT)
PATHS["ART_DIR"]     = str(ART_DIR)
PATHS["CACHE_DIR"]   = str(CACHE_DIR)

# artifacts utama
PATHS["DF_TRAIN_ALL"] = str(ART_DIR / "df_train_all.parquet")
PATHS["DF_TRAIN_CLS"] = str(ART_DIR / "df_train_cls.parquet")
PATHS["DF_TRAIN_SEG"] = str(ART_DIR / "df_train_seg.parquet")
PATHS["DF_TEST"]      = str(ART_DIR / "df_test.parquet")

PATHS["CV_CASE_FOLDS"]   = str(ART_DIR / "cv_case_folds.csv")
PATHS["CV_SAMPLE_FOLDS"] = str(ART_DIR / "cv_sample_folds.csv")

PATHS["IMG_PROFILE_TRAIN"] = str(ART_DIR / "image_profile_train.parquet")
PATHS["IMG_PROFILE_TEST"]  = str(ART_DIR / "image_profile_test.parquet")
PATHS["MASK_PROFILE"]      = str(ART_DIR / "mask_profile.parquet")
PATHS["CASE_SUMMARY"]      = str(ART_DIR / "case_summary.parquet")

# ----------------------------
# 3) Select best MATCH/PRED CFG dirs automatically
# ----------------------------
# MATCH: match_base_cfg_<hash>/match_features_train_all.csv
MATCH_CFG_DIR = pick_best_cfg(
    CACHE_DIR,
    prefix="match_base_cfg_",
    feat_train_filename="match_features_train_all.csv",
)

# PRED: pred_base.../pred_features_train_all.csv
PRED_CFG_DIR = pick_best_cfg(
    CACHE_DIR,
    prefix="pred_base",
    feat_train_filename="pred_features_train_all.csv",
)

# DINO cache cfg (opsional): cache/dino_v2_large/cfg_*/manifest_train_all.csv
DINO_CFG_DIR = None
dino_root = CACHE_DIR / "dino_v2_large"
if dino_root.exists():
    dino_cands = []
    for d in dino_root.iterdir():
        if d.is_dir() and d.name.startswith("cfg_") and (d / "manifest_train_all.csv").exists():
            dino_cands.append(d)
    if dino_cands:
        # pilih yang paling "lengkap" berdasarkan rows manifest_train_all
        scored = []
        for d in dino_cands:
            mf = d / "manifest_train_all.csv"
            scored.append((_fast_count_rows_csv(mf), d))
        scored.sort(key=lambda x: (-x[0], x[1].name))
        DINO_CFG_DIR = scored[0][1]

# simpan cfg dir ke PATHS
PATHS["MATCH_CFG_DIR"] = str(MATCH_CFG_DIR)
PATHS["PRED_CFG_DIR"]  = str(PRED_CFG_DIR)
PATHS["DINO_CFG_DIR"]  = str(DINO_CFG_DIR) if DINO_CFG_DIR else ""

# feature paths dari cfg terpilih
PATHS["MATCH_FEAT_TRAIN"] = str(MATCH_CFG_DIR / "match_features_train_all.csv")
PATHS["MATCH_FEAT_TEST"]  = str(MATCH_CFG_DIR / "match_features_test.csv")
PATHS["PRED_FEAT_TRAIN"]  = str(PRED_CFG_DIR / "pred_features_train_all.csv")
PATHS["PRED_FEAT_TEST"]   = str(PRED_CFG_DIR / "pred_features_test.csv")

# ----------------------------
# 4) DINOv2-LARGE model path (offline) — dipakai step training/infer berikutnya
# ----------------------------
DINO_LARGE_DIR = Path("/kaggle/input/dinov2/pytorch/large/1")
PATHS["DINO_LARGE_DIR"] = str(DINO_LARGE_DIR)

# ----------------------------
# 5) Sanity checks (wajib ada)
# ----------------------------
must_exist = [
    ("sample_submission.csv", PATHS["SAMPLE_SUB"]),
    ("df_train_all.parquet",  PATHS["DF_TRAIN_ALL"]),
    ("cv_case_folds.csv",     PATHS["CV_CASE_FOLDS"]),
    ("match_features_train_all.csv", PATHS["MATCH_FEAT_TRAIN"]),
    ("pred_features_train_all.csv",  PATHS["PRED_FEAT_TRAIN"]),
]
missing = [name for name, p in must_exist if not Path(p).exists()]
if missing:
    raise FileNotFoundError("Missing required files: " + ", ".join(missing))

# DINO model dir opsional tapi biasanya kamu butuh; jadi hanya warning (tidak hard fail)
if not DINO_LARGE_DIR.exists():
    print(f"WARNING: DINOv2-Large dir tidak ditemukan: {DINO_LARGE_DIR} (kalau butuh backbone, pastikan input ada)")

# ----------------------------
# 6) Print summary (konsisten dengan step-step kamu)
# ----------------------------
print("OK — Roots")
print("  COMP_ROOT   :", COMP_ROOT)
print("  OUT_DS_ROOT :", OUT_DS_ROOT)
print("  OUT_ROOT    :", OUT_ROOT)

print("\nOK — Selected CFG")
print("  MATCH_CFG_DIR:", MATCH_CFG_DIR.name)
print("  PRED_CFG_DIR :", PRED_CFG_DIR.name)
print("  DINO_CFG_DIR :", (DINO_CFG_DIR.name if DINO_CFG_DIR else "(not found / optional)"))

print("\nOK — Key files")
for k in ["DF_TRAIN_ALL", "CV_CASE_FOLDS", "MATCH_FEAT_TRAIN", "PRED_FEAT_TRAIN", "IMG_PROFILE_TRAIN"]:
    p = Path(PATHS[k])
    print(f"  {k:16s}: {p}  {'(exists)' if p.exists() else '(missing/optional)'}")

print("\nOK — DINO model dir")
print("  DINO_LARGE_DIR:", DINO_LARGE_DIR, "(exists)" if DINO_LARGE_DIR.exists() else "(missing)")


OK — Roots
  COMP_ROOT   : /kaggle/input/recodai-luc-scientific-image-forgery-detection
  OUT_DS_ROOT : /kaggle/input/recod-ailuc-dinov2-base
  OUT_ROOT    : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc

OK — Selected CFG
  MATCH_CFG_DIR: match_base_cfg_f9f7ea3a65c5
  PRED_CFG_DIR : pred_base_v3_v7_cfg_5dbf0aa165
  DINO_CFG_DIR : cfg_3246fd54aab0

OK — Key files
  DF_TRAIN_ALL    : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/df_train_all.parquet  (exists)
  CV_CASE_FOLDS   : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/cv_case_folds.csv  (exists)
  MATCH_FEAT_TRAIN: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/match_base_cfg_f9f7ea3a65c5/match_features_train_all.csv  (exists)
  PRED_FEAT_TRAIN : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/pred_base_v3_v7_cfg_5dbf0aa165/pred_features_train_all.csv  (exists)
  IMG_PROFILE_TRAIN: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/image_profile_train.parquet  (exists)

OK 

# Build Training Table (X, y, folds)

In [2]:
# ============================================================
# STEP 2 — Build Training Table (X, y, folds) — REVISI FULL (Transformer-ready, robust)
# - Fokus: siapkan df_train_tabular + FEATURE_COLS
# - Sumber utama: pred_features + (opsional) match_features + (opsional) image_profile
# - Split: gunakan cv_case_folds.csv (anti leakage, by case_id)
# - Tidak ada submission di sini
#
# Output globals:
# - df_train_tabular, FEATURE_COLS
# - (opsional) X_train, y_train, folds (pandas series/df)
#
# Saved:
# - /kaggle/working/recodai_luc_gate_artifacts/feature_cols.json
# - /kaggle/working/recodai_luc_gate_artifacts/feature_schema.json
# - /kaggle/working/recodai_luc_gate_artifacts/df_train_tabular.parquet
# ============================================================

import os, json, math, gc, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

# ----------------------------
# 0) Require PATHS
# ----------------------------
if "PATHS" not in globals() or not isinstance(PATHS, dict):
    raise RuntimeError("Missing PATHS. Jalankan dulu STAGE 0 — Set Paths & Select Config (CFG).")

# ----------------------------
# 1) Check DINOv2 Large local path (offline) (hanya cek exist)
# ----------------------------
DINO_LARGE_DIR = Path(PATHS.get("DINO_LARGE_DIR", "/kaggle/input/dinov2/pytorch/large/1"))
if not DINO_LARGE_DIR.exists():
    raise FileNotFoundError(f"DINOv2-Large path not found: {DINO_LARGE_DIR}")
PATHS["DINO_LARGE_DIR"] = str(DINO_LARGE_DIR)

# ----------------------------
# 2) Feature Engineering Config (fleksibel)
# ----------------------------
FE_CFG = {
    "use_match_features": True,
    "use_image_profile": True,

    "add_log_features": True,
    "add_interactions": True,
    "drop_constant_features": True,

    # outlier control
    "clip_by_quantile": True,
    "clip_q": 0.999,              # p99.9 cap
    "clip_max_fallback": 1e9,      # fallback cap jika quantile gagal

    # fill
    "fillna_value": 0.0,

    # dtype
    "cast_float32": True,
}

# ----------------------------
# 3) Prefer WORKING features if exist (kalau kamu regen di /kaggle/working)
# ----------------------------
def prefer_working(input_path: str, working_candidate: str | None = None) -> Path:
    p_in = Path(input_path)
    if working_candidate is not None:
        p_w = Path(working_candidate)
        if p_w.exists():
            return p_w
    return p_in

match_cfg_name = Path(PATHS["MATCH_CFG_DIR"]).name if PATHS.get("MATCH_CFG_DIR") else ""
pred_cfg_name  = Path(PATHS["PRED_CFG_DIR"]).name  if PATHS.get("PRED_CFG_DIR") else ""

WORK_ROOT = Path("/kaggle/working/recodai_luc/cache")
match_feat_work = WORK_ROOT / match_cfg_name / "match_features_train_all.csv"
pred_feat_work  = WORK_ROOT / pred_cfg_name  / "pred_features_train_all.csv"

PRED_FEAT_TRAIN  = prefer_working(PATHS["PRED_FEAT_TRAIN"],  str(pred_feat_work))
MATCH_FEAT_TRAIN = prefer_working(PATHS["MATCH_FEAT_TRAIN"], str(match_feat_work))

DF_TRAIN_ALL      = Path(PATHS["DF_TRAIN_ALL"])
CV_CASE_FOLDS     = Path(PATHS["CV_CASE_FOLDS"])
IMG_PROFILE_TRAIN = Path(PATHS.get("IMG_PROFILE_TRAIN", ""))

for need_name, need_path in [
    ("df_train_all.parquet", DF_TRAIN_ALL),
    ("cv_case_folds.csv", CV_CASE_FOLDS),
    ("pred_features_train_all.csv", PRED_FEAT_TRAIN),
]:
    if not need_path.exists():
        raise FileNotFoundError(f"Missing required file: {need_name} -> {need_path}")

print("Using:")
print("  DF_TRAIN_ALL      :", DF_TRAIN_ALL)
print("  CV_CASE_FOLDS     :", CV_CASE_FOLDS)
print("  PRED_FEAT_TRAIN   :", PRED_FEAT_TRAIN)
print("  MATCH_FEAT_TRAIN  :", MATCH_FEAT_TRAIN, "(optional)" if MATCH_FEAT_TRAIN.exists() else "(missing/skip)")
print("  IMG_PROFILE_TRAIN :", IMG_PROFILE_TRAIN, "(optional)" if IMG_PROFILE_TRAIN.exists() else "(missing/skip)")
print("  DINO_LARGE_DIR    :", DINO_LARGE_DIR)

# ----------------------------
# 4) Load minimal inputs
# ----------------------------
df_base = pd.read_parquet(DF_TRAIN_ALL)
df_cv   = pd.read_csv(CV_CASE_FOLDS)
df_pred = pd.read_csv(PRED_FEAT_TRAIN)

df_match = None
if FE_CFG["use_match_features"] and MATCH_FEAT_TRAIN.exists():
    try:
        df_match = pd.read_csv(MATCH_FEAT_TRAIN)
    except Exception:
        df_match = None

df_prof = None
if FE_CFG["use_image_profile"] and IMG_PROFILE_TRAIN.exists():
    try:
        df_prof = pd.read_parquet(IMG_PROFILE_TRAIN)
    except Exception:
        df_prof = None

# ----------------------------
# 5) Normalize keys: uid/sample_id, case_id, variant
# ----------------------------
def ensure_uid_case_variant(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()

    if "uid" not in df.columns:
        for alt in ["sample_id", "id", "key"]:
            if alt in df.columns:
                df = df.rename(columns={alt: "uid"})
                break
    if "uid" not in df.columns:
        raise ValueError("Cannot find uid/sample_id column. Expected 'uid' or 'sample_id'.")

    if ("case_id" not in df.columns) or ("variant" not in df.columns):
        uid = df["uid"].astype(str)
        if "case_id" not in df.columns:
            df["case_id"] = uid.str.extract(r"^(\d+)")[0].astype("Int64")
        if "variant" not in df.columns:
            v  = uid.str.extract(r"__(\w+)$")[0]
            v2 = uid.str.extract(r"_(\w+)$")[0]
            df["variant"] = v.fillna(v2).fillna("unk")

    # guard missing extraction
    if df["case_id"].isna().any():
        raise ValueError("Failed to parse case_id from uid for some rows. Check uid format.")

    df["case_id"] = df["case_id"].astype(int)
    df["variant"] = df["variant"].astype(str)
    df["uid"]     = df["uid"].astype(str)
    return df

df_pred = ensure_uid_case_variant(df_pred)

df_base2 = df_base.copy()
if "uid" not in df_base2.columns:
    if "sample_id" in df_base2.columns:
        df_base2 = df_base2.rename(columns={"sample_id": "uid"})
    elif ("case_id" in df_base2.columns and "variant" in df_base2.columns):
        df_base2["uid"] = df_base2["case_id"].astype(str) + "__" + df_base2["variant"].astype(str)

# label detection
label_col = None
for cand in ["y_forged", "has_mask", "is_forged", "forged"]:
    if cand in df_base2.columns:
        label_col = cand
        break
if label_col is None and "y_forged" in df_pred.columns:
    label_col = "y_forged"
if label_col is None:
    raise ValueError("Cannot find label column in df_train_all/pred_features (y_forged/has_mask/is_forged/forged).")

# folds sanity
if "case_id" not in df_cv.columns or "fold" not in df_cv.columns:
    raise ValueError("cv_case_folds.csv must contain columns: case_id, fold")
df_cv["case_id"] = df_cv["case_id"].astype(int)
df_cv["fold"]    = df_cv["fold"].astype(int)

# ----------------------------
# 6) Merge: start from df_pred (1 row per uid)
# ----------------------------
df_train = df_pred.copy()

# attach label
if "y_forged" in df_train.columns:
    df_train["y"] = pd.to_numeric(df_train["y_forged"], errors="coerce")
else:
    if "uid" in df_base2.columns:
        df_train = df_train.merge(
            df_base2[["uid", label_col]].rename(columns={label_col: "y"}),
            on="uid",
            how="left",
        )
    else:
        if {"case_id", "variant", label_col}.issubset(df_base2.columns):
            df_train = df_train.merge(
                df_base2[["case_id", "variant", label_col]].rename(columns={label_col: "y"}),
                on=["case_id", "variant"],
                how="left",
            )
        else:
            raise ValueError("Could not merge label from df_train_all (missing uid or case_id+variant).")

# guard y missing
if df_train["y"].isna().any():
    miss = int(df_train["y"].isna().sum())
    raise ValueError(f"Label merge produced NaN in y: {miss} rows. Check df_train_all vs pred_features alignment.")
df_train["y"] = df_train["y"].astype(int)

# attach folds
df_train = df_train.drop(columns=["fold"], errors="ignore").merge(df_cv[["case_id","fold"]], on="case_id", how="left")
if df_train["fold"].isna().any():
    miss = int(df_train["fold"].isna().sum())
    raise ValueError(f"Missing fold after merging cv_case_folds.csv: {miss} rows.")
df_train["fold"] = df_train["fold"].astype(int)

# optional: merge match features (only new cols)
if df_match is not None:
    df_match = ensure_uid_case_variant(df_match)
    base_cols = set(df_train.columns)
    new_cols = [c for c in df_match.columns if c not in base_cols]
    keep_cols = ["uid"] + [c for c in new_cols if c not in ["case_id", "variant"]]
    if len(keep_cols) > 1:
        df_train = df_train.merge(df_match[keep_cols], on="uid", how="left")

# optional: merge image profile by case_id
if df_prof is not None and "case_id" in df_prof.columns:
    df_prof2 = df_prof.copy()
    df_prof2["case_id"] = df_prof2["case_id"].astype(int)
    df_prof2 = df_prof2.drop_duplicates("case_id")
    clash = set(df_prof2.columns).intersection(df_train.columns)
    clash -= {"case_id"}
    if clash:
        df_prof2 = df_prof2.rename(columns={c: f"profile_{c}" for c in clash})
    df_train = df_train.merge(df_prof2, on="case_id", how="left")

# ----------------------------
# 7) Feature engineering (lebih kaya + stabil untuk Transformer)
# ----------------------------
def safe_log1p(arr):
    arr = np.asarray(arr, dtype=np.float64)
    arr = np.where(np.isfinite(arr), arr, 0.0)
    arr = np.clip(arr, 0.0, None)
    return np.log1p(arr)

def get_clip_cap(series: pd.Series, q: float, fallback: float):
    s = pd.to_numeric(series, errors="coerce").astype(float)
    s = s.replace([np.inf, -np.inf], np.nan).dropna()
    if len(s) == 0:
        return float(fallback)
    s = s[s >= 0]
    if len(s) == 0:
        return float(fallback)
    cap = float(s.quantile(q))
    if (not np.isfinite(cap)) or (cap <= 0):
        return float(fallback)
    return float(cap)

# heavy-tail candidates (akan dipakai kalau ada)
HEAVY_COLS = [
    "peak_ratio", "best_weight", "best_count",
    "n_pairs_thr", "n_pairs_mnn", "n_pairs",
    "n_comp", "largest_comp",
    "grid_area_frac", "mask_area_frac", "pred_area_frac",
    "overmask_tighten_steps",
]

clip_caps = {}
if FE_CFG["clip_by_quantile"]:
    for c in HEAVY_COLS:
        if c in df_train.columns:
            clip_caps[c] = get_clip_cap(df_train[c], FE_CFG["clip_q"], FE_CFG["clip_max_fallback"])

# log+cap features
if FE_CFG["add_log_features"]:
    for c in HEAVY_COLS:
        if c in df_train.columns:
            cap = clip_caps.get(c, FE_CFG["clip_max_fallback"])
            x = pd.to_numeric(df_train[c], errors="coerce").fillna(0.0).astype(float)
            x = np.clip(x, 0.0, cap)
            df_train[f"{c}_cap"] = x.astype(np.float32)
            df_train[f"log_{c}"] = safe_log1p(x).astype(np.float32)

# interaction features
if FE_CFG["add_interactions"]:
    def getf(col, default=0.0):
        if col in df_train.columns:
            return pd.to_numeric(df_train[col], errors="coerce").fillna(default).astype(float).values
        return np.full(len(df_train), default, dtype=np.float64)

    best_mean_sim = getf("best_mean_sim", 0.0)
    best_count    = getf("best_count", 0.0)
    grid_area     = getf("grid_area_frac", 0.0)
    has_peak      = getf("has_peak", 0.0)
    n_comp        = getf("n_comp", 0.0)
    largest_comp  = getf("largest_comp", 0.0)

    df_train["sim_x_count"]   = (best_mean_sim * best_count).astype(np.float32)
    df_train["area_x_sim"]    = (grid_area * best_mean_sim).astype(np.float32)
    df_train["area_x_count"]  = (grid_area * best_count).astype(np.float32)
    df_train["comp_density"]  = (largest_comp / (1.0 + n_comp)).astype(np.float32)
    df_train["comp_inv"]      = (1.0 / (1.0 + n_comp)).astype(np.float32)

    # mnn ratio
    n_pairs_thr = getf("n_pairs_thr", 0.0)
    n_pairs_mnn = getf("n_pairs_mnn", 0.0)
    df_train["mnn_ratio"] = (n_pairs_mnn / (1.0 + n_pairs_thr)).astype(np.float32)

    # peak gating
    if "log_peak_ratio" in df_train.columns:
        df_train["has_peak_x_logpeak"] = (has_peak * getf("log_peak_ratio", 0.0)).astype(np.float32)
    elif "log_peak_ratio_cap" in df_train.columns:
        df_train["has_peak_x_logpeak"] = (has_peak * getf("log_peak_ratio_cap", 0.0)).astype(np.float32)
    else:
        df_train["has_peak_x_logpeak"] = (has_peak * 0.0).astype(np.float32)

# replace inf -> NaN (numeric)
for c in df_train.columns:
    if pd.api.types.is_numeric_dtype(df_train[c]):
        df_train[c] = df_train[c].replace([np.inf, -np.inf], np.nan)

# ----------------------------
# 8) Select feature columns (numeric only; exclude identifiers/labels/split)
# ----------------------------
TARGET_COLS = {"y", "y_forged", "has_mask", "is_forged", "forged"}
SPLIT_COLS  = {"fold"}
ID_NUM_DROP = {"case_id"}  # numeric id jangan dipakai sebagai feature

num_cols = [c for c in df_train.columns if pd.api.types.is_numeric_dtype(df_train[c])]
feature_cols = [c for c in num_cols if c not in TARGET_COLS and c not in SPLIT_COLS and c not in ID_NUM_DROP]

# fill NaN
df_train[feature_cols] = df_train[feature_cols].fillna(FE_CFG["fillna_value"])

# drop constant cols (stabil untuk Transformer)
if FE_CFG["drop_constant_features"]:
    nun = df_train[feature_cols].nunique(dropna=False)
    nonconst = nun[nun > 1].index.tolist()
    dropped = sorted(set(feature_cols) - set(nonconst))
    feature_cols = nonconst
else:
    dropped = []

# cast float32
if FE_CFG["cast_float32"]:
    df_train[feature_cols] = df_train[feature_cols].astype(np.float32)

# ----------------------------
# 9) Final outputs
# ----------------------------
df_train_tabular = df_train[["uid","case_id","variant","fold","y"] + feature_cols].copy()

X_train = df_train_tabular[feature_cols]
y_train = df_train_tabular["y"].astype(int)
folds   = df_train_tabular["fold"].astype(int)

print("\nOK — Training table built")
print("  df_train_tabular:", df_train_tabular.shape)
print("  X_train:", X_train.shape, "| y pos%:", float(y_train.mean())*100.0)
print("  folds:", int(folds.nunique()), "unique folds")
print("  feature_cols:", int(len(feature_cols)))
if dropped:
    print("  dropped_constant_features:", len(dropped))

# quick sanity
if X_train.shape[0] != y_train.shape[0]:
    raise RuntimeError("X_train and y_train row mismatch")
if y_train.isna().any():
    raise RuntimeError("y_train contains NaN")
if folds.isna().any():
    raise RuntimeError("folds contains NaN")

FEATURE_COLS = list(feature_cols)
print("\nFeature head:", FEATURE_COLS[:20])
print("Feature tail:", FEATURE_COLS[-10:])

# ----------------------------
# 10) Save reproducible schema (feature list + FE config + clip caps)
# ----------------------------
OUT_ART = Path("/kaggle/working/recodai_luc_gate_artifacts")
OUT_ART.mkdir(parents=True, exist_ok=True)

with open(OUT_ART / "feature_cols.json", "w") as f:
    json.dump(FEATURE_COLS, f, indent=2)

schema = {
    "fe_cfg": FE_CFG,
    "clip_caps": clip_caps,
    "dropped_constant_features": dropped,
    "n_features": int(len(FEATURE_COLS)),
    "example_feature_head": FEATURE_COLS[:25],
}
with open(OUT_ART / "feature_schema.json", "w") as f:
    json.dump(schema, f, indent=2)

# opsional: simpan table agar gampang resume
df_train_tabular.to_parquet(OUT_ART / "df_train_tabular.parquet", index=False)

print(f"\nSaved -> {OUT_ART/'feature_cols.json'}")
print(f"Saved -> {OUT_ART/'feature_schema.json'}")
print(f"Saved -> {OUT_ART/'df_train_tabular.parquet'}")


Using:
  DF_TRAIN_ALL      : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/df_train_all.parquet
  CV_CASE_FOLDS     : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/cv_case_folds.csv
  PRED_FEAT_TRAIN   : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/pred_base_v3_v7_cfg_5dbf0aa165/pred_features_train_all.csv
  MATCH_FEAT_TRAIN  : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/match_base_cfg_f9f7ea3a65c5/match_features_train_all.csv (optional)
  IMG_PROFILE_TRAIN : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/artifacts/image_profile_train.parquet (optional)
  DINO_LARGE_DIR    : /kaggle/input/dinov2/pytorch/large/1

OK — Training table built
  df_train_tabular: (5176, 67)
  X_train: (5176, 62) | y pos%: 54.07650695517774
  folds: 5 unique folds
  feature_cols: 62
  dropped_constant_features: 6

Feature head: ['has_peak', 'peak_ratio', 'best_weight', 'best_count', 'best_mean_sim', 'n_pairs_thr', 'n_pairs_mnn', 'best_inlier_ratio', 'best_

# Train Baseline Model (Leakage-Safe CV)

In [None]:
# ============================================================
# Step 3 — Train Baseline Model (Leakage-Safe CV)
# - Baseline: mHC-FTTransformer (numeric tabular) + Sinkhorn (PDF mHC)
# - CV: pakai kolom `fold` (by case_id)
# - Output:
#   * OOF probabilities
#   * CV report (AUC, F1, Precision, Recall, LogLoss)
#   * Simpan model per fold + model_full (torch .pt pack: state_dict + scaler + cfg)
#
# Implementasi ide PDF (mHC):
# - Multi-stream residual mixing (n_streams = 4 default)
# - H_res per-layer diproyeksikan ke doubly-stochastic (Birkhoff polytope) via Sinkhorn-Knopp (tmax=20)
# - Gating factor alpha init 0.01 (awal dekat identity -> stabil)
#
# Kaggle-safe:
# - Update hanya active stream (stream-0) tiap layer -> jauh lebih cepat
# - AMP + grad accumulation (effective batch besar tanpa OOM)
# - LR scheduler step ala tabel: decay di 0.8 & 0.9 progress
# ============================================================

import json, gc, math, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

from IPython.display import display
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, log_loss

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# 0) Require outputs from Step 2
# ----------------------------
need_vars = ["df_train_tabular", "FEATURE_COLS"]
for v in need_vars:
    if v not in globals():
        raise RuntimeError(f"Missing `{v}`. Jalankan dulu Step 2 — Build Training Table (X, y, folds).")

df_train_tabular = df_train_tabular.copy()
FEATURE_COLS = list(FEATURE_COLS)

required_cols = {"uid", "case_id", "variant", "fold", "y"}
missing_cols = [c for c in required_cols if c not in df_train_tabular.columns]
if missing_cols:
    raise ValueError(f"df_train_tabular missing columns: {missing_cols}.")

# ----------------------------
# 1) CFG (dua preset: SAFE vs STRONG)
# ----------------------------
# SAFE: cocok T4/P100; STRONG: kalau A100/H100 (atau kamu yakin VRAM kuat)
CFG_SAFE = {
    "seed": 2025,

    # mHC (PDF)
    "n_streams": 4,         # mHC/HC expansion rate n = 4 (sesuai tabel)
    "sinkhorn_tmax": 20,    # sesuai PDF (praktikal)
    "alpha_init": 0.01,     # sesuai tabel (awal dekat identity)

    # model capacity (SAFE)
    "d_model": 256,
    "n_layers": 6,
    "n_heads": 8,
    "ffn_mult": 4,
    "dropout": 0.15,
    "attn_dropout": 0.10,

    # training (effective batch besar via accum)
    "batch_size": 256,      # micro-batch
    "accum_steps": 2,       # effective batch = 512
    "epochs": 50,

    # AdamW ala tabel (yang relevan)
    "lr": 3e-4,
    "betas": (0.9, 0.95),
    "eps": 1e-8,            # eps tabel 1e-20 terlalu ekstrem utk tabular kecil; ini lebih aman
    "weight_decay": 5e-2,

    # warmup + step decay (mirip tabel: step @ 0.8 & 0.9)
    "warmup_frac": 0.10,
    "lr_decay_milestones": (0.80, 0.90),
    "lr_decay_values": (0.316, 0.10),  # multiplier piecewise setelah milestone

    "grad_clip": 1.0,

    # early stopping
    "early_stop_patience": 10,
    "early_stop_min_delta": 1e-4,

    # report threshold (hanya report baseline)
    "report_thr": 0.5,
}

CFG_STRONG = {
    **CFG_SAFE,
    "d_model": 384,
    "n_layers": 8,
    "dropout": 0.20,
    "epochs": 70,
    "lr": 2e-4,
    "weight_decay": 7e-2,
    "batch_size": 256,
    "accum_steps": 2,
}

# pilih otomatis berdasarkan GPU memory (kalau ada)
CFG = dict(CFG_SAFE)
if torch.cuda.is_available():
    try:
        mem_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)
        if mem_gb >= 30:
            CFG = dict(CFG_STRONG)
    except Exception:
        pass

# ----------------------------
# 2) Seed + device
# ----------------------------
def seed_everything(seed: int = 2025):
    import random, os
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_everything(int(CFG["seed"]))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")
print("Device:", device, "| AMP:", use_amp, "| CFG:", ("STRONG" if CFG is CFG_STRONG else "SAFE"))

# speed knobs
if device.type == "cuda":
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# ----------------------------
# 3) Build arrays + guard
# ----------------------------
X = df_train_tabular[FEATURE_COLS].to_numpy(dtype=np.float32, copy=True)
y = df_train_tabular["y"].to_numpy(dtype=np.int64, copy=True)
folds = df_train_tabular["fold"].to_numpy(dtype=np.int64, copy=True)

if not np.isfinite(X).all():
    X = np.nan_to_num(X, nan=0.0, posinf=0.0, neginf=0.0)

n = len(df_train_tabular)
unique_folds = sorted(pd.Series(folds).unique().tolist())
n_folds = len(unique_folds)
n_features = X.shape[1]

print("Setup:")
print("  rows      :", n)
print("  folds     :", n_folds, "|", unique_folds)
print("  pos%      :", float(y.mean()) * 100.0)
print("  n_features:", n_features)

# ----------------------------
# 4) Dataset
# ----------------------------
class TabDataset(Dataset):
    def __init__(self, X, y=None):
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = None if y is None else torch.from_numpy(y.astype(np.int64))

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        if self.y is None:
            return self.X[idx]
        return self.X[idx], self.y[idx]

# ----------------------------
# 5) Normalization (leakage-safe: fit on train fold only)
# ----------------------------
def fit_standardizer(X_tr: np.ndarray):
    mu = X_tr.mean(axis=0, dtype=np.float64)
    sig = X_tr.std(axis=0, dtype=np.float64)
    sig = np.where(sig < 1e-8, 1.0, sig)
    return mu.astype(np.float32), sig.astype(np.float32)

def apply_standardizer(X_in: np.ndarray, mu: np.ndarray, sig: np.ndarray):
    return ((X_in - mu) / sig).astype(np.float32)

# ----------------------------
# 6) Metrics helpers
# ----------------------------
def safe_auc(y_true, p):
    if len(np.unique(y_true)) < 2:
        return None
    return float(roc_auc_score(y_true, p))

def safe_logloss(y_true, p):
    p = np.clip(np.asarray(p, dtype=np.float64), 1e-6, 1-1e-6)
    return float(log_loss(y_true, p, labels=[0, 1]))

# ----------------------------
# 7) PDF mHC: Sinkhorn projection (doubly-stochastic)
# ----------------------------
def sinkhorn_doubly_stochastic(logits: torch.Tensor, tmax: int = 20, eps: float = 1e-6):
    """
    logits: (n, n) unconstrained
    return: (n, n) ~ doubly-stochastic via Sinkhorn-Knopp on exp(logits)
    """
    # stabilize exp
    z = logits - logits.max()
    M = torch.exp(z)
    for _ in range(int(tmax)):
        M = M / (M.sum(dim=-1, keepdim=True) + eps)  # row norm
        M = M / (M.sum(dim=-2, keepdim=True) + eps)  # col norm
    return M

# ----------------------------
# 8) Model: RMSNorm (PDF uses RMSNorm in big transformers)
# ----------------------------
class RMSNorm(nn.Module):
    def __init__(self, d: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d))

    def forward(self, x):
        # x: (..., d)
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).rsqrt()
        return x * rms * self.weight

# ----------------------------
# 9) Transformer block + mHC mixing (multi-stream)
# ----------------------------
class MHCAttnBlock(nn.Module):
    """
    Update only active stream-0, then mix all streams with H_res (mHC).
    """
    def __init__(self, d_model, n_heads, ffn_mult, dropout, attn_dropout,
                 n_streams=4, sinkhorn_tmax=20, alpha_init=0.01):
        super().__init__()
        self.n_streams = int(n_streams)
        self.sinkhorn_tmax = int(sinkhorn_tmax)

        self.norm1 = RMSNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=attn_dropout, batch_first=True)
        self.drop1 = nn.Dropout(dropout)

        self.norm2 = RMSNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, ffn_mult * d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ffn_mult * d_model, d_model),
        )
        self.drop2 = nn.Dropout(dropout)

        # mHC params (PDF):
        # raw logits -> Sinkhorn -> doubly-stochastic; then convex combine with Identity using alpha
        self.h_logits = nn.Parameter(torch.zeros(self.n_streams, self.n_streams))
        nn.init.zeros_(self.h_logits)

        # alpha in (0,1), init alpha_init ~ 0.01
        a0 = float(alpha_init)
        a0 = min(max(a0, 1e-4), 1 - 1e-4)
        self.alpha_logit = nn.Parameter(torch.log(torch.tensor(a0 / (1 - a0), dtype=torch.float32)))

    def forward(self, streams):
        # streams: (B, n_streams, S, D)
        B, nS, S, D = streams.shape
        assert nS == self.n_streams

        # update only active stream 0
        x = streams[:, 0]  # (B,S,D)

        # self-attn (pre-norm)
        x0 = x
        q = self.norm1(x)
        attn_out, _ = self.attn(q, q, q, need_weights=False)
        x = x0 + self.drop1(attn_out)

        # FFN (pre-norm)
        x1 = x
        h = self.norm2(x)
        h = self.ffn(h)
        x = x1 + self.drop2(h)

        # put back
        streams = streams.clone()
        streams[:, 0] = x

        # mHC mixing
        H = sinkhorn_doubly_stochastic(self.h_logits, tmax=self.sinkhorn_tmax)  # (n,n)
        alpha = torch.sigmoid(self.alpha_logit).to(dtype=streams.dtype)
        I = torch.eye(self.n_streams, device=streams.device, dtype=streams.dtype)
        Hres = (1.0 - alpha) * I + alpha * H.to(dtype=streams.dtype)

        # mix across stream dimension
        mixed = torch.einsum("ij,bjtd->bitd", Hres, streams)
        return mixed

class MHCFTTransformer(nn.Module):
    def __init__(self, n_features,
                 d_model=256, n_heads=8, n_layers=6, ffn_mult=4,
                 dropout=0.15, attn_dropout=0.10,
                 n_streams=4, sinkhorn_tmax=20, alpha_init=0.01):
        super().__init__()
        self.n_features = int(n_features)
        self.d_model = int(d_model)

        # numeric tokenization: x_f -> x_f * W_f + b_f
        self.w = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)
        self.b = nn.Parameter(torch.zeros(self.n_features, self.d_model))
        self.feat_emb = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)

        self.cls = nn.Parameter(torch.randn(1, 1, self.d_model) * 0.02)
        self.in_drop = nn.Dropout(dropout)

        self.blocks = nn.ModuleList([
            MHCAttnBlock(
                d_model=self.d_model,
                n_heads=n_heads,
                ffn_mult=ffn_mult,
                dropout=dropout,
                attn_dropout=attn_dropout,
                n_streams=n_streams,
                sinkhorn_tmax=sinkhorn_tmax,
                alpha_init=alpha_init
            )
            for _ in range(int(n_layers))
        ])

        self.out_norm = RMSNorm(self.d_model)
        self.head = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(self.d_model, 1),
        )

        self.n_streams = int(n_streams)

    def forward(self, x):
        # x: (B,F)
        tok = x.unsqueeze(-1) * self.w.unsqueeze(0) + self.b.unsqueeze(0)   # (B,F,D)
        tok = tok + self.feat_emb.unsqueeze(0)
        B = tok.size(0)
        cls = self.cls.expand(B, -1, -1)                                    # (B,1,D)
        seq = torch.cat([cls, tok], dim=1)                                  # (B,1+F,D)
        seq = self.in_drop(seq)

        # init streams: replicate seq
        streams = seq.unsqueeze(1).repeat(1, self.n_streams, 1, 1)          # (B,nS,S,D)

        for blk in self.blocks:
            streams = blk(streams)

        z = streams[:, 0, 0]  # CLS from active stream
        z = self.out_norm(z)
        logit = self.head(z).squeeze(-1)
        return logit

# ----------------------------
# 10) LR Scheduler: warmup + step decay (0.8 & 0.9)
# ----------------------------
def make_warmup_step_scheduler(optimizer, total_steps: int, warmup_steps: int,
                              milestones_frac=(0.8, 0.9), decay_values=(0.316, 0.1)):
    m1 = int(float(milestones_frac[0]) * total_steps)
    m2 = int(float(milestones_frac[1]) * total_steps)
    d1 = float(decay_values[0])
    d2 = float(decay_values[1])

    def lr_lambda(step):
        # warmup: 0 -> 1
        if step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))

        # step decay piecewise (mirip tabel)
        if step < m1:
            mult = 1.0
        elif step < m2:
            mult = d1
        else:
            mult = d2
        return mult

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# ----------------------------
# 11) Predict helper (handle batch=(xb,yb))
# ----------------------------
@torch.no_grad()
def predict_proba(model, loader):
    model.eval()
    probs = []
    for batch in loader:
        xb = batch[0] if isinstance(batch, (list, tuple)) else batch
        xb = xb.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=use_amp):
            logits = model(xb)
            p = torch.sigmoid(logits)
        probs.append(p.detach().cpu().numpy())
    return np.concatenate(probs, axis=0).astype(np.float32)

# ----------------------------
# 12) Train one fold (AMP + grad accumulation + early stopping)
# ----------------------------
def train_one_fold(X_tr, y_tr, X_va, y_va, cfg):
    mu, sig = fit_standardizer(X_tr)
    X_trn = apply_standardizer(X_tr, mu, sig)
    X_van = apply_standardizer(X_va, mu, sig)

    ds_tr = TabDataset(X_trn, y_tr)
    ds_va = TabDataset(X_van, y_va)

    dl_tr = DataLoader(
        ds_tr, batch_size=int(cfg["batch_size"]), shuffle=True,
        num_workers=2, pin_memory=(device.type == "cuda"),
        drop_last=False
    )
    dl_va = DataLoader(
        ds_va, batch_size=int(cfg["batch_size"]), shuffle=False,
        num_workers=2, pin_memory=(device.type == "cuda"),
        drop_last=False
    )

    model = MHCFTTransformer(
        n_features=n_features,
        d_model=int(cfg["d_model"]),
        n_heads=int(cfg["n_heads"]),
        n_layers=int(cfg["n_layers"]),
        ffn_mult=int(cfg["ffn_mult"]),
        dropout=float(cfg["dropout"]),
        attn_dropout=float(cfg["attn_dropout"]),
        n_streams=int(cfg["n_streams"]),
        sinkhorn_tmax=int(cfg["sinkhorn_tmax"]),
        alpha_init=float(cfg["alpha_init"]),
    ).to(device)

    # imbalance weight
    pos = int(y_tr.sum())
    neg = int(len(y_tr) - pos)
    pos_weight = float(neg / max(1, pos))
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        betas=tuple(cfg["betas"]),
        eps=float(cfg["eps"]),
        weight_decay=float(cfg["weight_decay"]),
    )

    accum_steps = max(1, int(cfg.get("accum_steps", 1)))
    optim_steps_per_epoch = int(math.ceil(len(dl_tr) / accum_steps))
    total_optim_steps = int(cfg["epochs"]) * max(1, optim_steps_per_epoch)
    warmup_steps = int(float(cfg["warmup_frac"]) * total_optim_steps)

    sch = make_warmup_step_scheduler(
        opt,
        total_steps=total_optim_steps,
        warmup_steps=warmup_steps,
        milestones_frac=cfg["lr_decay_milestones"],
        decay_values=cfg["lr_decay_values"],
    )

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    best = {"val_logloss": 1e9, "epoch": -1}
    best_state = None
    bad = 0

    global_step = 0

    for epoch in range(int(cfg["epochs"])):
        model.train()
        t0 = time.time()
        loss_sum = 0.0
        n_sum = 0

        opt.zero_grad(set_to_none=True)

        micro_step = 0
        optim_step_in_epoch = 0

        for batch in dl_tr:
            xb, yb = batch
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model(xb)
                loss = loss_fn(logits, yb)
                loss = loss / accum_steps

            scaler.scale(loss).backward()
            micro_step += 1

            loss_sum += float(loss.item()) * xb.size(0) * accum_steps  # undo divide for logging
            n_sum += xb.size(0)

            if (micro_step % accum_steps) == 0:
                if cfg["grad_clip"] and float(cfg["grad_clip"]) > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))

                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)

                sch.step()
                global_step += 1
                optim_step_in_epoch += 1

        # flush last partial accumulation (kalau ada)
        if (micro_step % accum_steps) != 0:
            if cfg["grad_clip"] and float(cfg["grad_clip"]) > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)
            sch.step()
            global_step += 1
            optim_step_in_epoch += 1

        # validate
        p_va = predict_proba(model, dl_va)
        vll = safe_logloss(y_va, p_va)

        tr_loss = loss_sum / max(1, n_sum)
        dt = time.time() - t0
        print(f"  epoch {epoch+1:03d}/{cfg['epochs']} | train_loss={tr_loss:.5f} | val_logloss={vll:.5f} | opt_steps={optim_step_in_epoch} | dt={dt:.1f}s")

        improved = (best["val_logloss"] - vll) > float(cfg["early_stop_min_delta"])
        if improved:
            best["val_logloss"] = float(vll)
            best["epoch"] = int(epoch)
            best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= int(cfg["early_stop_patience"]):
                print(f"  early stop at epoch {epoch+1}, best_epoch={best['epoch']+1}, best_val_logloss={best['val_logloss']:.5f}")
                break

        if device.type == "cuda":
            torch.cuda.empty_cache()
        gc.collect()

    if best_state is not None:
        model.load_state_dict(best_state, strict=True)

    # final val preds (best)
    p_va = predict_proba(model, dl_va)

    pack = {
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "mu": mu,
        "sig": sig,
        "cfg": cfg,
    }
    return pack, p_va, best

# ----------------------------
# 13) CV loop
# ----------------------------
oof_pred = np.zeros(n, dtype=np.float32)
fold_reports = []

models_dir = Path("/kaggle/working/recodai_luc_gate_artifacts/baseline_mhc_transformer_folds")
models_dir.mkdir(parents=True, exist_ok=True)

for f in unique_folds:
    print(f"\n[Fold {f}]")
    tr_idx = np.where(folds != f)[0]
    va_idx = np.where(folds == f)[0]

    X_tr, y_tr = X[tr_idx], y[tr_idx]
    X_va, y_va = X[va_idx], y[va_idx]

    pack, p_va, best = train_one_fold(X_tr, y_tr, X_va, y_va, CFG)
    oof_pred[va_idx] = p_va

    auc = safe_auc(y_va, p_va)
    thr = float(CFG["report_thr"])
    yhat = (p_va >= thr).astype(np.int32)

    rep = {
        "fold": int(f),
        "n_val": int(len(va_idx)),
        "pos_val": int(y_va.sum()),
        "auc": auc,
        f"f1@{thr}": float(f1_score(y_va, yhat, zero_division=0)),
        f"precision@{thr}": float(precision_score(y_va, yhat, zero_division=0)),
        f"recall@{thr}": float(recall_score(y_va, yhat, zero_division=0)),
        "logloss": safe_logloss(y_va, p_va),
        "best_val_logloss": float(best["val_logloss"]),
        "best_epoch": int(best["epoch"] + 1),
    }
    fold_reports.append(rep)

    torch.save(
        {"pack": pack, "feature_cols": FEATURE_COLS},
        models_dir / f"baseline_mhc_transformer_fold_{f}.pt"
    )

    if device.type == "cuda":
        torch.cuda.empty_cache()
    gc.collect()

# ----------------------------
# 14) Overall OOF metrics
# ----------------------------
oof_auc = safe_auc(y, oof_pred)
thr = float(CFG["report_thr"])
oof_yhat = (oof_pred >= thr).astype(np.int32)

overall = {
    "rows": int(n),
    "folds": int(n_folds),
    "pos_total": int(y.sum()),
    "pos_rate": float(y.mean()),
    "oof_auc": oof_auc,
    f"oof_f1@{thr}": float(f1_score(y, oof_yhat, zero_division=0)),
    f"oof_precision@{thr}": float(precision_score(y, oof_yhat, zero_division=0)),
    f"oof_recall@{thr}": float(recall_score(y, oof_yhat, zero_division=0)),
    "oof_logloss": safe_logloss(y, oof_pred),
}

df_rep = pd.DataFrame(fold_reports).sort_values("fold").reset_index(drop=True)
print("\nPer-fold report:")
display(df_rep)

print("\nOOF overall:")
print(overall)

# ----------------------------
# 15) Train FULL model (fixed epochs = 70% of CV epochs)
# ----------------------------
def train_full_fixed(X_full_raw, y_full, cfg):
    mu, sig = fit_standardizer(X_full_raw)
    X_full = apply_standardizer(X_full_raw, mu, sig)

    ds_full = TabDataset(X_full, y_full)
    dl_full = DataLoader(
        ds_full, batch_size=int(cfg["batch_size"]), shuffle=True,
        num_workers=2, pin_memory=(device.type == "cuda"),
        drop_last=False
    )

    model = MHCFTTransformer(
        n_features=n_features,
        d_model=int(cfg["d_model"]),
        n_heads=int(cfg["n_heads"]),
        n_layers=int(cfg["n_layers"]),
        ffn_mult=int(cfg["ffn_mult"]),
        dropout=float(cfg["dropout"]),
        attn_dropout=float(cfg["attn_dropout"]),
        n_streams=int(cfg["n_streams"]),
        sinkhorn_tmax=int(cfg["sinkhorn_tmax"]),
        alpha_init=float(cfg["alpha_init"]),
    ).to(device)

    pos = int(y_full.sum())
    neg = int(len(y_full) - pos)
    pos_weight = float(neg / max(1, pos))
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        betas=tuple(cfg["betas"]),
        eps=float(cfg["eps"]),
        weight_decay=float(cfg["weight_decay"]),
    )

    accum_steps = max(1, int(cfg.get("accum_steps", 1)))
    E_FULL = max(12, int(cfg["epochs"] * 0.7))
    optim_steps_per_epoch = int(math.ceil(len(dl_full) / accum_steps))
    total_optim_steps = E_FULL * max(1, optim_steps_per_epoch)
    warmup_steps = int(float(cfg["warmup_frac"]) * total_optim_steps)

    sch = make_warmup_step_scheduler(
        opt,
        total_steps=total_optim_steps,
        warmup_steps=warmup_steps,
        milestones_frac=cfg["lr_decay_milestones"],
        decay_values=cfg["lr_decay_values"],
    )

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    print(f"\nTraining full mHC transformer for {E_FULL} epochs (fixed)...")
    for epoch in range(E_FULL):
        model.train()
        loss_sum = 0.0
        n_sum = 0

        opt.zero_grad(set_to_none=True)
        micro_step = 0
        for xb, yb in dl_full:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model(xb)
                loss = loss_fn(logits, yb)
                loss = loss / accum_steps

            scaler.scale(loss).backward()
            micro_step += 1

            loss_sum += float(loss.item()) * xb.size(0) * accum_steps
            n_sum += xb.size(0)

            if (micro_step % accum_steps) == 0:
                if cfg["grad_clip"] and float(cfg["grad_clip"]) > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                sch.step()

        # flush last partial
        if (micro_step % accum_steps) != 0:
            if cfg["grad_clip"] and float(cfg["grad_clip"]) > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
            scaler.step(opt)
            scaler.update()
            opt.zero_grad(set_to_none=True)
            sch.step()

        print(f"  full epoch {epoch+1:03d}/{E_FULL} | loss={loss_sum/max(1,n_sum):.5f}")

        if device.type == "cuda":
            torch.cuda.empty_cache()
        gc.collect()

    full_pack = {
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "mu": mu,
        "sig": sig,
        "cfg": cfg,
    }
    return full_pack

out_dir = Path("/kaggle/working/recodai_luc_gate_artifacts")
out_dir.mkdir(parents=True, exist_ok=True)

full_pack = train_full_fixed(X, y, CFG)
torch.save({"pack": full_pack, "feature_cols": FEATURE_COLS}, out_dir / "baseline_mhc_transformer_model_full.pt")

# ----------------------------
# 16) Save OOF + report
# ----------------------------
df_oof = df_train_tabular[["uid", "case_id", "variant", "fold", "y"]].copy()
df_oof["oof_pred_baseline_mhc_tf"] = oof_pred
df_oof.to_csv(out_dir / "oof_baseline_mhc_transformer.csv", index=False)

report = {
    "model": "mHC-FTTransformer (numeric tabular) — baseline",
    "cfg": CFG,
    "feature_count": int(len(FEATURE_COLS)),
    "fold_reports": fold_reports,
    "overall": overall,
}
with open(out_dir / "baseline_mhc_transformer_cv_report.json", "w") as f:
    json.dump(report, f, indent=2)

print("\nSaved artifacts:")
print("  fold models  ->", models_dir)
print("  full model   ->", out_dir / "baseline_mhc_transformer_model_full.pt")
print("  oof preds    ->", out_dir / "oof_baseline_mhc_transformer.csv")
print("  cv report    ->", out_dir / "baseline_mhc_transformer_cv_report.json")

# Export globals
OOF_PRED_BASELINE_MHC_TF = oof_pred
BASELINE_MHC_TF_OVERALL = overall
BASELINE_MHC_TF_FOLD_REPORTS = fold_reports


Device: cpu | AMP: False | CFG: SAFE
Setup:
  rows      : 5176
  folds     : 5 | [0, 1, 2, 3, 4]
  pos%      : 54.07650695517774
  n_features: 62

[Fold 0]


# Optimize Model & Hyperparameters (Iterative)

In [None]:
# ============================================================
# Step 4 — Optimize Model & Hyperparameters (Iterative) — TRANSFORMER ONLY (REVISI FULL v2, mHC-lite from PDF)
# - FIX: predict_proba robust (batch can be xb or (xb,yb) or list/tuple) => NO ERROR
# - Implementasi materi PDF (adapted, runtime-safe):
#     * mHC-lite (multi-stream residual on CLS) with Sinkhorn-Knopp (tmax=20)
#     * n_streams=4, alpha=0.01 (default)
#     * AdamW betas=(0.9,0.95)
#     * LR schedule: Warmup + Step decay at ratios [0.8,0.9] with rates [0.316, 0.1]
# - Validasi: leakage-safe CV pakai `fold` by case_id
# - Skor utama: OOF best F-beta (beta=0.5)
#
# Output:
# - /kaggle/working/recodai_luc_gate_artifacts/opt_search/opt_results.csv
# - /kaggle/working/recodai_luc_gate_artifacts/opt_search/opt_results.json
# - /kaggle/working/recodai_luc_gate_artifacts/opt_search/opt_fold_details.csv
# - /kaggle/working/recodai_luc_gate_artifacts/opt_search/oof_preds_<cfg_name>.csv (top configs)
# - /kaggle/working/recodai_luc_gate_artifacts/best_gate_config.json
# - /kaggle/working/recodai_luc_gate_artifacts/best_gate_model.pt
#
# REQUIRE:
# - Step 2 sudah jalan: df_train_tabular, FEATURE_COLS
# ============================================================

import os, json, gc, math, time, warnings, re
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

from IPython.display import display
from sklearn.metrics import roc_auc_score, log_loss

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# 0) Require data from Step 2
# ----------------------------
need_vars = ["df_train_tabular", "FEATURE_COLS"]
for v in need_vars:
    if v not in globals():
        raise RuntimeError(f"Missing `{v}`. Jalankan dulu Step 2 — Build Training Table (X, y, folds).")

df_train_tabular = df_train_tabular.copy()
FEATURE_COLS = list(FEATURE_COLS)

X_all = df_train_tabular[FEATURE_COLS].to_numpy(dtype=np.float32, copy=True)
y_all = df_train_tabular["y"].to_numpy(dtype=np.int64, copy=True)
folds_all = df_train_tabular["fold"].to_numpy(dtype=np.int64, copy=True)
uids_all = df_train_tabular["uid"].astype(str).to_numpy()

# guard
if not np.isfinite(X_all).all():
    X_all = np.nan_to_num(X_all, nan=0.0, posinf=0.0, neginf=0.0)

unique_folds = sorted(pd.Series(folds_all).unique().tolist())
n = len(y_all)
pos_rate = float(y_all.mean())
n_features = X_all.shape[1]

print("Optimize setup (Transformer only, mHC-lite):")
print(f"  rows={n} | folds={len(unique_folds)} | pos%={pos_rate*100:.2f} | n_features={n_features}")

# ----------------------------
# 1) Global settings
# ----------------------------
SEED = 2025
BETA = 0.5
THR_GRID = 201

# runtime control (tanpa mengubah akurasi terlalu banyak)
STAGE1_FOLDS = min(3, len(unique_folds))         # stage-1 pakai subset fold (lebih cepat)
STAGE1_EPOCH_CAP = 40                            # stage-1 cap epoch
STAGE1_PAT_CAP = 6

STAGE2_TOPM = min(3, 5)                          # stage-2 full CV hanya top-M config
REPORT_TOPK_OOF = 3

# optional time budget (biar tidak kebablasan)
TIME_BUDGET_SEC = 0
# contoh: TIME_BUDGET_SEC = 2.5 * 60 * 60  # 2.5 jam
# biarkan 0 jika tidak ingin dihentikan otomatis

def seed_everything(seed=2025):
    import random
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_everything(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")
print("Device:", device, "| AMP:", use_amp)

try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

# ----------------------------
# 2) Helpers: threshold search + safe metrics
# ----------------------------
def best_fbeta_fast(y_true, p, beta=0.5, grid=201):
    y = (np.asarray(y_true).astype(np.int32) == 1)
    p = np.clip(np.asarray(p, dtype=np.float64), 1e-8, 1.0 - 1e-8)

    thrs = np.linspace(0.01, 0.99, grid, dtype=np.float64)
    pred = (p[:, None] >= thrs[None, :])

    y1 = y[:, None]
    tp = (pred & y1).sum(axis=0).astype(np.float64)
    fp = (pred & (~y1)).sum(axis=0).astype(np.float64)
    fn = (y.sum().astype(np.float64) - tp)

    precision = np.divide(tp, tp + fp, out=np.zeros_like(tp), where=(tp + fp) > 0)
    recall    = np.divide(tp, tp + fn, out=np.zeros_like(tp), where=(tp + fn) > 0)

    b2 = beta * beta
    denom = (b2 * precision + recall)
    fbeta = np.divide((1.0 + b2) * precision * recall, denom, out=np.zeros_like(precision), where=denom > 0)

    j = int(np.argmax(fbeta))
    return {
        "fbeta": float(fbeta[j]),
        "thr": float(thrs[j]),
        "precision": float(precision[j]),
        "recall": float(recall[j]),
    }

def safe_auc(y_true, p):
    if len(np.unique(y_true)) < 2:
        return None
    return float(roc_auc_score(y_true, p))

def safe_logloss(y_true, p):
    p = np.clip(np.asarray(p, dtype=np.float64), 1e-8, 1 - 1e-8)
    return float(log_loss(y_true, p, labels=[0, 1]))

# ----------------------------
# 3) Dataset + Standardizer (fit only on train fold => no leakage)
# ----------------------------
class TabDataset(Dataset):
    def __init__(self, X, y=None):
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = None if y is None else torch.from_numpy(y.astype(np.int64))

    def __len__(self): return self.X.shape[0]

    def __getitem__(self, idx):
        if self.y is None:
            return self.X[idx]
        return self.X[idx], self.y[idx]

def fit_standardizer(X_tr: np.ndarray):
    mu = X_tr.mean(axis=0, dtype=np.float64)
    sig = X_tr.std(axis=0, dtype=np.float64)
    sig = np.where(sig < 1e-8, 1.0, sig)
    return mu.astype(np.float32), sig.astype(np.float32)

def apply_standardizer(X_in: np.ndarray, mu: np.ndarray, sig: np.ndarray):
    return ((X_in - mu) / sig).astype(np.float32)

# ----------------------------
# 4) PDF-inspired blocks: RMSNorm + Sinkhorn + mHC-lite on CLS streams
# ----------------------------
class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-8):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d))

    def forward(self, x):
        # x: (..., d)
        rms = torch.mean(x * x, dim=-1, keepdim=True)
        x = x * torch.rsqrt(rms + self.eps)
        return x * self.weight

@torch.no_grad()
def sinkhorn_knopp(P, tmax=20, eps=1e-12):
    """
    P: (B, n, n) non-negative
    output: doubly-stochastic approx via alternating row/col normalization
    """
    M = P.clamp_min(eps)
    for _ in range(int(tmax)):
        M = M / (M.sum(dim=-1, keepdim=True).clamp_min(eps))  # row norm
        M = M / (M.sum(dim=-2, keepdim=True).clamp_min(eps))  # col norm
    return M

class MHCLite(nn.Module):
    """
    Lightweight adaptation of mHC:
    - maintain n_streams residual streams only for CLS embedding
    - compute non-negative mixing matrix via softplus + Sinkhorn-Knopp
    - inject current CLS back to each stream
    """
    def __init__(self, d_model, n_streams=4, alpha=0.01, tmax=20, dropout=0.0):
        super().__init__()
        self.n = int(n_streams)
        self.alpha = float(alpha)
        self.tmax = int(tmax)
        self.dropout = nn.Dropout(float(dropout))

        # dynamic mapping from CLS -> n*n
        # use RMSNorm to stabilize (PDF mentions RMSNorm in infra)
        self.norm = RMSNorm(d_model, eps=1e-8)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Linear(d_model, self.n * self.n),
        )
        self.softplus = nn.Softplus()

    def forward(self, streams, cls_vec):
        # streams: (B, n, D)
        # cls_vec: (B, D)
        B, n, D = streams.shape
        h = self.norm(cls_vec)
        logits = self.mlp(h).view(B, n, n) * self.alpha
        P = self.softplus(logits)  # non-negative
        M = sinkhorn_knopp(P, tmax=self.tmax)  # (B,n,n) doubly-stochastic approx

        mixed = torch.einsum("bij,bjd->bid", M, streams)       # mix streams
        injected = mixed + cls_vec.unsqueeze(1)                # inject CLS
        return self.dropout(injected)

class TransformerBlock(nn.Module):
    """
    Custom transformer block with RMSNorm + separate attn_dropout support.
    """
    def __init__(self, d_model, n_heads, ffn_mult=4, dropout=0.2, attn_dropout=0.1):
        super().__init__()
        self.norm1 = RMSNorm(d_model, eps=1e-8)
        self.attn = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=int(n_heads),
            dropout=float(attn_dropout), batch_first=True
        )
        self.drop1 = nn.Dropout(float(dropout))

        self.norm2 = RMSNorm(d_model, eps=1e-8)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, int(ffn_mult) * d_model),
            nn.GELU(),
            nn.Dropout(float(dropout)),
            nn.Linear(int(ffn_mult) * d_model, d_model),
        )
        self.drop2 = nn.Dropout(float(dropout))

    def forward(self, x):
        # Pre-norm (norm_first)
        h = self.norm1(x)
        attn_out, _ = self.attn(h, h, h, need_weights=False)
        x = x + self.drop1(attn_out)

        h = self.norm2(x)
        x = x + self.drop2(self.ffn(h))
        return x

class FTTransformer_MHCLite(nn.Module):
    """
    Numeric FT-Transformer + mHC-lite on CLS between blocks (runtime-safe).
    """
    def __init__(self, n_features, d_model=384, n_heads=8, n_layers=8, ffn_mult=4,
                 dropout=0.2, attn_dropout=0.1,
                 n_streams=4, mhc_alpha=0.01, sinkhorn_tmax=20, mhc_dropout=0.0):
        super().__init__()
        self.n_features = int(n_features)
        self.d_model = int(d_model)
        self.n_layers = int(n_layers)

        # per-feature linear tokenization
        self.w = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)
        self.b = nn.Parameter(torch.zeros(self.n_features, self.d_model))
        self.feat_emb = nn.Parameter(torch.randn(self.n_features, self.d_model) * 0.02)

        self.cls = nn.Parameter(torch.randn(1, 1, self.d_model) * 0.02)

        self.token_dropout = nn.Dropout(float(attn_dropout))

        self.blocks = nn.ModuleList([
            TransformerBlock(
                d_model=self.d_model,
                n_heads=n_heads,
                ffn_mult=ffn_mult,
                dropout=dropout,
                attn_dropout=attn_dropout
            ) for _ in range(self.n_layers)
        ])

        self.mhc = nn.ModuleList([
            MHCLite(
                d_model=self.d_model,
                n_streams=n_streams,
                alpha=mhc_alpha,
                tmax=sinkhorn_tmax,
                dropout=mhc_dropout
            ) for _ in range(self.n_layers)
        ])

        self.out_norm = RMSNorm(self.d_model, eps=1e-8)
        self.head = nn.Sequential(
            nn.Linear(self.d_model, self.d_model),
            nn.GELU(),
            nn.Dropout(float(dropout)),
            nn.Linear(self.d_model, 1),
        )

    def forward(self, x):
        # x: (B, F)
        tok = x.unsqueeze(-1) * self.w.unsqueeze(0) + self.b.unsqueeze(0)
        tok = tok + self.feat_emb.unsqueeze(0)
        tok = self.token_dropout(tok)

        B = tok.size(0)
        cls = self.cls.expand(B, -1, -1)   # (B,1,D)
        seq = torch.cat([cls, tok], dim=1) # (B,1+F,D)

        # init streams from CLS
        streams = seq[:, 0:1, :].expand(B, self.mhc[0].n, self.d_model).contiguous()

        for l, blk in enumerate(self.blocks):
            # inject mean-stream into CLS before block
            seq[:, 0, :] = streams.mean(dim=1)

            seq = blk(seq)
            cls_vec = seq[:, 0, :]

            # mHC-lite update streams using current CLS
            streams = self.mhc[l](streams, cls_vec)

        out = self.out_norm(streams.mean(dim=1))
        logit = self.head(out).squeeze(-1)
        return logit

# ----------------------------
# 5) LR schedule from PDF: warmup + step decay at ratios [0.8,0.9] with rates [0.316,0.1]
# ----------------------------
def make_warmup_step_scheduler(optimizer, total_steps, warmup_steps, r1=0.8, r2=0.9, d1=0.316, d2=0.1):
    m1 = int(r1 * total_steps)
    m2 = int(r2 * total_steps)
    def lr_lambda(step):
        if warmup_steps > 0 and step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))
        mult = 1.0
        if step >= m1:
            mult *= float(d1)
        if step >= m2:
            mult *= float(d2)
        return mult
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# ----------------------------
# 6) Predict helper (FIX: handle batch=(xb,yb) / list / tuple)
# ----------------------------
@torch.no_grad()
def predict_proba(model, loader):
    model.eval()
    probs = []
    for batch in loader:
        xb = batch[0] if isinstance(batch, (list, tuple)) else batch
        xb = xb.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=use_amp):
            logits = model(xb)
            p = torch.sigmoid(logits)
        probs.append(p.detach().cpu().numpy())
    return np.concatenate(probs, axis=0).astype(np.float32)

# ----------------------------
# 7) Train one fold (AMP + early stopping)
# ----------------------------
def train_one_fold_transformer(X_tr, y_tr, X_va, y_va, cfg):
    mu, sig = fit_standardizer(X_tr)
    X_trn = apply_standardizer(X_tr, mu, sig)
    X_van = apply_standardizer(X_va, mu, sig)

    ds_tr = TabDataset(X_trn, y_tr)
    ds_va = TabDataset(X_van, y_va)

    dl_tr = DataLoader(
        ds_tr, batch_size=int(cfg["batch_size"]), shuffle=True,
        num_workers=2, pin_memory=(device.type=="cuda"), drop_last=False
    )
    dl_va = DataLoader(
        ds_va, batch_size=int(cfg["batch_size"]), shuffle=False,
        num_workers=2, pin_memory=(device.type=="cuda"), drop_last=False
    )

    model = FTTransformer_MHCLite(
        n_features=n_features,
        d_model=int(cfg["d_model"]),
        n_heads=int(cfg["n_heads"]),
        n_layers=int(cfg["n_layers"]),
        ffn_mult=int(cfg["ffn_mult"]),
        dropout=float(cfg["dropout"]),
        attn_dropout=float(cfg["attn_dropout"]),
        n_streams=int(cfg["n_streams"]),
        mhc_alpha=float(cfg["mhc_alpha"]),
        sinkhorn_tmax=int(cfg["sinkhorn_tmax"]),
        mhc_dropout=float(cfg["mhc_dropout"]),
    ).to(device)

    # imbalance -> pos_weight for BCEWithLogitsLoss
    pos = int(y_tr.sum())
    neg = int(len(y_tr) - pos)
    pos_weight = float(neg / max(1, pos))
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))

    # PDF-like AdamW betas
    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(cfg["lr"]),
        weight_decay=float(cfg["weight_decay"]),
        betas=(float(cfg["beta1"]), float(cfg["beta2"])),
        eps=float(cfg["adam_eps"]),
    )

    total_steps = int(cfg["epochs"]) * max(1, len(dl_tr))
    warmup_steps = int(float(cfg["warmup_frac"]) * total_steps)

    sch = make_warmup_step_scheduler(
        opt,
        total_steps=total_steps,
        warmup_steps=warmup_steps,
        r1=float(cfg["lr_decay_ratio1"]),
        r2=float(cfg["lr_decay_ratio2"]),
        d1=float(cfg["lr_decay_rate1"]),
        d2=float(cfg["lr_decay_rate2"]),
    )

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    best_val = 1e9
    best_state = None
    best_epoch = -1
    bad = 0
    step = 0

    for epoch in range(int(cfg["epochs"])):
        model.train()
        loss_sum = 0.0
        n_sum = 0

        for xb, yb in dl_tr:
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()

            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model(xb)
                loss = loss_fn(logits, yb)

            scaler.scale(loss).backward()
            if float(cfg["grad_clip"]) and float(cfg["grad_clip"]) > 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(cfg["grad_clip"]))
            scaler.step(opt)
            scaler.update()
            sch.step()

            loss_sum += float(loss.item()) * xb.size(0)
            n_sum += xb.size(0)
            step += 1

        # val
        p_va = predict_proba(model, dl_va)
        vll = safe_logloss(y_va, p_va)

        improved = (best_val - vll) > float(cfg["min_delta"])
        if improved:
            best_val = float(vll)
            best_epoch = int(epoch)
            best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= int(cfg["patience"]):
                break

        gc.collect()

    if best_state is not None:
        model.load_state_dict(best_state, strict=True)

    p_va = predict_proba(model, dl_va)

    pack = {
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "mu": mu,
        "sig": sig,
        "cfg": dict(cfg),
        "best_epoch": int(best_epoch + 1),
        "best_val_logloss": float(best_val),
    }
    return pack, p_va

# ----------------------------
# 8) CV evaluator for a config (optionally limited folds)
# ----------------------------
def run_cv_config(cfg, cfg_name, folds_subset=None, beta=0.5, thr_grid=201):
    oof = np.zeros(n, dtype=np.float32)
    fold_rows = []
    fold_packs = []

    use_folds = unique_folds if folds_subset is None else list(folds_subset)

    for f in use_folds:
        tr = np.where(folds_all != f)[0]
        va = np.where(folds_all == f)[0]

        X_tr, y_tr = X_all[tr], y_all[tr]
        X_va, y_va = X_all[va], y_all[va]

        pack, p_va = train_one_fold_transformer(X_tr, y_tr, X_va, y_va, cfg)
        oof[va] = p_va

        fold_auc = safe_auc(y_va, p_va)
        fold_ll  = safe_logloss(y_va, p_va)
        best_fold = best_fbeta_fast(y_va, p_va, beta=beta, grid=max(51, thr_grid//2))

        fold_rows.append({
            "cfg": cfg_name,
            "fold": int(f),
            "n_val": int(len(va)),
            "pos_val": int(y_va.sum()),
            "auc": fold_auc,
            "logloss": fold_ll,
            "best_fbeta": best_fold["fbeta"],
            "best_thr": best_fold["thr"],
            "best_prec": best_fold["precision"],
            "best_rec": best_fold["recall"],
            "best_val_logloss": float(pack["best_val_logloss"]),
            "best_epoch": int(pack["best_epoch"]),
        })

        fold_packs.append({"fold": int(f), "pack": pack})

        del pack
        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    # overall (on filled parts only)
    # NOTE: ini tetap OOF "full-length", tapi fold yang tidak dihitung akan 0.
    # Untuk stage-1 ranking, kita hitung metric hanya pada indeks fold subset.
    if folds_subset is None:
        idx_eval = np.arange(n)
    else:
        idx_eval = np.where(np.isin(folds_all, np.array(use_folds)))[0]

    oof_eval = oof[idx_eval]
    y_eval = y_all[idx_eval]

    oof_auc = safe_auc(y_eval, oof_eval)
    oof_ll  = safe_logloss(y_eval, oof_eval)
    best_oof = best_fbeta_fast(y_eval, oof_eval, beta=beta, grid=thr_grid)

    summary = {
        "cfg": cfg_name,
        "stage": "full" if folds_subset is None else f"subset{len(use_folds)}",
        "oof_auc": oof_auc,
        "oof_logloss": oof_ll,
        "oof_best_fbeta": best_oof["fbeta"],
        "oof_best_thr": best_oof["thr"],
        "oof_best_prec": best_oof["precision"],
        "oof_best_rec": best_oof["recall"],

        # log cfg
        "d_model": cfg["d_model"],
        "n_layers": cfg["n_layers"],
        "n_heads": cfg["n_heads"],
        "ffn_mult": cfg["ffn_mult"],
        "dropout": cfg["dropout"],
        "attn_dropout": cfg["attn_dropout"],

        # mHC-lite params (PDF)
        "n_streams": cfg["n_streams"],
        "mhc_alpha": cfg["mhc_alpha"],
        "sinkhorn_tmax": cfg["sinkhorn_tmax"],
        "mhc_dropout": cfg["mhc_dropout"],

        # train params (PDF-ish)
        "batch_size": cfg["batch_size"],
        "epochs": cfg["epochs"],
        "lr": cfg["lr"],
        "weight_decay": cfg["weight_decay"],
        "warmup_frac": cfg["warmup_frac"],
        "beta1": cfg["beta1"],
        "beta2": cfg["beta2"],
        "adam_eps": cfg["adam_eps"],
        "lr_decay_ratio1": cfg["lr_decay_ratio1"],
        "lr_decay_ratio2": cfg["lr_decay_ratio2"],
        "lr_decay_rate1": cfg["lr_decay_rate1"],
        "lr_decay_rate2": cfg["lr_decay_rate2"],
        "patience": cfg["patience"],
        "min_delta": cfg["min_delta"],
        "grad_clip": cfg["grad_clip"],
    }
    return summary, fold_rows, oof, fold_packs

# ----------------------------
# 9) Define candidate configs (big but still Kaggle-safe)
# ----------------------------
BASE = dict(
    # training
    batch_size=512 if device.type=="cuda" else 256,
    epochs=70 if device.type=="cuda" else 40,
    lr=2e-4,
    weight_decay=8e-3,          # tabular biasanya lebih kecil dari 0.1; tapi kita tetap grid
    warmup_frac=0.10,
    grad_clip=1.0,
    patience=10,
    min_delta=1e-4,

    # optimizer (PDF)
    beta1=0.9,
    beta2=0.95,
    adam_eps=1e-8,              # PDF table shows 1e-20; 1e-8 lebih aman di Kaggle float32/amp

    # LR schedule (PDF)
    lr_decay_ratio1=0.8,
    lr_decay_ratio2=0.9,
    lr_decay_rate1=0.316,
    lr_decay_rate2=0.1,

    # mHC-lite (PDF)
    n_streams=4,
    mhc_alpha=0.01,
    sinkhorn_tmax=20,
    mhc_dropout=0.0,
)

candidates = []

# Strong default (stabil)
candidates.append(("mhc_384x8", dict(BASE, d_model=384, n_layers=8,  n_heads=8,  ffn_mult=4, dropout=0.20, attn_dropout=0.10)))

# Lebih kuat (sedikit lebih lambat; lebih regularized)
candidates.append(("mhc_384x10_reg", dict(BASE, d_model=384, n_layers=10, n_heads=8,  ffn_mult=4, dropout=0.25, attn_dropout=0.12,
                                         lr=1.6e-4, weight_decay=1.2e-2, epochs=85 if device.type=="cuda" else 50, patience=12)))

# FFN lebih kecil (kadang lebih tahan noise)
candidates.append(("mhc_384x8_ffn2", dict(BASE, d_model=384, n_layers=8,  n_heads=8,  ffn_mult=2, dropout=0.18, attn_dropout=0.10,
                                         lr=2.2e-4, weight_decay=6e-3)))

# Lebih cepat (buat pembanding)
candidates.append(("mhc_256x6_fast", dict(BASE, d_model=256, n_layers=6,  n_heads=8,  ffn_mult=4, dropout=0.18, attn_dropout=0.08,
                                         lr=3e-4, weight_decay=4e-3, epochs=60 if device.type=="cuda" else 35, patience=9)))

# Extreme (hanya kalau CUDA; kalau CPU skip otomatis)
if device.type == "cuda":
    candidates.append(("mhc_512x10_big", dict(BASE, d_model=512, n_layers=10, n_heads=16, ffn_mult=4, dropout=0.28, attn_dropout=0.15,
                                             lr=1.2e-4, weight_decay=2e-2, epochs=90, patience=12,
                                             mhc_dropout=0.05)))

print(f"\nTotal Transformer candidates: {len(candidates)}")
print("Primary score: OOF best F-beta (beta=0.5)")

# ----------------------------
# 10) Run 2-stage search (runtime-safe)
# ----------------------------
OUT_DIR = Path("/kaggle/working/recodai_luc_gate_artifacts")
OPT_DIR = OUT_DIR / "opt_search"
OPT_DIR.mkdir(parents=True, exist_ok=True)

# pick subset folds evenly spaced for stage-1
if STAGE1_FOLDS >= len(unique_folds):
    folds_subset = unique_folds
else:
    stepk = max(1, len(unique_folds) // STAGE1_FOLDS)
    folds_subset = unique_folds[::stepk][:STAGE1_FOLDS]

print("\nStage-1 folds subset:", folds_subset)

t0 = time.time()
stage1_rows = []
stage1_fold_rows = []
stage1_oof_store = {}

for i, (name, cfg) in enumerate(candidates, 1):
    if TIME_BUDGET_SEC and (time.time() - t0) > TIME_BUDGET_SEC:
        print("Time budget reached. Stop search.")
        break

    cfg1 = dict(cfg)
    cfg1["epochs"] = int(min(int(cfg1["epochs"]), int(STAGE1_EPOCH_CAP)))
    cfg1["patience"] = int(min(int(cfg1["patience"]), int(STAGE1_PAT_CAP)))

    print(f"\n[Stage-1 {i:02d}/{len(candidates)}] CV(subset) -> {name}")
    summ, fold_rows, oof, _ = run_cv_config(cfg1, name, folds_subset=folds_subset, beta=BETA, thr_grid=101)

    stage1_rows.append(summ)
    stage1_fold_rows.extend(fold_rows)
    stage1_oof_store[name] = oof

    print(f"  stage1 best_fbeta: {summ['oof_best_fbeta']:.6f} | thr: {summ['oof_best_thr']:.3f}"
          f" | logloss: {summ['oof_logloss']:.6f}")

df_s1 = pd.DataFrame(stage1_rows).sort_values(["oof_best_fbeta","oof_logloss"], ascending=[False, True]).reset_index(drop=True)
print("\nStage-1 ranking (top):")
display(df_s1.head(10))

# pick top-M for stage-2 full CV
topM = min(int(STAGE2_TOPM), len(df_s1))
stage2_names = df_s1["cfg"].head(topM).tolist()
print("\nStage-2 will run full CV for:", stage2_names)

all_summaries = []
all_fold_rows = []
oof_store = {}

for j, nm in enumerate(stage2_names, 1):
    if TIME_BUDGET_SEC and (time.time() - t0) > TIME_BUDGET_SEC:
        print("Time budget reached. Stop stage-2.")
        break

    cfg = None
    for (nname, ccfg) in candidates:
        if nname == nm:
            cfg = ccfg
            break
    if cfg is None:
        continue

    print(f"\n[Stage-2 {j:02d}/{len(stage2_names)}] CV(full) -> {nm}")
    summ, fold_rows, oof, _ = run_cv_config(cfg, nm, folds_subset=None, beta=BETA, thr_grid=THR_GRID)

    all_summaries.append(summ)
    all_fold_rows.extend(fold_rows)
    oof_store[nm] = oof

    print(f"  OOF best_fbeta: {summ['oof_best_fbeta']:.6f} | thr: {summ['oof_best_thr']:.3f}"
          f" | auc: {(summ['oof_auc'] if summ['oof_auc'] is not None else float('nan')):.6f}"
          f" | logloss: {summ['oof_logloss']:.6f}")

df_sum = pd.DataFrame(all_summaries)
df_fold = pd.DataFrame(all_fold_rows)

df_sum = df_sum.sort_values(["oof_best_fbeta", "oof_logloss"], ascending=[False, True]).reset_index(drop=True)

print("\nStage-2 top candidates (full CV):")
display(df_sum)

# save search results
df_sum.to_csv(OPT_DIR / "opt_results.csv", index=False)
with open(OPT_DIR / "opt_results.json", "w") as f:
    json.dump(df_sum.to_dict(orient="records"), f, indent=2)
df_fold.to_csv(OPT_DIR / "opt_fold_details.csv", index=False)

# save OOF preds for top configs (debugging)
top_names = df_sum["cfg"].head(min(REPORT_TOPK_OOF, len(df_sum))).tolist()
for nm in top_names:
    df_o = pd.DataFrame({
        "uid": uids_all,
        "y": y_all,
        "fold": folds_all,
        f"oof_pred_{nm}": oof_store[nm]
    })
    df_o.to_csv(OPT_DIR / f"oof_preds_{nm}.csv", index=False)

# ----------------------------
# 11) Choose best config + retrain fold packs for best (full CV)
# ----------------------------
if len(df_sum) == 0:
    raise RuntimeError("Stage-2 produced no results. Cek device/VRAM atau turunkan kandidat/epochs.")

best_single = df_sum.iloc[0].to_dict()
best_cfg_name = best_single["cfg"]

best_cfg = None
for nm, cfg in candidates:
    if nm == best_cfg_name:
        best_cfg = cfg
        break
if best_cfg is None:
    raise RuntimeError("Best cfg not found in candidates list (unexpected).")

print("\nBest config:", best_cfg_name)
print(best_single)

print(f"\nRe-train folds for best config -> {best_cfg_name}")
best_fold_packs = []
best_oof = np.zeros(n, dtype=np.float32)

for f in unique_folds:
    tr = np.where(folds_all != f)[0]
    va = np.where(folds_all == f)[0]

    X_tr, y_tr = X_all[tr], y_all[tr]
    X_va, y_va = X_all[va], y_all[va]

    pack, p_va = train_one_fold_transformer(X_tr, y_tr, X_va, y_va, best_cfg)
    pack["fold"] = int(f)
    best_fold_packs.append(pack)
    best_oof[va] = p_va

    gc.collect()
    if device.type == "cuda":
        torch.cuda.empty_cache()

best_model_path = OUT_DIR / "best_gate_model.pt"
torch.save(
    {
        "type": "mhc_lite_ft_transformer",
        "feature_cols": FEATURE_COLS,
        "fold_packs": best_fold_packs,
        "cfg_name": best_cfg_name,
        "cfg": best_cfg,
        "seed": SEED,
    },
    best_model_path
)

best_oof_best = best_fbeta_fast(y_all, best_oof, beta=BETA, grid=THR_GRID)

best_bundle = {
    # dibuat supaya Step 5 nanti gampang:
    "type": "mhc_lite_ft_transformer",
    "model_name": best_cfg_name,
    "members": [best_cfg_name],
    "random_seed": SEED,
    "beta_for_tuning": BETA,

    "feature_cols": FEATURE_COLS,
    "cfg": best_cfg,

    "oof_best_thr": best_oof_best["thr"],
    "oof_best_fbeta": best_oof_best["fbeta"],
    "oof_best_prec": best_oof_best["precision"],
    "oof_best_rec": best_oof_best["recall"],
    "oof_auc": safe_auc(y_all, best_oof),
    "oof_logloss": safe_logloss(y_all, best_oof),

    "notes": "Best config from Step 4 (Transformer-only, mHC-lite PDF-inspired). Step 5 should train FULL model for inference.",
}

with open(OUT_DIR / "best_gate_config.json", "w") as f:
    json.dump(best_bundle, f, indent=2)

print("\nSaved best artifacts:")
print("  best model (fold packs) ->", best_model_path)
print("  best config             ->", OUT_DIR / "best_gate_config.json")
print("  opt results             ->", OPT_DIR / "opt_results.csv")
print("  fold detail             ->", OPT_DIR / "opt_fold_details.csv")

# Export globals for Step 5
BEST_GATE_BUNDLE = best_bundle
BEST_TF_CFG_NAME = best_cfg_name
BEST_TF_CFG = best_cfg
OPT_RESULTS_DF = df_sum


# Final Training (Train on Full Data)

In [None]:
# ============================================================
# Step 5 — Final Training (Train on Full Data) — TRANSFORMER ONLY (REVISI FULL v3)
# - FIX: predict_proba aman untuk batch=(xb,yb)/list/tuple
# - mHC-style training: AdamW betas(0.9,0.95) + Warmup + Step decay (0.8/0.9)
# - Auto "besar tapi muat" berdasarkan VRAM (safe untuk runtime Kaggle)
#
# Output:
#   /kaggle/working/recodai_luc_gate_artifacts/final_gate_model.pt
#   /kaggle/working/recodai_luc_gate_artifacts/final_gate_bundle.json
# ============================================================

import os, json, gc, math, time, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.metrics import log_loss, roc_auc_score

# ----------------------------
# 0) REQUIRE
# ----------------------------
if "df_train_tabular" not in globals():
    raise RuntimeError("Missing `df_train_tabular`. Jalankan Step 2 dulu.")
if "FEATURE_COLS" not in globals():
    raise RuntimeError("Missing `FEATURE_COLS`. Jalankan Step 2 dulu.")

df_train_tabular = df_train_tabular.copy()
FEATURE_COLS = list(FEATURE_COLS)

need_cols = {"uid","case_id","fold","y"}
miss = [c for c in need_cols if c not in df_train_tabular.columns]
if miss:
    raise ValueError(f"df_train_tabular missing columns: {miss}")

X_all = df_train_tabular[FEATURE_COLS].to_numpy(dtype=np.float32, copy=True)
y_all = df_train_tabular["y"].to_numpy(dtype=np.int64, copy=True)

if not np.isfinite(X_all).all():
    X_all = np.nan_to_num(X_all, nan=0.0, posinf=0.0, neginf=0.0)

OUT_DIR = Path("/kaggle/working/recodai_luc_gate_artifacts")
OUT_DIR.mkdir(parents=True, exist_ok=True)

print("Final training data:")
print(f"  rows={len(y_all)} | pos%={float(y_all.mean())*100:.2f} | n_features={X_all.shape[1]}")

# ----------------------------
# 1) Load best cfg (optional)
#    - kompatibel: format lama transformer_ft, atau bundle apapun yang punya ["cfg"]
# ----------------------------
cfg_path = OUT_DIR / "best_gate_config.json"
best_bundle = None
source = None

if "BEST_GATE_BUNDLE" in globals() and isinstance(BEST_GATE_BUNDLE, dict):
    best_bundle = BEST_GATE_BUNDLE
    source = "memory(BEST_GATE_BUNDLE)"
elif cfg_path.exists():
    best_bundle = json.loads(cfg_path.read_text())
    source = str(cfg_path)

if best_bundle is not None and isinstance(best_bundle, dict) and isinstance(best_bundle.get("cfg", None), dict):
    base_cfg = dict(best_bundle["cfg"])
    print("\nLoaded cfg from:", source)
else:
    base_cfg = {}
    print("\nNo best_gate_config found. Using strong default cfg.")

# ----------------------------
# 2) Device + seed
# ----------------------------
FINAL_SEED = int(best_bundle.get("seed", 2025)) if isinstance(best_bundle, dict) else 2025

def seed_everything(seed: int):
    import random
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

seed_everything(FINAL_SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
use_amp = (device.type == "cuda")

# matmul perf (pytorch 2.x)
try:
    torch.set_float32_matmul_precision("high")
except Exception:
    pass

vram_gb = None
if device.type == "cuda":
    vram_gb = torch.cuda.get_device_properties(0).total_memory / (1024**3)

print("\nDevice:", device, "| AMP:", use_amp, "| VRAM_GB:", (f"{vram_gb:.1f}" if vram_gb else "CPU"))

# ----------------------------
# 3) Runtime-safe "big model" policy
#    - kamu minta model besar tersimpan => default ON jika CUDA
#    - kalau CPU: otomatis dikecilkan supaya runtime tidak meledak
# ----------------------------
WANT_BIG_MODEL = True
USE_INTERNAL_VAL = True         # early stopping lebih aman
VAL_FRAC_CASE = 0.08            # 8% case-level val (cukup stabil, tidak terlalu buang data)
EARLY_STOP = True

# multi-seed = lebih kuat tapi lebih lama; keep 1 untuk runtime Kaggle
N_SEEDS = 1

# default cfg kalau base_cfg kosong
CFG = dict(
    d_model=384,
    n_layers=8,
    n_heads=8,
    ffn_mult=4,
    dropout=0.20,
    attn_dropout=0.10,

    batch_size=512,
    epochs=60,              # cukup, early stop akan potong
    lr=2e-4,
    weight_decay=1e-2,      # lebih dekat spirit AdamW (paper), tapi tidak seganas 0.1
    warmup_frac=0.10,
    grad_clip=1.0,

    patience=10,
    min_delta=1e-4,
)

# merge base_cfg -> CFG
for k,v in base_cfg.items():
    if k in CFG:
        CFG[k] = v

# auto upscale (big) jika GPU
def autoscale_cfg_for_device(cfg: dict):
    cfg = dict(cfg)
    if device.type != "cuda":
        # CPU fallback: lebih kecil biar feasible
        cfg["d_model"] = min(int(cfg.get("d_model", 384)), 256)
        cfg["n_layers"] = min(int(cfg.get("n_layers", 8)), 6)
        cfg["n_heads"]  = min(int(cfg.get("n_heads", 8)), 8)
        cfg["batch_size"] = min(int(cfg.get("batch_size", 512)), 256)
        cfg["epochs"] = min(int(cfg.get("epochs", 60)), 35)
        cfg["lr"] = float(cfg.get("lr", 2e-4))
        cfg["weight_decay"] = float(cfg.get("weight_decay", 1e-2))
        return cfg

    # CUDA
    if not WANT_BIG_MODEL:
        return cfg

    # pilih "besar tapi aman" berdasar VRAM
    if vram_gb is None:
        return cfg

    if vram_gb >= 24:
        cfg.update(dict(d_model=max(int(cfg["d_model"]), 512),
                        n_layers=max(int(cfg["n_layers"]), 12),
                        n_heads=max(int(cfg["n_heads"]), 16),
                        ffn_mult=max(int(cfg["ffn_mult"]), 4),
                        dropout=max(float(cfg["dropout"]), 0.25),
                        attn_dropout=max(float(cfg["attn_dropout"]), 0.12),
                        lr=min(float(cfg["lr"]), 1.5e-4),
                        weight_decay=max(float(cfg["weight_decay"]), 2e-2),
                        batch_size=min(int(cfg["batch_size"]), 512),
                        epochs=max(int(cfg["epochs"]), 70),
                        patience=max(int(cfg["patience"]), 12)))
    elif vram_gb >= 16:
        cfg.update(dict(d_model=max(int(cfg["d_model"]), 448),
                        n_layers=max(int(cfg["n_layers"]), 10),
                        n_heads=max(int(cfg["n_heads"]), 8),
                        ffn_mult=max(int(cfg["ffn_mult"]), 4),
                        dropout=max(float(cfg["dropout"]), 0.24),
                        attn_dropout=max(float(cfg["attn_dropout"]), 0.12),
                        lr=min(float(cfg["lr"]), 1.8e-4),
                        weight_decay=max(float(cfg["weight_decay"]), 1.5e-2),
                        batch_size=min(int(cfg["batch_size"]), 512),
                        epochs=max(int(cfg["epochs"]), 65),
                        patience=max(int(cfg["patience"]), 11)))
    else:
        # 8–12GB (umum Kaggle)
        cfg.update(dict(d_model=max(int(cfg["d_model"]), 384),
                        n_layers=max(int(cfg["n_layers"]), 8),
                        n_heads=max(int(cfg["n_heads"]), 8),
                        ffn_mult=max(int(cfg["ffn_mult"]), 4),
                        dropout=max(float(cfg["dropout"]), 0.22),
                        attn_dropout=max(float(cfg["attn_dropout"]), 0.10),
                        lr=min(float(cfg["lr"]), 2e-4),
                        weight_decay=max(float(cfg["weight_decay"]), 1e-2),
                        batch_size=min(int(cfg["batch_size"]), 384),   # jaga OOM
                        epochs=max(int(cfg["epochs"]), 60),
                        patience=max(int(cfg["patience"]), 10)))
    return cfg

CFG = autoscale_cfg_for_device(CFG)

# mHC paper: AdamW betas (0.9,0.95); eps kecil (tetap aman)
ADAM_BETAS = (0.9, 0.95)
ADAM_EPS   = 1e-8  # paper pakai eps sangat kecil; 1e-8 lebih aman numerik untuk tabular

# Step decay ratios (paper): 0.8 & 0.9 epoch fractions
STEP_RATIOS = (0.8, 0.9)
STEP_GAMMAS = (0.316, 0.1)

print("\nCFG (final):")
for k in ["d_model","n_layers","n_heads","ffn_mult","dropout","attn_dropout","batch_size","epochs","lr","weight_decay","warmup_frac","patience"]:
    print(f"  {k}: {CFG[k]}")

# effective batch via grad accumulation (biar tetap stabil walau batch diturunin)
TARGET_EFF_BATCH = 1024 if device.type=="cuda" else 256
GRAD_ACCUM = max(1, int(math.ceil(TARGET_EFF_BATCH / int(CFG["batch_size"]))))

print(f"\nGradAccum: {GRAD_ACCUM} (target_eff_batch={TARGET_EFF_BATCH})")

# ----------------------------
# 4) Dataset + Standardizer
# ----------------------------
class TabDataset(Dataset):
    def __init__(self, X, y=None):
        self.X = torch.from_numpy(X.astype(np.float32))
        self.y = None if y is None else torch.from_numpy(y.astype(np.int64))
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, idx):
        if self.y is None:
            return self.X[idx]
        return self.X[idx], self.y[idx]

def fit_standardizer(X_tr: np.ndarray):
    mu = X_tr.mean(axis=0, dtype=np.float64)
    sig = X_tr.std(axis=0, dtype=np.float64)
    sig = np.where(sig < 1e-8, 1.0, sig)
    return mu.astype(np.float32), sig.astype(np.float32)

def apply_standardizer(X_in: np.ndarray, mu: np.ndarray, sig: np.ndarray):
    return ((X_in - mu) / sig).astype(np.float32)

# ----------------------------
# 5) Model (FT-Transformer numeric) + DeepNorm-like scaling (mHC spirit)
# ----------------------------
class RMSNorm(nn.Module):
    def __init__(self, d, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(d))
    def forward(self, x):
        # x: (..., d)
        rms = x.pow(2).mean(dim=-1, keepdim=True).add(self.eps).sqrt()
        return (x / rms) * self.scale

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, ffn_mult=4, dropout=0.2, attn_dropout=0.1, deepnorm_alpha=1.0):
        super().__init__()
        self.alpha = float(deepnorm_alpha)

        self.norm1 = RMSNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, n_heads, dropout=attn_dropout, batch_first=True)
        self.drop1 = nn.Dropout(dropout)

        self.norm2 = RMSNorm(d_model)
        hidden = int(ffn_mult * d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, hidden),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, d_model),
        )
        self.drop2 = nn.Dropout(dropout)

    def forward(self, x):
        # Pre-norm + DeepNorm residual scaling
        h = self.norm1(x)
        attn_out, _ = self.attn(h, h, h, need_weights=False)
        x = (x * self.alpha) + self.drop1(attn_out)

        h = self.norm2(x)
        ffn_out = self.ffn(h)
        x = (x * self.alpha) + self.drop2(ffn_out)
        return x

class FTTransformerBig(nn.Module):
    def __init__(self, n_features, d_model=384, n_heads=8, n_layers=8, ffn_mult=4, dropout=0.2, attn_dropout=0.1):
        super().__init__()
        self.n_features = n_features
        self.d_model = d_model

        # per-feature tokenization (numeric-only)
        self.w = nn.Parameter(torch.randn(n_features, d_model) * 0.02)
        self.b = nn.Parameter(torch.zeros(n_features, d_model))
        self.feat_emb = nn.Parameter(torch.randn(n_features, d_model) * 0.02)

        self.cls = nn.Parameter(torch.randn(1, 1, d_model) * 0.02)

        # DeepNorm alpha (paper spirit): alpha ~ (2L)^(1/4)
        alpha = (2.0 * n_layers) ** 0.25

        self.token_dropout = nn.Dropout(attn_dropout)
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, n_heads, ffn_mult=ffn_mult, dropout=dropout, attn_dropout=attn_dropout, deepnorm_alpha=alpha)
            for _ in range(n_layers)
        ])
        self.final_norm = RMSNorm(d_model)

        self.head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 1),
        )

    def forward(self, x):
        # x: (B,F)
        tok = x.unsqueeze(-1) * self.w.unsqueeze(0) + self.b.unsqueeze(0)  # (B,F,D)
        tok = tok + self.feat_emb.unsqueeze(0)
        tok = self.token_dropout(tok)

        B = tok.size(0)
        cls = self.cls.expand(B, -1, -1)          # (B,1,D)
        seq = torch.cat([cls, tok], dim=1)        # (B,1+F,D)

        for blk in self.blocks:
            seq = blk(seq)

        z = self.final_norm(seq[:, 0])            # CLS
        logit = self.head(z).squeeze(-1)
        return logit

# ----------------------------
# 6) Schedulers (Warmup + Step decay @ 0.8/0.9 epochs)
# ----------------------------
def make_warmup_then_step_scheduler(optimizer, total_steps, warmup_steps, step_milestones, gammas):
    """
    warmup linear -> base lr
    then multiply lr by gammas at given milestone steps (cumulative)
    """
    step_milestones = list(step_milestones)
    gammas = list(gammas)

    def lr_lambda(step):
        if step < warmup_steps:
            return float(step + 1) / float(max(1, warmup_steps))

        mult = 1.0
        for ms, g in zip(step_milestones, gammas):
            if step >= ms:
                mult *= g
        return mult

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)

# ----------------------------
# 7) Predict helper (anti tuple-batch error)
# ----------------------------
@torch.no_grad()
def predict_proba(model, loader):
    model.eval()
    ps = []
    for batch in loader:
        xb = batch[0] if isinstance(batch, (list, tuple)) else batch
        xb = xb.to(device, non_blocking=True)
        with torch.cuda.amp.autocast(enabled=use_amp):
            logits = model(xb)
            p = torch.sigmoid(logits)
        ps.append(p.detach().cpu().numpy())
    return np.concatenate(ps, axis=0).astype(np.float32)

def safe_logloss(y_true, p):
    p = np.clip(np.asarray(p, dtype=np.float64), 1e-8, 1-1e-8)
    return float(log_loss(y_true, p, labels=[0,1]))

def safe_auc(y_true, p):
    if len(np.unique(y_true)) < 2:
        return None
    return float(roc_auc_score(y_true, p))

# ----------------------------
# 8) Internal val split (group-safe by case_id)
# ----------------------------
def make_case_split(df: pd.DataFrame, val_frac=0.08, seed=2025):
    g = df.groupby("case_id")["y"].max().reset_index().rename(columns={"y":"case_y"})
    pos_cases = g.loc[g["case_y"] == 1, "case_id"].to_numpy()
    neg_cases = g.loc[g["case_y"] == 0, "case_id"].to_numpy()

    rng = np.random.RandomState(seed)
    rng.shuffle(pos_cases)
    rng.shuffle(neg_cases)

    n_val_pos = max(1, int(len(pos_cases) * val_frac)) if len(pos_cases) else 0
    n_val_neg = max(1, int(len(neg_cases) * val_frac)) if len(neg_cases) else 0

    val_cases = np.concatenate([pos_cases[:n_val_pos], neg_cases[:n_val_neg]])
    val_set = set(map(int, val_cases.tolist()))
    is_val = df["case_id"].astype(int).map(lambda x: int(x) in val_set).to_numpy(dtype=bool)
    return is_val

# ----------------------------
# 9) Train once (full-data with optional internal val)
# ----------------------------
def train_full_once(seed_offset=0):
    seed_everything(FINAL_SEED + seed_offset)

    if USE_INTERNAL_VAL:
        is_val = make_case_split(df_train_tabular, val_frac=float(VAL_FRAC_CASE), seed=FINAL_SEED + seed_offset)
        tr_idx = np.where(~is_val)[0]
        va_idx = np.where(is_val)[0]
        X_tr, y_tr = X_all[tr_idx], y_all[tr_idx]
        X_va, y_va = X_all[va_idx], y_all[va_idx]
        print(f"  internal split: train={len(tr_idx)} | val={len(va_idx)} | val_pos%={float(y_va.mean())*100:.2f}")
    else:
        X_tr, y_tr = X_all, y_all
        X_va = y_va = None

    mu, sig = fit_standardizer(X_tr)
    X_trn = apply_standardizer(X_tr, mu, sig)

    ds_tr = TabDataset(X_trn, y_tr)
    dl_tr = DataLoader(
        ds_tr,
        batch_size=int(CFG["batch_size"]),
        shuffle=True,
        num_workers=2,
        pin_memory=(device.type=="cuda"),
        drop_last=False
    )

    if USE_INTERNAL_VAL:
        X_van = apply_standardizer(X_va, mu, sig)
        ds_va = TabDataset(X_van, y_va)
        dl_va = DataLoader(
            ds_va,
            batch_size=int(CFG["batch_size"]),
            shuffle=False,
            num_workers=2,
            pin_memory=(device.type=="cuda"),
            drop_last=False
        )
    else:
        dl_va = None

    model = FTTransformerBig(
        n_features=X_all.shape[1],
        d_model=int(CFG["d_model"]),
        n_heads=int(CFG["n_heads"]),
        n_layers=int(CFG["n_layers"]),
        ffn_mult=int(CFG["ffn_mult"]),
        dropout=float(CFG["dropout"]),
        attn_dropout=float(CFG["attn_dropout"]),
    ).to(device)

    # imbalance pos_weight (from TRAIN split)
    pos = int(y_tr.sum())
    neg = int(len(y_tr) - pos)
    pos_weight = float(neg / max(1, pos))
    loss_fn = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))

    opt = torch.optim.AdamW(
        model.parameters(),
        lr=float(CFG["lr"]),
        weight_decay=float(CFG["weight_decay"]),
        betas=ADAM_BETAS,
        eps=ADAM_EPS,
    )

    steps_per_epoch = max(1, len(dl_tr))
    total_steps = int(CFG["epochs"]) * steps_per_epoch
    warmup_steps = int(float(CFG["warmup_frac"]) * total_steps)

    ms1 = int(STEP_RATIOS[0] * total_steps)
    ms2 = int(STEP_RATIOS[1] * total_steps)
    sch = make_warmup_then_step_scheduler(opt, total_steps, warmup_steps, [ms1, ms2], list(STEP_GAMMAS))

    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    best_val = 1e18
    best_state = None
    best_epoch = -1
    bad = 0

    t0 = time.time()
    for epoch in range(int(CFG["epochs"])):
        model.train()
        loss_sum = 0.0
        n_sum = 0

        opt.zero_grad(set_to_none=True)
        for it, (xb, yb) in enumerate(dl_tr, 1):
            xb = xb.to(device, non_blocking=True)
            yb = yb.to(device, non_blocking=True).float()

            with torch.cuda.amp.autocast(enabled=use_amp):
                logits = model(xb)
                loss = loss_fn(logits, yb)
                loss = loss / float(GRAD_ACCUM)

            scaler.scale(loss).backward()

            if (it % GRAD_ACCUM) == 0 or it == len(dl_tr):
                if float(CFG.get("grad_clip", 1.0)) > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(CFG.get("grad_clip", 1.0)))

                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
                sch.step()

            loss_sum += float(loss.item()) * xb.size(0) * float(GRAD_ACCUM)
            n_sum += xb.size(0)

        tr_loss = loss_sum / max(1, n_sum)

        if dl_va is not None:
            p_va = predict_proba(model, dl_va)
            vll = safe_logloss(y_va, p_va)
            vauc = safe_auc(y_va, p_va)

            improved = (best_val - vll) > float(CFG["min_delta"])
            if improved:
                best_val = float(vll)
                best_epoch = int(epoch)
                best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
                bad = 0
            else:
                bad += 1

            print(f"  epoch {epoch+1:03d}/{int(CFG['epochs'])} | tr_loss={tr_loss:.5f} | val_ll={vll:.5f} | val_auc={(vauc if vauc is not None else float('nan')):.5f} | bad={bad}")
            if EARLY_STOP and bad >= int(CFG["patience"]):
                print(f"  early stop at epoch {epoch+1}, best_epoch={best_epoch+1}, best_val_ll={best_val:.5f}")
                break
        else:
            print(f"  epoch {epoch+1:03d}/{int(CFG['epochs'])} | tr_loss={tr_loss:.5f}")

        gc.collect()

    # restore best weights
    if best_state is not None:
        model.load_state_dict(best_state, strict=True)

    pack = {
        "type": "ft_transformer_big_full",
        "state_dict": {k: v.detach().cpu() for k, v in model.state_dict().items()},
        "mu": mu,
        "sig": sig,
        "cfg": dict(CFG),
        "seed": int(FINAL_SEED + seed_offset),
        "pos_weight": float(pos_weight),
        "train_rows": int(len(y_tr)),
        "val_rows": int(len(y_va)) if USE_INTERNAL_VAL else 0,
        "best_epoch": int(best_epoch + 1) if best_epoch >= 0 else None,
        "best_val_logloss": float(best_val) if best_state is not None else None,
        "train_time_s": float(time.time() - t0),
        "grad_accum": int(GRAD_ACCUM),
        "adam_betas": list(ADAM_BETAS),
        "step_ratios": list(STEP_RATIOS),
        "step_gammas": list(STEP_GAMMAS),
    }
    return pack

# ----------------------------
# 10) Train final model(s) with OOM-safe fallback
# ----------------------------
final_packs = []
for s in range(int(N_SEEDS)):
    print(f"\n[Final Train] seed_offset={s}")
    try:
        pack = train_full_once(seed_offset=s)
    except RuntimeError as e:
        # OOM fallback: reduce batch_size then d_model/layers
        msg = str(e).lower()
        if ("out of memory" in msg) and device.type == "cuda":
            print("  OOM detected. Applying fallback: batch_size -> half, d_model/layers -> downshift")
            torch.cuda.empty_cache()

            CFG["batch_size"] = max(128, int(CFG["batch_size"]) // 2)
            CFG["d_model"] = max(256, int(CFG["d_model"]) - 64)
            CFG["n_layers"] = max(6, int(CFG["n_layers"]) - 2)
            TARGET_EFF_BATCH = 768
            GRAD_ACCUM = max(1, int(math.ceil(TARGET_EFF_BATCH / int(CFG["batch_size"]))))

            print("  New CFG:")
            for k in ["d_model","n_layers","n_heads","batch_size","epochs"]:
                print(f"    {k}: {CFG[k]}")
            print(f"  New GradAccum: {GRAD_ACCUM}")

            pack = train_full_once(seed_offset=s)
        else:
            raise
    final_packs.append(pack)
    gc.collect()

# ----------------------------
# 11) Save artifacts
# ----------------------------
final_model_path = OUT_DIR / "final_gate_model.pt"
torch.save(
    {
        "feature_cols": FEATURE_COLS,
        "packs": final_packs,  # list, even if 1 seed
        "bundle_source": source,
    },
    final_model_path
)

final_bundle = {
    "type": "ft_transformer_big_full",
    "n_seeds": int(N_SEEDS),
    "seeds": [int(p["seed"]) for p in final_packs],
    "feature_cols": FEATURE_COLS,
    "cfg": dict(CFG),
    "use_internal_val": bool(USE_INTERNAL_VAL),
    "val_frac_case": float(VAL_FRAC_CASE) if USE_INTERNAL_VAL else 0.0,
    "early_stop": bool(EARLY_STOP) if USE_INTERNAL_VAL else False,
    "train_rows": int(len(y_all)),
    "pos_rate": float(y_all.mean()),
    "notes": "Final model = BIG tabular transformer head over DINOv2-Large-derived features. Save this for inference loading.",
    "ref_best_bundle": (best_bundle.get("selection", {}) if isinstance(best_bundle, dict) else {}),
}

final_bundle_path = OUT_DIR / "final_gate_bundle.json"
final_bundle_path.write_text(json.dumps(final_bundle, indent=2))

print("\nSaved final training artifacts:")
print("  model  ->", final_model_path)
print("  bundle ->", final_bundle_path)

# Export globals
FINAL_GATE_MODEL_PT = str(final_model_path)
FINAL_GATE_BUNDLE = final_bundle


# Finalize & Save Model Bundle (Reproducible)

In [None]:
# ============================================================
# Step 6 — Finalize & Save Model Bundle (Reproducible) — REVISI FULL v3 (TRANSFORMER COMPAT)
# - Fokus: bundle artefak penting (Transformer .pt) + threshold placeholder + manifest + ZIP portable
# - Tidak ada submission di sini
#
# REQUIRE:
# - Step 2: feature_cols.json
# - Step 5: final_gate_model.pt + final_gate_bundle.json
# ============================================================

import os, json, time, platform, warnings, zipfile
from pathlib import Path

warnings.filterwarnings("ignore")

OUT_DIR = Path("/kaggle/working/recodai_luc_gate_artifacts")
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# Helpers
# ----------------------------
def read_json_safe(p: Path, default=None):
    try:
        return json.loads(p.read_text())
    except Exception:
        return default

def pick_first_existing(paths):
    for p in paths:
        if p is None:
            continue
        p = Path(p)
        if p.exists() and p.is_file():
            return p
    return None

def safe_add(zf: zipfile.ZipFile, p: Path, arcname: str):
    if p is None:
        return
    p = Path(p)
    if p.exists() and p.is_file():
        zf.write(p, arcname=arcname)

# ----------------------------
# 0) Locate required artifacts
# ----------------------------
final_model_pt = OUT_DIR / "final_gate_model.pt"
final_bundle_path = OUT_DIR / "final_gate_bundle.json"
feature_cols_path = OUT_DIR / "feature_cols.json"

if not feature_cols_path.exists():
    raise FileNotFoundError(f"Missing feature_cols: {feature_cols_path} (jalankan Step 2 dulu)")
if not final_model_pt.exists():
    raise FileNotFoundError(f"Missing final model: {final_model_pt} (jalankan Step 5 dulu)")

# Optional / legacy (hanya untuk kompat)
final_model_joblib = OUT_DIR / "final_gate_model.joblib"  # legacy fallback jika suatu saat ada
model_format = "torch_pt"
final_model_path = final_model_pt

# Optional artifacts (gunakan nama yang benar sesuai Step 3/4 kamu)
baseline_report_path = pick_first_existing([
    OUT_DIR / "baseline_transformer_cv_report.json",   # paling mungkin (dari Step 3 kamu)
    OUT_DIR / "baseline_transformer_cv_report_v2.json",
    OUT_DIR / "baseline_cv_report.json",               # legacy
])

opt_config_path = pick_first_existing([
    OUT_DIR / "best_gate_config.json",                 # dari Step 4
])

opt_results_csv = pick_first_existing([
    OUT_DIR / "opt_search" / "opt_results.csv",
])

opt_fold_csv = pick_first_existing([
    OUT_DIR / "opt_search" / "opt_fold_details.csv",
])

oof_tf_baseline_csv = pick_first_existing([
    OUT_DIR / "oof_baseline_transformer.csv",
])

# ini optional dan sering tidak ada (jangan bikin error)
oof_baseline_csv = pick_first_existing([
    OUT_DIR / "oof_baseline.csv",
])

print("Found artifacts:")
print("  final_model  :", final_model_path, f"(format={model_format})")
print("  final_bundle :", final_bundle_path if final_bundle_path.exists() else "(missing/skip)")
print("  feature_cols :", feature_cols_path)
print("  baseline_report :", baseline_report_path if baseline_report_path else "(missing/skip)")
print("  best_gate_config :", opt_config_path if opt_config_path else "(missing/skip)")

# ----------------------------
# 1) Load metadata
# ----------------------------
feature_cols = read_json_safe(feature_cols_path, default=[])
if not isinstance(feature_cols, list) or len(feature_cols) == 0:
    raise ValueError("feature_cols.json invalid / empty")

final_bundle = read_json_safe(final_bundle_path, default={}) if final_bundle_path.exists() else {}
baseline_report = read_json_safe(baseline_report_path, default=None) if baseline_report_path else None
opt_config = read_json_safe(opt_config_path, default=None) if opt_config_path else None

# ----------------------------
# 2) Threshold placeholders (Transformer)
# - PRIORITY:
#   (a) thresholds.json (kalau sudah ada)
#   (b) best_gate_config.json -> selection.oof_best_thr (Step 4)
#   (c) fallback 0.5
# ----------------------------
thresholds_path = OUT_DIR / "thresholds.json"

if thresholds_path.exists():
    thresholds = read_json_safe(thresholds_path, default={})
else:
    T_gate = None

    # (b) Step 4 selection
    if isinstance(opt_config, dict):
        sel = opt_config.get("selection", {})
        if isinstance(sel, dict):
            T_gate = sel.get("oof_best_thr", None)

    # fallback
    if T_gate is None:
        T_gate = 0.5

    thresholds = {
        "T_gate": float(T_gate),
        "beta_for_tuning": 0.5,
        "guards": {
            "min_area_frac": None,
            "max_area_frac": None,
            "max_components": None
        },
        "notes": "Placeholder. Update after calibration/threshold tuning on OOF or validation set."
    }
    thresholds_path.write_text(json.dumps(thresholds, indent=2))

# ----------------------------
# 3) Capture dataset/cfg metadata (if available)
# ----------------------------
cfg_meta = {}
if "PATHS" in globals() and isinstance(PATHS, dict):
    cfg_meta = {
        "COMP_ROOT": PATHS.get("COMP_ROOT", None),
        "OUT_DS_ROOT": PATHS.get("OUT_DS_ROOT", None),
        "OUT_ROOT": PATHS.get("OUT_ROOT", None),
        "MATCH_CFG_DIR": PATHS.get("MATCH_CFG_DIR", None),
        "PRED_CFG_DIR": PATHS.get("PRED_CFG_DIR", None),
        "DINO_CFG_DIR": PATHS.get("DINO_CFG_DIR", None),
        "DINO_LARGE_DIR": PATHS.get("DINO_LARGE_DIR", None),
        "PRED_FEAT_TRAIN": PATHS.get("PRED_FEAT_TRAIN", None),
        "MATCH_FEAT_TRAIN": PATHS.get("MATCH_FEAT_TRAIN", None),
        "DF_TRAIN_ALL": PATHS.get("DF_TRAIN_ALL", None),
        "CV_CASE_FOLDS": PATHS.get("CV_CASE_FOLDS", None),
        "IMG_PROFILE_TRAIN": PATHS.get("IMG_PROFILE_TRAIN", None),
    }

# ----------------------------
# 4) Manifest (reproducible)
# ----------------------------
task_str = "Recod.ai/LUC — Gate Model (authentic vs forged) — DINOv2 features + Transformer gate (.pt)"

manifest = {
    "created_utc": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
    "python": platform.python_version(),
    "platform": platform.platform(),
    "bundle_version": "v3",
    "task": task_str,
    "model_format": model_format,
    "artifacts": {
        "final_model": str(final_model_path),
        "final_bundle": str(final_bundle_path) if final_bundle_path.exists() else None,
        "feature_cols": str(feature_cols_path),
        "thresholds": str(thresholds_path),
        "baseline_report": str(baseline_report_path) if baseline_report_path else None,
        "best_gate_config": str(opt_config_path) if opt_config_path else None,
        "opt_results_csv": str(opt_results_csv) if opt_results_csv else None,
        "opt_fold_details_csv": str(opt_fold_csv) if opt_fold_csv else None,
        "oof_baseline_csv": str(oof_baseline_csv) if oof_baseline_csv else None,
        "oof_baseline_transformer_csv": str(oof_tf_baseline_csv) if oof_tf_baseline_csv else None,
    },
    "cfg_meta": cfg_meta,
    "model_summary": {
        "type": (final_bundle.get("type") if isinstance(final_bundle, dict) else None),
        "n_seeds": (final_bundle.get("n_seeds") if isinstance(final_bundle, dict) else None),
        "seeds": (final_bundle.get("seeds") if isinstance(final_bundle, dict) else None),
        "train_rows": (final_bundle.get("train_rows") if isinstance(final_bundle, dict) else None),
        "pos_rate": (final_bundle.get("pos_rate") if isinstance(final_bundle, dict) else None),
        "feature_count": int(len(feature_cols)),
        "T_gate": float(thresholds.get("T_gate", 0.5)),
    },
    "baseline_summary": (baseline_report.get("overall") if isinstance(baseline_report, dict) else None),
    "opt_summary": (opt_config.get("selection") if isinstance(opt_config, dict) else None),
}

manifest_path = OUT_DIR / "model_bundle_manifest.json"
manifest_path.write_text(json.dumps(manifest, indent=2))

# ----------------------------
# 5) Bundle pack (JSON portable) + optional joblib
# ----------------------------
bundle_pack = {
    "model_format": model_format,
    "final_model_path": str(final_model_path),
    "final_bundle": final_bundle,
    "feature_cols": feature_cols,
    "thresholds": thresholds,
    "cfg_meta": cfg_meta,
    "manifest": manifest,
}

bundle_pack_json = OUT_DIR / "model_bundle_pack.json"
bundle_pack_json.write_text(json.dumps(bundle_pack, indent=2))

bundle_pack_joblib = OUT_DIR / "model_bundle_pack.joblib"
joblib_ok = False
try:
    import joblib
    joblib.dump(bundle_pack, bundle_pack_joblib)
    joblib_ok = True
except Exception:
    joblib_ok = False

# ----------------------------
# 6) Create portable ZIP
# ----------------------------
zip_path = OUT_DIR / "model_bundle_v3.zip"

with zipfile.ZipFile(zip_path, "w", compression=zipfile.ZIP_DEFLATED) as zf:
    safe_add(zf, final_model_path, final_model_path.name)
    safe_add(zf, final_bundle_path, final_bundle_path.name)
    safe_add(zf, feature_cols_path, feature_cols_path.name)
    safe_add(zf, thresholds_path, thresholds_path.name)
    safe_add(zf, manifest_path, manifest_path.name)
    safe_add(zf, bundle_pack_json, bundle_pack_json.name)
    if joblib_ok:
        safe_add(zf, bundle_pack_joblib, bundle_pack_joblib.name)

    # optional extras
    safe_add(zf, baseline_report_path, baseline_report_path.name if baseline_report_path else "baseline_report.json")
    safe_add(zf, opt_config_path, opt_config_path.name if opt_config_path else "best_gate_config.json")

    if opt_results_csv:
        safe_add(zf, opt_results_csv, f"opt_search/{opt_results_csv.name}")
    if opt_fold_csv:
        safe_add(zf, opt_fold_csv, f"opt_search/{opt_fold_csv.name}")

    if oof_baseline_csv:
        safe_add(zf, oof_baseline_csv, oof_baseline_csv.name)
    if oof_tf_baseline_csv:
        safe_add(zf, oof_tf_baseline_csv, oof_tf_baseline_csv.name)

print("\nOK — Model bundle finalized")
print("  manifest     ->", manifest_path)
print("  pack (json)  ->", bundle_pack_json)
print("  pack (joblib)->", (bundle_pack_joblib if joblib_ok else "(skip; joblib not available)"))
print("  thresholds   ->", thresholds_path)
print("  zip          ->", zip_path)

print("\nBundle summary:")
print("  model_format :", model_format)
print("  feature_cnt  :", len(feature_cols))
print("  T_gate       :", thresholds.get("T_gate"))
print("  task         :", task_str)