 # Kaggle CPU Environment Setup

In [None]:
# ============================================================
# STAGE 0 — Environment + Paths + Health Checks (ONE CELL)
# REVISI FULL v8.1 (BRUTAL-LB READY + CONSISTENT SIGNAL POLICY)
#
# v8.1 changes vs your v8.0:
# - SIGNAL policy dibuat konsisten dgn saran: flux-safe (asinh_flux) + snr_tanh (bukan mag_ulim)
# - Tambah policy penting untuk stage lanjut:
#     * DELTA_POLICY="per_band" (hindari delta lintas band)
#     * USE_REST_FRAME_TIME=True (t/(1+z))
#     * WINDOW_POLICY="peak_centered" + MULTI_WINDOW_K (TTA windows)
#     * Calibration & threshold robust knobs
# - Health report tambah ringkas statistik Z_err test (domain shift)
# ============================================================

import os, sys, gc, json, time, random, hashlib, warnings
from pathlib import Path

import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# Repro / run identity
# ----------------------------
SEED = 2025
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)

# ----------------------------
# Device detect
# ----------------------------
try:
    import torch
    TORCH_OK = True
    _cuda_ok = torch.cuda.is_available()
except Exception:
    torch = None
    TORCH_OK = False
    _cuda_ok = False

DEVICE = "cuda" if _cuda_ok else "cpu"

# ----------------------------
# Thread policy (anti-freeze, tapi brutal mode boleh lebih tinggi)
# ----------------------------
def _pick_threads(device: str) -> int:
    if device == "cuda":
        return 2
    cpu = os.cpu_count() or 4
    # sedikit lebih fleksibel untuk feature engineering, tapi tetap cap
    return int(min(10, max(2, cpu // 2)))

THREADS = _pick_threads(DEVICE)
for k in ["OMP_NUM_THREADS","OPENBLAS_NUM_THREADS","MKL_NUM_THREADS","VECLIB_MAXIMUM_THREADS","NUMEXPR_NUM_THREADS"]:
    os.environ.setdefault(k, str(THREADS))
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

if TORCH_OK:
    torch.manual_seed(SEED)
    try:
        torch.set_num_threads(THREADS)
        torch.set_num_interop_threads(1)
    except Exception:
        pass

# ----------------------------
# Helpers
# ----------------------------
SAFE_NA_VALUES = ["", " ", "NA", "NaN", "nan", "NULL", "null", "None", "none"]
SAFE_READ_KW = dict(low_memory=False, na_values=SAFE_NA_VALUES, keep_default_na=True)

def _must_exist(p: Path, what: str):
    if not p.exists():
        raise FileNotFoundError(f"[MISSING] {what}: {p}")

def _norm_cols(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df.columns = [c.strip() for c in df.columns]
    return df

def _normalize_split(x):
    if pd.isna(x):
        return ""
    s = str(x).strip()
    if not s:
        return ""
    if s.isdigit():
        return f"split_{int(s):02d}"
    s = s.lower().replace("-", "_").replace(" ", "_")
    if s.startswith("split_"):
        tail = s.split("split_", 1)[1].strip("_")
        if tail.isdigit():
            return f"split_{int(tail):02d}"
    return s

def _discover_data_root(default_root: Path) -> Path:
    """
    Cari folder di /kaggle/input yang punya:
    - train_log.csv, test_log.csv, sample_submission.csv
    - split_01..split_20 (minimal split_01 dan split_20)
    """
    if default_root.exists():
        return default_root
    base = Path("/kaggle/input")
    if not base.exists():
        return default_root

    candidates = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        if (d / "train_log.csv").exists() and (d / "test_log.csv").exists() and (d / "sample_submission.csv").exists():
            has_01 = (d / "split_01").exists() or (d / "split_1").exists()
            has_20 = (d / "split_20").exists()
            if has_01 and has_20:
                candidates.append(d)

    if len(candidates) == 1:
        return candidates[0]
    if len(candidates) > 1:
        candidates = sorted(candidates, key=lambda x: (not (x / "split_01").exists(), x.name))
        return candidates[0]
    return default_root

def _hash_cfg(d: dict) -> str:
    s = json.dumps(d, sort_keys=True, ensure_ascii=True)
    return hashlib.sha256(s.encode("utf-8")).hexdigest()[:10]

def _safe_float(x, default=0.0):
    try:
        v = float(x)
        if np.isfinite(v):
            return v
    except Exception:
        pass
    return float(default)

def _sample_ids_per_split(df_log: pd.DataFrame, split_name: str, n: int, seed: int) -> list:
    s = df_log.loc[df_log["split"] == split_name, "object_id"].astype(str)
    if len(s) == 0:
        return []
    if len(s) <= n:
        return s.tolist()
    return s.sample(n=n, random_state=seed).tolist()

def _scan_lightcurve_for_ids(csv_path: Path, target_ids: set, chunk_rows: int = 200_000):
    """
    Scan streaming untuk memastikan object_id target benar-benar muncul di file lightcurve.
    Kembalikan: found_ids, obs_count_by_id, band_mask_by_id
    """
    found = set()
    obs_count = {oid: 0 for oid in target_ids}
    band_mask = {oid: 0 for oid in target_ids}  # bitmask u,g,r,i,z,y (0..5)

    band_to_bit = {"u": 1<<0, "g": 1<<1, "r": 1<<2, "i": 1<<3, "z": 1<<4, "y": 1<<5}

    usecols = ["object_id", "Filter"]
    it = pd.read_csv(
        csv_path,
        usecols=usecols,
        dtype={"object_id": "string"},
        chunksize=chunk_rows,
        **SAFE_READ_KW
    )

    for chunk in it:
        chunk = _norm_cols(chunk)
        if "object_id" not in chunk.columns or "Filter" not in chunk.columns:
            raise ValueError(f"[BAD LIGHTCURVE COLUMNS] {csv_path.name}: {list(chunk.columns)}")

        f = chunk["Filter"].astype("string").str.strip().str.lower()
        oid = chunk["object_id"].astype("string").str.strip()

        mask = oid.isin(target_ids)
        if not mask.any():
            continue

        sub_oid = oid[mask].astype(str).values
        sub_flt = f[mask].astype(str).values

        for o, flt in zip(sub_oid, sub_flt):
            found.add(o)
            obs_count[o] = obs_count.get(o, 0) + 1
            band_mask[o] = band_mask.get(o, 0) | band_to_bit.get(flt, 0)

        if len(found) == len(target_ids):
            break

    return found, obs_count, band_mask

def _full_scan_lightcurve_object_ids(csv_path: Path, chunk_rows: int = 200_000):
    found = set()
    it = pd.read_csv(
        csv_path,
        usecols=["object_id"],
        dtype={"object_id": "string"},
        chunksize=chunk_rows,
        **SAFE_READ_KW
    )
    for chunk in it:
        chunk = _norm_cols(chunk)
        ids = chunk["object_id"].astype("string").str.strip()
        found.update(ids.dropna().astype(str).unique().tolist())
    return found

# ----------------------------
# PATHS (auto-discovery)
# ----------------------------
DEFAULT_DATA_ROOT = Path("/kaggle/input/mallorn-dataset")
DATA_ROOT = _discover_data_root(DEFAULT_DATA_ROOT)

PATHS = {
    "DATA_ROOT": DATA_ROOT,
    "SAMPLE_SUB": DATA_ROOT / "sample_submission.csv",
    "TRAIN_LOG":  DATA_ROOT / "train_log.csv",
    "TEST_LOG":   DATA_ROOT / "test_log.csv",
    "SPLITS":     [DATA_ROOT / f"split_{i:02d}" for i in range(1, 21)],
}

# ----------------------------
# WORKDIR (versioned run)
# ----------------------------
WORKDIR = Path("/kaggle/working")
BASE_RUN_DIR = WORKDIR / "mallorn_run"
BASE_RUN_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# CFG — Brutal LB defaults (selaras dengan saran)
# ----------------------------
CFG = {
    # Run intent
    "RUN_MODE": "brutal_lb",
    "DEVICE_PREFERRED": "auto",             # auto/cpu/cuda

    # Multi-seed (ensemble)
    "SEEDS": [2025, 2026, 2027],
    "CV_REPEATS": 1,

    # Model switches
    "USE_GBDT": True,                       # baseline kuat (tabular)
    "USE_TRANSFORMER": True,
    "USE_HYBRID_BLEND": True,
    "USE_STACKING": False,

    # F1 essentials
    "USE_THRESHOLD_TUNING": True,
    "THR_STRATEGY": "per_fold_median",      # per_fold_median / global_best
    "USE_PROBA_CALIBRATION": True,
    "CALIBRATION_METHOD": "temperature",    # temperature / platt / isotonic

    # Photometry / signal policy (KUNCI)
    "USE_DEEXTINCTION": True,
    # SIGNAL_CHANNELS: dibuat flux-safe (bukan mag_ulim) supaya Stage6 scoring & Stage8 agg konsisten
    "SIGNAL_CHANNELS": ["asinh_flux", "snr_tanh"],

    # flux-safe transform details (dipakai stage 4/5)
    "FLUX_SCALE_POLICY": "per_object_median_fluxerr",  # global_constant / per_object_median_fluxerr
    "GLOBAL_FLUX_SCALE": 1.0,            # dipakai jika policy global_constant
    "MIN_FLUXERR": 1e-6,

    # SNR handling
    "SNR_CLIP": 30.0,
    "SNR_DET_THR_LIST": [2.0, 3.0, 4.0],
    "SNR_STRONG_THR": 5.0,
    "SNR_TANH_K": 10.0,                  # snr_tanh = tanh(snr / k)

    # Time policy (fisika)
    "USE_REST_FRAME_TIME": True,         # t_rest = t/(1+z)
    "REST_FRAME_EPS": 1e-6,

    # Delta policy (hindari noise lintas band)
    "DELTA_POLICY": "per_band",          # per_band / global (JANGAN global)

    # IO
    "CHUNK_ROWS": 200_000,

    # CV
    "N_FOLDS": 10,
    "CV_STRATIFY": True,
    "CV_USE_SPLIT_COL": True,
    "CV_FORCE_POS_EACH_FOLD": True,

    # Sequence length + windowing (kunci untuk transformer)
    "MAX_LEN_LIST": [384, 512],
    "WINDOW_POLICY": "peak_centered",    # peak_centered / best_contiguous / multi_window
    "TRAIN_RANDOM_CROP": True,           # augment crop sekitar peak
    "MULTI_WINDOW_K": 3,                 # TTA: 2-3 window saat inference (kalau CPU kuat)
    "PEAK_SCORE": "snr_pos",             # snr_pos / abs_signal

    # Training defaults
    "DEEP_EPOCHS": 25,
    "DEEP_BS": 128,
    "DEEP_LR": 3e-4,
    "DEEP_WEIGHT_DECAY": 0.02,
    "DEEP_POS_WEIGHT_MODE": "auto",
    "DEEP_USE_EMA": True,

    # Stage0 validation
    "STAGE0_LC_VALIDATE_MODE": "sample", # off/sample/full
    "STAGE0_LC_SAMPLE_PER_SPLIT": 80,
    "STAGE0_FAIL_FAST_MISSING_RATE": 0.01,
}

CFG_HASH = _hash_cfg(CFG)
RUN_TAG = time.strftime("%Y%m%d_%H%M%S")
RUN_DIR = BASE_RUN_DIR / f"run_{RUN_TAG}_{CFG_HASH}"

ART_DIR   = RUN_DIR / "artifacts"
CACHE_DIR = RUN_DIR / "cache"
OOF_DIR   = RUN_DIR / "oof"
SUB_DIR   = RUN_DIR / "submissions"
LOG_DIR   = RUN_DIR / "logs"
FEAT_DIR  = CACHE_DIR / "features"
SEQ_DIR   = CACHE_DIR / "seq"

for d in [RUN_DIR, ART_DIR, CACHE_DIR, OOF_DIR, SUB_DIR, LOG_DIR, FEAT_DIR, SEQ_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# ----------------------------
# Validate existence (files + split folders)
# ----------------------------
_must_exist(PATHS["SAMPLE_SUB"], "sample_submission.csv")
_must_exist(PATHS["TRAIN_LOG"],  "train_log.csv")
_must_exist(PATHS["TEST_LOG"],   "test_log.csv")

for sd in PATHS["SPLITS"]:
    tr = sd / "train_full_lightcurves.csv"
    te = sd / "test_full_lightcurves.csv"
    _must_exist(tr, f"{sd.name}/train_full_lightcurves.csv")
    _must_exist(te, f"{sd.name}/test_full_lightcurves.csv")

# ----------------------------
# Load logs + sample
# ----------------------------
df_sub = _norm_cols(pd.read_csv(PATHS["SAMPLE_SUB"], dtype={"object_id": "string"}, **SAFE_READ_KW))
if not {"object_id", "prediction"}.issubset(df_sub.columns):
    raise ValueError(f"sample_submission must have object_id,prediction. Found: {list(df_sub.columns)}")
if df_sub["object_id"].duplicated().any():
    raise ValueError(f"Duplicated object_id in sample_submission: {int(df_sub['object_id'].duplicated().sum())}")

df_train_log = _norm_cols(pd.read_csv(PATHS["TRAIN_LOG"], dtype={"object_id":"string","split":"string"}, **SAFE_READ_KW))
df_test_log  = _norm_cols(pd.read_csv(PATHS["TEST_LOG"],  dtype={"object_id":"string","split":"string"}, **SAFE_READ_KW))

need_train = {"object_id","EBV","Z","split","target"}
need_test  = {"object_id","EBV","Z","split"}
missing_train = sorted(list(need_train - set(df_train_log.columns)))
missing_test  = sorted(list(need_test - set(df_test_log.columns)))
if missing_train:
    raise ValueError(f"train_log missing: {missing_train}")
if missing_test:
    raise ValueError(f"test_log missing: {missing_test}")

df_train_log["split"] = df_train_log["split"].map(_normalize_split)
df_test_log["split"]  = df_test_log["split"].map(_normalize_split)

valid_splits = {f"split_{i:02d}" for i in range(1, 21)}
bad_tr = sorted(set(df_train_log["split"]) - valid_splits)
bad_te = sorted(set(df_test_log["split"]) - valid_splits)
if bad_tr:
    raise ValueError(f"Invalid split in train_log (examples): {bad_tr[:10]}")
if bad_te:
    raise ValueError(f"Invalid split in test_log  (examples): {bad_te[:10]}")

for col in ["EBV", "Z"]:
    df_train_log[col] = pd.to_numeric(df_train_log[col], errors="coerce")
    df_test_log[col]  = pd.to_numeric(df_test_log[col],  errors="coerce")

df_train_log["EBV_missing"] = df_train_log["EBV"].isna().astype("int8")
df_train_log["Z_missing"]   = df_train_log["Z"].isna().astype("int8")
df_test_log["EBV_missing"]  = df_test_log["EBV"].isna().astype("int8")
df_test_log["Z_missing"]    = df_test_log["Z"].isna().astype("int8")

ebv_med = float(df_train_log["EBV"].median(skipna=True)) if df_train_log["EBV"].notna().any() else 0.0
z_med   = float(df_train_log["Z"].median(skipna=True))   if df_train_log["Z"].notna().any()   else 0.0

df_train_log["EBV"] = df_train_log["EBV"].fillna(ebv_med)
df_train_log["Z"]   = df_train_log["Z"].fillna(z_med)
df_test_log["EBV"]  = df_test_log["EBV"].fillna(ebv_med)
df_test_log["Z"]    = df_test_log["Z"].fillna(z_med)

df_train_log["target"] = pd.to_numeric(df_train_log["target"], errors="coerce")
if df_train_log["target"].isna().any():
    raise ValueError(f"train_log target NaN after coercion: {int(df_train_log['target'].isna().sum())}")
u = set(pd.unique(df_train_log["target"]).tolist())
if not u.issubset({0, 1}):
    raise ValueError(f"train_log target must be 0/1. Found: {sorted(list(u))}")

# Z_err handling (domain shift info)
if "Z_err" not in df_test_log.columns:
    df_test_log["Z_err"] = np.nan
df_test_log["Z_err"] = pd.to_numeric(df_test_log["Z_err"], errors="coerce")
df_test_log["has_zerr"] = (~df_test_log["Z_err"].isna()).astype("int8")
df_test_log["Z_err"] = df_test_log["Z_err"].fillna(0.0)

if "Z_err" not in df_train_log.columns:
    df_train_log["Z_err"] = 0.0
df_train_log["Z_err"] = pd.to_numeric(df_train_log["Z_err"], errors="coerce").fillna(0.0)
df_train_log["has_zerr"] = np.zeros(len(df_train_log), dtype=np.int8)

if df_train_log["object_id"].duplicated().any():
    raise ValueError(f"Duplicated object_id in train_log: {int(df_train_log['object_id'].duplicated().sum())}")
if df_test_log["object_id"].duplicated().any():
    raise ValueError(f"Duplicated object_id in test_log:  {int(df_test_log['object_id'].duplicated().sum())}")

sub_ids  = df_sub["object_id"].astype("string").str.strip()
test_ids = df_test_log["object_id"].astype("string").str.strip()

if len(sub_ids) != len(test_ids):
    raise ValueError(f"Row mismatch: sample_submission={len(sub_ids)} vs test_log={len(test_ids)}")

s_sub = set(sub_ids.tolist())
s_tst = set(test_ids.tolist())
if s_sub != s_tst:
    missing_in_test = list(s_sub - s_tst)[:5]
    missing_in_sub  = list(s_tst - s_sub)[:5]
    raise ValueError(
        "sample_submission and test_log object_id set mismatch.\n"
        f"- sample not in test_log (up to5): {missing_in_test}\n"
        f"- test_log not in sample (up to5): {missing_in_sub}"
    )

SUB_ORDER = sub_ids.tolist()
OID2SPLIT_TRAIN = dict(zip(df_train_log["object_id"].astype(str), df_train_log["split"].astype(str)))
OID2SPLIT_TEST  = dict(zip(df_test_log["object_id"].astype(str),  df_test_log["split"].astype(str)))

# ----------------------------
# Health report: per split meta summary
# ----------------------------
split_rows = []
for sp in sorted(valid_splits):
    tr = df_train_log[df_train_log["split"] == sp]
    te = df_test_log[df_test_log["split"] == sp]
    pos = int((tr["target"] == 1).sum())
    tot = int(len(tr))
    split_rows.append({
        "split": sp,
        "train_n": tot,
        "train_pos": pos,
        "train_pos_pct": (pos / max(tot, 1)) * 100.0,
        "test_n": int(len(te)),
        "train_Z_med": _safe_float(tr["Z"].median(), 0.0) if tot else 0.0,
        "train_EBV_med": _safe_float(tr["EBV"].median(), 0.0) if tot else 0.0,
        "test_has_zerr_pct": float(te["has_zerr"].mean() * 100.0) if len(te) else 0.0,
        "test_Zerr_med": _safe_float(te["Z_err"].median(), 0.0) if len(te) else 0.0,
    })
df_split_summary = pd.DataFrame(split_rows).sort_values("split").reset_index(drop=True)

# ----------------------------
# FAIL-FAST Lightcurve validation (off/sample/full)
# ----------------------------
lc_mode = str(CFG.get("STAGE0_LC_VALIDATE_MODE", "sample")).lower().strip()
lc_sample_n = int(CFG.get("STAGE0_LC_SAMPLE_PER_SPLIT", 80))
chunk_rows = int(CFG.get("CHUNK_ROWS", 200_000))
fail_rate = float(CFG.get("STAGE0_FAIL_FAST_MISSING_RATE", 0.01))

lc_diag = {
    "mode": lc_mode,
    "sample_per_split": lc_sample_n,
    "chunk_rows": chunk_rows,
    "missing_train_ids": [],
    "missing_test_ids": [],
    "missing_train_count": 0,
    "missing_test_count": 0,
    "missing_train_rate": 0.0,
    "missing_test_rate": 0.0,
    "sample_obs_stats": {},
    "sample_band_coverage": {}
}

if lc_mode not in ["off", "sample", "full"]:
    raise ValueError("CFG['STAGE0_LC_VALIDATE_MODE'] must be one of: off/sample/full")

if lc_mode != "off":
    print("\n[STAGE0] Lightcurve validation:", lc_mode)

    missing_train = []
    missing_test = []
    sample_obs_stats = {}
    sample_band_cov = {}

    for sp in sorted(valid_splits):
        split_dir = DATA_ROOT / sp
        tr_path = split_dir / "train_full_lightcurves.csv"
        te_path = split_dir / "test_full_lightcurves.csv"

        tr_ids = df_train_log.loc[df_train_log["split"] == sp, "object_id"].astype(str).tolist()
        te_ids = df_test_log.loc[df_test_log["split"] == sp, "object_id"].astype(str).tolist()

        if lc_mode == "full":
            found_tr = _full_scan_lightcurve_object_ids(tr_path, chunk_rows=chunk_rows) if len(tr_ids) else set()
            found_te = _full_scan_lightcurve_object_ids(te_path, chunk_rows=chunk_rows) if len(te_ids) else set()

            miss_tr = sorted(list(set(tr_ids) - found_tr))
            miss_te = sorted(list(set(te_ids) - found_te))

            if miss_tr:
                missing_train.extend([(sp, x) for x in miss_tr[:50]])
            if miss_te:
                missing_test.extend([(sp, x) for x in miss_te[:50]])

            sample_obs_stats[sp] = {
                "full_mode": True,
                "train_missing": len(miss_tr),
                "test_missing": len(miss_te),
            }
            sample_band_cov[sp] = {"full_mode": True}

        else:
            tr_s = _sample_ids_per_split(df_train_log, sp, lc_sample_n, SEED)
            te_s = _sample_ids_per_split(df_test_log,  sp, lc_sample_n, SEED + 1)

            band_cov_bits = 0

            if len(tr_s):
                found, obs_count, band_mask = _scan_lightcurve_for_ids(tr_path, set(tr_s), chunk_rows=chunk_rows)
                miss = [x for x in tr_s if x not in found]
                missing_train.extend([(sp, x) for x in miss])

                counts = [obs_count.get(x, 0) for x in tr_s]
                bandbits = [band_mask.get(x, 0) for x in tr_s]
                band_cov_bits |= int(np.bitwise_or.reduce(bandbits)) if len(bandbits) else 0

                sample_obs_stats.setdefault(sp, {})
                sample_obs_stats[sp].update({
                    "train_sample_n": len(tr_s),
                    "train_sample_missing": len(miss),
                    "train_sample_obs_min": int(np.min(counts)) if len(counts) else 0,
                    "train_sample_obs_med": float(np.median(counts)) if len(counts) else 0.0,
                    "train_sample_obs_p95": float(np.percentile(counts, 95)) if len(counts) else 0.0,
                })

            if len(te_s):
                found, obs_count, band_mask = _scan_lightcurve_for_ids(te_path, set(te_s), chunk_rows=chunk_rows)
                miss = [x for x in te_s if x not in found]
                missing_test.extend([(sp, x) for x in miss])

                counts = [obs_count.get(x, 0) for x in te_s]
                bandbits = [band_mask.get(x, 0) for x in te_s]
                band_cov_bits |= int(np.bitwise_or.reduce(bandbits)) if len(bandbits) else 0

                sample_obs_stats.setdefault(sp, {})
                sample_obs_stats[sp].update({
                    "test_sample_n": len(te_s),
                    "test_sample_missing": len(miss),
                    "test_sample_obs_min": int(np.min(counts)) if len(counts) else 0,
                    "test_sample_obs_med": float(np.median(counts)) if len(counts) else 0.0,
                    "test_sample_obs_p95": float(np.percentile(counts, 95)) if len(counts) else 0.0,
                })

            bit_to_band = [(1<<0,"u"),(1<<1,"g"),(1<<2,"r"),(1<<3,"i"),(1<<4,"z"),(1<<5,"y")]
            bands_present = [b for bit,b in bit_to_band if (band_cov_bits & bit)]
            sample_band_cov[sp] = {
                "bands_present_in_sample": bands_present,
                "bands_present_count": len(bands_present),
            }

    lc_diag["missing_train_ids"] = missing_train[:200]
    lc_diag["missing_test_ids"] = missing_test[:200]
    lc_diag["missing_train_count"] = len(missing_train)
    lc_diag["missing_test_count"] = len(missing_test)

    if lc_mode == "full":
        lc_diag["missing_train_rate"] = None
        lc_diag["missing_test_rate"]  = None
    else:
        total_train_sample = sum(v.get("train_sample_n", 0) for v in sample_obs_stats.values())
        total_test_sample  = sum(v.get("test_sample_n", 0)  for v in sample_obs_stats.values())
        lc_diag["missing_train_rate"] = (len(missing_train) / max(total_train_sample, 1))
        lc_diag["missing_test_rate"]  = (len(missing_test)  / max(total_test_sample, 1))

    lc_diag["sample_obs_stats"] = sample_obs_stats
    lc_diag["sample_band_coverage"] = sample_band_cov

    if lc_mode == "sample":
        if lc_diag["missing_test_rate"] > fail_rate or lc_diag["missing_train_rate"] > fail_rate:
            raise RuntimeError(
                "[FAIL-FAST] Lightcurve validation indicates missing object_id in lightcurve files.\n"
                f"- missing_train_rate={lc_diag['missing_train_rate']:.3%}\n"
                f"- missing_test_rate ={lc_diag['missing_test_rate']:.3%}\n"
                "Periksa split routing / file path / object_id normalization."
            )
        if len(missing_test) > 0 or len(missing_train) > 0:
            print("[WARN] Ada sample object_id yang tidak ditemukan di lightcurve file.")
            print("       Contoh missing_test:", missing_test[:5])
            print("       Contoh missing_train:", missing_train[:5])

# ----------------------------
# Basic counts
# ----------------------------
pos = int((df_train_log["target"] == 1).sum())
neg = int((df_train_log["target"] == 0).sum())
tot = int(len(df_train_log))

print("ENV OK (Stage0)")
print(f"- DEVICE: {DEVICE} | THREADS: {THREADS}")
print(f"- DATA_ROOT: {DATA_ROOT}")
print(f"- Python: {sys.version.split()[0]}")
print(f"- Numpy:  {np.__version__}")
print(f"- Pandas: {pd.__version__}")
if TORCH_OK:
    print(f"- Torch:  {torch.__version__} | CUDA: {_cuda_ok}")
else:
    print("- Torch:  not available")

print("\nDATA OK")
print(f"- train_log objects: {tot:,} | pos={pos:,} | neg={neg:,} | pos%={(pos/max(tot,1))*100:.3f}%")
print(f"- test_log objects:  {len(df_test_log):,}")
print(f"- sample_submission: {len(df_sub):,}")
print(f"- splits: {len(PATHS['SPLITS'])} folders (01..20)")
print(f"- RUN_DIR: {RUN_DIR}")

print("\nSPLIT SUMMARY (meta)")
try:
    display(df_split_summary.head(10))
except Exception:
    print(df_split_summary.head(10).to_string(index=False))

# ----------------------------
# Save diagnostics snapshot
# ----------------------------
diag = {
    "SEED": SEED,
    "DEVICE": DEVICE,
    "THREADS": THREADS,
    "CFG": CFG,
    "CFG_HASH": CFG_HASH,
    "RUN_DIR": str(RUN_DIR),
    "DATA_ROOT": str(DATA_ROOT),
    "counts": {
        "train_objects": int(len(df_train_log)),
        "train_pos": int(pos),
        "train_neg": int(neg),
        "test_objects": int(len(df_test_log)),
        "sample_rows": int(len(df_sub)),
    },
    "split_meta_summary": df_split_summary.to_dict(orient="records"),
    "lightcurve_validation": lc_diag,
}

with open(RUN_DIR / "config_stage0.json", "w", encoding="utf-8") as f:
    json.dump(
        {
            "SEED": SEED,
            "DEVICE": DEVICE,
            "THREADS": THREADS,
            "CFG": CFG,
            "CFG_HASH": CFG_HASH,
            "RUN_DIR": str(RUN_DIR),
            "DATA_ROOT": str(DATA_ROOT),
        },
        f, indent=2
    )

with open(RUN_DIR / "run_diagnostics.json", "w", encoding="utf-8") as f:
    json.dump(diag, f, indent=2)

# ----------------------------
# Export globals
# ----------------------------
globals().update({
    "SEED": SEED, "DEVICE": DEVICE, "THREADS": THREADS,
    "CFG": CFG, "CFG_HASH": CFG_HASH,
    "PATHS": PATHS, "DATA_ROOT": DATA_ROOT, "RUN_DIR": RUN_DIR,
    "ART_DIR": ART_DIR, "CACHE_DIR": CACHE_DIR, "OOF_DIR": OOF_DIR,
    "SUB_DIR": SUB_DIR, "LOG_DIR": LOG_DIR, "FEAT_DIR": FEAT_DIR, "SEQ_DIR": SEQ_DIR,
    "df_sub": df_sub, "df_train_log": df_train_log, "df_test_log": df_test_log,
    "OID2SPLIT_TRAIN": OID2SPLIT_TRAIN, "OID2SPLIT_TEST": OID2SPLIT_TEST,
    "SUB_ORDER": SUB_ORDER,
    "df_split_summary": df_split_summary,
})

gc.collect()


# Verify Dataset Paths & Split Discovery

In [None]:
# ============================================================
# STAGE 1 — Split Routing + LC Profiling + Object Quality Features (ONE CELL)
# REVISI FULL v5.1 (BRUTAL-LB READY + ANTI-AGN + REST-FRAME READY)
#
# v5.1 vs v5.0:
# - Object quality tambah fitur kuat:
#   * det_count abs + POS/NEG split (anti-AGN)
#   * multi-threshold det counts (ikut CFG["SNR_DET_THR_LIST"])
#   * snr_pos_max / snr_neg_min
#   * flux_max/min/mean/std + ferr_mean (valid numeric rows)
#   * per-band det counts (base thr) + n_bands_det
#   * timespan_rest + cadence_proxy_rest jika USE_REST_FRAME_TIME=True
# - Fix: chunk["Filter"] dinormalisasi sebelum band counts groupby
# - Robust: ignore rows object_id yang tidak ada di log index (tidak crash)
# ============================================================

import re, gc, json, time
from pathlib import Path
import numpy as np
import pandas as pd

# ----------------------------
# 0) Require STAGE 0 globals
# ----------------------------
need0 = ["PATHS", "df_train_log", "df_test_log", "RUN_DIR", "LOG_DIR", "CFG", "SEED"]
for need in need0:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan STAGE 0 dulu.")

DATA_ROOT = Path(PATHS["DATA_ROOT"])
RUN_DIR = Path(RUN_DIR)
LOG_DIR = Path(LOG_DIR)
LOG_DIR.mkdir(parents=True, exist_ok=True)

CFG_LOCAL = globals().get("CFG", {}) or {}
SEED = int(globals().get("SEED", 2025))

# deterministic split list (ikuti PATHS["SPLITS"])
if "SPLITS" not in PATHS or not isinstance(PATHS["SPLITS"], (list, tuple)) or len(PATHS["SPLITS"]) == 0:
    raise RuntimeError("PATHS['SPLITS'] tidak valid. Pastikan STAGE 0 sukses.")
SPLIT_LIST = [Path(p).name for p in PATHS["SPLITS"]]
SPLIT_LIST = [s for s in SPLIT_LIST if s.startswith("split_")]
SPLIT_LIST = sorted(SPLIT_LIST)  # split_01..split_20

VALID_SPLITS = set([f"split_{i:02d}" for i in range(1, 21)])
if set(SPLIT_LIST) != VALID_SPLITS:
    raise RuntimeError(
        "SPLIT_LIST mismatch dengan expected split_01..split_20.\n"
        f"Found (first 10): {sorted(list(set(SPLIT_LIST)))[:10]}"
    )

SPLIT_DIRS = {Path(p).name: Path(p) for p in PATHS["SPLITS"]}

# ----------------------------
# 1) Safe read config
# ----------------------------
SAFE_NA_VALUES = ["", " ", "NA", "NaN", "nan", "NULL", "null", "None", "none"]
SAFE_READ_KW = dict(low_memory=False, na_values=SAFE_NA_VALUES, keep_default_na=True)

# micro profiling knobs
HEAD_ROWS = int(CFG_LOCAL.get("STAGE1_HEAD_ROWS", CFG_LOCAL.get("LC_HEAD_ROWS", 4000)))

# sample ID presence knobs
SAMPLE_ID_PER_SPLIT = int(
    CFG_LOCAL.get(
        "STAGE1_ID_SAMPLE_PER_SPLIT",
        CFG_LOCAL.get("STAGE0_LC_SAMPLE_PER_SPLIT", CFG_LOCAL.get("SAMPLE_ID_PER_SPLIT", 12))
    )
)

CHUNK_ROWS = int(CFG_LOCAL.get("CHUNK_ROWS", 200_000))

# adaptive scan caps
MAX_CHUNKS_PER_FILE = int(CFG_LOCAL.get("MAX_CHUNKS_PER_FILE", 6))
MAX_CHUNKS_HARD = int(CFG_LOCAL.get("MAX_CHUNKS_HARD", 30))

# numeric sanity thresholds
MAX_TIME_NA_FRAC = float(CFG_LOCAL.get("MAX_TIME_NA_FRAC", 0.02))
MAX_FERR_NA_FRAC = float(CFG_LOCAL.get("MAX_FERR_NA_FRAC", 0.02))
MIN_SAMPLE_ROWS  = int(CFG_LOCAL.get("MIN_SAMPLE_ROWS", 200))

# ID miss handling
ID_MISS_FAIL_FRAC = float(CFG_LOCAL.get("ID_MISS_FAIL_FRAC", 0.80))
FAIL_FAST_MISSING_RATE = float(
    CFG_LOCAL.get("STAGE1_FAIL_FAST_MISSING_RATE", CFG_LOCAL.get("STAGE0_FAIL_FAST_MISSING_RATE", 0.01))
)

# Stage1 validate mode
LC_VALIDATE_MODE = str(
    CFG_LOCAL.get("STAGE1_LC_VALIDATE_MODE", CFG_LOCAL.get("STAGE0_LC_VALIDATE_MODE", "sample"))
).lower().strip()
if LC_VALIDATE_MODE not in ["off", "sample", "full"]:
    raise ValueError("CFG['STAGE1_LC_VALIDATE_MODE'] must be one of: off/sample/full")

# SNR policy
SNR_CLIP = float(CFG_LOCAL.get("SNR_CLIP", 30.0))
MIN_FLUXERR = float(CFG_LOCAL.get("MIN_FLUXERR", 1e-6))

# det threshold list (brutal): simpan multi-threshold count per object
_thr_list = CFG_LOCAL.get("SNR_DET_THR_LIST", None)
if isinstance(_thr_list, (list, tuple)) and len(_thr_list) > 0:
    SNR_DET_THR_LIST = [float(x) for x in _thr_list]
else:
    SNR_DET_THR_LIST = [float(CFG_LOCAL.get("SNR_DET_THR", 3.0))]
# sanitize
SNR_DET_THR_LIST = sorted(list(dict.fromkeys([float(x) for x in SNR_DET_THR_LIST if np.isfinite(float(x)) and float(x) > 0])))
if len(SNR_DET_THR_LIST) == 0:
    SNR_DET_THR_LIST = [3.0]

# base threshold for “per-band det count”
SNR_DET_THR = float(SNR_DET_THR_LIST[0])
SNR_STRONG_THR = float(CFG_LOCAL.get("SNR_STRONG_THR", 5.0))

# rest-frame policy
USE_REST = bool(CFG_LOCAL.get("USE_REST_FRAME_TIME", False))
REST_EPS = float(CFG_LOCAL.get("REST_FRAME_EPS", 1e-6))

# NEW: build full object quality table
BUILD_OBJECT_QUALITY = bool(CFG_LOCAL.get("STAGE1_BUILD_OBJECT_QUALITY", True))
OBJ_QUALITY_CHUNK_ROWS = int(CFG_LOCAL.get("STAGE1_OBJ_QUALITY_CHUNK_ROWS", CHUNK_ROWS))
ZEROOBS_FAIL_RATE = float(CFG_LOCAL.get("STAGE1_ZEROOBS_FAIL_RATE", 0.0005))  # 0.05% default

# ----------------------------
# 2) Helpers
# ----------------------------
REQ_LC_COLS = ["object_id", "Time (MJD)", "Flux", "Flux_err", "Filter"]
REQ_LC_COLS_SET = set([c.strip() for c in REQ_LC_COLS])
ALLOWED_FILTERS = {"u", "g", "r", "i", "z", "y"}
BANDS = ["u","g","r","i","z","y"]
BAND_TO_IDX = {b:i for i,b in enumerate(BANDS)}

def _norm_cols(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df.columns = [c.strip() for c in df.columns]
    return df

def normalize_split_name(x) -> str:
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return ""
    s = str(x).strip()
    if not s:
        return ""
    s2 = s.lower().replace("-", "_").replace(" ", "_")
    if s2.isdigit():
        return f"split_{int(s2):02d}"
    m = re.fullmatch(r"split_(\d{1,2})", s2)
    if m:
        return f"split_{int(m.group(1)):02d}"
    m = re.fullmatch(r"split(\d{1,2})", s2)
    if m:
        return f"split_{int(m.group(1)):02d}"
    return s2

def sizeof_mb(p: Path) -> float:
    try:
        return p.stat().st_size / (1024**2)
    except Exception:
        return float("nan")

_HEADER_CACHE = {}

def _get_usecols(csv_path: Path, required_trimmed: set):
    key = str(csv_path)
    if key in _HEADER_CACHE:
        cols0, trim2orig = _HEADER_CACHE[key]
    else:
        df0 = pd.read_csv(csv_path, nrows=0, **SAFE_READ_KW)
        cols0 = list(df0.columns)
        trim2orig = {}
        for c in cols0:
            ct = str(c).strip()
            if ct not in trim2orig:
                trim2orig[ct] = c
        _HEADER_CACHE[key] = (cols0, trim2orig)

    missing = sorted(list(required_trimmed - set(trim2orig.keys())))
    if missing:
        found_trim = sorted(list(trim2orig.keys()))
        raise ValueError(
            f"[LC SCHEMA] {csv_path} missing required columns (trim-aware): {missing}\n"
            f"Found columns (trimmed, first 50): {found_trim[:50]}"
        )
    usecols = [trim2orig[c] for c in required_trimmed]
    return usecols

def _read_sample_df(p: Path, nrows: int):
    usecols = _get_usecols(p, REQ_LC_COLS_SET)
    dfh = pd.read_csv(p, usecols=usecols, nrows=nrows, **SAFE_READ_KW)
    dfh = _norm_cols(dfh)
    return dfh

def _numeric_and_filter_stats(dfh: pd.DataFrame):
    out = {"n_sample": int(len(dfh))}
    if len(dfh) == 0:
        out.update({k: np.nan for k in [
            "time_na_frac","flux_na_frac","ferr_na_frac",
            "time_min","time_max","time_span",
            "flux_neg_frac","flux_p01","flux_p50","flux_p99",
            "ferr_min","ferr_p50","ferr_p99",
            "snr_abs_p50","snr_abs_p95","snr_ge_det_frac","snr_ge_strong_frac",
            "ferr_zero_frac"
        ]})
        out.update({"filter_bad": "", "filter_sample": ""})
        for b in BANDS:
            out[f"frac_{b}"] = 0.0
        out["ferr_neg_any"] = 0
        return out

    filt = dfh["Filter"].astype("string").str.strip().str.lower()
    filt = filt[~filt.isna()]
    uniq = sorted(set(filt.tolist()))
    bad = sorted([v for v in uniq if v not in ALLOWED_FILTERS])
    out["filter_bad"] = ",".join(bad[:10]) if bad else ""
    out["filter_sample"] = ",".join(uniq[:10]) if uniq else ""

    if len(filt) > 0:
        vc = filt.value_counts()
        denom = float(vc.sum())
        for b in BANDS:
            out[f"frac_{b}"] = float(vc.get(b, 0) / denom)
    else:
        for b in BANDS:
            out[f"frac_{b}"] = 0.0

    t = pd.to_numeric(dfh["Time (MJD)"], errors="coerce")
    f = pd.to_numeric(dfh["Flux"], errors="coerce")
    e = pd.to_numeric(dfh["Flux_err"], errors="coerce")

    out["time_na_frac"] = float(t.isna().mean())
    out["flux_na_frac"] = float(f.isna().mean())
    out["ferr_na_frac"] = float(e.isna().mean())

    if (~t.isna()).any():
        out["time_min"] = float(t.min())
        out["time_max"] = float(t.max())
        out["time_span"] = float(t.max() - t.min())
    else:
        out["time_min"] = np.nan
        out["time_max"] = np.nan
        out["time_span"] = np.nan

    if (~f.isna()).any():
        fd = f.dropna()
        out["flux_neg_frac"] = float((fd < 0).mean())
        out["flux_p01"] = float(np.quantile(fd, 0.01))
        out["flux_p50"] = float(np.quantile(fd, 0.50))
        out["flux_p99"] = float(np.quantile(fd, 0.99))
    else:
        out["flux_neg_frac"] = np.nan
        out["flux_p01"] = np.nan
        out["flux_p50"] = np.nan
        out["flux_p99"] = np.nan

    if (~e.isna()).any():
        ed = e.dropna()
        out["ferr_min"] = float(ed.min())
        out["ferr_p50"] = float(np.quantile(ed, 0.50))
        out["ferr_p99"] = float(np.quantile(ed, 0.99))
        out["ferr_neg_any"] = int((ed < 0).any())
        out["ferr_zero_frac"] = float((ed <= 0).mean())
    else:
        out["ferr_min"] = np.nan
        out["ferr_p50"] = np.nan
        out["ferr_p99"] = np.nan
        out["ferr_neg_any"] = 0
        out["ferr_zero_frac"] = np.nan

    if (~f.isna()).any() and (~e.isna()).any():
        ff = f.to_numpy()
        ee = e.to_numpy()
        m = np.isfinite(ff) & np.isfinite(ee)
        if m.any():
            ee2 = np.maximum(ee[m], MIN_FLUXERR)
            snr = ff[m] / ee2
            snr = np.clip(snr, -SNR_CLIP, SNR_CLIP)
            snr_abs = np.abs(snr)
            out["snr_abs_p50"] = float(np.quantile(snr_abs, 0.50))
            out["snr_abs_p95"] = float(np.quantile(snr_abs, 0.95))
            out["snr_ge_det_frac"] = float((snr_abs >= SNR_DET_THR).mean())
            out["snr_ge_strong_frac"] = float((snr_abs >= SNR_STRONG_THR).mean())
        else:
            out["snr_abs_p50"] = np.nan
            out["snr_abs_p95"] = np.nan
            out["snr_ge_det_frac"] = np.nan
            out["snr_ge_strong_frac"] = np.nan
    else:
        out["snr_abs_p50"] = np.nan
        out["snr_abs_p95"] = np.nan
        out["snr_ge_det_frac"] = np.nan
        out["snr_ge_strong_frac"] = np.nan

    return out

def _sample_id_presence_adaptive(csv_path: Path, want_ids: set, chunk_rows: int,
                                 max_chunks_init: int, max_chunks_hard: int):
    if not want_ids:
        return 0, set(), 0, max_chunks_init

    usecols = _get_usecols(csv_path, {"object_id"})
    remaining = set(want_ids)
    found = set()
    total_chunks_read = 0
    used_cap = max_chunks_init

    cap = max_chunks_init
    while True:
        nread = 0
        for i, chunk in enumerate(pd.read_csv(csv_path, usecols=usecols, chunksize=chunk_rows, **SAFE_READ_KW)):
            if i < total_chunks_read:
                continue
            nread += 1
            total_chunks_read += 1
            chunk = _norm_cols(chunk)
            ids = set(chunk["object_id"].astype("string").dropna().str.strip().astype(str).tolist())
            hit = remaining & ids
            if hit:
                found |= hit
                remaining -= hit
            if not remaining:
                break
            if nread >= cap:
                break

        used_cap = cap
        if not remaining:
            break
        if total_chunks_read >= max_chunks_hard:
            break
        cap = min(cap * 2, max_chunks_hard)

    return len(found), remaining, total_chunks_read, used_cap

# ----------------------------
# 3) Normalize split col in logs (idempotent)
# ----------------------------
for df, name in [(df_train_log, "train_log"), (df_test_log, "test_log")]:
    if "split" not in df.columns:
        raise ValueError(f"{name} missing 'split' column.")
    df["split"] = df["split"].map(normalize_split_name)

bad_train_split = sorted(set(df_train_log["split"].unique()) - VALID_SPLITS)
bad_test_split  = sorted(set(df_test_log["split"].unique())  - VALID_SPLITS)
if bad_train_split:
    raise ValueError(f"train_log has invalid split values (examples): {bad_train_split[:10]}")
if bad_test_split:
    raise ValueError(f"test_log has invalid split values (examples): {bad_test_split[:10]}")

# ----------------------------
# 4) Verify disk splits set + required files exist
# ----------------------------
disk_splits = set(SPLIT_DIRS.keys())
missing_dirs = sorted(list(VALID_SPLITS - disk_splits))
extra_dirs   = sorted(list(disk_splits - VALID_SPLITS))
if missing_dirs or extra_dirs:
    msg = []
    if missing_dirs: msg.append(f"Missing split folders: {missing_dirs[:10]}")
    if extra_dirs:   msg.append(f"Unexpected split folders: {extra_dirs[:10]}")
    raise RuntimeError("Split folder set mismatch.\n" + "\n".join(msg))

missing_files = []
for sp in SPLIT_LIST:
    sd = SPLIT_DIRS[sp]
    for kind in ["train", "test"]:
        p = sd / f"{kind}_full_lightcurves.csv"
        if not p.exists():
            missing_files.append(str(p))
if missing_files:
    raise FileNotFoundError("Some lightcurve files missing (showing up to 10):\n" + "\n".join(missing_files[:10]))

# ----------------------------
# 5) Build routing manifest
# ----------------------------
train_counts = df_train_log["split"].value_counts().to_dict()
test_counts  = df_test_log["split"].value_counts().to_dict()

routing_rows = []
for sp in SPLIT_LIST:
    sd = SPLIT_DIRS[sp]
    tr = sd / "train_full_lightcurves.csv"
    te = sd / "test_full_lightcurves.csv"
    routing_rows.append({
        "split": sp,
        "train_csv": str(tr),
        "test_csv": str(te),
        "train_mb": sizeof_mb(tr),
        "test_mb": sizeof_mb(te),
        "n_train_objects_log": int(train_counts.get(sp, 0)),
        "n_test_objects_log":  int(test_counts.get(sp, 0)),
    })

df_routing = pd.DataFrame(routing_rows)
routing_path = LOG_DIR / "split_routing.csv"
df_routing.to_csv(routing_path, index=False)

# ----------------------------
# 6) Micro profiling + ID crosscheck
# ----------------------------
stats_rows = []
id_warn_rows = []
warn_flux_na_files = 0

t0 = time.time()

agg_id_total = 0
agg_id_missing = 0

print(f"[STAGE1] LC_VALIDATE_MODE={LC_VALIDATE_MODE} | HEAD_ROWS={HEAD_ROWS} | SAMPLE_ID_PER_SPLIT={SAMPLE_ID_PER_SPLIT}")
print(f"[STAGE1] SNR_DET_THR_LIST={SNR_DET_THR_LIST} | SNR_STRONG_THR={SNR_STRONG_THR} | USE_REST={USE_REST}")

for sp in SPLIT_LIST:
    sd = SPLIT_DIRS[sp]
    for kind in ["train", "test"]:
        p = sd / f"{kind}_full_lightcurves.csv"

        dfh = _read_sample_df(p, nrows=HEAD_ROWS)
        if len(dfh) < MIN_SAMPLE_ROWS:
            raise ValueError(f"[LC SAMPLE] Too few rows sampled from {p} (n={len(dfh)}). Possible read issue.")

        st = _numeric_and_filter_stats(dfh)

        if st.get("filter_bad", ""):
            raise ValueError(f"[LC FILTER] Unexpected Filter values in {p}: {st['filter_bad']} (sample={st.get('filter_sample','')})")

        if st.get("time_na_frac", 0.0) > MAX_TIME_NA_FRAC:
            raise ValueError(f"[LC NUM] Time(MJD) NaN too high in sample: {p} frac={st['time_na_frac']:.4f}")
        if st.get("ferr_na_frac", 0.0) > MAX_FERR_NA_FRAC:
            raise ValueError(f"[LC NUM] Flux_err NaN too high in sample: {p} frac={st['ferr_na_frac']:.4f}")
        if int(st.get("ferr_neg_any", 0)) == 1:
            raise ValueError(f"[LC NUM] Negative Flux_err detected in sample of {p} (should be >=0).")

        if st.get("flux_na_frac", 0.0) > 0:
            warn_flux_na_files += 1

        id_k = 0
        id_found = 0
        id_missing = 0
        id_scan_chunks = 0
        id_scan_cap_used = 0
        miss_ids_list = []

        if LC_VALIDATE_MODE in ["sample", "full"]:
            if kind == "train":
                ids = df_train_log.loc[df_train_log["split"] == sp, "object_id"].astype("string").dropna().str.strip()
            else:
                ids = df_test_log.loc[df_test_log["split"] == sp, "object_id"].astype("string").dropna().str.strip()

            if LC_VALIDATE_MODE == "full":
                usecols = _get_usecols(p, {"object_id"})
                found_all = set()
                for chunk in pd.read_csv(p, usecols=usecols, chunksize=CHUNK_ROWS, **SAFE_READ_KW):
                    chunk = _norm_cols(chunk)
                    found_all.update(chunk["object_id"].astype("string").dropna().str.strip().astype(str).unique().tolist())
                miss = sorted(list(set(ids.astype(str).tolist()) - found_all))
                id_k = int(len(ids))
                id_missing = int(len(miss))
                id_found = id_k - id_missing
                id_scan_chunks = None
                id_scan_cap_used = None
                miss_ids_list = miss[:20]
            else:
                id_k = int(min(SAMPLE_ID_PER_SPLIT, len(ids)))
                want = set(ids.sample(n=id_k, random_state=SEED + (0 if kind=="train" else 7)).astype(str).tolist()) if id_k > 0 else set()
                found_n, missing_ids, chunks_read, cap_used = _sample_id_presence_adaptive(
                    p, want, CHUNK_ROWS, MAX_CHUNKS_PER_FILE, MAX_CHUNKS_HARD
                )
                id_found = int(found_n)
                id_missing = int(len(missing_ids))
                id_scan_chunks = int(chunks_read)
                id_scan_cap_used = int(cap_used)
                miss_ids_list = list(missing_ids)[:10]

                miss_frac = (id_missing / max(id_k, 1)) if id_k else 0.0
                agg_id_total += id_k
                agg_id_missing += id_missing

                if id_k and miss_frac >= ID_MISS_FAIL_FRAC:
                    raise ValueError(
                        f"[LC ID] Severe mismatch within adaptive scan: {p} missing {id_missing}/{id_k} "
                        f"(chunks_read={chunks_read}, hard_cap={MAX_CHUNKS_HARD}). Example missing: {miss_ids_list[:3]}"
                    )

                if id_k and id_missing > 0:
                    id_warn_rows.append({
                        "split": sp, "kind": kind, "file": str(p),
                        "k": id_k, "missing": id_missing,
                        "chunks_read": chunks_read, "cap_used": cap_used,
                        "example_missing": ",".join(miss_ids_list[:5]),
                    })

        row = {
            "split": sp,
            "kind": kind,
            "file": str(p),
            "file_mb": sizeof_mb(p),
            "n_sample": st.get("n_sample", 0),
            "time_na_frac": st.get("time_na_frac", np.nan),
            "flux_na_frac": st.get("flux_na_frac", np.nan),
            "ferr_na_frac": st.get("ferr_na_frac", np.nan),
            "time_min": st.get("time_min", np.nan),
            "time_max": st.get("time_max", np.nan),
            "time_span": st.get("time_span", np.nan),
            "flux_neg_frac": st.get("flux_neg_frac", np.nan),
            "flux_p01": st.get("flux_p01", np.nan),
            "flux_p50": st.get("flux_p50", np.nan),
            "flux_p99": st.get("flux_p99", np.nan),
            "ferr_min": st.get("ferr_min", np.nan),
            "ferr_p50": st.get("ferr_p50", np.nan),
            "ferr_p99": st.get("ferr_p99", np.nan),
            "ferr_zero_frac": st.get("ferr_zero_frac", np.nan),
            "snr_abs_p50": st.get("snr_abs_p50", np.nan),
            "snr_abs_p95": st.get("snr_abs_p95", np.nan),
            "snr_ge_det_frac": st.get("snr_ge_det_frac", np.nan),
            "snr_ge_strong_frac": st.get("snr_ge_strong_frac", np.nan),
            "snr_det_thr_base": float(SNR_DET_THR),
            "snr_strong_thr": float(SNR_STRONG_THR),
            "filter_sample": st.get("filter_sample", ""),
            "id_check_k": int(id_k),
            "id_found": int(id_found),
            "id_missing": int(id_missing),
            "id_scan_chunks": id_scan_chunks,
            "id_scan_cap_used": id_scan_cap_used,
        }
        for b in BANDS:
            row[f"frac_{b}"] = st.get(f"frac_{b}", 0.0)

        stats_rows.append(row)

df_lc_stats = pd.DataFrame(stats_rows)
lc_stats_path = LOG_DIR / "lc_sample_stats.csv"
df_lc_stats.to_csv(lc_stats_path, index=False)

df_id_warn = pd.DataFrame(id_warn_rows)
id_warn_path = LOG_DIR / "lc_id_presence_warnings.csv"
df_id_warn.to_csv(id_warn_path, index=False)

agg_missing_rate = (agg_id_missing / max(agg_id_total, 1)) if (LC_VALIDATE_MODE == "sample") else None
if LC_VALIDATE_MODE == "sample":
    if agg_missing_rate > FAIL_FAST_MISSING_RATE:
        raise RuntimeError(
            "[FAIL-FAST] Aggregate sample-ID missing rate terlalu tinggi.\n"
            f"- agg_missing_rate={agg_missing_rate:.3%} (missing={agg_id_missing}, total_sample={agg_id_total})\n"
            f"- threshold={FAIL_FAST_MISSING_RATE:.3%}\n"
            "Ini indikasi routing/split/object_id normalization bermasalah."
        )

# ----------------------------
# 7) Build object_quality tables (train/test) — streaming scan
# ----------------------------
objq_train_path = None
objq_test_path = None
splitq_path = None

if BUILD_OBJECT_QUALITY:
    print("\n[STAGE1] Building object_quality (full streaming scan) ...")

    _tr = df_train_log.copy()
    _te = df_test_log.copy()
    _tr["object_id"] = _tr["object_id"].astype("string").str.strip()
    _te["object_id"] = _te["object_id"].astype("string").str.strip()

    # include has_zerr if exists
    keep_tr = ["split","target","Z","Z_err","EBV"]
    keep_te = ["split","Z","Z_err","EBV"]
    if "has_zerr" in _tr.columns: keep_tr.append("has_zerr")
    if "has_zerr" in _te.columns: keep_te.append("has_zerr")

    obj_train = _tr.set_index("object_id")[keep_tr].copy()
    obj_test  = _te.set_index("object_id")[keep_te].copy()

    # numeric holders
    def _init_obj_table(df):
        df = df.copy()

        df["n_obs_total"] = np.int32(0)
        for b in BANDS:
            df[f"n_{b}"] = np.int32(0)

        # per-band det counts for base det thr
        for b in BANDS:
            df[f"det_{b}"] = np.int32(0)

        # aggregate counters
        df["neg_count"] = np.int32(0)
        df["pos_count"] = np.int32(0)
        df["ferr0_count"] = np.int32(0)

        # multi-threshold det counts (abs + pos + neg)
        for thr in SNR_DET_THR_LIST:
            key = int(round(thr * 10))  # 2.0->20, 3.0->30
            df[f"snr_det_abs_{key}"] = np.int32(0)
            df[f"snr_det_pos_{key}"] = np.int32(0)
            df[f"snr_det_neg_{key}"] = np.int32(0)

        # strong counts (abs + pos + neg)
        df["snr_strong_abs"] = np.int32(0)
        df["snr_strong_pos"] = np.int32(0)
        df["snr_strong_neg"] = np.int32(0)

        # maxima/minima
        df["snr_abs_max"] = np.float32(0.0)
        df["snr_pos_max"] = np.float32(0.0)
        df["snr_neg_min"] = np.float32(0.0)  # negative (<=0), more negative => stronger neg excursion

        df["flux_max"] = np.float32(-np.inf)
        df["flux_min"] = np.float32(np.inf)

        # sums for mean/std
        df["flux_sum"] = np.float64(0.0)
        df["flux_sumsq"] = np.float64(0.0)
        df["ferr_sum"] = np.float64(0.0)
        df["n_valid"] = np.int32(0)

        # time range
        df["tmin"] = np.float64(np.inf)
        df["tmax"] = np.float64(-np.inf)

        return df

    obj_train = _init_obj_table(obj_train)
    obj_test  = _init_obj_table(obj_test)

    usecols = _get_usecols(SPLIT_DIRS["split_01"] / "train_full_lightcurves.csv", REQ_LC_COLS_SET)  # schema assumed same

    def _update_obj_stats(obj_df: pd.DataFrame, csv_path: Path, unknown_counter: dict):
        idx_master = obj_df.index

        for chunk in pd.read_csv(csv_path, usecols=usecols, chunksize=OBJ_QUALITY_CHUNK_ROWS, **SAFE_READ_KW):
            chunk = _norm_cols(chunk)
            chunk["object_id"] = chunk["object_id"].astype("string").str.strip()
            chunk = chunk[chunk["object_id"].notna()]
            if len(chunk) == 0:
                continue

            # normalize Filter into chunk column (important!)
            chunk["Filter"] = chunk["Filter"].astype("string").str.strip().str.lower().fillna("")
            bad = set(chunk["Filter"].unique().tolist()) - ALLOWED_FILTERS - {""}
            if bad:
                raise ValueError(f"[LC FILTER] Unexpected Filter values in {csv_path}: {sorted(list(bad))[:10]}")

            # restrict to known object_id (avoid crash if file contains extra ids)
            oid = chunk["object_id"].astype(str)
            known_mask = oid.isin(idx_master)
            if not known_mask.any():
                unknown_counter["unknown_rows"] += int(len(chunk))
                continue
            if (~known_mask).any():
                unknown_counter["unknown_rows"] += int((~known_mask).sum())
                chunk = chunk.loc[known_mask].copy()
                oid = chunk["object_id"].astype(str)

            # numeric
            t = pd.to_numeric(chunk["Time (MJD)"], errors="coerce")
            f = pd.to_numeric(chunk["Flux"], errors="coerce")
            e = pd.to_numeric(chunk["Flux_err"], errors="coerce")

            # total n per object from ALL rows
            n_all = oid.value_counts()
            idx_all = n_all.index
            obj_df.loc[idx_all, "n_obs_total"] = (obj_df.loc[idx_all, "n_obs_total"].astype(np.int64) + n_all.astype(np.int64)).astype(np.int32)

            # time min/max (numeric coerce)
            tmin = t.groupby(oid).min()
            tmax = t.groupby(oid).max()
            obj_df.loc[idx_all, "tmin"] = np.minimum(obj_df.loc[idx_all, "tmin"].to_numpy(), tmin.reindex(idx_all).to_numpy())
            obj_df.loc[idx_all, "tmax"] = np.maximum(obj_df.loc[idx_all, "tmax"].to_numpy(), tmax.reindex(idx_all).to_numpy())

            # band counts
            bc = chunk.groupby([oid, chunk["Filter"]]).size().unstack(fill_value=0)
            for b in BANDS:
                if b in bc.columns:
                    obj_df.loc[bc.index, f"n_{b}"] = (obj_df.loc[bc.index, f"n_{b}"].astype(np.int64) + bc[b].astype(np.int64)).astype(np.int32)

            # valid numeric rows for SNR/flux stats
            ff = f.to_numpy()
            ee = e.to_numpy()
            mm = np.isfinite(ff) & np.isfinite(ee)
            if not mm.any():
                continue

            oid_m = oid.to_numpy()[mm]
            flt_m = chunk["Filter"].to_numpy()[mm]
            f_m = ff[mm].astype(np.float64, copy=False)
            e_raw_m = ee[mm].astype(np.float64, copy=False)

            # ferr0
            ferr0 = (e_raw_m <= 0).astype(np.int32)
            # clip for snr
            e_m = np.maximum(e_raw_m, MIN_FLUXERR)
            snr = (f_m / e_m).astype(np.float64, copy=False)
            snr = np.clip(snr, -SNR_CLIP, SNR_CLIP)
            snr_abs = np.abs(snr)

            neg = (f_m < 0).astype(np.int32)
            pos = (f_m > 0).astype(np.int32)

            # build tmp frame
            tmp = pd.DataFrame({
                "object_id": oid_m,
                "filter": flt_m,
                "flux": f_m,
                "ferr": e_raw_m,
                "neg": neg,
                "pos": pos,
                "ferr0": ferr0,
                "snr": snr,
                "snr_abs": snr_abs,
                "snr_pos": np.maximum(snr, 0.0),
                "snr_neg": np.minimum(snr, 0.0),
            })

            # multi-threshold det flags (abs + pos + neg)
            for thr in SNR_DET_THR_LIST:
                key = int(round(thr * 10))
                tmp[f"det_abs_{key}"] = (tmp["snr_abs"].to_numpy() >= thr).astype(np.int32)
                tmp[f"det_pos_{key}"] = (tmp["snr"].to_numpy() >= thr).astype(np.int32)
                tmp[f"det_neg_{key}"] = (tmp["snr"].to_numpy() <= -thr).astype(np.int32)

            # strong flags
            tmp["strong_abs"] = (tmp["snr_abs"].to_numpy() >= SNR_STRONG_THR).astype(np.int32)
            tmp["strong_pos"] = (tmp["snr"].to_numpy() >= SNR_STRONG_THR).astype(np.int32)
            tmp["strong_neg"] = (tmp["snr"].to_numpy() <= -SNR_STRONG_THR).astype(np.int32)

            # aggregate per object
            agg_dict = {
                "neg_count": ("neg", "sum"),
                "pos_count": ("pos", "sum"),
                "ferr0_count": ("ferr0", "sum"),
                "snr_abs_max": ("snr_abs", "max"),
                "snr_pos_max": ("snr_pos", "max"),
                "snr_neg_min": ("snr_neg", "min"),
                "flux_max": ("flux", "max"),
                "flux_min": ("flux", "min"),
                "flux_sum": ("flux", "sum"),
                "ferr_sum": ("ferr", "sum"),
                "n_valid": ("flux", "count"),
            }
            agg = tmp.groupby("object_id").agg(**agg_dict)

            # sumsq for std
            tmp["flux2"] = tmp["flux"].to_numpy() * tmp["flux"].to_numpy()
            agg2 = tmp.groupby("object_id").agg(flux_sumsq=("flux2", "sum"))

            idx = agg.index
            obj_df.loc[idx, "neg_count"] = (obj_df.loc[idx, "neg_count"].astype(np.int64) + agg["neg_count"].astype(np.int64)).astype(np.int32)
            obj_df.loc[idx, "pos_count"] = (obj_df.loc[idx, "pos_count"].astype(np.int64) + agg["pos_count"].astype(np.int64)).astype(np.int32)
            obj_df.loc[idx, "ferr0_count"] = (obj_df.loc[idx, "ferr0_count"].astype(np.int64) + agg["ferr0_count"].astype(np.int64)).astype(np.int32)

            obj_df.loc[idx, "flux_sum"] = obj_df.loc[idx, "flux_sum"].to_numpy() + agg["flux_sum"].to_numpy()
            obj_df.loc[idx, "flux_sumsq"] = obj_df.loc[idx, "flux_sumsq"].to_numpy() + agg2["flux_sumsq"].reindex(idx).to_numpy()
            obj_df.loc[idx, "ferr_sum"] = obj_df.loc[idx, "ferr_sum"].to_numpy() + agg["ferr_sum"].to_numpy()
            obj_df.loc[idx, "n_valid"] = (obj_df.loc[idx, "n_valid"].astype(np.int64) + agg["n_valid"].astype(np.int64)).astype(np.int32)

            obj_df.loc[idx, "snr_abs_max"] = np.maximum(obj_df.loc[idx, "snr_abs_max"].to_numpy(), agg["snr_abs_max"].to_numpy()).astype(np.float32)
            obj_df.loc[idx, "snr_pos_max"] = np.maximum(obj_df.loc[idx, "snr_pos_max"].to_numpy(), agg["snr_pos_max"].to_numpy()).astype(np.float32)
            obj_df.loc[idx, "snr_neg_min"] = np.minimum(obj_df.loc[idx, "snr_neg_min"].to_numpy(), agg["snr_neg_min"].to_numpy()).astype(np.float32)

            obj_df.loc[idx, "flux_max"] = np.maximum(obj_df.loc[idx, "flux_max"].to_numpy(), agg["flux_max"].to_numpy()).astype(np.float32)
            obj_df.loc[idx, "flux_min"] = np.minimum(obj_df.loc[idx, "flux_min"].to_numpy(), agg["flux_min"].to_numpy()).astype(np.float32)

            # multi-threshold det counts update
            det_aggs = {}
            for thr in SNR_DET_THR_LIST:
                key = int(round(thr * 10))
                det_aggs[f"det_abs_{key}"] = (f"det_abs_{key}", "sum")
                det_aggs[f"det_pos_{key}"] = (f"det_pos_{key}", "sum")
                det_aggs[f"det_neg_{key}"] = (f"det_neg_{key}", "sum")
            det_sum = tmp.groupby("object_id").agg(**det_aggs)

            for thr in SNR_DET_THR_LIST:
                key = int(round(thr * 10))
                obj_df.loc[det_sum.index, f"snr_det_abs_{key}"] = (obj_df.loc[det_sum.index, f"snr_det_abs_{key}"].astype(np.int64) + det_sum[f"det_abs_{key}"].astype(np.int64)).astype(np.int32)
                obj_df.loc[det_sum.index, f"snr_det_pos_{key}"] = (obj_df.loc[det_sum.index, f"snr_det_pos_{key}"].astype(np.int64) + det_sum[f"det_pos_{key}"].astype(np.int64)).astype(np.int32)
                obj_df.loc[det_sum.index, f"snr_det_neg_{key}"] = (obj_df.loc[det_sum.index, f"snr_det_neg_{key}"].astype(np.int64) + det_sum[f"det_neg_{key}"].astype(np.int64)).astype(np.int32)

            # strong counts
            strong_sum = tmp.groupby("object_id").agg(
                strong_abs=("strong_abs", "sum"),
                strong_pos=("strong_pos", "sum"),
                strong_neg=("strong_neg", "sum"),
            )
            obj_df.loc[strong_sum.index, "snr_strong_abs"] = (obj_df.loc[strong_sum.index, "snr_strong_abs"].astype(np.int64) + strong_sum["strong_abs"].astype(np.int64)).astype(np.int32)
            obj_df.loc[strong_sum.index, "snr_strong_pos"] = (obj_df.loc[strong_sum.index, "snr_strong_pos"].astype(np.int64) + strong_sum["strong_pos"].astype(np.int64)).astype(np.int32)
            obj_df.loc[strong_sum.index, "snr_strong_neg"] = (obj_df.loc[strong_sum.index, "snr_strong_neg"].astype(np.int64) + strong_sum["strong_neg"].astype(np.int64)).astype(np.int32)

            # per-band det counts (base thr only; abs>=SNR_DET_THR)
            tmp["det_base"] = (tmp["snr_abs"].to_numpy() >= SNR_DET_THR).astype(np.int32)
            det_band = tmp[tmp["det_base"] == 1].groupby(["object_id","filter"]).size().unstack(fill_value=0)
            for b in BANDS:
                if b in det_band.columns:
                    obj_df.loc[det_band.index, f"det_{b}"] = (obj_df.loc[det_band.index, f"det_{b}"].astype(np.int64) + det_band[b].astype(np.int64)).astype(np.int32)

    unknown_tr = {"unknown_rows": 0}
    unknown_te = {"unknown_rows": 0}

    for sp in SPLIT_LIST:
        sd = SPLIT_DIRS[sp]
        _update_obj_stats(obj_train, sd / "train_full_lightcurves.csv", unknown_tr)
        _update_obj_stats(obj_test,  sd / "test_full_lightcurves.csv", unknown_te)

    def _finalize_obj(df: pd.DataFrame, is_train: bool):
        df = df.copy()
        n = df["n_obs_total"].astype(np.float64).to_numpy()
        n_safe = np.maximum(n, 1.0)

        # time features
        df["timespan"] = (df["tmax"] - df["tmin"]).astype(np.float64)
        df.loc[~np.isfinite(df["timespan"]), "timespan"] = np.nan

        df.loc[np.isinf(df["tmin"]), "tmin"] = np.nan
        df.loc[np.isinf(df["tmax"]), "tmax"] = np.nan
        df.loc[np.isinf(df["flux_max"]), "flux_max"] = np.nan
        df.loc[np.isinf(df["flux_min"]), "flux_min"] = np.nan

        # ratios
        df["neg_flux_frac"] = (df["neg_count"].astype(np.float64) / n_safe).astype(np.float32)
        df["pos_flux_frac"] = (df["pos_count"].astype(np.float64) / n_safe).astype(np.float32)

        # multi-threshold fractions
        for thr in SNR_DET_THR_LIST:
            key = int(round(thr * 10))
            df[f"snr_det_abs_frac_{key}"] = (df[f"snr_det_abs_{key}"].astype(np.float64) / n_safe).astype(np.float32)
            df[f"snr_det_pos_frac_{key}"] = (df[f"snr_det_pos_{key}"].astype(np.float64) / n_safe).astype(np.float32)
            df[f"snr_det_neg_frac_{key}"] = (df[f"snr_det_neg_{key}"].astype(np.float64) / n_safe).astype(np.float32)

        df["snr_strong_abs_frac"] = (df["snr_strong_abs"].astype(np.float64) / n_safe).astype(np.float32)
        df["snr_strong_pos_frac"] = (df["snr_strong_pos"].astype(np.float64) / n_safe).astype(np.float32)
        df["snr_strong_neg_frac"] = (df["snr_strong_neg"].astype(np.float64) / n_safe).astype(np.float32)

        # band coverage
        band_present = []
        for b in BANDS:
            band_present.append((df[f"n_{b}"].to_numpy() > 0).astype(np.int8))
            df[f"frac_{b}"] = (df[f"n_{b}"].astype(np.float64) / n_safe).astype(np.float32)

        df["n_bands_present"] = np.clip(np.sum(np.vstack(band_present), axis=0), 0, 6).astype(np.int8)

        # det band coverage (base thr)
        det_present = []
        for b in BANDS:
            det_present.append((df[f"det_{b}"].to_numpy() > 0).astype(np.int8))
            df[f"det_frac_{b}"] = (df[f"det_{b}"].astype(np.float64) / n_safe).astype(np.float32)
        df["n_bands_det"] = np.clip(np.sum(np.vstack(det_present), axis=0), 0, 6).astype(np.int8)

        # cadence proxy
        df["cadence_proxy"] = (df["timespan"].astype(np.float64) / np.maximum(df["n_obs_total"].astype(np.float64) - 1.0, 1.0)).astype(np.float32)

        # flux / err stats from sums
        nv = np.maximum(df["n_valid"].astype(np.float64).to_numpy(), 1.0)
        df["flux_mean"] = (df["flux_sum"].astype(np.float64).to_numpy() / nv).astype(np.float32)
        # var = E[x^2] - mean^2
        ex2 = (df["flux_sumsq"].astype(np.float64).to_numpy() / nv)
        var = np.maximum(ex2 - (df["flux_mean"].astype(np.float64).to_numpy() ** 2), 0.0)
        df["flux_std"] = np.sqrt(var).astype(np.float32)
        df["ferr_mean"] = (df["ferr_sum"].astype(np.float64).to_numpy() / nv).astype(np.float32)

        # rest-frame
        if USE_REST and ("Z" in df.columns):
            z = pd.to_numeric(df["Z"], errors="coerce").fillna(0.0).to_numpy(dtype=np.float64)
            denom = np.maximum(1.0 + z, REST_EPS)
            df["timespan_rest"] = (df["timespan"].astype(np.float64).to_numpy() / denom).astype(np.float32)
            df["cadence_proxy_rest"] = (df["cadence_proxy"].astype(np.float64).to_numpy() / denom).astype(np.float32)
            # zerr_rel (domain shift)
            if "Z_err" in df.columns:
                zerr = pd.to_numeric(df["Z_err"], errors="coerce").fillna(0.0).to_numpy(dtype=np.float64)
                df["zerr_rel"] = (zerr / denom).astype(np.float32)

        return df

    obj_train_f = _finalize_obj(obj_train, is_train=True)
    obj_test_f  = _finalize_obj(obj_test,  is_train=False)

    # fail-fast: zero obs objects
    z0_tr = int((obj_train_f["n_obs_total"] == 0).sum())
    z0_te = int((obj_test_f["n_obs_total"] == 0).sum())
    r0_tr = z0_tr / max(len(obj_train_f), 1)
    r0_te = z0_te / max(len(obj_test_f), 1)

    if (r0_tr > ZEROOBS_FAIL_RATE) or (r0_te > ZEROOBS_FAIL_RATE):
        raise RuntimeError(
            "[FAIL-FAST] Terlalu banyak object_id dengan 0 observasi (routing/reading issue).\n"
            f"- train_zero_obs: {z0_tr}/{len(obj_train_f)} ({r0_tr:.3%})\n"
            f"- test_zero_obs : {z0_te}/{len(obj_test_f)} ({r0_te:.3%})\n"
            f"- threshold      : {ZEROOBS_FAIL_RATE:.3%}\n"
        )

    objq_train_path = LOG_DIR / "object_quality_train.csv"
    objq_test_path  = LOG_DIR / "object_quality_test.csv"
    obj_train_f.reset_index().to_csv(objq_train_path, index=False)
    obj_test_f.reset_index().to_csv(objq_test_path, index=False)

    # split-level summary
    def _split_summary(df: pd.DataFrame, is_train: bool):
        g = df.groupby("split", dropna=False)
        rows = []
        for sp, d in g:
            r = {
                "split": sp,
                "n_objects": int(len(d)),
                "n_obs_total_sum": int(d["n_obs_total"].sum()),
                "n_obs_total_med": float(d["n_obs_total"].median()),
                "timespan_med": float(d["timespan"].median(skipna=True)),
                "neg_flux_frac_med": float(d["neg_flux_frac"].median(skipna=True)),
                "pos_flux_frac_med": float(d["pos_flux_frac"].median(skipna=True)),
                "snr_abs_max_p95": float(np.nanpercentile(d["snr_abs_max"].to_numpy(), 95)),
                "snr_pos_max_p95": float(np.nanpercentile(d["snr_pos_max"].to_numpy(), 95)),
                "snr_neg_min_p05": float(np.nanpercentile(d["snr_neg_min"].to_numpy(), 5)),
                "n_bands_present_med": float(d["n_bands_present"].median()),
                "n_bands_det_med": float(d["n_bands_det"].median()),
                "cadence_proxy_med": float(d["cadence_proxy"].median(skipna=True)),
                "flux_std_med": float(d["flux_std"].median(skipna=True)),
            }
            # include base det fractions
            key0 = int(round(SNR_DET_THR * 10))
            if f"snr_det_abs_frac_{key0}" in d.columns:
                r["snr_det_abs_frac_med_base"] = float(d[f"snr_det_abs_frac_{key0}"].median(skipna=True))
                r["snr_det_pos_frac_med_base"] = float(d[f"snr_det_pos_frac_{key0}"].median(skipna=True))
                r["snr_det_neg_frac_med_base"] = float(d[f"snr_det_neg_frac_{key0}"].median(skipna=True))

            for b in BANDS:
                r[f"frac_{b}_med"] = float(d[f"frac_{b}"].median(skipna=True))
                r[f"det_frac_{b}_med"] = float(d[f"det_frac_{b}"].median(skipna=True))

            if USE_REST and "timespan_rest" in d.columns:
                r["timespan_rest_med"] = float(d["timespan_rest"].median(skipna=True))
                r["cadence_proxy_rest_med"] = float(d["cadence_proxy_rest"].median(skipna=True))

            if is_train and "target" in d.columns:
                r["pos"] = int((d["target"] == 1).sum())
                r["pos_pct"] = float((d["target"] == 1).mean() * 100.0)
            rows.append(r)
        return pd.DataFrame(rows).sort_values("split").reset_index(drop=True)

    split_train = _split_summary(obj_train_f, is_train=True)
    split_test  = _split_summary(obj_test_f,  is_train=False)
    split_train["kind"] = "train"
    split_test["kind"]  = "test"
    df_splitq = pd.concat([split_train, split_test], ignore_index=True)

    splitq_path = LOG_DIR / "split_quality_summary.csv"
    df_splitq.to_csv(splitq_path, index=False)

# ----------------------------
# 8) Summary prints + JSON summary
# ----------------------------
elapsed = time.time() - t0

worst_id_missing = (
    df_lc_stats.sort_values(["id_missing", "id_check_k"], ascending=[False, False])
    .head(10)[["split","kind","id_missing","id_check_k","id_scan_chunks","id_scan_cap_used","file_mb"]]
)
worst_flux_neg = (
    df_lc_stats.sort_values("flux_neg_frac", ascending=False)
    .head(10)[["split","kind","flux_neg_frac","snr_ge_det_frac","file_mb"]]
)
worst_snr = (
    df_lc_stats.sort_values("snr_ge_det_frac", ascending=False)
    .head(10)[["split","kind","snr_ge_det_frac","snr_ge_strong_frac","snr_abs_p95","file_mb"]]
)

summary = {
    "stage": "stage1",
    "data_root": str(DATA_ROOT),
    "log_dir": str(LOG_DIR),
    "lc_validate_mode": LC_VALIDATE_MODE,
    "head_rows": HEAD_ROWS,
    "sample_id_per_split": SAMPLE_ID_PER_SPLIT,
    "chunk_rows": CHUNK_ROWS,
    "obj_quality_chunk_rows": OBJ_QUALITY_CHUNK_ROWS,
    "build_object_quality": bool(BUILD_OBJECT_QUALITY),
    "snr": {
        "SNR_CLIP": SNR_CLIP,
        "SNR_DET_THR_LIST": SNR_DET_THR_LIST,
        "SNR_STRONG_THR": SNR_STRONG_THR,
        "MIN_FLUXERR": MIN_FLUXERR,
    },
    "rest_frame": {
        "USE_REST_FRAME_TIME": USE_REST,
        "REST_EPS": REST_EPS,
    },
    "thresholds": {
        "MAX_TIME_NA_FRAC": MAX_TIME_NA_FRAC,
        "MAX_FERR_NA_FRAC": MAX_FERR_NA_FRAC,
        "ID_MISS_FAIL_FRAC": ID_MISS_FAIL_FRAC,
        "FAIL_FAST_MISSING_RATE": FAIL_FAST_MISSING_RATE,
        "MIN_SAMPLE_ROWS": MIN_SAMPLE_ROWS,
        "ZEROOBS_FAIL_RATE": ZEROOBS_FAIL_RATE,
    },
    "aggregate_id_missing": {
        "total_sample_ids": int(agg_id_total),
        "missing_sample_ids": int(agg_id_missing),
        "missing_rate": float(agg_missing_rate) if agg_missing_rate is not None else None
    },
    "warn_flux_na_files": int(warn_flux_na_files),
    "routing_csv": str(routing_path),
    "lc_sample_stats_csv": str(lc_stats_path),
    "lc_id_presence_warnings_csv": str(id_warn_path),
    "object_quality_train_csv": str(objq_train_path) if objq_train_path else None,
    "object_quality_test_csv": str(objq_test_path) if objq_test_path else None,
    "split_quality_summary_csv": str(splitq_path) if splitq_path else None,
    "elapsed_sec": float(elapsed),
}

summary_path = LOG_DIR / "stage1_summary.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

print("\nSTAGE 1 OK — ROUTING + PROFILING + OBJECT_QUALITY READY")
print(f"- routing saved: {routing_path}")
print(f"- lc sample stats saved: {lc_stats_path}")
print(f"- id warnings saved: {id_warn_path}")
if objq_train_path:
    print(f"- object_quality_train: {objq_train_path}")
    print(f"- object_quality_test : {objq_test_path}")
    print(f"- split_quality_summary: {splitq_path}")
print(f"- summary json saved: {summary_path}")
print(f"- elapsed: {elapsed/60:.2f} min | warn_flux_na_files={warn_flux_na_files}")

print("\nTOP ISSUES (ID missing in sample/adaptive scan)")
print(worst_id_missing.to_string(index=False))

print("\nTOP PATTERN (highest negative flux fraction in sample head)")
print(worst_flux_neg.to_string(index=False))

print("\nTOP PATTERN (highest high-SNR fraction in sample head)")
print(worst_snr.to_string(index=False))

# ----------------------------
# 9) Export to globals
# ----------------------------
globals().update({
    "DATA_ROOT": DATA_ROOT,
    "SPLIT_DIRS": SPLIT_DIRS,
    "SPLIT_LIST": SPLIT_LIST,
    "df_split_routing": df_routing,
    "df_lc_sample_stats": df_lc_stats,
    "df_lc_id_presence_warnings": df_id_warn,
    "STAGE1_SUMMARY_PATH": summary_path,
    "OBJECT_QUALITY_TRAIN_PATH": str(objq_train_path) if objq_train_path else None,
    "OBJECT_QUALITY_TEST_PATH": str(objq_test_path) if objq_test_path else None,
    "SPLIT_QUALITY_SUMMARY_PATH": str(splitq_path) if splitq_path else None,
})

gc.collect()
print("\nStage 1 complete: splits verified + routing/stats + object_quality exported.")


# Load and Validate Train/Test Logs

In [None]:
# ============================================================
# STAGE 2 — Clean Meta Logs + CV Fold Assignment + Meta Enrichment (ONE CELL)
# REVISI FULL v6.1 (STAGE1 v5.1 compatible + safe objq join + SNR alias + robust clipping)
#
# v6.1 upgrade:
# - Join object_quality: drop duplicate cols (split/target/Z/Z_err/EBV/has_zerr) sebelum join (no _oq noise)
# - Auto-detect & include ALL useful objq feature cols (multi-threshold + pos/neg + det band + flux stats)
# - Alias safety:
#   * snr_det_frac <= snr_det_abs_frac_{base_key} jika ada
#   * snr_strong_frac <= snr_strong_abs_frac jika ada
# - Z_err clip: prefer train photo-z pool, fallback test pool
# - Optional fold balancing add mean_log1p_obs when objq exists
# ============================================================

import re, gc, json, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Require STAGE 0/1 globals
# ----------------------------
for need in ["PATHS", "ART_DIR", "SPLIT_DIRS", "CFG", "SEED", "LOG_DIR"]:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan STAGE 0 & STAGE 1 dulu.")

TRAIN_LOG_PATH = Path(PATHS["TRAIN_LOG"])
TEST_LOG_PATH  = Path(PATHS["TEST_LOG"])

ART_DIR = Path(ART_DIR); ART_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR = Path(LOG_DIR); LOG_DIR.mkdir(parents=True, exist_ok=True)

SEED = int(SEED)
N_FOLDS = int(CFG.get("N_FOLDS", 5))
CV_USE_SPLIT_COL = bool(CFG.get("CV_USE_SPLIT_COL", True))

# deterministic split list (sinkron STAGE 0/1)
SPLIT_LIST = [Path(p).name for p in PATHS["SPLITS"]]
SPLIT_LIST = sorted([s for s in SPLIT_LIST if s.startswith("split_")])
VALID_SPLITS = set(SPLIT_LIST)

disk_splits = set(SPLIT_DIRS.keys())
if disk_splits != VALID_SPLITS:
    miss = sorted(list(VALID_SPLITS - disk_splits))
    extra = sorted(list(disk_splits - VALID_SPLITS))
    raise RuntimeError(
        f"SPLIT_DIRS mismatch. missing={miss[:5]} extra={extra[:5]} (jalankan ulang STAGE 1)"
    )

SAFE_NA_VALUES = ["", " ", "NA", "NaN", "nan", "NULL", "null", "None", "none"]
SAFE_READ_KW = dict(low_memory=False, na_values=SAFE_NA_VALUES, keep_default_na=True)

# fold assignment tuning
FOLD_QUOTA = int(CFG.get("SPLIT_PER_FOLD_QUOTA", int(np.ceil(len(SPLIT_LIST)/max(N_FOLDS,1)))))
RESTARTS = int(CFG.get("SPLIT_ASSIGN_RESTARTS", 512))
RESTARTS_HARD = int(CFG.get("SPLIT_ASSIGN_RESTARTS_HARD", 2048))
PENALTY_ZERO_POS = float(CFG.get("FOLD_BALANCE_PENALTY_ZERO_POS", 3.0))

# weights objective
LAMBDA_COUNT = float(CFG.get("FOLD_BALANCE_LAMBDA_COUNT", 0.25))
LAMBDA_QUOTA = float(CFG.get("FOLD_BALANCE_LAMBDA_QUOTA", 0.05))
LAMBDA_ZMEAN = float(CFG.get("FOLD_BALANCE_LAMBDA_ZMEAN", 0.15))
LAMBDA_EBVMEAN = float(CFG.get("FOLD_BALANCE_LAMBDA_EBVMEAN", 0.10))
LAMBDA_OBSMEAN = float(CFG.get("FOLD_BALANCE_LAMBDA_OBSMEAN", 0.10))  # aktif jika objq ada

# clipping quantiles
QLO = float(CFG.get("META_QLO", 0.001))
QHI = float(CFG.get("META_QHI", 0.999))

# split prior smoothing
PRIOR_ALPHA = float(CFG.get("SPLIT_PRIOR_ALPHA", 10.0))

# optional join object_quality from stage1
OBJQ_TRAIN = globals().get("OBJECT_QUALITY_TRAIN_PATH", None)
OBJQ_TEST  = globals().get("OBJECT_QUALITY_TEST_PATH", None)
if not OBJQ_TRAIN:
    p = LOG_DIR / "object_quality_train.csv"
    OBJQ_TRAIN = str(p) if p.exists() else None
if not OBJQ_TEST:
    p = LOG_DIR / "object_quality_test.csv"
    OBJQ_TEST = str(p) if p.exists() else None

# base snr det thr key (untuk alias snr_det_frac)
_thr_list = CFG.get("SNR_DET_THR_LIST", None)
if isinstance(_thr_list, (list, tuple)) and len(_thr_list) > 0:
    _thr_list = [float(x) for x in _thr_list if np.isfinite(float(x))]
    _thr_list = sorted(list(dict.fromkeys(_thr_list)))
else:
    _thr_list = [float(CFG.get("SNR_DET_THR", 3.0))]
BASE_DET_THR = float(_thr_list[0]) if len(_thr_list) else 3.0
BASE_DET_KEY = int(round(BASE_DET_THR * 10))  # 3.0->30

# ----------------------------
# 1) Helpers
# ----------------------------
def normalize_split_name(x) -> str:
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return ""
    s = str(x).strip()
    if not s:
        return ""
    s2 = s.lower().replace("-", "_").replace(" ", "_")
    if s2.isdigit():
        return f"split_{int(s2):02d}"
    m = re.fullmatch(r"split_(\d{1,2})", s2)
    if m:
        return f"split_{int(m.group(1)):02d}"
    m = re.fullmatch(r"split(\d{1,2})", s2)
    if m:
        return f"split_{int(m.group(1)):02d}"
    return s2

def _norm_cols(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df.columns = [str(c).strip() for c in df.columns]
    return df

def _coerce_float32(df: pd.DataFrame, col: str):
    if col in df.columns:
        df[col] = pd.to_numeric(df[col], errors="coerce").astype("float32")

def _safe_clip(series: pd.Series, lo: float, hi: float) -> pd.Series:
    # handle NaN bounds
    if not np.isfinite(lo): lo = float(np.nanmin(series.values.astype(float))) if np.isfinite(series.values.astype(float)).any() else 0.0
    if not np.isfinite(hi): hi = float(np.nanmax(series.values.astype(float))) if np.isfinite(series.values.astype(float)).any() else 0.0
    if hi < lo:
        lo, hi = hi, lo
    return series.clip(lower=np.float32(lo), upper=np.float32(hi)).astype("float32")

def _qclip_bounds(arr: np.ndarray, qlo=0.001, qhi=0.999, default=(0.0, 0.0)):
    x = np.asarray(arr, dtype=float)
    x = x[np.isfinite(x)]
    if len(x) == 0:
        return float(default[0]), float(default[1])
    lo, hi = np.quantile(x, [qlo, qhi])
    lo = float(lo); hi = float(hi)
    if not np.isfinite(lo): lo = float(default[0])
    if not np.isfinite(hi): hi = float(default[1])
    if hi < lo: lo, hi = hi, lo
    return lo, hi

def _load_or_use_global(global_name: str, path: Path) -> pd.DataFrame:
    if global_name in globals() and isinstance(globals()[global_name], pd.DataFrame):
        return _norm_cols(globals()[global_name].copy())
    return _norm_cols(pd.read_csv(path, dtype={"object_id":"string","split":"string"}, **SAFE_READ_KW))

def _fill_z(df: pd.DataFrame, split_med: dict, gmed: float):
    z = df["Z"].copy()
    if z.isna().any():
        z = z.fillna(df["split"].map(split_med))
        z = z.fillna(np.float32(gmed))
    return z.astype("float32")

def _read_objq(path: str) -> pd.DataFrame:
    df = pd.read_csv(path, dtype={"object_id": "string"})
    df = _norm_cols(df)
    if "object_id" not in df.columns:
        raise ValueError(f"object_quality missing object_id: {path}")
    df["object_id"] = df["object_id"].astype("string").str.strip()

    # drop columns that duplicate meta log content (avoid _oq collisions)
    drop_dup = set(["split","target","Z","Z_err","EBV","has_zerr","is_photoz"])
    keep_cols = ["object_id"] + [c for c in df.columns if c not in drop_dup and c != "object_id"]

    df = df[keep_cols].copy()

    # coerce numerics where possible (do not crash if some are strings)
    for c in df.columns:
        if c == "object_id":
            continue
        if df[c].dtype == "O":
            # attempt numeric
            df[c] = pd.to_numeric(df[c], errors="ignore")
        # if still object, drop it (keep stage2 meta numeric)
        if df[c].dtype == "O":
            df.drop(columns=[c], inplace=True)

    return df.set_index("object_id", drop=True)

def _assign_splits_to_folds_greedy_multi(sp_stat: pd.DataFrame, n_folds: int, quota: int, seed: int,
                                        lam_count: float, lam_quota: float, lam_z: float, lam_ebv: float, lam_obs: float,
                                        penalty_zero_pos: float, restarts: int):
    rng = np.random.default_rng(seed)

    global_pos_rate = float(sp_stat["pos"].sum() / max(sp_stat["n"].sum(), 1))
    target_fold_n = float(sp_stat["n"].sum() / max(n_folds, 1))

    wsum = float(sp_stat["n"].sum())
    g_z = float((sp_stat["n"] * sp_stat["mean_log1pZ"]).sum() / max(wsum, 1.0))
    g_e = float((sp_stat["n"] * sp_stat["mean_EBV"]).sum() / max(wsum, 1.0))
    g_o = float((sp_stat["n"] * sp_stat["mean_log1pObs"]).sum() / max(wsum, 1.0)) if "mean_log1pObs" in sp_stat.columns else 0.0

    base_order = sp_stat.sort_values(["pos","n"], ascending=False).reset_index(drop=True)

    best = None

    for _ in range(max(restarts, 1)):
        order = base_order.copy()
        order["_j"] = rng.normal(0, 1e-6, size=len(order))
        order = order.sort_values(["pos","n","_j"], ascending=[False,False,True]).drop(columns=["_j"]).reset_index(drop=True)

        fold_n = np.zeros(n_folds, dtype=float)
        fold_pos = np.zeros(n_folds, dtype=float)
        fold_k = np.zeros(n_folds, dtype=int)
        fold_zsum = np.zeros(n_folds, dtype=float)
        fold_esum = np.zeros(n_folds, dtype=float)
        fold_osum = np.zeros(n_folds, dtype=float)

        split2fold = {}

        for _, r in order.iterrows():
            sp = r["split"]; n = float(r["n"]); p = float(r["pos"])
            zmean = float(r["mean_log1pZ"]); eme = float(r["mean_EBV"])
            omean = float(r["mean_log1pObs"]) if "mean_log1pObs" in r else 0.0

            cand = np.where(fold_k < quota)[0]
            if len(cand) == 0:
                cand = np.arange(n_folds)

            scores = []
            for f in cand:
                n2 = fold_n[f] + n
                p2 = fold_pos[f] + p
                pr2 = (p2 / n2) if n2 > 0 else global_pos_rate

                z2 = (fold_zsum[f] + n * zmean) / max(n2, 1.0)
                e2 = (fold_esum[f] + n * eme) / max(n2, 1.0)
                o2 = (fold_osum[f] + n * omean) / max(n2, 1.0)

                score = abs(pr2 - global_pos_rate) \
                        + lam_count * abs(n2 - target_fold_n) / max(target_fold_n, 1.0) \
                        + lam_z * abs(z2 - g_z) \
                        + lam_ebv * abs(e2 - g_e) \
                        + lam_obs * abs(o2 - g_o) \
                        + lam_quota * (fold_k[f] / max(quota, 1))
                scores.append(score)

            scores = np.asarray(scores, dtype=float)
            best_idx = np.where(scores == scores.min())[0]
            choose = int(cand[int(rng.choice(best_idx))]) if len(best_idx) > 1 else int(cand[int(best_idx[0])])

            split2fold[sp] = choose
            fold_n[choose] += n
            fold_pos[choose] += p
            fold_k[choose] += 1
            fold_zsum[choose] += n * zmean
            fold_esum[choose] += n * eme
            fold_osum[choose] += n * omean

        fold_pr = np.divide(fold_pos, np.maximum(fold_n, 1e-9))
        score = float(np.sum(np.abs(fold_pr - global_pos_rate))) \
                + float(lam_count * np.std(fold_n) / max(target_fold_n, 1.0))

        zero_pos = int(np.sum(fold_pos == 0))
        score += penalty_zero_pos * zero_pos

        cand_pack = (score, split2fold, fold_n.copy(), fold_pos.copy(), fold_k.copy(),
                     fold_zsum.copy(), fold_esum.copy(), fold_osum.copy(), zero_pos)

        if best is None or cand_pack[0] < best[0]:
            best = cand_pack

        if best is not None and best[-1] == 0 and best[0] < 0.01:
            break

    return best

# ----------------------------
# 2) Load logs
# ----------------------------
df_train = _load_or_use_global("df_train_log", TRAIN_LOG_PATH)
df_test  = _load_or_use_global("df_test_log",  TEST_LOG_PATH)

# ----------------------------
# 3) Required columns check
# ----------------------------
req_common = {"object_id", "split", "EBV", "Z"}
req_train  = req_common | {"target"}
req_test   = req_common

miss_train = sorted(list(req_train - set(df_train.columns)))
miss_test  = sorted(list(req_test  - set(df_test.columns)))
if miss_train:
    raise ValueError(f"train_log missing columns: {miss_train} | found={list(df_train.columns)}")
if miss_test:
    raise ValueError(f"test_log missing columns: {miss_test} | found={list(df_test.columns)}")

# ----------------------------
# 4) Basic cleaning
# ----------------------------
df_train["object_id"] = df_train["object_id"].astype("string").str.strip()
df_test["object_id"]  = df_test["object_id"].astype("string").str.strip()

df_train["split"] = df_train["split"].astype("string").map(normalize_split_name)
df_test["split"]  = df_test["split"].astype("string").map(normalize_split_name)

bad_train_split = sorted(set(df_train["split"].unique()) - VALID_SPLITS)
bad_test_split  = sorted(set(df_test["split"].unique())  - VALID_SPLITS)
if bad_train_split:
    raise ValueError(f"train_log invalid split values: {bad_train_split[:10]}")
if bad_test_split:
    raise ValueError(f"test_log invalid split values: {bad_test_split[:10]}")

# ----------------------------
# 5) Ensure Z_err exists + numeric coercion
# ----------------------------
if "Z_err" not in df_train.columns:
    df_train["Z_err"] = np.nan
if "Z_err" not in df_test.columns:
    df_test["Z_err"] = np.nan

for c in ["EBV","Z","Z_err"]:
    _coerce_float32(df_train, c)
    _coerce_float32(df_test, c)

df_train["has_zerr"] = (~pd.to_numeric(df_train["Z_err"], errors="coerce").isna()).astype("int8")
df_test["has_zerr"]  = (~pd.to_numeric(df_test["Z_err"],  errors="coerce").isna()).astype("int8")
df_train["is_photoz"] = df_train["has_zerr"].astype("int8")
df_test["is_photoz"]  = df_test["has_zerr"].astype("int8")

# ----------------------------
# 6) Duplicate / overlap checks
# ----------------------------
if df_train["object_id"].duplicated().any():
    ex = df_train.loc[df_train["object_id"].duplicated(), "object_id"].head(5).tolist()
    raise ValueError(f"Duplicated object_id in train_log (examples): {ex}")
if df_test["object_id"].duplicated().any():
    ex = df_test.loc[df_test["object_id"].duplicated(), "object_id"].head(5).tolist()
    raise ValueError(f"Duplicated object_id in test_log (examples): {ex}")

overlap = set(df_train["object_id"].tolist()) & set(df_test["object_id"].tolist())
if overlap:
    raise ValueError(f"object_id overlap train vs test (examples): {list(overlap)[:5]}")

# ----------------------------
# 7) Target validation
# ----------------------------
df_train["target"] = pd.to_numeric(df_train["target"], errors="coerce")
if df_train["target"].isna().any():
    raise ValueError(f"train_log target has NaN after coercion: {int(df_train['target'].isna().sum())} rows.")
uniq_t = set(pd.unique(df_train["target"]).tolist())
if not uniq_t.issubset({0,1}):
    raise ValueError(f"train_log target must be binary 0/1. Found: {sorted(list(uniq_t))}")
df_train["target"] = df_train["target"].astype("int8")

# ----------------------------
# 8) Missing flags + fills (safe)
# ----------------------------
for df in [df_train, df_test]:
    df["EBV_missing"]  = df["EBV"].isna().astype("int8")
    df["Z_missing"]    = df["Z"].isna().astype("int8")
    df["Zerr_missing"] = df["Z_err"].isna().astype("int8")

# fill EBV with train median
ebv_med = float(np.nanmedian(df_train["EBV"].values.astype(float))) if np.isfinite(df_train["EBV"].values.astype(float)).any() else 0.0
df_train["EBV"] = df_train["EBV"].fillna(np.float32(ebv_med)).astype("float32")
df_test["EBV"]  = df_test["EBV"].fillna(np.float32(ebv_med)).astype("float32")

# fill Z with split median then global median (from train only)
train_split_med = df_train.groupby("split")["Z"].median().to_dict()
train_gmed = float(np.nanmedian(df_train["Z"].values.astype(float))) if np.isfinite(df_train["Z"].values.astype(float)).any() else 0.0
df_train["Z"] = _fill_z(df_train, train_split_med, train_gmed)
df_test["Z"]  = _fill_z(df_test,  train_split_med, train_gmed)

# fill Z_err with 0 (keep has_zerr as flag)
df_train["Z_err"] = df_train["Z_err"].fillna(np.float32(0.0)).astype("float32")
df_test["Z_err"]  = df_test["Z_err"].fillna(np.float32(0.0)).astype("float32")

# ----------------------------
# 9) Clipping + derived meta features
# ----------------------------
EBV_LO, EBV_HI = _qclip_bounds(df_train["EBV"].values, QLO, QHI)
Z_LO,   Z_HI   = _qclip_bounds(df_train["Z"].values,   QLO, QHI)

df_train["EBV_clip"] = _safe_clip(df_train["EBV"], EBV_LO, EBV_HI)
df_test["EBV_clip"]  = _safe_clip(df_test["EBV"],  EBV_LO, EBV_HI)

df_train["Z_clip"] = _safe_clip(df_train["Z"], Z_LO, Z_HI)
df_test["Z_clip"]  = _safe_clip(df_test["Z"],  Z_LO, Z_HI)

# Z_err clip: prefer train photo-z pool, fallback to test photo-z pool
ZE_LO = 0.0
pool_tr = df_train.loc[df_train["has_zerr"] == 1, "Z_err"].values
pool_te = df_test.loc[df_test["has_zerr"] == 1, "Z_err"].values
if np.isfinite(pool_tr.astype(float)).any():
    _, ZE_HI = _qclip_bounds(pool_tr, QLO, QHI, default=(0.0, 0.0))
elif np.isfinite(pool_te.astype(float)).any():
    _, ZE_HI = _qclip_bounds(pool_te, QLO, QHI, default=(0.0, 0.0))
else:
    ZE_HI = 0.0
ZE_HI = max(float(ZE_HI), 0.0)

df_train["Zerr_clip"] = _safe_clip(df_train["Z_err"], ZE_LO, ZE_HI)
df_test["Zerr_clip"]  = _safe_clip(df_test["Z_err"],  ZE_LO, ZE_HI)

eps = np.float32(1e-6)

for df in [df_train, df_test]:
    df["log1pZ"] = np.log1p(df["Z_clip"]).astype("float32")
    df["inv_1pz"] = (1.0 / (1.0 + df["Z_clip"])).astype("float32")
    df["z2"] = (df["Z_clip"] * df["Z_clip"]).astype("float32")

    df["log1pEBV"] = np.log1p(df["EBV_clip"]).astype("float32")
    df["ebv_over_1pz"] = (df["EBV_clip"] / (1.0 + df["Z_clip"] + eps)).astype("float32")
    df["ebv_x_1pz"] = (df["EBV_clip"] * (1.0 + df["Z_clip"])).astype("float32")

    df["log1pZerr"] = np.log1p(df["Zerr_clip"]).astype("float32")
    # safer denom
    df["zerr_over_1pz"] = (df["Zerr_clip"] / (1.0 + df["Z_clip"] + eps)).astype("float32")
    df["zerr_rel_z"] = (df["Zerr_clip"] / (df["Z_clip"] + eps)).astype("float32")
    df["high_zerr"] = (df["zerr_over_1pz"] > np.float32(0.30)).astype("int8")  # tunable

split2id = {f"split_{i:02d}": i for i in range(1, 21)}
df_train["split_id"] = df_train["split"].map(split2id).astype("int16")
df_test["split_id"]  = df_test["split"].map(split2id).astype("int16")

# ----------------------------
# 10) JOIN object_quality from STAGE1 (lebih awal, bisa dipakai fold balancing juga)
# ----------------------------
objq_joined = False
oq_train_cols = []
oq_test_cols = []

if OBJQ_TRAIN and Path(OBJQ_TRAIN).exists():
    oq_tr = _read_objq(OBJQ_TRAIN)
    oq_train_cols = oq_tr.columns.tolist()
    df_train = df_train.set_index("object_id", drop=False).join(oq_tr, how="left").reset_index(drop=True)
    objq_joined = True

if OBJQ_TEST and Path(OBJQ_TEST).exists():
    oq_te = _read_objq(OBJQ_TEST)
    oq_test_cols = oq_te.columns.tolist()
    df_test = df_test.set_index("object_id", drop=False).join(oq_te, how="left").reset_index(drop=True)
    objq_joined = True

# unify objq cols intersection (train/test) for meta keep
oq_cols_common = sorted(list(set(oq_train_cols) & set(oq_test_cols)))

# fill + alias safety
if objq_joined and len(oq_cols_common) > 0:
    for df in [df_train, df_test]:
        # default fills
        if "n_obs_total" in df.columns:
            df["n_obs_total"] = pd.to_numeric(df["n_obs_total"], errors="coerce").fillna(0).astype("int32")
            df["low_obs"] = (df["n_obs_total"] < 20).astype("int8")
        if "n_bands_present" in df.columns:
            df["n_bands_present"] = pd.to_numeric(df["n_bands_present"], errors="coerce").fillna(0).astype("int8")
            df["low_bandcov"] = (df["n_bands_present"] <= 2).astype("int8")

        # alias det/strong fractions if stage1 v5.1 naming present
        cand_det = f"snr_det_abs_frac_{BASE_DET_KEY}"
        if ("snr_det_frac" not in df.columns) and (cand_det in df.columns):
            df["snr_det_frac"] = pd.to_numeric(df[cand_det], errors="coerce").fillna(0.0).astype("float32")
        if ("snr_strong_frac" not in df.columns) and ("snr_strong_abs_frac" in df.columns):
            df["snr_strong_frac"] = pd.to_numeric(df["snr_strong_abs_frac"], errors="coerce").fillna(0.0).astype("float32")

        # common numeric fills
        for c in ["timespan","cadence_proxy","neg_flux_frac","snr_det_frac","snr_strong_frac",
                  "snr_abs_max","snr_pos_max","snr_neg_min","flux_mean","flux_std","ferr_mean",
                  "timespan_rest","cadence_proxy_rest","zerr_rel"]:
            if c in df.columns:
                df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0.0).astype("float32")

        # per band fracs
        for b in ["u","g","r","i","z","y"]:
            c = f"frac_{b}"
            if c in df.columns:
                df[c] = pd.to_numeric(df[c], errors="coerce").fillna(0.0).astype("float32")
            c2 = f"det_frac_{b}"
            if c2 in df.columns:
                df[c2] = pd.to_numeric(df[c2], errors="coerce").fillna(0.0).astype("float32")

# ----------------------------
# 11) Fold assignment (split-aware, balanced incl Z/EBV (+ optional obs))
# ----------------------------
df_train["fold"] = -1

if CV_USE_SPLIT_COL:
    sp_pos = df_train.groupby("split")["target"].sum().reindex(SPLIT_LIST).fillna(0).astype(int)
    sp_n   = df_train.groupby("split")["target"].count().reindex(SPLIT_LIST).fillna(0).astype(int)

    sp_zm  = df_train.groupby("split")["log1pZ"].mean().reindex(SPLIT_LIST).fillna(0.0).astype("float32")
    sp_em  = df_train.groupby("split")["EBV_clip"].mean().reindex(SPLIT_LIST).fillna(0.0).astype("float32")

    # optional obs summary if objq exists
    if "n_obs_total" in df_train.columns:
        sp_om = df_train.groupby("split")["n_obs_total"].mean().reindex(SPLIT_LIST).fillna(0.0).astype("float32")
        sp_om = np.log1p(sp_om).astype("float32")
    else:
        sp_om = pd.Series(np.zeros(len(SPLIT_LIST), dtype="float32"), index=SPLIT_LIST)

    sp_stat = pd.DataFrame({
        "split": SPLIT_LIST,
        "n": sp_n.values.astype(int),
        "pos": sp_pos.values.astype(int),
        "mean_log1pZ": sp_zm.values.astype(float),
        "mean_EBV": sp_em.values.astype(float),
        "mean_log1pObs": sp_om.values.astype(float),
    })

    best = _assign_splits_to_folds_greedy_multi(
        sp_stat, n_folds=N_FOLDS, quota=FOLD_QUOTA, seed=SEED,
        lam_count=LAMBDA_COUNT, lam_quota=LAMBDA_QUOTA,
        lam_z=LAMBDA_ZMEAN, lam_ebv=LAMBDA_EBVMEAN, lam_obs=LAMBDA_OBSMEAN,
        penalty_zero_pos=PENALTY_ZERO_POS, restarts=RESTARTS
    )
    if best is None:
        raise RuntimeError("split->fold assignment failed unexpectedly.")

    if best[-1] > 0 and RESTARTS_HARD > RESTARTS:
        best2 = _assign_splits_to_folds_greedy_multi(
            sp_stat, n_folds=N_FOLDS, quota=FOLD_QUOTA, seed=SEED + 999,
            lam_count=LAMBDA_COUNT, lam_quota=LAMBDA_QUOTA,
            lam_z=LAMBDA_ZMEAN, lam_ebv=LAMBDA_EBVMEAN, lam_obs=LAMBDA_OBSMEAN,
            penalty_zero_pos=PENALTY_ZERO_POS, restarts=RESTARTS_HARD
        )
        if best2 is not None and best2[0] <= best[0]:
            best = best2

    score, split2fold, fold_n, fold_pos, fold_k, *_rest, zero_pos = best
    df_train["fold"] = df_train["split"].map(split2fold).astype("int16")

    uniq_folds = sorted(df_train["fold"].unique().tolist())
    if uniq_folds != list(range(N_FOLDS)):
        print(f"[WARN] split->fold tidak memakai semua fold. uniq_folds={uniq_folds}. Fallback StratifiedKFold.")
        CV_USE_SPLIT_COL = False
    else:
        with open(ART_DIR / "split2fold.json", "w", encoding="utf-8") as f:
            json.dump({k:int(v) for k,v in split2fold.items()}, f)

if not CV_USE_SPLIT_COL:
    from sklearn.model_selection import StratifiedKFold
    skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
    y = df_train["target"].to_numpy()
    idx = np.arange(len(df_train))
    for fold_id, (_, va_idx) in enumerate(skf.split(idx, y)):
        df_train.iloc[va_idx, df_train.columns.get_loc("fold")] = fold_id
    df_train["fold"] = df_train["fold"].astype("int16")

if (df_train["fold"] < 0).any():
    n_bad = int((df_train["fold"] < 0).sum())
    raise RuntimeError(f"Fold assignment gagal: ada {n_bad} baris fold=-1")

fold_tab = (
    df_train.groupby("fold")["target"]
    .agg(["count","sum"])
    .rename(columns={"sum":"pos"})
    .reindex(range(N_FOLDS)).fillna(0)
)
fold_tab["pos_rate"] = fold_tab["pos"] / fold_tab["count"].clip(lower=1)

# ----------------------------
# 12) OOF split prior (smoothed)
# ----------------------------
g_pos = float(df_train["target"].sum())
g_n = float(len(df_train))
g_rate = g_pos / max(g_n, 1.0)

sp_all = df_train.groupby("split")["target"].agg(["count","sum"]).rename(columns={"count":"n","sum":"pos"})
sp_all["prior"] = (sp_all["pos"] + PRIOR_ALPHA * g_rate) / (sp_all["n"] + PRIOR_ALPHA)

df_test["split_pos_prior"] = df_test["split"].map(sp_all["prior"]).fillna(g_rate).astype("float32")

df_train["split_pos_prior_oof"] = np.float32(g_rate)
for f in range(N_FOLDS):
    tr_idx = df_train["fold"] != f
    sp_f = df_train.loc[tr_idx].groupby("split")["target"].agg(["count","sum"]).rename(columns={"count":"n","sum":"pos"})
    g_rate_f = float(df_train.loc[tr_idx, "target"].sum() / max(tr_idx.sum(), 1))
    sp_f["prior"] = (sp_f["pos"] + PRIOR_ALPHA * g_rate_f) / (sp_f["n"] + PRIOR_ALPHA)
    m = df_train["fold"] == f
    df_train.loc[m, "split_pos_prior_oof"] = df_train.loc[m, "split"].map(sp_f["prior"]).fillna(g_rate_f).astype("float32")

df_train["split_pos_prior"] = df_train["split"].map(sp_all["prior"]).fillna(g_rate).astype("float32")

# ----------------------------
# 13) Build meta tables (index=object_id)
# ----------------------------
base_train_cols = [
    "object_id","split","split_id",
    "EBV","EBV_clip","log1pEBV","ebv_over_1pz","ebv_x_1pz",
    "Z","Z_clip","log1pZ","inv_1pz","z2",
    "Z_err","Zerr_clip","log1pZerr","zerr_over_1pz","zerr_rel_z","high_zerr",
    "EBV_missing","Z_missing","Zerr_missing","has_zerr","is_photoz",
    "split_pos_prior","split_pos_prior_oof",
    "fold","target"
]
base_test_cols = [
    "object_id","split","split_id",
    "EBV","EBV_clip","log1pEBV","ebv_over_1pz","ebv_x_1pz",
    "Z","Z_clip","log1pZ","inv_1pz","z2",
    "Z_err","Zerr_clip","log1pZerr","zerr_over_1pz","zerr_rel_z","high_zerr",
    "EBV_missing","Z_missing","Zerr_missing","has_zerr","is_photoz",
    "split_pos_prior"
]

# choose objq feature columns that exist in BOTH train/test
objq_cols = []
for c in oq_cols_common:
    if c in df_train.columns and c in df_test.columns:
        objq_cols.append(c)

# ensure alias cols included if created
for c in ["low_obs","low_bandcov","snr_det_frac","snr_strong_frac"]:
    if (c in df_train.columns) and (c in df_test.columns) and (c not in objq_cols):
        objq_cols.append(c)

keep_train = [c for c in base_train_cols if c in df_train.columns] + objq_cols
keep_test  = [c for c in base_test_cols  if c in df_test.columns]  + objq_cols

df_train_meta = df_train[keep_train].copy().set_index("object_id", drop=True).sort_index()
df_test_meta  = df_test[keep_test].copy().set_index("object_id", drop=True).sort_index()

if not df_train_meta.index.is_unique:
    raise RuntimeError("df_train_meta index (object_id) not unique after processing.")
if not df_test_meta.index.is_unique:
    raise RuntimeError("df_test_meta index (object_id) not unique after processing.")

id2split_train = df_train_meta["split"].to_dict()
id2split_test  = df_test_meta["split"].to_dict()

# ----------------------------
# 14) Save artifacts
# ----------------------------
train_pq = ART_DIR / "train_meta.parquet"
test_pq  = ART_DIR / "test_meta.parquet"
train_csv = ART_DIR / "train_meta.csv"
test_csv  = ART_DIR / "test_meta.csv"

try:
    df_train_meta.to_parquet(train_pq, index=True)
    df_test_meta.to_parquet(test_pq, index=True)
    saved_train, saved_test = str(train_pq), str(test_pq)
except Exception:
    df_train_meta.to_csv(train_csv, index=True)
    df_test_meta.to_csv(test_csv, index=True)
    saved_train, saved_test = str(train_csv), str(test_csv)

split_stats = pd.DataFrame({
    "train_objects": df_train_meta["split"].value_counts().reindex(SPLIT_LIST).fillna(0).astype(int),
    "test_objects":  df_test_meta["split"].value_counts().reindex(SPLIT_LIST).fillna(0).astype(int),
})
split_stats.index.name = "split"
pos_by_split = df_train_meta.groupby("split")["target"].sum().reindex(SPLIT_LIST).fillna(0).astype(int)
split_stats["train_pos"] = pos_by_split.values
split_stats["train_pos_rate"] = (split_stats["train_pos"] / split_stats["train_objects"].clip(lower=1)).astype("float32")
split_stats_path = ART_DIR / "split_stats.csv"
split_stats.to_csv(split_stats_path)

fold_path = ART_DIR / "train_folds.csv"
df_train_meta.reset_index()[["object_id","split","fold","target"]].to_csv(fold_path, index=False)

with open(ART_DIR / "id2split_train.json", "w", encoding="utf-8") as f:
    json.dump(id2split_train, f)
with open(ART_DIR / "id2split_test.json", "w", encoding="utf-8") as f:
    json.dump(id2split_test, f)

drop_nonfeat = set(["target","fold"])
meta_feature_cols = [c for c in df_train_meta.columns.tolist() if c not in drop_nonfeat]
with open(ART_DIR / "meta_feature_cols.json", "w", encoding="utf-8") as f:
    json.dump(meta_feature_cols, f, indent=2)

pos = int((df_train_meta["target"] == 1).sum())
neg = int((df_train_meta["target"] == 0).sum())
tot = int(len(df_train_meta))
pos_rate = pos / max(tot, 1)
scale_pos_weight = float(neg / max(pos, 1))

stage2_summary = {
    "stage": "stage2",
    "N_FOLDS": int(N_FOLDS),
    "CV_USE_SPLIT_COL_USED": bool(CV_USE_SPLIT_COL),
    "OBJQ_JOINED": bool(objq_joined),
    "objq_cols_used": int(len(objq_cols)),
    "counts": {
        "train": int(tot),
        "pos": int(pos),
        "neg": int(neg),
        "pos_rate": float(pos_rate),
        "test": int(len(df_test_meta))
    },
    "clip_ranges": {
        "EBV_train": [float(EBV_LO), float(EBV_HI)],
        "Z_train": [float(Z_LO), float(Z_HI)],
        "Zerr_used": [float(ZE_LO), float(ZE_HI)]
    },
    "split_prior": {
        "alpha": float(PRIOR_ALPHA),
        "global_pos_rate": float(g_rate)
    },
    "scale_pos_weight": float(scale_pos_weight),
    "base_det_thr": float(BASE_DET_THR),
    "artifacts": {
        "train_meta": saved_train,
        "test_meta": saved_test,
        "split_stats": str(split_stats_path),
        "train_folds": str(fold_path),
        "id2split_train": str(ART_DIR / "id2split_train.json"),
        "id2split_test": str(ART_DIR / "id2split_test.json"),
        "split2fold": str(ART_DIR / "split2fold.json") if (ART_DIR / "split2fold.json").exists() else None,
        "meta_feature_cols": str(ART_DIR / "meta_feature_cols.json"),
    }
}
with open(ART_DIR / "stage2_summary.json", "w", encoding="utf-8") as f:
    json.dump(stage2_summary, f, indent=2)

# ----------------------------
# 15) Print summary
# ----------------------------
print("STAGE 2 OK — META READY (clean + folds + enriched)")
print(f"- CV_USE_SPLIT_COL_USED: {CV_USE_SPLIT_COL} | N_FOLDS={N_FOLDS} | split_quota={FOLD_QUOTA}")
print(f"- OBJQ_JOINED: {objq_joined} | objq_cols_used={len(objq_cols)} | base_det_thr={BASE_DET_THR}")
print(f"- train objects: {tot:,} | pos={pos:,} | neg={neg:,} | pos%={pos_rate*100:.3f}%")
print(f"- test objects : {len(df_test_meta):,}")
print(f"- saved train  : {saved_train}")
print(f"- saved test   : {saved_test}")
print(f"- saved stats  : {split_stats_path}")
print(f"- saved folds  : {fold_path}")
print(f"- saved meta_feature_cols: {ART_DIR / 'meta_feature_cols.json'}")
print(f"- scale_pos_weight (neg/pos): {scale_pos_weight:.3f}")

print("\nCLIP RANGES")
print(f"- EBV clip (train): [{EBV_LO:.6f}, {EBV_HI:.6f}]")
print(f"- Z   clip (train): [{Z_LO:.6f}, {Z_HI:.6f}]")
print(f"- Zerr clip used  : [{ZE_LO:.6f}, {ZE_HI:.6f}]")

print("\nFOLD BALANCE (count/pos/pos_rate)")
print(fold_tab.to_string())

# ----------------------------
# 16) Export globals
# ----------------------------
globals().update({
    "df_train_meta": df_train_meta,
    "df_test_meta": df_test_meta,
    "id2split_train": id2split_train,
    "id2split_test": id2split_test,
    "split_stats": split_stats,
    "split2id": split2id,
    "scale_pos_weight": scale_pos_weight,
    "CV_USE_SPLIT_COL_USED": CV_USE_SPLIT_COL,
    "META_FEATURE_COLS_PATH": str(ART_DIR / "meta_feature_cols.json"),
})

gc.collect()


# Lightcurve Loading Strategy

In [None]:
# ============================================================
# STAGE 3 — Robust Lightcurve Loader Utilities + FULL Object-Quality Build (ONE CELL)
# REVISI FULL v5.4 (UPGRADE dari v5.3: filter numeric-safe + OBJQ sanitize hard + smoke test anti-StopIteration)
#
# Upgrade v5.4 (dibanding v5.3):
# - FIX filter "0..5" (string/angka) -> map ke u,g,r,i,z,y (tidak kebuang karena invalid)
# - OBJQ join makin aman:
#   * drop kolom "Unnamed:*"
#   * drop kolom meta-like (heuristic) walau tidak overlap
#   * keep only kolom objq valid (whitelist) biar objq csv kotor tidak nyusup ke meta
#   * numeric coercion + dedup aggregator aman
# - Smoke test: handle StopIteration (kalau chunk pertama habis terfilter / file kosong)
#
# Tetap:
# - FAST dtype mode (float32 parse) + fallback SAFE mode (coerce)
# - FULL Object Quality builder (streaming per split)
# - Output: logs/object_quality_train.csv + logs/object_quality_test.csv
# - Auto-join ke df_train_meta / df_test_meta + overwrite meta di ART_DIR
# - Loader functions: iter_lightcurve_chunks, load_object_lightcurve, load_many_object_lightcurves, NPZ cache
#
# Output globals:
# - OBJECT_QUALITY_TRAIN_PATH, OBJECT_QUALITY_TEST_PATH
# - df_train_meta, df_test_meta updated (enriched)
# ============================================================

import gc, re, json, time
from pathlib import Path
import numpy as np
import pandas as pd

# ----------------------------
# 0) Require previous stages
# ----------------------------
need_prev = ["SPLIT_DIRS", "SPLIT_LIST", "df_train_meta", "df_test_meta", "ART_DIR", "CFG", "SEED", "LOG_DIR"]
for need in need_prev:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan STAGE 0 -> STAGE 1 -> STAGE 2 dulu.")

ART_DIR = Path(ART_DIR); ART_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR = Path(LOG_DIR); LOG_DIR.mkdir(parents=True, exist_ok=True)
SEED = int(SEED)

# Optional CACHE_DIR (dari STAGE 0). Kalau tidak ada, fallback ke ART_DIR/cache
CACHE_DIR = Path(globals().get("CACHE_DIR", ART_DIR / "cache"))
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Core config
MIN_FLUXERR = float(CFG.get("MIN_FLUXERR", 1e-6))
CHUNK_ROWS_DEFAULT = int(CFG.get("CHUNK_ROWS", 400_000))
SNR_CLIP = float(CFG.get("SNR_CLIP", 30.0))
SNR_DET_THR = float(CFG.get("SNR_DET_THR", 3.0))
SNR_STRONG_THR = float(CFG.get("SNR_STRONG_THR", 5.0))

# Object-quality builder config (BRUTAL but CPU-safe)
BUILD_OBJECT_QUALITY = bool(CFG.get("BUILD_OBJECT_QUALITY", True))
OBJQ_REFRESH = bool(CFG.get("OBJQ_REFRESH", False))
OBJQ_CHUNK_ROWS = int(CFG.get("OBJQ_CHUNK_ROWS", max(200_000, CHUNK_ROWS_DEFAULT)))
OBJQ_SAVE_PARQUET = bool(CFG.get("OBJQ_SAVE_PARQUET", False))  # default csv for portability
OBJQ_ONLY_IF_MISSING = bool(CFG.get("OBJQ_ONLY_IF_MISSING", True))  # don't rebuild if files exist (unless refresh)

# konsisten dengan STAGE 0/2
SAFE_NA_VALUES = ["", " ", "NA", "NaN", "nan", "NULL", "null", "None", "none"]
SAFE_READ_KW = dict(low_memory=False, na_values=SAFE_NA_VALUES, keep_default_na=True)

REQ_LC_KEYS = ["object_id", "mjd", "flux", "flux_err", "filter"]
ALLOWED_FILTERS = {"u", "g", "r", "i", "z", "y"}
ALLOWED_FILTERS_TUP = ("u", "g", "r", "i", "z", "y")
FILTER_ORDER = {"u":0, "g":1, "r":2, "i":3, "z":4, "y":5}
FILTER2ID = {"u":0, "g":1, "r":2, "i":3, "z":4, "y":5}
ID2FILTER = {v:k for k,v in FILTER2ID.items()}

# v5.4: support numeric filters "0..5"
FILTER_NUM2STR = {"0":"u","1":"g","2":"r","3":"i","4":"z","5":"y"}

# Cache configs
_LC_CFG_CACHE = {}  # (split_name, which) -> cfg dict
_LC_OBJ_CACHE = {}  # in-memory small cache: (which, object_id) -> df
LC_CACHE_DIR = CACHE_DIR / "lightcurves_npz"
LC_CACHE_DIR.mkdir(parents=True, exist_ok=True)
MAX_MEM_CACHE = int(CFG.get("LC_MEM_CACHE_MAX", 64))  # max objects in RAM cache

# ----------------------------
# Utils: ID normalize (robust)
# ----------------------------
def _norm_id(x):
    if isinstance(x, (bytes, np.bytes_, bytearray)):
        try:
            x = x.decode("utf-8", errors="ignore")
        except Exception:
            x = str(x)
    s = str(x).strip()
    if (s.startswith("b'") and s.endswith("'")) or (s.startswith('b"') and s.endswith('"')):
        s = s[2:-1]
    return s.strip()

# Normalize meta indices early (prevent mismatch)
df_train_meta = df_train_meta.copy(deep=False)
df_test_meta  = df_test_meta.copy(deep=False)
df_train_meta.index = pd.Index([_norm_id(z) for z in df_train_meta.index.tolist()], name=df_train_meta.index.name)
df_test_meta.index  = pd.Index([_norm_id(z) for z in df_test_meta.index.tolist()], name=df_test_meta.index.name)

if "split" not in df_train_meta.columns or "split" not in df_test_meta.columns:
    raise RuntimeError("Missing column `split` in df_train_meta/df_test_meta. Pastikan STAGE 2 membuat routing split.")

# ----------------------------
# 1) Build split file mapping (train/test lightcurves)
# ----------------------------
SPLIT_FILES = {}
for s in SPLIT_LIST:
    sd = Path(SPLIT_DIRS[s])
    tr = sd / "train_full_lightcurves.csv"
    te = sd / "test_full_lightcurves.csv"
    if (not tr.exists()) or (not te.exists()):
        raise FileNotFoundError(f"Missing lightcurve file(s) in {sd}: train={tr.exists()} test={te.exists()}")
    SPLIT_FILES[s] = {"train": tr, "test": te}

# Save split file manifest
manifest = []
for s in SPLIT_LIST:
    p_tr = SPLIT_FILES[s]["train"]
    p_te = SPLIT_FILES[s]["test"]
    manifest.append({
        "split": s,
        "train_path": str(p_tr),
        "test_path": str(p_te),
        "train_mb": float(p_tr.stat().st_size) / (1024**2),
        "test_mb":  float(p_te.stat().st_size) / (1024**2),
    })
df_manifest = pd.DataFrame(manifest).sort_values("split")
manifest_path = ART_DIR / "split_file_manifest.csv"
df_manifest.to_csv(manifest_path, index=False)

# ----------------------------
# 2) Build object routing by split
# ----------------------------
train_ids_by_split = {s: [] for s in SPLIT_LIST}
test_ids_by_split  = {s: [] for s in SPLIT_LIST}

tr_groups = df_train_meta.groupby("split").groups
te_groups = df_test_meta.groupby("split").groups

for sp, idx in tr_groups.items():
    if sp in train_ids_by_split:
        train_ids_by_split[sp] = pd.Index(idx).astype(str).map(_norm_id).tolist()

for sp, idx in te_groups.items():
    if sp in test_ids_by_split:
        test_ids_by_split[sp] = pd.Index(idx).astype(str).map(_norm_id).tolist()

if sum(len(v) for v in train_ids_by_split.values()) != len(df_train_meta):
    raise RuntimeError("Routing train_ids_by_split mismatch total vs df_train_meta length.")
if sum(len(v) for v in test_ids_by_split.values()) != len(df_test_meta):
    raise RuntimeError("Routing test_ids_by_split mismatch total vs df_test_meta length.")

df_counts = pd.DataFrame({
    "split": SPLIT_LIST,
    "train_objects": [len(train_ids_by_split[s]) for s in SPLIT_LIST],
    "test_objects":  [len(test_ids_by_split[s]) for s in SPLIT_LIST],
})
counts_path = ART_DIR / "object_counts_by_split.csv"
df_counts.to_csv(counts_path, index=False)

# ----------------------------
# 3) Robust header mapping -> canonical columns (FAST dtype + SAFE fallback)
# ----------------------------
def _canon_col(x: str) -> str:
    s = str(x).strip().lower()
    s = s.replace("\ufeff", "")
    s = re.sub(r"\s+", "", s)
    s = s.replace("(", "").replace(")", "")
    s = s.replace("-", "_")
    return s

def _build_lc_read_cfg(p: Path):
    h = pd.read_csv(p, nrows=0, **SAFE_READ_KW)
    orig_cols = list(h.columns)

    c2o = {}
    for c in orig_cols:
        k = _canon_col(c)
        if k not in c2o:
            c2o[k] = c

    obj_col = c2o.get("object_id", None)

    time_col = None
    for k in ["time_mjd", "timemjd", "mjd", "time"]:
        if k in c2o:
            time_col = c2o[k]
            break

    flux_col = c2o.get("flux", None)

    ferr_col = None
    for k in ["flux_err", "fluxerr", "fluxerror"]:
        if k in c2o:
            ferr_col = c2o[k]
            break

    filt_col = c2o.get("filter", None)

    missing = []
    if obj_col is None:  missing.append("object_id")
    if time_col is None: missing.append("Time (MJD)")
    if flux_col is None: missing.append("Flux")
    if ferr_col is None: missing.append("Flux_err")
    if filt_col is None: missing.append("Filter")
    if missing:
        raise ValueError(
            f"Missing required lightcurve columns in {p.name}: {missing}. "
            f"Header sample: {orig_cols[:20]}"
        )

    usecols = [obj_col, time_col, flux_col, ferr_col, filt_col]
    rename = {obj_col:"object_id", time_col:"mjd", flux_col:"flux", ferr_col:"flux_err", filt_col:"filter"}

    dtype_fast = {
        obj_col: "string",
        filt_col: "string",
        time_col: "float32",
        flux_col: "float32",
        ferr_col: "float32",
    }
    dtype_safe = {obj_col: "string", filt_col: "string"}

    return {"usecols": usecols, "dtype_fast": dtype_fast, "dtype_safe": dtype_safe, "rename": rename}

def _normalize_lc_chunk(
    df: pd.DataFrame,
    drop_bad_filter: bool = True,
    drop_bad_mjd: bool = True,
    drop_bad_fluxerr: bool = True,
    encode_filter: bool = False,
):
    df = df[["object_id","mjd","flux","flux_err","filter"]].copy()

    df["object_id"] = df["object_id"].astype("string").str.strip()

    # v5.4: filter robust (support numeric 0..5)
    f = df["filter"].astype("string").str.strip().str.lower()
    f = f.replace(FILTER_NUM2STR)  # "0"->"u", ...
    df["filter"] = f

    df.loc[~df["filter"].isin(ALLOWED_FILTERS_TUP), "filter"] = pd.NA

    # numeric coercion (SAFE path)
    if df["mjd"].dtype == "O" or str(df["mjd"].dtype).startswith("string"):
        df["mjd"] = pd.to_numeric(df["mjd"], errors="coerce")
    if df["flux"].dtype == "O" or str(df["flux"].dtype).startswith("string"):
        df["flux"] = pd.to_numeric(df["flux"], errors="coerce")
    if df["flux_err"].dtype == "O" or str(df["flux_err"].dtype).startswith("string"):
        df["flux_err"] = pd.to_numeric(df["flux_err"], errors="coerce")

    df["mjd"] = df["mjd"].astype("float32")
    df["flux"] = df["flux"].astype("float32")
    df["flux_err"] = df["flux_err"].astype("float32")

    # Guard flux_err
    fe = df["flux_err"]
    if drop_bad_fluxerr:
        df = df[fe.notna()]
        fe = df["flux_err"]
        df = df[fe > 0]
        fe = df["flux_err"]

    if MIN_FLUXERR > 0:
        fe = df["flux_err"]
        df.loc[fe.notna() & (fe < MIN_FLUXERR), "flux_err"] = np.float32(MIN_FLUXERR)

    # Drop empty id
    df = df[df["object_id"].notna() & (df["object_id"] != "")]

    if drop_bad_filter:
        df = df[df["filter"].notna()]
    if drop_bad_mjd:
        df = df[df["mjd"].notna()]

    df = df[REQ_LC_KEYS]

    if encode_filter:
        df = df.copy()
        df["filter_id"] = df["filter"].map(FILTER2ID).astype("int8")

    return df

# ----------------------------
# 4) Chunked readers
# ----------------------------
def iter_lightcurve_chunks(
    split_name: str,
    which: str,
    chunksize: int = None,
    drop_bad_filter: bool = True,
    drop_bad_mjd: bool = True,
    drop_bad_fluxerr: bool = True,
    encode_filter: bool = False,
):
    split_name = str(split_name).strip()
    if split_name not in SPLIT_FILES:
        raise KeyError(f"Unknown split_name={split_name}. Known={list(SPLIT_FILES.keys())[:5]}..")
    if which not in ("train", "test"):
        raise ValueError("which must be 'train' or 'test'")

    if chunksize is None:
        chunksize = CHUNK_ROWS_DEFAULT

    p = SPLIT_FILES[split_name][which]
    key = (split_name, which)
    if key not in _LC_CFG_CACHE:
        _LC_CFG_CACHE[key] = _build_lc_read_cfg(p)
    cfg = _LC_CFG_CACHE[key]

    try:
        reader = pd.read_csv(
            p,
            usecols=cfg["usecols"],
            dtype=cfg["dtype_fast"],
            chunksize=int(chunksize),
            engine="c",
            memory_map=True,
            **SAFE_READ_KW
        )
        for chunk in reader:
            chunk = chunk.rename(columns=cfg["rename"])
            yield _normalize_lc_chunk(
                chunk,
                drop_bad_filter=drop_bad_filter,
                drop_bad_mjd=drop_bad_mjd,
                drop_bad_fluxerr=drop_bad_fluxerr,
                encode_filter=encode_filter,
            )
    except Exception:
        reader = pd.read_csv(
            p,
            usecols=cfg["usecols"],
            dtype=cfg["dtype_safe"],
            chunksize=int(chunksize),
            engine="c",
            memory_map=True,
            **SAFE_READ_KW
        )
        for chunk in reader:
            chunk = chunk.rename(columns=cfg["rename"])
            yield _normalize_lc_chunk(
                chunk,
                drop_bad_filter=drop_bad_filter,
                drop_bad_mjd=drop_bad_mjd,
                drop_bad_fluxerr=drop_bad_fluxerr,
                encode_filter=encode_filter,
            )

def load_object_lightcurve(
    object_id: str,
    which: str,
    chunksize: int = None,
    sort_time: bool = True,
    max_chunks: int = None,
    stop_after_found_block: bool = True,
    drop_bad_filter: bool = True,
    drop_bad_mjd: bool = True,
    drop_bad_fluxerr: bool = True,
):
    object_id = _norm_id(object_id)

    if which == "train":
        if object_id not in df_train_meta.index:
            raise KeyError(f"object_id not found in df_train_meta: {object_id}")
        split_name = str(df_train_meta.loc[object_id, "split"]).strip()
    elif which == "test":
        if object_id not in df_test_meta.index:
            raise KeyError(f"object_id not found in df_test_meta: {object_id}")
        split_name = str(df_test_meta.loc[object_id, "split"]).strip()
    else:
        raise ValueError("which must be 'train' or 'test'")

    if split_name not in SPLIT_FILES:
        raise KeyError(f"Routing split not found in SPLIT_FILES: split={split_name} object_id={object_id}")

    if chunksize is None:
        chunksize = CHUNK_ROWS_DEFAULT

    pieces = []
    seen = 0
    found_any = False
    last_hit = False

    for ch in iter_lightcurve_chunks(
        split_name, which, chunksize=chunksize,
        drop_bad_filter=drop_bad_filter, drop_bad_mjd=drop_bad_mjd, drop_bad_fluxerr=drop_bad_fluxerr,
        encode_filter=False
    ):
        seen += 1
        sub = ch[ch["object_id"] == object_id]
        hit = (len(sub) > 0)
        if hit:
            pieces.append(sub)
            found_any = True

        if stop_after_found_block and found_any and last_hit and (not hit):
            break
        last_hit = hit

        if max_chunks is not None and seen >= int(max_chunks):
            break

    if not pieces:
        out = pd.DataFrame(columns=REQ_LC_KEYS)
    else:
        out = pd.concat(pieces, ignore_index=True)
        if sort_time and len(out) > 1:
            out["filter_ord"] = out["filter"].map(FILTER_ORDER).astype("int16")
            out = (
                out.sort_values(["mjd", "filter_ord"], kind="mergesort")
                   .drop(columns=["filter_ord"])
                   .reset_index(drop=True)
            )
    return out

def load_many_object_lightcurves(
    split_name: str,
    which: str,
    object_ids,
    chunksize: int = None,
    max_chunks: int = None,
    sort_time: bool = True,
    drop_bad_filter: bool = True,
    drop_bad_mjd: bool = True,
    drop_bad_fluxerr: bool = True,
):
    split_name = str(split_name).strip()
    if split_name not in SPLIT_FILES:
        raise KeyError(f"Unknown split_name={split_name}")
    if which not in ("train","test"):
        raise ValueError("which must be train/test")

    if chunksize is None:
        chunksize = CHUNK_ROWS_DEFAULT

    oid_set = set([_norm_id(x) for x in object_ids if _norm_id(x) != ""])
    if len(oid_set) == 0:
        return {}

    out = {oid: [] for oid in oid_set}

    seen = 0
    for ch in iter_lightcurve_chunks(
        split_name, which, chunksize=chunksize,
        drop_bad_filter=drop_bad_filter, drop_bad_mjd=drop_bad_mjd, drop_bad_fluxerr=drop_bad_fluxerr,
        encode_filter=False
    ):
        seen += 1
        m = ch["object_id"].isin(oid_set)
        if m.any():
            sub = ch.loc[m]
            for oid, g in sub.groupby("object_id", sort=False):
                out[_norm_id(oid)].append(g)

        if max_chunks is not None and seen >= int(max_chunks):
            break

    final = {}
    for oid, parts in out.items():
        if not parts:
            continue
        df = pd.concat(parts, ignore_index=True)
        if sort_time and len(df) > 1:
            df["filter_ord"] = df["filter"].map(FILTER_ORDER).astype("int16")
            df = (
                df.sort_values(["mjd","filter_ord"], kind="mergesort")
                  .drop(columns=["filter_ord"])
                  .reset_index(drop=True)
            )
        final[oid] = df

    return final

# ----------------------------
# 5) NPZ cache utilities
# ----------------------------
def get_lc_cache_path(object_id: str, which: str) -> Path:
    object_id = _norm_id(object_id)
    which = str(which).strip()
    return LC_CACHE_DIR / f"{which}__{object_id}.npz"

def _mem_cache_put(key, value):
    _LC_OBJ_CACHE[key] = value
    if len(_LC_OBJ_CACHE) > MAX_MEM_CACHE:
        k0 = next(iter(_LC_OBJ_CACHE.keys()))
        _LC_OBJ_CACHE.pop(k0, None)

def load_object_lightcurve_cached(
    object_id: str,
    which: str,
    chunksize: int = None,
    sort_time: bool = True,
    max_chunks: int = None,
    stop_after_found_block: bool = True,
    use_npz_cache: bool = True,
    refresh_cache: bool = False,
):
    object_id = _norm_id(object_id)
    which = str(which).strip()

    mem_key = (which, object_id)
    if mem_key in _LC_OBJ_CACHE and (not refresh_cache):
        return _LC_OBJ_CACHE[mem_key].copy()

    npz_path = get_lc_cache_path(object_id, which)
    if use_npz_cache and npz_path.exists() and (not refresh_cache):
        z = np.load(npz_path, allow_pickle=False)
        mjd = z["mjd"].astype("float32")
        flux = z["flux"].astype("float32")
        ferr = z["flux_err"].astype("float32")
        filt_id = z["filter_id"].astype("int8")
        filt = np.array([ID2FILTER[int(i)] for i in filt_id], dtype=object)

        df = pd.DataFrame({
            "object_id": object_id,
            "mjd": mjd,
            "flux": flux,
            "flux_err": ferr,
            "filter": filt,
        })
        if sort_time and len(df) > 1:
            df["filter_ord"] = df["filter"].map(FILTER_ORDER).astype("int16")
            df = df.sort_values(["mjd","filter_ord"], kind="mergesort").drop(columns=["filter_ord"]).reset_index(drop=True)
        _mem_cache_put(mem_key, df)
        return df.copy()

    df = load_object_lightcurve(
        object_id, which,
        chunksize=chunksize,
        sort_time=sort_time,
        max_chunks=max_chunks,
        stop_after_found_block=stop_after_found_block,
        drop_bad_filter=True, drop_bad_mjd=True, drop_bad_fluxerr=True
    )

    if use_npz_cache:
        try:
            if len(df) > 0:
                filt_id = df["filter"].map(FILTER2ID).astype("int8").to_numpy()
                np.savez_compressed(
                    npz_path,
                    mjd=df["mjd"].to_numpy(dtype="float32"),
                    flux=df["flux"].to_numpy(dtype="float32"),
                    flux_err=df["flux_err"].to_numpy(dtype="float32"),
                    filter_id=filt_id
                )
        except Exception:
            pass

    _mem_cache_put(mem_key, df)
    return df.copy()

# ----------------------------
# 6) FULL Object-Quality builder (streaming, vectorized)
# ----------------------------
def _build_object_quality(which: str, out_path: Path, refresh: bool = False, chunksize: int = 400_000):
    which = str(which).strip()
    if which not in ("train","test"):
        raise ValueError("which must be train/test")

    if out_path.exists() and (not refresh):
        df = pd.read_csv(out_path, dtype={"object_id":"string"}, **SAFE_READ_KW)
        if "object_id" not in df.columns:
            raise RuntimeError(f"Bad objq file (missing object_id): {out_path}")
        df["object_id"] = df["object_id"].astype("string").map(_norm_id)
        return df

    idx = df_train_meta.index if which == "train" else df_test_meta.index
    idx = pd.Index([_norm_id(z) for z in idx.astype(str).tolist()])
    n_obj = int(len(idx))

    n_obs = np.zeros(n_obj, dtype=np.int32)
    mjd_min = np.full(n_obj, np.float32(np.inf), dtype=np.float32)
    mjd_max = np.full(n_obj, np.float32(-np.inf), dtype=np.float32)

    sum_flux = np.zeros(n_obj, dtype=np.float64)
    sum_flux2 = np.zeros(n_obj, dtype=np.float64)
    sum_abs_flux = np.zeros(n_obj, dtype=np.float64)
    neg_cnt = np.zeros(n_obj, dtype=np.int32)

    snr_abs_max = np.full(n_obj, np.float32(0.0), dtype=np.float32)
    sum_snr_abs = np.zeros(n_obj, dtype=np.float64)
    det_cnt = np.zeros(n_obj, dtype=np.int32)
    strong_cnt = np.zeros(n_obj, dtype=np.int32)

    band_cnt = np.zeros((6, n_obj), dtype=np.int32)

    t0 = time.time()
    rows_seen = 0

    for sp in SPLIT_LIST:
        for ch in iter_lightcurve_chunks(
            sp, which,
            chunksize=chunksize,
            drop_bad_filter=True, drop_bad_mjd=True, drop_bad_fluxerr=True,
            encode_filter=True
        ):
            if ch is None or len(ch) == 0:
                continue

            oids = ch["object_id"].to_numpy(dtype=object, copy=False)
            oid_idx = idx.get_indexer(oids)  # -1 if not found
            m = oid_idx >= 0
            if not np.any(m):
                continue

            oid_idx = oid_idx[m].astype(np.int64, copy=False)

            mjd = ch["mjd"].to_numpy(dtype=np.float32, copy=False)[m]
            flux = ch["flux"].to_numpy(dtype=np.float32, copy=False)[m]
            ferr = ch["flux_err"].to_numpy(dtype=np.float32, copy=False)[m]
            fid = ch["filter_id"].to_numpy(dtype=np.int8, copy=False)[m]

            if MIN_FLUXERR > 0:
                ferr = np.maximum(ferr, np.float32(MIN_FLUXERR))

            rows_seen += int(len(oid_idx))

            bc = np.bincount(oid_idx, minlength=n_obj)
            n_obs += bc.astype(np.int32)

            fx64 = flux.astype(np.float64)
            sum_flux += np.bincount(oid_idx, weights=fx64, minlength=n_obj)
            sum_flux2 += np.bincount(oid_idx, weights=(fx64 * fx64), minlength=n_obj)
            sum_abs_flux += np.bincount(oid_idx, weights=np.abs(fx64), minlength=n_obj)
            neg_cnt += np.bincount(oid_idx, weights=(flux < 0).astype(np.int32), minlength=n_obj).astype(np.int32)

            snr = (flux / ferr).astype(np.float32)
            snr = np.clip(snr, -np.float32(SNR_CLIP), np.float32(SNR_CLIP))
            snr_abs = np.abs(snr).astype(np.float32)

            sum_snr_abs += np.bincount(oid_idx, weights=snr_abs.astype(np.float64), minlength=n_obj)
            det_cnt += np.bincount(oid_idx, weights=(snr_abs >= np.float32(SNR_DET_THR)).astype(np.int32), minlength=n_obj).astype(np.int32)
            strong_cnt += np.bincount(oid_idx, weights=(snr_abs >= np.float32(SNR_STRONG_THR)).astype(np.int32), minlength=n_obj).astype(np.int32)

            for b in range(6):
                mb = (fid == b)
                if np.any(mb):
                    band_cnt[b] += np.bincount(oid_idx[mb], minlength=n_obj).astype(np.int32)

            order = np.argsort(oid_idx, kind="mergesort")
            idx_s = oid_idx[order]
            mjd_s = mjd[order]
            snr_s = snr_abs[order]

            starts = np.r_[0, 1 + np.where(idx_s[1:] != idx_s[:-1])[0]]
            uniq = idx_s[starts]

            mn = np.minimum.reduceat(mjd_s, starts).astype(np.float32)
            mx = np.maximum.reduceat(mjd_s, starts).astype(np.float32)
            sx = np.maximum.reduceat(snr_s, starts).astype(np.float32)

            mjd_min[uniq] = np.minimum(mjd_min[uniq], mn)
            mjd_max[uniq] = np.maximum(mjd_max[uniq], mx)
            snr_abs_max[uniq] = np.maximum(snr_abs_max[uniq], sx)

    n_obs_f = np.maximum(n_obs.astype(np.float32), np.float32(1.0))
    timespan = np.where(
        np.isfinite(mjd_min) & np.isfinite(mjd_max),
        (mjd_max - mjd_min).astype(np.float32),
        np.float32(0.0)
    )
    cadence_proxy = (timespan / np.maximum(n_obs - 1, 1).astype(np.float32)).astype(np.float32)

    flux_mean = (sum_flux / n_obs_f.astype(np.float64)).astype(np.float32)
    flux_var = (sum_flux2 / n_obs_f.astype(np.float64) - (sum_flux / n_obs_f.astype(np.float64))**2)
    flux_var = np.maximum(flux_var, 0.0)
    flux_std = np.sqrt(flux_var).astype(np.float32)

    abs_flux_mean = (sum_abs_flux / n_obs_f.astype(np.float64)).astype(np.float32)
    neg_flux_frac = (neg_cnt.astype(np.float32) / n_obs_f).astype(np.float32)

    snr_abs_mean = (sum_snr_abs / n_obs_f.astype(np.float64)).astype(np.float32)
    snr_det_frac = (det_cnt.astype(np.float32) / n_obs_f).astype(np.float32)
    snr_strong_frac = (strong_cnt.astype(np.float32) / n_obs_f).astype(np.float32)

    n_bands_present = (band_cnt > 0).sum(axis=0).astype(np.int8)
    frac_bands = (band_cnt.astype(np.float32) / n_obs_f[None, :]).astype(np.float32)

    df_out = pd.DataFrame({
        "object_id": pd.Series(idx, dtype="string"),
        "n_obs_total": n_obs.astype(np.int32),
        "n_bands_present": n_bands_present,
        "timespan": timespan,
        "cadence_proxy": cadence_proxy,
        "neg_flux_frac": neg_flux_frac,
        "flux_mean": flux_mean,
        "flux_std": flux_std,
        "abs_flux_mean": abs_flux_mean,
        "snr_abs_mean": snr_abs_mean,
        "snr_abs_max": snr_abs_max.astype(np.float32),
        "snr_det_frac": snr_det_frac,
        "snr_strong_frac": snr_strong_frac,
    })

    for b, name in enumerate(["u","g","r","i","z","y"]):
        df_out[f"n_{name}"] = band_cnt[b].astype(np.int32)
        df_out[f"frac_{name}"] = frac_bands[b].astype(np.float32)

    df_out.to_csv(out_path, index=False)
    if OBJQ_SAVE_PARQUET:
        try:
            df_out.to_parquet(out_path.with_suffix(".parquet"), index=False)
        except Exception:
            pass

    elapsed = time.time() - t0
    print(f"[OBJQ] Built {which} object_quality: {out_path} | objects={len(df_out):,} | rows_seen={rows_seen:,} | {elapsed/60:.2f} min")
    return df_out

# ----------------------------
# 7) SAFE overwrite-join object-quality -> meta (v5.4 sanitize hard)
# ----------------------------
OBJECT_QUALITY_TRAIN_PATH = LOG_DIR / "object_quality_train.csv"
OBJECT_QUALITY_TEST_PATH  = LOG_DIR / "object_quality_test.csv"

META_PROTECT_COLS = set([
    "split", "split_id", "fold",
    "target", "y", "label", "class",
    "EBV", "EBV_clip", "EBV_used", "EBV_missing", "log1pEBV",
    "Z", "Z_clip", "Z_err", "Z_missing", "Z_err_missing", "log1pZ", "log1pZerr",
    "is_photoz", "photoz", "redshift",
    "prior", "pos_prior", "neg_prior"
])

def _drop_unnamed_cols(df: pd.DataFrame) -> pd.DataFrame:
    bad = [c for c in df.columns if str(c).startswith("Unnamed")]
    return df.drop(columns=bad, errors="ignore") if bad else df

def _is_objq_valid_col(c: str) -> bool:
    c = str(c)
    base = {
        "n_obs_total","n_bands_present","timespan","cadence_proxy",
        "neg_flux_frac","abs_flux_mean",
        "flux_mean","flux_std",
        "snr_abs_mean","snr_abs_max","snr_det_frac","snr_strong_frac",
        "low_obs","low_bandcov",
    }
    if c in base:
        return True
    if c.startswith("n_"):     # n_u..n_y
        return True
    if c.startswith("frac_"):  # frac_u..frac_y
        return True
    return False

def _looks_like_meta_col(c: str) -> bool:
    s = str(c).lower()
    if s in {x.lower() for x in META_PROTECT_COLS}:
        return True
    meta_sub = [
        "split", "target", "label", "class", "fold",
        "ebv", "redshift", "photoz", "z_err", "zerr",
        "log1p", "_clip", "_missing", "prior"
    ]
    return any(k in s for k in meta_sub)

def _clean_objq_df(df: pd.DataFrame, tag: str) -> pd.DataFrame:
    if df is None or len(df) == 0:
        return df
    df = _drop_unnamed_cols(df).copy()

    if "object_id" not in df.columns:
        raise RuntimeError(f"[OBJQ][{tag}] missing object_id col")

    df["object_id"] = df["object_id"].astype("string").map(_norm_id)

    drop_metaish = [c for c in df.columns if c != "object_id" and _looks_like_meta_col(c)]
    if drop_metaish:
        print(f"[OBJQ][{tag}] drop META-like cols from objq: {drop_metaish[:12]}{'...' if len(drop_metaish)>12 else ''}")
        df = df.drop(columns=drop_metaish, errors="ignore")

    keep = ["object_id"] + [c for c in df.columns if c != "object_id" and _is_objq_valid_col(c)]
    dropped_other = [c for c in df.columns if c not in keep]
    if dropped_other:
        print(f"[OBJQ][{tag}] drop non-objq cols: {dropped_other[:12]}{'...' if len(dropped_other)>12 else ''}")
    df = df[keep].copy()

    for c in df.columns:
        if c == "object_id":
            continue
        df[c] = pd.to_numeric(df[c], errors="coerce")

    return df

def _dedup_objq(df: pd.DataFrame, tag: str):
    if df is None or df.empty:
        return df
    df = df.copy()
    df["object_id"] = df["object_id"].astype("string").map(_norm_id)

    if not df["object_id"].duplicated().any():
        return df

    print(f"[WARN][OBJQ][{tag}] duplicated object_id detected -> aggregating")
    num_cols = [c for c in df.columns if c != "object_id"]
    count_like = set([c for c in num_cols if c.startswith("n_")] + ["n_obs_total","n_bands_present","low_obs","low_bandcov"])

    agg = {}
    for c in num_cols:
        agg[c] = "max" if c in count_like else "mean"

    return df.groupby("object_id", as_index=False).agg(agg)

def _add_objq_flags(objq_df: pd.DataFrame):
    oq = objq_df.copy()
    if "n_obs_total" in oq.columns:
        oq["low_obs"] = (pd.to_numeric(oq["n_obs_total"], errors="coerce").fillna(0).astype(int) < 20).astype("int8")
    if "n_bands_present" in oq.columns:
        oq["low_bandcov"] = (pd.to_numeric(oq["n_bands_present"], errors="coerce").fillna(0).astype(int) <= 2).astype("int8")
    return oq

def _safe_join_objq_overwrite(df_meta: pd.DataFrame, objq_df: pd.DataFrame, tag="train"):
    df_meta = df_meta.copy(deep=False)
    df_meta.index = pd.Index([_norm_id(z) for z in df_meta.index.tolist()], name=df_meta.index.name)

    oq = objq_df.copy(deep=False)
    if "object_id" in oq.columns:
        oq["object_id"] = oq["object_id"].astype("string").map(_norm_id)
        oq = oq.set_index("object_id", drop=True)
    else:
        oq.index = pd.Index([_norm_id(z) for z in oq.index.tolist()], name=oq.index.name)

    overlap = df_meta.columns.intersection(oq.columns)
    if len(overlap) > 0:
        print(f"[OBJQ][{tag}] overlap detected -> overwrite from objq: {list(overlap[:12])}{'...' if len(overlap)>12 else ''}")
        df_meta = df_meta.drop(columns=list(overlap), errors="ignore")

    out = df_meta.join(oq, how="left")
    added_cols = [c for c in oq.columns if c in out.columns]
    return out, added_cols

objq_train_df = None
objq_test_df  = None
added_cols_train = []
added_cols_test  = []

if BUILD_OBJECT_QUALITY:
    do_train = True
    do_test = True
    if OBJQ_ONLY_IF_MISSING and (OBJECT_QUALITY_TRAIN_PATH.exists() and OBJECT_QUALITY_TEST_PATH.exists()) and (not OBJQ_REFRESH):
        do_train = do_test = False

    if do_train:
        objq_train_df = _build_object_quality("train", OBJECT_QUALITY_TRAIN_PATH, refresh=OBJQ_REFRESH, chunksize=OBJQ_CHUNK_ROWS)
    else:
        objq_train_df = pd.read_csv(OBJECT_QUALITY_TRAIN_PATH, dtype={"object_id":"string"}, **SAFE_READ_KW)

    if do_test:
        objq_test_df = _build_object_quality("test", OBJECT_QUALITY_TEST_PATH, refresh=OBJQ_REFRESH, chunksize=OBJQ_CHUNK_ROWS)
    else:
        objq_test_df = pd.read_csv(OBJECT_QUALITY_TEST_PATH, dtype={"object_id":"string"}, **SAFE_READ_KW)

    objq_train_df = _clean_objq_df(objq_train_df, "train")
    objq_test_df  = _clean_objq_df(objq_test_df,  "test")

    objq_train_df = _dedup_objq(objq_train_df, "train")
    objq_test_df  = _dedup_objq(objq_test_df,  "test")

    objq_train_df = _add_objq_flags(objq_train_df)
    objq_test_df  = _add_objq_flags(objq_test_df)

    df_train_meta, added_cols_train = _safe_join_objq_overwrite(df_train_meta, objq_train_df, tag="train")
    df_test_meta,  added_cols_test  = _safe_join_objq_overwrite(df_test_meta,  objq_test_df,  tag="test")

    for dfm, added in [(df_train_meta, added_cols_train), (df_test_meta, added_cols_test)]:
        for c in added:
            if c in dfm.columns:
                dfm[c] = pd.to_numeric(dfm[c], errors="coerce").fillna(0)

    try:
        df_train_meta.to_parquet(ART_DIR / "train_meta.parquet", index=True)
        df_test_meta.to_parquet(ART_DIR / "test_meta.parquet", index=True)
        meta_saved = "parquet"
    except Exception:
        df_train_meta.to_csv(ART_DIR / "train_meta.csv", index=True)
        df_test_meta.to_csv(ART_DIR / "test_meta.csv", index=True)
        meta_saved = "csv"

    with open(ART_DIR / "stage3_objq_summary.json", "w", encoding="utf-8") as f:
        json.dump({
            "stage": "stage3",
            "version": "v5.4",
            "BUILD_OBJECT_QUALITY": bool(BUILD_OBJECT_QUALITY),
            "OBJQ_REFRESH": bool(OBJQ_REFRESH),
            "OBJQ_CHUNK_ROWS": int(OBJQ_CHUNK_ROWS),
            "object_quality_train": str(OBJECT_QUALITY_TRAIN_PATH),
            "object_quality_test": str(OBJECT_QUALITY_TEST_PATH),
            "meta_overwrite_format": meta_saved,
            "added_objq_cols_train": sorted(list(set(added_cols_train))),
            "added_objq_cols_test": sorted(list(set(added_cols_test))),
        }, f, indent=2)

    print(f"[OBJQ] Joined into meta (SAFE overwrite v5.4) + saved updated meta ({meta_saved}) in {ART_DIR}")

# ----------------------------
# 8) Smoke test (schema + find a few objects quickly) — v5.4 safe StopIteration
# ----------------------------
rng = np.random.default_rng(SEED)
candidate_splits = []
for s in SPLIT_LIST:
    if len(train_ids_by_split.get(s, [])) > 0 and len(test_ids_by_split.get(s, [])) > 0:
        candidate_splits.append(s)
    if len(candidate_splits) >= 2:
        break
if len(candidate_splits) == 0:
    raise RuntimeError("Tidak ada split yang punya train & test objects (unexpected). Cek STAGE 2.")

SMOKE_CHUNK = int(CFG.get("SMOKE_CHUNK_ROWS", 80_000))
SMOKE_MAX_CHUNKS = int(CFG.get("SMOKE_MAX_CHUNKS", 6))
SMOKE_N_IDS = int(CFG.get("SMOKE_N_IDS_PER_SPLIT", 2))

for s in candidate_splits:
    try:
        ch_tr = next(iter_lightcurve_chunks(s, "train", chunksize=SMOKE_CHUNK, encode_filter=False))
        ch_te = next(iter_lightcurve_chunks(s, "test",  chunksize=SMOKE_CHUNK, encode_filter=False))
    except StopIteration:
        print(f"[WARN][SMOKE] split={s} StopIteration (file empty / all rows filtered). Skip.")
        continue

    if list(ch_tr.columns) != REQ_LC_KEYS:
        raise RuntimeError(f"Train chunk schema mismatch in {s}: {list(ch_tr.columns)}")
    if list(ch_te.columns) != REQ_LC_KEYS:
        raise RuntimeError(f"Test chunk schema mismatch in {s}: {list(ch_te.columns)}")

    tr_ids = train_ids_by_split[s]
    te_ids = test_ids_by_split[s]
    pick_tr = [tr_ids[i] for i in rng.integers(0, len(tr_ids), size=min(SMOKE_N_IDS, len(tr_ids)))]
    pick_te = [te_ids[i] for i in rng.integers(0, len(te_ids), size=min(SMOKE_N_IDS, len(te_ids)))]

    got_tr = load_many_object_lightcurves(s, "train", pick_tr, chunksize=SMOKE_CHUNK, max_chunks=SMOKE_MAX_CHUNKS, sort_time=False)
    got_te = load_many_object_lightcurves(s, "test",  pick_te, chunksize=SMOKE_CHUNK, max_chunks=SMOKE_MAX_CHUNKS, sort_time=False)

    miss_tr = [x for x in pick_tr if _norm_id(x) not in got_tr]
    miss_te = [x for x in pick_te if _norm_id(x) not in got_te]

    if miss_tr:
        print(f"[WARN][SMOKE] split={s} train not found within first {SMOKE_MAX_CHUNKS} chunks: {miss_tr[:2]}")
    if miss_te:
        print(f"[WARN][SMOKE] split={s} test not found within first {SMOKE_MAX_CHUNKS} chunks: {miss_te[:2]}")

print("STAGE 3 OK — LIGHTCURVE UTILITIES READY (+OBJQ full if enabled)")
print(f"- Saved manifest: {manifest_path}")
print(f"- Saved counts  : {counts_path}")
print(f"- LC_CACHE_DIR  : {LC_CACHE_DIR}")
print(f"- OBJQ train    : {OBJECT_QUALITY_TRAIN_PATH} (exists={OBJECT_QUALITY_TRAIN_PATH.exists()})")
print(f"- OBJQ test     : {OBJECT_QUALITY_TEST_PATH} (exists={OBJECT_QUALITY_TEST_PATH.exists()})")
print(f"- Smoke splits  : {candidate_splits}")

# ----------------------------
# 9) Export globals
# ----------------------------
globals().update({
    "SPLIT_FILES": SPLIT_FILES,
    "train_ids_by_split": train_ids_by_split,
    "test_ids_by_split": test_ids_by_split,
    "iter_lightcurve_chunks": iter_lightcurve_chunks,
    "load_object_lightcurve": load_object_lightcurve,
    "load_many_object_lightcurves": load_many_object_lightcurves,
    "load_object_lightcurve_cached": load_object_lightcurve_cached,
    "get_lc_cache_path": get_lc_cache_path,
    "LC_CACHE_DIR": LC_CACHE_DIR,
    "REQ_LC_KEYS": REQ_LC_KEYS,
    "ALLOWED_FILTERS": ALLOWED_FILTERS,
    "FILTER_ORDER": FILTER_ORDER,
    "FILTER2ID": FILTER2ID,
    "ID2FILTER": ID2FILTER,
    "OBJECT_QUALITY_TRAIN_PATH": str(OBJECT_QUALITY_TRAIN_PATH),
    "OBJECT_QUALITY_TEST_PATH": str(OBJECT_QUALITY_TEST_PATH),
    "df_train_meta": df_train_meta,
    "df_test_meta": df_test_meta,
})

gc.collect()


# Photometric Cleaning (De-extinction + Negative Flux Safe Transform)

In [None]:
# ============================================================
# STAGE 4 — Photometric Cleaning (FORCE OVERWRITE) — REVISI FULL v7.1
# - Dust de-extinction (Rubin DustValues().r_x) for ugrizy
# - Output:
#   * mag, mag_err (positive-only; non-detect uses DET_SIGMA*err as limit)
#   * asinh_mag, asinh_mag_err (Luptitude-like; works for negative flux)
#   * snr, snr_abs, snr_asinh (preserve sign via asinh)
# - Manifest + summary + config JSON
#
# v7.1 upgrade (dibanding v7.0):
# - filter lebih robust: support "0..5" + "lsst_u" + unexpected casing; invalid filter bisa drop (atau strict)
# - EBV robust: kalau EBV/EBV_clip tidak ada -> auto zeros
# - skip write untuk part kosong (hindari rows<=0 error)
# - summary tambah counters: dropped_bad_filter_rows, skipped_empty_parts
# ============================================================

import gc, json, warnings, time, shutil
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Require previous stages
# ----------------------------
for need in ["iter_lightcurve_chunks", "df_train_meta", "df_test_meta", "ART_DIR", "SPLIT_LIST", "LOG_DIR"]:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan STAGE 0 -> 1 -> 2 -> 3 dulu.")

ART_DIR = Path(ART_DIR); ART_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR = Path(LOG_DIR); LOG_DIR.mkdir(parents=True, exist_ok=True)

CFG = globals().get("CFG", {}) if isinstance(globals().get("CFG", {}), dict) else {}

# ----------------------------
# 1) Settings
# ----------------------------
CHUNKSIZE   = int(CFG.get("PHOT_CHUNKSIZE", 350_000))
ERR_EPS     = float(CFG.get("PHOT_ERR_EPS", 1e-6))

SNR_DET_POS = float(CFG.get("SNR_DET_POS", 3.0))  # detection for MAG branch (positive snr only)
DET_SIGMA   = float(CFG.get("DET_SIGMA", 3.0))

SNR_CLIP    = float(CFG.get("SNR_CLIP", 30.0))    # for stability
MIN_FLUX_POS_UJY = float(CFG.get("MIN_FLUX_POS_UJY", 1e-6))

MAG_MIN, MAG_MAX   = float(CFG.get("MAG_MIN", -10.0)), float(CFG.get("MAG_MAX", 50.0))
MAGERR_FLOOR_DET   = float(CFG.get("MAGERR_FLOOR_DET", 1e-3))
MAGERR_FLOOR_ND    = float(CFG.get("MAGERR_FLOOR_ND", 0.75))
MAGERR_CAP         = float(CFG.get("MAGERR_CAP", 10.0))

WRITE_FORMAT = str(CFG.get("PHOT_WRITE_FORMAT", "parquet")).lower()  # "parquet" or "csv.gz"
ONLY_SPLITS  = CFG.get("PHOT_ONLY_SPLITS", None)   # e.g. ["split_01"]
KEEP_FLUX_DEBUG = bool(CFG.get("PHOT_KEEP_FLUX_DEBUG", False))
DROP_BAD_TIME_ROWS = bool(CFG.get("PHOT_DROP_BAD_TIME_ROWS", True))
DROP_NAN_FLUX_ROWS  = bool(CFG.get("PHOT_DROP_NAN_FLUX_ROWS", False))  # if True, drop NaN flux instead of zeroing
REBUILD_MODE = str(CFG.get("PHOT_REBUILD_MODE", "wipe_all")).lower()  # "wipe_all" | "wipe_parts_only"

# v7.1: strict filter or drop invalid
STRICT_FILTER = bool(CFG.get("PHOT_STRICT_FILTER", False))

# ----------------------------
# 2) Extinction coefficients (Rubin DustValues().r_x)
#    r_x * EBV = A_x  (extinction in mag)
# ----------------------------
EXT_RLAMBDA = {
    "u": 4.757217815396922,
    "g": 3.6605664439892625,
    "r": 2.70136780871597,
    "i": 2.0536599130965882,
    "z": 1.5900964472616756,
    "y": 1.3077049588254708,
}

BAND2ID = {"u": 0, "g": 1, "r": 2, "i": 3, "z": 4, "y": 5}
ID2BAND = {v: k for k, v in BAND2ID.items()}
RLAM_BY_ID = np.array([EXT_RLAMBDA["u"], EXT_RLAMBDA["g"], EXT_RLAMBDA["r"], EXT_RLAMBDA["i"], EXT_RLAMBDA["z"], EXT_RLAMBDA["y"]], dtype=np.float32)

# filter aliases
FILTER_NUM2STR = {"0": "u", "1": "g", "2": "r", "3": "i", "4": "z", "5": "y"}

def _get_ebv_series(df_meta: pd.DataFrame):
    # pakai EBV_clip jika ada, fallback EBV, kalau tidak ada -> zeros
    if "EBV_clip" in df_meta.columns:
        s = df_meta["EBV_clip"]
    elif "EBV" in df_meta.columns:
        s = df_meta["EBV"]
    else:
        s = pd.Series(0.0, index=df_meta.index)
    # numeric + fill
    s = pd.to_numeric(s, errors="coerce").fillna(0.0)
    return s

EBV_TRAIN_SER = _get_ebv_series(df_train_meta)
EBV_TEST_SER  = _get_ebv_series(df_test_meta)

# flux unit assumed uJy -> AB mag zero point:
MAG_ZP = float(2.5 * np.log10(3631e6))  # ~23.9 (uJy)
K_2P5_LN10 = np.float32(2.5 / np.log(10.0))     # 2.5/ln(10)
K_1P0857 = np.float32(1.0857362)

# ----------------------------
# 3) Output root + WIPE (with safety guard)
# ----------------------------
LC_CLEAN_DIR = ART_DIR / "lc_clean_mag"   # keep name for compatibility

art_abs = ART_DIR.resolve()
lc_abs  = LC_CLEAN_DIR.resolve()

try:
    ok_rel = lc_abs.is_relative_to(art_abs)
except AttributeError:
    ok_rel = str(lc_abs).startswith(str(art_abs) + "/") or str(lc_abs).startswith(str(art_abs) + "\\")

if not ok_rel:
    raise RuntimeError(f"Safety guard failed: LC_CLEAN_DIR bukan turunan ART_DIR.\nART_DIR={art_abs}\nLC_CLEAN_DIR={lc_abs}")

if REBUILD_MODE == "wipe_all":
    if LC_CLEAN_DIR.exists():
        shutil.rmtree(LC_CLEAN_DIR, ignore_errors=True)
    LC_CLEAN_DIR.mkdir(parents=True, exist_ok=True)
elif REBUILD_MODE == "wipe_parts_only":
    LC_CLEAN_DIR.mkdir(parents=True, exist_ok=True)
else:
    raise ValueError("REBUILD_MODE must be 'wipe_all' or 'wipe_parts_only'")

# ----------------------------
# 4) Atomic writer
# ----------------------------
def _atomic_write_parquet(df: pd.DataFrame, out_path: Path):
    tmp = out_path.with_name(out_path.stem + ".tmp" + out_path.suffix)
    try:
        df.to_parquet(tmp, index=False)
        tmp.replace(out_path)
    finally:
        if tmp.exists() and (not out_path.exists()):
            try: tmp.unlink()
            except Exception: pass

def _atomic_write_csv_gz(df: pd.DataFrame, out_path: Path):
    final_path = out_path.with_suffix(".csv.gz")
    tmp = final_path.with_name(final_path.stem + ".tmp" + "".join(final_path.suffixes))
    try:
        df.to_csv(tmp, index=False, compression="gzip")
        tmp.replace(final_path)
    finally:
        if tmp.exists() and (not final_path.exists()):
            try: tmp.unlink()
            except Exception: pass
    return final_path

def write_part(df: pd.DataFrame, out_path: Path, fmt: str):
    out_path.parent.mkdir(parents=True, exist_ok=True)
    if fmt == "parquet":
        try:
            _atomic_write_parquet(df, out_path)
            return "parquet", out_path
        except Exception as e:
            alt = _atomic_write_csv_gz(df, out_path)
            return f"csv.gz (fallback from parquet: {type(e).__name__})", alt
    elif fmt == "csv.gz":
        alt = _atomic_write_csv_gz(df, out_path)
        return "csv.gz", alt
    else:
        raise ValueError("fmt must be 'parquet' or 'csv.gz'")

# ----------------------------
# 5) Core cleaning -> MAG + ASINH_MAG + SNR features
# ----------------------------
def _normalize_filter_series(filt_ser: pd.Series) -> pd.Series:
    f = filt_ser.astype("string").str.strip().str.lower()
    # common prefixes
    f = f.str.replace("lsst_", "", regex=False)
    f = f.str.replace("rubin_", "", regex=False)
    f = f.str.replace("band_", "", regex=False)
    # numeric 0..5
    f = f.replace(FILTER_NUM2STR)
    # if still long but endswith ugrizy, take last char
    last = f.str[-1]
    mask_end = last.isin(list(BAND2ID.keys()))
    f = f.where(f.isin(list(BAND2ID.keys())), last.where(mask_end, f))
    return f

def clean_chunk_to_phot(ch: pd.DataFrame, ebv_ser: pd.Series):
    # Required columns
    for c in ["object_id", "mjd", "flux", "flux_err", "filter"]:
        if c not in ch.columns:
            raise ValueError(f"iter_lightcurve_chunks chunk missing column: {c}. Found={list(ch.columns)}")

    oid_ser  = ch["object_id"].astype("string").str.strip()
    filt_ser = _normalize_filter_series(ch["filter"])

    # numeric arrays
    mjd  = ch["mjd"].to_numpy(copy=False).astype(np.float32, copy=False)
    flux = ch["flux"].to_numpy(copy=False).astype(np.float32, copy=False)
    err  = ch["flux_err"].to_numpy(copy=False).astype(np.float32, copy=False)

    # sanitize err
    err = np.nan_to_num(err, nan=np.float32(ERR_EPS), posinf=np.float32(ERR_EPS), neginf=np.float32(ERR_EPS))
    err = np.maximum(err, np.float32(ERR_EPS))

    # sanitize flux
    flux = flux.astype(np.float32, copy=False)
    flux[~np.isfinite(flux)] = np.float32(np.nan)

    # band id (robust)
    band_id_s = filt_ser.map(BAND2ID)
    bad_filter_rows = int(band_id_s.isna().sum())
    if bad_filter_rows:
        if STRICT_FILTER:
            bad = filt_ser[band_id_s.isna()].value_counts().head(10).index.tolist()
            raise ValueError(f"Unknown/invalid filter values encountered (top examples): {bad}")
        keep_f = (~band_id_s.isna()).to_numpy()
        oid_ser = oid_ser[keep_f]
        filt_ser = filt_ser[keep_f]
        mjd = mjd[keep_f]
        flux = flux[keep_f]
        err = err[keep_f]
        band_id_s = band_id_s[keep_f]

    band_id = band_id_s.to_numpy(copy=False).astype(np.int8, copy=False)

    # EBV map (index=object_id)
    ebv = oid_ser.map(ebv_ser).fillna(0.0).to_numpy(copy=False).astype(np.float32, copy=False)
    ebv[~np.isfinite(ebv)] = np.float32(0.0)

    # A_x
    rlam = RLAM_BY_ID[band_id.astype(np.int32)]
    A = (rlam * ebv).astype(np.float32, copy=False)

    # de-extinction in flux space: f0 = f * 10^(0.4 A)
    mul = np.power(np.float32(10.0), (np.float32(0.4) * A)).astype(np.float32, copy=False)
    flux_deext = (flux * mul).astype(np.float32, copy=False)
    err_deext  = (err  * mul).astype(np.float32, copy=False)

    ok_flux = np.isfinite(flux_deext)
    nan_flux_rows = int((~ok_flux).sum())

    if DROP_NAN_FLUX_ROWS and nan_flux_rows:
        keep_flux = ok_flux
    else:
        keep_flux = None
        if nan_flux_rows:
            flux_deext[~ok_flux] = np.float32(0.0)

    # SNR
    denom = np.maximum(err_deext, np.float32(ERR_EPS))
    snr = (flux_deext / denom).astype(np.float32, copy=False)
    snr = np.clip(snr, -np.float32(SNR_CLIP), np.float32(SNR_CLIP)).astype(np.float32, copy=False)
    snr_abs = np.abs(snr).astype(np.float32, copy=False)

    detected_pos = (snr > np.float32(SNR_DET_POS)).astype(np.int8, copy=False)

    # MAG branch (positive-only; non-detect uses DET_SIGMA*err)
    flux_detlim = (np.float32(DET_SIGMA) * err_deext).astype(np.float32, copy=False)
    flux_for_mag = np.where(
        detected_pos == 1,
        np.maximum(flux_deext, np.float32(MIN_FLUX_POS_UJY)),
        np.maximum(flux_detlim, np.float32(MIN_FLUX_POS_UJY)),
    ).astype(np.float32, copy=False)

    mag = (np.float32(MAG_ZP) - np.float32(2.5) * np.log10(flux_for_mag)).astype(np.float32, copy=False)
    mag = np.clip(mag, np.float32(MAG_MIN), np.float32(MAG_MAX)).astype(np.float32, copy=False)

    mag_err = (K_1P0857 * (err_deext / flux_for_mag)).astype(np.float32, copy=False)
    mag_err = np.clip(mag_err, np.float32(MAGERR_FLOOR_DET), np.float32(MAGERR_CAP)).astype(np.float32, copy=False)
    if MAGERR_FLOOR_ND is not None and float(MAGERR_FLOOR_ND) > 0:
        mag_err = np.where(detected_pos == 1, mag_err, np.maximum(mag_err, np.float32(MAGERR_FLOOR_ND))).astype(np.float32, copy=False)

    # ASINH magnitude (Luptitude-like; defined for negative flux)
    b = np.maximum(np.float32(DET_SIGMA) * err_deext, np.float32(MIN_FLUX_POS_UJY)).astype(np.float32, copy=False)
    x = (flux_deext / (np.float32(2.0) * b)).astype(np.float32, copy=False)
    asinh_term = np.arcsinh(x).astype(np.float32, copy=False)
    ln_b = np.log(b).astype(np.float32, copy=False)

    asinh_mag = (np.float32(MAG_ZP) - K_2P5_LN10 * (asinh_term + ln_b)).astype(np.float32, copy=False)
    asinh_mag = np.clip(asinh_mag, np.float32(MAG_MIN), np.float32(MAG_MAX)).astype(np.float32, copy=False)

    denom_asinh = np.sqrt((flux_deext * flux_deext) + (np.float32(2.0) * b) * (np.float32(2.0) * b)).astype(np.float32, copy=False)
    asinh_mag_err = (K_2P5_LN10 * (err_deext / np.maximum(denom_asinh, np.float32(ERR_EPS)))).astype(np.float32, copy=False)
    asinh_mag_err = np.clip(asinh_mag_err, np.float32(MAGERR_FLOOR_DET), np.float32(MAGERR_CAP)).astype(np.float32, copy=False)

    snr_asinh = np.arcsinh(snr).astype(np.float32, copy=False)

    out = pd.DataFrame({
        "object_id": pd.array(oid_ser.to_numpy(copy=False), dtype="string"),
        "mjd": mjd,
        "band_id": band_id,
        "mag": mag,
        "mag_err": mag_err,
        "asinh_mag": asinh_mag,
        "asinh_mag_err": asinh_mag_err,
        "snr": snr,
        "snr_abs": snr_abs,
        "snr_asinh": snr_asinh,
        "detected_pos": detected_pos,
    })

    if KEEP_FLUX_DEBUG:
        out["A_x"]        = pd.Series(A, dtype="float32")
        out["flux_deext"] = pd.Series(flux_deext, dtype="float32")
        out["err_deext"]  = pd.Series(err_deext, dtype="float32")
        out["b_soft"]     = pd.Series(b, dtype="float32")

    dropped_flux = 0
    if keep_flux is not None:
        keep = keep_flux
        dropped_flux = int((~keep).sum())
        out = out[keep]

    dropped_time = 0
    if DROP_BAD_TIME_ROWS:
        t = out["mjd"].to_numpy(copy=False).astype(np.float32, copy=False)
        keep_t = np.isfinite(t)
        dropped_time = int((~keep_t).sum())
        if dropped_time:
            out = out[keep_t]

    return out, int(dropped_time), int(nan_flux_rows), int(dropped_flux), int(bad_filter_rows)

# ----------------------------
# 6) Process split-wise
# ----------------------------
splits_to_use = ONLY_SPLITS if (ONLY_SPLITS is not None) else list(SPLIT_LIST)
splits_to_use = list(splits_to_use)

summary_rows, manifest_rows = [], []

def _wipe_parts_dir(out_dir: Path):
    if out_dir.exists():
        for pat in ["part_*.parquet", "part_*.csv.gz", "*.tmp", "*.tmp.parquet", "*.tmp.csv.gz"]:
            for f in out_dir.glob(pat):
                try: f.unlink()
                except Exception: pass

def process_split(split_name: str, which: str):
    ebv_ser = EBV_TRAIN_SER if which == "train" else EBV_TEST_SER
    out_dir = LC_CLEAN_DIR / split_name / which
    out_dir.mkdir(parents=True, exist_ok=True)

    if REBUILD_MODE == "wipe_parts_only":
        _wipe_parts_dir(out_dir)

    t0 = time.time()
    part_idx = 0
    n_rows_total = 0
    n_det_pos = 0
    dropped_time_total = 0
    nan_flux_total = 0
    dropped_flux_total = 0
    dropped_bad_filter_total = 0
    skipped_empty_parts = 0

    mag_min = np.inf
    mag_max = -np.inf
    asinh_mag_min = np.inf
    asinh_mag_max = -np.inf

    for ch in iter_lightcurve_chunks(split_name, which, chunksize=CHUNKSIZE):
        cleaned, dropped_time, nan_flux, dropped_flux, bad_filter_rows = clean_chunk_to_phot(ch, ebv_ser)

        dropped_time_total += int(dropped_time)
        nan_flux_total += int(nan_flux)
        dropped_flux_total += int(dropped_flux)
        dropped_bad_filter_total += int(bad_filter_rows)

        if cleaned is None or len(cleaned) == 0:
            skipped_empty_parts += 1
            continue

        n_rows = int(len(cleaned))
        n_rows_total += n_rows
        n_det_pos += int(cleaned["detected_pos"].to_numpy(copy=False).astype(np.int8).sum())

        mag_arr = cleaned["mag"].to_numpy(copy=False).astype(np.float32, copy=False)
        fin = np.isfinite(mag_arr)
        if fin.any():
            mag_min = float(min(mag_min, float(np.min(mag_arr[fin]))))
            mag_max = float(max(mag_max, float(np.max(mag_arr[fin]))))

        am_arr = cleaned["asinh_mag"].to_numpy(copy=False).astype(np.float32, copy=False)
        fin2 = np.isfinite(am_arr)
        if fin2.any():
            asinh_mag_min = float(min(asinh_mag_min, float(np.min(am_arr[fin2]))))
            asinh_mag_max = float(max(asinh_mag_max, float(np.max(am_arr[fin2]))))

        out_path = out_dir / f"part_{part_idx:04d}.parquet"
        used_fmt, final_path = write_part(cleaned, out_path, WRITE_FORMAT)

        manifest_rows.append({
            "split": split_name,
            "which": which,
            "part": int(part_idx),
            "path": str(final_path),
            "rows": int(n_rows),
            "format": str(used_fmt),
        })

        part_idx += 1
        del cleaned, ch
        if part_idx % 10 == 0:
            gc.collect()

    dt = time.time() - t0
    summary_rows.append({
        "split": split_name,
        "which": which,
        "parts": int(part_idx),
        "rows": int(n_rows_total),
        "det_pos_frac": float(n_det_pos / max(n_rows_total, 1)),
        "mag_min": (mag_min if np.isfinite(mag_min) else np.nan),
        "mag_max": (mag_max if np.isfinite(mag_max) else np.nan),
        "asinh_mag_min": (asinh_mag_min if np.isfinite(asinh_mag_min) else np.nan),
        "asinh_mag_max": (asinh_mag_max if np.isfinite(asinh_mag_max) else np.nan),
        "dropped_time_rows": int(dropped_time_total),
        "nan_flux_rows": int(nan_flux_total),
        "dropped_nan_flux_rows": int(dropped_flux_total),
        "dropped_bad_filter_rows": int(dropped_bad_filter_total),
        "skipped_empty_parts": int(skipped_empty_parts),
        "sec": float(dt),
    })

    print(
        f"[Stage 4] {split_name}/{which}: parts={part_idx} | rows={n_rows_total:,} | "
        f"det_pos%={100*(n_det_pos/max(n_rows_total,1)):.2f}% | "
        f"bad_filter_drop={dropped_bad_filter_total:,} | nan_flux={nan_flux_total:,} | drop_nan_flux={dropped_flux_total:,} | drop_time={dropped_time_total:,} | "
        f"mag=[{(mag_min if np.isfinite(mag_min) else np.nan):.2f},{(mag_max if np.isfinite(mag_max) else np.nan):.2f}] | "
        f"asinh_mag=[{(asinh_mag_min if np.isfinite(asinh_mag_min) else np.nan):.2f},{(asinh_mag_max if np.isfinite(asinh_mag_max) else np.nan):.2f}] | "
        f"time={dt:.1f}s"
    )

print(f"[Stage 4] REBUILD_MODE={REBUILD_MODE} | Writing to: {LC_CLEAN_DIR}")
print(f"[Stage 4] WRITE_FORMAT={WRITE_FORMAT} | CHUNKSIZE={CHUNKSIZE:,} | DROP_NAN_FLUX_ROWS={DROP_NAN_FLUX_ROWS} | STRICT_FILTER={STRICT_FILTER}")

for s in splits_to_use:
    process_split(s, "train")
    process_split(s, "test")

df_parts_manifest = pd.DataFrame(manifest_rows)
df_summary  = pd.DataFrame(summary_rows)

manifest_path = LC_CLEAN_DIR / "lc_clean_mag_manifest.csv"
summary_path  = LC_CLEAN_DIR / "lc_clean_mag_summary.csv"
df_parts_manifest.to_csv(manifest_path, index=False)
df_summary.to_csv(summary_path, index=False)

cfg_path = LC_CLEAN_DIR / "photometric_config_mag.json"
with open(cfg_path, "w", encoding="utf-8") as f:
    json.dump({
        "STAGE": "stage4",
        "VERSION": "v7.1",
        "EXT_RLAMBDA": EXT_RLAMBDA,
        "SNR_DET_POS": float(SNR_DET_POS),
        "DET_SIGMA": float(DET_SIGMA),
        "ERR_EPS": float(ERR_EPS),
        "SNR_CLIP": float(SNR_CLIP),
        "MAG_ZP": float(MAG_ZP),
        "MAG_MIN": float(MAG_MIN),
        "MAG_MAX": float(MAG_MAX),
        "CHUNKSIZE": int(CHUNKSIZE),
        "WRITE_FORMAT": str(WRITE_FORMAT),
        "ONLY_SPLITS": list(splits_to_use),
        "KEEP_FLUX_DEBUG": bool(KEEP_FLUX_DEBUG),
        "DROP_BAD_TIME_ROWS": bool(DROP_BAD_TIME_ROWS),
        "DROP_NAN_FLUX_ROWS": bool(DROP_NAN_FLUX_ROWS),
        "REBUILD_MODE": str(REBUILD_MODE),
        "STRICT_FILTER": bool(STRICT_FILTER),
        "SCHEMA": "mag + asinh_mag + snr_asinh",
        "COLUMNS": [
            "object_id","mjd","band_id",
            "mag","mag_err",
            "asinh_mag","asinh_mag_err",
            "snr","snr_abs","snr_asinh",
            "detected_pos",
        ] + (["A_x","flux_deext","err_deext","b_soft"] if KEEP_FLUX_DEBUG else []),
    }, f, indent=2)

# sanity: manifest coverage
if len(df_parts_manifest) == 0:
    raise RuntimeError("Stage 4 produced empty manifest. Check iter_lightcurve_chunks and split routing.")
if (df_parts_manifest["rows"] <= 0).any():
    bad = df_parts_manifest.loc[df_parts_manifest["rows"] <= 0].head(5).to_dict("records")
    raise RuntimeError(f"Found empty part files (rows<=0). Examples: {bad}")

print("\n[Stage 4] Done.")
print(f"- LC_CLEAN_DIR  : {LC_CLEAN_DIR}")
print(f"- Saved manifest: {manifest_path}")
print(f"- Saved summary : {summary_path}")
print(f"- Saved config  : {cfg_path}")

def get_clean_parts(split_name: str, which: str):
    m = df_parts_manifest[(df_parts_manifest["split"] == split_name) & (df_parts_manifest["which"] == which)].sort_values("part")
    return m["path"].astype(str).tolist()

globals().update({
    "EXT_RLAMBDA": EXT_RLAMBDA,
    "BAND2ID": BAND2ID,
    "ID2BAND": ID2BAND,
    "MAG_ZP": MAG_ZP,
    "LC_CLEAN_DIR": LC_CLEAN_DIR,
    "lc_clean_mag_manifest": df_parts_manifest,
    "lc_clean_mag_summary": df_summary,
    "get_clean_parts": get_clean_parts,
})

gc.collect()


# Sequence Tokenization (Event-based Tokens)

In [None]:
# ============================================================
# STAGE 5 — Sequence Tokenization (Event-based Tokens) (ONE CELL)
# REVISI FULL v6.1 (FORCE ASINH + FEATURESET v2 + STAGE4 v7.1 compatible + robust meta + safer reuse)
#
# Compatible STAGE 4 schemas:
# - v7+: columns include: asinh_mag, asinh_mag_err, detected_pos, snr (and/or snr_asinh), band_id
# - legacy: columns include: flux_asinh, err_log1p, detected, snr, band_id
#
# Output:
# - ART_DIR/seq_tokens/<split>/<train|test>/shard_*.npz
# - ART_DIR/seq_tokens/seq_manifest_train.csv
# - ART_DIR/seq_tokens/seq_manifest_test.csv
# - ART_DIR/seq_tokens/seq_build_stats.csv
# - ART_DIR/seq_tokens/seq_config.json
#
# NPZ shard arrays:
# - object_id: (n_obj,) bytes
# - x       : (total_tokens, feature_dim) float32
# - band    : (total_tokens,) int8   (band_id per token)
# - offsets : (n_obj+1,) int64       (start offsets for each object)
# ============================================================

import gc, json, warnings, time, shutil, os
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Require minimal globals
# ----------------------------
for need in ["ART_DIR", "df_train_meta", "df_test_meta"]:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan minimal STAGE 0 + STAGE 2 dulu (ART_DIR + meta).")

ART_DIR = Path(ART_DIR)
ART_DIR.mkdir(parents=True, exist_ok=True)

CFG = globals().get("CFG", {})
CFG = CFG if isinstance(CFG, dict) else {}
SEED = int(globals().get("SEED", 2025))

# ----------------------------
# 1) Helpers
# ----------------------------
def _safe_string_series(s: pd.Series) -> pd.Series:
    try:
        return s.astype("string").str.strip()
    except Exception:
        return s.astype(str).str.strip()

def _find_stage4_manifest(art_dir: Path):
    cand = art_dir / "lc_clean_mag" / "lc_clean_mag_manifest.csv"
    if cand.exists():
        return cand
    root = Path("/kaggle/working/mallorn_run")
    if not root.exists():
        return None
    cands = list(root.glob("run_*/artifacts/lc_clean_mag/lc_clean_mag_manifest.csv"))
    if not cands:
        cands = list(root.glob("run_*/**/lc_clean_mag_manifest.csv"))
    if not cands:
        return None
    cands = sorted(cands, key=lambda p: p.stat().st_mtime, reverse=True)
    return cands[0]

def _sync_dirs_from_manifest(manifest_csv: Path):
    lc_clean_dir = manifest_csv.parent
    art_dir_new  = lc_clean_dir.parent
    run_dir_new  = art_dir_new.parent
    return run_dir_new, art_dir_new, lc_clean_dir

def _read_meta_file(art_dir_synced: Path, which: str) -> pd.DataFrame:
    pq = art_dir_synced / f"{which}_meta.parquet"
    csv = art_dir_synced / f"{which}_meta.csv"
    if pq.exists():
        df = pd.read_parquet(pq)
    elif csv.exists():
        df = pd.read_csv(csv)
    else:
        return None
    if isinstance(df.index, pd.RangeIndex) and ("object_id" in df.columns):
        df = df.set_index("object_id", drop=True)
    elif ("object_id" in df.columns) and (df.index.name != "object_id"):
        df = df.set_index("object_id", drop=True)
    if isinstance(df.index, pd.RangeIndex):
        for c in ["Unnamed: 0", "index"]:
            if c in df.columns:
                df = df.set_index(c, drop=True)
                break
    df.index = df.index.astype("string")
    return df

def _load_meta_if_needed(art_dir_synced: Path):
    global df_train_meta, df_test_meta
    cand_train = _read_meta_file(art_dir_synced, "train")
    cand_test  = _read_meta_file(art_dir_synced, "test")
    if cand_train is None or cand_test is None:
        return False, "meta file not found in synced ART_DIR; keep in-memory"
    try:
        if (len(df_train_meta) != len(cand_train)) or (len(df_test_meta) != len(cand_test)):
            df_train_meta = cand_train
            df_test_meta  = cand_test
            return True, "reloaded meta due to size mismatch"
        sample_ids = df_train_meta.index[:5].astype(str).tolist()
        if not all((sid in cand_train.index) for sid in sample_ids):
            df_train_meta = cand_train
            df_test_meta  = cand_test
            return True, "reloaded meta due to id mismatch"
        return False, "meta already consistent"
    except Exception as e:
        return False, f"meta reload skipped ({type(e).__name__}: {e})"

def _ensure_meta_features(meta_df: pd.DataFrame) -> pd.DataFrame:
    """Ensure columns needed for per-token meta exist: EBV_clip, log1pZ, zerr_rel, is_photoz."""
    df = meta_df.copy(deep=False)

    # EBV_clip
    if "EBV_clip" not in df.columns:
        if "EBV" in df.columns:
            df["EBV_clip"] = pd.to_numeric(df["EBV"], errors="coerce")
        else:
            df["EBV_clip"] = 0.0
    df["EBV_clip"] = pd.to_numeric(df["EBV_clip"], errors="coerce").fillna(0.0).astype(np.float32)

    # choose Z source
    if "Z" in df.columns:
        z = pd.to_numeric(df["Z"], errors="coerce")
    elif "Z_clip" in df.columns:
        z = pd.to_numeric(df["Z_clip"], errors="coerce")
    elif "photoz" in df.columns:
        z = pd.to_numeric(df["photoz"], errors="coerce")
    else:
        z = pd.Series(0.0, index=df.index)

    z = z.fillna(0.0)
    z_pos = np.maximum(z.to_numpy(dtype=np.float32, copy=False), 0.0).astype(np.float32)
    df["log1pZ"] = np.log1p(z_pos).astype(np.float32)

    # Z_err relative
    if "Z_err" in df.columns:
        zerr = pd.to_numeric(df["Z_err"], errors="coerce").fillna(0.0).to_numpy(dtype=np.float32, copy=False)
    elif "Z_err_clip" in df.columns:
        zerr = pd.to_numeric(df["Z_err_clip"], errors="coerce").fillna(0.0).to_numpy(dtype=np.float32, copy=False)
    else:
        zerr = np.zeros((len(df),), dtype=np.float32)

    denom = np.maximum(z_pos, np.float32(1e-3))
    df["zerr_rel"] = (np.asarray(zerr, dtype=np.float32) / denom).astype(np.float32)

    # is_photoz
    if "is_photoz" in df.columns:
        ip = pd.to_numeric(df["is_photoz"], errors="coerce").fillna(0).astype(np.int8)
    elif "photoz" in df.columns:
        ip = pd.to_numeric(df["photoz"], errors="coerce").notna().astype(np.int8)
    else:
        ip = pd.Series(0, index=df.index, dtype=np.int8)
    df["is_photoz"] = ip.astype(np.int8)

    return df

# ----------------------------
# 2) Locate STAGE 4 output
# ----------------------------
manifest_csv = _find_stage4_manifest(ART_DIR)
if manifest_csv is None:
    root = Path("/kaggle/working/mallorn_run")
    runs = sorted([p.name for p in root.glob("run_*") if p.is_dir()])[-15:] if root.exists() else []
    raise RuntimeError(
        "Output STAGE 4 (lc_clean_mag_manifest.csv) tidak ditemukan.\n"
        f"- ART_DIR saat ini: {ART_DIR}\n"
        f"- Expected: {ART_DIR/'lc_clean_mag'}\n"
        f"- Runs available (last 15): {runs}\n"
        "Solusi: pastikan STAGE 4 selesai dan menulis artifacts/lc_clean_mag."
    )

RUN_DIR, ART_DIR, LC_CLEAN_DIR = _sync_dirs_from_manifest(manifest_csv)
ART_DIR = Path(ART_DIR); LC_CLEAN_DIR = Path(LC_CLEAN_DIR)

print("STAGE 5 ROUTING SYNC OK")
print(f"- RUN_DIR      : {RUN_DIR}")
print(f"- ART_DIR      : {ART_DIR}")
print(f"- LC_CLEAN_DIR : {LC_CLEAN_DIR}")
print(f"- manifest_csv : {manifest_csv}")

reloaded, msg = _load_meta_if_needed(ART_DIR)
print(f"- meta_sync    : {msg}")

# normalize meta index
df_train_meta = df_train_meta.copy(deep=False)
df_test_meta  = df_test_meta.copy(deep=False)
df_train_meta.index = df_train_meta.index.astype("string")
df_test_meta.index  = df_test_meta.index.astype("string")

# ensure meta derived features exist
df_train_meta = _ensure_meta_features(df_train_meta)
df_test_meta  = _ensure_meta_features(df_test_meta)

# ----------------------------
# 3) Load & validate Stage4 manifest
# ----------------------------
_df_clean_manifest = pd.read_csv(manifest_csv)
_df_clean_manifest.columns = [c.strip() for c in _df_clean_manifest.columns]

need_cols = {"split", "which", "part", "path"}
miss = sorted(list(need_cols - set(_df_clean_manifest.columns)))
if miss:
    raise RuntimeError(f"Manifest STAGE 4 missing columns: {miss} | cols={list(_df_clean_manifest.columns)}")

paths = _df_clean_manifest["path"].astype(str).tolist()
missing_paths = [p for p in paths if not Path(p).exists()]
if missing_paths:
    raise RuntimeError(
        "Ada file part STAGE 4 yang hilang.\n"
        f"Missing count={len(missing_paths)} | contoh={missing_paths[:10]}\n"
        "Solusi: rerun STAGE 4 (wipe_all) untuk regenerasi cache."
    )

def get_clean_parts(split_name: str, which: str):
    m = _df_clean_manifest[(_df_clean_manifest["split"] == split_name) & (_df_clean_manifest["which"] == which)]
    if m.empty:
        return []
    return m.sort_values("part")["path"].astype(str).tolist()

# ----------------------------
# 4) Recover SPLIT_LIST + routing ids
# ----------------------------
if "split" not in df_train_meta.columns or "split" not in df_test_meta.columns:
    raise RuntimeError("Kolom `split` tidak ada di meta. Pastikan STAGE 2/3 membuat routing split di meta.")

splits_meta = sorted(set(df_train_meta["split"].astype(str).tolist()) | set(df_test_meta["split"].astype(str).tolist()))
splits_in_manifest = sorted(set(_df_clean_manifest["split"].astype(str).tolist()))
SPLIT_LIST = sorted(set(splits_in_manifest) | set(splits_meta))
SPLITS_TO_CONSIDER = [s for s in SPLIT_LIST if s in splits_in_manifest]

train_ids_by_split = {s: [] for s in SPLIT_LIST}
for oid, sp in df_train_meta["split"].astype(str).items():
    if sp in train_ids_by_split:
        train_ids_by_split[sp].append(str(oid))

test_ids_by_split = {s: [] for s in SPLIT_LIST}
for oid, sp in df_test_meta["split"].astype(str).items():
    if sp in test_ids_by_split:
        test_ids_by_split[sp].append(str(oid))

# ----------------------------
# 5) Settings (recommended defaults)
# ----------------------------
ONLY_SPLITS = CFG.get("STAGE5_ONLY_SPLITS", None)

REBUILD_MODE = str(CFG.get("STAGE5_REBUILD_MODE", "wipe_all")).lower()   # "wipe_all" or "reuse_if_exists"
if REBUILD_MODE not in ("wipe_all", "reuse_if_exists"):
    REBUILD_MODE = "wipe_all"

COMPRESS_NPZ = bool(CFG.get("STAGE5_COMPRESS_NPZ", False))
SHARD_MAX_OBJECTS = int(CFG.get("STAGE5_SHARD_MAX_OBJECTS", 1500))

SNR_TANH_SCALE = float(CFG.get("STAGE5_SNR_TANH_SCALE", 10.0))
TIME_CLIP_MAX_DAYS = CFG.get("STAGE5_TIME_CLIP_MAX_DAYS", None)
TIME_CLIP_MAX_DAYS = None if TIME_CLIP_MAX_DAYS in [None, "None", "none", ""] else float(TIME_CLIP_MAX_DAYS)
DROP_BAD_TIME_ROWS = bool(CFG.get("STAGE5_DROP_BAD_TIME_ROWS", True))

L_MAX = int(CFG.get("L_MAX", 256))
TRUNC_POLICY = str(CFG.get("TRUNC_POLICY", "smart")).lower()      # smart/head/none
KEEP_DET_FRAC = float(CFG.get("KEEP_DET_FRAC", 0.70))
KEEP_EDGE = bool(CFG.get("KEEP_EDGE", True))
USE_RESTFRAME_TIME = bool(CFG.get("USE_RESTFRAME_TIME", True))

NUM_BUCKETS = int(CFG.get("SEQ_NUM_BUCKETS", 256))

# FORCE ASINH
TOKEN_MODE_FORCE = "asinh"
FEATURE_SET = str(CFG.get("SEQ_FEATURE_SET", "v2")).lower()   # v2 recommended
ADD_OBJ_META_PER_TOKEN = bool(CFG.get("SEQ_ADD_META_PER_TOKEN", True))

OBJ_META_COLS = ["EBV_clip", "log1pZ", "zerr_rel", "is_photoz"]

SEQ_DIR = Path(ART_DIR) / "seq_tokens"
SEQ_DIR.mkdir(parents=True, exist_ok=True)

if FEATURE_SET == "v2":
    CORE_FEATURES = ["t_rel_log", "dt_log", "dt_band_log", "signal", "err_log", "snr_tanh", "detected", "band_change", "delta_signal"]
else:
    CORE_FEATURES = ["t_rel_log", "dt_log", "signal", "err_log", "snr_tanh", "detected"]

META_FEATURES = [f"meta_{c}" for c in OBJ_META_COLS] if ADD_OBJ_META_PER_TOKEN else []
FEATURE_NAMES = CORE_FEATURES + META_FEATURES
FEATURE_DIM = len(FEATURE_NAMES)

# ----------------------------
# 6) Reader for cleaned parts (schema auto; force-asinh)
# ----------------------------
BASE_COLS_MIN = {"object_id", "mjd", "band_id", "snr"}

def _read_clean_part(path: str) -> pd.DataFrame:
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"Clean part missing: {p}")

    if p.suffix == ".parquet":
        df = pd.read_parquet(p)
    elif p.name.endswith(".csv.gz"):
        df = pd.read_csv(p, compression="gzip")
    else:
        df = pd.read_csv(p)

    df.columns = [c.strip() for c in df.columns]
    cols = set(df.columns)

    if not BASE_COLS_MIN.issubset(cols):
        raise RuntimeError(f"Clean part missing base cols {sorted(list(BASE_COLS_MIN - cols))} | file={p}")

    # unify detection column -> detected (int8)
    if "detected" in cols:
        df["detected"] = pd.to_numeric(df["detected"], errors="coerce").fillna(0).astype(np.int8)
    elif "detected_pos" in cols:
        df["detected"] = pd.to_numeric(df["detected_pos"], errors="coerce").fillna(0).astype(np.int8)
    else:
        df["detected"] = (pd.to_numeric(df["snr"], errors="coerce").fillna(0.0).astype(np.float32) > 0).astype(np.int8)

    # FORCE ASINH signal/err
    if ("asinh_mag" in cols) and ("asinh_mag_err" in cols):
        df["signal"] = pd.to_numeric(df["asinh_mag"], errors="coerce").astype(np.float32)
        err_lin = pd.to_numeric(df["asinh_mag_err"], errors="coerce").astype(np.float32).to_numpy(copy=False)
        err_lin = np.nan_to_num(err_lin, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
        err_lin = np.maximum(err_lin, 0.0).astype(np.float32)
        df["err_log"] = np.log1p(err_lin).astype(np.float32)
    elif ("flux_asinh" in cols) and ("err_log1p" in cols):
        df["signal"] = pd.to_numeric(df["flux_asinh"], errors="coerce").astype(np.float32)
        df["err_log"] = pd.to_numeric(df["err_log1p"], errors="coerce").astype(np.float32)
    else:
        raise RuntimeError(
            "FORCE ASINH gagal: tidak menemukan pasangan kolom asinh_mag/asinh_mag_err "
            "atau flux_asinh/err_log1p.\n"
            f"Found cols sample={list(df.columns)[:40]} | file={p}"
        )

    df["object_id"] = _safe_string_series(df["object_id"])
    df["mjd"] = pd.to_numeric(df["mjd"], errors="coerce").astype(np.float32)
    df["band_id"] = pd.to_numeric(df["band_id"], errors="coerce").astype(np.int16)
    df["snr"] = pd.to_numeric(df["snr"], errors="coerce").astype(np.float32)
    df["signal"] = pd.to_numeric(df["signal"], errors="coerce").astype(np.float32)
    df["err_log"] = pd.to_numeric(df["err_log"], errors="coerce").astype(np.float32)

    # band sanity -> 0..5
    b = df["band_id"].to_numpy(copy=False)
    okb = (b >= 0) & (b <= 5)
    if not np.all(okb):
        df = df[okb]

    if DROP_BAD_TIME_ROWS:
        df = df[np.isfinite(df["mjd"].to_numpy(copy=False))]

    keep = ["object_id", "mjd", "band_id", "snr", "detected", "signal", "err_log"]
    return df[keep]

# ----------------------------
# 7) Truncation (smart)
# ----------------------------
def _smart_truncate(mjd, det, snr, Lmax: int):
    n = len(mjd)
    if n <= Lmax:
        return np.arange(n, dtype=np.int64)

    idx_all = np.arange(n, dtype=np.int64)
    keep = set()
    if KEEP_EDGE and n >= 2:
        keep.add(0); keep.add(n - 1)

    det_idx = idx_all[det.astype(bool)]
    k_det = int(max(0, min(len(det_idx), int(np.floor(Lmax * KEEP_DET_FRAC)))))
    if k_det > 0 and len(det_idx) > 0:
        score = np.abs(snr[det_idx])
        top = det_idx[np.argsort(-score)[:k_det]]
        for i in top.tolist():
            keep.add(int(i))

    if len(keep) < Lmax:
        rem = [i for i in idx_all.tolist() if i not in keep]
        need = Lmax - len(keep)
        if rem and need > 0:
            pick = np.linspace(0, len(rem) - 1, num=need, dtype=int)
            for p in pick.tolist():
                keep.add(int(rem[p]))

    out = np.array(sorted(keep), dtype=np.int64)
    if len(out) > Lmax:
        pos = np.linspace(0, len(out) - 1, num=Lmax, dtype=int)
        out = out[pos]
    return out

# ----------------------------
# 8) Meta per-token helper
# ----------------------------
def _get_obj_meta_vec(meta_df: pd.DataFrame, oid: str) -> np.ndarray:
    if (not ADD_OBJ_META_PER_TOKEN) or (oid not in meta_df.index):
        return np.zeros((len(OBJ_META_COLS),), dtype=np.float32)
    vals = []
    for c in OBJ_META_COLS:
        if c in meta_df.columns:
            v = meta_df.loc[oid, c]
            v = float(v) if (v is not None and np.isfinite(v)) else 0.0
        else:
            v = 0.0
        vals.append(v)
    return np.asarray(vals, dtype=np.float32)

def build_empty_tokens():
    X = np.zeros((1, int(FEATURE_DIM)), dtype=np.float32)
    B = np.full((1,), -1, dtype=np.int8)
    return X, B, 0, 1

# ----------------------------
# 9) Build tokens per object (v2)
# ----------------------------
def build_object_tokens(df_obj: pd.DataFrame, meta_df: pd.DataFrame, z_val: float = 0.0):
    if df_obj is None or df_obj.empty:
        return build_empty_tokens()

    mjd = df_obj["mjd"].to_numpy(dtype=np.float32, copy=False)
    band = df_obj["band_id"].to_numpy(dtype=np.int16, copy=False)
    snr  = df_obj["snr"].to_numpy(dtype=np.float32, copy=False)
    det  = df_obj["detected"].to_numpy(dtype=np.int8, copy=False)
    sig  = df_obj["signal"].to_numpy(dtype=np.float32, copy=False)
    err_log = df_obj["err_log"].to_numpy(dtype=np.float32, copy=False)

    order = np.lexsort((band, mjd))
    mjd = mjd[order]; band = band[order]; snr = snr[order]; det = det[order]
    sig = sig[order]; err_log = err_log[order]

    z = float(z_val) if (z_val is not None and np.isfinite(z_val)) else 0.0
    z = max(z, 0.0)
    denom = (1.0 + z) if USE_RESTFRAME_TIME else 1.0

    t0 = float(mjd[0])
    t_rel = ((mjd - np.float32(t0)) / np.float32(denom)).astype(np.float32)
    if TIME_CLIP_MAX_DAYS is not None:
        mx = np.float32(TIME_CLIP_MAX_DAYS)
        t_rel = np.clip(t_rel, 0.0, mx)

    dt = np.empty_like(t_rel, dtype=np.float32)
    dt[0] = np.float32(0.0)
    if len(t_rel) > 1:
        dt[1:] = np.maximum(t_rel[1:] - t_rel[:-1], 0.0).astype(np.float32)

    dt_band = np.zeros_like(dt, dtype=np.float32)
    last_mjd = {}
    for i in range(len(mjd)):
        b = int(band[i])
        if b in last_mjd:
            dt_band[i] = max(float((mjd[i] - last_mjd[b]) / denom), 0.0)
        last_mjd[b] = float(mjd[i])
    if TIME_CLIP_MAX_DAYS is not None:
        mx = np.float32(TIME_CLIP_MAX_DAYS)
        dt = np.clip(dt, 0.0, mx)
        dt_band = np.clip(dt_band, 0.0, mx)

    t_rel_log = np.log1p(t_rel).astype(np.float32)
    dt_log = np.log1p(dt).astype(np.float32)
    dt_band_log = np.log1p(dt_band).astype(np.float32)

    snr = np.nan_to_num(snr, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
    snr_tanh = np.tanh(snr / np.float32(SNR_TANH_SCALE)).astype(np.float32)
    det_f = det.astype(np.float32)

    band_change = np.zeros((len(band),), dtype=np.float32)
    delta_signal = np.zeros((len(sig),), dtype=np.float32)
    if len(band) > 1:
        band_change[1:] = (band[1:] != band[:-1]).astype(np.float32)
        delta_signal[1:] = (sig[1:] - sig[:-1]).astype(np.float32)

    if FEATURE_SET == "v2":
        X = np.stack(
            [t_rel_log, dt_log, dt_band_log, sig, err_log, snr_tanh, det_f, band_change, delta_signal],
            axis=1
        ).astype(np.float32)
    else:
        X = np.stack([t_rel_log, dt_log, sig, err_log, snr_tanh, det_f], axis=1).astype(np.float32)

    if ADD_OBJ_META_PER_TOKEN:
        oid = str(df_obj["object_id"].iloc[0])
        mv = _get_obj_meta_vec(meta_df, oid)
        mv_rep = np.repeat(mv[None, :], repeats=X.shape[0], axis=0).astype(np.float32)
        X = np.concatenate([X, mv_rep], axis=1).astype(np.float32)

    L0 = int(X.shape[0])

    if L_MAX and int(L_MAX) > 0 and X.shape[0] > int(L_MAX):
        if TRUNC_POLICY == "smart":
            keep = _smart_truncate(mjd, det, snr, int(L_MAX))
        elif TRUNC_POLICY == "head":
            keep = np.arange(int(L_MAX), dtype=np.int64)
        elif TRUNC_POLICY in ("none", "full"):
            keep = np.arange(X.shape[0], dtype=np.int64)
        else:
            keep = _smart_truncate(mjd, det, snr, int(L_MAX))

        if len(keep) != X.shape[0]:
            keep = keep.astype(np.int64)

            mjd2 = mjd[keep]
            band2 = band[keep].astype(np.int16)
            snr2 = snr[keep].astype(np.float32)
            det2 = det[keep].astype(np.int8)
            sig2 = sig[keep].astype(np.float32)
            err_log2 = err_log[keep].astype(np.float32)

            t0 = float(mjd2[0])
            t_rel2 = ((mjd2 - np.float32(t0)) / np.float32(denom)).astype(np.float32)
            if TIME_CLIP_MAX_DAYS is not None:
                mx = np.float32(TIME_CLIP_MAX_DAYS)
                t_rel2 = np.clip(t_rel2, 0.0, mx)

            dt2 = np.empty_like(t_rel2, dtype=np.float32)
            dt2[0] = np.float32(0.0)
            if len(t_rel2) > 1:
                dt2[1:] = np.maximum(t_rel2[1:] - t_rel2[:-1], 0.0).astype(np.float32)

            dtb2 = np.zeros_like(dt2, dtype=np.float32)
            last_mjd2 = {}
            for i in range(len(mjd2)):
                b = int(band2[i])
                if b in last_mjd2:
                    dtb2[i] = max(float((mjd2[i] - last_mjd2[b]) / denom), 0.0)
                last_mjd2[b] = float(mjd2[i])

            if TIME_CLIP_MAX_DAYS is not None:
                mx = np.float32(TIME_CLIP_MAX_DAYS)
                dt2 = np.clip(dt2, 0.0, mx)
                dtb2 = np.clip(dtb2, 0.0, mx)

            t_rel_log2 = np.log1p(t_rel2).astype(np.float32)
            dt_log2 = np.log1p(dt2).astype(np.float32)
            dtb_log2 = np.log1p(dtb2).astype(np.float32)

            snr_tanh2 = np.tanh(snr2 / np.float32(SNR_TANH_SCALE)).astype(np.float32)
            det_f2 = det2.astype(np.float32)

            band_change2 = np.zeros((len(band2),), dtype=np.float32)
            delta_signal2 = np.zeros((len(sig2),), dtype=np.float32)
            if len(band2) > 1:
                band_change2[1:] = (band2[1:] != band2[:-1]).astype(np.float32)
                delta_signal2[1:] = (sig2[1:] - sig2[:-1]).astype(np.float32)

            if FEATURE_SET == "v2":
                Xcore2 = np.stack(
                    [t_rel_log2, dt_log2, dtb_log2, sig2, err_log2, snr_tanh2, det_f2, band_change2, delta_signal2],
                    axis=1
                ).astype(np.float32)
            else:
                Xcore2 = np.stack([t_rel_log2, dt_log2, sig2, err_log2, snr_tanh2, det_f2], axis=1).astype(np.float32)

            if ADD_OBJ_META_PER_TOKEN:
                oid = str(df_obj["object_id"].iloc[0])
                mv = _get_obj_meta_vec(meta_df, oid)
                mv_rep = np.repeat(mv[None, :], repeats=Xcore2.shape[0], axis=0).astype(np.float32)
                X2 = np.concatenate([Xcore2, mv_rep], axis=1).astype(np.float32)
            else:
                X2 = Xcore2

            return X2, band2.astype(np.int8), L0, int(X2.shape[0])

    return X, band.astype(np.int8), L0, int(X.shape[0])

# ----------------------------
# 10) Shard writer + reuse manifest
# ----------------------------
def save_shard(out_path: Path, object_ids, X_concat, B_concat, offsets):
    out_path.parent.mkdir(parents=True, exist_ok=True)
    obj_arr = np.asarray(object_ids, dtype="S")
    if COMPRESS_NPZ:
        np.savez_compressed(out_path, object_id=obj_arr, x=X_concat, band=B_concat, offsets=offsets)
    else:
        np.savez(out_path, object_id=obj_arr, x=X_concat, band=B_concat, offsets=offsets)

def reconstruct_manifest_from_shards(split_name: str, which: str, out_dir: Path):
    rows = []
    for sp in sorted(out_dir.glob("shard_*.npz")):
        data = np.load(sp, allow_pickle=False)
        obj = data["object_id"]
        offsets = data["offsets"].astype(np.int64)
        if offsets.ndim != 1 or offsets.size < 2:
            continue
        if len(obj) != (offsets.size - 1):
            raise RuntimeError(f"Bad shard (object_id len != offsets-1): {sp}")
        lengths = offsets[1:] - offsets[:-1]
        for i in range(len(lengths)):
            oid = obj[i]
            oid = oid.decode("utf-8", errors="ignore") if isinstance(oid, (bytes, np.bytes_)) else str(oid)
            rows.append({
                "object_id": oid,
                "split": split_name,
                "which": which,
                "shard": str(sp),
                "start": int(offsets[i]),
                "length": int(lengths[i]),
            })
    return rows

# ----------------------------
# 11) Bucket builder (IO hemat)
# ----------------------------
def build_sequences_bucket(split_name: str, which: str, expected_ids: set, out_dir: Path, num_buckets: int = 256):
    try:
        import pyarrow as pa
        import pyarrow.parquet as pq
    except Exception as e:
        raise RuntimeError("pyarrow tidak tersedia. Di Kaggle biasanya ada.") from e

    parts = get_clean_parts(split_name, which)
    if not parts:
        raise RuntimeError(f"Tidak ada cleaned parts untuk {split_name}/{which}. Cek STAGE 4 output.")

    tmp_dir = Path(ART_DIR) / "tmp_seq_buckets" / split_name / which
    if tmp_dir.exists():
        shutil.rmtree(tmp_dir, ignore_errors=True)
    tmp_dir.mkdir(parents=True, exist_ok=True)

    writers = {}
    kept_rows = 0
    t0 = time.time()

    def bucket_idx(series_objid: pd.Series) -> np.ndarray:
        h = pd.util.hash_pandas_object(series_objid, index=False).to_numpy(dtype=np.uint64, copy=False)
        return (h % np.uint64(num_buckets)).astype(np.int16)

    try:
        for p in parts:
            df = _read_clean_part(p)
            if df.empty:
                continue

            df = df[df["object_id"].isin(expected_ids)]
            if df.empty:
                continue

            kept_rows += int(len(df))
            bidx = bucket_idx(df["object_id"])
            df["_b"] = bidx

            for b, sub in df.groupby("_b", sort=False):
                sub = sub.drop(columns=["_b"])
                if sub.empty:
                    continue
                fp = tmp_dir / f"bucket_{int(b):03d}.parquet"
                table = pa.Table.from_pandas(sub, preserve_index=False)
                if int(b) not in writers:
                    writers[int(b)] = pq.ParquetWriter(fp, table.schema, compression="snappy")
                writers[int(b)].write_table(table)

            del df
            gc.collect()

    finally:
        for w in list(writers.values()):
            try: w.close()
            except Exception: pass

    meta = df_train_meta if which == "train" else df_test_meta

    manifest_rows = []
    shard_idx = 0
    batch_obj_ids, batch_X, batch_B, batch_len = [], [], [], []
    built_ids = set()
    len_before, len_after = [], []

    def flush_shard_local():
        nonlocal shard_idx, batch_obj_ids, batch_X, batch_B, batch_len, manifest_rows
        if not batch_obj_ids:
            return
        lengths = np.asarray(batch_len, dtype=np.int64)
        offsets = np.zeros(len(lengths) + 1, dtype=np.int64)
        offsets[1:] = np.cumsum(lengths)

        Xc = np.concatenate(batch_X, axis=0).astype(np.float32)
        Bc = np.concatenate(batch_B, axis=0).astype(np.int8)

        shard_path = out_dir / f"shard_{shard_idx:04d}.npz"
        save_shard(shard_path, batch_obj_ids, Xc, Bc, offsets)

        for i, oid in enumerate(batch_obj_ids):
            manifest_rows.append({
                "object_id": oid,
                "split": split_name,
                "which": which,
                "shard": str(shard_path),
                "start": int(offsets[i]),
                "length": int(lengths[i]),
            })

        shard_idx += 1
        batch_obj_ids, batch_X, batch_B, batch_len = [], [], [], []
        gc.collect()

    for bf in sorted(tmp_dir.glob("bucket_*.parquet")):
        dfb = pd.read_parquet(bf)
        if dfb.empty:
            continue

        for oid, g in dfb.groupby("object_id", sort=False):
            oid = str(oid)
            if oid in built_ids:
                continue

            z_val = 0.0
            if USE_RESTFRAME_TIME and (oid in meta.index):
                if "Z" in meta.columns:
                    z_val = float(meta.loc[oid, "Z"])
                elif "Z_clip" in meta.columns:
                    z_val = float(meta.loc[oid, "Z_clip"])
                elif "photoz" in meta.columns:
                    z_val = float(meta.loc[oid, "photoz"])

            X, B, lb, la = build_object_tokens(g, meta_df=meta, z_val=z_val)
            len_before.append(lb); len_after.append(la)

            batch_obj_ids.append(oid)
            batch_X.append(X)
            batch_B.append(B)
            batch_len.append(int(X.shape[0]))
            built_ids.add(oid)

            if len(batch_obj_ids) >= SHARD_MAX_OBJECTS:
                flush_shard_local()

        del dfb
        gc.collect()

    # fill missing objects with empty token
    missing_ids = list(expected_ids - built_ids)
    if missing_ids:
        for oid in missing_ids:
            oid = str(oid)
            X, B, lb, la = build_empty_tokens()
            len_before.append(lb); len_after.append(la)

            batch_obj_ids.append(oid)
            batch_X.append(X)
            batch_B.append(B)
            batch_len.append(int(X.shape[0]))
            built_ids.add(oid)

            if len(batch_obj_ids) >= SHARD_MAX_OBJECTS:
                flush_shard_local()

    flush_shard_local()
    shutil.rmtree(tmp_dir, ignore_errors=True)

    st = {
        "kept_rows": int(kept_rows),
        "built_objects": int(len(built_ids)),
        "missing_filled": int(len(missing_ids)),
        "len_before_mean": float(np.mean(len_before)) if len_before else 0.0,
        "len_before_p95": float(np.quantile(len_before, 0.95)) if len_before else 0.0,
        "len_after_mean": float(np.mean(len_after)) if len_after else 0.0,
        "len_after_p95": float(np.quantile(len_after, 0.95)) if len_after else 0.0,
        "truncated_frac": float(np.mean([a < b for a, b in zip(len_after, len_before)])) if len_before else 0.0,
        "time_s": float(time.time() - t0),
    }
    return manifest_rows, st

# ----------------------------
# 12) RUN
# ----------------------------
splits_to_run = ONLY_SPLITS if (ONLY_SPLITS is not None) else SPLITS_TO_CONSIDER
splits_to_run = list(splits_to_run)

print("\n[Stage 5] SETTINGS")
print(f"- TOKEN_MODE_FORCE: {TOKEN_MODE_FORCE}")
print(f"- FEATURE_SET     : {FEATURE_SET} | dim={FEATURE_DIM}")
print(f"- ADD_META        : {ADD_OBJ_META_PER_TOKEN} | meta_cols={OBJ_META_COLS}")
print(f"- L_MAX/TRUNC     : {L_MAX} / {TRUNC_POLICY} (KEEP_DET_FRAC={KEEP_DET_FRAC})")
print(f"- NUM_BUCKETS     : {NUM_BUCKETS} | SHARD_MAX_OBJECTS={SHARD_MAX_OBJECTS}")
print(f"- REBUILD_MODE    : {REBUILD_MODE} | COMPRESS_NPZ={COMPRESS_NPZ}")

all_manifest_train, all_manifest_test, split_run_stats = [], [], []

def expected_set_for(split_name: str, which: str) -> set:
    return set(train_ids_by_split.get(split_name, [])) if which == "train" else set(test_ids_by_split.get(split_name, []))

for split_name in splits_to_run:
    for which in ["train", "test"]:
        out_dir = SEQ_DIR / split_name / which
        out_dir.mkdir(parents=True, exist_ok=True)

        expected_ids = expected_set_for(split_name, which)
        if len(expected_ids) == 0:
            raise RuntimeError(f"Expected ids empty for {split_name}/{which}.")

        shard_exists = any(out_dir.glob("shard_*.npz"))
        if REBUILD_MODE == "reuse_if_exists" and shard_exists:
            print(f"\n[Stage 5] REUSE (exists): {split_name}/{which}")
            man_rows = reconstruct_manifest_from_shards(split_name, which, out_dir)
            if not man_rows:
                raise RuntimeError(f"REUSE mode aktif tapi gagal rekonstruksi manifest: {out_dir}")

            got_ids = set([r["object_id"] for r in man_rows])
            miss_ids = expected_ids - got_ids
            if miss_ids:
                raise RuntimeError(
                    f"REUSE shard tidak cover semua expected ids untuk {split_name}/{which}. "
                    f"missing={len(miss_ids)} (contoh={list(sorted(miss_ids))[:5]}). "
                    "Solusi: set STAGE5_REBUILD_MODE='wipe_all' untuk rebuild bersih."
                )

            if which == "train":
                all_manifest_train.extend(man_rows)
            else:
                all_manifest_test.extend(man_rows)

            split_run_stats.append({
                "split": split_name, "which": which,
                "kept_rows": 0,
                "built_objects": len(got_ids),
                "missing_filled": 0,
                "len_before_mean": 0.0,
                "len_before_p95": 0.0,
                "len_after_mean": 0.0,
                "len_after_p95": 0.0,
                "truncated_frac": 0.0,
                "time_s": 0.0,
            })
            continue
        else:
            for f in out_dir.glob("shard_*.npz"):
                try: f.unlink()
                except Exception: pass

        print(f"\n[Stage 5] {split_name}/{which} | expected={len(expected_ids):,}")

        manifest_rows, st = build_sequences_bucket(
            split_name=split_name,
            which=which,
            expected_ids=expected_ids,
            out_dir=out_dir,
            num_buckets=NUM_BUCKETS
        )

        print(f"[Stage 5] OK: built={st['built_objects']:,} (missing_filled={st['missing_filled']:,}) | "
              f"kept_rows={st['kept_rows']:,} | "
              f"len_mean {st['len_before_mean']:.1f}->{st['len_after_mean']:.1f} | "
              f"p95 {st['len_before_p95']:.1f}->{st['len_after_p95']:.1f} | "
              f"trunc%={st['truncated_frac']*100:.1f}% | "
              f"time={st['time_s']:.2f}s")

        split_run_stats.append({"split": split_name, "which": which, **st})

        if which == "train":
            all_manifest_train.extend(manifest_rows)
        else:
            all_manifest_test.extend(manifest_rows)

        gc.collect()

# ----------------------------
# 13) Save manifests + stats + config
# ----------------------------
df_m_train = pd.DataFrame(all_manifest_train)
df_m_test  = pd.DataFrame(all_manifest_test)

if not df_m_train.empty:
    df_m_train = df_m_train.sort_values(["split", "shard", "start"]).reset_index(drop=True)
if not df_m_test.empty:
    df_m_test = df_m_test.sort_values(["split", "shard", "start"]).reset_index(drop=True)

mtrain_path = SEQ_DIR / "seq_manifest_train.csv"
mtest_path  = SEQ_DIR / "seq_manifest_test.csv"
df_m_train.to_csv(mtrain_path, index=False)
df_m_test.to_csv(mtest_path, index=False)

df_stats = pd.DataFrame(split_run_stats)
stats_path = SEQ_DIR / "seq_build_stats.csv"
df_stats.to_csv(stats_path, index=False)

cfg_out = {
    "stage": "stage5",
    "version": "v6.1",
    "token_mode": "asinh",
    "token_mode_force": "asinh",
    "feature_set": FEATURE_SET,
    "feature_names": FEATURE_NAMES,
    "feature_dim": int(FEATURE_DIM),
    "obj_meta_cols": OBJ_META_COLS,
    "add_meta_per_token": bool(ADD_OBJ_META_PER_TOKEN),
    "snr_tanh_scale": float(SNR_TANH_SCALE),
    "time_clip_max_days": None if TIME_CLIP_MAX_DAYS is None else float(TIME_CLIP_MAX_DAYS),
    "compress_npz": bool(COMPRESS_NPZ),
    "shard_max_objects": int(SHARD_MAX_OBJECTS),
    "num_buckets": int(NUM_BUCKETS),
    "L_MAX": int(L_MAX),
    "TRUNC_POLICY": str(TRUNC_POLICY),
    "KEEP_DET_FRAC": float(KEEP_DET_FRAC),
    "KEEP_EDGE": bool(KEEP_EDGE),
    "USE_RESTFRAME_TIME": bool(USE_RESTFRAME_TIME),
    "REBUILD_MODE": str(REBUILD_MODE),
    "RUN_DIR_USED": str(RUN_DIR),
    "ART_DIR_USED": str(ART_DIR),
    "LC_CLEAN_DIR_USED": str(LC_CLEAN_DIR),
    "manifest_csv": str(manifest_csv),
    "stage4_schema_hint": "v7(asinh_mag/asinh_mag_err, detected_pos) or legacy(flux_asinh/err_log1p, detected)",
}
cfg_path = SEQ_DIR / "seq_config.json"
cfg_path.write_text(json.dumps(cfg_out, indent=2))

print("\n[Stage 5] DONE")
print(f"- token_mode : asinh (forced)")
print(f"- feature_set: {FEATURE_SET} | dim={FEATURE_DIM}")
print(f"- Saved: {mtrain_path} (rows={len(df_m_train):,})")
print(f"- Saved: {mtest_path}  (rows={len(df_m_test):,})")
print(f"- Saved: {stats_path}")
print(f"- Saved: {cfg_path}")

# ----------------------------
# 14) Smoke test
# ----------------------------
def load_sequence(object_id: str, which: str):
    object_id = str(object_id).strip()
    m = df_m_train if which == "train" else df_m_test
    row = m[m["object_id"] == object_id]
    if row.empty:
        raise KeyError(f"object_id not found in seq manifest ({which}): {object_id}")
    r = row.iloc[0]
    data = np.load(r["shard"], allow_pickle=False)
    start = int(r["start"]); length = int(r["length"])
    X = data["x"][start:start+length]
    B = data["band"][start:start+length]
    return X, B

_smoke_oid = str(df_train_meta.index[0])
X_sm, B_sm = load_sequence(_smoke_oid, "train")
print(f"\n[Stage 5] Smoke test object_id={_smoke_oid}")
print(f"- seq_len={len(X_sm)} | X_shape={X_sm.shape} | bands_unique={sorted(set(B_sm.tolist()))}")

globals().update({
    "RUN_DIR": RUN_DIR,
    "ART_DIR": ART_DIR,
    "LC_CLEAN_DIR": LC_CLEAN_DIR,
    "SEQ_DIR": SEQ_DIR,
    "seq_manifest_train": df_m_train,
    "seq_manifest_test": df_m_test,
    "SEQ_FEATURE_NAMES": FEATURE_NAMES,
    "SEQ_FEATURE_DIM": int(FEATURE_DIM),
    "SEQ_TOKEN_MODE": "asinh",
    "get_clean_parts": get_clean_parts,
    "load_sequence": load_sequence,
})

gc.collect()


# Sequence Length Policy (Padding, Truncation, Windowing)

In [None]:
# ============================================================
# STAGE 6 — Sequence Length Policy (Padding, Truncation, Windowing)
# ONE CELL, Kaggle CPU-SAFE — REVISI FULL v3.1
#
# Upgrade v3.1:
# - Tetap 100% kompatibel STAGE 5 v6+ (signal/err_log) & legacy
# - Manifest safety: auto-filter ke expected ids, anti-duplicate object_id
# - Target column discovery lebih robust (meta -> fallback train.csv jika PATHS ada)
# - REUSE mode: strict file-set + strict coverage
# ============================================================

import gc, json, time, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Require minimal globals
# ----------------------------
for need in ["ART_DIR"]:
    if need not in globals():
        raise RuntimeError("Missing `ART_DIR`. Jalankan STAGE 0 dulu (setup run dir).")

ART_DIR = Path(ART_DIR)
CFG = globals().get("CFG", {})
CFG = CFG if isinstance(CFG, dict) else {}
SEED = int(globals().get("SEED", 2025))

# ----------------------------
# 0a) Locate STAGE 5 outputs (seq_tokens)
# ----------------------------
def _find_seq_tokens_dir(art_dir: Path) -> Path:
    cand = art_dir / "seq_tokens"
    if cand.exists() and cand.is_dir() and (cand / "seq_manifest_train.csv").exists() and (cand / "seq_manifest_test.csv").exists():
        return cand
    for p in art_dir.glob("*"):
        if p.is_dir() and (p / "seq_manifest_train.csv").exists() and (p / "seq_manifest_test.csv").exists():
            return p
    raise FileNotFoundError(
        "Cannot find STAGE 5 seq_tokens directory under ART_DIR.\n"
        f"ART_DIR={art_dir}\n"
        "Expected: ART_DIR/seq_tokens/seq_manifest_train.csv and seq_manifest_test.csv"
    )

SEQ_TOKENS_DIR = Path(globals().get("SEQ_TOKENS_DIR", None) or _find_seq_tokens_dir(ART_DIR))

p_mtr = SEQ_TOKENS_DIR / "seq_manifest_train.csv"
p_mte = SEQ_TOKENS_DIR / "seq_manifest_test.csv"
p_cfg = SEQ_TOKENS_DIR / "seq_config.json"

if not p_mtr.exists(): raise FileNotFoundError(f"Missing: {p_mtr}")
if not p_mte.exists(): raise FileNotFoundError(f"Missing: {p_mte}")
if not p_cfg.exists(): raise FileNotFoundError(f"Missing: {p_cfg}")

seq_manifest_train = pd.read_csv(p_mtr)
seq_manifest_test  = pd.read_csv(p_mte)

with open(p_cfg, "r", encoding="utf-8") as f:
    seq_cfg = json.load(f) if p_cfg.exists() else {}

SEQ_FEATURE_NAMES = (
    seq_cfg.get("feature_names", None)
    or seq_cfg.get("SEQ_FEATURE_NAMES", None)
    or globals().get("SEQ_FEATURE_NAMES", None)
)
if SEQ_FEATURE_NAMES is None:
    raise RuntimeError("SEQ_FEATURE_NAMES not found in seq_config.json and not present in globals.")
SEQ_FEATURE_NAMES = list(SEQ_FEATURE_NAMES)
feat = {name: i for i, name in enumerate(SEQ_FEATURE_NAMES)}

SEQ_TOKEN_MODE_IN = seq_cfg.get("token_mode", None) or globals().get("SEQ_TOKEN_MODE", None)

print(f"[Stage 6] Using SEQ_TOKENS_DIR: {SEQ_TOKENS_DIR}")
print(f"[Stage 6] Loaded manifests: train_rows={len(seq_manifest_train):,} | test_rows={len(seq_manifest_test):,}")
print(f"[Stage 6] token_mode(prefer)={SEQ_TOKEN_MODE_IN} | F={len(SEQ_FEATURE_NAMES)}")

# ----------------------------
# Helpers
# ----------------------------
def _norm_id(x):
    if isinstance(x, (bytes, np.bytes_)):
        try:
            x = x.decode("utf-8", errors="ignore")
        except Exception:
            x = str(x)
    s = str(x).strip()
    if (s.startswith("b'") and s.endswith("'")) or (s.startswith('b"') and s.endswith('"')):
        s = s[2:-1]
    return s.strip()

def _resolve_shard_path(x, base_dir: Path) -> str:
    p = Path(str(x))
    if p.exists():
        return str(p)
    p2 = base_dir / p
    if p2.exists():
        return str(p2)
    p3 = base_dir / p.name
    if p3.exists():
        return str(p3)
    return str(p)

def _pick_first_existing(keys, d):
    for k in keys:
        if k in d:
            return k
    return None

def _pick_value_feat(feat_dict, prefer_mode=None):
    value_exact = [
        "signal", "value", "flux",
        "flux_asinh", "asinh_flux", "asinh_mag",
        "flux_asinh_clip", "flux_asinh_norm", "flux_asinh_scaled",
    ]
    mag_exact = ["mag", "mag_norm", "mag_clip", "mag_scaled"]

    value_feat = _pick_first_existing(value_exact, feat_dict)
    mag_feat   = _pick_first_existing(mag_exact, feat_dict)

    if value_feat is None:
        value_fuzzy = [k for k in feat_dict.keys() if any(t in k for t in ["signal", "flux", "value", "asinh"])]
        value_feat = sorted(value_fuzzy)[0] if value_fuzzy else None
    if mag_feat is None:
        mag_fuzzy = [k for k in feat_dict.keys() if "mag" in k]
        mag_feat = sorted(mag_fuzzy)[0] if mag_fuzzy else None

    pm = (str(prefer_mode).lower().strip() if prefer_mode is not None else None)
    if pm == "mag" and mag_feat is not None:
        return "mag", mag_feat
    if pm == "asinh":
        if value_feat is not None:
            return "asinh", value_feat
        if mag_feat is not None:
            print("[WARN] prefer asinh but no value-like feature found; fallback to mag.")
            return "mag", mag_feat

    if value_feat is not None:
        return "asinh", value_feat
    if mag_feat is not None:
        return "mag", mag_feat

    return None, None

# required for downstream model
REQ_FOR_MODEL = ["t_rel_log", "dt_log"]
missing_req = [k for k in REQ_FOR_MODEL if k not in feat]
if missing_req:
    raise ValueError(f"SEQ_FEATURE_NAMES missing required feats for model: {missing_req}. Head={SEQ_FEATURE_NAMES[:40]}")

SNR_FEAT = _pick_first_existing(["snr_tanh", "snr"], feat)
if SNR_FEAT is None:
    raise ValueError("SEQ_FEATURE_NAMES missing snr_tanh or snr.")

DET_FEAT = _pick_first_existing(["detected", "detected_pos"], feat)
if DET_FEAT is None:
    raise ValueError("SEQ_FEATURE_NAMES missing detected or detected_pos.")

SEQ_TOKEN_MODE, SCORE_VALUE_FEAT = _pick_value_feat(feat, prefer_mode=SEQ_TOKEN_MODE_IN)
if SEQ_TOKEN_MODE is None or SCORE_VALUE_FEAT is None:
    raise ValueError(
        "Cannot infer token mode/value feature from SEQ_FEATURE_NAMES.\n"
        f"SEQ_FEATURE_NAMES={SEQ_FEATURE_NAMES}"
    )

print(f"[Stage 6] token_mode={SEQ_TOKEN_MODE} | score_value_feat={SCORE_VALUE_FEAT} | snr_feat={SNR_FEAT} | det_feat={DET_FEAT}")

# ----------------------------
# 0c) Ensure df_train_meta / df_test_meta exist (rebuild if missing)
# ----------------------------
if "df_train_meta" not in globals() or not isinstance(globals()["df_train_meta"], pd.DataFrame):
    if "df_train_log" in globals() and isinstance(globals()["df_train_log"], pd.DataFrame):
        df_train_meta = globals()["df_train_log"].copy()
        df_train_meta["object_id"] = df_train_meta["object_id"].astype(str).apply(_norm_id)
        df_train_meta = df_train_meta.set_index("object_id", drop=True)
        print("[Stage 6] df_train_meta missing -> rebuilt from df_train_log.")
    else:
        raise RuntimeError("Missing df_train_meta and df_train_log. Jalankan STAGE 0/1 dulu.")
else:
    df_train_meta = globals()["df_train_meta"].copy()

if "df_test_meta" not in globals() or not isinstance(globals()["df_test_meta"], pd.DataFrame):
    if "df_test_log" in globals() and isinstance(globals()["df_test_log"], pd.DataFrame):
        df_test_meta = globals()["df_test_log"].copy()
        df_test_meta["object_id"] = df_test_meta["object_id"].astype(str).apply(_norm_id)
        df_test_meta = df_test_meta.set_index("object_id", drop=True)
        print("[Stage 6] df_test_meta missing -> rebuilt from df_test_log.")
    else:
        raise RuntimeError("Missing df_test_meta and df_test_log. Jalankan STAGE 0/1 dulu.")
else:
    df_test_meta = globals()["df_test_meta"].copy()

df_train_meta.index = pd.Index([_norm_id(z) for z in df_train_meta.index], name=df_train_meta.index.name)
df_test_meta.index  = pd.Index([_norm_id(z) for z in df_test_meta.index],  name=df_test_meta.index.name)

# ----------------------------
# 1) Settings
# ----------------------------
FORCE_MAX_LEN = CFG.get("STAGE6_FORCE_MAX_LEN", None)
MAXLEN_CAPS = (256, 384, 512)

if SEQ_TOKEN_MODE == "asinh":
    W_SNR = float(CFG.get("STAGE6_W_SNR", 1.00))
    W_VAL = float(CFG.get("STAGE6_W_VAL", 0.05))
    W_DET = float(CFG.get("STAGE6_W_DET", 0.05))
else:
    W_SNR = float(CFG.get("STAGE6_W_SNR", 1.00))
    W_VAL = float(CFG.get("STAGE6_W_VAL", 0.35))
    W_DET = float(CFG.get("STAGE6_W_DET", 0.25))

PAD_BAND_ID = int(CFG.get("STAGE6_PAD_BAND_ID", 0))
DUMMY_TOKEN_FOR_EMPTY = bool(CFG.get("STAGE6_DUMMY_TOKEN_FOR_EMPTY", True))
SHIFT_BAND_IDS = bool(CFG.get("STAGE6_SHIFT_BAND_IDS", True))

DTYPE_X = np.float32
REBUILD_MODE = str(CFG.get("STAGE6_REBUILD_MODE", "wipe_all")).lower()  # wipe_all / reuse_if_exists

SNR_RAW_TO_TANH_SCALE = float(CFG.get("STAGE6_SNR_RAW_TO_TANH_SCALE", 10.0))  # only if snr(raw)

# ----------------------------
# 2) Length distribution -> choose MAX_LEN
# ----------------------------
def describe_lengths(m: pd.DataFrame, name: str):
    if "length" not in m.columns:
        raise RuntimeError(f"Manifest {name} missing 'length' column.")
    L = pd.to_numeric(m["length"], errors="coerce").fillna(0).to_numpy(dtype=np.int32, copy=False)
    q = np.percentile(L, [0, 1, 5, 10, 25, 50, 75, 90, 95, 98, 99, 100])
    print(f"\n{name} length stats")
    print(f"- n_objects={len(L):,} | min={int(q[0])} | p50={int(q[5])} | p90={int(q[7])} | p95={int(q[8])} | p99={int(q[10])} | max={int(q[-1])}")
    return q

q_tr = describe_lengths(seq_manifest_train, "TRAIN")
q_te = describe_lengths(seq_manifest_test,  "TEST")

p95 = int(max(q_tr[8], q_te[8]))
if FORCE_MAX_LEN is not None and str(FORCE_MAX_LEN) not in ["None", "none", ""]:
    MAX_LEN = int(FORCE_MAX_LEN)
else:
    if p95 <= 256:
        MAX_LEN = 256
    elif p95 <= 384:
        MAX_LEN = 384
    else:
        MAX_LEN = 512
    if MAX_LEN not in MAXLEN_CAPS:
        MAX_LEN = int(min(MAXLEN_CAPS, key=lambda x: abs(x - MAX_LEN)))

print(f"\n[Stage 6] MAX_LEN={MAX_LEN} (based on p95={p95})")
print(f"[Stage 6] Weights: W_SNR={W_SNR} | W_VAL={W_VAL} | W_DET={W_DET} | SHIFT_BAND_IDS={SHIFT_BAND_IDS} | DUMMY_TOKEN_FOR_EMPTY={DUMMY_TOKEN_FOR_EMPTY}")

# ----------------------------
# 3) Window scoring + padding
# ----------------------------
def _brightness_proxy_from_mag(mag: np.ndarray) -> np.ndarray:
    mag = np.nan_to_num(mag, nan=np.float32(0.0), posinf=np.float32(0.0), neginf=np.float32(0.0)).astype(np.float32, copy=False)
    if mag.size == 0:
        return np.zeros_like(mag, dtype=np.float32)
    med = np.float32(np.median(mag))
    br = np.maximum(med - mag, np.float32(0.0))
    br = np.log1p(br).astype(np.float32, copy=False)
    return br

def _get_snr_tanh_from_X(X: np.ndarray) -> np.ndarray:
    s = X[:, feat[SNR_FEAT]].astype(np.float32, copy=False)
    s = np.nan_to_num(s, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
    if SNR_FEAT == "snr_tanh":
        return np.abs(s).astype(np.float32, copy=False)
    return np.abs(np.tanh(s / np.float32(SNR_RAW_TO_TANH_SCALE))).astype(np.float32, copy=False)

def _score_tokens(X: np.ndarray) -> np.ndarray:
    snr = _get_snr_tanh_from_X(X)
    det = X[:, feat[DET_FEAT]].astype(np.float32, copy=False)
    det = np.nan_to_num(det, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)

    if SEQ_TOKEN_MODE == "mag":
        mag = X[:, feat[SCORE_VALUE_FEAT]].astype(np.float32, copy=False)
        val = _brightness_proxy_from_mag(mag)
    else:
        val = np.abs(X[:, feat[SCORE_VALUE_FEAT]]).astype(np.float32, copy=False)
        val = np.nan_to_num(val, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)

    score = (np.float32(W_SNR) * snr) + (np.float32(W_VAL) * val) + (np.float32(W_DET) * det)
    return np.nan_to_num(score, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)

def select_best_window(score: np.ndarray, max_len: int) -> tuple[int, int]:
    L = int(score.shape[0])
    if L <= max_len:
        return 0, L
    cs = np.empty(L + 1, dtype=np.float32)
    cs[0] = 0.0
    np.cumsum(score.astype(np.float32, copy=False), out=cs[1:])
    ws = cs[max_len:] - cs[:-max_len]
    if ws.size <= 0 or (not np.isfinite(ws).any()):
        start = int((L - max_len) // 2)
    else:
        start = int(np.argmax(ws))
    return start, start + max_len

def _is_empty_stub(X: np.ndarray, B: np.ndarray) -> bool:
    try:
        if X is None or B is None:
            return True
        if int(X.shape[0]) != 1 or int(B.shape[0]) != 1:
            return False
        if int(B[0]) >= 0:
            return False
        return bool(np.all(np.abs(X.astype(np.float32, copy=False)) <= np.float32(1e-12)))
    except Exception:
        return False

def pad_to_fixed(X: np.ndarray, B: np.ndarray, max_len: int):
    F = int(X.shape[1]) if (X is not None and getattr(X, "ndim", 0) == 2) else int(len(SEQ_FEATURE_NAMES))
    Xp = np.zeros((max_len, F), dtype=DTYPE_X)
    Bp = np.full((max_len,), PAD_BAND_ID, dtype=np.int8)
    Mp = np.zeros((max_len,), dtype=np.int8)

    empty_like = (
        X is None or B is None
        or getattr(X, "size", 0) == 0
        or getattr(B, "size", 0) == 0
        or _is_empty_stub(X, B)
    )

    if empty_like:
        if DUMMY_TOKEN_FOR_EMPTY:
            Mp[0] = 1
            Bp[0] = np.int8(PAD_BAND_ID)
            return Xp, Bp, Mp, 0, 0, 1
        else:
            return Xp, Bp, Mp, 0, 0, 0

    L0 = int(X.shape[0])
    if L0 <= max_len:
        ws, we = 0, L0
        Xw, Bw = X, B
    else:
        sc = _score_tokens(X)
        ws, we = select_best_window(sc, max_len=max_len)
        Xw, Bw = X[ws:we], B[ws:we]

    lw = int(Xw.shape[0])
    Xp[:lw] = Xw.astype(DTYPE_X, copy=False)

    if SHIFT_BAND_IDS:
        Bw16 = Bw.astype(np.int16, copy=False)
        Bp[:lw] = np.clip(Bw16 + 1, 0, 127).astype(np.int8, copy=False)
    else:
        Bp[:lw] = Bw.astype(np.int8, copy=False)

    Mp[:lw] = 1
    return Xp, Bp, Mp, int(L0), int(ws), int(we)

# ----------------------------
# 4) Fixed cache builder setup
# ----------------------------
FIX_DIR = ART_DIR / "fixed_seq"
FIX_DIR.mkdir(parents=True, exist_ok=True)

train_ids = df_train_meta.index.astype(str).tolist()

def _try_load_sample_sub_ids():
    if "df_sub" in globals() and isinstance(globals()["df_sub"], pd.DataFrame) and "object_id" in globals()["df_sub"].columns:
        return globals()["df_sub"]["object_id"].astype(str).str.strip().apply(_norm_id).to_list()
    if "PATHS" in globals() and isinstance(globals()["PATHS"], dict):
        p = globals()["PATHS"].get("SAMPLE_SUB", None)
        if p and Path(p).exists():
            df = pd.read_csv(p)
            if "object_id" in df.columns:
                return df["object_id"].astype(str).str.strip().apply(_norm_id).to_list()
    return None

test_ids = _try_load_sample_sub_ids()
if test_ids is None:
    test_ids = df_test_meta.index.astype(str).tolist()

# find y column (meta -> fallback train.csv if exists in PATHS)
def _find_target_col(df: pd.DataFrame):
    for cand in ["target", "y", "label", "class", "target_id", "binary_target", "is_tde", "is_event", "event"]:
        if cand in df.columns:
            return cand
    return None

_y_col = _find_target_col(df_train_meta)
if _y_col is None and "PATHS" in globals() and isinstance(globals()["PATHS"], dict):
    p_train = globals()["PATHS"].get("TRAIN_CSV", None) or globals()["PATHS"].get("TRAIN", None)
    if p_train and Path(p_train).exists():
        dft = pd.read_csv(p_train)
        if "object_id" in dft.columns:
            dft["object_id"] = dft["object_id"].astype(str).apply(_norm_id)
            ycol = _find_target_col(dft)
            if ycol is not None:
                # join back to meta (safe)
                tmp = dft.set_index("object_id")[ycol]
                df_train_meta[ycol] = pd.to_numeric(df_train_meta.index.map(tmp), errors="coerce")
                _y_col = ycol
                print(f"[Stage 6] target col loaded from TRAIN_CSV: {_y_col}")

if _y_col is None:
    raise RuntimeError(f"Cannot find target column in df_train_meta. cols_head={list(df_train_meta.columns)[:40]}")

y_train = pd.to_numeric(df_train_meta[_y_col], errors="coerce").fillna(0).astype(np.int16).to_numpy(copy=False)

# dedup checks
if len(set(train_ids)) != len(train_ids):
    s = pd.Series(train_ids); dup = s[s.duplicated()].head(10).tolist()
    raise RuntimeError(f"train_ids has duplicates. ex={dup}")
if len(set(test_ids)) != len(test_ids):
    s = pd.Series(test_ids); dup = s[s.duplicated()].head(10).tolist()
    raise RuntimeError(f"test_ids has duplicates. ex={dup}")

train_row = {oid: i for i, oid in enumerate(train_ids)}
test_row  = {oid: i for i, oid in enumerate(test_ids)}

NTR, NTE, F = len(train_ids), len(test_ids), len(SEQ_FEATURE_NAMES)

def _gb(nbytes): return float(nbytes) / (1024**3)
print(f"\n[Stage 6] Memmap X sizes approx: train={_gb(NTR*MAX_LEN*F*np.dtype(DTYPE_X).itemsize):.3f} GB | "
      f"test={_gb(NTE*MAX_LEN*F*np.dtype(DTYPE_X).itemsize):.3f} GB | dtype={DTYPE_X}")

train_X_path = FIX_DIR / "train_X.dat"
train_B_path = FIX_DIR / "train_B.dat"
train_M_path = FIX_DIR / "train_M.dat"
test_X_path  = FIX_DIR / "test_X.dat"
test_B_path  = FIX_DIR / "test_B.dat"
test_M_path  = FIX_DIR / "test_M.dat"

train_len_path = FIX_DIR / "train_origlen.npy"
train_ws_path  = FIX_DIR / "train_winstart.npy"
train_we_path  = FIX_DIR / "train_winend.npy"
test_len_path  = FIX_DIR / "test_origlen.npy"
test_ws_path   = FIX_DIR / "test_winstart.npy"
test_we_path   = FIX_DIR / "test_winend.npy"

def _all_exist(paths):
    return all(Path(p).exists() for p in paths)

reuse_paths = [
    train_X_path, train_B_path, train_M_path,
    test_X_path, test_B_path, test_M_path,
    FIX_DIR / "train_ids.npy", FIX_DIR / "test_ids.npy", FIX_DIR / "train_y.npy",
    train_len_path, train_ws_path, train_we_path,
    test_len_path, test_ws_path, test_we_path,
    FIX_DIR / "length_policy_config.json"
]

if REBUILD_MODE == "wipe_all":
    for p in reuse_paths:
        try:
            Path(p).unlink(missing_ok=True)
        except Exception:
            pass

if REBUILD_MODE == "reuse_if_exists" and _all_exist(reuse_paths):
    print("[Stage 6] REUSE: fixed_seq cache already present.")
    globals().update({
        "SEQ_TOKENS_DIR": SEQ_TOKENS_DIR,
        "FIX_DIR": FIX_DIR, "MAX_LEN": MAX_LEN,
        "FIX_TRAIN_X_PATH": train_X_path, "FIX_TRAIN_B_PATH": train_B_path, "FIX_TRAIN_M_PATH": train_M_path,
        "FIX_TEST_X_PATH": test_X_path,  "FIX_TEST_B_PATH": test_B_path,  "FIX_TEST_M_PATH": test_M_path,
        "FIX_TRAIN_Y_PATH": FIX_DIR / "train_y.npy",
        "FIX_TRAIN_IDS_PATH": FIX_DIR / "train_ids.npy",
        "FIX_TEST_IDS_PATH": FIX_DIR / "test_ids.npy",
        "FIX_POLICY_CFG_PATH": FIX_DIR / "length_policy_config.json",
        "SEQ_TOKEN_MODE": SEQ_TOKEN_MODE,
        "SCORE_VALUE_FEAT": SCORE_VALUE_FEAT,
        "SHIFT_BAND_IDS": SHIFT_BAND_IDS,
        "PAD_BAND_ID": PAD_BAND_ID,
        "df_train_meta": df_train_meta,
        "df_test_meta": df_test_meta,
    })
    gc.collect()

else:
    # memmaps
    Xtr = np.memmap(train_X_path, dtype=DTYPE_X, mode="w+", shape=(NTR, MAX_LEN, F))
    Btr = np.memmap(train_B_path, dtype=np.int8,  mode="w+", shape=(NTR, MAX_LEN))
    Mtr = np.memmap(train_M_path, dtype=np.int8,  mode="w+", shape=(NTR, MAX_LEN))

    Xte = np.memmap(test_X_path, dtype=DTYPE_X, mode="w+", shape=(NTE, MAX_LEN, F))
    Bte = np.memmap(test_B_path, dtype=np.int8,  mode="w+", shape=(NTE, MAX_LEN))
    Mte = np.memmap(test_M_path, dtype=np.int8,  mode="w+", shape=(NTE, MAX_LEN))

    origlen_tr  = np.zeros((NTR,), dtype=np.int32)
    winstart_tr = np.zeros((NTR,), dtype=np.int32)
    winend_tr   = np.zeros((NTR,), dtype=np.int32)

    origlen_te  = np.zeros((NTE,), dtype=np.int32)
    winstart_te = np.zeros((NTE,), dtype=np.int32)
    winend_te   = np.zeros((NTE,), dtype=np.int32)

    filled_tr = np.zeros((NTR,), dtype=np.uint8)
    filled_te = np.zeros((NTE,), dtype=np.uint8)

    def _prep_manifest(m: pd.DataFrame, expected_set: set, which: str) -> pd.DataFrame:
        m2 = m.copy()
        for c in ["object_id", "shard", "start", "length"]:
            if c not in m2.columns:
                raise RuntimeError(f"Manifest missing '{c}'. cols={list(m2.columns)}")

        m2["object_id"] = m2["object_id"].astype(str).apply(_norm_id)
        m2 = m2[m2["object_id"].isin(expected_set)].copy()

        # drop duplicate objects (keep first)
        if m2["object_id"].duplicated().any():
            dups = m2.loc[m2["object_id"].duplicated(), "object_id"].head(10).tolist()
            print(f"[WARN] {which} manifest has duplicate object_id; dropping duplicates. ex={dups}")
            m2 = m2.drop_duplicates("object_id", keep="first")

        m2["shard"] = m2["shard"].astype(str).apply(lambda s: _resolve_shard_path(s, SEQ_TOKENS_DIR))
        m2["start"] = pd.to_numeric(m2["start"], errors="coerce").fillna(-1).astype(np.int64)
        m2["length"] = pd.to_numeric(m2["length"], errors="coerce").fillna(0).astype(np.int64)
        m2 = m2[(m2["start"] >= 0) & (m2["length"] >= 0)].copy()

        return m2

    exp_train = set(train_ids)
    exp_test  = set(test_ids)
    m_train2 = _prep_manifest(seq_manifest_train, exp_train, "train")
    m_test2  = _prep_manifest(seq_manifest_test,  exp_test,  "test")

    def process_manifest_into_memmap(m2: pd.DataFrame, which: str):
        if which == "train":
            row_map = train_row
            Xmm, Bmm, Mmm = Xtr, Btr, Mtr
            origlen, ws_arr, we_arr = origlen_tr, winstart_tr, winend_tr
            filled_mask = filled_tr
            expected_n = NTR
        else:
            row_map = test_row
            Xmm, Bmm, Mmm = Xte, Bte, Mte
            origlen, ws_arr, we_arr = origlen_te, winstart_te, winend_te
            filled_mask = filled_te
            expected_n = NTE

        shard_paths = m2["shard"].unique().tolist()
        miss_sh = [p for p in shard_paths if not Path(p).exists()]
        if miss_sh:
            raise RuntimeError(f"Missing shard files ({which}): count={len(miss_sh)} | ex={miss_sh[:5]}")

        filled = dup = empty = dropped_bad = 0
        t0 = time.time()
        get = row_map.get

        for si, (shard_path, g) in enumerate(m2.groupby("shard", sort=True), start=1):
            data = np.load(shard_path, allow_pickle=False)
            if "x" not in data or "band" not in data:
                raise RuntimeError(f"Shard missing keys ['x','band']. Got={list(data.keys())} | shard={shard_path}")

            x_all = data["x"]
            b_all = data["band"]

            oids = g["object_id"].to_numpy(copy=False)
            starts = g["start"].to_numpy(dtype=np.int64, copy=False)
            lens   = g["length"].to_numpy(dtype=np.int64, copy=False)

            for oid, st, ln in zip(oids, starts, lens):
                idx = get(str(oid), -1)
                if idx < 0:
                    continue
                if ln <= 0:
                    empty += 1
                    continue
                if filled_mask[idx]:
                    dup += 1
                    continue

                end = int(st + ln)
                if st < 0 or end > x_all.shape[0] or end > b_all.shape[0]:
                    dropped_bad += 1
                    continue

                X = x_all[st:end]
                B = b_all[st:end]

                Xp, Bp, Mp, L0, ws, we = pad_to_fixed(X, B, max_len=MAX_LEN)

                Xmm[idx, :, :] = Xp
                Bmm[idx, :] = Bp
                Mmm[idx, :] = Mp
                origlen[idx] = int(L0)
                ws_arr[idx] = int(ws)
                we_arr[idx] = int(we)
                filled_mask[idx] = 1
                filled += 1

                if filled % 2000 == 0:
                    gc.collect()

            if si % 25 == 0:
                print(f"[Stage 6][{which}] shards_processed={si:,}/{len(shard_paths):,} | filled={filled:,}")

        elapsed = time.time() - t0
        return {
            "filled": int(filled),
            "dup_skipped": int(dup),
            "empty_len": int(empty),
            "dropped_bad_slices": int(dropped_bad),
            "time_s": float(elapsed),
            "expected": int(expected_n)
        }

    print("\n[Stage 6] Building fixed cache (TRAIN) from STAGE 5 manifests...")
    st_tr = process_manifest_into_memmap(m_train2, "train")
    print(f"[Stage 6] TRAIN filled={st_tr['filled']:,}/{st_tr['expected']:,} | dup={st_tr['dup_skipped']:,} | empty={st_tr['empty_len']:,} | dropped_bad={st_tr['dropped_bad_slices']:,} | time={st_tr['time_s']:.2f}s")

    print("\n[Stage 6] Building fixed cache (TEST) from STAGE 5 manifests...")
    st_te = process_manifest_into_memmap(m_test2, "test")
    print(f"[Stage 6] TEST  filled={st_te['filled']:,}/{st_te['expected']:,} | dup={st_te['dup_skipped']:,} | empty={st_te['empty_len']:,} | dropped_bad={st_te['dropped_bad_slices']:,} | time={st_te['time_s']:.2f}s")

    Xtr.flush(); Btr.flush(); Mtr.flush()
    Xte.flush(); Bte.flush(); Mte.flush()

    miss_tr = np.where(filled_tr == 0)[0]
    miss_te = np.where(filled_te == 0)[0]
    if len(miss_tr) > 0:
        ex = [train_ids[i] for i in miss_tr[:10]]
        raise RuntimeError(f"[Stage 6] TRAIN missing filled rows: {len(miss_tr):,}/{NTR:,} | ex={ex}")
    if len(miss_te) > 0:
        ex = [test_ids[i] for i in miss_te[:10]]
        raise RuntimeError(f"[Stage 6] TEST missing filled rows: {len(miss_te):,}/{NTE:,} | ex={ex}")

    np.save(FIX_DIR / "train_ids.npy", np.asarray(train_ids, dtype="S"))
    np.save(FIX_DIR / "test_ids.npy",  np.asarray(test_ids,  dtype="S"))
    np.save(FIX_DIR / "train_y.npy",   y_train)

    np.save(train_len_path, origlen_tr)
    np.save(train_ws_path,  winstart_tr)
    np.save(train_we_path,  winend_tr)

    np.save(test_len_path, origlen_te)
    np.save(test_ws_path,  winstart_te)
    np.save(test_we_path,  winend_te)

    policy_cfg = {
        "stage5_inputs": {
            "SEQ_TOKENS_DIR": str(SEQ_TOKENS_DIR),
            "seq_manifest_train": str(p_mtr),
            "seq_manifest_test": str(p_mte),
            "seq_config": str(p_cfg),
        },
        "token_mode": SEQ_TOKEN_MODE,
        "score_value_feat": SCORE_VALUE_FEAT,
        "snr_feat": SNR_FEAT,
        "det_feat": DET_FEAT,
        "max_len": int(MAX_LEN),
        "feature_names": list(SEQ_FEATURE_NAMES),
        "score_weights": {"W_SNR": float(W_SNR), "W_VAL": float(W_VAL), "W_DET": float(W_DET)},
        "window_policy": "best_contiguous_window_by_max_sum(score)",
        "padding": {
            "PAD_BAND_ID": int(PAD_BAND_ID),
            "SHIFT_BAND_IDS": bool(SHIFT_BAND_IDS),
            "DUMMY_TOKEN_FOR_EMPTY": bool(DUMMY_TOKEN_FOR_EMPTY),
        },
        "dtype_X": "float32",
        "order": {"train": "df_train_meta.index", "test": "df_sub.object_id if exists else df_test_meta.index", "y_col": str(_y_col)},
        "stats": {"train": st_tr, "test": st_te},
    }

    cfg_path = FIX_DIR / "length_policy_config.json"
    with open(cfg_path, "w", encoding="utf-8") as f:
        json.dump(policy_cfg, f, indent=2)

    print("\n[Stage 6] DONE")
    print(f"- FIX_DIR: {FIX_DIR}")
    print(f"- Saved config: {cfg_path}")

    globals().update({
        "SEQ_TOKENS_DIR": SEQ_TOKENS_DIR,
        "SEQ_FEATURE_NAMES": SEQ_FEATURE_NAMES,
        "FIX_DIR": FIX_DIR,
        "MAX_LEN": MAX_LEN,
        "FIX_TRAIN_X_PATH": train_X_path,
        "FIX_TRAIN_B_PATH": train_B_path,
        "FIX_TRAIN_M_PATH": train_M_path,
        "FIX_TEST_X_PATH": test_X_path,
        "FIX_TEST_B_PATH": test_B_path,
        "FIX_TEST_M_PATH": test_M_path,
        "FIX_TRAIN_Y_PATH": FIX_DIR / "train_y.npy",
        "FIX_TRAIN_IDS_PATH": FIX_DIR / "train_ids.npy",
        "FIX_TEST_IDS_PATH": FIX_DIR / "test_ids.npy",
        "FIX_POLICY_CFG_PATH": cfg_path,
        "SEQ_TOKEN_MODE": SEQ_TOKEN_MODE,
        "SCORE_VALUE_FEAT": SCORE_VALUE_FEAT,
        "SNR_FEAT": SNR_FEAT,
        "DET_FEAT": DET_FEAT,
        "SHIFT_BAND_IDS": SHIFT_BAND_IDS,
        "PAD_BAND_ID": PAD_BAND_ID,
        "df_train_meta": df_train_meta,
        "df_test_meta": df_test_meta,
    })

    gc.collect()


# CV Split (Object-Level, Stratified)

In [None]:
# ============================================================
# STAGE 7 — CV Split (Object-Level, Stratified) (ONE CELL)
# REVISI FULL v2.7 (UPGRADE: stronger auto-k, per-k group fallback, fold_stats csv, cv_train_ids.npy)
#
# Output:
# - artifacts/cv/cv_folds.csv
# - artifacts/cv/cv_folds.npz   (train_idx_f + val_idx_f + optional holdout_val_mask)
# - artifacts/cv/cv_report.txt
# - artifacts/cv/cv_config.json
# - artifacts/cv/fold_stats.csv
# - artifacts/cv/cv_train_ids.npy
# - (optional) artifacts/cv/cv_holdout_val_mask.npy
# - globals: fold_assign, folds, n_splits, CV_DIR
# ============================================================

import gc, json, time
from pathlib import Path
import numpy as np
import pandas as pd

# ----------------------------
# 0) Require minimal globals
# ----------------------------
for need in ["df_train_meta", "ART_DIR"]:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan minimal STAGE 2 dulu (df_train_meta & ART_DIR).")

SEED = int(globals().get("SEED", 2025))
ART_DIR = Path(ART_DIR)

# ----------------------------
# 1) CV Settings
# ----------------------------
DEFAULT_SPLITS = 5
FORCE_N_SPLITS = None              # set int kalau mau paksa (mis. 3), else None

MIN_POS_PER_FOLD = 3               # umum 2–10 (imbalance -> kecilkan)
ENFORCE_MIN_POS_PER_FOLD = True    # kalau True: n_splits turun otomatis sampai min_pos>=MIN_POS_PER_FOLD (atau fallback holdout)

MIN_NEG_PER_FOLD = 1               # biasanya aman 1 (neg biasanya banyak). Boleh dinaikkan kalau perlu.
ENFORCE_MIN_NEG_PER_FOLD = True

USE_GROUP_BY_SPLIT = False         # True => prefer StratifiedGroupKFold (groups=df_train_meta["split"])
AUTO_FALLBACK_GROUP = True         # True => kalau group-cv gagal, fallback ke StratifiedKFold

HOLDOUT_FALLBACK = True            # True => kalau CV tidak mungkin, pakai holdout
HOLDOUT_FRAC = 0.20                # target val fraction untuk holdout (stratified)

SAVE_PARQUET_FOLDS = False         # opsional: save cv_folds.parquet jika pyarrow tersedia

print(
    f"[Stage 7] seed={SEED} | default_splits={DEFAULT_SPLITS} | "
    f"MIN_POS_PER_FOLD={MIN_POS_PER_FOLD} enforce_pos={ENFORCE_MIN_POS_PER_FOLD} | "
    f"MIN_NEG_PER_FOLD={MIN_NEG_PER_FOLD} enforce_neg={ENFORCE_MIN_NEG_PER_FOLD} | "
    f"group_by_split={USE_GROUP_BY_SPLIT} fallback_group={AUTO_FALLBACK_GROUP} | "
    f"holdout_fallback={HOLDOUT_FALLBACK} holdout_frac={HOLDOUT_FRAC}"
)

# ----------------------------
# 2) Helpers
# ----------------------------
def _norm_id(x):
    if isinstance(x, (bytes, bytearray, np.bytes_)):
        try:
            x = x.decode("utf-8", errors="ignore")
        except Exception:
            x = str(x)
    s = str(x).strip()
    if (s.startswith("b'") and s.endswith("'")) or (s.startswith('b"') and s.endswith('"')):
        s = s[2:-1]
    return s.strip()

def _decode_ids(arr) -> list:
    return [_norm_id(x) for x in arr.tolist()]

def _safe_str_list(idx) -> list:
    return pd.Index(idx).astype("string").str.strip().astype(str).tolist()

def _find_train_ids_npy(art_dir: Path):
    # priority 1: FIX_DIR (Stage 6)
    if "FIX_DIR" in globals():
        p = Path(globals()["FIX_DIR"]) / "train_ids.npy"
        if p.exists():
            return p
    # priority 2: ART_DIR/fixed_seq
    p = art_dir / "fixed_seq" / "train_ids.npy"
    if p.exists():
        return p
    # priority 3: scan mallorn_run runs (latest mtime)
    root = Path("/kaggle/working/mallorn_run")
    if root.exists():
        cands = list(root.glob("run_*/artifacts/fixed_seq/train_ids.npy"))
        if cands:
            cands = sorted(cands, key=lambda x: x.stat().st_mtime, reverse=True)
            return cands[0]
    return None

def _pick_target_col(df: pd.DataFrame):
    for cand in ["target", "y", "label", "class", "is_tde", "binary_target", "target_id"]:
        if cand in df.columns:
            return cand
    return None

# ----------------------------
# 3) Determine train_ids ordering (prefer fixed cache from STAGE 6)
# ----------------------------
p_ids = _find_train_ids_npy(ART_DIR)
if p_ids is not None:
    raw = np.load(p_ids, allow_pickle=False)
    train_ids = _decode_ids(raw)
    order_source = str(p_ids)
else:
    train_ids = _safe_str_list(df_train_meta.index)
    order_source = "df_train_meta.index"

if len(train_ids) != len(set(train_ids)):
    s = pd.Series(train_ids)
    dup = s[s.duplicated()].iloc[:10].tolist()
    raise RuntimeError(f"[Stage 7] train_ids has duplicates (examples): {dup}")

N = int(len(train_ids))

# ----------------------------
# 4) Normalize meta index (string+strip) + fast mapping
# ----------------------------
meta = df_train_meta.copy()
meta_ids = _safe_str_list(meta.index)

if len(meta_ids) != len(set(meta_ids)):
    vc = pd.Series(meta_ids).value_counts()
    dup = vc[vc > 1].index.tolist()[:10]
    raise RuntimeError(f"[Stage 7] df_train_meta index has duplicates after str/strip (examples): {dup}")

meta.index = pd.Index(meta_ids, name="object_id")

pos_idx = meta.index.get_indexer(train_ids)
missing_mask = (pos_idx < 0)
if missing_mask.any():
    ex = [train_ids[i] for i in np.where(missing_mask)[0][:10]]
    raise RuntimeError(
        "[Stage 7] Some train_ids not found in df_train_meta (after str/strip index).\n"
        f"Missing count={int(missing_mask.sum())} | ex={ex}\n"
        "Solusi: pastikan df_train_meta index adalah object_id dan konsisten dengan fixed_seq train_ids (kalau pakai Stage 6)."
    )

# ----------------------------
# 5) Robust target -> y (ordered by train_ids)
# ----------------------------
target_col = _pick_target_col(meta)
if target_col is None:
    raise RuntimeError(f"[Stage 7] Cannot find target column in df_train_meta. cols(sample)={list(meta.columns)[:40]}")

y_all = pd.to_numeric(meta[target_col], errors="coerce").fillna(0).to_numpy(copy=False)
y = y_all[pos_idx]
y = (y > 0).astype(np.int8)

pos = int((y == 1).sum())
neg = int((y == 0).sum())
if pos == 0 or neg == 0:
    raise RuntimeError(f"[Stage 7] Invalid class distribution: pos={pos}, neg={neg}. Cannot do stratified split.")

pos_rate = pos / max(N, 1)
pos_weight = float(neg / max(pos, 1))
print(f"[Stage 7] N={N:,} pos={pos:,} neg={neg:,} pos%={pos_rate*100:.6f}% | pos_weight~{pos_weight:.4f} | order_source={order_source}")

# ----------------------------
# 6) Optional groups (by split)
# ----------------------------
groups = None
group_col = None
if USE_GROUP_BY_SPLIT:
    for cand in ["split", "split_id", "split_name", "split_idx"]:
        if cand in meta.columns:
            group_col = cand
            break
    if group_col is None:
        if not AUTO_FALLBACK_GROUP:
            raise RuntimeError("[Stage 7] USE_GROUP_BY_SPLIT=True but no split column found in df_train_meta.")
        print("[Stage 7] WARN: split column not found; fallback to StratifiedKFold.")
        USE_GROUP_BY_SPLIT = False
    else:
        g_all = meta[group_col].astype("string").str.strip().astype(str).to_numpy(copy=False)
        groups = g_all[pos_idx]

# ----------------------------
# 7) Choose n_splits safely + auto-adjust
# ----------------------------
max_splits_by_pos = pos
max_splits_by_neg = neg
max_splits_by_minpos = max(1, pos // max(int(MIN_POS_PER_FOLD), 1))
max_splits_by_minneg = max(1, neg // max(int(MIN_NEG_PER_FOLD), 1))

n0 = min(DEFAULT_SPLITS, max_splits_by_pos, max_splits_by_neg, max_splits_by_minpos, max_splits_by_minneg)
if FORCE_N_SPLITS is not None:
    n0 = int(FORCE_N_SPLITS)

print(f"[Stage 7] Candidate n_splits={n0} | enforce_pos={ENFORCE_MIN_POS_PER_FOLD} enforce_neg={ENFORCE_MIN_NEG_PER_FOLD}")

# ----------------------------
# 8) Build folds with robust fallback (group->non-group per-k)
# ----------------------------
try:
    from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
    try:
        from sklearn.model_selection import StratifiedGroupKFold
    except Exception:
        StratifiedGroupKFold = None
except Exception as e:
    raise RuntimeError("scikit-learn is not available in this environment.") from e

def _try_split_kfold(k: int, use_group: bool):
    fold_assign = np.full(N, -1, dtype=np.int16)
    folds = []
    per = []

    if use_group:
        if StratifiedGroupKFold is None:
            return (False, "StratifiedGroupKFold(unavailable)", None, None, None)
        if groups is None:
            return (False, "StratifiedGroupKFold(groups=None)", None, None, None)
        splitter = StratifiedGroupKFold(n_splits=k, shuffle=True, random_state=SEED)
        split_iter = splitter.split(np.zeros(N), y, groups=groups)
        cv_type = f"StratifiedGroupKFold({group_col})"
    else:
        splitter = StratifiedKFold(n_splits=k, shuffle=True, random_state=SEED)
        split_iter = splitter.split(np.zeros(N), y)
        cv_type = "StratifiedKFold"

    try:
        for fold, (tr_idx, val_idx) in enumerate(split_iter):
            fold_assign[val_idx] = fold
            yf = y[val_idx]
            pf = int((yf == 1).sum())
            nf = int((yf == 0).sum())
            per.append((len(val_idx), pf, nf))
            folds.append({
                "fold": int(fold),
                "train_idx": tr_idx.astype(np.int32, copy=False),
                "val_idx": val_idx.astype(np.int32, copy=False),
            })
    except Exception as e:
        return (False, f"{cv_type} (error: {type(e).__name__}: {e})", None, None, None)

    if (fold_assign < 0).any():
        return (False, f"{cv_type} (unassigned)", None, None, None)

    # strict: no fold with empty class
    for (_, pf, nf) in per:
        if pf == 0 or nf == 0:
            return (False, f"{cv_type} (empty class in fold)", None, None, None)

    return (True, cv_type, fold_assign, folds, per)

def _passes_min_constraints(per):
    if not per:
        return False, 0, 0
    min_pos_seen = min(pf for (_, pf, _) in per)
    min_neg_seen = min(nf for (_, _, nf) in per)
    if ENFORCE_MIN_POS_PER_FOLD and (min_pos_seen < int(MIN_POS_PER_FOLD)) and (FORCE_N_SPLITS is None):
        return False, min_pos_seen, min_neg_seen
    if ENFORCE_MIN_NEG_PER_FOLD and (min_neg_seen < int(MIN_NEG_PER_FOLD)) and (FORCE_N_SPLITS is None):
        return False, min_pos_seen, min_neg_seen
    return True, min_pos_seen, min_neg_seen

def _make_holdout():
    if pos < 2 or neg < 2:
        raise RuntimeError(f"[Stage 7] Cannot build holdout safely. Need pos>=2 and neg>=2. Got pos={pos}, neg={neg}.")

    val_n = int(round(N * float(HOLDOUT_FRAC)))
    val_n = max(val_n, 2)
    val_n = min(val_n, N - 2)

    splitter = StratifiedShuffleSplit(n_splits=1, test_size=(val_n / N), random_state=SEED)
    tr_idx, val_idx = next(splitter.split(np.zeros(N), y))

    # ensure val has both classes
    pf = int((y[val_idx] == 1).sum())
    nf = int((y[val_idx] == 0).sum())
    if pf == 0 or nf == 0:
        raise RuntimeError(f"[Stage 7] Holdout produced empty class in val. pos_val={pf} neg_val={nf}. Try different seed/frac.")

    fold_assign = np.zeros(N, dtype=np.int16)  # no -1 (downstream safe)
    val_mask = np.zeros(N, dtype=np.uint8)
    val_mask[val_idx] = 1

    folds = [{
        "fold": 0,
        "train_idx": tr_idx.astype(np.int32, copy=False),
        "val_idx": val_idx.astype(np.int32, copy=False),
    }]
    per = [(len(val_idx), pf, nf)]
    return 1, "Holdout(StratifiedShuffleSplit)", fold_assign, folds, per, val_mask

best = None
val_mask_holdout = None

if n0 >= 2:
    for k in range(n0, 1, -1):
        # try group first (if requested), else plain
        tried = []
        if USE_GROUP_BY_SPLIT:
            tried.append(True)
        tried.append(False)  # always allow non-group attempt

        for use_group in tried:
            ok, cv_type, fa, folds_k, per = _try_split_kfold(k, use_group=use_group)
            if not ok:
                continue

            passed, min_pos_seen, min_neg_seen = _passes_min_constraints(per)
            if not passed:
                continue

            best = (k, cv_type, fa, folds_k, per, min_pos_seen, min_neg_seen)
            break

        if best is not None:
            break

# if strict constraints failed, pick first valid anyway (still must have both classes per fold)
if best is None and n0 >= 2:
    for k in range(n0, 1, -1):
        ok, cv_type, fa, folds_k, per = _try_split_kfold(k, use_group=False)
        if ok:
            min_pos_seen = min(pf for (_, pf, _) in per) if per else 0
            min_neg_seen = min(nf for (_, _, nf) in per) if per else 0
            best = (k, cv_type, fa, folds_k, per, min_pos_seen, min_neg_seen)
            print(f"[Stage 7] NOTE: Could not satisfy min constraints. Using k={k} with min_pos={min_pos_seen}, min_neg={min_neg_seen}.")
            break

# fallback to holdout
if best is None:
    if HOLDOUT_FALLBACK:
        n_splits, cv_type, fold_assign, folds, per, val_mask_holdout = _make_holdout()
        min_pos_seen = per[0][1] if per else 0
        min_neg_seen = per[0][2] if per else 0
        best = (n_splits, cv_type, fold_assign, folds, per, min_pos_seen, min_neg_seen)
        print(f"[Stage 7] FALLBACK -> {cv_type} | val_pos={min_pos_seen} val_neg={min_neg_seen}")
    else:
        raise RuntimeError("[Stage 7] Failed to build a valid CV split. Try smaller DEFAULT_SPLITS / FORCE_N_SPLITS, or enable HOLDOUT_FALLBACK.")

# unpack
n_splits, cv_type, fold_assign, folds, per, min_pos_seen, min_neg_seen = best
print(f"[Stage 7] FINAL: n_splits={n_splits} | cv_type={cv_type} | min_pos_in_fold={min_pos_seen} | min_neg_in_fold={min_neg_seen}")

# ----------------------------
# 9) Report + fold stats
# ----------------------------
lines = []
lines.append(f"CV={cv_type} n_splits={n_splits} seed={SEED}")
lines.append(f"Order source: {order_source}")
lines.append(f"Target column: {target_col}")
lines.append(f"Total: N={N} | pos={pos} | neg={neg} | pos%={pos_rate*100:.6f}% | pos_weight~{pos_weight:.6f}")
if USE_GROUP_BY_SPLIT:
    lines.append(f"Group requested: {group_col} | used_group={('Group' in cv_type)}")

lines.append("Per-fold distribution (val):")
fold_rows = []
if n_splits >= 2:
    for f in range(n_splits):
        idx = np.where(fold_assign == f)[0]
        yf = y[idx]
        pf = int((yf == 1).sum())
        nf = int((yf == 0).sum())
        lines.append(f"- fold {f}: n={len(idx):6d} | pos={pf:5d} | neg={nf:6d} | pos%={(pf/max(len(idx),1))*100:9.6f}%")
        fold_rows.append({"fold": f, "n_val": len(idx), "pos_val": pf, "neg_val": nf, "pos_frac_val": pf/max(len(idx),1)})
else:
    vidx = folds[0]["val_idx"]
    yf = y[vidx]
    pf = int((yf == 1).sum())
    nf = int((yf == 0).sum())
    lines.append(f"- holdout val: n={len(vidx):6d} | pos={pf:5d} | neg={nf:6d} | pos%={(pf/max(len(vidx),1))*100:9.6f}%")
    lines.append("NOTE: holdout mode uses folds[0].train_idx / folds[0].val_idx; fold_assign is all zeros (no -1).")
    fold_rows.append({"fold": 0, "n_val": len(vidx), "pos_val": pf, "neg_val": nf, "pos_frac_val": pf/max(len(vidx),1)})

df_fold_stats = pd.DataFrame(fold_rows)

# ----------------------------
# 10) Save artifacts
# ----------------------------
CV_DIR = ART_DIR / "cv"
CV_DIR.mkdir(parents=True, exist_ok=True)

df_folds = pd.DataFrame({"object_id": train_ids, "fold": fold_assign.astype(int)})

folds_csv = CV_DIR / "cv_folds.csv"
df_folds.to_csv(folds_csv, index=False)

if SAVE_PARQUET_FOLDS:
    try:
        df_folds.to_parquet(CV_DIR / "cv_folds.parquet", index=False)
    except Exception as e:
        print(f"[Stage 7] WARN: parquet save skipped ({type(e).__name__}: {e})")

# save ids for downstream consistency (even without Stage 6)
np.save(CV_DIR / "cv_train_ids.npy", np.asarray(train_ids, dtype="S"))

npz_path = CV_DIR / "cv_folds.npz"
npz_kwargs = {}
for fd in folds:
    f = int(fd["fold"])
    npz_kwargs[f"train_idx_{f}"] = fd["train_idx"].astype(np.int32, copy=False)
    npz_kwargs[f"val_idx_{f}"]   = fd["val_idx"].astype(np.int32, copy=False)

if val_mask_holdout is not None:
    npz_kwargs["holdout_val_mask"] = val_mask_holdout.astype(np.uint8, copy=False)
    np.save(CV_DIR / "cv_holdout_val_mask.npy", val_mask_holdout.astype(np.uint8, copy=False))

np.savez(npz_path, **npz_kwargs)

report_path = CV_DIR / "cv_report.txt"
with open(report_path, "w", encoding="utf-8") as f:
    f.write("\n".join(lines) + "\n")

fold_stats_path = CV_DIR / "fold_stats.csv"
df_fold_stats.to_csv(fold_stats_path, index=False)

cfg_path = CV_DIR / "cv_config.json"
with open(cfg_path, "w", encoding="utf-8") as f:
    json.dump(
        {
            "seed": SEED,
            "n_splits": int(n_splits),
            "cv_type": cv_type,
            "min_pos_per_fold": int(MIN_POS_PER_FOLD),
            "min_neg_per_fold": int(MIN_NEG_PER_FOLD),
            "enforce_min_pos_per_fold": bool(ENFORCE_MIN_POS_PER_FOLD),
            "enforce_min_neg_per_fold": bool(ENFORCE_MIN_NEG_PER_FOLD),
            "use_group_by_split_requested": bool(USE_GROUP_BY_SPLIT),
            "auto_fallback_group": bool(AUTO_FALLBACK_GROUP),
            "holdout_fallback": bool(HOLDOUT_FALLBACK),
            "holdout_frac": float(HOLDOUT_FRAC),
            "order_source": order_source,
            "target_col": target_col,
            "group_col": group_col,
            "pos_weight_hint": float(pos_weight),
            "summary": {
                "N": int(N), "pos": int(pos), "neg": int(neg), "pos_rate": float(pos_rate),
                "min_pos_in_fold": int(min_pos_seen), "min_neg_in_fold": int(min_neg_seen),
            },
            "artifacts": {
                "folds_csv": str(folds_csv),
                "folds_npz": str(npz_path),
                "report_txt": str(report_path),
                "fold_stats_csv": str(fold_stats_path),
                "cv_train_ids_npy": str(CV_DIR / "cv_train_ids.npy"),
                "holdout_val_mask_npy": (str(CV_DIR / "cv_holdout_val_mask.npy") if val_mask_holdout is not None else None),
            },
        },
        f,
        indent=2,
    )

print("\n[Stage 7] CV split OK")
print(f"- Saved: {folds_csv}")
print(f"- Saved: {npz_path}")
print(f"- Saved: {report_path}")
print(f"- Saved: {fold_stats_path}")
print(f"- Saved: {cfg_path}")

tail_n = min(len(lines), 12)
print("\n".join(lines[-tail_n:]))

# ----------------------------
# 11) Export globals for next stage
# ----------------------------
globals().update({
    "CV_DIR": CV_DIR,
    "n_splits": int(n_splits),
    "train_ids_ordered": train_ids,
    "y_ordered": y,
    "fold_assign": fold_assign,
    "folds": folds,
    "CV_FOLDS_CSV": folds_csv,
    "CV_FOLDS_NPZ": npz_path,
    "CV_CFG_PATH": cfg_path,
    "CV_TYPE": cv_type,
    "CV_ORDER_SOURCE": order_source,
    "POS_WEIGHT_HINT": float(pos_weight),
    "HOLDOUT_VAL_MASK": (val_mask_holdout if val_mask_holdout is not None else None),
})

gc.collect()


# Train Model (CPU-Safe Configuration)

In [None]:
# ============================================================
# STAGE 8 — Train Multiband Event Transformer
# REVISI FULL v4.5 (IMPROVED: PR-AUC primary metric + optional focal + safer agg chunk + better logging)
#
# Key upgrades v4.5:
# - Primary metric default = PR-AUC (AveragePrecision) (lebih cocok imbalance) + tie-break val_loss
# - Optional focal loss (CFG["focal_gamma"] > 0) for extreme imbalance
# - Auto chunk size for AGG features to avoid RAM spikes
# - Per-fold training history saved to LOG_DIR/fold_{fold}_history.csv
# - Stabilize encoder input with LayerNorm
# - Keeps: global_features_raw.npy caching + EMA + pos_weight default
# ============================================================

import os, gc, json, math, time, warnings, hashlib
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)

# ----------------------------
# 0) Require minimal previous stages
# ----------------------------
need_min = ["FIX_DIR","MAX_LEN","SEQ_FEATURE_NAMES","df_train_meta","n_splits","folds"]
for k in need_min:
    if k not in globals():
        raise RuntimeError(f"Missing `{k}`. Jalankan STAGE 0..7 dulu dengan urutan benar.")

# ----------------------------
# 0a) Resolve train_ids ordering + labels (FAST + robust)
# ----------------------------
def _decode_ids(arr):
    out = []
    for x in arr.tolist():
        if isinstance(x, (bytes, bytearray, np.bytes_)):
            s = x.decode("utf-8", errors="ignore")
        else:
            s = str(x)
        out.append(s.strip())
    return out

if "train_ids_ordered" in globals() and globals()["train_ids_ordered"] is not None:
    train_ids = [str(x).strip() for x in list(globals()["train_ids_ordered"])]
else:
    p = Path(globals()["FIX_DIR"]) / "train_ids.npy"
    if p.exists():
        raw = np.load(p, allow_pickle=False)
        train_ids = _decode_ids(raw)
    else:
        train_ids = pd.Index(globals()["df_train_meta"].index).astype("string").str.strip().astype(str).tolist()

df_train_meta = globals()["df_train_meta"]
target_col = None
for cand in ["target","y","label","class","is_tde","binary_target","target_id"]:
    if cand in df_train_meta.columns:
        target_col = cand
        break
if target_col is None:
    raise RuntimeError(f"Cannot find target column in df_train_meta. cols(sample)={list(df_train_meta.columns)[:40]}")

meta = df_train_meta.copy()
meta.index = pd.Index(pd.Index(meta.index).astype("string").str.strip().astype(str), name="object_id")

pos_idx = meta.index.get_indexer(train_ids)
if (pos_idx < 0).any():
    miss = [train_ids[i] for i in np.where(pos_idx < 0)[0][:10]]
    raise RuntimeError(f"Some train_ids not found in df_train_meta.index (after str/strip). ex={miss}")

y_all = pd.to_numeric(meta[target_col], errors="coerce").fillna(0).astype(np.int16).to_numpy(copy=False)
y = y_all[pos_idx]
y = (y > 0).astype(np.int8)

# ----------------------------
# 0b) Ensure output dirs exist
# ----------------------------
if "RUN_DIR" in globals() and globals()["RUN_DIR"] is not None:
    RUN_DIR = Path(globals()["RUN_DIR"])
else:
    if "ART_DIR" in globals() and globals()["ART_DIR"] is not None:
        RUN_DIR = Path(globals()["ART_DIR"]).parent
    else:
        RUN_DIR = Path("/kaggle/working/mallorn_run")

ART_DIR = Path(globals().get("ART_DIR", RUN_DIR / "artifacts"))
ART_DIR.mkdir(parents=True, exist_ok=True)

CKPT_DIR = Path(globals().get("CKPT_DIR", RUN_DIR / "checkpoints"))
OOF_DIR  = Path(globals().get("OOF_DIR",  RUN_DIR / "oof"))
LOG_DIR  = Path(globals().get("LOG_DIR",  RUN_DIR / "logs"))
CKPT_DIR.mkdir(parents=True, exist_ok=True)
OOF_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR.mkdir(parents=True, exist_ok=True)

globals().update({"RUN_DIR": RUN_DIR, "ART_DIR": ART_DIR, "CKPT_DIR": CKPT_DIR, "OOF_DIR": OOF_DIR, "LOG_DIR": LOG_DIR})

# ----------------------------
# 1) Torch imports + CPU safety
# ----------------------------
try:
    import torch
    import torch.nn as nn
except Exception as e:
    raise RuntimeError("PyTorch tidak tersedia di environment ini.") from e

SEED = int(globals().get("SEED", 2025))
torch.manual_seed(SEED)
np.random.seed(SEED)

device = torch.device("cpu")

# threads (safe)
try:
    torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", "2")))
    torch.set_num_interop_threads(1)
except Exception:
    pass

# metrics
try:
    from sklearn.metrics import roc_auc_score, average_precision_score
    _HAS_AP = True
except Exception as e:
    raise RuntimeError("scikit-learn metrics tidak tersedia (roc_auc_score/average_precision_score).") from e

# ----------------------------
# 2) Open memmaps (fixed seq)
# ----------------------------
FIX_DIR = Path(globals()["FIX_DIR"])
N = len(train_ids)
L = int(globals()["MAX_LEN"])
SEQ_FEATURE_NAMES = list(globals()["SEQ_FEATURE_NAMES"])
Fdim = len(SEQ_FEATURE_NAMES)
feat = {n:i for i,n in enumerate(SEQ_FEATURE_NAMES)}

train_X_path = FIX_DIR / "train_X.dat"
train_B_path = FIX_DIR / "train_B.dat"
train_M_path = FIX_DIR / "train_M.dat"

for p in [train_X_path, train_B_path, train_M_path]:
    if not p.exists():
        raise FileNotFoundError(f"Missing fixed cache file: {p}. Pastikan STAGE 6 sukses.")

# ----------------------------
# 2b) Read Stage6 policy (SHIFT_BAND_IDS + dtype_X + score_value_feat + token_mode)
# ----------------------------
SHIFT_BAND_IDS = False
PAD_BAND_ID = 0
DTYPE_X_MEMMAP = np.float32
SEQ_TOKEN_MODE = None
SCORE_VALUE_FEAT = globals().get("SCORE_VALUE_FEAT", None)

policy_path = FIX_DIR / "length_policy_config.json"
if policy_path.exists():
    try:
        with open(policy_path, "r", encoding="utf-8") as f:
            pol = json.load(f)
        SHIFT_BAND_IDS = bool(pol.get("padding", {}).get("SHIFT_BAND_IDS", False))
        PAD_BAND_ID = int(pol.get("padding", {}).get("PAD_BAND_ID", 0))
        dt = str(pol.get("dtype_X", "float32")).lower()
        DTYPE_X_MEMMAP = np.float16 if (("float16" in dt) or ("fp16" in dt)) else np.float32
        if SCORE_VALUE_FEAT is None:
            SCORE_VALUE_FEAT = pol.get("score_value_feat", None)
        if SEQ_TOKEN_MODE is None:
            SEQ_TOKEN_MODE = pol.get("token_mode", None)
    except Exception:
        pass

X_mm = np.memmap(train_X_path, dtype=DTYPE_X_MEMMAP, mode="r", shape=(N, L, Fdim))
B_mm = np.memmap(train_B_path, dtype=np.int8,        mode="r", shape=(N, L))
M_mm = np.memmap(train_M_path, dtype=np.int8,        mode="r", shape=(N, L))

# ----------------------------
# 2c) Robust value feature resolution
# ----------------------------
def _pick_val_feat(feat_map, prefer=None):
    cand = []
    if prefer is not None:
        cand.append(str(prefer))
    cand += [
        "signal", "value", "flux",
        "flux_asinh", "flux_asinh_clip", "flux_asinh_norm", "flux_asinh_scaled",
        "mag", "mag_norm", "mag_clip", "mag_scaled",
        "delta_signal", "delta_flux", "signal_clip", "signal_norm",
    ]
    for c in cand:
        if c and (c in feat_map):
            return c
    keys = list(feat_map.keys())
    fuzzy = [k for k in keys if any(t in k for t in ["signal", "flux", "value", "asinh", "mag"])]
    if fuzzy:
        return sorted(fuzzy)[0]
    return None

if SEQ_TOKEN_MODE is None:
    SEQ_TOKEN_MODE = globals().get("SEQ_TOKEN_MODE", None)
if SEQ_TOKEN_MODE is None:
    SEQ_TOKEN_MODE = "mag" if any((k == "mag") or k.startswith("mag") for k in feat.keys()) else "asinh"
SEQ_TOKEN_MODE = str(SEQ_TOKEN_MODE).lower().strip()

VAL_FEAT = _pick_val_feat(feat, prefer=SCORE_VALUE_FEAT)
if VAL_FEAT is None:
    raise RuntimeError("Cannot resolve VAL_FEAT from SEQ_FEATURE_NAMES.")

for k in ["snr_tanh","detected"]:
    if k not in feat:
        raise RuntimeError(f"Feature '{k}' not found in SEQ_FEATURE_NAMES.")

VAL_IS_MAG = (SEQ_TOKEN_MODE == "mag") and ("mag" in VAL_FEAT)

print(f"[Stage 8] token_mode={SEQ_TOKEN_MODE} | VAL_FEAT={VAL_FEAT} | VAL_IS_MAG={VAL_IS_MAG}")
print(f"[Stage 8] SHIFT_BAND_IDS(from stage6)={SHIFT_BAND_IDS} | PAD_BAND_ID={PAD_BAND_ID} | X_dtype={DTYPE_X_MEMMAP}")

# ----------------------------
# 3) Build RAW meta global features (no leak)
# ----------------------------
if ("EBV_clip" in meta.columns):
    EBV_used = pd.to_numeric(meta["EBV_clip"], errors="coerce")
elif ("EBV" in meta.columns):
    EBV_used = pd.to_numeric(meta["EBV"], errors="coerce")
else:
    EBV_used = pd.Series(np.zeros((len(meta),), dtype=np.float32), index=meta.index)

BASE_G_COLS = ["Z","Z_err","EBV_used","Z_missing","Z_err_missing","EBV_missing","is_photoz"]

tmp_meta = meta.copy()
tmp_meta["EBV_used"] = EBV_used
for c in BASE_G_COLS:
    if c not in tmp_meta.columns:
        tmp_meta[c] = 0.0

G_meta = tmp_meta.iloc[pos_idx][BASE_G_COLS].copy()
for c in BASE_G_COLS:
    G_meta[c] = pd.to_numeric(G_meta[c], errors="coerce").fillna(0.0).astype(np.float32)
G_meta_np = G_meta.to_numpy(dtype=np.float32, copy=False)

with open(Path(LOG_DIR)/"global_meta_cols.json", "w", encoding="utf-8") as f:
    json.dump({"cols": BASE_G_COLS}, f, indent=2)

# ----------------------------
# 3b) Sequence aggregate features (global + per-band) + CACHE
# ----------------------------
USE_AGG_SEQ_FEATURES = True
N_BANDS = 6

def _safe_div(a, b):
    return a / np.maximum(b, 1.0)

def _auto_chunk(L, F, target_mb=220, min_chunk=64, max_chunk=2048):
    # approx bytes = chunk * L * F * 4
    target_bytes = float(target_mb) * (1024**2)
    denom = float(L) * float(F) * 4.0
    if denom <= 0:
        return min_chunk
    c = int(target_bytes // denom)
    c = max(min_chunk, min(max_chunk, c))
    return int(c)

def build_agg_seq_features(X_mm, B_mm, M_mm, chunk=None):
    snr_i = feat["snr_tanh"]
    det_i = feat["detected"]
    val_i = feat[VAL_FEAT]

    if chunk is None:
        chunk = _auto_chunk(L, Fdim, target_mb=220, min_chunk=64, max_chunk=1024)

    out_chunks = []
    for s in range(0, N, chunk):
        e = min(N, s + chunk)

        # NOTE: slicing memmap will read chunk; keep chunk modest to avoid RAM spikes
        Xc = np.asarray(X_mm[s:e])  # (B,L,F)
        Bc = np.asarray(B_mm[s:e])  # (B,L)
        Mc = np.asarray(M_mm[s:e])  # (B,L)

        real = (Mc == 1)
        tok_count = real.sum(axis=1).astype(np.float32)

        snr = np.abs(Xc[:, :, snr_i]).astype(np.float32, copy=False)
        det = (Xc[:, :, det_i] > 0.5).astype(np.float32, copy=False)
        val = Xc[:, :, val_i].astype(np.float32, copy=False)

        snr_r = snr * real
        det_r = det * real

        det_frac = _safe_div(det_r.sum(axis=1), tok_count)
        mean_abs_snr = _safe_div(snr_r.sum(axis=1), tok_count)
        max_abs_snr = np.where(tok_count > 0, snr_r.max(axis=1), 0.0).astype(np.float32)

        if VAL_IS_MAG:
            val_r = np.where(real, val, np.nan)
            mean_val = np.nan_to_num(np.nanmean(val_r, axis=1).astype(np.float32), nan=0.0)
            std_val  = np.nan_to_num(np.nanstd(val_r, axis=1).astype(np.float32),  nan=0.0)
            min_val  = np.nan_to_num(np.nanmin(val_r, axis=1).astype(np.float32),  nan=0.0)
            global_val_feats = np.stack([mean_val, std_val, min_val], axis=1).astype(np.float32)
        else:
            aval = np.abs(val).astype(np.float32, copy=False)
            aval_r = aval * real
            mean_aval = _safe_div(aval_r.sum(axis=1), tok_count).astype(np.float32)
            val_r = np.where(real, val, np.nan)
            std_val = np.nan_to_num(np.nanstd(val_r, axis=1).astype(np.float32), nan=0.0)
            max_aval = np.where(tok_count > 0, aval_r.max(axis=1), 0.0).astype(np.float32)
            global_val_feats = np.stack([mean_aval, std_val, max_aval], axis=1).astype(np.float32)

        # band ids to [0..N_BANDS-1]
        if SHIFT_BAND_IDS:
            Badj = Bc.astype(np.int16, copy=False)
            Badj = np.where(real, np.clip(Badj - 1, 0, N_BANDS - 1), 0).astype(np.int16, copy=False)
        else:
            Badj = Bc.astype(np.int16, copy=False)

        per_band = []
        for b in range(N_BANDS):
            bm = (Badj == b) & real
            cnt = bm.sum(axis=1).astype(np.float32)

            detb = (det * bm).sum(axis=1).astype(np.float32)
            snrb = (snr * bm).sum(axis=1).astype(np.float32)

            det_frac_b = _safe_div(detb, cnt)
            mean_abs_snr_b = _safe_div(snrb, cnt)

            if VAL_IS_MAG:
                vb = np.where(bm, val, np.nan)
                mean_val_b = np.nan_to_num(np.nanmean(vb, axis=1).astype(np.float32), nan=0.0)
            else:
                ab = (np.abs(val).astype(np.float32) * bm).sum(axis=1).astype(np.float32)
                mean_val_b = _safe_div(ab, cnt).astype(np.float32)

            per_band.append(np.stack([cnt, det_frac_b, mean_abs_snr_b, mean_val_b], axis=1))

        per_band = np.concatenate(per_band, axis=1).astype(np.float32)

        glob = np.stack([tok_count, det_frac, mean_abs_snr, max_abs_snr], axis=1).astype(np.float32)
        agg = np.concatenate([glob, global_val_feats, per_band], axis=1).astype(np.float32)

        out_chunks.append(agg)

        del Xc, Bc, Mc, Badj
        if ((s // chunk) % 4) == 0:
            gc.collect()

    return np.concatenate(out_chunks, axis=0).astype(np.float32)

# ---- cache key for G_seq ----
def _hash_cfg(d: dict) -> str:
    s = json.dumps(d, sort_keys=True, ensure_ascii=True)
    return hashlib.md5(s.encode("utf-8")).hexdigest()[:12]

G_CACHE_DIR = FIX_DIR
G_RAW_CACHE = G_CACHE_DIR / "global_features_raw.npy"
G_RAW_META  = G_CACHE_DIR / "global_features_raw_meta.json"

agg_spec = {
    "use_agg_seq": bool(USE_AGG_SEQ_FEATURES),
    "N": int(N),
    "L": int(L),
    "Fdim": int(Fdim),
    "n_bands": int(N_BANDS),
    "seq_token_mode": str(SEQ_TOKEN_MODE),
    "val_feat": str(VAL_FEAT),
    "val_is_mag": bool(VAL_IS_MAG),
    "shift_band_ids_from_stage6": bool(SHIFT_BAND_IDS),
    "pad_band_id_from_stage6": int(PAD_BAND_ID),
    "dtype_X_memmap": str(DTYPE_X_MEMMAP),
    "meta_cols": list(BASE_G_COLS),
}
agg_hash = _hash_cfg(agg_spec)

G_raw_np = None
if G_RAW_CACHE.exists() and G_RAW_META.exists():
    try:
        old = json.loads(G_RAW_META.read_text())
        if old.get("agg_hash", None) == agg_hash and int(old.get("N", -1)) == int(N):
            G_raw_np = np.load(G_RAW_CACHE, allow_pickle=False).astype(np.float32, copy=False)
            if G_raw_np.shape[0] != N:
                G_raw_np = None
    except Exception:
        G_raw_np = None

if G_raw_np is None:
    if USE_AGG_SEQ_FEATURES:
        chunk0 = _auto_chunk(L, Fdim, target_mb=220, min_chunk=64, max_chunk=1024)
        print(f"[Stage 8] Building AGG sequence features (one-time, cached)... chunk={chunk0}")
        t0 = time.time()
        G_seq_np = build_agg_seq_features(X_mm, B_mm, M_mm, chunk=chunk0)
        print(f"[Stage 8] AGG built: shape={G_seq_np.shape} | time={time.time()-t0:.1f}s")
    else:
        G_seq_np = np.zeros((N,0), dtype=np.float32)

    G_raw_np = np.concatenate([G_meta_np, G_seq_np], axis=1).astype(np.float32)
    np.save(G_RAW_CACHE, G_raw_np.astype(np.float32, copy=False))
    G_RAW_META.write_text(json.dumps({"agg_hash": agg_hash, "N": int(N), "spec": agg_spec}, indent=2))

g_dim = int(G_raw_np.shape[1])

with open(Path(LOG_DIR)/"global_feature_spec.json", "w", encoding="utf-8") as f:
    json.dump(
        {
            "agg_hash": agg_hash,
            "meta_cols": BASE_G_COLS,
            "use_agg_seq": bool(USE_AGG_SEQ_FEATURES),
            "token_mode": SEQ_TOKEN_MODE,
            "val_feat": VAL_FEAT,
            "val_is_mag": bool(VAL_IS_MAG),
            "total_g_dim": int(g_dim),
            "shift_band_ids_from_stage6": bool(SHIFT_BAND_IDS),
            "pad_band_id_from_stage6": int(PAD_BAND_ID),
            "dtype_X_memmap": str(DTYPE_X_MEMMAP),
            "score_value_feat_from_stage6": (None if SCORE_VALUE_FEAT is None else str(SCORE_VALUE_FEAT)),
            "cache": {"path": str(G_RAW_CACHE), "meta": str(G_RAW_META)},
        },
        f,
        indent=2,
    )

# ----------------------------
# 4) Dataset / Loader
# ----------------------------
class MemmapSeqDataset(torch.utils.data.Dataset):
    def __init__(self, idx, X_mm, B_mm, M_mm, G_raw_np, y=None):
        self.idx = np.asarray(idx, dtype=np.int32)
        self.X_mm = X_mm
        self.B_mm = B_mm
        self.M_mm = M_mm
        self.G_raw = G_raw_np
        self.y = None if y is None else np.asarray(y, dtype=np.int8)

    def __len__(self):
        return len(self.idx)

    def __getitem__(self, i):
        j = int(self.idx[i])
        X = np.asarray(self.X_mm[j])                   # (L,F)
        B = np.asarray(self.B_mm[j]).astype(np.int64)  # (L,)
        M = np.asarray(self.M_mm[j]).astype(np.int64)  # (L,)
        G0 = np.asarray(self.G_raw[j], dtype=np.float32)

        Xt = torch.from_numpy(X)   # dtype as memmap; model casts to float32
        Bt = torch.from_numpy(B)
        Mt = torch.from_numpy(M)
        Gt = torch.from_numpy(G0)

        if self.y is None:
            return Xt, Bt, Mt, Gt

        yy = float(self.y[j])
        return Xt, Bt, Mt, Gt, torch.tensor(yy, dtype=torch.float32)

def make_loader(ds, batch_size, shuffle, sampler=None):
    return torch.utils.data.DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=(sampler is None and shuffle),
        sampler=sampler,
        num_workers=0,
        pin_memory=False,
        drop_last=False,
    )

# ----------------------------
# 5) EMA helper
# ----------------------------
class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = float(decay)
        self.shadow = {}
        self.backup = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.shadow[name] = p.detach().clone()

    @torch.no_grad()
    def update(self, model):
        d = self.decay
        for name, p in model.named_parameters():
            if not p.requires_grad:
                continue
            self.shadow[name].mul_(d).add_(p.detach(), alpha=(1.0 - d))

    def store(self, model):
        self.backup = {}
        for name, p in model.named_parameters():
            if p.requires_grad:
                self.backup[name] = p.detach().clone()

    def copy_to(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad:
                p.data.copy_(self.shadow[name].data)

    def restore(self, model):
        for name, p in model.named_parameters():
            if p.requires_grad:
                p.data.copy_(self.backup[name].data)
        self.backup = {}

# ----------------------------
# 6) Model
# ----------------------------
class MultibandEventTransformer(nn.Module):
    def __init__(
        self,
        feat_dim, max_len, n_bands=6,
        d_model=160, n_heads=4, n_layers=3, ff_mult=2,
        dropout=0.14, g_dim=0,
        shift_band_ids=False
    ):
        super().__init__()
        self.n_bands = int(n_bands)
        self.d_model = int(d_model)
        self.max_len = int(max_len)
        self.shift_band_ids = bool(shift_band_ids)

        self.x_proj = nn.Sequential(
            nn.Linear(feat_dim, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
        )
        self.x_ln = nn.LayerNorm(d_model)  # stabilizer

        self.band_emb = nn.Embedding(self.n_bands, d_model)

        self.pos_emb = nn.Parameter(torch.zeros(1, max_len, d_model))
        nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=int(d_model * ff_mult),
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)

        self.attn = nn.Linear(d_model, 1)
        self.pool_ln = nn.LayerNorm(d_model)

        g_out = max(32, d_model // 2)
        self.g_proj = nn.Sequential(
            nn.Linear(g_dim, g_out),
            nn.GELU(),
            nn.Dropout(dropout),
        )

        self.head = nn.Sequential(
            nn.Linear(d_model + g_out, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 1),
        )

    def set_global_scaler(self, mean_np, std_np):
        mean_t = torch.tensor(mean_np, dtype=torch.float32)
        std_t  = torch.tensor(std_np, dtype=torch.float32)
        self.register_buffer("g_mean_buf", mean_t, persistent=False)
        self.register_buffer("g_std_buf", std_t, persistent=False)

    def forward(self, X, band_id, mask, G_raw):
        X = X.to(torch.float32)
        band_id = band_id.to(torch.long)
        mask = mask.to(torch.long)

        if self.shift_band_ids:
            real = (mask == 1)
            if real.any():
                band2 = band_id.clone()
                band2[real] = (band2[real] - 1).clamp(0, self.n_bands - 1)
                band2[~real] = 0
                band_id = band2
            else:
                band_id = torch.zeros_like(band_id)

        band_id = band_id.clamp(0, self.n_bands - 1)

        pad_mask = (mask == 0)
        all_pad = pad_mask.all(dim=1)
        if all_pad.any():
            pad_mask = pad_mask.clone()
            pad_mask[all_pad, 0] = False

        h = self.x_proj(X)
        h = self.x_ln(h)
        h = h + self.band_emb(band_id) + self.pos_emb[:, :X.shape[1], :]

        h = self.encoder(h, src_key_padding_mask=pad_mask)

        # attention pooling
        a = self.attn(h).squeeze(-1)
        a = a.masked_fill(pad_mask, -1e9)
        w = torch.softmax(a, dim=1)
        pooled_attn = torch.sum(h * w.unsqueeze(-1), dim=1)

        # mean pooling
        valid = (~pad_mask).to(h.dtype).unsqueeze(-1)
        denom = valid.sum(dim=1).clamp_min(1.0)
        pooled_mean = (h * valid).sum(dim=1) / denom

        # max pooling
        h_masked = h.masked_fill(pad_mask.unsqueeze(-1), -1e9)
        pooled_max = torch.max(h_masked, dim=1).values
        pooled_max = torch.where(torch.isfinite(pooled_max), pooled_max, torch.zeros_like(pooled_max))

        pooled = (0.50 * pooled_attn) + (0.30 * pooled_mean) + (0.20 * pooled_max)
        pooled = self.pool_ln(pooled)

        # global
        G = G_raw.to(torch.float32)
        G = (G - self.g_mean_buf) / self.g_std_buf
        g = self.g_proj(G)

        z = torch.cat([pooled, g], dim=1)
        return self.head(z).squeeze(-1)

# ----------------------------
# 7) Training config (improved defaults)
# ----------------------------
CFG = {
    "d_model": 160,
    "n_heads": 4,
    "n_layers": 3,
    "ff_mult": 2,
    "dropout": 0.14,

    "batch_size": 16,
    "grad_accum": 2,

    "epochs": 18,
    "lr": 5e-4,
    "weight_decay": 0.02,

    "patience": 6,
    "max_grad_norm": 1.0,

    # imbalance:
    # "pos_weight" recommended default; PR-AUC becomes primary metric
    "balance_mode": "pos_weight",   # pos_weight / sampler / both
    "label_smoothing": 0.02,

    # optional focal (set >0 if imbalance ekstrem; start 1.0~2.0)
    "focal_gamma": 0.0,             # 0.0 = OFF

    "primary_metric": "ap",         # "ap" (PR-AUC) or "auc"
    "scheduler": "onecycle",

    "use_ema": True,
    "ema_decay": 0.999,

    # augmentation (keep moderate on CPU)
    "aug_tokendrop_p": 0.05,
    "aug_value_noise": 0.010,
    "aug_featdrop_p": 0.00,
}

# adapt for long sequences
if L >= 512:
    CFG["d_model"] = 128
    CFG["n_heads"] = 4
    CFG["n_layers"] = 2
    CFG["batch_size"] = 12
    CFG["grad_accum"] = 2
    CFG["lr"] = 4e-4
    CFG["dropout"] = 0.16

cfg_path = Path(LOG_DIR) / "train_cfg_stage8.json"
with open(cfg_path, "w", encoding="utf-8") as f:
    json.dump(CFG, f, indent=2)

pos_all = int((y == 1).sum())
neg_all = int((y == 0).sum())
print("[Stage 8] TRAIN CONFIG (CPU)")
print(f"- N={N:,} | pos={pos_all:,} | neg={neg_all:,} | pos%={pos_all/max(N,1)*100:.6f}%")
print(f"- token_mode={SEQ_TOKEN_MODE} | VAL_FEAT={VAL_FEAT} | val_is_mag={VAL_IS_MAG} | g_dim={g_dim} | use_agg_seq={USE_AGG_SEQ_FEATURES}")
print(f"- Model: d_model={CFG['d_model']} heads={CFG['n_heads']} layers={CFG['n_layers']} dropout={CFG['dropout']}")
print(f"- Batch={CFG['batch_size']} grad_accum={CFG['grad_accum']} epochs={CFG['epochs']} lr={CFG['lr']}")
print(f"- balance_mode={CFG['balance_mode']} | focal_gamma={CFG['focal_gamma']} | primary_metric={CFG['primary_metric']}")
print(f"- ema={CFG['use_ema']}({CFG['ema_decay']})")
print(f"- CKPT_DIR={CKPT_DIR}")
print(f"- OOF_DIR ={OOF_DIR}")
print(f"- LOG_DIR ={LOG_DIR}")

# ----------------------------
# 8) Helpers
# ----------------------------
def sigmoid_np(x):
    x = np.clip(x, -50, 50)
    return 1.0 / (1.0 + np.exp(-x))

def f1_binary(y_true, y_pred01):
    y_true = y_true.astype(np.int32)
    y_pred01 = y_pred01.astype(np.int32)
    tp = int(((y_true == 1) & (y_pred01 == 1)).sum())
    fp = int(((y_true == 0) & (y_pred01 == 1)).sum())
    fn = int(((y_true == 1) & (y_pred01 == 0)).sum())
    if tp == 0:
        return 0.0
    prec = tp / max(tp + fp, 1)
    rec  = tp / max(tp + fn, 1)
    if prec + rec == 0:
        return 0.0
    return float(2 * prec * rec / (prec + rec))

def best_f1_threshold(y_true, prob, grid=240):
    y_true = y_true.astype(np.int8)
    prob = prob.astype(np.float32)
    best_thr, best_f1 = 0.5, -1.0
    for thr in np.linspace(0.01, 0.99, grid, dtype=np.float32):
        f1 = f1_binary(y_true, (prob >= thr).astype(np.int8))
        if f1 > best_f1:
            best_f1, best_thr = float(f1), float(thr)
    return best_thr, best_f1

@torch.inference_mode()
def eval_model(model, loader, criterion_eval, use_ema=False, ema=None):
    model.eval()
    if use_ema and (ema is not None):
        ema.store(model)
        ema.copy_to(model)

    losses, logits_all, y_all = [], [], []
    for batch in loader:
        Xb, Bb, Mb, Gb, yb = batch
        Xb = Xb.to(device); Bb = Bb.to(device); Mb = Mb.to(device); Gb = Gb.to(device); yb = yb.to(device)
        logit = model(Xb, Bb, Mb, Gb)
        loss = criterion_eval(logit, yb)
        losses.append(float(loss.item()))
        logits_all.append(logit.detach().cpu().numpy())
        y_all.append(yb.detach().cpu().numpy())

    if use_ema and (ema is not None):
        ema.restore(model)

    logits_all = np.concatenate(logits_all, axis=0) if logits_all else np.zeros((0,), dtype=np.float32)
    y_all = np.concatenate(y_all, axis=0).astype(np.int8) if y_all else np.zeros((0,), dtype=np.int8)

    probs = sigmoid_np(logits_all)
    pred01 = (probs >= 0.5).astype(np.int8)

    f1 = f1_binary(y_all, pred01)
    if len(np.unique(y_all)) == 2:
        auc = float(roc_auc_score(y_all, probs))
        ap  = float(average_precision_score(y_all, probs))
    else:
        auc, ap = float("nan"), float("nan")

    return float(np.mean(losses) if losses else np.nan), probs, y_all, f1, auc, ap

def fit_scaler_fold(G_raw_np, tr_idx):
    X = G_raw_np[tr_idx]
    mean = X.mean(axis=0).astype(np.float32)
    std  = X.std(axis=0).astype(np.float32)
    std  = np.where(std < 1e-6, 1.0, std).astype(np.float32)
    return mean, std

def make_adamw_param_groups(model, weight_decay):
    decay, no_decay = [], []
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        ln = name.lower()
        if name.endswith(".bias") or ("norm" in ln) or ("ln" in ln):
            no_decay.append(p)
        else:
            decay.append(p)
    return [
        {"params": decay, "weight_decay": float(weight_decay)},
        {"params": no_decay, "weight_decay": 0.0},
    ]

# focal with logits (optional) + optional pos_weight already inside BCE
def make_train_loss(pos_weight_t=None, focal_gamma=0.0):
    focal_gamma = float(focal_gamma)
    bce = nn.BCEWithLogitsLoss(pos_weight=pos_weight_t, reduction="none") if pos_weight_t is not None else nn.BCEWithLogitsLoss(reduction="none")
    def _loss(logit, target):
        # target in [0,1]
        target = target.to(logit.dtype)
        base = bce(logit, target)  # (B,)
        if focal_gamma <= 0:
            return base.mean()
        p = torch.sigmoid(logit)
        pt = torch.where(target > 0.5, p, 1.0 - p)
        w = (1.0 - pt).clamp_min(0.0).pow(focal_gamma)
        return (w * base).mean()
    return _loss

# ----------------------------
# 9) Batch augment (safe)
# ----------------------------
def apply_batch_aug(Xb, Bb, Mb, cfg, feat_map, val_feat_name):
    p_drop  = float(cfg.get("aug_tokendrop_p", 0.0))
    noise   = float(cfg.get("aug_value_noise", 0.0))
    p_fdrop = float(cfg.get("aug_featdrop_p", 0.0))

    # token drop
    if p_drop and p_drop > 0:
        real = (Mb == 1)
        if real.any():
            rnd = torch.rand_like(Mb.float())
            drop = (rnd < p_drop) & real

            # ensure at least 1 token remains for samples that had any
            nreal = real.sum(dim=1)
            ndrop = drop.sum(dim=1)
            bad = (nreal > 0) & (ndrop >= nreal)
            if bad.any():
                bad_idx = torch.where(bad)[0].tolist()
                for bi in bad_idx:
                    pos = torch.where(real[bi])[0]
                    if pos.numel() > 0:
                        keep_one = pos[int(torch.randint(0, pos.numel(), (1,)).item())]
                        drop[bi, keep_one] = False

            Mb = Mb.clone()
            Mb[drop] = 0

    # value noise
    if noise and noise > 0 and (val_feat_name in feat_map):
        vi = int(feat_map[val_feat_name])
        real = (Mb == 1)
        n = int(real.sum().item())
        if n > 0:
            eps = torch.randn((n,), device=Xb.device, dtype=Xb.dtype) * float(noise)
            Xb = Xb.clone()
            col = Xb[:, :, vi].clone()
            col[real] = col[real] + eps
            Xb[:, :, vi] = col

    # feature dropout (optional): only on key features
    if p_fdrop and p_fdrop > 0:
        real = (Mb == 1)
        if real.any():
            cand_feats = []
            for nm in [val_feat_name, "snr_tanh"]:
                if nm in feat_map:
                    cand_feats.append(int(feat_map[nm]))
            if cand_feats:
                Xb = Xb.clone()
                rnd = torch.rand_like(Mb.float())
                for fi in cand_feats:
                    mask_drop = (rnd < p_fdrop) & real
                    col = Xb[:, :, fi].clone()
                    col[mask_drop] = 0.0
                    Xb[:, :, fi] = col

    return Xb, Bb, Mb

# ----------------------------
# 10) CV Train
# ----------------------------
oof_prob = np.full((N,), np.nan, dtype=np.float32)
fold_metrics = []

start_time = time.time()
n_splits = int(globals()["n_splits"])
cv_type = str(globals().get("CV_TYPE", ""))

primary_metric = str(CFG.get("primary_metric", "ap")).lower().strip()
if primary_metric not in ("ap", "auc"):
    primary_metric = "ap"

for fold_info in globals()["folds"]:
    fold = int(fold_info.get("fold", 0))
    tr_idx = np.asarray(fold_info["train_idx"], dtype=np.int32)
    val_idx = np.asarray(fold_info["val_idx"], dtype=np.int32)

    y_tr = y[tr_idx]
    pos_f = int((y_tr == 1).sum())
    neg_f = int((y_tr == 0).sum())
    if pos_f == 0:
        raise RuntimeError(f"[Stage 8] Fold {fold}: no positives in training split.")

    balance_mode = str(CFG.get("balance_mode", "pos_weight")).lower().strip()
    use_sampler = balance_mode in ("sampler", "both")
    use_posw    = balance_mode in ("pos_weight", "both")

    pos_weight = float(neg_f / max(pos_f, 1))
    pos_weight_t = torch.tensor([pos_weight], dtype=torch.float32, device=device) if use_posw else None

    ls = float(CFG.get("label_smoothing", 0.0))
    def smooth(yb):
        if ls <= 0:
            return yb
        return yb * (1.0 - ls) + 0.5 * ls

    # train loss (weighted + optional focal)
    focal_gamma = float(CFG.get("focal_gamma", 0.0))
    loss_train_fn = make_train_loss(pos_weight_t=pos_weight_t, focal_gamma=focal_gamma)
    # eval loss (unweighted BCE for comparability)
    criterion_eval  = nn.BCEWithLogitsLoss()

    print(f"\n[Stage 8] FOLD {fold} | train={len(tr_idx):,} val={len(val_idx):,} "
          f"| pos={pos_f:,} neg={neg_f:,} | pos_weight={pos_weight:.4f} | balance_mode={balance_mode} | primary={primary_metric}")

    g_mean, g_std = fit_scaler_fold(G_raw_np, tr_idx)

    ds_tr = MemmapSeqDataset(tr_idx, X_mm, B_mm, M_mm, G_raw_np, y=y)
    ds_va = MemmapSeqDataset(val_idx, X_mm, B_mm, M_mm, G_raw_np, y=y)

    sampler = None
    if use_sampler:
        ytr_local = y[tr_idx]
        w = np.ones((len(tr_idx),), dtype=np.float32)
        w[ytr_local == 1] = float(neg_f / max(pos_f, 1))
        sampler = torch.utils.data.WeightedRandomSampler(
            weights=torch.from_numpy(w),
            num_samples=len(tr_idx),
            replacement=True
        )

    dl_tr = make_loader(ds_tr, batch_size=int(CFG["batch_size"]), shuffle=True, sampler=sampler)
    dl_va = make_loader(ds_va, batch_size=int(CFG["batch_size"]), shuffle=False)

    model = MultibandEventTransformer(
        feat_dim=Fdim,
        max_len=L,
        n_bands=6,
        d_model=int(CFG["d_model"]),
        n_heads=int(CFG["n_heads"]),
        n_layers=int(CFG["n_layers"]),
        ff_mult=int(CFG["ff_mult"]),
        dropout=float(CFG["dropout"]),
        g_dim=g_dim,
        shift_band_ids=bool(SHIFT_BAND_IDS),
    ).to(device)
    model.set_global_scaler(g_mean, g_std)

    param_groups = make_adamw_param_groups(model, weight_decay=float(CFG["weight_decay"]))
    opt = torch.optim.AdamW(param_groups, lr=float(CFG["lr"]))

    scheduler = None
    grad_accum = int(CFG["grad_accum"])
    if str(CFG.get("scheduler","")).lower() == "onecycle":
        steps_per_epoch_opt = int(math.ceil(len(dl_tr) / max(grad_accum, 1)))
        steps_per_epoch_opt = max(steps_per_epoch_opt, 1)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            opt,
            max_lr=float(CFG["lr"]),
            epochs=int(CFG["epochs"]),
            steps_per_epoch=steps_per_epoch_opt,
            pct_start=0.1,
            anneal_strategy="cos",
            div_factor=10.0,
            final_div_factor=50.0,
        )

    use_ema = bool(CFG.get("use_ema", True))
    ema = EMA(model, decay=float(CFG.get("ema_decay", 0.999))) if use_ema else None

    # best tracking
    best_score = -1e18
    best_val_loss = float("inf")
    best_epoch = -1
    best_probs = None
    best_thr = 0.5
    best_f1_at_bestthr = -1.0
    patience_left = int(CFG["patience"])

    # history log
    hist_rows = []

    for epoch in range(1, int(CFG["epochs"]) + 1):
        model.train()
        opt.zero_grad(set_to_none=True)

        total_loss = 0.0
        n_batches = 0
        accum = 0
        opt_steps = 0

        for batch in dl_tr:
            Xb, Bb, Mb, Gb, yb = batch
            Xb = Xb.to(device).to(torch.float32)
            Bb = Bb.to(device)
            Mb = Mb.to(device)
            Gb = Gb.to(device)
            yb = yb.to(device)

            # aug
            Xb, Bb, Mb = apply_batch_aug(Xb, Bb, Mb, CFG, feat, VAL_FEAT)

            yb_s = smooth(yb)
            logit = model(Xb, Bb, Mb, Gb)
            loss = loss_train_fn(logit, yb_s)

            total_loss += float(loss.item())
            n_batches += 1

            (loss / float(grad_accum)).backward()
            accum += 1

            if accum == grad_accum:
                if CFG["max_grad_norm"] is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(CFG["max_grad_norm"]))
                opt.step()
                opt.zero_grad(set_to_none=True)
                opt_steps += 1
                accum = 0
                if scheduler is not None:
                    scheduler.step()
                if ema is not None:
                    ema.update(model)

        if accum > 0:
            if CFG["max_grad_norm"] is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(CFG["max_grad_norm"]))
            opt.step()
            opt.zero_grad(set_to_none=True)
            opt_steps += 1
            if scheduler is not None:
                scheduler.step()
            if ema is not None:
                ema.update(model)

        train_loss = total_loss / max(n_batches, 1)

        val_loss, probs, y_val, f1_05, val_auc, val_ap = eval_model(model, dl_va, criterion_eval, use_ema=use_ema, ema=ema)
        score = float(val_ap) if primary_metric == "ap" else float(val_auc)

        # improve rule: primary metric first, tie-break val_loss
        improved = (score > best_score + 1e-7) or (math.isnan(best_score) and not math.isnan(score))
        if (not improved) and (abs(score - best_score) <= 1e-7) and (val_loss < best_val_loss - 1e-6):
            improved = True

        if improved:
            best_score = float(score)
            best_val_loss = float(val_loss)
            best_epoch = int(epoch)
            best_probs = probs.copy()
            best_thr, best_f1_at_bestthr = best_f1_threshold(y_val, best_probs, grid=240)

            ckpt_path = CKPT_DIR / f"fold_{fold}.pt"
            payload = {
                "fold": fold,
                "epoch": epoch,
                "model_state": model.state_dict(),
                "cfg": CFG,
                "seq_feature_names": SEQ_FEATURE_NAMES,
                "max_len": L,
                "token_mode": SEQ_TOKEN_MODE,
                "val_feat": VAL_FEAT,
                "val_is_mag": bool(VAL_IS_MAG),
                "global_meta_cols": BASE_G_COLS,
                "use_agg_seq_features": bool(USE_AGG_SEQ_FEATURES),
                "global_feature_cache": {"path": str(G_RAW_CACHE), "meta": str(G_RAW_META), "agg_hash": agg_hash},
                "global_scaler": {"mean": g_mean.astype(np.float32), "std": g_std.astype(np.float32)},
                "pos_weight_train": float(pos_weight),
                "balance_mode": balance_mode,
                "focal_gamma": float(focal_gamma),
                "shift_band_ids_from_stage6": bool(SHIFT_BAND_IDS),
                "pad_band_id_from_stage6": int(PAD_BAND_ID),
                "dtype_X_memmap": str(DTYPE_X_MEMMAP),
                "cv_type": str(cv_type),
                "score_value_feat_from_stage6": (None if SCORE_VALUE_FEAT is None else str(SCORE_VALUE_FEAT)),
                "primary_metric": primary_metric,
                "best_thr_val_f1": float(best_thr),
                "best_f1_val": float(best_f1_at_bestthr),
                "best_val_auc": float(val_auc),
                "best_val_ap": float(val_ap),
                "best_val_loss": float(best_val_loss),
            }
            if ema is not None:
                payload["ema_shadow"] = {k: v.detach().cpu() for k, v in ema.shadow.items()}
                payload["ema_decay"] = float(ema.decay)

            torch.save(payload, ckpt_path)
            patience_left = int(CFG["patience"])
        else:
            patience_left -= 1

        lr_now = opt.param_groups[0]["lr"]
        hist_rows.append({
            "epoch": int(epoch),
            "lr": float(lr_now),
            "opt_steps": int(opt_steps),
            "train_loss": float(train_loss),
            "val_loss": float(val_loss),
            "val_auc": float(val_auc),
            "val_ap": float(val_ap),
            "f1_at_0p5": float(f1_05),
            "score_primary": float(score),
            "best_epoch": int(best_epoch),
            "patience_left": int(patience_left),
        })

        print(f"  epoch {epoch:02d} | lr={lr_now:.2e} | opt_steps={opt_steps:4d} | "
              f"train_loss={train_loss:.5f} | val_loss={val_loss:.5f} | auc={val_auc:.5f} | ap={val_ap:.5f} | "
              f"f1@0.5={f1_05:.4f} | best_ep={best_epoch} | pat={patience_left}")

        if patience_left <= 0:
            break

    # save history
    try:
        pd.DataFrame(hist_rows).to_csv(Path(LOG_DIR)/f"fold_{fold}_history.csv", index=False)
    except Exception:
        pass

    if best_probs is None:
        raise RuntimeError(f"Fold {fold}: best_probs is None (unexpected).")

    oof_prob[val_idx] = best_probs.astype(np.float32)

    pred01 = (best_probs >= 0.5).astype(np.int8)
    best_f1_05 = f1_binary(y[val_idx], pred01)

    fold_metrics.append({
        "fold": fold,
        "val_size": int(len(val_idx)),
        "best_epoch": int(best_epoch),
        "primary_metric": primary_metric,
        "best_primary_score": float(best_score),
        "best_val_loss": float(best_val_loss),
        "val_auc_at_best": float(val_auc),
        "val_ap_at_best": float(val_ap),
        "f1_at_0p5": float(best_f1_05),
        "best_thr_val_f1": float(best_thr),
        "best_f1_val": float(best_f1_at_bestthr),
        "pos_weight_train": float(pos_weight),
        "focal_gamma": float(focal_gamma),
        "g_dim": int(g_dim),
        "use_agg_seq": bool(USE_AGG_SEQ_FEATURES),
        "balance_mode": balance_mode,
        "shift_band_ids_from_stage6": bool(SHIFT_BAND_IDS),
        "ema_used": bool(use_ema),
        "val_feat": str(VAL_FEAT),
        "val_is_mag": bool(VAL_IS_MAG),
    })

    del model, opt, ds_tr, ds_va, dl_tr, dl_va, ema
    gc.collect()

elapsed = time.time() - start_time

# ----------------------------
# 11) Save OOF artifacts + summary
# ----------------------------
oof_path_npy = OOF_DIR / "oof_prob.npy"
np.save(oof_path_npy, oof_prob)

df_oof = pd.DataFrame({"object_id": train_ids, "target": y.astype(int), "oof_prob": oof_prob.astype(np.float32)})
oof_path_csv = OOF_DIR / "oof_prob.csv"
df_oof.to_csv(oof_path_csv, index=False)

metrics_path = OOF_DIR / "fold_metrics.json"
with open(metrics_path, "w", encoding="utf-8") as f:
    json.dump({"fold_metrics": fold_metrics, "elapsed_sec": float(elapsed), "cv_type": str(cv_type)}, f, indent=2)

valid = np.isfinite(oof_prob)
if valid.any() and len(np.unique(y[valid])) == 2:
    oof_auc = float(roc_auc_score(y[valid], oof_prob[valid]))
    oof_ap  = float(average_precision_score(y[valid], oof_prob[valid]))
else:
    oof_auc = float("nan")
    oof_ap  = float("nan")

if valid.any():
    oof_pred01 = (oof_prob[valid] >= 0.5).astype(np.int8)
    oof_f1_05 = f1_binary(y[valid], oof_pred01)
    oof_thr, oof_bestf1 = best_f1_threshold(y[valid], oof_prob[valid], grid=300)
else:
    oof_f1_05 = float("nan")
    oof_thr, oof_bestf1 = float("nan"), float("nan")

print("\n[Stage 8] TRAIN DONE")
print(f"- elapsed: {elapsed/60:.2f} min")
print(f"- OOF saved: {oof_path_npy}")
print(f"- OOF saved: {oof_path_csv}")
print(f"- fold metrics: {metrics_path}")
print(f"- OOF rows valid: {int(valid.sum()):,}/{N:,}")
print(f"- OOF AUC (valid-only): {oof_auc:.5f}")
print(f"- OOF AP  (valid-only): {oof_ap:.5f}")
print(f"- OOF F1@0.5 (valid-only): {oof_f1_05:.4f}")
print(f"- OOF best F1={oof_bestf1:.4f} @ thr={oof_thr:.4f}")

globals().update({
    "oof_prob": oof_prob,
    "OOF_PROB_PATH": oof_path_npy,
    "OOF_CSV_PATH": oof_path_csv,
    "FOLD_METRICS_PATH": metrics_path,
    "TRAIN_CFG_PATH": cfg_path,
    "VAL_FEAT": VAL_FEAT,
    "VAL_IS_MAG": VAL_IS_MAG,
    "OOF_BEST_THR_F1": float(oof_thr) if np.isfinite(oof_thr) else None,
})

gc.collect()


# OOF Prediction + Threshold Tuning

In [None]:
# ============================================================
# STAGE 9 — OOF Prediction + Threshold Tuning (ONE CELL)
# REVISI FULL v4.4 (EXACT unique-level sweep + add F-beta + constraints + better tie-break)
#
# Upgrades v4.4:
# - Exact sweep on UNIQUE prob levels (more accurate than mixed candidate grids)
# - Still supports rule="ge" (>=) and rule="gt" (>)
# - Provides ge_equiv thresholds for gt (via nextafter) for downstream that uses >=
# - Adds F0.5 and F2 (precision/recall emphasis)
# - Optional constraints: best thr with min_precision / min_recall
# - Improved tie-break: after metric ties, prefer smaller |pos_pred - pos| gap
#
# Output:
# - OOF_DIR/threshold_tuning.json
# - OOF_DIR/threshold_report.txt
# - OOF_DIR/threshold_table_top1000.csv
# - globals: BEST_THR, BEST_THR_GE_F1, BEST_THR_GE_MCC, ... + tables
# ============================================================

import gc, json, math, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)

# ----------------------------
# 0) Require previous stages
# ----------------------------
need = ["OOF_DIR", "df_train_meta"]
for k in need:
    if k not in globals():
        raise RuntimeError(f"Missing `{k}`. Jalankan STAGE 0..8 dulu.")

OOF_DIR = Path(OOF_DIR)
OOF_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# Helper: robust stringify id
# ----------------------------
def _to_str_list(ids):
    out = []
    for x in ids:
        if isinstance(x, (bytes, np.bytes_, bytearray)):
            out.append(x.decode("utf-8", errors="ignore").strip())
        else:
            out.append(str(x).strip())
    return out

def _as_1d_float32(arr):
    a = np.asarray(arr)
    if a.dtype == object and a.ndim == 0:
        try:
            a = np.asarray(a.item())
        except Exception:
            pass
    a = np.asarray(a, dtype=np.float32)
    if a.ndim == 0:
        return a.reshape(1)
    if a.ndim > 1:
        a = a.reshape(-1)
    return a

# ----------------------------
# Detect target column in df_train_meta
# ----------------------------
def _detect_target_col(df):
    for cand in ["target","y","label","class","is_tde","binary_target","target_id"]:
        if cand in df.columns:
            return cand
    return None

TARGET_COL = _detect_target_col(df_train_meta)
if TARGET_COL is None:
    raise RuntimeError(
        "Cannot detect target column in df_train_meta. "
        f"Columns sample: {list(df_train_meta.columns)[:60]}"
    )

# normalize meta index to string+strip once
meta = df_train_meta.copy()
meta.index = pd.Index(pd.Index(meta.index).astype("string").str.strip().astype(str), name="object_id")

# ----------------------------
# Load OOF (prefer CSV)
# ----------------------------
def _load_oof():
    pcsv = OOF_DIR / "oof_prob.csv"
    if pcsv.exists():
        df = pd.read_csv(pcsv)
        if ("object_id" in df.columns) and ("oof_prob" in df.columns):
            ids = df["object_id"].astype(str).str.strip().tolist()
            prob = _as_1d_float32(df["oof_prob"].to_numpy())
            return ids, prob, "csv(oof_prob.csv)"

    if "oof_prob" in globals():
        prob = _as_1d_float32(globals()["oof_prob"])
        if isinstance(prob, np.ndarray) and prob.ndim == 1 and len(prob) > 0:
            if "train_ids_ordered" in globals() and globals()["train_ids_ordered"] is not None:
                ids = _to_str_list(list(globals()["train_ids_ordered"]))
                return ids, prob, "globals(oof_prob + train_ids_ordered)"
            if len(prob) == len(meta):
                ids = _to_str_list(meta.index.tolist())
                return ids, prob, "globals(oof_prob + df_train_meta.index)"

    pnpy = OOF_DIR / "oof_prob.npy"
    if pnpy.exists():
        prob = _as_1d_float32(np.load(pnpy, allow_pickle=False))
        if "train_ids_ordered" in globals() and globals()["train_ids_ordered"] is not None:
            ids = _to_str_list(list(globals()["train_ids_ordered"]))
            return ids, prob, "npy(oof_prob.npy + train_ids_ordered)"
        if len(prob) == len(meta):
            ids = _to_str_list(meta.index.tolist())
            return ids, prob, "npy(oof_prob.npy + df_train_meta.index)"

    raise FileNotFoundError("OOF prob not found (csv/globals/npy). Jalankan STAGE 8 dulu.")

train_ids, oof_prob, src = _load_oof()
oof_prob = _as_1d_float32(oof_prob).astype(np.float32)

if len(train_ids) != len(oof_prob):
    raise RuntimeError(f"OOF length mismatch: len(train_ids)={len(train_ids)} vs len(oof_prob)={len(oof_prob)}")

# IMPORTANT: keep NaN for holdout-safe (do NOT nan_to_num->0)
valid = np.isfinite(oof_prob)
if not valid.any():
    raise RuntimeError("All oof_prob are non-finite (NaN/inf). Check STAGE 8 output.")

# align y by train_ids via fast indexer
idx = meta.index.get_indexer(train_ids)
if (idx < 0).any():
    bad = [train_ids[i] for i in np.where(idx < 0)[0][:10]]
    raise KeyError(
        f"OOF ids not found in df_train_meta.index (string-normalized). ex={bad} | missing_n={int((idx<0).sum())}"
    )

y_raw = pd.to_numeric(meta[TARGET_COL], errors="coerce").fillna(0).astype(np.int16).to_numpy(copy=False)
y_all = (y_raw[idx] > 0).astype(np.int8)

# filter valid rows
train_ids_v = [train_ids[i] for i in np.where(valid)[0]]
p_v = np.clip(oof_prob[valid].astype(np.float32), 0.0, 1.0)
y_v = y_all[valid].astype(np.int8)

N_all = int(len(y_all))
N = int(len(y_v))
pos = int((y_v == 1).sum())
neg = int((y_v == 0).sum())

print(f"[Stage 9] Loaded OOF from: {src}")
print(f"[Stage 9] Valid rows: {N:,}/{N_all:,} (holdout mode => valid << total)")
print(f"[Stage 9] pos={pos:,} | neg={neg:,} | pos%={pos/max(N,1)*100:.6f}% | target_col={TARGET_COL}")

uy = set(np.unique(y_v).tolist())
if not uy.issubset({0, 1}):
    raise ValueError(f"y must be binary 0/1. Found: {sorted(list(uy))}")

# threshold-free sanity
try:
    from sklearn.metrics import roc_auc_score, average_precision_score
    auc_oof = float(roc_auc_score(y_v, p_v)) if (len(uy) == 2 and N > 1) else float("nan")
    ap_oof  = float(average_precision_score(y_v, p_v)) if (len(uy) == 2 and N > 1) else float("nan")
except Exception:
    auc_oof = float("nan")
    ap_oof  = float("nan")

# ----------------------------
# 1) Metric helpers (vectorized)
# ----------------------------
def _safe_div(a, b):
    return a / np.maximum(b, 1e-12)

def _fbeta(prec, rec, beta):
    beta2 = float(beta) ** 2
    return _safe_div((1.0 + beta2) * prec * rec, beta2 * prec + rec)

def _metrics_from_counts(tp, fp, fn, tn):
    tp = tp.astype(np.float64); fp = fp.astype(np.float64)
    fn = fn.astype(np.float64); tn = tn.astype(np.float64)

    prec = _safe_div(tp, tp + fp)
    rec  = _safe_div(tp, tp + fn)

    f1   = _safe_div(2 * prec * rec, prec + rec)
    f05  = _fbeta(prec, rec, beta=0.5)
    f2   = _fbeta(prec, rec, beta=2.0)

    acc  = _safe_div(tp + tn, tp + fp + fn + tn)

    tpr  = _safe_div(tp, tp + fn)
    tnr  = _safe_div(tn, tn + fp)
    bacc = 0.5 * (tpr + tnr)

    num = tp * tn - fp * fn
    den = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
    mcc = np.where(den > 0, num / np.sqrt(den), 0.0)

    return f1, f05, f2, prec, rec, acc, bacc, mcc

# ----------------------------
# 2) Exact sweep on UNIQUE probability levels (fast + exact)
# ----------------------------
p = p_v.astype(np.float32)
y = y_v.astype(np.int8)

ord_desc = np.argsort(-p, kind="mergesort")  # stable
p_sorted = p[ord_desc]
y_sorted = y[ord_desc]

pos_prefix = np.cumsum(y_sorted == 1).astype(np.int64)
neg_prefix = np.cumsum(y_sorted == 0).astype(np.int64)

pos_total = int(pos_prefix[-1]) if N > 0 else 0
neg_total = int(neg_prefix[-1]) if N > 0 else 0

# unique groups in descending sorted probs
# group boundaries where p changes
if N == 0:
    raise RuntimeError("No valid rows to tune thresholds.")

chg = np.ones((N,), dtype=bool)
chg[1:] = (p_sorted[1:] != p_sorted[:-1])

grp_starts = np.where(chg)[0].astype(np.int64)
grp_vals = p_sorted[grp_starts].astype(np.float32)

# ends are starts next -1, last ends at N-1
grp_ends = np.empty_like(grp_starts)
grp_ends[:-1] = grp_starts[1:] - 1
grp_ends[-1] = N - 1

# for ge: pred positive includes current group => k = end+1
k_ge = (grp_ends + 1).astype(np.int64)
# for gt: pred positive excludes equals => k = start (count of > thr)
k_gt = grp_starts.astype(np.int64)

# also add edge thresholds:
# - for ge: thr just above max => k=0
# - for gt: thr below min => k=N
thr_ge_edge = np.nextafter(np.float32(grp_vals[0]), np.float32(1.0))  # >max (if max<1)
thr_gt_edge = np.float32(0.0)  # <=0 means almost all > thr depending (gt uses >)

# build table function from k and thr
def _table_from_k(thr_vals, k_vals, rule):
    k = np.clip(k_vals, 0, N).astype(np.int64)
    tp = np.where(k > 0, pos_prefix[k - 1], 0).astype(np.int64)
    fp = np.where(k > 0, neg_prefix[k - 1], 0).astype(np.int64)
    fn = (pos_total - tp).astype(np.int64)
    tn = (neg_total - fp).astype(np.int64)

    f1, f05, f2, prec, rec, acc, bacc, mcc = _metrics_from_counts(tp, fp, fn, tn)
    pos_pred = k.astype(np.int64)
    gap = np.abs(pos_pred - pos_total).astype(np.int64)

    return pd.DataFrame({
        "thr": thr_vals.astype(np.float32),
        "rule": rule,
        "f1": f1.astype(np.float32),
        "f05": f05.astype(np.float32),
        "f2": f2.astype(np.float32),
        "precision": prec.astype(np.float32),
        "recall": rec.astype(np.float32),
        "accuracy": acc.astype(np.float32),
        "balanced_accuracy": bacc.astype(np.float32),
        "mcc": mcc.astype(np.float32),
        "tp": tp.astype(np.int64),
        "fp": fp.astype(np.int64),
        "fn": fn.astype(np.int64),
        "tn": tn.astype(np.int64),
        "pos_pred": pos_pred.astype(np.int64),
        "pos_pred_gap": gap.astype(np.int64),
    })

# base tables (unique)
thr_table_ge = _table_from_k(grp_vals, k_ge, "ge")
thr_table_gt = _table_from_k(grp_vals, k_gt, "gt")

# add edge rows
# ge edge: thr > max => k=0 (if thr_ge_edge==max when max==1, still ok)
edge_ge = _table_from_k(np.array([thr_ge_edge], np.float32), np.array([0], np.int64), "ge")
thr_table_ge = pd.concat([edge_ge, thr_table_ge], axis=0, ignore_index=True)

# gt edge: thr < min => k=N (use thr=-eps as 0 with >0 rule; safest use thr=-1e-7 then clip? keep 0 with separate row)
# better: use thr = np.nextafter(0, -1) to include probs==0 in >thr; but keep in [0,1] by storing 0 and k computed separately.
edge_gt = _table_from_k(np.array([np.float32(0.0)], np.float32), np.array([N], np.int64), "gt")
thr_table_gt = pd.concat([thr_table_gt, edge_gt], axis=0, ignore_index=True)

# de-dup thresholds per rule (keep best by pos_pred count uniqueness)
thr_table_ge = thr_table_ge.drop_duplicates(subset=["thr","rule"], keep="first").reset_index(drop=True)
thr_table_gt = thr_table_gt.drop_duplicates(subset=["thr","rule"], keep="first").reset_index(drop=True)

# ----------------------------
# 3) Add extra candidate thresholds (prevalence & Stage8 best)
# ----------------------------
# prevalence-match (roughly pos_pred ~ pos for ge)
if pos_total > 0:
    thr_prev = float(p_sorted[min(pos_total - 1, N - 1)])
else:
    thr_prev = 1.0

extra_thr = [0.5, thr_prev, 0.0, 1.0]
if "OOF_BEST_THR_F1" in globals() and globals()["OOF_BEST_THR_F1"] is not None:
    try:
        extra_thr.append(float(globals()["OOF_BEST_THR_F1"]))
    except Exception:
        pass

def _eval_specific(thr, rule):
    thr = float(thr)
    # use sorted probs and rule to find k
    if rule == "ge":
        # count(prob >= thr)
        k0 = int(np.searchsorted(-p_sorted, -np.float32(thr), side="right"))
    else:
        # count(prob > thr)
        k0 = int(np.searchsorted(-p_sorted, -np.float32(thr), side="left"))
    k0 = max(0, min(k0, N))

    tp0 = int(pos_prefix[k0 - 1]) if k0 > 0 else 0
    fp0 = int(neg_prefix[k0 - 1]) if k0 > 0 else 0
    fn0 = int(pos_total - tp0)
    tn0 = int(neg_total - fp0)

    prec0 = tp0 / max(tp0 + fp0, 1)
    rec0  = tp0 / max(tp0 + fn0, 1)
    f10 = 0.0 if (tp0 == 0 or (prec0 + rec0) == 0) else (2 * prec0 * rec0 / (prec0 + rec0))
    f05 = 0.0 if (prec0 == 0 and rec0 == 0) else ((1.25 * prec0 * rec0) / max(0.25 * prec0 + rec0, 1e-12))
    f2  = 0.0 if (prec0 == 0 and rec0 == 0) else ((5.00 * prec0 * rec0) / max(4.00 * prec0 + rec0, 1e-12))
    acc0 = (tp0 + tn0) / max(tp0 + fp0 + fn0 + tn0, 1)
    bacc0 = 0.5 * ((tp0 / max(tp0 + fn0, 1)) + (tn0 / max(tn0 + fp0, 1)))
    den0 = (tp0 + fp0) * (tp0 + fn0) * (tn0 + fp0) * (tn0 + fn0)
    mcc0 = 0.0 if den0 <= 0 else ((tp0 * tn0 - fp0 * fn0) / math.sqrt(den0))
    gap0 = abs(k0 - pos_total)

    return {
        "thr": thr, "rule": rule,
        "f1": float(f10), "f05": float(f05), "f2": float(f2),
        "precision": float(prec0), "recall": float(rec0),
        "accuracy": float(acc0), "balanced_accuracy": float(bacc0), "mcc": float(mcc0),
        "tp": tp0, "fp": fp0, "fn": fn0, "tn": tn0,
        "pos_pred": int(k0), "pos_pred_gap": int(gap0),
    }

extra_rows = []
for t in extra_thr:
    extra_rows.append(_eval_specific(t, "ge"))
    extra_rows.append(_eval_specific(t, "gt"))

extra_df = pd.DataFrame(extra_rows)
# merge (prefer keeping best metric rows if duplicates)
thr_table_ge = pd.concat([thr_table_ge, extra_df[extra_df["rule"]=="ge"]], ignore_index=True)
thr_table_gt = pd.concat([thr_table_gt, extra_df[extra_df["rule"]=="gt"]], ignore_index=True)
thr_table_ge = thr_table_ge.drop_duplicates(subset=["thr","rule"], keep="first").reset_index(drop=True)
thr_table_gt = thr_table_gt.drop_duplicates(subset=["thr","rule"], keep="first").reset_index(drop=True)

# ----------------------------
# 4) Best pickers + constraints
# ----------------------------
def _pick_best(df, primary, tie_cols=None):
    # primary desc, tie cols desc, last: pos_pred_gap asc (prefer prevalence match)
    tie_cols = tie_cols or []
    sort_cols = [primary] + tie_cols + ["pos_pred_gap"]
    asc = [False] * (1 + len(tie_cols)) + [True]
    return df.sort_values(sort_cols, ascending=asc).iloc[0]

def _pick_best_constrained(df, primary, min_precision=None, min_recall=None):
    dd = df.copy()
    if min_precision is not None:
        dd = dd[dd["precision"] >= float(min_precision)]
    if min_recall is not None:
        dd = dd[dd["recall"] >= float(min_recall)]
    if len(dd) == 0:
        return None
    return _pick_best(dd, primary, tie_cols=["mcc","balanced_accuracy","recall","precision"])

# baseline
base05_ge = _eval_specific(0.5, "ge")
base05_gt = _eval_specific(0.5, "gt")

# best per rule
best_ge_f1   = _pick_best(thr_table_ge, "f1",  ["mcc","balanced_accuracy","recall","precision"])
best_ge_f05  = _pick_best(thr_table_ge, "f05", ["mcc","balanced_accuracy","precision","recall"])
best_ge_f2   = _pick_best(thr_table_ge, "f2",  ["mcc","balanced_accuracy","recall","precision"])
best_ge_mcc  = _pick_best(thr_table_ge, "mcc", ["f1","balanced_accuracy","accuracy"])
best_ge_bacc = _pick_best(thr_table_ge, "balanced_accuracy", ["mcc","accuracy","f1"])

best_gt_f1   = _pick_best(thr_table_gt, "f1",  ["mcc","balanced_accuracy","recall","precision"])
best_gt_f05  = _pick_best(thr_table_gt, "f05", ["mcc","balanced_accuracy","precision","recall"])
best_gt_f2   = _pick_best(thr_table_gt, "f2",  ["mcc","balanced_accuracy","recall","precision"])
best_gt_mcc  = _pick_best(thr_table_gt, "mcc", ["f1","balanced_accuracy","accuracy"])
best_gt_bacc = _pick_best(thr_table_gt, "balanced_accuracy", ["mcc","accuracy","f1"])

# optional constraint examples (edit if you want)
MIN_PREC = None   # e.g. 0.80
MIN_REC  = None   # e.g. 0.30
best_ge_f1_con = _pick_best_constrained(thr_table_ge, "f1", min_precision=MIN_PREC, min_recall=MIN_REC)

# gt -> ge equivalent threshold (so downstream can still do prob >= thr)
def _gt_to_ge_equiv(thr_gt):
    thr_gt = float(thr_gt)
    thr_ge = float(np.nextafter(np.float32(thr_gt), np.float32(1.0)))
    return float(min(max(thr_ge, 0.0), 1.0))

BEST_THR_GE_F1   = float(best_ge_f1["thr"])
BEST_THR_GE_F05  = float(best_ge_f05["thr"])
BEST_THR_GE_F2   = float(best_ge_f2["thr"])
BEST_THR_GE_MCC  = float(best_ge_mcc["thr"])
BEST_THR_GE_BACC = float(best_ge_bacc["thr"])

BEST_THR_GT_F1   = float(best_gt_f1["thr"])
BEST_THR_GT_F05  = float(best_gt_f05["thr"])
BEST_THR_GT_F2   = float(best_gt_f2["thr"])
BEST_THR_GT_MCC  = float(best_gt_mcc["thr"])
BEST_THR_GT_BACC = float(best_gt_bacc["thr"])

BEST_THR_GT_F1_GE_EQUIV   = _gt_to_ge_equiv(BEST_THR_GT_F1)
BEST_THR_GT_MCC_GE_EQUIV  = _gt_to_ge_equiv(BEST_THR_GT_MCC)
BEST_THR_GT_BACC_GE_EQUIV = _gt_to_ge_equiv(BEST_THR_GT_BACC)

# default threshold choice (keep F1 + GE)
BEST_THR = BEST_THR_GE_F1

# ----------------------------
# 5) Save artifacts
# ----------------------------
thr_table_all = pd.concat([thr_table_ge, thr_table_gt], axis=0, ignore_index=True)

# save top1000 by F1 primary
top1000 = thr_table_all.sort_values(
    ["f1","mcc","balanced_accuracy","recall","precision","pos_pred_gap"],
    ascending=[False, False, False, False, False, True]
).head(1000).reset_index(drop=True)

out_json = OOF_DIR / "threshold_tuning.json"
out_txt  = OOF_DIR / "threshold_report.txt"
out_csv  = OOF_DIR / "threshold_table_top1000.csv"
top1000.to_csv(out_csv, index=False)

payload = {
    "version": "v4.4",
    "source": src,
    "target_col": TARGET_COL,
    "n_total_rows": int(N_all),
    "n_valid_rows": int(N),
    "pos_valid": int(pos),
    "neg_valid": int(neg),
    "pos_rate_valid": float(pos / max(N, 1)),
    "oof_auc_valid_only": float(auc_oof),
    "oof_ap_valid_only": float(ap_oof),
    "prevalence_match_thr_ge": float(thr_prev),
    "baseline_thr_0p5": {"ge": base05_ge, "gt": base05_gt},
    "best_ge": {
        "best_thr_f1":   best_ge_f1.to_dict(),
        "best_thr_f05":  best_ge_f05.to_dict(),
        "best_thr_f2":   best_ge_f2.to_dict(),
        "best_thr_mcc":  best_ge_mcc.to_dict(),
        "best_thr_bacc": best_ge_bacc.to_dict(),
        "best_thr_f1_constrained": (None if best_ge_f1_con is None else best_ge_f1_con.to_dict()),
        "constraints": {"min_precision": MIN_PREC, "min_recall": MIN_REC},
    },
    "best_gt": {
        "best_thr_f1":   best_gt_f1.to_dict(),
        "best_thr_f05":  best_gt_f05.to_dict(),
        "best_thr_f2":   best_gt_f2.to_dict(),
        "best_thr_mcc":  best_gt_mcc.to_dict(),
        "best_thr_bacc": best_gt_bacc.to_dict(),
        "ge_equiv_for_downstream_using_ge": {
            "f1": float(BEST_THR_GT_F1_GE_EQUIV),
            "mcc": float(BEST_THR_GT_MCC_GE_EQUIV),
            "bacc": float(BEST_THR_GT_BACC_GE_EQUIV),
        },
    },
    "default_best_thr": {"metric": "f1", "rule": "ge", "thr": float(BEST_THR)},
}

with open(out_json, "w", encoding="utf-8") as f:
    json.dump(payload, f, indent=2)

def _fmt_row(d):
    return (f"thr={d['thr']:.6f} | F1={d['f1']:.6f} | F0.5={d['f05']:.6f} | F2={d['f2']:.6f} | "
            f"P={d['precision']:.6f} R={d['recall']:.6f} | BACC={d['balanced_accuracy']:.6f} | MCC={d['mcc']:.6f} | "
            f"pos_pred={int(d['pos_pred'])} (gap={int(d['pos_pred_gap'])})")

lines = []
lines.append("OOF Threshold Tuning Report (v4.4)")
lines.append(f"- source={src}")
lines.append(f"- target_col={TARGET_COL}")
lines.append(f"- total_rows={N_all} | valid_rows={N} | pos_valid={pos} | neg_valid={neg} | pos%={pos/max(N,1)*100:.6f}%")
lines.append(f"- OOF AUC (valid-only) = {auc_oof:.6f}")
lines.append(f"- OOF AP  (valid-only) = {ap_oof:.6f}")
lines.append(f"- prevalence-match thr (ge) ~ {thr_prev:.6f}")
lines.append("")
lines.append("Baseline @ thr=0.5")
lines.append(f"- rule=ge: {_fmt_row(base05_ge)}")
lines.append(f"- rule=gt: {_fmt_row(base05_gt)}")
lines.append("")
lines.append("BEST (rule=ge)  [downstream default: pred = prob >= thr]")
lines.append(f"- BEST-F1   : {_fmt_row(best_ge_f1.to_dict())}")
lines.append(f"- BEST-F0.5 : {_fmt_row(best_ge_f05.to_dict())}")
lines.append(f"- BEST-F2   : {_fmt_row(best_ge_f2.to_dict())}")
lines.append(f"- BEST-MCC  : {_fmt_row(best_ge_mcc.to_dict())}")
lines.append(f"- BEST-BACC : {_fmt_row(best_ge_bacc.to_dict())}")
if best_ge_f1_con is not None:
    lines.append("")
    lines.append(f"BEST-F1 constrained (minP={MIN_PREC}, minR={MIN_REC}): {_fmt_row(best_ge_f1_con.to_dict())}")
lines.append("")
lines.append("BEST (rule=gt)  [strict '>' boundary]")
lines.append(f"- BEST-F1   : {_fmt_row(best_gt_f1.to_dict())} | ge_equiv={BEST_THR_GT_F1_GE_EQUIV:.6f}")
lines.append(f"- BEST-MCC  : {_fmt_row(best_gt_mcc.to_dict())} | ge_equiv={BEST_THR_GT_MCC_GE_EQUIV:.6f}")
lines.append(f"- BEST-BACC : {_fmt_row(best_gt_bacc.to_dict())} | ge_equiv={BEST_THR_GT_BACC_GE_EQUIV:.6f}")
lines.append("")
lines.append("Top 12 (by F1) overall:")
for i in range(min(12, len(top1000))):
    r = top1000.iloc[i].to_dict()
    lines.append(f"{i+1:02d}. rule={r['rule']} | {_fmt_row(r)}")

with open(out_txt, "w", encoding="utf-8") as f:
    f.write("\n".join(lines) + "\n")

print("[Stage 9] DONE")
print(f"- Saved: {out_json}")
print(f"- Saved: {out_txt}")
print(f"- Saved: {out_csv}")
print(f"- OOF AUC (valid-only): {auc_oof:.6f} | AP: {ap_oof:.6f}")
print(f"- DEFAULT BEST_THR (ge/F1) = {BEST_THR:.6f} | F1={float(best_ge_f1['f1']):.6f} (P={float(best_ge_f1['precision']):.6f} R={float(best_ge_f1['recall']):.6f})")
print(f"- BEST_THR_GE_MCC          = {BEST_THR_GE_MCC:.6f} | MCC={float(best_ge_mcc['mcc']):.6f}")
print(f"- BEST_THR_GE_BACC         = {BEST_THR_GE_BACC:.6f} | BACC={float(best_ge_bacc['balanced_accuracy']):.6f}")
print(f"- BEST_THR_GT_F1           = {BEST_THR_GT_F1:.6f} (gt) | ge_equiv={BEST_THR_GT_F1_GE_EQUIV:.6f}")

globals().update({
    "train_ids_oof_all": train_ids,
    "train_ids_oof_valid": train_ids_v,
    "oof_prob_all": oof_prob,
    "oof_prob_valid": p_v,
    "y_oof_valid": y_v,

    "BEST_THR": float(BEST_THR),                 # default: rule=ge, metric=F1
    "BEST_THR_GE_F1": float(BEST_THR_GE_F1),
    "BEST_THR_GE_F05": float(BEST_THR_GE_F05),
    "BEST_THR_GE_F2": float(BEST_THR_GE_F2),
    "BEST_THR_GE_MCC": float(BEST_THR_GE_MCC),
    "BEST_THR_GE_BACC": float(BEST_THR_GE_BACC),

    "BEST_THR_GT_F1": float(BEST_THR_GT_F1),
    "BEST_THR_GT_F05": float(BEST_THR_GT_F05),
    "BEST_THR_GT_F2": float(BEST_THR_GT_F2),
    "BEST_THR_GT_MCC": float(BEST_THR_GT_MCC),
    "BEST_THR_GT_BACC": float(BEST_THR_GT_BACC),

    # use these if downstream always uses >= but you want strict ">"
    "BEST_THR_GT_F1_GE_EQUIV": float(BEST_THR_GT_F1_GE_EQUIV),
    "BEST_THR_GT_MCC_GE_EQUIV": float(BEST_THR_GT_MCC_GE_EQUIV),
    "BEST_THR_GT_BACC_GE_EQUIV": float(BEST_THR_GT_BACC_GE_EQUIV),

    "thr_table_ge": thr_table_ge,
    "thr_table_gt": thr_table_gt,
    "thr_table_top1000": top1000,
    "THR_JSON_PATH": out_json,
    "THR_REPORT_PATH": out_txt,
    "THR_TABLE_CSV_PATH": out_csv,
    "OOF_AUC_VALID_ONLY": float(auc_oof),
    "OOF_AP_VALID_ONLY": float(ap_oof),
})

gc.collect()


# Test Inference (Fold Ensemble)

In [None]:
# ============================================================
# STAGE 10 — Test Inference (Fold Ensemble) (ONE CELL)
# REVISI FULL v4.5 (TEST G-FEAT CACHE + STRICT GDIM + BAND RANGE CHECK + EMA INFER + DEBUG EXPORTS)
#
# Upgrades v4.5:
# - Cache TEST global features: FIX_DIR/global_features_test_raw.npy (+ meta json)
# - Strict g_dim: if computed G_raw < g_dim expected by ckpt => raise (avoid silent padding)
# - Band range sanity check for SHIFT_BAND_IDS & n_bands (fail-fast on mismatch)
# - EMA inference: apply EMA shadow safely (shape match only) + report hits
# - Optional debug exports: test_prob_folds.csv
# ============================================================

import os, gc, json, re, math, time, warnings, hashlib
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message="enable_nested_tensor is True.*")

# ----------------------------
# 0) Require previous stages
# ----------------------------
need = ["ART_DIR","FIX_DIR","MAX_LEN","SEQ_FEATURE_NAMES","df_test_meta","CKPT_DIR","n_splits"]
for k in need:
    if k not in globals():
        raise RuntimeError(f"Missing `{k}`. Jalankan STAGE 0..9 dulu.")

# Torch
try:
    import torch
    import torch.nn as nn
except Exception as e:
    raise RuntimeError("PyTorch tidak tersedia di environment ini.") from e

device = torch.device("cpu")
SEED = int(globals().get("SEED", 2025))
torch.manual_seed(SEED)
np.random.seed(SEED)

# Thread guard (CPU)
try:
    torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", "2")))
    torch.set_num_interop_threads(1)
except Exception:
    pass

FIX_DIR = Path(globals()["FIX_DIR"])
ART_DIR = Path(globals()["ART_DIR"]); ART_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR = Path(globals()["CKPT_DIR"])

OUT_DIR = ART_DIR / "preds"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# Settings
# ----------------------------
USE_EMA_WEIGHTS_FOR_INFER = True      # use EMA if ckpt has ema_shadow
EXPORT_TEST_PRED_01 = True            # export 0/1 csv using BEST_THR if available
DEFAULT_THR_IF_MISSING = 0.5
EXPORT_TEST_PROB_FOLDS_CSV = False    # debug: wide CSV per fold
STRICT_GDIM = True                   # if computed G_raw < g_dim expected -> raise

# ----------------------------
# helper: normalize id robustly
# ----------------------------
def _norm_id(x):
    if isinstance(x, (bytes, np.bytes_, bytearray)):
        try:
            x = x.decode("utf-8", errors="ignore")
        except Exception:
            x = str(x)
    s = str(x).strip()
    if (s.startswith("b'") and s.endswith("'")) or (s.startswith('b"') and s.endswith('"')):
        s = s[2:-1]
    return s.strip()

def _load_ids_npy(path: Path):
    arr = np.load(path, allow_pickle=False)
    xs = arr.tolist() if hasattr(arr, "tolist") else list(arr)
    return [_norm_id(z) for z in xs]

def sigmoid_np(x):
    x = np.clip(x, -50, 50)
    return 1.0 / (1.0 + np.exp(-x))

# ----------------------------
# 0b) Read Stage6 policy (SHIFT_BAND_IDS, PAD_BAND_ID, dtype_X, token/value hints)
# ----------------------------
SHIFT_BAND_IDS = False
PAD_BAND_ID = 0
DTYPE_X_MEMMAP = np.float32
POL_TOKEN_MODE = None
POL_SCORE_VALUE_FEAT = None

policy_path = FIX_DIR / "length_policy_config.json"
if policy_path.exists():
    try:
        pol = json.loads(policy_path.read_text())
        SHIFT_BAND_IDS = bool(pol.get("padding", {}).get("SHIFT_BAND_IDS", False))
        PAD_BAND_ID = int(pol.get("padding", {}).get("PAD_BAND_ID", 0))
        dt = str(pol.get("dtype_X", "float32")).lower()
        DTYPE_X_MEMMAP = np.float16 if (("float16" in dt) or ("fp16" in dt)) else np.float32
        POL_TOKEN_MODE = pol.get("token_mode", None)
        POL_SCORE_VALUE_FEAT = pol.get("score_value_feat", None)
    except Exception:
        pass

# ----------------------------
# 1) Load TEST ordering (must match STAGE 6)
# ----------------------------
test_ids_path = FIX_DIR / "test_ids.npy"
if not test_ids_path.exists():
    raise FileNotFoundError(f"Missing: {test_ids_path}. Pastikan STAGE 6 sukses.")

test_ids = _load_ids_npy(test_ids_path)
NTE = len(test_ids)
if NTE <= 0:
    raise RuntimeError("test_ids kosong (NTE=0). Pastikan STAGE 6 sukses membuat test_ids.npy.")

# Align df_test_meta.index via string-map (HARD)
df_test_meta = globals()["df_test_meta"].copy(deep=False)
df_test_meta.index = pd.Index([_norm_id(z) for z in df_test_meta.index.tolist()], name=df_test_meta.index.name)

pos_idx = df_test_meta.index.get_indexer(test_ids)
if (pos_idx < 0).any():
    bad = [test_ids[i] for i in np.where(pos_idx < 0)[0][:10]]
    raise KeyError(f"Some test_ids not found in df_test_meta.index. ex={bad} | missing_n={int((pos_idx<0).sum())}")
pos_idx = pos_idx.astype(np.int32)

if len(set(test_ids)) != len(test_ids):
    s = pd.Series(test_ids)
    dup = s[s.duplicated()].head(10).tolist()
    raise ValueError(f"Duplicate object_id in test_ids ordering (examples): {dup}")

# ----------------------------
# 2) Open fixed-length TEST memmaps (dtype from Stage6 policy)
# ----------------------------
SEQ_FEATURE_NAMES = list(globals()["SEQ_FEATURE_NAMES"])
Fdim = len(SEQ_FEATURE_NAMES)
L = int(globals()["MAX_LEN"])

test_X_path = FIX_DIR / "test_X.dat"
test_B_path = FIX_DIR / "test_B.dat"
test_M_path = FIX_DIR / "test_M.dat"
for p in [test_X_path, test_B_path, test_M_path]:
    if not p.exists():
        raise FileNotFoundError(f"Missing fixed cache file: {p}. Pastikan STAGE 6 sukses.")

Xte = np.memmap(test_X_path, dtype=DTYPE_X_MEMMAP, mode="r", shape=(NTE, L, Fdim))
Bte = np.memmap(test_B_path, dtype=np.int8,        mode="r", shape=(NTE, L))
Mte = np.memmap(test_M_path, dtype=np.int8,        mode="r", shape=(NTE, L))

feat = {n:i for i,n in enumerate(SEQ_FEATURE_NAMES)}
for k in ["snr_tanh","detected"]:
    if k not in feat:
        raise RuntimeError(f"Feature '{k}' not found in SEQ_FEATURE_NAMES.")

# ----------------------------
# 3) Checkpoints (fold_*.pt)
# ----------------------------
n_splits = int(globals()["n_splits"])
ckpts = []
for f in range(n_splits):
    p = CKPT_DIR / f"fold_{f}.pt"
    if not p.exists():
        raise FileNotFoundError(f"Missing checkpoint: {p}. Pastikan STAGE 8 menyimpan ckpt per fold.")
    ckpts.append(p)

# ----------------------------
# 4) Safe/compat checkpoint loader
# ----------------------------
def torch_load_compat(path: Path):
    try:
        obj = torch.load(path, map_location="cpu", weights_only=True)
        if isinstance(obj, dict) and ("model_state" in obj or "global_scaler" in obj or "cfg" in obj):
            return obj
        return torch.load(path, map_location="cpu", weights_only=False)
    except TypeError:
        return torch.load(path, map_location="cpu")
    except Exception:
        return torch.load(path, map_location="cpu", weights_only=False)

def extract_state_and_meta(ckpt_obj):
    if isinstance(ckpt_obj, dict) and "model_state" in ckpt_obj and isinstance(ckpt_obj["model_state"], dict):
        return ckpt_obj["model_state"], ckpt_obj
    if isinstance(ckpt_obj, dict):
        any_tensor = any(torch.is_tensor(v) for v in ckpt_obj.values())
        if any_tensor:
            return ckpt_obj, {}
        return ckpt_obj, ckpt_obj
    raise RuntimeError(f"Unsupported ckpt object type: {type(ckpt_obj)}")

# ----------------------------
# 5) Infer architecture from state_dict
# ----------------------------
def infer_from_state(sd: dict):
    keys = set(sd.keys())

    if "band_emb.weight" not in sd:
        raise RuntimeError("state_dict missing band_emb.weight.")
    n_bands = int(sd["band_emb.weight"].shape[0])
    d_model = int(sd["band_emb.weight"].shape[1])

    if "pos_emb" not in sd:
        raise RuntimeError("state_dict missing pos_emb.")
    max_len_ckpt = int(sd["pos_emb"].shape[1])

    if "x_proj.0.weight" in keys:
        feat_dim = int(sd["x_proj.0.weight"].shape[1])
    elif "x_proj.weight" in keys:
        feat_dim = int(sd["x_proj.weight"].shape[1])
    else:
        raise RuntimeError("state_dict missing x_proj.*.weight")

    if "g_proj.0.weight" in keys:
        g_dim = int(sd["g_proj.0.weight"].shape[1])
        g_hidden = int(sd["g_proj.0.weight"].shape[0])
    else:
        g_dim = 0
        g_hidden = 0

    layer_ids = set()
    for k in keys:
        m = re.match(r"encoder\.layers\.(\d+)\.", k)
        if m:
            layer_ids.add(int(m.group(1)))
    n_layers = (max(layer_ids) + 1) if layer_ids else 0
    if n_layers <= 0:
        raise RuntimeError("Cannot infer n_layers (encoder.layers.* not found).")

    k_lin1 = "encoder.layers.0.linear1.weight"
    if k_lin1 in sd:
        dim_ff = int(sd[k_lin1].shape[0])
    else:
        lin1_keys = [k for k in keys if k.endswith("linear1.weight")]
        if not lin1_keys:
            raise RuntimeError("Cannot infer dim_feedforward (linear1.weight not found).")
        dim_ff = int(sd[sorted(lin1_keys)[0]].shape[0])

    has_pool_ln = ("pool_ln.weight" in keys and "pool_ln.bias" in keys)

    head_w_idx = []
    for k in keys:
        m = re.match(r"head\.(\d+)\.weight", k)
        if m:
            head_w_idx.append(int(m.group(1)))
    if not head_w_idx:
        raise RuntimeError("Cannot infer head structure (head.*.weight not found).")
    head_final_idx = max(sorted(set(head_w_idx)))

    return {
        "n_bands": n_bands,
        "d_model": d_model,
        "max_len_ckpt": max_len_ckpt,
        "feat_dim": feat_dim,
        "g_dim": g_dim,
        "g_hidden": g_hidden,
        "n_layers": n_layers,
        "dim_ff": dim_ff,
        "has_pool_ln": has_pool_ln,
        "head_final_idx": head_final_idx,
    }

# ----------------------------
# 6) Resolve token/value feature from CKPT meta + fallback
# ----------------------------
def _pick_val_feat(feat_map, prefer=None):
    cand = []
    if prefer is not None:
        cand.append(str(prefer))
    cand += [
        "signal", "value", "flux",
        "flux_asinh", "flux_asinh_clip", "flux_asinh_norm", "flux_asinh_scaled",
        "mag", "mag_norm", "mag_clip", "mag_scaled",
        "delta_signal", "delta_flux", "signal_clip", "signal_norm",
    ]
    for c in cand:
        if c and (c in feat_map):
            return c
    keys = list(feat_map.keys())
    fuzzy = [k for k in keys if any(t in k for t in ["signal", "flux", "value", "asinh", "mag"])]
    if fuzzy:
        return sorted(fuzzy)[0]
    return None

first_obj = torch_load_compat(ckpts[0])
_, first_meta = extract_state_and_meta(first_obj)

CKPT_TOKEN_MODE = None
CKPT_VAL_FEAT = None
CKPT_VAL_IS_MAG = None
CKPT_SHIFT_BAND_IDS = None
CKPT_META_COLS = None
CKPT_GCACHE = None

if isinstance(first_meta, dict):
    CKPT_TOKEN_MODE = first_meta.get("token_mode", None)
    CKPT_VAL_FEAT = first_meta.get("val_feat", None)
    CKPT_VAL_IS_MAG = first_meta.get("val_is_mag", None)
    CKPT_SHIFT_BAND_IDS = first_meta.get("shift_band_ids_from_stage6", None)
    CKPT_META_COLS = first_meta.get("global_meta_cols", None)
    CKPT_GCACHE = first_meta.get("global_feature_cache", None)

# HARD guard: SHIFT_BAND_IDS must match ckpt if present
if CKPT_SHIFT_BAND_IDS is not None and bool(CKPT_SHIFT_BAND_IDS) != bool(SHIFT_BAND_IDS):
    raise RuntimeError(
        "[Stage 10] SHIFT_BAND_IDS mismatch between Stage6 policy and ckpt meta.\n"
        f"- Stage6 policy SHIFT_BAND_IDS={SHIFT_BAND_IDS}\n"
        f"- ckpt meta shift_band_ids_from_stage6={CKPT_SHIFT_BAND_IDS}\n"
        "Solusi: inference harus pakai FIX_DIR yang sama dengan training ckpt."
    )

SEQ_TOKEN_MODE = CKPT_TOKEN_MODE if CKPT_TOKEN_MODE is not None else POL_TOKEN_MODE
if SEQ_TOKEN_MODE is None:
    SEQ_TOKEN_MODE = "mag" if any(k.startswith("mag") for k in feat.keys()) else "generic"
SEQ_TOKEN_MODE = str(SEQ_TOKEN_MODE).lower().strip()

VAL_FEAT = _pick_val_feat(feat, prefer=(CKPT_VAL_FEAT or POL_SCORE_VALUE_FEAT))
if VAL_FEAT is None:
    raise RuntimeError("Cannot resolve VAL_FEAT from SEQ_FEATURE_NAMES.")
VAL_IS_MAG = bool(CKPT_VAL_IS_MAG) if CKPT_VAL_IS_MAG is not None else (SEQ_TOKEN_MODE == "mag" and "mag" in VAL_FEAT)

# ----------------------------
# 7) Band range sanity (fail-fast)
# ----------------------------
def _band_sanity_check(Bmm, Mmm, n_bands, shift_flag, sample_n=512):
    s = min(int(sample_n), int(Bmm.shape[0]))
    if s <= 0:
        return
    Bc = np.asarray(Bmm[:s])
    Mc = np.asarray(Mmm[:s])
    real = (Mc == 1)
    if not real.any():
        return
    br = Bc[real].astype(np.int64, copy=False)
    bmin = int(br.min())
    bmax = int(br.max())

    # accepted sets:
    # - if shift_flag True: real bands expected in [1..n_bands] (pad=0)
    # - if shift_flag False: real bands expected in [0..n_bands-1]
    if shift_flag:
        ok = (bmin >= 0) and (bmax <= n_bands)  # allow 0 if some weirdness, but real typically 1..n_bands
        if not ok:
            raise RuntimeError(
                "[Stage 10] Band-id range looks incompatible with SHIFT_BAND_IDS=True.\n"
                f"- observed real band min={bmin} max={bmax}\n"
                f"- expected roughly 1..{n_bands} (pad=0)\n"
                "Solusi: pastikan test_B.dat dibuat dengan Stage6 policy yang sama seperti training."
            )
    else:
        ok = (bmin >= 0) and (bmax <= (n_bands - 1))
        if not ok:
            raise RuntimeError(
                "[Stage 10] Band-id range looks incompatible with SHIFT_BAND_IDS=False.\n"
                f"- observed real band min={bmin} max={bmax}\n"
                f"- expected 0..{n_bands-1}\n"
                "Solusi: kemungkinan Stage6 menghasilkan band_id 1..n_bands tapi SHIFT_BAND_IDS tidak diset.\n"
                "Rebuild Stage6 / gunakan FIX_DIR yang benar."
            )

# we'll run sanity after we know n_bands from ckpt arch below

# ----------------------------
# 8) Determine META_COLS and whether need agg
# ----------------------------
DEFAULT_META_COLS = ["Z","Z_err","EBV_used","Z_missing","Z_err_missing","EBV_missing","is_photoz"]
META_COLS = DEFAULT_META_COLS
if isinstance(CKPT_META_COLS, (list, tuple)) and len(CKPT_META_COLS) > 0:
    META_COLS = [str(c) for c in CKPT_META_COLS]

# infer all folds arch + decide need_agg
fold_arch = []
fold_meta_summary = []
need_agg = False
arch_used = None

for fold, p in enumerate(ckpts):
    obj = torch_load_compat(p)
    sd, meta = extract_state_and_meta(obj)
    arch = infer_from_state(sd)
    fold_arch.append(arch)
    if arch_used is None:
        arch_used = dict(arch)

    if int(arch.get("g_dim", 0)) > len(META_COLS):
        need_agg = True
    if isinstance(meta, dict) and bool(meta.get("use_agg_seq_features", False)):
        need_agg = True

    # consistency guard across folds
    if isinstance(meta, dict):
        vf = meta.get("val_feat", None)
        vim = meta.get("val_is_mag", None)
        sb = meta.get("shift_band_ids_from_stage6", None)
        if vf is not None and str(vf) != str(VAL_FEAT):
            raise RuntimeError(f"[Stage 10] Fold {fold}: val_feat mismatch. fold={vf} vs resolved={VAL_FEAT}")
        if vim is not None and bool(vim) != bool(VAL_IS_MAG):
            raise RuntimeError(f"[Stage 10] Fold {fold}: val_is_mag mismatch. fold={vim} vs resolved={VAL_IS_MAG}")
        if sb is not None and bool(sb) != bool(SHIFT_BAND_IDS):
            raise RuntimeError(f"[Stage 10] Fold {fold}: SHIFT_BAND_IDS mismatch. fold={sb} vs policy={SHIFT_BAND_IDS}")

    fold_meta_summary.append({
        "fold": fold,
        "g_dim": int(arch.get("g_dim", 0)),
        "d_model": int(arch.get("d_model", 0)),
        "n_layers": int(arch.get("n_layers", 0)),
    })

# now we can sanity-check band range using ckpt n_bands
N_BANDS = int(arch_used["n_bands"])
_band_sanity_check(Bte, Mte, n_bands=N_BANDS, shift_flag=SHIFT_BAND_IDS, sample_n=512)

# ----------------------------
# 9) Build TEST global features (meta + optional agg seq feats) + CACHE
# ----------------------------
def _hash_cfg(d: dict) -> str:
    s = json.dumps(d, sort_keys=True, ensure_ascii=True)
    return hashlib.md5(s.encode("utf-8")).hexdigest()[:12]

def _prepare_meta_cols(df):
    df = df.copy(deep=False)

    if "EBV_used" in META_COLS and "EBV_used" not in df.columns:
        if ("EBV_clip" in df.columns):
            df["EBV_used"] = df["EBV_clip"]
        elif "EBV" in df.columns:
            df["EBV_used"] = df["EBV"]
        else:
            df["EBV_used"] = 0.0

    for c in META_COLS:
        if c not in df.columns:
            df[c] = 0.0

    Gm = df.iloc[pos_idx][META_COLS].copy()
    for c in META_COLS:
        Gm[c] = pd.to_numeric(Gm[c], errors="coerce").fillna(0.0).astype(np.float32)
    return Gm.to_numpy(dtype=np.float32, copy=False)

def _safe_div(a, b):
    return a / np.maximum(b, 1.0)

def build_agg_seq_features_memmap(Xmm, Bmm, Mmm, chunk=512):
    snr_i = feat["snr_tanh"]
    det_i = feat["detected"]
    val_i = feat[VAL_FEAT]

    # dim = glob(4) + global_val_feats(3) + per_band(N_BANDS*4)
    agg_dim = 4 + 3 + (N_BANDS * 4)
    out = np.zeros((NTE, agg_dim), dtype=np.float32)

    for start in range(0, NTE, int(chunk)):
        end = min(NTE, start + int(chunk))
        Xc = np.asarray(Xmm[start:end])  # (B,L,F)
        Bc = np.asarray(Bmm[start:end])  # (B,L)
        Mc = np.asarray(Mmm[start:end])  # (B,L)

        real = (Mc == 1)

        # band mapping (same as Stage8 agg builder)
        if SHIFT_BAND_IDS:
            Bc2 = Bc.astype(np.int16, copy=True)
            if real.any():
                Bc2[real] = np.clip(Bc2[real] - 1, 0, N_BANDS - 1)
            Bc2[~real] = 0
            Bc = Bc2.astype(np.int8, copy=False)

        tok_count = real.sum(axis=1).astype(np.float32)

        snr = np.abs(Xc[:, :, snr_i]).astype(np.float32, copy=False)
        det = (Xc[:, :, det_i] > 0.5).astype(np.float32, copy=False)
        val = Xc[:, :, val_i].astype(np.float32, copy=False)

        snr_r = snr * real
        det_r = det * real

        det_frac = _safe_div(det_r.sum(axis=1), tok_count)
        mean_abs_snr = _safe_div(snr_r.sum(axis=1), tok_count)
        max_abs_snr = np.where(tok_count > 0, snr_r.max(axis=1), 0.0).astype(np.float32)

        if VAL_IS_MAG:
            val_r = np.where(real, val, np.nan)
            mean_val = np.nan_to_num(np.nanmean(val_r, axis=1).astype(np.float32), nan=0.0)
            std_val  = np.nan_to_num(np.nanstd(val_r,  axis=1).astype(np.float32), nan=0.0)
            min_val  = np.nan_to_num(np.nanmin(val_r,  axis=1).astype(np.float32), nan=0.0)
            global_val_feats = np.stack([mean_val, std_val, min_val], axis=1).astype(np.float32)
        else:
            aval = np.abs(val).astype(np.float32, copy=False)
            aval_r = aval * real
            mean_aval = _safe_div(aval_r.sum(axis=1), tok_count).astype(np.float32)
            val_r = np.where(real, val, np.nan)
            std_val = np.nan_to_num(np.nanstd(val_r, axis=1).astype(np.float32), nan=0.0)
            max_aval = np.where(tok_count > 0, aval_r.max(axis=1), 0.0).astype(np.float32)
            global_val_feats = np.stack([mean_aval, std_val, max_aval], axis=1).astype(np.float32)

        per_band = []
        for b in range(N_BANDS):
            bm = (Bc == b) & real
            cnt = bm.sum(axis=1).astype(np.float32)

            detb = (det * bm).sum(axis=1).astype(np.float32)
            snrb = (snr * bm).sum(axis=1).astype(np.float32)

            det_frac_b = _safe_div(detb, cnt)
            mean_abs_snr_b = _safe_div(snrb, cnt)

            if VAL_IS_MAG:
                vb = np.where(bm, val, np.nan)
                mean_val_b = np.nan_to_num(np.nanmean(vb, axis=1).astype(np.float32), nan=0.0)
            else:
                ab = (np.abs(val).astype(np.float32, copy=False) * bm).sum(axis=1).astype(np.float32)
                mean_val_b = _safe_div(ab, cnt).astype(np.float32)

            per_band.append(np.stack([cnt, det_frac_b, mean_abs_snr_b, mean_val_b], axis=1))

        per_band = np.concatenate(per_band, axis=1).astype(np.float32)
        glob = np.stack([tok_count, det_frac, mean_abs_snr, max_abs_snr], axis=1).astype(np.float32)
        agg = np.concatenate([glob, global_val_feats, per_band], axis=1).astype(np.float32)

        out[start:end] = agg

        del Xc, Bc, Mc
        if (start // int(chunk)) % 4 == 0:
            gc.collect()

    return out

# --- TEST cache paths ---
G_TEST_CACHE = FIX_DIR / "global_features_test_raw.npy"
G_TEST_META  = FIX_DIR / "global_features_test_raw_meta.json"

agg_spec = {
    "NTE": int(NTE),
    "L": int(L),
    "Fdim": int(Fdim),
    "n_bands": int(N_BANDS),
    "shift_band_ids_from_stage6": bool(SHIFT_BAND_IDS),
    "pad_band_id_from_stage6": int(PAD_BAND_ID),
    "dtype_X_memmap": str(DTYPE_X_MEMMAP),
    "meta_cols": list(META_COLS),
    "val_feat": str(VAL_FEAT),
    "val_is_mag": bool(VAL_IS_MAG),
    "need_agg": bool(need_agg),
}
agg_hash = _hash_cfg(agg_spec)

print(f"[Stage 10] token_mode={SEQ_TOKEN_MODE} | VAL_FEAT={VAL_FEAT} | VAL_IS_MAG={VAL_IS_MAG} | X_dtype={DTYPE_X_MEMMAP}")
print(f"[Stage 10] SHIFT_BAND_IDS={SHIFT_BAND_IDS} | PAD_BAND_ID={PAD_BAND_ID} | N_BANDS={N_BANDS}")
print(f"[Stage 10] META_COLS={META_COLS} | need_agg={need_agg}")
print(f"[Stage 10] USE_EMA_WEIGHTS_FOR_INFER={USE_EMA_WEIGHTS_FOR_INFER} | STRICT_GDIM={STRICT_GDIM}")

G_raw_default = None
if G_TEST_CACHE.exists() and G_TEST_META.exists():
    try:
        old = json.loads(G_TEST_META.read_text())
        if old.get("agg_hash") == agg_hash and int(old.get("NTE", -1)) == int(NTE):
            G_raw_default = np.load(G_TEST_CACHE, allow_pickle=False).astype(np.float32, copy=False)
            if G_raw_default.shape[0] != NTE:
                G_raw_default = None
    except Exception:
        G_raw_default = None

if G_raw_default is None:
    print("[Stage 10] Building TEST global features (then cached)...")
    t0 = time.time()
    G_meta_np = _prepare_meta_cols(df_test_meta)

    if need_agg:
        G_seq_np = build_agg_seq_features_memmap(Xte, Bte, Mte, chunk=512)
        G_raw_default = np.concatenate([G_meta_np, G_seq_np], axis=1).astype(np.float32)
        agg_dim = int(G_seq_np.shape[1])
    else:
        G_raw_default = G_meta_np.astype(np.float32, copy=False)
        agg_dim = 0

    np.save(G_TEST_CACHE, G_raw_default.astype(np.float32, copy=False))
    G_TEST_META.write_text(json.dumps(
        {"agg_hash": agg_hash, "NTE": int(NTE), "spec": agg_spec, "agg_dim": int(agg_dim), "g_dim": int(G_raw_default.shape[1])},
        indent=2
    ))
    print(f"[Stage 10] TEST G built: shape={G_raw_default.shape} | time={time.time()-t0:.1f}s | cached={G_TEST_CACHE}")

# ----------------------------
# 10) Model definition (pooling matches Stage 8)
# ----------------------------
class FlexMultibandEventTransformer(nn.Module):
    def __init__(self, feat_dim, max_len, n_bands, d_model, n_heads, n_layers, dim_ff, dropout,
                 g_dim, g_hidden, has_pool_ln=True, head_final_idx=3):
        super().__init__()
        self.n_bands = int(n_bands)
        self.max_len = int(max_len)
        self.d_model = int(d_model)

        self.x_proj = nn.Sequential(
            nn.Linear(int(feat_dim), int(d_model)),
            nn.GELU(),
            nn.Dropout(float(dropout)),
        )
        self.band_emb = nn.Embedding(int(n_bands), int(d_model))

        self.pos_emb = nn.Parameter(torch.zeros(1, int(max_len), int(d_model)))
        nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=int(d_model),
            nhead=int(n_heads),
            dim_feedforward=int(dim_ff),
            dropout=float(dropout),
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=int(n_layers))

        self.attn = nn.Linear(int(d_model), 1)

        self.has_pool_ln = bool(has_pool_ln)
        if self.has_pool_ln:
            self.pool_ln = nn.LayerNorm(int(d_model))

        self.g_dim = int(g_dim)
        self.g_hidden = int(g_hidden)
        if self.g_dim > 0 and self.g_hidden > 0:
            self.g_proj = nn.Sequential(
                nn.Linear(int(g_dim), int(g_hidden)),
                nn.GELU(),
                nn.Dropout(float(dropout)),
            )
        else:
            self.g_proj = None

        in_head = int(d_model + (g_hidden if (self.g_proj is not None) else 0))
        if head_final_idx == 3:
            self.head = nn.Sequential(
                nn.Linear(in_head, int(d_model)),
                nn.GELU(),
                nn.Dropout(float(dropout)),
                nn.Linear(int(d_model), 1),
            )
        elif head_final_idx == 2:
            self.head = nn.Sequential(
                nn.Linear(in_head, int(d_model)),
                nn.GELU(),
                nn.Linear(int(d_model), 1),
            )
        else:
            self.head = nn.Sequential(
                nn.Linear(in_head, int(d_model)),
                nn.Linear(int(d_model), 1),
            )

    def forward(self, X, band_id, mask, G):
        X = X.to(torch.float32)
        band_id = band_id.clamp(0, self.n_bands - 1).to(torch.long)
        mask = mask.to(torch.long)

        pad_mask = (mask == 0)
        all_pad = pad_mask.all(dim=1)
        if all_pad.any():
            pad_mask = pad_mask.clone()
            pad_mask[all_pad, 0] = False

        h = self.x_proj(X) + self.band_emb(band_id) + self.pos_emb[:, :X.shape[1], :]
        h = self.encoder(h, src_key_padding_mask=pad_mask)

        a = self.attn(h).squeeze(-1)
        a = a.masked_fill(pad_mask, -1e9)
        w = torch.softmax(a, dim=1)
        pooled_attn = torch.sum(h * w.unsqueeze(-1), dim=1)

        valid = (~pad_mask).to(h.dtype).unsqueeze(-1)
        denom = valid.sum(dim=1).clamp_min(1.0)
        pooled_mean = (h * valid).sum(dim=1) / denom

        h_masked = h.masked_fill(pad_mask.unsqueeze(-1), -1e9)
        pooled_max = torch.max(h_masked, dim=1).values
        pooled_max = torch.where(torch.isfinite(pooled_max), pooled_max, torch.zeros_like(pooled_max))

        pooled = (0.50 * pooled_attn) + (0.30 * pooled_mean) + (0.20 * pooled_max)
        if self.has_pool_ln:
            pooled = self.pool_ln(pooled)

        if self.g_proj is not None:
            g = self.g_proj(G.to(torch.float32))
            z = torch.cat([pooled, g], dim=1)
        else:
            z = pooled

        return self.head(z).squeeze(-1)

@torch.inference_mode()
def predict_logits_batchwise(model, Xmm, Bmm, Mmm, G_raw, mean, std, batch_size=64):
    model.eval()
    out = np.zeros((Xmm.shape[0],), dtype=np.float32)
    N0 = int(Xmm.shape[0])

    for i in range(0, N0, int(batch_size)):
        j = min(N0, i + int(batch_size))
        Xb_np = np.asarray(Xmm[i:j]).astype(np.float32, copy=False)
        Bb_np = np.asarray(Bmm[i:j])
        Mb_np = np.asarray(Mmm[i:j])

        real = (Mb_np == 1)

        # Apply band shift EXACTLY ONCE if needed
        if SHIFT_BAND_IDS:
            Bb_np2 = Bb_np.astype(np.int16, copy=True)
            if real.any():
                Bb_np2[real] = np.clip(Bb_np2[real] - 1, 0, N_BANDS - 1)
            Bb_np2[~real] = 0
            Bb_np = Bb_np2.astype(np.int64, copy=False)
        else:
            Bb_np = Bb_np.astype(np.int64, copy=False)

        Gb_np = G_raw[i:j]
        Gb_np = ((Gb_np - mean) / std).astype(np.float32, copy=False)

        Xb = torch.from_numpy(Xb_np)
        Bb = torch.from_numpy(Bb_np)
        Mb = torch.from_numpy(Mb_np.astype(np.int64, copy=False))
        Gb = torch.from_numpy(Gb_np)

        logit = model(Xb.to(device), Bb.to(device), Mb.to(device), Gb.to(device))
        out[i:j] = logit.detach().cpu().numpy().astype(np.float32, copy=False)

    return out

# Batch size heuristic
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "64"))
if L >= 512:
    BATCH_SIZE = min(BATCH_SIZE, 32)
BATCH_SIZE = max(4, BATCH_SIZE)

test_logit_folds = np.zeros((NTE, n_splits), dtype=np.float32)

print(f"[Stage 10] Test inference: N_test={NTE:,} | folds={n_splits} | batch={BATCH_SIZE} | ensemble=mean_logits_then_sigmoid")

# ----------------------------
# 11) Fold inference
# ----------------------------
for fold, ckpt_path in enumerate(ckpts):
    ckpt_obj = torch_load_compat(ckpt_path)
    sd, meta = extract_state_and_meta(ckpt_obj)

    arch = fold_arch[fold]
    cfg = meta.get("cfg", {}) if isinstance(meta, dict) else {}
    dropout = float(cfg.get("dropout", 0.0)) if isinstance(cfg, dict) else 0.0

    n_heads = int(cfg.get("n_heads", 4)) if isinstance(cfg, dict) else 4
    if n_heads <= 0:
        n_heads = 4
    if (arch["d_model"] % n_heads) != 0:
        for h in [4, 8, 2, 1, 16, 32]:
            if h > 0 and (arch["d_model"] % h) == 0:
                n_heads = h
                break
        if (arch["d_model"] % n_heads) != 0:
            raise RuntimeError(f"Fold {fold}: cannot choose valid n_heads for d_model={arch['d_model']}")

    # HARD checks
    if arch["feat_dim"] != Fdim:
        raise RuntimeError(
            f"Fold {fold}: feature_dim mismatch.\n"
            f"- ckpt expects feat_dim={arch['feat_dim']}\n"
            f"- memmap has Fdim={Fdim}\n"
            "Solusi: pastikan STAGE 6 feature list sama saat training ckpt dibuat."
        )
    if arch["max_len_ckpt"] != L:
        raise RuntimeError(
            f"Fold {fold}: max_len mismatch.\n"
            f"- ckpt max_len={arch['max_len_ckpt']}\n"
            f"- memmap MAX_LEN={L}\n"
        )

    g_dim = int(arch["g_dim"])
    if g_dim <= 0:
        G_raw = np.zeros((NTE, 0), dtype=np.float32)
        g_mean = np.zeros((0,), dtype=np.float32)
        g_std  = np.ones((0,), dtype=np.float32)
    else:
        if G_raw_default.shape[1] < g_dim:
            msg = (
                f"[Stage 10] Fold {fold}: computed TEST G_raw dim is smaller than ckpt expects.\n"
                f"- G_raw_default_dim={G_raw_default.shape[1]}\n"
                f"- ckpt g_dim={g_dim}\n"
                "Ini biasanya artinya kamu tidak membangun global features yang sama seperti saat training.\n"
                "Solusi: pakai FIX_DIR yang sama, pastikan need_agg benar, dan META_COLS sama.\n"
            )
            if STRICT_GDIM:
                raise RuntimeError(msg)
            else:
                # legacy unsafe fallback
                pad = np.zeros((NTE, g_dim - G_raw_default.shape[1]), dtype=np.float32)
                G_raw = np.concatenate([G_raw_default, pad], axis=1).astype(np.float32, copy=False)
        elif G_raw_default.shape[1] > g_dim:
            G_raw = G_raw_default[:, :g_dim]
        else:
            G_raw = G_raw_default

        scaler = meta.get("global_scaler", None) if isinstance(meta, dict) else None
        if scaler is None or not isinstance(scaler, dict) or ("mean" not in scaler) or ("std" not in scaler):
            raise RuntimeError(f"[Stage 10] Fold {fold}: missing global_scaler in checkpoint meta.")
        g_mean = np.asarray(scaler["mean"], dtype=np.float32).reshape(-1)
        g_std  = np.asarray(scaler["std"],  dtype=np.float32).reshape(-1)

        if g_mean.shape[0] != g_dim or g_std.shape[0] != g_dim:
            raise RuntimeError(
                f"[Stage 10] Fold {fold}: global_scaler shape mismatch.\n"
                f"- mean/std len: {g_mean.shape[0]}/{g_std.shape[0]}\n"
                f"- g_dim: {g_dim}\n"
                "Solusi: ckpt tidak konsisten atau fitur global tidak sama."
            )
        g_std = np.where(g_std < 1e-6, 1.0, g_std).astype(np.float32)

    model = FlexMultibandEventTransformer(
        feat_dim=arch["feat_dim"],
        max_len=arch["max_len_ckpt"],
        n_bands=arch["n_bands"],
        d_model=arch["d_model"],
        n_heads=n_heads,
        n_layers=arch["n_layers"],
        dim_ff=arch["dim_ff"],
        dropout=dropout,
        g_dim=g_dim,
        g_hidden=arch["g_hidden"],
        has_pool_ln=arch["has_pool_ln"],
        head_final_idx=arch["head_final_idx"],
    ).to(device)

    model.load_state_dict(sd, strict=True)

    # OPTIONAL: apply EMA weights for inference (if present)
    used_ema = False
    ema_hits = 0
    if USE_EMA_WEIGHTS_FOR_INFER and isinstance(meta, dict) and isinstance(meta.get("ema_shadow", None), dict):
        ema_shadow = meta["ema_shadow"]
        st = model.state_dict()
        for k, v in ema_shadow.items():
            if k in st and torch.is_tensor(v) and st[k].shape == v.shape:
                st[k] = v.to(dtype=st[k].dtype, device=st[k].device)
                ema_hits += 1
        if ema_hits > 0:
            model.load_state_dict(st, strict=True)
            used_ema = True

    logits = predict_logits_batchwise(
        model, Xte, Bte, Mte, G_raw, mean=g_mean, std=g_std, batch_size=BATCH_SIZE
    )
    if not np.isfinite(logits).all():
        raise RuntimeError(f"[Stage 10] Fold {fold}: logits has NaN/inf. Check inputs/scaler.")

    test_logit_folds[:, fold] = logits
    probs_tmp = sigmoid_np(logits)

    print(
        f"  fold {fold}: d_model={arch['d_model']} n_heads={n_heads} g_dim={g_dim} | "
        f"ema={used_ema} (hits={ema_hits}) | "
        f"logit_mean={float(logits.mean()):.6f} | prob_mean={float(probs_tmp.mean()):.6f} | prob_std={float(probs_tmp.std()):.6f}"
    )

    del model, logits, probs_tmp
    gc.collect()

# ensemble on logits
test_logit_ens = test_logit_folds.mean(axis=1).astype(np.float32)
test_prob_folds = sigmoid_np(test_logit_folds).astype(np.float32)
test_prob_ens   = sigmoid_np(test_logit_ens).astype(np.float32)

if not np.isfinite(test_prob_ens).all():
    raise RuntimeError("[Stage 10] test_prob_ens contains NaN/inf (unexpected).")

# ----------------------------
# 12) Save artifacts
# ----------------------------
logit_fold_path = OUT_DIR / "test_logit_folds.npy"
logit_ens_path  = OUT_DIR / "test_logit_ens.npy"
prob_fold_path  = OUT_DIR / "test_prob_folds.npy"
prob_ens_path   = OUT_DIR / "test_prob_ens.npy"
csv_path        = OUT_DIR / "test_prob_ens.csv"
cfg_path        = OUT_DIR / "test_infer_config.json"

np.save(logit_fold_path, test_logit_folds)
np.save(logit_ens_path,  test_logit_ens)
np.save(prob_fold_path,  test_prob_folds)
np.save(prob_ens_path,   test_prob_ens)

pd.DataFrame({"object_id": test_ids, "prob": test_prob_ens}).to_csv(csv_path, index=False)

if EXPORT_TEST_PROB_FOLDS_CSV:
    df_f = pd.DataFrame({"object_id": test_ids})
    for f in range(n_splits):
        df_f[f"prob_fold{f}"] = test_prob_folds[:, f]
    (OUT_DIR / "test_prob_folds.csv").write_text(df_f.to_csv(index=False))

# Optional: binary predictions file (0/1) using BEST_THR from Stage 9 if available
thr_used = None
pred01_path = None
if EXPORT_TEST_PRED_01:
    thr_used = float(globals().get("BEST_THR", DEFAULT_THR_IF_MISSING))
    thr_used = min(max(thr_used, 0.0), 1.0)
    test_pred01 = (test_prob_ens >= thr_used).astype(np.int8)
    pred01_path = OUT_DIR / "test_pred_01.csv"
    pd.DataFrame({"object_id": test_ids, "prediction": test_pred01.astype(int)}).to_csv(pred01_path, index=False)

infer_cfg = {
    "seed": int(SEED),
    "n_splits": int(n_splits),
    "ensemble": "mean_logits_then_sigmoid",
    "batch_size": int(BATCH_SIZE),
    "max_len": int(L),
    "feature_dim": int(Fdim),
    "token_mode": str(SEQ_TOKEN_MODE),
    "val_feat": str(VAL_FEAT),
    "val_is_mag": bool(VAL_IS_MAG),
    "shift_band_ids_from_stage6": bool(SHIFT_BAND_IDS),
    "pad_band_id_from_stage6": int(PAD_BAND_ID),
    "dtype_X_memmap": str(DTYPE_X_MEMMAP),
    "global_meta_cols": META_COLS,
    "need_agg_seq": bool(need_agg),
    "global_default_dim": int(G_raw_default.shape[1]),
    "use_ema_weights_for_infer": bool(USE_EMA_WEIGHTS_FOR_INFER),
    "export_test_pred_01": bool(EXPORT_TEST_PRED_01),
    "thr_used_for_pred01": (None if thr_used is None else float(thr_used)),
    "ckpt_dir": str(CKPT_DIR),
    "ckpts": [str(p) for p in ckpts],
    "arch_inferred_from_first_fold": arch_used,
    "fold_meta_summary": fold_meta_summary,
    "test_global_cache": {"path": str(G_TEST_CACHE), "meta": str(G_TEST_META), "agg_hash": agg_hash},
    "outputs": {
        "test_logit_folds": str(logit_fold_path),
        "test_logit_ens": str(logit_ens_path),
        "test_prob_folds": str(prob_fold_path),
        "test_prob_ens": str(prob_ens_path),
        "test_prob_ens_csv": str(csv_path),
        "test_pred_01_csv": (None if pred01_path is None else str(pred01_path)),
    }
}
with open(cfg_path, "w", encoding="utf-8") as f:
    json.dump(infer_cfg, f, indent=2)

print("\n[Stage 10] DONE")
print(f"- Saved logits folds: {logit_fold_path}")
print(f"- Saved logits ens  : {logit_ens_path}")
print(f"- Saved probs folds : {prob_fold_path}")
print(f"- Saved probs ens   : {prob_ens_path}")
print(f"- Saved csv         : {csv_path}")
if pred01_path is not None:
    print(f"- Saved pred 0/1    : {pred01_path} (thr={thr_used:.6f})")
print(f"- Saved config      : {cfg_path}")
print(f"- ens prob mean={float(test_prob_ens.mean()):.6f} | std={float(test_prob_ens.std()):.6f} | "
      f"min={float(test_prob_ens.min()):.6f} | max={float(test_prob_ens.max()):.6f}")

globals().update({
    "test_ids": test_ids,
    "test_logit_folds": test_logit_folds,
    "test_logit_ens": test_logit_ens,
    "test_prob_folds": test_prob_folds,
    "test_prob_ens": test_prob_ens,
    "TEST_LOGIT_FOLDS_PATH": logit_fold_path,
    "TEST_LOGIT_ENS_PATH": logit_ens_path,
    "TEST_PROB_FOLDS_PATH": prob_fold_path,
    "TEST_PROB_ENS_PATH": prob_ens_path,
    "TEST_PROB_CSV_PATH": csv_path,
    "TEST_INFER_CFG_PATH": cfg_path,
    "STAGE10_VAL_FEAT": VAL_FEAT,
    "STAGE10_VAL_IS_MAG": VAL_IS_MAG,
    "STAGE10_USE_EMA_INFER": bool(USE_EMA_WEIGHTS_FOR_INFER),
    "TEST_PRED01_PATH": (pred01_path if pred01_path is not None else None),
    "TEST_PRED01_THR_USED": (thr_used if thr_used is not None else None),
    "TEST_GLOBAL_FEAT_CACHE_PATH": G_TEST_CACHE,
    "TEST_GLOBAL_FEAT_META_PATH": G_TEST_META,
})

gc.collect()


# Evaluation 

In [None]:
# ============================================================
# ONE CELL — EVALUATION (Precision / Recall / F1) + Threshold Sweep (OOF)
# REVISI FULL v3.4 (HOLDOUT-SAFE + GE/GT + PREVALENCE THR + EXTRA CANDS + GT->GE_EQUIV)
#
# Default behavior:
# - HOLDOUT_SAFE=True: drop non-finite probs (NaN/inf) from tuning
# - Evaluate BOTH rules: ge (>=) and gt (>)
# - Default BEST_THR uses rule=ge, metric=F1
# ============================================================

import gc, json, math, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Require minimal
# ----------------------------
if "df_train_meta" not in globals():
    raise RuntimeError("Missing df_train_meta. Jalankan stage meta dulu.")

ART_DIR = Path(globals().get("ART_DIR", "/kaggle/working"))
OOF_DIR = Path(globals().get("OOF_DIR", ART_DIR / "oof"))
OOF_DIR.mkdir(parents=True, exist_ok=True)

# Switches
HOLDOUT_SAFE = True            # drop non-finite oof probs from tuning (recommended)
DO_BOTH_RULES = True           # compute ge and gt tables
ADD_NEXTAFTER_CANDIDATES = True  # add nextafter(unique_probs) as thr candidates (recommended)

# ----------------------------
# Utils
# ----------------------------
def _norm_id(x):
    if isinstance(x, (bytes, np.bytes_)):
        try:
            x = x.decode("utf-8", errors="ignore")
        except Exception:
            x = str(x)
    s = str(x).strip()
    if (s.startswith("b'") and s.endswith("'")) or (s.startswith('b"') and s.endswith('"')):
        s = s[2:-1]
    return s.strip()

def _as_1d_float32(arr):
    a = np.asarray(arr)
    if a.dtype == object and a.ndim == 0:
        try:
            a = np.asarray(a.item())
        except Exception:
            pass
    a = np.asarray(a, dtype=np.float32)
    if a.ndim == 0:
        return a.reshape(1)
    if a.ndim > 1:
        a = a.reshape(-1)
    return a

def _safe_div(a, b):
    return a / np.maximum(b, 1e-12)

def _to_np_bool(x):
    if isinstance(x, np.ndarray):
        return x.astype(bool, copy=False)
    if hasattr(x, "to_numpy"):
        return x.to_numpy(dtype=bool, copy=False)
    return np.asarray(x, dtype=bool)

# HOLDOUT_SAFE: keep NaN as invalid (do not nan_to_num -> 0 for tuning)
def _sanitize_prob(p, holdout_safe=True):
    p = np.asarray(p, dtype=np.float32)
    if holdout_safe:
        # keep non-finite for filtering later
        p = np.clip(p, 0.0, 1.0, out=p, where=np.isfinite(p))
        return p.astype(np.float32, copy=False)
    # legacy: force finite
    p = np.nan_to_num(p, nan=0.0, posinf=1.0, neginf=0.0)
    p = np.clip(p, 0.0, 1.0)
    return p.astype(np.float32)

# ----------------------------
# 0a) Normalize meta index
# ----------------------------
df_train_meta = df_train_meta.copy(deep=False)
df_train_meta.index = pd.Index([_norm_id(z) for z in df_train_meta.index], name=df_train_meta.index.name)

# ----------------------------
# 0b) Detect target column
# ----------------------------
def _detect_target_col(df):
    for cand in ["target", "y", "label", "class", "is_tde", "binary_target", "target_id"]:
        if cand in df.columns:
            return cand
    return None

TARGET_COL = _detect_target_col(df_train_meta)
if TARGET_COL is None:
    raise RuntimeError(
        "Cannot detect target column in df_train_meta. "
        f"Columns sample: {list(df_train_meta.columns)[:60]}"
    )

y_series = pd.to_numeric(df_train_meta[TARGET_COL], errors="coerce").fillna(0.0)
y_bin = (y_series.to_numpy(dtype=np.float32) > 0).astype(np.int8)
y_map = pd.Series(y_bin, index=df_train_meta.index)
if y_map.index.has_duplicates:
    y_map = y_map.groupby(level=0).max()

# ----------------------------
# 1) Load oof_prob (prefer csv)
# ----------------------------
def load_oof():
    pcsv = OOF_DIR / "oof_prob.csv"
    if pcsv.exists():
        df = pd.read_csv(pcsv)
        if ("object_id" in df.columns) and ("oof_prob" in df.columns):
            df = df[["object_id", "oof_prob"]].copy()
            df["object_id"] = df["object_id"].apply(_norm_id)
            p = _sanitize_prob(_as_1d_float32(df["oof_prob"].to_numpy()), holdout_safe=HOLDOUT_SAFE)
            if len(p) != len(df):
                raise RuntimeError("oof_prob.csv: length mismatch after parsing.")
            df["oof_prob"] = p
            return p, df, "csv(oof_prob.csv)"

    if "oof_prob" in globals():
        p = _sanitize_prob(_as_1d_float32(globals()["oof_prob"]), holdout_safe=HOLDOUT_SAFE)
        return p, None, "globals(oof_prob)"

    pnpy = OOF_DIR / "oof_prob.npy"
    if pnpy.exists():
        p = _sanitize_prob(_as_1d_float32(np.load(pnpy, allow_pickle=False)), holdout_safe=HOLDOUT_SAFE)
        return p, None, "npy(oof_prob.npy)"

    raise FileNotFoundError("OOF prob tidak ditemukan (oof_prob.csv / globals oof_prob / oof_prob.npy).")

oof_prob_all, df_oof_csv, oof_src = load_oof()
if not isinstance(oof_prob_all, np.ndarray) or oof_prob_all.ndim != 1:
    raise TypeError(f"Invalid oof_prob. type={type(oof_prob_all)} ndim={getattr(oof_prob_all,'ndim',None)}")

# ----------------------------
# 2) Align y to OOF order
# ----------------------------
if df_oof_csv is not None:
    # de-dup ids in oof: mean per id but preserve first order
    ids_first = pd.unique(df_oof_csv["object_id"].to_numpy())
    if len(ids_first) != len(df_oof_csv):
        df_mean = df_oof_csv.groupby("object_id", as_index=True)["oof_prob"].mean()
        df_oof_csv = pd.DataFrame({"object_id": ids_first})
        df_oof_csv["oof_prob"] = df_mean.reindex(ids_first).to_numpy(dtype=np.float32)
        oof_prob_all = df_oof_csv["oof_prob"].to_numpy(dtype=np.float32, copy=False)

    train_ids_all = df_oof_csv["object_id"].tolist()

    ok_raw = pd.Index(train_ids_all).isin(y_map.index)
    ok = _to_np_bool(ok_raw)
    if not ok.all():
        bad = [train_ids_all[i] for i in np.where(~ok)[0][:10]]
        print(f"[WARN] oof ids not in df_train_meta: missing_n={int((~ok).sum())} examples={bad}")
        df_oof_csv = df_oof_csv.loc[ok].reset_index(drop=True)
        train_ids_all = df_oof_csv["object_id"].tolist()
        oof_prob_all = df_oof_csv["oof_prob"].to_numpy(dtype=np.float32, copy=False)

    y_all = y_map.reindex(train_ids_all).to_numpy(dtype=np.int8, copy=True)

elif "train_ids_ordered" in globals():
    ids = [_norm_id(z) for z in list(globals()["train_ids_ordered"])]
    if len(ids) != len(oof_prob_all):
        raise RuntimeError("train_ids_ordered length mismatch with oof_prob. Gunakan oof_prob.csv agar alignment aman.")
    missing = [oid for oid in ids if oid not in y_map.index]
    if missing:
        raise KeyError(f"train_ids_ordered contains ids not in df_train_meta. ex={missing[:10]} missing_n={len(missing)}")
    train_ids_all = ids
    y_all = y_map.reindex(train_ids_all).to_numpy(dtype=np.int8, copy=True)

else:
    if len(oof_prob_all) != len(df_train_meta):
        raise RuntimeError(
            f"Tidak bisa align y. len(oof_prob)={len(oof_prob_all)} != len(df_train_meta)={len(df_train_meta)} "
            "dan tidak ada oof_prob.csv atau train_ids_ordered."
        )
    if df_train_meta.index.has_duplicates:
        raise RuntimeError(
            "df_train_meta.index has duplicates, tapi oof source tidak punya object_id ordering. "
            "Solusi: simpan oof_prob.csv (object_id + oof_prob) atau sediakan train_ids_ordered."
        )
    train_ids_all = df_train_meta.index.astype(str).tolist()
    y_all = y_map.reindex(train_ids_all).to_numpy(dtype=np.int8, copy=True)

if len(y_all) != len(oof_prob_all):
    raise RuntimeError(f"Length mismatch: y={len(y_all)} vs oof_prob={len(oof_prob_all)}")

uy = set(np.unique(y_all).tolist())
if not uy.issubset({0, 1}):
    raise ValueError(f"y must be binary 0/1. Found: {sorted(list(uy))}")

# ----------------------------
# 2b) HOLDOUT_SAFE filtering (valid only)
# ----------------------------
valid = np.isfinite(oof_prob_all)
if HOLDOUT_SAFE:
    train_ids = [train_ids_all[i] for i in np.where(valid)[0]]
    oof_prob = np.clip(oof_prob_all[valid].astype(np.float32), 0.0, 1.0)
    y = y_all[valid].astype(np.int8)
else:
    train_ids = train_ids_all
    oof_prob = np.clip(np.nan_to_num(oof_prob_all, nan=0.0, posinf=1.0, neginf=0.0).astype(np.float32), 0.0, 1.0)
    y = y_all.astype(np.int8)

N_all = int(len(y_all))
N = int(len(y))
pos = int((y == 1).sum())
neg = int((y == 0).sum())

print(f"[Eval] OOF source={oof_src} | target_col={TARGET_COL}")
print(f"[Eval] rows: valid={N:,} / total={N_all:,} | pos={pos:,} neg={neg:,} pos%={pos/max(N,1)*100:.6f}%")
if HOLDOUT_SAFE and (N < N_all):
    print(f"[Eval] HOLDOUT_SAFE dropped non-finite rows: {N_all - N} rows")

# ----------------------------
# 3) Ranking metrics (threshold-free)
# ----------------------------
roc_auc = None
pr_auc = None
try:
    from sklearn.metrics import roc_auc_score, average_precision_score
    if (y.max() == 1) and (y.min() == 0) and N > 1:
        roc_auc = float(roc_auc_score(y, oof_prob))
        pr_auc  = float(average_precision_score(y, oof_prob))
except Exception:
    pass

# ----------------------------
# 4) Threshold candidates (grid + quantiles + unique + nextafter + extras)
# ----------------------------
grid = np.concatenate([
    np.linspace(0.00, 0.10, 41, dtype=np.float32),
    np.linspace(0.10, 0.90, 161, dtype=np.float32),
    np.linspace(0.90, 1.00, 41, dtype=np.float32),
]).astype(np.float32)

qs = np.linspace(0.001, 0.999, 999, dtype=np.float32)
try:
    quant_thr = np.quantile(oof_prob, qs).astype(np.float32) if N > 0 else np.array([], dtype=np.float32)
except Exception:
    quant_thr = np.array([], dtype=np.float32)

uniq = np.unique(oof_prob.astype(np.float32))
if len(uniq) > 8000:
    take = np.linspace(0, len(uniq) - 1, 8000, dtype=int)
    uniq = uniq[take].astype(np.float32)

uniq_up = np.nextafter(uniq, np.float32(1.0)).astype(np.float32) if (ADD_NEXTAFTER_CANDIDATES and len(uniq) > 0) else np.array([], dtype=np.float32)

# prevalence-match threshold (for ge): choose thr so predicted positives roughly == pos
if pos > 0:
    p_sorted_tmp = np.sort(oof_prob)[::-1]
    thr_prev = float(p_sorted_tmp[min(pos - 1, N - 1)])
else:
    thr_prev = 1.0

extra = [0.0, 0.5, 1.0, float(thr_prev)]
for cand_name in ["BEST_THR", "OOF_BEST_THR_F1", "BEST_THR_F1"]:
    if cand_name in globals() and globals()[cand_name] is not None:
        try:
            extra.append(float(globals()[cand_name]))
        except Exception:
            pass

thr_candidates = np.unique(
    np.clip(np.concatenate([grid, quant_thr, uniq, uniq_up, np.array(extra, dtype=np.float32)]), 0.0, 1.0)
).astype(np.float32)

if len(thr_candidates) > 20000:
    take = np.linspace(0, len(thr_candidates) - 1, 20000, dtype=int)
    thr_candidates = thr_candidates[take].astype(np.float32)

# ----------------------------
# 5) FAST sweep using sorted probabilities
# ----------------------------
ord_desc = np.argsort(-oof_prob)
p_sorted = oof_prob[ord_desc]
y_sorted = y[ord_desc].astype(np.int8)

pos_prefix = np.cumsum(y_sorted == 1).astype(np.int64)
neg_prefix = np.cumsum(y_sorted == 0).astype(np.int64)
pos_total = int(pos_prefix[-1]) if N > 0 else 0
neg_total = int(neg_prefix[-1]) if N > 0 else 0

def _metrics_from_counts(tp, fp, fn, tn):
    tp = tp.astype(np.float64); fp = fp.astype(np.float64)
    fn = fn.astype(np.float64); tn = tn.astype(np.float64)

    prec = _safe_div(tp, tp + fp)
    rec  = _safe_div(tp, tp + fn)
    f1   = _safe_div(2 * prec * rec, prec + rec)

    def fbeta(prec, rec, beta):
        b2 = beta * beta
        return _safe_div((1.0 + b2) * prec * rec, b2 * prec + rec)

    f05 = fbeta(prec, rec, 0.5)
    f2  = fbeta(prec, rec, 2.0)

    acc  = _safe_div(tp + tn, tp + fp + fn + tn)
    tpr  = _safe_div(tp, tp + fn)
    tnr  = _safe_div(tn, tn + fp)
    bacc = 0.5 * (tpr + tnr)

    num = tp * tn - fp * fn
    den = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
    mcc = np.where(den > 0, num / np.sqrt(den), 0.0)

    return f1, f05, f2, prec, rec, acc, bacc, mcc

def _sweep(rule: str):
    # k = number of predicted positives
    if rule == "ge":
        k = np.searchsorted(-p_sorted, -thr_candidates.astype(np.float32), side="right").astype(np.int64)
    elif rule == "gt":
        k = np.searchsorted(-p_sorted, -thr_candidates.astype(np.float32), side="left").astype(np.int64)
    else:
        raise ValueError("rule must be 'ge' or 'gt'")

    k = np.clip(k, 0, N).astype(np.int64)

    tp = np.where(k > 0, pos_prefix[k - 1], 0).astype(np.int64)
    fp = np.where(k > 0, neg_prefix[k - 1], 0).astype(np.int64)
    fn = (pos_total - tp).astype(np.int64)
    tn = (neg_total - fp).astype(np.int64)

    f1, f05, f2, prec, rec, acc, bacc, mcc = _metrics_from_counts(tp, fp, fn, tn)

    return pd.DataFrame({
        "thr": thr_candidates.astype(np.float32),
        "rule": rule,
        "f1": f1.astype(np.float32),
        "f0.5": f05.astype(np.float32),
        "f2": f2.astype(np.float32),
        "precision": prec.astype(np.float32),
        "recall": rec.astype(np.float32),
        "acc": acc.astype(np.float32),
        "bacc": bacc.astype(np.float32),
        "mcc": mcc.astype(np.float32),
        "tp": tp, "fp": fp, "fn": fn, "tn": tn,
        "pos_pred": k.astype(np.int64),
    })

thr_ge = _sweep("ge")
thr_gt = _sweep("gt") if DO_BOTH_RULES else None

def _pick_best(df, primary, tie_cols):
    sort_cols = [primary] + tie_cols
    asc = [False] * len(sort_cols)
    return df.sort_values(sort_cols, ascending=asc).iloc[0]

def _eval_at(thr, rule):
    thr = float(thr)
    if rule == "ge":
        k0 = int(np.searchsorted(-p_sorted, -np.float32(thr), side="right"))
    else:
        k0 = int(np.searchsorted(-p_sorted, -np.float32(thr), side="left"))
    k0 = max(0, min(k0, N))

    tp0 = int(pos_prefix[k0 - 1]) if k0 > 0 else 0
    fp0 = int(neg_prefix[k0 - 1]) if k0 > 0 else 0
    fn0 = int(pos_total - tp0)
    tn0 = int(neg_total - fp0)

    p0 = tp0 / max(tp0 + fp0, 1)
    r0 = tp0 / max(tp0 + fn0, 1)
    f10 = 0.0 if (tp0 == 0 or (p0 + r0) == 0) else (2 * p0 * r0 / (p0 + r0))
    f05 = 0.0 if (0.25 * p0 + r0) == 0 else ((1.25) * p0 * r0 / (0.25 * p0 + r0))
    f2  = 0.0 if (4.0 * p0 + r0) == 0 else ((5.0) * p0 * r0 / (4.0 * p0 + r0))

    acc0  = (tp0 + tn0) / max(tp0 + fp0 + fn0 + tn0, 1)
    bacc0 = 0.5 * ((tp0 / max(tp0 + fn0, 1)) + (tn0 / max(tn0 + fp0, 1)))

    den0 = (tp0 + fp0) * (tp0 + fn0) * (tn0 + fp0) * (tn0 + fn0)
    mcc0 = 0.0 if den0 <= 0 else ((tp0 * tn0 - fp0 * fn0) / math.sqrt(den0))

    return {
        "thr": thr, "rule": rule,
        "tp": tp0, "fp": fp0, "fn": fn0, "tn": tn0, "pos_pred": k0,
        "precision": float(p0), "recall": float(r0),
        "f1": float(f10), "f0.5": float(f05), "f2": float(f2),
        "acc": float(acc0), "bacc": float(bacc0), "mcc": float(mcc0),
    }

base_ge = _eval_at(0.5, "ge")
base_gt = _eval_at(0.5, "gt") if DO_BOTH_RULES else None

# Bests (rule=ge)
best_f1_ge  = _pick_best(thr_ge, "f1",   ["mcc","bacc","recall","precision","acc"])
best_f05_ge = _pick_best(thr_ge, "f0.5", ["precision","mcc","f1","acc"])
best_f2_ge  = _pick_best(thr_ge, "f2",   ["recall","mcc","f1","bacc","acc"])

BEST_THR_GE_F1  = float(best_f1_ge["thr"])
BEST_THR_GE_F05 = float(best_f05_ge["thr"])
BEST_THR_GE_F2  = float(best_f2_ge["thr"])

best_ge_f1  = _eval_at(BEST_THR_GE_F1, "ge")
best_ge_f05 = _eval_at(BEST_THR_GE_F05, "ge")
best_ge_f2  = _eval_at(BEST_THR_GE_F2, "ge")

# Bests (rule=gt)
if DO_BOTH_RULES:
    best_f1_gt  = _pick_best(thr_gt, "f1",   ["mcc","bacc","recall","precision","acc"])
    BEST_THR_GT_F1 = float(best_f1_gt["thr"])
    best_gt_f1 = _eval_at(BEST_THR_GT_F1, "gt")

    # gt -> ge equivalent (so downstream can still use prob >= thr)
    BEST_THR_GT_F1_GE_EQUIV = float(np.nextafter(np.float32(BEST_THR_GT_F1), np.float32(1.0)))
else:
    BEST_THR_GT_F1 = None
    BEST_THR_GT_F1_GE_EQUIV = None
    best_gt_f1 = None

# Default export for downstream (keep old variable names)
BEST_THR_F1  = BEST_THR_GE_F1
BEST_THR_F05 = BEST_THR_GE_F05
BEST_THR_F2  = BEST_THR_GE_F2
BEST_THR     = BEST_THR_GE_F1

# ----------------------------
# 6) Report + tables
# ----------------------------
print("\nEVALUATION (OOF) — Precision/Recall/F-scores (+BACC/MCC)")
if roc_auc is not None:
    print(f"- ROC-AUC={roc_auc:.6f} | PR-AUC={pr_auc:.6f}")
print(f"- prevalence-match thr (ge) ~ {thr_prev:.6f}")

print("\nBaseline @ thr=0.5")
print(f"- rule=ge: F1={base_ge['f1']:.6f} | P={base_ge['precision']:.6f} | R={base_ge['recall']:.6f} | "
      f"ACC={base_ge['acc']:.6f} | BACC={base_ge['bacc']:.6f} | MCC={base_ge['mcc']:.6f} | pos_pred={base_ge['pos_pred']}")
if DO_BOTH_RULES:
    print(f"- rule=gt: F1={base_gt['f1']:.6f} | P={base_gt['precision']:.6f} | R={base_gt['recall']:.6f} | "
          f"ACC={base_gt['acc']:.6f} | BACC={base_gt['bacc']:.6f} | MCC={base_gt['mcc']:.6f} | pos_pred={base_gt['pos_pred']}")

print(f"\nBEST (rule=ge) — default downstream: pred = prob >= thr")
print(f"- BEST-F1   @ thr={best_ge_f1['thr']:.6f} | F1={best_ge_f1['f1']:.6f} | P={best_ge_f1['precision']:.6f} | R={best_ge_f1['recall']:.6f} | pos_pred={best_ge_f1['pos_pred']}")
print(f"- BEST-F0.5 @ thr={best_ge_f05['thr']:.6f} | F0.5={best_ge_f05['f0.5']:.6f} | P={best_ge_f05['precision']:.6f} | R={best_ge_f05['recall']:.6f}")
print(f"- BEST-F2   @ thr={best_ge_f2['thr']:.6f} | F2={best_ge_f2['f2']:.6f} | P={best_ge_f2['precision']:.6f} | R={best_ge_f2['recall']:.6f}")

if DO_BOTH_RULES:
    print(f"\nBEST (rule=gt) — strict boundary (>)")
    print(f"- BEST-F1 @ thr={best_gt_f1['thr']:.6f} (gt) | ge_equiv={BEST_THR_GT_F1_GE_EQUIV:.6f} | "
          f"F1={best_gt_f1['f1']:.6f} | pos_pred={best_gt_f1['pos_pred']}")

# Combine + sort for top view
thr_all = pd.concat([thr_ge, thr_gt], ignore_index=True) if DO_BOTH_RULES else thr_ge.copy()
thr_sorted = thr_all.sort_values(
    ["f1","mcc","bacc","recall","precision"],
    ascending=[False, False, False, False, False]
).reset_index(drop=True)

print("\nTop 10 by F1 (mixed rules):")
for i in range(min(10, len(thr_sorted))):
    r = thr_sorted.iloc[i]
    print(f"{i+1:02d}. rule={r['rule']} thr={float(r['thr']):.6f} | f1={float(r['f1']):.6f} | "
          f"P={float(r['precision']):.6f} R={float(r['recall']):.6f} | mcc={float(r['mcc']):.6f} bacc={float(r['bacc']):.6f} | "
          f"pos_pred={int(r['pos_pred'])}")

# ----------------------------
# 7) Save artifacts
# ----------------------------
out_txt   = OOF_DIR / "eval_report.txt"
out_csv   = OOF_DIR / "eval_threshold_table.csv"
out_csv_t = OOF_DIR / "eval_threshold_table_top500.csv"
out_json  = OOF_DIR / "eval_summary.json"

thr_sorted.to_csv(out_csv, index=False)
thr_sorted.head(500).to_csv(out_csv_t, index=False)

payload = {
    "version": "v3.4",
    "source": oof_src,
    "target_col": TARGET_COL,
    "n_total_rows": int(N_all),
    "n_valid_rows": int(N),
    "pos_valid": int(pos),
    "neg_valid": int(neg),
    "roc_auc_valid_only": roc_auc,
    "pr_auc_valid_only": pr_auc,
    "prevalence_match_thr_ge": float(thr_prev),
    "baseline_thr_0p5": {"ge": base_ge, "gt": base_gt},
    "best_ge": {"f1": best_ge_f1, "f0.5": best_ge_f05, "f2": best_ge_f2},
    "best_gt": {"f1": best_gt_f1, "ge_equiv_for_downstream_using_ge": {"f1": BEST_THR_GT_F1_GE_EQUIV}},
    "default_best_thr": {"metric": "f1", "rule": "ge", "thr": float(BEST_THR)},
    "switches": {
        "HOLDOUT_SAFE": bool(HOLDOUT_SAFE),
        "DO_BOTH_RULES": bool(DO_BOTH_RULES),
        "ADD_NEXTAFTER_CANDIDATES": bool(ADD_NEXTAFTER_CANDIDATES),
    },
    "paths": {"report": str(out_txt), "table": str(out_csv), "table_top500": str(out_csv_t), "summary": str(out_json)},
}

# text report (ringkas tapi jelas)
lines = []
lines.append("OOF Evaluation Report (v3.4)")
lines.append(f"source={oof_src} | target_col={TARGET_COL}")
lines.append(f"valid_rows={N} / total_rows={N_all} | pos={pos} | neg={neg} | pos%={pos/max(N,1)*100:.10f}%")
if roc_auc is not None:
    lines.append(f"ROC-AUC(valid-only)={roc_auc:.10f} | PR-AUC(valid-only)={pr_auc:.10f}")
lines.append(f"prevalence_match_thr_ge={thr_prev:.10f}")
lines.append("")
lines.append("Baseline @ thr=0.5")
lines.append(f"ge: F1={base_ge['f1']:.10f} P={base_ge['precision']:.10f} R={base_ge['recall']:.10f} "
             f"ACC={base_ge['acc']:.10f} BACC={base_ge['bacc']:.10f} MCC={base_ge['mcc']:.10f} pos_pred={base_ge['pos_pred']}")
if DO_BOTH_RULES and base_gt is not None:
    lines.append(f"gt: F1={base_gt['f1']:.10f} P={base_gt['precision']:.10f} R={base_gt['recall']:.10f} "
                 f"ACC={base_gt['acc']:.10f} BACC={base_gt['bacc']:.10f} MCC={base_gt['mcc']:.10f} pos_pred={base_gt['pos_pred']}")
lines.append("")
lines.append(f"BEST (ge/F1) thr={BEST_THR_GE_F1:.10f} | F1={best_ge_f1['f1']:.10f} P={best_ge_f1['precision']:.10f} R={best_ge_f1['recall']:.10f}")
if DO_BOTH_RULES and best_gt_f1 is not None:
    lines.append(f"BEST (gt/F1) thr={BEST_THR_GT_F1:.10f} | ge_equiv={BEST_THR_GT_F1_GE_EQUIV:.10f} | F1={best_gt_f1['f1']:.10f}")
lines.append("")
lines.append("Top 10 by F1 (mixed rules):")
for i in range(min(10, len(thr_sorted))):
    r = thr_sorted.iloc[i]
    lines.append(f"{i+1:02d}. rule={r['rule']} thr={float(r['thr']):.10f} f1={float(r['f1']):.10f} "
                 f"P={float(r['precision']):.10f} R={float(r['recall']):.10f} "
                 f"mcc={float(r['mcc']):.10f} bacc={float(r['bacc']):.10f} pos_pred={int(r['pos_pred'])}")

out_txt.write_text("\n".join(lines) + "\n", encoding="utf-8")
out_json.write_text(json.dumps(payload, indent=2), encoding="utf-8")

print("\nSaved:")
print(f"- {out_txt}")
print(f"- {out_csv}")
print(f"- {out_csv_t}")
print(f"- {out_json}")

globals().update({
    "BEST_THR": float(BEST_THR),
    "BEST_THR_F1": float(BEST_THR_F1),
    "BEST_THR_F05": float(BEST_THR_F05),
    "BEST_THR_F2": float(BEST_THR_F2),

    "BEST_THR_GE_F1": float(BEST_THR_GE_F1),
    "BEST_THR_GE_F05": float(BEST_THR_GE_F05),
    "BEST_THR_GE_F2": float(BEST_THR_GE_F2),

    "BEST_THR_GT_F1": (None if BEST_THR_GT_F1 is None else float(BEST_THR_GT_F1)),
    "BEST_THR_GT_F1_GE_EQUIV": (None if BEST_THR_GT_F1_GE_EQUIV is None else float(BEST_THR_GT_F1_GE_EQUIV)),

    "thr_table_eval": thr_sorted,
    "EVAL_REPORT_PATH": out_txt,
    "EVAL_TABLE_PATH": out_csv,
    "EVAL_TABLE_TOP500_PATH": out_csv_t,
    "EVAL_SUMMARY_PATH": out_json,
    "OOF_AUC_VALID_ONLY": roc_auc,
    "OOF_AP_VALID_ONLY": pr_auc,
})

gc.collect()


# Submission Build

In [None]:
# ============================================================
# STAGE 11 — Submission Build (ONE CELL) — REVISI FULL v3.5
#
# Upgrade v3.5:
# - Threshold JSON fallback FIX (supports Stage9 threshold_tuning.json & Eval v3.4 summary)
# - Auto-handle if pred file already 0/1 (test_pred_01.csv)
# - Stronger diagnostics + strict order = sample_submission
#
# Output:
# - /kaggle/working/submission.csv
# - SUB_DIR/submission.csv (copy)
# ============================================================

import gc, json, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Require STAGE 0 globals
# ----------------------------
for need in ["PATHS", "SUB_DIR"]:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan STAGE 0 dulu (setup).")

sample_path = Path(PATHS["SAMPLE_SUB"])
if not sample_path.exists():
    raise FileNotFoundError(f"Missing sample_submission.csv: {sample_path}")

# IMPORTANT: dtype object_id=str to prevent ID corruption
df_sub = pd.read_csv(sample_path, dtype={"object_id": str})
if not {"object_id", "prediction"}.issubset(df_sub.columns):
    raise ValueError(f"sample_submission must have columns object_id,prediction. Found: {list(df_sub.columns)}")

# ----------------------------
# Helpers
# ----------------------------
def _norm_id(x):
    if isinstance(x, (bytes, np.bytes_)):
        try:
            x = x.decode("utf-8", errors="ignore")
        except Exception:
            x = str(x)
    s = str(x).strip()
    if (s.startswith("b'") and s.endswith("'")) or (s.startswith('b"') and s.endswith('"')):
        s = s[2:-1]
    return s.strip()

def _as_1d_float32(arr):
    a = np.asarray(arr)
    if a.dtype == object and a.ndim == 0:
        try:
            a = np.asarray(a.item())
        except Exception:
            pass
    a = np.asarray(a, dtype=np.float32)
    if a.ndim == 0:
        return a.reshape(1)          # IMPORTANT: always 1D
    if a.ndim > 1:
        a = a.reshape(-1)
    return a

def _sanitize_prob(p):
    p = np.asarray(p, dtype=np.float32)
    p = np.nan_to_num(p, nan=0.0, posinf=1.0, neginf=0.0)
    p = np.clip(p, 0.0, 1.0)
    return p.astype(np.float32)

def _load_ids_npy(path: Path):
    arr = np.load(path, allow_pickle=False)
    ids = arr.tolist() if hasattr(arr, "tolist") else list(arr)
    return [_norm_id(x) for x in ids]

def _try_load_json(p):
    try:
        p = Path(p)
        if not p.exists():
            return None
        with open(p, "r", encoding="utf-8") as f:
            obj = json.load(f)
        return obj if isinstance(obj, dict) else None
    except Exception:
        return None

def _try_load_stage10_cfg():
    p = globals().get("TEST_INFER_CFG_PATH", None)
    if p is None:
        return None
    return _try_load_json(p)

def _detect_prob_col(df):
    # Prefer explicit names (probability)
    for c in ["prob", "proba", "oof_prob", "pred", "p"]:
        if c in df.columns:
            return c
    # Some users store probability in "prediction"
    if "prediction" in df.columns:
        # if it looks float-ish (not only 0/1), treat as prob
        s = pd.to_numeric(df["prediction"], errors="coerce")
        if s.notna().mean() > 0.95:
            u = set(np.unique(s.dropna().astype(float).to_numpy()).tolist())
            if not u.issubset({0.0, 1.0}):
                return "prediction"

    # else: pick a single mostly-numeric column besides object_id
    cand = [c for c in df.columns if c != "object_id"]
    floatish = []
    for c in cand:
        s = pd.to_numeric(df[c], errors="coerce")
        if s.notna().mean() > 0.95:
            floatish.append(c)
    if len(floatish) == 1:
        return floatish[0]
    return None

def _detect_binary_col(df):
    # Prefer explicit binary prediction column
    for c in ["prediction", "pred", "label", "y"]:
        if c in df.columns:
            s = pd.to_numeric(df[c], errors="coerce")
            if s.notna().mean() > 0.95:
                u = set(np.unique(s.dropna().astype(int).to_numpy()).tolist())
                if u.issubset({0, 1}):
                    return c
    return None

def _load_pred_df():
    """
    Return (df_pred, mode, src_str)
    mode:
      - "prob": df_pred has columns object_id, prob (float [0,1])
      - "bin" : df_pred has columns object_id, prediction (int 0/1)
    Priority:
      A) globals: test_ids + test_prob_ens
      B) STAGE10 config json -> outputs.test_pred_01_csv (bin) OR outputs.test_prob_ens_csv (prob)
      C) csv fallbacks
      D) npy fallback: FIX_DIR/test_ids.npy + test_prob_ens.npy
    """
    # ---- A) globals ----
    if ("test_prob_ens" in globals()) and (globals()["test_prob_ens"] is not None) and \
       ("test_ids" in globals()) and (globals()["test_ids"] is not None):
        ids = [_norm_id(x) for x in list(globals()["test_ids"])]
        prob = _sanitize_prob(_as_1d_float32(globals()["test_prob_ens"]))
        if len(ids) == len(prob) and len(ids) > 0:
            return pd.DataFrame({"object_id": ids, "prob": prob}), "prob", "globals(test_ids + test_prob_ens)"

    # ---- B) STAGE 10 config json ----
    cfg = _try_load_stage10_cfg()
    if isinstance(cfg, dict):
        out = cfg.get("outputs", {}) if isinstance(cfg.get("outputs", {}), dict) else {}
        # prefer binary if exists
        pred01 = out.get("test_pred_01_csv", None)
        if pred01:
            p = Path(pred01)
            if p.exists():
                df = pd.read_csv(p, dtype={"object_id": str})
                if "object_id" in df.columns:
                    colb = _detect_binary_col(df)
                    if colb is None:
                        raise RuntimeError(f"Cannot detect binary column in: {p} | cols={list(df.columns)}")
                    df = df.copy()
                    df["object_id"] = df["object_id"].apply(_norm_id)
                    y = pd.to_numeric(df[colb], errors="coerce").fillna(0).astype(int).clip(0,1).to_numpy()
                    return pd.DataFrame({"object_id": df["object_id"].tolist(), "prediction": y}), "bin", f"stage10_cfg_bin({p})"

        csvp = out.get("test_prob_ens_csv", None)
        if csvp:
            p = Path(csvp)
            if p.exists():
                df = pd.read_csv(p, dtype={"object_id": str})
                if "object_id" in df.columns:
                    colp = _detect_prob_col(df)
                    if colp is None:
                        raise RuntimeError(f"Cannot detect prob column in: {p} | cols={list(df.columns)}")
                    df = df.copy()
                    df["object_id"] = df["object_id"].apply(_norm_id)
                    prob = _sanitize_prob(_as_1d_float32(df[colp].to_numpy()))
                    if len(prob) != len(df):
                        raise RuntimeError(f"CSV prob length mismatch: {p}")
                    return pd.DataFrame({"object_id": df["object_id"].tolist(), "prob": prob}), "prob", f"stage10_cfg_csv({p})"

    # ---- C) csv fallback ----
    art_dir = Path(globals().get("ART_DIR", "/kaggle/working"))
    preds_dir = art_dir / "preds"

    cand_csv = []
    if "TEST_PRED01_PATH" in globals() and globals()["TEST_PRED01_PATH"] is not None:
        cand_csv.append(Path(globals()["TEST_PRED01_PATH"]))
    if "TEST_PROB_CSV_PATH" in globals() and globals()["TEST_PROB_CSV_PATH"] is not None:
        cand_csv.append(Path(globals()["TEST_PROB_CSV_PATH"]))

    cand_csv += [
        preds_dir / "test_pred_01.csv",
        preds_dir / "test_prob_ens.csv",
        art_dir / "test_pred_01.csv",
        art_dir / "test_prob_ens.csv",
    ]

    for p in cand_csv:
        if p.exists():
            df = pd.read_csv(p, dtype={"object_id": str})
            if "object_id" not in df.columns:
                continue

            # If binary file
            colb = _detect_binary_col(df)
            if colb is not None:
                df = df.copy()
                df["object_id"] = df["object_id"].apply(_norm_id)
                y = pd.to_numeric(df[colb], errors="coerce").fillna(0).astype(int).clip(0,1).to_numpy()
                return pd.DataFrame({"object_id": df["object_id"].tolist(), "prediction": y}), "bin", f"csv_bin({p})"

            # Else probability file
            colp = _detect_prob_col(df)
            if colp is None:
                raise RuntimeError(f"Cannot detect prob/binary column in: {p} | cols={list(df.columns)}")
            df = df.copy()
            df["object_id"] = df["object_id"].apply(_norm_id)
            prob = _sanitize_prob(_as_1d_float32(df[colp].to_numpy()))
            if len(prob) != len(df):
                raise RuntimeError(f"CSV prob length mismatch: {p}")
            return pd.DataFrame({"object_id": df["object_id"].tolist(), "prob": prob}), "prob", f"csv_prob({p})"

    # ---- D) npy fallback ----
    fix_dir = Path(globals().get("FIX_DIR", "/kaggle/working/mallorn_run/artifacts/fixed_seq"))
    p_ids = fix_dir / "test_ids.npy"
    if not p_ids.exists():
        raise RuntimeError("Missing test_ids. Pastikan STAGE 6 membuat fixed_seq/test_ids.npy atau STAGE 10 export test_ids.")
    ids = _load_ids_npy(p_ids)
    if len(ids) == 0:
        raise RuntimeError("test_ids.npy kosong.")

    cand_npy = []
    if "TEST_PROB_ENS_PATH" in globals() and globals()["TEST_PROB_ENS_PATH"] is not None:
        cand_npy.append(Path(globals()["TEST_PROB_ENS_PATH"]))
    cand_npy += [preds_dir / "test_prob_ens.npy", art_dir / "test_prob_ens.npy"]

    prob = None
    used = None
    for p in cand_npy:
        if p.exists():
            prob = _sanitize_prob(_as_1d_float32(np.load(p, allow_pickle=False)))
            used = p
            break
    if prob is None:
        raise RuntimeError("Missing test_prob_ens. Jalankan STAGE 10 dulu (Test Inference).")
    if len(prob) != len(ids):
        raise RuntimeError(f"Length mismatch (NPY): test_prob={len(prob)} vs test_ids={len(ids)}")

    return pd.DataFrame({"object_id": ids, "prob": prob}), "prob", f"npy({used}) + ids({p_ids})"

def _load_best_threshold_fallback():
    """
    Correctly parse thresholds from:
      - Stage 9: OOF_DIR/threshold_tuning.json
          * payload["default_best_thr"]["thr"]
          * payload["best_ge"]["best_thr_f1"]["thr"]  (if exists)
      - Eval v3.4/v3.5 summary: OOF_DIR/eval_summary.json
          * payload["default_best_thr"]["thr"]
          * payload["best_ge"]["f1"]["thr"]
    """
    art_dir = Path(globals().get("ART_DIR", "/kaggle/working"))
    oof_dir = Path(globals().get("OOF_DIR", art_dir / "oof"))

    cand = []
    for k in ["THR_JSON_PATH", "EVAL_SUMMARY_PATH"]:
        if k in globals() and globals()[k] is not None:
            cand.append(Path(globals()[k]))
    cand += [oof_dir / "threshold_tuning.json", oof_dir / "eval_summary.json"]

    def _dig(obj, path_list):
        cur = obj
        for key in path_list:
            if not isinstance(cur, dict) or key not in cur:
                return None
            cur = cur[key]
        return cur

    for p in cand:
        obj = _try_load_json(p)
        if not isinstance(obj, dict):
            continue
        name = Path(p).name

        # generic: default_best_thr.thr
        v = _dig(obj, ["default_best_thr", "thr"])
        if v is not None:
            try:
                return float(v)
            except Exception:
                pass

        if name == "threshold_tuning.json":
            v = _dig(obj, ["best_ge", "best_thr_f1", "thr"])
            if v is not None:
                try:
                    return float(v)
                except Exception:
                    pass

        if name == "eval_summary.json":
            v = _dig(obj, ["best_ge", "f1", "thr"])
            if v is not None:
                try:
                    return float(v)
                except Exception:
                    pass

        # last resort: scan known keys if user saved them
        for k in ["BEST_THR", "BEST_THR_F1", "BEST_THR_GE_F1"]:
            if k in obj:
                try:
                    return float(obj[k])
                except Exception:
                    pass

    return None

# ----------------------------
# 1) Load prediction df
# ----------------------------
df_pred, pred_mode, pred_src = _load_pred_df()
if df_pred is None or df_pred.empty:
    raise RuntimeError("df_pred empty (unexpected).")

df_pred = df_pred.copy()
df_pred["object_id"] = df_pred["object_id"].apply(_norm_id)

# strict: no duplicate ids
if df_pred["object_id"].duplicated().any():
    dup = df_pred.loc[df_pred["object_id"].duplicated(), "object_id"].iloc[:10].tolist()
    raise ValueError(f"Duplicated object_id in predictions (examples): {dup}")

if pred_mode == "prob":
    p = df_pred["prob"].to_numpy(dtype=np.float32, copy=False)
    if not np.isfinite(p).all():
        bad = int((~np.isfinite(p)).sum())
        raise ValueError(f"Found non-finite probabilities in test predictions: {bad} rows")
    df_pred["prob"] = _sanitize_prob(p)

    p2 = df_pred["prob"].to_numpy(dtype=np.float32, copy=False)
    print("[Stage 11] Loaded test predictions (PROB)")
    print(f"- source: {pred_src}")
    print(f"- N_pred={len(df_pred):,} | prob_mean={float(p2.mean()):.6f} | std={float(p2.std()):.6f} | min={float(p2.min()):.6f} | max={float(p2.max()):.6f}")
else:
    yb = pd.to_numeric(df_pred["prediction"], errors="coerce").fillna(0).astype(int).clip(0,1).to_numpy()
    df_pred["prediction"] = yb.astype(np.int8)
    print("[Stage 11] Loaded test predictions (BINARY 0/1)")
    print(f"- source: {pred_src}")
    print(f"- N_pred={len(df_pred):,} | pos_pred={int(df_pred['prediction'].sum()):,} ({float(df_pred['prediction'].mean())*100:.6f}%)")

# ----------------------------
# 2) Threshold selection (only if needed)
# ----------------------------
FORCE_THR = None  # set manual if you want, e.g. 0.37

thr_src = None
thr = None

if pred_mode == "prob":
    if FORCE_THR is not None:
        thr = float(FORCE_THR); thr_src = "FORCE_THR"
    elif "BEST_THR_F1" in globals() and globals()["BEST_THR_F1"] is not None:
        thr = float(globals()["BEST_THR_F1"]); thr_src = "globals(BEST_THR_F1)"
    elif "BEST_THR" in globals() and globals()["BEST_THR"] is not None:
        thr = float(globals()["BEST_THR"]); thr_src = "globals(BEST_THR)"
    else:
        fb = _load_best_threshold_fallback()
        if fb is not None:
            thr = float(fb); thr_src = "json_fallback(threshold_tuning/eval_summary)"
        else:
            thr = 0.5; thr_src = "default(0.5)"
    thr = float(np.clip(thr, 0.0, 1.0))

# ----------------------------
# 3) Align to sample_submission order + build output
# ----------------------------
df_sub = df_sub.copy()
df_sub["object_id"] = df_sub["object_id"].apply(_norm_id)

if df_sub["object_id"].duplicated().any():
    dup = df_sub.loc[df_sub["object_id"].duplicated(), "object_id"].iloc[:10].tolist()
    raise ValueError(f"sample_submission has duplicate object_id (unexpected). examples={dup}")

sample_ids = pd.Index(df_sub["object_id"].tolist())
pred_ids = pd.Index(df_pred["object_id"].tolist())

in_pred = sample_ids.isin(pred_ids)
in_sample = pred_ids.isin(sample_ids)

print(f"[Stage 11] ID coverage check:")
print(f"- sample_in_pred = {int(in_pred.sum()):,} / {len(sample_ids):,}")
print(f"- pred_in_sample = {int(in_sample.sum()):,} / {len(pred_ids):,}")

if (~in_sample).any():
    extra = pred_ids[~in_sample][:10].tolist()
    print(f"[Stage 11] WARN: predictions contain extra ids not in sample (show 10): {extra}")
if (~in_pred).any():
    miss = sample_ids[~in_pred][:10].tolist()
    print(f"[Stage 11] WARN: sample contains ids missing in predictions (show 10): {miss}")

# Merge in sample order (STRICT)
df_out = df_sub[["object_id"]].merge(df_pred, on="object_id", how="left", sort=False)

if pred_mode == "prob":
    if df_out["prob"].isna().any():
        missing_n = int(df_out["prob"].isna().sum())
        miss_ids = df_out.loc[df_out["prob"].isna(), "object_id"].iloc[:10].tolist()
        raise ValueError(
            f"Some sample_submission object_id have no prediction: missing_n={missing_n}. Examples: {miss_ids}\n"
            "Penyebab umum: object_id kebaca numeric (leading zero hilang) atau pred tidak lengkap."
        )
    df_out["prediction"] = (df_out["prob"].to_numpy(dtype=np.float32) >= np.float32(thr)).astype(np.int8)
else:
    if df_out["prediction"].isna().any():
        missing_n = int(df_out["prediction"].isna().sum())
        miss_ids = df_out.loc[df_out["prediction"].isna(), "object_id"].iloc[:10].tolist()
        raise ValueError(
            f"Some sample_submission object_id have no binary prediction: missing_n={missing_n}. Examples: {miss_ids}"
        )
    df_out["prediction"] = pd.to_numeric(df_out["prediction"], errors="coerce").fillna(0).astype(int).clip(0,1).astype(np.int8)

df_out = df_out[["object_id", "prediction"]]

# Strict checks
if len(df_out) != len(df_sub):
    raise RuntimeError("submission row count mismatch with sample_submission.")
if not df_out["object_id"].equals(df_sub["object_id"]):
    raise RuntimeError("submission order mismatch with sample_submission (must be identical).")

u = set(np.unique(df_out["prediction"].to_numpy()).tolist())
if not u.issubset({0, 1}):
    raise RuntimeError(f"submission prediction contains values outside {{0,1}}: {sorted(list(u))}")

pos_pred = int(df_out["prediction"].sum())
print("\n[Stage 11] SUBMISSION READY (BINARY 0/1)")
if pred_mode == "prob":
    print(f"- threshold_used={thr:.6f} | thr_source={thr_src}")
print(f"- rows={len(df_out):,} | pos_pred={pos_pred:,} ({pos_pred/max(len(df_out),1)*100:.6f}%)")

# ----------------------------
# 4) Write files
# ----------------------------
SUB_DIR = Path(SUB_DIR)
SUB_DIR.mkdir(parents=True, exist_ok=True)

out_main = Path("/kaggle/working/submission.csv")
out_copy = SUB_DIR / "submission.csv"

df_out.to_csv(out_main, index=False)
df_out.to_csv(out_copy, index=False)

print(f"- wrote: {out_main}")
print(f"- copy : {out_copy}")
print("\nPreview:")
print(df_out.head(8).to_string(index=False))

globals().update({
    "SUBMISSION_PATH": out_main,
    "SUBMISSION_COPY_PATH": out_copy,
    "SUBMISSION_MODE": "binary",
    "SUBMISSION_THRESHOLD": (None if pred_mode != "prob" else float(thr)),
    "SUBMISSION_THRESHOLD_SOURCE": (None if pred_mode != "prob" else thr_src),
    "SUBMISSION_PRED_SOURCE": pred_src,
})

gc.collect()
