In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/mallorn-dataset/sample_submission.csv
/kaggle/input/mallorn-dataset/test_log.csv
/kaggle/input/mallorn-dataset/train_log.csv
/kaggle/input/mallorn-dataset/split_17/train_full_lightcurves.csv
/kaggle/input/mallorn-dataset/split_17/test_full_lightcurves.csv
/kaggle/input/mallorn-dataset/split_01/train_full_lightcurves.csv
/kaggle/input/mallorn-dataset/split_01/test_full_lightcurves.csv
/kaggle/input/mallorn-dataset/split_02/train_full_lightcurves.csv
/kaggle/input/mallorn-dataset/split_02/test_full_lightcurves.csv
/kaggle/input/mallorn-dataset/split_08/train_full_lightcurves.csv
/kaggle/input/mallorn-dataset/split_08/test_full_lightcurves.csv
/kaggle/input/mallorn-dataset/split_04/train_full_lightcurves.csv
/kaggle/input/mallorn-dataset/split_04/test_full_lightcurves.csv
/kaggle/input/mallorn-dataset/split_07/train_full_lightcurves.csv
/kaggle/input/mallorn-dataset/split_07/test_full_lightcurves.csv
/kaggle/input/mallorn-dataset/split_15/train_full_lightcurves.csv
/kaggle/i

 # Kaggle CPU Environment Setup

In [3]:
# ============================================================
# STAGE 0 — Kaggle CPU Environment Setup (ONE CELL) — REVISI FULL v6
# Fokus v6:
# - Iterasi split pakai PATHS["SPLITS"] (bukan set) -> deterministik
# - Validasi: test_log == sample_submission (set & count)
# - Siapkan mapping: object_id -> split & meta
# - Guard untuk stage-stage berikutnya (resume friendly)
# ============================================================

import os, sys, gc, json, time, random, hashlib, warnings
from pathlib import Path

import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

SEED = 2025
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)

# ----------------------------
# CPU thread limits (anti-freeze)
# ----------------------------
THREADS = 2
for k in ["OMP_NUM_THREADS","OPENBLAS_NUM_THREADS","MKL_NUM_THREADS","VECLIB_MAXIMUM_THREADS","NUMEXPR_NUM_THREADS"]:
    os.environ.setdefault(k, str(THREADS))
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

try:
    import torch
    torch.manual_seed(SEED)
    torch.set_num_threads(THREADS)
    torch.set_num_interop_threads(1)
except Exception:
    torch = None

# ----------------------------
# PATHS
# ----------------------------
DATA_ROOT = Path("/kaggle/input/mallorn-dataset")
PATHS = {
    "DATA_ROOT": DATA_ROOT,
    "SAMPLE_SUB": DATA_ROOT / "sample_submission.csv",
    "TRAIN_LOG":  DATA_ROOT / "train_log.csv",
    "TEST_LOG":   DATA_ROOT / "test_log.csv",
    "SPLITS":     [DATA_ROOT / f"split_{i:02d}" for i in range(1, 21)],
}

# ----------------------------
# WORKDIR (versioned run)
# ----------------------------
WORKDIR = Path("/kaggle/working")
BASE_RUN_DIR = WORKDIR / "mallorn_run"
BASE_RUN_DIR.mkdir(parents=True, exist_ok=True)

CFG = {
    # Pipeline toggles
    "USE_GBDT": True,
    "USE_DEEP_LITE": False,      # CNN binned (CPU-friendly). Nyalakan kalau mau hybrid.
    "USE_HYBRID_BLEND": False,   # blend GBDT + DEEP_LITE
    "USE_THRESHOLD_TUNING": True,

    # Feature settings
    "USE_ASINH_FLUX": True,
    "SNR_CLIP": 30.0,
    "SNR_DET_THR": 3.0,
    "SNR_STRONG_THR": 5.0,
    "MIN_FLUXERR": 1e-6,

    # Streaming
    "CHUNK_ROWS": 200_000,

    # CV
    "N_FOLDS": 5,
    "CV_STRATIFY": True,
    "CV_USE_SPLIT_COL": True,    # pakai group=split agar anti leakage split
    # Note: jika sklearn tidak punya StratifiedGroupKFold, fallback GroupKFold

    # Deep-lite (binned)
    "BINS": 48,                  # time bins
    "DEEP_EPOCHS": 8,            # CPU-friendly
    "DEEP_BS": 128,
}

def _hash_cfg(d: dict) -> str:
    s = json.dumps(d, sort_keys=True)
    return hashlib.sha256(s.encode("utf-8")).hexdigest()[:10]

CFG_HASH = _hash_cfg(CFG)
RUN_TAG = time.strftime("%Y%m%d_%H%M%S")
RUN_DIR = BASE_RUN_DIR / f"run_{RUN_TAG}_{CFG_HASH}"

ART_DIR   = RUN_DIR / "artifacts"
CACHE_DIR = RUN_DIR / "cache"
OOF_DIR   = RUN_DIR / "oof"
SUB_DIR   = RUN_DIR / "submissions"
LOG_DIR   = RUN_DIR / "logs"
FEAT_DIR  = CACHE_DIR / "features"
SEQ_DIR   = CACHE_DIR / "seq"      # untuk deep-lite bins

for d in [RUN_DIR, ART_DIR, CACHE_DIR, OOF_DIR, SUB_DIR, LOG_DIR, FEAT_DIR, SEQ_DIR]:
    d.mkdir(parents=True, exist_ok=True)

def _must_exist(p: Path, what: str):
    if not p.exists():
        raise FileNotFoundError(f"[MISSING] {what}: {p}")

_must_exist(PATHS["SAMPLE_SUB"], "sample_submission.csv")
_must_exist(PATHS["TRAIN_LOG"],  "train_log.csv")
_must_exist(PATHS["TEST_LOG"],   "test_log.csv")

for sd in PATHS["SPLITS"]:
    tr = sd / "train_full_lightcurves.csv"
    te = sd / "test_full_lightcurves.csv"
    _must_exist(tr, f"{sd.name}/train_full_lightcurves.csv")
    _must_exist(te, f"{sd.name}/test_full_lightcurves.csv")

SAFE_NA_VALUES = ["", " ", "NA", "NaN", "nan", "NULL", "null", "None", "none"]
SAFE_READ_KW = dict(low_memory=False, na_values=SAFE_NA_VALUES, keep_default_na=True)

def _norm_cols(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df.columns = [c.strip() for c in df.columns]
    return df

# ----------------------------
# Load logs + sample
# ----------------------------
df_sub = _norm_cols(pd.read_csv(PATHS["SAMPLE_SUB"], dtype={"object_id": "string"}, **SAFE_READ_KW))
if not {"object_id", "prediction"}.issubset(df_sub.columns):
    raise ValueError(f"sample_submission must have object_id,prediction. Found: {list(df_sub.columns)}")

df_train_log = _norm_cols(pd.read_csv(PATHS["TRAIN_LOG"], dtype={"object_id":"string","split":"string"}, **SAFE_READ_KW))
df_test_log  = _norm_cols(pd.read_csv(PATHS["TEST_LOG"],  dtype={"object_id":"string","split":"string"}, **SAFE_READ_KW))

need_train = {"object_id","EBV","Z","split","target"}
need_test  = {"object_id","EBV","Z","split"}
if not need_train.issubset(df_train_log.columns):
    raise ValueError(f"train_log missing: {sorted(list(need_train-set(df_train_log.columns)))}")
if not need_test.issubset(df_test_log.columns):
    raise ValueError(f"test_log missing: {sorted(list(need_test-set(df_test_log.columns)))}")

def _normalize_split(x):
    if pd.isna(x): return ""
    s = str(x).strip()
    if not s: return ""
    if s.isdigit(): return f"split_{int(s):02d}"
    s = s.lower().replace("-", "_").replace(" ", "_")
    if s.startswith("split_"):
        tail = s.split("split_",1)[1].strip("_")
        if tail.isdigit():
            return f"split_{int(tail):02d}"
    return s

df_train_log["split"] = df_train_log["split"].map(_normalize_split)
df_test_log["split"]  = df_test_log["split"].map(_normalize_split)

valid_splits = {f"split_{i:02d}" for i in range(1,21)}
bad_tr = sorted(set(df_train_log["split"]) - valid_splits)
bad_te = sorted(set(df_test_log["split"]) - valid_splits)
if bad_tr: raise ValueError(f"Invalid split in train_log: {bad_tr[:10]}")
if bad_te: raise ValueError(f"Invalid split in test_log:  {bad_te[:10]}")

for col in ["EBV","Z"]:
    df_train_log[col] = pd.to_numeric(df_train_log[col], errors="coerce")
    df_test_log[col]  = pd.to_numeric(df_test_log[col],  errors="coerce")
    if df_train_log[col].isna().any():
        raise ValueError(f"train_log {col} has NaN after numeric coercion: {int(df_train_log[col].isna().sum())}")
    if df_test_log[col].isna().any():
        raise ValueError(f"test_log {col} has NaN after numeric coercion: {int(df_test_log[col].isna().sum())}")

df_train_log["target"] = pd.to_numeric(df_train_log["target"], errors="coerce")
if df_train_log["target"].isna().any():
    raise ValueError(f"train_log target NaN after coercion: {int(df_train_log['target'].isna().sum())}")
u = set(pd.unique(df_train_log["target"]).tolist())
if not u.issubset({0,1}):
    raise ValueError(f"train_log target must be 0/1. Found: {sorted(list(u))}")

# Z_err handling
if "Z_err" not in df_test_log.columns:
    df_test_log["Z_err"] = np.nan
df_test_log["Z_err"] = pd.to_numeric(df_test_log["Z_err"], errors="coerce")
df_test_log["has_zerr"] = (~df_test_log["Z_err"].isna()).astype("int8")
df_test_log["Z_err"] = df_test_log["Z_err"].fillna(0.0)

if "Z_err" not in df_train_log.columns:
    df_train_log["Z_err"] = 0.0
df_train_log["Z_err"] = pd.to_numeric(df_train_log["Z_err"], errors="coerce").fillna(0.0)
df_train_log["has_zerr"] = 0

# Uniqueness
if df_train_log["object_id"].duplicated().any():
    raise ValueError(f"Duplicated object_id in train_log: {int(df_train_log['object_id'].duplicated().sum())}")
if df_test_log["object_id"].duplicated().any():
    raise ValueError(f"Duplicated object_id in test_log:  {int(df_test_log['object_id'].duplicated().sum())}")

# sample_submission alignment check
sub_ids = df_sub["object_id"].astype("string")
test_ids = df_test_log["object_id"].astype("string")
if len(sub_ids) != len(test_ids):
    raise ValueError(f"Row mismatch: sample_submission={len(sub_ids)} vs test_log={len(test_ids)}")

s_sub = set(sub_ids.tolist())
s_tst = set(test_ids.tolist())
if s_sub != s_tst:
    missing_in_test = list(s_sub - s_tst)[:5]
    missing_in_sub  = list(s_tst - s_sub)[:5]
    raise ValueError(
        "sample_submission and test_log object_id set mismatch.\n"
        f"- sample not in test_log (up to5): {missing_in_test}\n"
        f"- test_log not in sample (up to5): {missing_in_sub}"
    )

# Basic counts
pos = int((df_train_log["target"]==1).sum())
neg = int((df_train_log["target"]==0).sum())
tot = int(len(df_train_log))

print("ENV OK (Kaggle CPU)")
print(f"- Python: {sys.version.split()[0]}")
print(f"- Numpy:  {np.__version__}")
print(f"- Pandas: {pd.__version__}")
if torch is not None:
    print(f"- Torch:  {torch.__version__} | CUDA: {torch.cuda.is_available()}")
else:
    print("- Torch:  not available")

print("\nDATA OK")
print(f"- train_log objects: {tot:,} | pos={pos:,} | neg={neg:,} | pos%={(pos/max(tot,1))*100:.2f}%")
print(f"- test_log objects:  {len(df_test_log):,}")
print(f"- sample_submission: {len(df_sub):,}")
print(f"- splits: {len(PATHS['SPLITS'])} folders (01..20)")

# Save snapshot
snap = {
    "SEED": SEED,
    "CFG": CFG,
    "CFG_HASH": CFG_HASH,
    "RUN_DIR": str(RUN_DIR),
    "DATA_ROOT": str(DATA_ROOT),
}
with open(RUN_DIR / "config_stage0.json", "w", encoding="utf-8") as f:
    json.dump(snap, f, indent=2)

globals().update({
    "SEED": SEED, "THREADS": THREADS, "CFG": CFG, "CFG_HASH": CFG_HASH,
    "PATHS": PATHS, "DATA_ROOT": DATA_ROOT, "RUN_DIR": RUN_DIR,
    "ART_DIR": ART_DIR, "CACHE_DIR": CACHE_DIR, "OOF_DIR": OOF_DIR,
    "SUB_DIR": SUB_DIR, "LOG_DIR": LOG_DIR, "FEAT_DIR": FEAT_DIR, "SEQ_DIR": SEQ_DIR,
    "df_sub": df_sub, "df_train_log": df_train_log, "df_test_log": df_test_log
})

gc.collect()


ENV OK (Kaggle CPU)
- Python: 3.12.12
- Numpy:  2.0.2
- Pandas: 2.2.2
- Torch:  2.8.0+cu126 | CUDA: False

DATA OK
- train_log objects: 3,043 | pos=148 | neg=2,895 | pos%=4.86%
- test_log objects:  7,135
- sample_submission: 7,135
- splits: 20 folders (01..20)


63

# Verify Dataset Paths & Split Discovery

In [4]:
# ============================================================
# STAGE 1 — Split Routing + Lightcurve Micro-Profiling (ONE CELL, CPU-SAFE)
# REVISI FULL v3 (IO-LEBIH IRIT + DETERM + LEBIH BANYAK STATS)
#
# Output:
# - logs/split_routing.csv
# - logs/lc_sample_stats.csv
# - logs/stage1_summary.json
#
# Catatan:
# - Stage ini tidak "meningkatkan akurasi" langsung, tapi memastikan ALL SPLITS kebaca
#   dan memberi statistik penting untuk tuning stage modelling berikutnya.
# ============================================================

import re, gc, json, time
from pathlib import Path
import numpy as np
import pandas as pd

# ----------------------------
# 0) Require STAGE 0 globals
# ----------------------------
need0 = ["PATHS", "df_train_log", "df_test_log"]
for need in need0:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan STAGE 0 dulu.")

DATA_ROOT = Path(PATHS["DATA_ROOT"])

# deterministic split list (JANGAN set)
SPLIT_LIST = [f"split_{i:02d}" for i in range(1, 21)]
VALID_SPLITS = set(SPLIT_LIST)

# split dirs mapping (deterministic)
SPLIT_DIRS = {p.name: p for p in PATHS["SPLITS"]}
# Optional dirs from stage 0
RUN_DIR = Path(globals().get("RUN_DIR", "/kaggle/working/mallorn_run"))
LOG_DIR = Path(globals().get("LOG_DIR", RUN_DIR / "logs"))
LOG_DIR.mkdir(parents=True, exist_ok=True)

CFG_LOCAL = globals().get("CFG", {})
SEED = int(globals().get("SEED", 2025))

# ----------------------------
# 1) Safe read config (konsisten dengan STAGE 0)
# ----------------------------
SAFE_NA_VALUES = ["", " ", "NA", "NaN", "nan", "NULL", "null", "None", "none"]
SAFE_READ_KW = dict(low_memory=False, na_values=SAFE_NA_VALUES, keep_default_na=True)

# sampling knobs (CPU-safe)
HEAD_ROWS = int(CFG_LOCAL.get("LC_HEAD_ROWS", 2000))           # sample rows per file
SAMPLE_ID_PER_SPLIT = int(CFG_LOCAL.get("SAMPLE_ID_PER_SPLIT", 5))
CHUNK_ROWS = int(CFG_LOCAL.get("CHUNK_ROWS", 200_000))
MAX_CHUNKS_PER_FILE = int(CFG_LOCAL.get("MAX_CHUNKS_PER_FILE", 6))

# numeric policy thresholds (ketat untuk Time/Flux_err)
MAX_TIME_NA_FRAC = float(CFG_LOCAL.get("MAX_TIME_NA_FRAC", 0.02))
MAX_FERR_NA_FRAC = float(CFG_LOCAL.get("MAX_FERR_NA_FRAC", 0.02))
ID_MISS_FAIL_FRAC = float(CFG_LOCAL.get("ID_MISS_FAIL_FRAC", 0.80))  # fail jika >= 80% ID sample tidak ketemu (scan cap)
MIN_SAMPLE_ROWS = int(CFG_LOCAL.get("MIN_SAMPLE_ROWS", 100))         # fail kalau sample terlalu kosong

# ----------------------------
# 2) Helpers
# ----------------------------
REQ_LC_COLS = ["object_id", "Time (MJD)", "Flux", "Flux_err", "Filter"]
REQ_LC_COLS_SET = set(REQ_LC_COLS)
ALLOWED_FILTERS = {"u", "g", "r", "i", "z", "y"}

def _norm_cols(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df.columns = [c.strip() for c in df.columns]
    return df

def normalize_split_name(x) -> str:
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return ""
    s = str(x).strip()
    if not s:
        return ""
    s2 = s.lower().replace("-", "_").replace(" ", "_")
    if s2.isdigit():
        return f"split_{int(s2):02d}"
    m = re.fullmatch(r"split_(\d{1,2})", s2)
    if m:
        return f"split_{int(m.group(1)):02d}"
    m = re.fullmatch(r"split(\d{1,2})", s2)
    if m:
        return f"split_{int(m.group(1)):02d}"
    return s2

def sizeof_mb(p: Path) -> float:
    try:
        return p.stat().st_size / (1024**2)
    except Exception:
        return float("nan")

def _read_sample_df(p: Path, nrows: int):
    """
    Single read for:
    - schema presence (via usecols)
    - filter sanity
    - numeric coercion sample stats
    """
    try:
        dfh = pd.read_csv(p, usecols=REQ_LC_COLS, nrows=nrows, **SAFE_READ_KW)
    except ValueError as e:
        # Usually "Usecols do not match columns"
        # Provide clearer diagnostics
        df0 = pd.read_csv(p, nrows=0, **SAFE_READ_KW)
        cols = [c.strip() for c in df0.columns]
        miss = sorted(list(REQ_LC_COLS_SET - set(cols)))
        raise ValueError(
            f"[LC SCHEMA] {p} missing required columns: {miss}. Found columns: {cols}"
        ) from e
    return _norm_cols(dfh)

def _numeric_and_filter_stats(dfh: pd.DataFrame):
    out = {"n_sample": int(len(dfh))}
    if len(dfh) == 0:
        out.update({
            "time_na_frac": 1.0, "flux_na_frac": 1.0, "ferr_na_frac": 1.0,
            "filter_bad": "", "filter_sample": ""
        })
        return out

    # Filter values
    filt = dfh["Filter"].astype("string").str.strip().str.lower()
    filt = filt[~filt.isna()]
    uniq = sorted(set(filt.tolist()))
    bad = sorted([v for v in uniq if v not in ALLOWED_FILTERS])
    out["filter_bad"] = ",".join(bad[:10]) if bad else ""
    out["filter_sample"] = ",".join(uniq[:10]) if uniq else ""

    # Band coverage from sample
    if len(filt) > 0:
        vc = filt.value_counts()
        denom = float(vc.sum())
        for b in ["u","g","r","i","z","y"]:
            out[f"frac_{b}"] = float(vc.get(b, 0) / denom)
    else:
        for b in ["u","g","r","i","z","y"]:
            out[f"frac_{b}"] = 0.0

    # Numeric coercion
    t = pd.to_numeric(dfh["Time (MJD)"], errors="coerce")
    f = pd.to_numeric(dfh["Flux"], errors="coerce")
    e = pd.to_numeric(dfh["Flux_err"], errors="coerce")

    out["time_na_frac"] = float(t.isna().mean())
    out["flux_na_frac"] = float(f.isna().mean())
    out["ferr_na_frac"] = float(e.isna().mean())

    # Quick stats (ignore NaN)
    if (~t.isna()).any():
        out["time_min"] = float(t.min())
        out["time_max"] = float(t.max())
        out["time_span"] = float(t.max() - t.min())
    else:
        out["time_min"] = np.nan
        out["time_max"] = np.nan
        out["time_span"] = np.nan

    if (~f.isna()).any():
        fd = f.dropna()
        out["flux_neg_frac"] = float((fd < 0).mean())
        out["flux_p01"] = float(np.quantile(fd, 0.01))
        out["flux_p50"] = float(np.quantile(fd, 0.50))
        out["flux_p99"] = float(np.quantile(fd, 0.99))
    else:
        out["flux_neg_frac"] = np.nan
        out["flux_p01"] = np.nan
        out["flux_p50"] = np.nan
        out["flux_p99"] = np.nan

    if (~e.isna()).any():
        ed = e.dropna()
        out["ferr_min"] = float(ed.min())
        out["ferr_p50"] = float(np.quantile(ed, 0.50))
        out["ferr_p99"] = float(np.quantile(ed, 0.99))
        out["ferr_neg_any"] = int((ed < 0).any())
    else:
        out["ferr_min"] = np.nan
        out["ferr_p50"] = np.nan
        out["ferr_p99"] = np.nan
        out["ferr_neg_any"] = 0

    return out

def _sample_id_presence(csv_path: Path, want_ids: set, chunk_rows: int, max_chunks: int):
    """
    Limited scan memastikan beberapa object_id dari log benar-benar muncul di file.
    Scan hanya kolom object_id (lebih murah).
    """
    if not want_ids:
        return 0, set(), 0
    remaining = set(want_ids)
    found = set()
    nread_chunks = 0

    for i, chunk in enumerate(pd.read_csv(csv_path, usecols=["object_id"], chunksize=chunk_rows, **SAFE_READ_KW)):
        nread_chunks += 1
        ids = set(chunk["object_id"].astype("string"))
        hit = remaining & ids
        if hit:
            found |= hit
            remaining -= hit
        if not remaining:
            break
        if i + 1 >= max_chunks:
            break

    return len(found), remaining, nread_chunks

# ----------------------------
# 3) Normalize split col in logs (idempotent)
# ----------------------------
for df, name in [(df_train_log, "train_log"), (df_test_log, "test_log")]:
    if "split" not in df.columns:
        raise ValueError(f"{name} missing 'split' column.")
    df["split"] = df["split"].map(normalize_split_name)

bad_train_split = sorted(set(df_train_log["split"].unique()) - VALID_SPLITS)
bad_test_split  = sorted(set(df_test_log["split"].unique())  - VALID_SPLITS)
if bad_train_split:
    raise ValueError(f"train_log has invalid split values (examples): {bad_train_split[:10]}")
if bad_test_split:
    raise ValueError(f"test_log has invalid split values (examples): {bad_test_split[:10]}")

# ----------------------------
# 4) Verify disk splits set + required files exist
# ----------------------------
disk_splits = set(SPLIT_DIRS.keys())
missing_dirs = sorted(list(VALID_SPLITS - disk_splits))
extra_dirs   = sorted(list(disk_splits - VALID_SPLITS))
if missing_dirs or extra_dirs:
    msg = []
    if missing_dirs: msg.append(f"Missing split folders: {missing_dirs[:10]}")
    if extra_dirs:   msg.append(f"Unexpected split folders: {extra_dirs[:10]}")
    raise RuntimeError("Split folder set mismatch.\n" + "\n".join(msg))

missing_files = []
for sp in SPLIT_LIST:
    sd = SPLIT_DIRS[sp]
    for kind in ["train", "test"]:
        p = sd / f"{kind}_full_lightcurves.csv"
        if not p.exists():
            missing_files.append(str(p))
if missing_files:
    raise FileNotFoundError("Some lightcurve files missing (showing up to 10):\n" + "\n".join(missing_files[:10]))

# ----------------------------
# 5) Build routing manifest (40 file)
# ----------------------------
train_counts = df_train_log["split"].value_counts().to_dict()
test_counts  = df_test_log["split"].value_counts().to_dict()

routing_rows = []
for sp in SPLIT_LIST:
    sd = SPLIT_DIRS[sp]
    tr = sd / "train_full_lightcurves.csv"
    te = sd / "test_full_lightcurves.csv"
    routing_rows.append({
        "split": sp,
        "train_csv": str(tr),
        "test_csv": str(te),
        "train_mb": sizeof_mb(tr),
        "test_mb": sizeof_mb(te),
        "n_train_objects_log": int(train_counts.get(sp, 0)),
        "n_test_objects_log":  int(test_counts.get(sp, 0)),
    })

df_routing = pd.DataFrame(routing_rows)
routing_path = LOG_DIR / "split_routing.csv"
df_routing.to_csv(routing_path, index=False)

# ----------------------------
# 6) Sample profiling + ID crosscheck (single read sample)
# ----------------------------
stats_rows = []
warn_flux_na_files = 0

t0 = time.time()

for sp in SPLIT_LIST:
    sd = SPLIT_DIRS[sp]
    for kind in ["train", "test"]:
        p = sd / f"{kind}_full_lightcurves.csv"

        # read sample once
        dfh = _read_sample_df(p, nrows=HEAD_ROWS)

        # minimum sanity: sample should not be empty
        if len(dfh) < MIN_SAMPLE_ROWS:
            raise ValueError(f"[LC SAMPLE] Too few rows sampled from {p} (n={len(dfh)}). Possible read issue.")

        # compute stats
        st = _numeric_and_filter_stats(dfh)

        # Filter sanity
        if st.get("filter_bad", ""):
            raise ValueError(f"[LC FILTER] Unexpected Filter values in {p}: {st['filter_bad']} (sample={st.get('filter_sample','')})")

        # numeric policy
        if st.get("time_na_frac", 0.0) > MAX_TIME_NA_FRAC:
            raise ValueError(f"[LC NUM] Time(MJD) NaN too high in sample: {p} frac={st['time_na_frac']:.4f}")
        if st.get("ferr_na_frac", 0.0) > MAX_FERR_NA_FRAC:
            raise ValueError(f"[LC NUM] Flux_err NaN too high in sample: {p} frac={st['ferr_na_frac']:.4f}")
        if int(st.get("ferr_neg_any", 0)) == 1:
            raise ValueError(f"[LC NUM] Negative Flux_err detected in sample of {p} (should be >=0).")

        if st.get("flux_na_frac", 0.0) > 0:
            warn_flux_na_files += 1

        # sample ID crosscheck (limited scan)
        if kind == "train":
            ids = df_train_log.loc[df_train_log["split"] == sp, "object_id"].astype("string")
        else:
            ids = df_test_log.loc[df_test_log["split"] == sp, "object_id"].astype("string")

        k = min(SAMPLE_ID_PER_SPLIT, len(ids))
        want = set(ids.sample(n=k, random_state=SEED).tolist()) if k > 0 else set()

        found_n, missing_ids, nread_chunks = _sample_id_presence(p, want, CHUNK_ROWS, MAX_CHUNKS_PER_FILE)
        miss_frac = (len(missing_ids) / max(len(want), 1)) if want else 0.0

        if want and miss_frac >= ID_MISS_FAIL_FRAC:
            raise ValueError(
                f"[LC ID] Severe mismatch within limited scan: {p} missing {len(missing_ids)}/{len(want)} "
                f"(chunks_read={nread_chunks}). Example missing: {list(missing_ids)[:3]}"
            )
        if want and missing_ids:
            print(f"[WARN] ID limited-scan miss: split={sp} kind={kind} miss {len(missing_ids)}/{len(want)} (chunks_read={nread_chunks})")

        row = {
            "split": sp,
            "kind": kind,
            "file": str(p),
            "file_mb": sizeof_mb(p),
            "n_sample": st.get("n_sample", 0),
            "time_na_frac": st.get("time_na_frac", np.nan),
            "flux_na_frac": st.get("flux_na_frac", np.nan),
            "ferr_na_frac": st.get("ferr_na_frac", np.nan),
            "time_min": st.get("time_min", np.nan),
            "time_max": st.get("time_max", np.nan),
            "time_span": st.get("time_span", np.nan),
            "flux_neg_frac": st.get("flux_neg_frac", np.nan),
            "flux_p01": st.get("flux_p01", np.nan),
            "flux_p50": st.get("flux_p50", np.nan),
            "flux_p99": st.get("flux_p99", np.nan),
            "ferr_min": st.get("ferr_min", np.nan),
            "ferr_p50": st.get("ferr_p50", np.nan),
            "ferr_p99": st.get("ferr_p99", np.nan),
            "filter_sample": st.get("filter_sample", ""),
            "id_check_k": int(len(want)),
            "id_found": int(found_n),
            "id_missing": int(len(missing_ids)),
            "id_scan_chunks": int(nread_chunks),
        }
        # band fractions
        for b in ["u","g","r","i","z","y"]:
            row[f"frac_{b}"] = st.get(f"frac_{b}", 0.0)

        stats_rows.append(row)

df_lc_stats = pd.DataFrame(stats_rows)
lc_stats_path = LOG_DIR / "lc_sample_stats.csv"
df_lc_stats.to_csv(lc_stats_path, index=False)

# ----------------------------
# 7) Summary prints + JSON summary
# ----------------------------
elapsed = time.time() - t0

# worst by flux_na_frac (informational only)
worst_flux_na = (
    df_lc_stats.sort_values("flux_na_frac", ascending=False)
    .head(8)[["split","kind","flux_na_frac","time_na_frac","ferr_na_frac","file_mb"]]
)

# worst by time_na_frac (hard policy already, but show)
worst_time_na = (
    df_lc_stats.sort_values("time_na_frac", ascending=False)
    .head(8)[["split","kind","time_na_frac","flux_na_frac","ferr_na_frac","file_mb"]]
)

summary = {
    "stage": "stage1",
    "data_root": str(DATA_ROOT),
    "log_dir": str(LOG_DIR),
    "head_rows": HEAD_ROWS,
    "sample_id_per_split": SAMPLE_ID_PER_SPLIT,
    "chunk_rows": CHUNK_ROWS,
    "max_chunks_per_file": MAX_CHUNKS_PER_FILE,
    "thresholds": {
        "MAX_TIME_NA_FRAC": MAX_TIME_NA_FRAC,
        "MAX_FERR_NA_FRAC": MAX_FERR_NA_FRAC,
        "ID_MISS_FAIL_FRAC": ID_MISS_FAIL_FRAC,
        "MIN_SAMPLE_ROWS": MIN_SAMPLE_ROWS
    },
    "warn_flux_na_files": int(warn_flux_na_files),
    "routing_csv": str(routing_path),
    "lc_sample_stats_csv": str(lc_stats_path),
    "elapsed_sec": float(elapsed),
}

summary_path = LOG_DIR / "stage1_summary.json"
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, indent=2)

print("STAGE 1 OK — SPLIT ROUTING READY")
print(f"- routing saved: {routing_path}")
print(f"- lc sample stats saved: {lc_stats_path}")
print(f"- summary json saved: {summary_path}")
print(f"- elapsed: {elapsed/60:.2f} min | warn_flux_na_files={warn_flux_na_files}")

print("\nOBJECT COUNTS (from logs)")
for sp in SPLIT_LIST:
    print(f"- {sp}: train={int(train_counts.get(sp,0)):,} | test={int(test_counts.get(sp,0)):,}")

print("\nWORST SAMPLE (highest flux_na_frac in sample head)")
print(worst_flux_na.to_string(index=False))

print("\nWORST SAMPLE (highest time_na_frac in sample head)")
print(worst_time_na.to_string(index=False))

# ----------------------------
# 8) Export to globals (dipakai stage berikutnya)
# ----------------------------
globals().update({
    "DATA_ROOT": DATA_ROOT,
    "SPLIT_DIRS": SPLIT_DIRS,
    "SPLIT_LIST": SPLIT_LIST,
    "df_split_routing": df_routing,
    "df_lc_sample_stats": df_lc_stats,
    "STAGE1_SUMMARY_PATH": summary_path,
})

gc.collect()
print("\nStage 1 complete: splits verified + routing/stats exported.")


STAGE 1 OK — SPLIT ROUTING READY
- routing saved: /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/logs/split_routing.csv
- lc sample stats saved: /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/logs/lc_sample_stats.csv
- summary json saved: /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/logs/stage1_summary.json
- elapsed: 0.04 min | warn_flux_na_files=18

OBJECT COUNTS (from logs)
- split_01: train=155 | test=364
- split_02: train=170 | test=414
- split_03: train=138 | test=338
- split_04: train=145 | test=332
- split_05: train=165 | test=375
- split_06: train=155 | test=374
- split_07: train=165 | test=398
- split_08: train=162 | test=387
- split_09: train=128 | test=289
- split_10: train=144 | test=331
- split_11: train=146 | test=325
- split_12: train=155 | test=353
- split_13: train=143 | test=379
- split_14: train=154 | test=351
- split_15: train=158 | test=342
- split_16: train=155 | test=354
- split_17: train=153 | test=351
- split_18: train=152

# Load and Validate Train/Test Logs

In [6]:
# ============================================================
# STAGE 2 — Clean Meta Logs + CV Fold Assignment (ONE CELL, CPU-SAFE)
# REVISI FULL v4 (FIX: 5-FOLD TERPAKAI + Z_ERR TIDAK MATI)
#
# Output:
#   * df_train_meta, df_test_meta (index=object_id)
#   * id2split_train, id2split_test
#   * artifacts: train_meta.(parquet|csv), test_meta.(parquet|csv)
#   * artifacts: split_stats.csv, train_folds.csv
#   * artifacts: id2split_train.json, id2split_test.json
#   * artifacts: split2fold.json (kalau CV_USE_SPLIT_COL=True)
#
# Notes:
# - Tidak load full lightcurves.
# - EBV/Z clip pakai TRAIN saja (anti leakage).
# - Z_err clip pakai TEST quantiles (tanpa label) -> supaya tidak 0..0.
# - Fold strategy:
#     * Jika CFG["CV_USE_SPLIT_COL"]=True -> assign split->fold dengan quota 4 split/fold
#     * else -> StratifiedKFold object-level
# ============================================================

import re, gc, json, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Require STAGE 0/1 globals
# ----------------------------
for need in ["PATHS", "ART_DIR", "SPLIT_DIRS", "CFG", "SEED"]:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan STAGE 0 & STAGE 1 dulu.")

TRAIN_LOG_PATH = Path(PATHS["TRAIN_LOG"])
TEST_LOG_PATH  = Path(PATHS["TEST_LOG"])

ART_DIR = Path(ART_DIR)
ART_DIR.mkdir(parents=True, exist_ok=True)

SEED = int(SEED)
N_FOLDS = int(CFG.get("N_FOLDS", 5))
CV_USE_SPLIT_COL = bool(CFG.get("CV_USE_SPLIT_COL", True))

SPLIT_LIST = [f"split_{i:02d}" for i in range(1, 21)]
VALID_SPLITS = set(SPLIT_LIST)

disk_splits = set(SPLIT_DIRS.keys())
if disk_splits != VALID_SPLITS:
    miss = sorted(list(VALID_SPLITS - disk_splits))
    extra = sorted(list(disk_splits - VALID_SPLITS))
    raise RuntimeError(f"SPLIT_DIRS mismatch. missing={miss[:5]} extra={extra[:5]} (jalankan ulang STAGE 1)")

SAFE_NA_VALUES = ["", " ", "NA", "NaN", "nan", "NULL", "null", "None", "none"]
SAFE_READ_KW = dict(low_memory=False, na_values=SAFE_NA_VALUES, keep_default_na=True)

# ----------------------------
# 1) Helpers
# ----------------------------
def normalize_split_name(x) -> str:
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return ""
    s = str(x).strip()
    if not s:
        return ""
    s2 = s.lower().replace("-", "_").replace(" ", "_")
    if s2.isdigit():
        return f"split_{int(s2):02d}"
    m = re.fullmatch(r"split_(\d{1,2})", s2)
    if m:
        return f"split_{int(m.group(1)):02d}"
    m = re.fullmatch(r"split(\d{1,2})", s2)
    if m:
        return f"split_{int(m.group(1)):02d}"
    return s2

def _norm_cols(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    df.columns = [c.strip() for c in df.columns]
    return df

def _coerce_float32(df: pd.DataFrame, col: str):
    if col in df.columns:
        df[col] = pd.to_numeric(df[col], errors="coerce").astype("float32")

def _safe_clip(series: pd.Series, lo: float, hi: float) -> pd.Series:
    return series.clip(lower=np.float32(lo), upper=np.float32(hi)).astype("float32")

def _qclip_bounds(arr: np.ndarray, qlo=0.001, qhi=0.999, default=(0.0, 0.0)):
    x = np.asarray(arr, dtype=float)
    x = x[np.isfinite(x)]
    if len(x) == 0:
        return float(default[0]), float(default[1])
    lo, hi = np.quantile(x, [qlo, qhi])
    return float(lo), float(hi)

def _load_or_use_global(global_name: str, path: Path) -> pd.DataFrame:
    if global_name in globals() and isinstance(globals()[global_name], pd.DataFrame):
        return _norm_cols(globals()[global_name].copy())
    return _norm_cols(pd.read_csv(path, dtype={"object_id":"string","split":"string"}, **SAFE_READ_KW))

# ----------------------------
# 2) Load logs
# ----------------------------
df_train = _load_or_use_global("df_train_log", TRAIN_LOG_PATH)
df_test  = _load_or_use_global("df_test_log",  TEST_LOG_PATH)

# ----------------------------
# 3) Required columns check
# ----------------------------
req_common = {"object_id", "split", "EBV", "Z"}
req_train  = req_common | {"target"}
req_test   = req_common

miss_train = sorted(list(req_train - set(df_train.columns)))
miss_test  = sorted(list(req_test  - set(df_test.columns)))
if miss_train:
    raise ValueError(f"train_log missing columns: {miss_train} | found={list(df_train.columns)}")
if miss_test:
    raise ValueError(f"test_log missing columns: {miss_test} | found={list(df_test.columns)}")

# ----------------------------
# 4) Basic cleaning
# ----------------------------
df_train["object_id"] = df_train["object_id"].astype("string").str.strip()
df_test["object_id"]  = df_test["object_id"].astype("string").str.strip()

df_train["split"] = df_train["split"].astype("string").map(normalize_split_name)
df_test["split"]  = df_test["split"].astype("string").map(normalize_split_name)

bad_train_split = sorted(set(df_train["split"].unique()) - VALID_SPLITS)
bad_test_split  = sorted(set(df_test["split"].unique())  - VALID_SPLITS)
if bad_train_split:
    raise ValueError(f"train_log invalid split values: {bad_train_split[:10]}")
if bad_test_split:
    raise ValueError(f"test_log invalid split values: {bad_test_split[:10]}")

# ----------------------------
# 5) Ensure Z_err exists + numeric coercion
# ----------------------------
if "Z_err" not in df_train.columns:
    df_train["Z_err"] = np.nan
if "Z_err" not in df_test.columns:
    df_test["Z_err"] = np.nan

# has_zerr BEFORE fill
df_train["has_zerr"] = (~pd.to_numeric(df_train["Z_err"], errors="coerce").isna()).astype("int8")
df_test["has_zerr"]  = (~pd.to_numeric(df_test["Z_err"],  errors="coerce").isna()).astype("int8")

for c in ["EBV","Z","Z_err"]:
    _coerce_float32(df_train, c)
    _coerce_float32(df_test, c)

# ----------------------------
# 6) Duplicate / overlap checks
# ----------------------------
if df_train["object_id"].duplicated().any():
    ex = df_train.loc[df_train["object_id"].duplicated(), "object_id"].head(5).tolist()
    raise ValueError(f"Duplicated object_id in train_log (examples): {ex}")
if df_test["object_id"].duplicated().any():
    ex = df_test.loc[df_test["object_id"].duplicated(), "object_id"].head(5).tolist()
    raise ValueError(f"Duplicated object_id in test_log (examples): {ex}")

overlap = set(df_train["object_id"].tolist()) & set(df_test["object_id"].tolist())
if overlap:
    raise ValueError(f"object_id overlap train vs test (examples): {list(overlap)[:5]}")

# ----------------------------
# 7) Target validation
# ----------------------------
df_train["target"] = pd.to_numeric(df_train["target"], errors="coerce")
if df_train["target"].isna().any():
    raise ValueError(f"train_log target has NaN after coercion: {int(df_train['target'].isna().sum())} rows.")
uniq_t = set(pd.unique(df_train["target"]).tolist())
if not uniq_t.issubset({0,1}):
    raise ValueError(f"train_log target must be binary 0/1. Found: {sorted(list(uniq_t))}")
df_train["target"] = df_train["target"].astype("int8")

# ----------------------------
# 8) Missing flags + fills
# ----------------------------
for df in [df_train, df_test]:
    df["EBV_missing"] = df["EBV"].isna().astype("int8")
    df["Z_missing"]   = df["Z"].isna().astype("int8")
    df["Zerr_missing"]= df["Z_err"].isna().astype("int8")

df_train["EBV"] = df_train["EBV"].fillna(np.float32(0.0)).astype("float32")
df_test["EBV"]  = df_test["EBV"].fillna(np.float32(0.0)).astype("float32")

# Z fill pakai TRAIN stats
train_split_med = df_train.groupby("split")["Z"].median().to_dict()
train_gmed = float(np.nanmedian(df_train["Z"].values.astype(float))) if np.isfinite(df_train["Z"].values.astype(float)).any() else 0.0
train_gmed = np.float32(train_gmed)

def _fill_z(df: pd.DataFrame, split_med: dict, gmed: np.float32):
    z = df["Z"].copy()
    if z.isna().any():
        z = z.fillna(df["split"].map(split_med))
        z = z.fillna(gmed)
    return z.astype("float32")

df_train["Z"] = _fill_z(df_train, train_split_med, train_gmed)
df_test["Z"]  = _fill_z(df_test,  train_split_med, train_gmed)

# Z_err fill NaN -> 0 (train mostly 0, test sebagian ada)
df_train["Z_err"] = df_train["Z_err"].fillna(np.float32(0.0)).astype("float32")
df_test["Z_err"]  = df_test["Z_err"].fillna(np.float32(0.0)).astype("float32")

# photo-z indicator (dataset-level)
df_train["is_photoz"] = np.int8(0)
df_test["is_photoz"]  = np.int8(1)

# ----------------------------
# 9) Clipping + derived meta features
# ----------------------------
# EBV/Z clip pakai TRAIN (anti leakage)
EBV_LO, EBV_HI = _qclip_bounds(df_train["EBV"].values, 0.001, 0.999)
Z_LO,   Z_HI   = _qclip_bounds(df_train["Z"].values,   0.001, 0.999)

df_train["EBV_clip"] = _safe_clip(df_train["EBV"], EBV_LO, EBV_HI)
df_test["EBV_clip"]  = _safe_clip(df_test["EBV"],  EBV_LO, EBV_HI)

df_train["Z_clip"] = _safe_clip(df_train["Z"], Z_LO, Z_HI)
df_test["Z_clip"]  = _safe_clip(df_test["Z"],  Z_LO, Z_HI)

# Z_err clip: PAKAI TEST quantiles supaya tidak 0..0
# (tanpa label, aman)
# guard: kalau test juga semua 0, tetap aman -> hi=0
ZE_LO = 0.0
_, ZE_HI = _qclip_bounds(df_test["Z_err"].values, 0.001, 0.999, default=(0.0, 0.0))
ZE_HI = max(float(ZE_HI), 0.0)

df_train["Zerr_clip"] = _safe_clip(df_train["Z_err"], ZE_LO, ZE_HI)
df_test["Zerr_clip"]  = _safe_clip(df_test["Z_err"],  ZE_LO, ZE_HI)

df_train["log1pZ"] = np.log1p(df_train["Z_clip"].astype("float32")).astype("float32")
df_test["log1pZ"]  = np.log1p(df_test["Z_clip"].astype("float32")).astype("float32")

df_train["log1pZerr"] = np.log1p(df_train["Zerr_clip"].astype("float32")).astype("float32")
df_test["log1pZerr"]  = np.log1p(df_test["Zerr_clip"].astype("float32")).astype("float32")

eps = np.float32(1e-6)
df_train["zerr_rel"] = (df_train["Zerr_clip"] / (df_train["Z_clip"] + eps)).astype("float32")
df_test["zerr_rel"]  = (df_test["Zerr_clip"]  / (df_test["Z_clip"]  + eps)).astype("float32")

split2id = {f"split_{i:02d}": i for i in range(1, 21)}
df_train["split_id"] = df_train["split"].map(split2id).astype("int16")
df_test["split_id"]  = df_test["split"].map(split2id).astype("int16")

# ----------------------------
# 10) Fold assignment
# ----------------------------
df_train["fold"] = -1

if CV_USE_SPLIT_COL:
    # Quota per fold: 20 split / 5 fold = 4 split/fold
    total_splits = len(SPLIT_LIST)
    if total_splits % N_FOLDS != 0:
        # tetap bisa jalan, tapi quota pakai ceil
        quota = int(np.ceil(total_splits / N_FOLDS))
    else:
        quota = int(total_splits / N_FOLDS)

    sp_stat = (
        df_train.groupby("split")["target"]
        .agg(["count","sum"])
        .rename(columns={"count":"n","sum":"pos"})
        .reindex(SPLIT_LIST)
        .fillna(0)
        .astype({"n":int,"pos":int})
        .reset_index()
    )
    sp_stat["neg"] = sp_stat["n"] - sp_stat["pos"]
    sp_stat["pos_rate"] = sp_stat["pos"] / sp_stat["n"].clip(lower=1)

    # sort: split paling "berat" dulu
    sp_stat = sp_stat.sort_values(["pos","n"], ascending=False).reset_index(drop=True)

    global_pos_rate = float(df_train["target"].mean())
    target_fold_n = float(len(df_train) / max(N_FOLDS, 1))

    fold_n = np.zeros(N_FOLDS, dtype=float)
    fold_pos = np.zeros(N_FOLDS, dtype=float)
    fold_k = np.zeros(N_FOLDS, dtype=int)
    split2fold = {}

    rng = np.random.default_rng(SEED)

    for _, r in sp_stat.iterrows():
        sp = r["split"]
        n  = float(r["n"])
        p  = float(r["pos"])

        # candidate folds yang masih belum penuh quota split
        cand = np.where(fold_k < quota)[0]
        if len(cand) == 0:
            # fallback: semua sudah quota (harusnya tidak terjadi jika quota benar)
            cand = np.arange(N_FOLDS)

        scores = []
        for f in cand:
            n2 = fold_n[f] + n
            p2 = fold_pos[f] + p
            pr2 = (p2 / n2) if n2 > 0 else global_pos_rate

            # score = keseimbangan class + keseimbangan size + penalti kalau fold mendekati penuh
            score = abs(pr2 - global_pos_rate) \
                    + 0.20 * abs(n2 - target_fold_n) / max(target_fold_n, 1.0) \
                    + 0.05 * (fold_k[f] / max(quota, 1))

            scores.append(score)

        scores = np.asarray(scores, dtype=float)
        best_idx = np.where(scores == scores.min())[0]
        choose = int(cand[int(rng.choice(best_idx))]) if len(best_idx) > 1 else int(cand[int(best_idx[0])])

        split2fold[sp] = choose
        fold_n[choose] += n
        fold_pos[choose] += p
        fold_k[choose] += 1

    # apply
    df_train["fold"] = df_train["split"].map(split2fold).astype("int16")

    # HARD GUARD: semua fold 0..K-1 harus muncul dan tidak kosong
    uniq_folds = sorted(df_train["fold"].unique().tolist())
    if uniq_folds != list(range(N_FOLDS)):
        print(f"[WARN] split->fold tidak memakai semua fold. uniq_folds={uniq_folds}. Fallback ke StratifiedKFold object-level.")
        CV_USE_SPLIT_COL = False
    else:
        # save mapping untuk audit
        with open(ART_DIR / "split2fold.json", "w", encoding="utf-8") as f:
            json.dump({k:int(v) for k,v in split2fold.items()}, f)

if not CV_USE_SPLIT_COL:
    try:
        from sklearn.model_selection import StratifiedKFold
        skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
        y = df_train["target"].to_numpy()
        idx = np.arange(len(df_train))
        for fold_id, (_, va_idx) in enumerate(skf.split(idx, y)):
            df_train.iloc[va_idx, df_train.columns.get_loc("fold")] = fold_id
        df_train["fold"] = df_train["fold"].astype("int16")
    except Exception as e:
        print(f"[WARN] StratifiedKFold unavailable ({type(e).__name__}). Using round-robin fallback.")
        df_train["fold"] = -1
        rng = np.random.default_rng(SEED)
        pos_idx = df_train.index[df_train["target"] == 1].to_numpy()
        neg_idx = df_train.index[df_train["target"] == 0].to_numpy()
        rng.shuffle(pos_idx); rng.shuffle(neg_idx)
        for j, ii in enumerate(pos_idx):
            df_train.at[ii, "fold"] = int(j % N_FOLDS)
        for j, ii in enumerate(neg_idx):
            df_train.at[ii, "fold"] = int(j % N_FOLDS)
        df_train["fold"] = df_train["fold"].astype("int16")

# hard check
if (df_train["fold"] < 0).any():
    n_bad = int((df_train["fold"] < 0).sum())
    raise RuntimeError(f"Fold assignment gagal: ada {n_bad} baris fold=-1")

# ----------------------------
# 11) Build meta tables (index=object_id)
# ----------------------------
keep_train = [
    "object_id","split","split_id",
    "EBV","EBV_clip","Z","Z_clip","log1pZ",
    "Z_err","Zerr_clip","log1pZerr","zerr_rel",
    "EBV_missing","Z_missing","Zerr_missing","has_zerr","is_photoz",
    "fold","target"
]
keep_test = [
    "object_id","split","split_id",
    "EBV","EBV_clip","Z","Z_clip","log1pZ",
    "Z_err","Zerr_clip","log1pZerr","zerr_rel",
    "EBV_missing","Z_missing","Zerr_missing","has_zerr","is_photoz"
]

if "SpecType" in df_train.columns:
    keep_train.append("SpecType")

df_train_meta = df_train[keep_train].copy().set_index("object_id", drop=True).sort_index()
df_test_meta  = df_test[keep_test].copy().set_index("object_id", drop=True).sort_index()

id2split_train = df_train_meta["split"].to_dict()
id2split_test  = df_test_meta["split"].to_dict()

# ----------------------------
# 12) Save artifacts
# ----------------------------
train_pq = ART_DIR / "train_meta.parquet"
test_pq  = ART_DIR / "test_meta.parquet"
train_csv = ART_DIR / "train_meta.csv"
test_csv  = ART_DIR / "test_meta.csv"

try:
    df_train_meta.to_parquet(train_pq, index=True)
    df_test_meta.to_parquet(test_pq, index=True)
    saved_train, saved_test = str(train_pq), str(test_pq)
except Exception:
    df_train_meta.to_csv(train_csv, index=True)
    df_test_meta.to_csv(test_csv, index=True)
    saved_train, saved_test = str(train_csv), str(test_csv)

split_stats = pd.DataFrame({
    "train_objects": df_train_meta["split"].value_counts().reindex(SPLIT_LIST).fillna(0).astype(int),
    "test_objects":  df_test_meta["split"].value_counts().reindex(SPLIT_LIST).fillna(0).astype(int),
})
split_stats.index.name = "split"
pos_by_split = df_train_meta.groupby("split")["target"].sum().reindex(SPLIT_LIST).fillna(0).astype(int)
split_stats["train_pos"] = pos_by_split.values
split_stats["train_pos_rate"] = (split_stats["train_pos"] / split_stats["train_objects"].clip(lower=1)).astype("float32")

split_stats_path = ART_DIR / "split_stats.csv"
split_stats.to_csv(split_stats_path)

fold_path = ART_DIR / "train_folds.csv"
df_train_meta.reset_index()[["object_id","split","fold","target"]].to_csv(fold_path, index=False)

with open(ART_DIR / "id2split_train.json", "w", encoding="utf-8") as f:
    json.dump(id2split_train, f)
with open(ART_DIR / "id2split_test.json", "w", encoding="utf-8") as f:
    json.dump(id2split_test, f)

# ----------------------------
# 13) Print summary
# ----------------------------
pos = int((df_train_meta["target"] == 1).sum())
neg = int((df_train_meta["target"] == 0).sum())
tot = int(len(df_train_meta))
pos_rate = pos / max(tot, 1)
scale_pos_weight = float(neg / max(pos, 1))

print("STAGE 2 OK — META READY (clean + folds)")
print(f"- CV_USE_SPLIT_COL: {CV_USE_SPLIT_COL} | N_FOLDS={N_FOLDS}")
print(f"- train objects: {tot:,} | pos(TDE)={pos:,} | neg={neg:,} | pos%={pos_rate*100:.3f}%")
print(f"- test objects : {len(df_test_meta):,}")
print(f"- saved train  : {saved_train}")
print(f"- saved test   : {saved_test}")
print(f"- saved stats  : {split_stats_path}")
print(f"- saved folds  : {fold_path}")
print(f"- scale_pos_weight (neg/pos): {scale_pos_weight:.3f}")

print("\nCLIP RANGES")
print(f"- EBV clip (train): [{EBV_LO:.6f}, {EBV_HI:.6f}]")
print(f"- Z   clip (train): [{Z_LO:.6f}, {Z_HI:.6f}]")
print(f"- Zerr clip (test ): [{ZE_LO:.6f}, {ZE_HI:.6f}]")

fold_tab = (
    df_train_meta.reset_index().groupby("fold")["target"]
    .agg(["count","sum"]).rename(columns={"sum":"pos"})
    .reindex(range(N_FOLDS)).fillna(0)
)
fold_tab["pos_rate"] = fold_tab["pos"] / fold_tab["count"].clip(lower=1)
print("\nFOLD BALANCE (count/pos/pos_rate) — MUST SHOW 0..K-1")
print(fold_tab.to_string())

# ----------------------------
# 14) Export globals
# ----------------------------
globals().update({
    "df_train_meta": df_train_meta,
    "df_test_meta": df_test_meta,
    "id2split_train": id2split_train,
    "id2split_test": id2split_test,
    "split_stats": split_stats,
    "split2id": split2id,
    "scale_pos_weight": scale_pos_weight,
    "CV_USE_SPLIT_COL_USED": CV_USE_SPLIT_COL,
})

gc.collect()


STAGE 2 OK — META READY (clean + folds)
- CV_USE_SPLIT_COL: True | N_FOLDS=5
- train objects: 3,043 | pos(TDE)=148 | neg=2,895 | pos%=4.864%
- test objects : 7,135
- saved train  : /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/train_meta.parquet
- saved test   : /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/test_meta.parquet
- saved stats  : /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/split_stats.csv
- saved folds  : /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/train_folds.csv
- scale_pos_weight (neg/pos): 19.561

CLIP RANGES
- EBV clip (train): [0.005042, 0.581790]
- Z   clip (train): [0.044923, 4.032352]
- Zerr clip (test ): [0.000000, 0.106846]

FOLD BALANCE (count/pos/pos_rate) — MUST SHOW 0..K-1
      count  pos  pos_rate
fold                      
0       607   20  0.032949
1       565    6  0.010619
2       632   53  0.083861
3       641   27  0.042122
4       598   42  0.070234


64

# Lightcurve Loading Strategy

In [9]:
# ============================================================
# STAGE 3 — Robust Lightcurve Loader Utilities (ONE CELL, Kaggle CPU-SAFE)
# REVISI FULL v3.1 (FIX IndexError + routing by split lebih aman)
#
# FIX utama:
# - groupby().groups menghasilkan label object_id (string), bukan integer positions
#   -> jangan pakai df.index[idx]; pakai idx langsung / idx.astype(str).tolist()
# ============================================================

import gc, re
from pathlib import Path
import numpy as np
import pandas as pd

# ----------------------------
# 0) Require previous stages
# ----------------------------
for need in ["SPLIT_DIRS", "SPLIT_LIST", "df_train_meta", "df_test_meta", "ART_DIR", "CFG", "SEED"]:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan STAGE 0 -> STAGE 1 -> STAGE 2 dulu.")

ART_DIR = Path(ART_DIR)
ART_DIR.mkdir(parents=True, exist_ok=True)

SEED = int(SEED)
MIN_FLUXERR = float(CFG.get("MIN_FLUXERR", 1e-6))

# konsisten dengan STAGE 0/2
SAFE_NA_VALUES = ["", " ", "NA", "NaN", "nan", "NULL", "null", "None", "none"]
SAFE_READ_KW = dict(low_memory=False, na_values=SAFE_NA_VALUES, keep_default_na=True)

REQ_LC_KEYS = ["object_id", "mjd", "flux", "flux_err", "filter"]
ALLOWED_FILTERS = {"u", "g", "r", "i", "z", "y"}
FILTER_ORDER = {"u":0, "g":1, "r":2, "i":3, "z":4, "y":5}

# ----------------------------
# 1) Build split file mapping (train/test lightcurves)
# ----------------------------
SPLIT_FILES = {}
for s in SPLIT_LIST:
    sd = Path(SPLIT_DIRS[s])
    tr = sd / "train_full_lightcurves.csv"
    te = sd / "test_full_lightcurves.csv"
    if (not tr.exists()) or (not te.exists()):
        raise FileNotFoundError(f"Missing lightcurve file(s) in {sd}: train={tr.exists()} test={te.exists()}")
    SPLIT_FILES[s] = {"train": tr, "test": te}

# Save split file manifest
manifest = []
for s in SPLIT_LIST:
    p_tr = SPLIT_FILES[s]["train"]
    p_te = SPLIT_FILES[s]["test"]
    manifest.append({
        "split": s,
        "train_path": str(p_tr),
        "test_path": str(p_te),
        "train_mb": float(p_tr.stat().st_size) / (1024**2),
        "test_mb":  float(p_te.stat().st_size) / (1024**2),
    })
df_manifest = pd.DataFrame(manifest).sort_values("split")
manifest_path = ART_DIR / "split_file_manifest.csv"
df_manifest.to_csv(manifest_path, index=False)

# ----------------------------
# 2) Build object routing by split (FIX IndexError)
# ----------------------------
train_ids_by_split = {s: [] for s in SPLIT_LIST}
test_ids_by_split  = {s: [] for s in SPLIT_LIST}

# groups: split -> Index(labels=object_id)
tr_groups = df_train_meta.groupby("split").groups
te_groups = df_test_meta.groupby("split").groups

for sp, idx in tr_groups.items():
    if sp in train_ids_by_split:
        # idx sudah berisi object_id labels
        train_ids_by_split[sp] = pd.Index(idx).astype(str).tolist()

for sp, idx in te_groups.items():
    if sp in test_ids_by_split:
        test_ids_by_split[sp] = pd.Index(idx).astype(str).tolist()

# sanity
if sum(len(v) for v in train_ids_by_split.values()) != len(df_train_meta):
    raise RuntimeError("Routing train_ids_by_split mismatch total vs df_train_meta length.")
if sum(len(v) for v in test_ids_by_split.values()) != len(df_test_meta):
    raise RuntimeError("Routing test_ids_by_split mismatch total vs df_test_meta length.")

df_counts = pd.DataFrame({
    "split": SPLIT_LIST,
    "train_objects": [len(train_ids_by_split[s]) for s in SPLIT_LIST],
    "test_objects":  [len(test_ids_by_split[s]) for s in SPLIT_LIST],
})
counts_path = ART_DIR / "object_counts_by_split.csv"
df_counts.to_csv(counts_path, index=False)

# ----------------------------
# 3) Robust header mapping -> canonical columns
# ----------------------------
_LC_CFG_CACHE = {}  # (split_name, which) -> cfg dict

def _canon_col(x: str) -> str:
    s = str(x).strip().lower()
    s = s.replace("\ufeff", "")
    s = re.sub(r"\s+", "", s)
    s = s.replace("(", "").replace(")", "")
    s = s.replace("-", "_")
    return s

def _build_lc_read_cfg(p: Path):
    h = pd.read_csv(p, nrows=0, **SAFE_READ_KW)
    orig_cols = list(h.columns)

    c2o = {}
    for c in orig_cols:
        k = _canon_col(c)
        if k not in c2o:
            c2o[k] = c

    obj_col = c2o.get("object_id", None)

    time_col = None
    for k in ["time_mjd", "timemjd", "mjd", "time"]:
        if k in c2o:
            time_col = c2o[k]
            break

    flux_col = c2o.get("flux", None)

    ferr_col = None
    for k in ["flux_err", "fluxerr", "fluxerror"]:
        if k in c2o:
            ferr_col = c2o[k]
            break

    filt_col = c2o.get("filter", None)

    missing = []
    if obj_col is None:  missing.append("object_id")
    if time_col is None: missing.append("Time (MJD)")
    if flux_col is None: missing.append("Flux")
    if ferr_col is None: missing.append("Flux_err")
    if filt_col is None: missing.append("Filter")
    if missing:
        raise ValueError(
            f"Missing required lightcurve columns in {p.name}: {missing}. "
            f"Header sample: {orig_cols[:20]}"
        )

    usecols = [obj_col, time_col, flux_col, ferr_col, filt_col]
    rename = {obj_col:"object_id", time_col:"mjd", flux_col:"flux", ferr_col:"flux_err", filt_col:"filter"}

    # enforce string only for id/filter (numeric we'll coerce later)
    dtypes = {obj_col:"string", filt_col:"string"}

    return {"usecols": usecols, "dtype": dtypes, "rename": rename}

def _normalize_lc_chunk(df: pd.DataFrame, drop_bad_filter: bool = True, drop_bad_mjd: bool = True):
    df = df[["object_id","mjd","flux","flux_err","filter"]].copy()

    df["object_id"] = df["object_id"].astype("string").str.strip()
    df["filter"] = df["filter"].astype("string").str.strip().str.lower()

    df.loc[~df["filter"].isin(list(ALLOWED_FILTERS)), "filter"] = pd.NA

    df["mjd"] = pd.to_numeric(df["mjd"], errors="coerce").astype("float32")
    df["flux"] = pd.to_numeric(df["flux"], errors="coerce").astype("float32")
    df["flux_err"] = pd.to_numeric(df["flux_err"], errors="coerce").astype("float32")

    fe = df["flux_err"]
    if MIN_FLUXERR > 0:
        df.loc[fe.notna() & (fe < MIN_FLUXERR), "flux_err"] = np.float32(MIN_FLUXERR)

    df = df[df["object_id"].notna() & (df["object_id"] != "")]
    if drop_bad_filter:
        df = df[df["filter"].notna()]
    if drop_bad_mjd:
        df = df[df["mjd"].notna()]

    return df[REQ_LC_KEYS]

# ----------------------------
# 4) Chunked readers
# ----------------------------
def iter_lightcurve_chunks(
    split_name: str,
    which: str,
    chunksize: int = 400_000,
    drop_bad_filter: bool = True,
    drop_bad_mjd: bool = True
):
    if split_name not in SPLIT_FILES:
        raise KeyError(f"Unknown split_name={split_name}.")
    if which not in ("train", "test"):
        raise ValueError("which must be 'train' or 'test'")

    p = SPLIT_FILES[split_name][which]
    key = (split_name, which)
    if key not in _LC_CFG_CACHE:
        _LC_CFG_CACHE[key] = _build_lc_read_cfg(p)
    cfg = _LC_CFG_CACHE[key]

    reader = pd.read_csv(
        p,
        usecols=cfg["usecols"],
        dtype=cfg["dtype"],
        chunksize=int(chunksize),
        **SAFE_READ_KW
    )
    for chunk in reader:
        chunk = chunk.rename(columns=cfg["rename"])
        yield _normalize_lc_chunk(chunk, drop_bad_filter=drop_bad_filter, drop_bad_mjd=drop_bad_mjd)

def load_object_lightcurve(
    object_id: str,
    which: str,
    chunksize: int = 400_000,
    sort_time: bool = True,
    max_chunks: int = None,
    stop_after_found_block: bool = True
):
    object_id = str(object_id).strip()

    if which == "train":
        if object_id not in df_train_meta.index:
            raise KeyError(f"object_id not found in df_train_meta: {object_id}")
        split_name = str(df_train_meta.loc[object_id, "split"])
    elif which == "test":
        if object_id not in df_test_meta.index:
            raise KeyError(f"object_id not found in df_test_meta: {object_id}")
        split_name = str(df_test_meta.loc[object_id, "split"])
    else:
        raise ValueError("which must be 'train' or 'test'")

    pieces = []
    seen = 0
    found_any = False
    last_hit = False

    for ch in iter_lightcurve_chunks(split_name, which, chunksize=chunksize):
        seen += 1
        sub = ch[ch["object_id"] == object_id]
        hit = (len(sub) > 0)
        if hit:
            pieces.append(sub)
            found_any = True

        if stop_after_found_block and found_any and last_hit and (not hit):
            break
        last_hit = hit

        if max_chunks is not None and seen >= int(max_chunks):
            break

    if not pieces:
        out = pd.DataFrame(columns=REQ_LC_KEYS)
    else:
        out = pd.concat(pieces, ignore_index=True)
        if sort_time and len(out) > 1:
            out["filter_ord"] = out["filter"].map(FILTER_ORDER).astype("int16")
            out = out.sort_values(["mjd", "filter_ord"], kind="mergesort").drop(columns=["filter_ord"]).reset_index(drop=True)

    return out

# ----------------------------
# 5) Quick smoke test
# ----------------------------
_smoke_splits = ["split_01", "split_08", "split_17"]
for s in _smoke_splits:
    if len(train_ids_by_split.get(s, [])) == 0 or len(test_ids_by_split.get(s, [])) == 0:
        raise RuntimeError(f"Split {s} has 0 objects in train/test meta (unexpected).")

    ch_tr = next(iter_lightcurve_chunks(s, "train", chunksize=50_000))
    ch_te = next(iter_lightcurve_chunks(s, "test",  chunksize=50_000))

    if list(ch_tr.columns) != REQ_LC_KEYS:
        raise RuntimeError(f"Train chunk schema mismatch in {s}: {list(ch_tr.columns)}")
    if list(ch_te.columns) != REQ_LC_KEYS:
        raise RuntimeError(f"Test chunk schema mismatch in {s}: {list(ch_te.columns)}")

    badf_tr = sorted(set(ch_tr["filter"].dropna().unique()) - ALLOWED_FILTERS)
    badf_te = sorted(set(ch_te["filter"].dropna().unique()) - ALLOWED_FILTERS)
    if badf_tr or badf_te:
        raise ValueError(f"Unexpected filter values in smoke chunk split={s}: train_bad={badf_tr} test_bad={badf_te}")

print("STAGE 3 OK — LIGHTCURVE LOADING UTILITIES READY")
print(f"- Saved manifest: {manifest_path}")
print(f"- Saved counts  : {counts_path}")

globals().update({
    "SPLIT_FILES": SPLIT_FILES,
    "train_ids_by_split": train_ids_by_split,
    "test_ids_by_split": test_ids_by_split,
    "iter_lightcurve_chunks": iter_lightcurve_chunks,
    "load_object_lightcurve": load_object_lightcurve,
    "REQ_LC_KEYS": REQ_LC_KEYS,
    "ALLOWED_FILTERS": ALLOWED_FILTERS,
    "FILTER_ORDER": FILTER_ORDER,
})

gc.collect()


STAGE 3 OK — LIGHTCURVE LOADING UTILITIES READY
- Saved manifest: /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/split_file_manifest.csv
- Saved counts  : /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/object_counts_by_split.csv


303

# Photometric Cleaning (De-extinction + Negative Flux Safe Transform)

In [10]:
# ============================================================
# STAGE 4 — Photometric Cleaning (FORCE OVERWRITE) — REVISI FULL v6.2
# FIX v6.2:
# - Safety guard: Path.is_relative_to (anti false-positive substring)
# - Atomic parquet: tmp file tetap .parquet (part_0000.tmp.parquet)
# - Filter normalization tanpa np.char (aman untuk <NA>/mixed dtype)
# - EBV pakai EBV_clip jika tersedia (lebih stabil), fallback EBV
# ============================================================

import gc, json, warnings, time, shutil
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Require previous stages
# ----------------------------
for need in ["iter_lightcurve_chunks", "df_train_meta", "df_test_meta", "ART_DIR", "SPLIT_LIST"]:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan STAGE 0 -> 1 -> 2 -> 3 dulu.")

ART_DIR = Path(ART_DIR)
ART_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# 1) Settings
# ----------------------------
CHUNKSIZE   = 350_000
ERR_EPS     = 1e-6
SNR_DET     = 3.0
DET_SIGMA   = 3.0

MIN_FLUX_POS_UJY   = 1e-6
MAG_MIN, MAG_MAX   = -10.0, 50.0
MAGERR_FLOOR_DET   = 1e-3
MAGERR_FLOOR_ND    = 0.75
MAGERR_CAP         = 10.0

WRITE_FORMAT = "parquet"   # "parquet" or "csv.gz"
ONLY_SPLITS  = None        # e.g. ["split_01"] untuk test cepat
KEEP_FLUX_DEBUG = False
DROP_BAD_TIME_ROWS = True

# FORCE overwrite
REBUILD_MODE = "wipe_all"  # "wipe_all" | "wipe_parts_only"

# ----------------------------
# 2) Extinction coefficients (placeholder; ganti kalau punya nilai resmi)
# ----------------------------
EXT_RLAMBDA = {"u": 4.8, "g": 3.6, "r": 2.7, "i": 2.1, "z": 1.6, "y": 1.3}
BAND2ID = {"u": 0, "g": 1, "r": 2, "i": 3, "z": 4, "y": 5}
ID2BAND = {v: k for k, v in BAND2ID.items()}

# pakai EBV_clip jika ada (lebih stabil), fallback EBV
EBV_TRAIN_SER = df_train_meta["EBV_clip"] if "EBV_clip" in df_train_meta.columns else df_train_meta["EBV"]
EBV_TEST_SER  = df_test_meta["EBV_clip"]  if "EBV_clip"  in df_test_meta.columns  else df_test_meta["EBV"]

MAG_ZP = float(2.5 * np.log10(3631e6))  # ~23.9

# ----------------------------
# 3) Output root + WIPE (with safety guard)
# ----------------------------
LC_CLEAN_DIR = ART_DIR / "lc_clean_mag"

art_abs = ART_DIR.resolve()
lc_abs  = LC_CLEAN_DIR.resolve()

# robust safety guard
try:
    ok_rel = lc_abs.is_relative_to(art_abs)
except AttributeError:
    # very old python fallback (shouldn't happen on Kaggle py3.12)
    ok_rel = str(lc_abs).startswith(str(art_abs) + "/") or str(lc_abs).startswith(str(art_abs) + "\\")

if not ok_rel:
    raise RuntimeError(f"Safety guard failed: LC_CLEAN_DIR bukan turunan ART_DIR.\nART_DIR={art_abs}\nLC_CLEAN_DIR={lc_abs}")

if REBUILD_MODE == "wipe_all":
    if LC_CLEAN_DIR.exists():
        shutil.rmtree(LC_CLEAN_DIR, ignore_errors=True)
    LC_CLEAN_DIR.mkdir(parents=True, exist_ok=True)
elif REBUILD_MODE == "wipe_parts_only":
    LC_CLEAN_DIR.mkdir(parents=True, exist_ok=True)
else:
    raise ValueError("REBUILD_MODE must be 'wipe_all' or 'wipe_parts_only'")

# ----------------------------
# 4) Atomic writer (tmp -> rename)
# ----------------------------
def _atomic_write_parquet(df: pd.DataFrame, out_path: Path):
    # tmp tetap berakhiran .parquet agar engine tidak rewel
    tmp = out_path.with_name(out_path.stem + ".tmp" + out_path.suffix)  # part_0000.tmp.parquet
    try:
        df.to_parquet(tmp, index=False)
        tmp.replace(out_path)
    finally:
        # cleanup kalau gagal sebelum replace
        if tmp.exists() and (not out_path.exists()):
            try:
                tmp.unlink()
            except Exception:
                pass

def _atomic_write_csv_gz(df: pd.DataFrame, out_path: Path):
    final_path = out_path.with_suffix(".csv.gz")
    tmp = final_path.with_name(final_path.stem + ".tmp" + "".join(final_path.suffixes))  # keep .csv.gz
    try:
        df.to_csv(tmp, index=False, compression="gzip")
        tmp.replace(final_path)
    finally:
        if tmp.exists() and (not final_path.exists()):
            try:
                tmp.unlink()
            except Exception:
                pass
    return final_path

def write_part(df: pd.DataFrame, out_path: Path, fmt: str):
    out_path.parent.mkdir(parents=True, exist_ok=True)
    if fmt == "parquet":
        try:
            _atomic_write_parquet(df, out_path)
            return "parquet", out_path
        except Exception as e:
            alt = _atomic_write_csv_gz(df, out_path)
            return f"csv.gz (fallback from parquet: {type(e).__name__})", alt
    elif fmt == "csv.gz":
        alt = _atomic_write_csv_gz(df, out_path)
        return "csv.gz", alt
    else:
        raise ValueError("fmt must be 'parquet' or 'csv.gz'")

# ----------------------------
# 5) Core cleaning (NaN/negative-safe)
# ----------------------------
def clean_chunk_to_mag(ch: pd.DataFrame, ebv_ser: pd.Series):
    # normalize id/filter safely (tanpa np.char)
    oid_ser = ch["object_id"].astype("string").str.strip()
    filt_ser = ch["filter"].astype("string").str.strip().str.lower()

    # numeric
    mjd = ch["mjd"].to_numpy(dtype=np.float32, copy=False)
    flux = ch["flux"].to_numpy(dtype=np.float32, copy=False)
    err  = ch["flux_err"].to_numpy(dtype=np.float32, copy=False)

    # sanitize err
    err = np.nan_to_num(err, nan=np.float32(ERR_EPS), posinf=np.float32(ERR_EPS), neginf=np.float32(ERR_EPS))
    err = np.maximum(err, np.float32(ERR_EPS))

    # sanitize flux: inf -> NaN (NaN tetap)
    flux = flux.astype(np.float32, copy=False)
    flux[~np.isfinite(flux)] = np.float32(np.nan)

    # band_id mapping (vectorized masks; aman walau filt_ser ada <NA>)
    filt = filt_ser.fillna("").to_numpy(dtype=object, copy=False)
    band_id = np.full(len(ch), -1, dtype=np.int8)
    for b, bid in BAND2ID.items():
        band_id[filt == b] = np.int8(bid)

    if np.any(band_id < 0):
        bad = pd.Series(filt[band_id < 0]).value_counts().head(10).index.tolist()
        raise ValueError(f"Unknown/invalid filter values encountered (top examples): {bad}")

    # EBV lookup (missing -> 0.0)
    ebv = oid_ser.map(ebv_ser).fillna(0.0).to_numpy(dtype=np.float32)
    ebv[~np.isfinite(ebv)] = np.float32(0.0)

    # R_lambda lookup
    rlam = np.zeros(len(ch), dtype=np.float32)
    for b, rv in EXT_RLAMBDA.items():
        rlam[filt == b] = np.float32(rv)

    A = (rlam * ebv).astype(np.float32)
    mul = np.power(np.float32(10.0), (np.float32(0.4) * A)).astype(np.float32)

    flux_deext = (flux * mul).astype(np.float32)
    err_deext  = (err  * mul).astype(np.float32)

    # snr (NaN flux -> 0)
    okf = np.isfinite(flux_deext)
    snr = np.zeros_like(err_deext, dtype=np.float32)
    snr[okf] = (flux_deext[okf] / np.maximum(err_deext[okf], np.float32(ERR_EPS))).astype(np.float32)

    detected = (snr > np.float32(SNR_DET)).astype(np.int8)

    nan_flux_rows = int((~okf).sum())
    if nan_flux_rows:
        detected[~okf] = np.int8(0)
        snr[~okf] = np.float32(0.0)

    flux_detlim = (np.float32(DET_SIGMA) * err_deext).astype(np.float32)

    flux_for_mag = np.where(
        detected == 1,
        np.maximum(flux_deext, np.float32(MIN_FLUX_POS_UJY)),
        np.maximum(flux_detlim, np.float32(MIN_FLUX_POS_UJY)),
    ).astype(np.float32)

    mag = (np.float32(MAG_ZP) - np.float32(2.5) * np.log10(flux_for_mag)).astype(np.float32)
    mag = np.clip(mag, np.float32(MAG_MIN), np.float32(MAG_MAX)).astype(np.float32)

    mag_err = (np.float32(1.0857362) * (err_deext / flux_for_mag)).astype(np.float32)
    mag_err = np.clip(mag_err, np.float32(MAGERR_FLOOR_DET), np.float32(MAGERR_CAP)).astype(np.float32)

    if MAGERR_FLOOR_ND is not None and float(MAGERR_FLOOR_ND) > 0:
        mag_err = np.where(
            detected == 1,
            mag_err,
            np.maximum(mag_err, np.float32(MAGERR_FLOOR_ND))
        ).astype(np.float32)

    out = pd.DataFrame({
        "object_id": pd.array(oid_ser.to_numpy(copy=False), dtype="string"),
        "mjd": mjd.astype(np.float32, copy=False),
        "band_id": band_id.astype(np.int8, copy=False),
        "mag": mag.astype(np.float32, copy=False),
        "mag_err": mag_err.astype(np.float32, copy=False),
        "snr": snr.astype(np.float32, copy=False),
        "detected": detected.astype(np.int8, copy=False),
    })

    dropped_time = 0
    if DROP_BAD_TIME_ROWS:
        t = out["mjd"].to_numpy(dtype=np.float32, copy=False)
        keep = np.isfinite(t)
        dropped_time = int((~keep).sum())
        if dropped_time:
            out = out[keep]

    if KEEP_FLUX_DEBUG:
        out["flux_deext"] = pd.Series(np.nan_to_num(flux_deext, nan=0.0), dtype="float32")
        out["err_deext"]  = pd.Series(err_deext, dtype="float32")

    return out, dropped_time, nan_flux_rows

# ----------------------------
# 6) Process split-wise
# ----------------------------
splits_to_use = ONLY_SPLITS if (ONLY_SPLITS is not None) else list(SPLIT_LIST)

summary_rows, manifest_rows = [], []

def _wipe_parts_dir(out_dir: Path):
    if out_dir.exists():
        for pat in ["part_*.parquet", "part_*.csv.gz", "*.tmp", "*.tmp.parquet", "*.tmp.csv.gz"]:
            for f in out_dir.glob(pat):
                try:
                    f.unlink()
                except Exception:
                    pass

def process_split(split_name: str, which: str):
    ebv_ser = EBV_TRAIN_SER if which == "train" else EBV_TEST_SER
    out_dir = LC_CLEAN_DIR / split_name / which
    out_dir.mkdir(parents=True, exist_ok=True)

    if REBUILD_MODE == "wipe_parts_only":
        _wipe_parts_dir(out_dir)

    t0 = time.time()
    part_idx = 0
    n_rows_total = 0
    n_det = 0
    n_finite_mag = 0
    mag_min = np.inf
    mag_max = -np.inf
    dropped_time_total = 0
    nan_flux_total = 0

    for ch in iter_lightcurve_chunks(split_name, which, chunksize=CHUNKSIZE):
        cleaned, dropped_time, nan_flux = clean_chunk_to_mag(ch, ebv_ser)

        dropped_time_total += int(dropped_time)
        nan_flux_total += int(nan_flux)

        n_rows = int(len(cleaned))
        n_rows_total += n_rows

        det_arr = cleaned["detected"].to_numpy(dtype=np.int8, copy=False)
        n_det += int(det_arr.sum())

        mag_arr = cleaned["mag"].to_numpy(dtype=np.float32, copy=False)
        fin = np.isfinite(mag_arr)
        n_finite_mag += int(fin.sum())
        if fin.any():
            mag_min = float(min(mag_min, float(np.min(mag_arr[fin]))))
            mag_max = float(max(mag_max, float(np.max(mag_arr[fin]))))

        out_path = out_dir / f"part_{part_idx:04d}.parquet"
        used_fmt, final_path = write_part(cleaned, out_path, WRITE_FORMAT)

        manifest_rows.append({
            "split": split_name,
            "which": which,
            "part": int(part_idx),
            "path": str(final_path),
            "rows": int(n_rows),
            "format": str(used_fmt),
        })

        part_idx += 1
        del cleaned, ch
        if part_idx % 10 == 0:
            gc.collect()

    dt = time.time() - t0
    summary_rows.append({
        "split": split_name,
        "which": which,
        "parts": int(part_idx),
        "rows": int(n_rows_total),
        "det_frac_snr_gt_thr": float(n_det / max(n_rows_total, 1)),
        "finite_mag_frac": float(n_finite_mag / max(n_rows_total, 1)),
        "mag_min": (mag_min if np.isfinite(mag_min) else np.nan),
        "mag_max": (mag_max if np.isfinite(mag_max) else np.nan),
        "dropped_time_rows": int(dropped_time_total),
        "nan_flux_rows": int(nan_flux_total),
        "sec": float(dt),
    })

    print(
        f"[Stage 4] {split_name}/{which}: parts={part_idx} | rows={n_rows_total:,} | "
        f"det%={100*(n_det/max(n_rows_total,1)):.2f}% | "
        f"nan_flux={nan_flux_total:,} | drop_time={dropped_time_total:,} | "
        f"mag_range=[{(mag_min if np.isfinite(mag_min) else np.nan):.2f}, {(mag_max if np.isfinite(mag_max) else np.nan):.2f}] | "
        f"time={dt:.1f}s"
    )

print(f"[Stage 4] REBUILD_MODE={REBUILD_MODE} | Writing to: {LC_CLEAN_DIR}")
for s in splits_to_use:
    process_split(s, "train")
    process_split(s, "test")

# ----------------------------
# 7) Save manifests + summary + config
# ----------------------------
df_parts_manifest = pd.DataFrame(manifest_rows)
df_summary  = pd.DataFrame(summary_rows)

manifest_path = LC_CLEAN_DIR / "lc_clean_mag_manifest.csv"
summary_path  = LC_CLEAN_DIR / "lc_clean_mag_summary.csv"

df_parts_manifest.to_csv(manifest_path, index=False)
df_summary.to_csv(summary_path, index=False)

cfg_path = LC_CLEAN_DIR / "photometric_config_mag.json"
with open(cfg_path, "w", encoding="utf-8") as f:
    json.dump({
        "EXT_RLAMBDA": EXT_RLAMBDA,
        "SNR_DET": float(SNR_DET),
        "DET_SIGMA": float(DET_SIGMA),
        "ERR_EPS": float(ERR_EPS),
        "MIN_FLUX_POS_UJY": float(MIN_FLUX_POS_UJY),
        "MAG_ZP": float(MAG_ZP),
        "MAG_MIN": float(MAG_MIN),
        "MAG_MAX": float(MAG_MAX),
        "MAGERR_FLOOR_DET": float(MAGERR_FLOOR_DET),
        "MAGERR_FLOOR_ND": float(MAGERR_FLOOR_ND),
        "MAGERR_CAP": float(MAGERR_CAP),
        "CHUNKSIZE": int(CHUNKSIZE),
        "WRITE_FORMAT": str(WRITE_FORMAT),
        "ONLY_SPLITS": list(splits_to_use),
        "KEEP_FLUX_DEBUG": bool(KEEP_FLUX_DEBUG),
        "DROP_BAD_TIME_ROWS": bool(DROP_BAD_TIME_ROWS),
        "REBUILD_MODE": str(REBUILD_MODE),
        "EBV_SOURCE": ("EBV_clip" if ("EBV_clip" in df_train_meta.columns and "EBV_clip" in df_test_meta.columns) else "EBV"),
    }, f, indent=2)

print("\n[Stage 4] Done.")
print(f"- LC_CLEAN_DIR  : {LC_CLEAN_DIR}")
print(f"- Saved manifest: {manifest_path}")
print(f"- Saved summary : {summary_path}")
print(f"- Saved config  : {cfg_path}")

# ----------------------------
# 8) Helper for next stages
# ----------------------------
def get_clean_parts(split_name: str, which: str):
    m = df_parts_manifest[(df_parts_manifest["split"] == split_name) & (df_parts_manifest["which"] == which)].sort_values("part")
    return m["path"].astype(str).tolist()

globals().update({
    "EXT_RLAMBDA": EXT_RLAMBDA,
    "BAND2ID": BAND2ID,
    "ID2BAND": ID2BAND,
    "MAG_ZP": MAG_ZP,
    "LC_CLEAN_DIR": LC_CLEAN_DIR,
    "lc_clean_mag_manifest": df_parts_manifest,
    "lc_clean_mag_summary": df_summary,
    "get_clean_parts": get_clean_parts,
})

gc.collect()


[Stage 4] REBUILD_MODE=wipe_all | Writing to: /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/lc_clean_mag
[Stage 4] split_01/train: parts=1 | rows=26,324 | det%=19.34% | nan_flux=11 | drop_time=0 | mag_range=[19.72, 25.96] | time=0.1s
[Stage 4] split_01/test: parts=1 | rows=59,235 | det%=23.02% | nan_flux=23 | drop_time=0 | mag_range=[19.61, 26.20] | time=0.2s
[Stage 4] split_02/train: parts=1 | rows=25,609 | det%=24.45% | nan_flux=6 | drop_time=0 | mag_range=[20.10, 26.04] | time=0.1s
[Stage 4] split_02/test: parts=1 | rows=71,229 | det%=21.69% | nan_flux=8 | drop_time=0 | mag_range=[18.77, 26.32] | time=0.2s
[Stage 4] split_03/train: parts=1 | rows=21,676 | det%=21.65% | nan_flux=5 | drop_time=0 | mag_range=[20.17, 26.23] | time=0.1s
[Stage 4] split_03/test: parts=1 | rows=53,751 | det%=21.90% | nan_flux=8 | drop_time=0 | mag_range=[19.61, 26.37] | time=0.2s
[Stage 4] split_04/train: parts=1 | rows=22,898 | det%=21.11% | nan_flux=12 | drop_time=0 | mag_range=[20

273

# Sequence Tokenization (Event-based Tokens)

In [11]:
# ============================================================
# STAGE 5 — Sequence Tokenization (Event-based Tokens) (ONE CELL, Kaggle CPU-SAFE)
# REVISI FULL v5.2 (PATH+META SYNC HARDENED + MISSING-OBJECT SAFE + BUCKET ROBUST)
#
# FIX UTAMA v5.2:
# - Auto-find STAGE 4 manifest dari run manapun.
# - Sync path benar: RUN_DIR/ART_DIR/LC_CLEAN_DIR dari manifest.
# - Auto-reload df_train_meta/df_test_meta dari ART_DIR synced jika mismatch.
# - Validasi semua part path exists.
# - Jika ada object_id tidak muncul di cleaned parts -> build EMPTY sequence (len=1) agar built==expected.
# - Bucket writer aman (try/finally close) + cleanup tmp dir rmtree.
#
# OUTPUT:
# - artifacts/seq_tokens/split_XX/{train|test}/shard_*.npz
# - artifacts/seq_tokens/seq_manifest_{train|test}.csv
# - artifacts/seq_tokens/seq_build_stats.csv
# - artifacts/seq_tokens/seq_config.json
# ============================================================

import gc, json, warnings, time, shutil
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Require minimal globals
# ----------------------------
for need in ["ART_DIR", "df_train_meta", "df_test_meta"]:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan minimal STAGE 0 + STAGE 2 dulu (ART_DIR + meta).")

ART_DIR = Path(ART_DIR)

# ----------------------------
# 1) Helpers
# ----------------------------
def _safe_string_series(s: pd.Series) -> pd.Series:
    try:
        return s.astype("string").str.strip()
    except Exception:
        return s.astype(str).str.strip()

def _find_stage4_manifest(art_dir: Path):
    cand = art_dir / "lc_clean_mag" / "lc_clean_mag_manifest.csv"
    if cand.exists():
        return cand

    root = Path("/kaggle/working/mallorn_run")
    if not root.exists():
        return None

    cands = list(root.glob("run_*/artifacts/lc_clean_mag/lc_clean_mag_manifest.csv"))
    if not cands:
        cands = list(root.glob("run_*/**/lc_clean_mag_manifest.csv"))
    if not cands:
        return None

    cands = sorted(cands, key=lambda p: p.stat().st_mtime, reverse=True)
    return cands[0]

def _sync_dirs_from_manifest(manifest_csv: Path):
    lc_clean_dir = manifest_csv.parent
    art_dir_new  = lc_clean_dir.parent
    run_dir_new  = art_dir_new.parent
    return run_dir_new, art_dir_new, lc_clean_dir

def _load_meta_if_needed(art_dir_synced: Path):
    """
    Jika df_train_meta/df_test_meta yg ada di memori mismatch dengan file meta di run synced,
    maka reload dari file agar routing split/object_id konsisten.
    """
    global df_train_meta, df_test_meta

    tr_pq = art_dir_synced / "train_meta.parquet"
    te_pq = art_dir_synced / "test_meta.parquet"
    tr_csv = art_dir_synced / "train_meta.csv"
    te_csv = art_dir_synced / "test_meta.csv"

    # kalau file meta tidak ada, pakai yang di memori
    if not (tr_pq.exists() or tr_csv.exists()) or not (te_pq.exists() or te_csv.exists()):
        return False, "meta file not found in synced ART_DIR; keep in-memory"

    def _read_meta(pq, csv):
        if pq.exists():
            return pd.read_parquet(pq).set_index("object_id") if "object_id" in pd.read_parquet(pq, columns=None).columns else pd.read_parquet(pq)
        else:
            return pd.read_csv(csv).set_index("object_id")

    # load candidate (safe, minimal)
    try:
        if tr_pq.exists():
            cand_train = pd.read_parquet(tr_pq)
        else:
            cand_train = pd.read_csv(tr_csv)

        if te_pq.exists():
            cand_test = pd.read_parquet(te_pq)
        else:
            cand_test = pd.read_csv(te_csv)

        # ensure index=object_id
        if "object_id" in cand_train.columns:
            cand_train = cand_train.set_index("object_id", drop=True)
        if "object_id" in cand_test.columns:
            cand_test = cand_test.set_index("object_id", drop=True)

        cand_train.index = cand_train.index.astype("string")
        cand_test.index = cand_test.index.astype("string")

        # mismatch check: size + a few ids overlap
        mem_train_n = int(len(df_train_meta))
        mem_test_n  = int(len(df_test_meta))
        cand_train_n = int(len(cand_train))
        cand_test_n  = int(len(cand_test))

        if (mem_train_n != cand_train_n) or (mem_test_n != cand_test_n):
            df_train_meta = cand_train
            df_test_meta = cand_test
            return True, f"reloaded meta due to size mismatch: mem({mem_train_n},{mem_test_n}) -> file({cand_train_n},{cand_test_n})"

        # also check sample ids
        sample_ids = df_train_meta.index[:5].astype(str).tolist()
        ok = all((sid in cand_train.index) for sid in sample_ids)
        if not ok:
            df_train_meta = cand_train
            df_test_meta = cand_test
            return True, "reloaded meta due to id mismatch"

        return False, "meta already consistent"
    except Exception as e:
        return False, f"meta reload skipped (error: {type(e).__name__}: {e})"

# ----------------------------
# 2) Locate STAGE 4 output (robust)
# ----------------------------
manifest_csv = _find_stage4_manifest(ART_DIR)
if manifest_csv is None:
    root = Path("/kaggle/working/mallorn_run")
    runs = sorted([p.name for p in root.glob("run_*") if p.is_dir()])[-15:] if root.exists() else []
    raise RuntimeError(
        "Output STAGE 4 (lc_clean_mag_manifest.csv) tidak ditemukan.\n"
        f"- ART_DIR saat ini: {ART_DIR}\n"
        f"- Expected: {ART_DIR/'lc_clean_mag'}\n"
        f"- Runs available (last 15): {runs}\n"
        "Solusi: pastikan STAGE 4 benar-benar selesai dan menulis artifacts/lc_clean_mag."
    )

RUN_DIR, ART_DIR, LC_CLEAN_DIR = _sync_dirs_from_manifest(manifest_csv)

print("STAGE 5 ROUTING SYNC OK")
print(f"- RUN_DIR      : {RUN_DIR}")
print(f"- ART_DIR      : {ART_DIR}")
print(f"- LC_CLEAN_DIR : {LC_CLEAN_DIR}")
print(f"- manifest_csv : {manifest_csv}")

reloaded, msg = _load_meta_if_needed(ART_DIR)
print(f"- meta_sync    : {msg}")

# ----------------------------
# 3) Load & validate Stage4 manifest
# ----------------------------
_df_clean_manifest = pd.read_csv(manifest_csv)
_df_clean_manifest.columns = [c.strip() for c in _df_clean_manifest.columns]

need_cols = {"split", "which", "part", "path"}
miss = sorted(list(need_cols - set(_df_clean_manifest.columns)))
if miss:
    raise RuntimeError(f"Manifest STAGE 4 missing columns: {miss} | cols={list(_df_clean_manifest.columns)}")

paths = _df_clean_manifest["path"].astype(str).tolist()
missing_paths = [p for p in paths if not Path(p).exists()]
if missing_paths:
    ex = missing_paths[:10]
    raise RuntimeError(
        "Ada file part STAGE 4 yang hilang (manifest ada tapi file tidak ada).\n"
        f"Missing count={len(missing_paths)} | contoh={ex}\n"
        "Solusi: rerun STAGE 4 dengan mode rebuild/wipe untuk regenerasi cache."
    )

def get_clean_parts(split_name: str, which: str):
    m = _df_clean_manifest[(_df_clean_manifest["split"] == split_name) & (_df_clean_manifest["which"] == which)]
    if m.empty:
        return []
    return m.sort_values("part")["path"].astype(str).tolist()

# ----------------------------
# 4) Recover SPLIT_LIST + routing ids (force stable)
# ----------------------------
SPLIT_LIST = [f"split_{i:02d}" for i in range(1, 21)]
splits_in_manifest = sorted(set(_df_clean_manifest["split"].astype(str).tolist()))
# pakai intersection agar tidak nyasar split yang tidak ada part-nya
SPLITS_TO_CONSIDER = [s for s in SPLIT_LIST if s in splits_in_manifest]

train_ids_by_split = {s: [] for s in SPLIT_LIST}
for oid, sp in df_train_meta["split"].astype(str).items():
    train_ids_by_split[str(sp)].append(str(oid))

test_ids_by_split = {s: [] for s in SPLIT_LIST}
for oid, sp in df_test_meta["split"].astype(str).items():
    test_ids_by_split[str(sp)].append(str(oid))

# ----------------------------
# 5) Settings
# ----------------------------
ONLY_SPLITS = None                 # None=all; contoh: ["split_01"] untuk test cepat
REBUILD_MODE = "wipe_all"          # "wipe_all" or "reuse_if_exists"

COMPRESS_NPZ = False
SHARD_MAX_OBJECTS = 1500

SNR_TANH_SCALE = 10.0
TIME_CLIP_MAX_DAYS = None
DROP_BAD_TIME_ROWS = True

L_MAX = int(CFG.get("L_MAX", 256)) if "CFG" in globals() else 256
TRUNC_POLICY = str(CFG.get("TRUNC_POLICY", "smart")) if "CFG" in globals() else "smart"  # smart/head/none
KEEP_DET_FRAC = float(CFG.get("KEEP_DET_FRAC", 0.70)) if "CFG" in globals() else 0.70
KEEP_EDGE = True
USE_RESTFRAME_TIME = bool(CFG.get("USE_RESTFRAME_TIME", True)) if "CFG" in globals() else True

NUM_BUCKETS = 64

SEQ_DIR = Path(ART_DIR) / "seq_tokens"
SEQ_DIR.mkdir(parents=True, exist_ok=True)

TOKEN_MODE = None
FEATURE_NAMES = None
FEATURE_DIM = None

BASE_COLS = {"object_id", "mjd", "band_id", "snr", "detected"}
MODE_COLS = {"mag": {"mag", "mag_err"}, "asinh": {"flux_asinh", "err_log1p"}}

# ----------------------------
# 6) Reader for cleaned parts
# ----------------------------
def _read_clean_part(path: str) -> pd.DataFrame:
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"Clean part missing: {p}")

    if p.suffix == ".parquet":
        df = pd.read_parquet(p)
    elif p.name.endswith(".csv.gz"):
        df = pd.read_csv(p, compression="gzip")
    else:
        df = pd.read_csv(p)

    df.columns = [c.strip() for c in df.columns]

    global TOKEN_MODE, FEATURE_NAMES, FEATURE_DIM
    if TOKEN_MODE is None:
        cols = set(df.columns)
        if BASE_COLS.issubset(cols) and MODE_COLS["mag"].issubset(cols):
            TOKEN_MODE = "mag"
            FEATURE_NAMES = ["t_rel_log", "dt_log", "mag", "mag_err_log", "snr_tanh", "detected"]
        elif BASE_COLS.issubset(cols) and MODE_COLS["asinh"].issubset(cols):
            TOKEN_MODE = "asinh"
            FEATURE_NAMES = ["t_rel_log", "dt_log", "flux_asinh", "err_log1p", "snr_tanh", "detected"]
        else:
            raise RuntimeError(
                "Cannot detect cleaned schema.\n"
                f"Found cols={list(df.columns)}\n"
                "Expected MAG or ASINH schema from STAGE 4."
            )
        FEATURE_DIM = len(FEATURE_NAMES)

    req = set(BASE_COLS) | set(MODE_COLS[TOKEN_MODE])
    miss = sorted(list(req - set(df.columns)))
    if miss:
        raise RuntimeError(f"Clean part missing columns: {miss} | file={p}")

    df["object_id"] = _safe_string_series(df["object_id"])
    df["mjd"] = pd.to_numeric(df["mjd"], errors="coerce").astype(np.float32)
    df["band_id"] = pd.to_numeric(df["band_id"], errors="coerce").astype(np.int16)
    df["snr"] = pd.to_numeric(df["snr"], errors="coerce").astype(np.float32)
    df["detected"] = pd.to_numeric(df["detected"], errors="coerce").fillna(0).astype(np.int8)

    if TOKEN_MODE == "mag":
        df["mag"] = pd.to_numeric(df["mag"], errors="coerce").astype(np.float32)
        df["mag_err"] = pd.to_numeric(df["mag_err"], errors="coerce").astype(np.float32)
    else:
        df["flux_asinh"] = pd.to_numeric(df["flux_asinh"], errors="coerce").astype(np.float32)
        df["err_log1p"] = pd.to_numeric(df["err_log1p"], errors="coerce").astype(np.float32)

    if DROP_BAD_TIME_ROWS:
        df = df[np.isfinite(df["mjd"].to_numpy())]

    return df

# ----------------------------
# 7) Truncation
# ----------------------------
def _smart_truncate(mjd, det, snr, Lmax: int):
    n = len(mjd)
    if n <= Lmax:
        return np.arange(n, dtype=np.int64)

    idx_all = np.arange(n, dtype=np.int64)
    keep = set()
    if KEEP_EDGE:
        keep.add(0); keep.add(n - 1)

    det_idx = idx_all[det.astype(bool)]
    k_det = int(max(0, min(len(det_idx), int(np.floor(Lmax * KEEP_DET_FRAC)))))
    if k_det > 0 and len(det_idx) > 0:
        score = np.abs(snr[det_idx])
        top = det_idx[np.argsort(-score)[:k_det]]
        for i in top.tolist():
            keep.add(int(i))

    if len(keep) < Lmax:
        rem = [i for i in idx_all.tolist() if i not in keep]
        need = Lmax - len(keep)
        if rem and need > 0:
            pick = np.linspace(0, len(rem) - 1, num=need, dtype=int)
            for p in pick.tolist():
                keep.add(int(rem[p]))

    out = np.array(sorted(keep), dtype=np.int64)
    if len(out) > Lmax:
        pos = np.linspace(0, len(out) - 1, num=Lmax, dtype=int)
        out = out[pos]
    return out

# ----------------------------
# 8) Build tokens per object (empty-safe)
# ----------------------------
def build_empty_tokens():
    # token kosong (len=1) agar object tetap ada
    X = np.zeros((1, int(FEATURE_DIM)), dtype=np.float32)
    B = np.full((1,), -1, dtype=np.int8)
    return X, B, 0, 1

def build_object_tokens(df_obj: pd.DataFrame, z_val: float = 0.0):
    if df_obj is None or df_obj.empty:
        return build_empty_tokens()

    mjd = df_obj["mjd"].to_numpy(dtype=np.float32, copy=False)
    band = df_obj["band_id"].to_numpy(dtype=np.int16, copy=False)
    snr  = df_obj["snr"].to_numpy(dtype=np.float32, copy=False)
    det  = df_obj["detected"].to_numpy(dtype=np.int8, copy=False)

    order = np.lexsort((band, mjd))
    mjd = mjd[order]; band = band[order]; snr = snr[order]; det = det[order]

    z = float(z_val) if (z_val is not None and np.isfinite(z_val)) else 0.0
    denom = (1.0 + max(z, 0.0)) if USE_RESTFRAME_TIME else 1.0

    t0 = mjd[0]
    t_rel = (mjd - t0) / np.float32(denom)
    dt = np.empty_like(t_rel); dt[0] = 0.0
    if len(t_rel) > 1:
        dt[1:] = np.maximum(t_rel[1:] - t_rel[:-1], 0.0)

    if TIME_CLIP_MAX_DAYS is not None:
        mx = np.float32(TIME_CLIP_MAX_DAYS)
        t_rel = np.clip(t_rel, 0.0, mx)
        dt    = np.clip(dt,    0.0, mx)

    t_rel_log = np.log1p(t_rel).astype(np.float32)
    dt_log    = np.log1p(dt).astype(np.float32)

    snr = np.nan_to_num(snr, nan=0.0, posinf=0.0, neginf=0.0)
    snr_tanh = np.tanh(snr / np.float32(SNR_TANH_SCALE)).astype(np.float32)
    det_f = det.astype(np.float32)

    if TOKEN_MODE == "mag":
        mag = df_obj["mag"].to_numpy(dtype=np.float32, copy=False)[order]
        mag_err = df_obj["mag_err"].to_numpy(dtype=np.float32, copy=False)[order]
        mag = np.nan_to_num(mag, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
        mag_err = np.nan_to_num(mag_err, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
        mag_err = np.maximum(mag_err, np.float32(0.0))
        mag_err_log = np.log1p(mag_err).astype(np.float32)
        X = np.stack([t_rel_log, dt_log, mag, mag_err_log, snr_tanh, det_f], axis=1).astype(np.float32)
    else:
        flux = df_obj["flux_asinh"].to_numpy(dtype=np.float32, copy=False)[order]
        elog = df_obj["err_log1p"].to_numpy(dtype=np.float32, copy=False)[order]
        flux = np.nan_to_num(flux, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
        elog = np.nan_to_num(elog, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32)
        X = np.stack([t_rel_log, dt_log, flux, elog, snr_tanh, det_f], axis=1).astype(np.float32)

    L0 = int(X.shape[0])
    if L_MAX and int(L_MAX) > 0 and X.shape[0] > int(L_MAX):
        if TRUNC_POLICY == "smart":
            keep = _smart_truncate(mjd, det, snr, int(L_MAX))
        elif TRUNC_POLICY == "head":
            keep = np.arange(int(L_MAX), dtype=np.int64)
        else:
            keep = np.arange(X.shape[0], dtype=np.int64)

        if len(keep) != X.shape[0]:
            X = X[keep]
            band = band[keep]

            # recompute dt_log (stabil)
            sel_mjd = mjd[keep]
            sel_t = (sel_mjd - sel_mjd[0]) / np.float32(denom)
            sel_dt = np.empty_like(sel_t); sel_dt[0] = 0.0
            if len(sel_t) > 1:
                sel_dt[1:] = np.maximum(sel_t[1:] - sel_t[:-1], 0.0)
            X[:, 1] = np.log1p(sel_dt).astype(np.float32)

    return X, band.astype(np.int8), L0, int(X.shape[0])

# ----------------------------
# 9) Shard writer
# ----------------------------
def save_shard(out_path: Path, object_ids, X_concat, B_concat, offsets):
    out_path.parent.mkdir(parents=True, exist_ok=True)
    obj_arr = np.asarray(object_ids, dtype="S")
    if COMPRESS_NPZ:
        np.savez_compressed(out_path, object_id=obj_arr, x=X_concat, band=B_concat, offsets=offsets)
    else:
        np.savez(out_path, object_id=obj_arr, x=X_concat, band=B_concat, offsets=offsets)

# ----------------------------
# 10) Robust builder: bucketize -> groupby object -> shard (missing-safe)
# ----------------------------
def build_sequences_bucket(split_name: str, which: str, expected_ids: set, out_dir: Path, num_buckets: int = 64):
    try:
        import pyarrow as pa
        import pyarrow.parquet as pq
    except Exception as e:
        raise RuntimeError("pyarrow tidak tersedia. Di Kaggle biasanya ada.") from e

    parts = get_clean_parts(split_name, which)
    if not parts:
        raise RuntimeError(f"Tidak ada cleaned parts untuk {split_name}/{which}. Cek STAGE 4 output.")

    tmp_dir = Path(ART_DIR) / "tmp_seq_buckets" / split_name / which
    if tmp_dir.exists():
        shutil.rmtree(tmp_dir, ignore_errors=True)
    tmp_dir.mkdir(parents=True, exist_ok=True)

    writers = {}
    kept_rows = 0
    t0 = time.time()

    def bucket_idx(series_objid: pd.Series) -> np.ndarray:
        h = pd.util.hash_pandas_object(series_objid, index=False).to_numpy(dtype=np.uint64, copy=False)
        return (h % np.uint64(num_buckets)).astype(np.int16)

    try:
        # 1) write buckets
        for p in parts:
            df = _read_clean_part(p)
            if df.empty:
                continue

            df = df[df["object_id"].isin(expected_ids)]
            if df.empty:
                continue

            kept_rows += int(len(df))
            bidx = bucket_idx(df["object_id"])
            df["_b"] = bidx

            for b in np.unique(bidx):
                sub = df[df["_b"] == b].drop(columns=["_b"])
                if sub.empty:
                    continue
                fp = tmp_dir / f"bucket_{int(b):03d}.parquet"
                table = pa.Table.from_pandas(sub, preserve_index=False)
                if int(b) not in writers:
                    writers[int(b)] = pq.ParquetWriter(fp, table.schema, compression="snappy")
                writers[int(b)].write_table(table)

            del df
            gc.collect()

    finally:
        for w in list(writers.values()):
            try:
                w.close()
            except Exception:
                pass

    meta = df_train_meta if which == "train" else df_test_meta

    manifest_rows = []
    shard_idx = 0
    batch_obj_ids, batch_X, batch_B, batch_len = [], [], [], []
    built_ids = set()

    len_before, len_after = [], []

    def flush_shard_local():
        nonlocal shard_idx, batch_obj_ids, batch_X, batch_B, batch_len, manifest_rows
        if not batch_obj_ids:
            return
        lengths = np.asarray(batch_len, dtype=np.int64)
        offsets = np.zeros(len(lengths) + 1, dtype=np.int64)
        offsets[1:] = np.cumsum(lengths)

        Xc = np.concatenate(batch_X, axis=0).astype(np.float32)
        Bc = np.concatenate(batch_B, axis=0).astype(np.int8)

        shard_path = out_dir / f"shard_{shard_idx:04d}.npz"
        save_shard(shard_path, batch_obj_ids, Xc, Bc, offsets)

        for i, oid in enumerate(batch_obj_ids):
            manifest_rows.append({
                "object_id": oid,
                "split": split_name,
                "which": which,
                "shard": str(shard_path),
                "start": int(offsets[i]),
                "length": int(lengths[i]),
            })

        shard_idx += 1
        batch_obj_ids, batch_X, batch_B, batch_len = [], [], [], []
        gc.collect()

    # 2) read buckets -> groupby object
    for bf in sorted(tmp_dir.glob("bucket_*.parquet")):
        dfb = pd.read_parquet(bf)
        if dfb.empty:
            continue

        for oid, g in dfb.groupby("object_id", sort=False):
            oid = str(oid)
            if oid in built_ids:
                continue

            z_val = float(meta.loc[oid, "Z"]) if (USE_RESTFRAME_TIME and oid in meta.index) else 0.0
            X, B, lb, la = build_object_tokens(g, z_val=z_val)

            len_before.append(lb)
            len_after.append(la)

            batch_obj_ids.append(oid)
            batch_X.append(X)
            batch_B.append(B)
            batch_len.append(int(X.shape[0]))
            built_ids.add(oid)

            if len(batch_obj_ids) >= SHARD_MAX_OBJECTS:
                flush_shard_local()

        del dfb
        gc.collect()

    # 3) fill missing objects (empty sequences)
    missing_ids = list(expected_ids - built_ids)
    if missing_ids:
        for oid in missing_ids:
            oid = str(oid)
            X, B, lb, la = build_empty_tokens()
            len_before.append(lb)
            len_after.append(la)

            batch_obj_ids.append(oid)
            batch_X.append(X)
            batch_B.append(B)
            batch_len.append(int(X.shape[0]))
            built_ids.add(oid)

            if len(batch_obj_ids) >= SHARD_MAX_OBJECTS:
                flush_shard_local()

    flush_shard_local()

    # cleanup tmp
    shutil.rmtree(tmp_dir, ignore_errors=True)

    st = {
        "kept_rows": int(kept_rows),
        "built_objects": int(len(built_ids)),
        "missing_filled": int(len(missing_ids)),
        "len_before_mean": float(np.mean(len_before)) if len_before else 0.0,
        "len_before_p95": float(np.quantile(len_before, 0.95)) if len_before else 0.0,
        "len_after_mean": float(np.mean(len_after)) if len_after else 0.0,
        "len_after_p95": float(np.quantile(len_after, 0.95)) if len_after else 0.0,
        "truncated_frac": float(np.mean([a < b for a, b in zip(len_after, len_before)])) if len_before else 0.0,
        "time_s": float(time.time() - t0),
    }
    return manifest_rows, st

# ----------------------------
# 11) RUN
# ----------------------------
splits_to_run = ONLY_SPLITS if (ONLY_SPLITS is not None) else SPLITS_TO_CONSIDER
all_manifest_train, all_manifest_test, split_run_stats = [], [], []

def expected_set_for(split_name: str, which: str) -> set:
    return set(train_ids_by_split.get(split_name, [])) if which == "train" else set(test_ids_by_split.get(split_name, []))

for split_name in splits_to_run:
    for which in ["train", "test"]:
        out_dir = SEQ_DIR / split_name / which
        out_dir.mkdir(parents=True, exist_ok=True)

        expected_ids = expected_set_for(split_name, which)
        if len(expected_ids) == 0:
            raise RuntimeError(f"Expected ids empty for {split_name}/{which}.")

        shard_exists = any(out_dir.glob("shard_*.npz"))
        if REBUILD_MODE == "reuse_if_exists" and shard_exists:
            print(f"\n[Stage 5] SKIP (exists): {split_name}/{which}")
            continue
        else:
            for f in out_dir.glob("shard_*.npz"):
                try: f.unlink()
                except Exception: pass

        print(f"\n[Stage 5] {split_name}/{which} | expected={len(expected_ids):,} | L_MAX={L_MAX} | TRUNC={TRUNC_POLICY}")

        manifest_rows, st = build_sequences_bucket(
            split_name=split_name,
            which=which,
            expected_ids=expected_ids,
            out_dir=out_dir,
            num_buckets=NUM_BUCKETS
        )

        print(f"[Stage 5] OK: built={st['built_objects']:,} (missing_filled={st['missing_filled']:,}) | "
              f"kept_rows={st['kept_rows']:,} | "
              f"len_mean {st['len_before_mean']:.1f}->{st['len_after_mean']:.1f} | "
              f"p95 {st['len_before_p95']:.1f}->{st['len_after_p95']:.1f} | "
              f"trunc%={st['truncated_frac']*100:.1f}% | "
              f"time={st['time_s']:.2f}s | mode={TOKEN_MODE}")

        split_run_stats.append({"split": split_name, "which": which, **st})

        if which == "train":
            all_manifest_train.extend(manifest_rows)
        else:
            all_manifest_test.extend(manifest_rows)

        gc.collect()

# ----------------------------
# 12) Save manifests + stats + config
# ----------------------------
df_m_train = pd.DataFrame(all_manifest_train).sort_values(["split", "shard", "start"]).reset_index(drop=True)
df_m_test  = pd.DataFrame(all_manifest_test).sort_values(["split", "shard", "start"]).reset_index(drop=True)

mtrain_path = SEQ_DIR / "seq_manifest_train.csv"
mtest_path  = SEQ_DIR / "seq_manifest_test.csv"
df_m_train.to_csv(mtrain_path, index=False)
df_m_test.to_csv(mtest_path, index=False)

df_stats = pd.DataFrame(split_run_stats)
stats_path = SEQ_DIR / "seq_build_stats.csv"
df_stats.to_csv(stats_path, index=False)

cfg = {
    "token_mode": TOKEN_MODE,
    "feature_names": FEATURE_NAMES,
    "feature_dim": int(FEATURE_DIM),
    "snr_tanh_scale": float(SNR_TANH_SCALE),
    "time_clip_max_days": None if TIME_CLIP_MAX_DAYS is None else float(TIME_CLIP_MAX_DAYS),
    "compress_npz": bool(COMPRESS_NPZ),
    "shard_max_objects": int(SHARD_MAX_OBJECTS),
    "num_buckets": int(NUM_BUCKETS),
    "L_MAX": int(L_MAX),
    "TRUNC_POLICY": str(TRUNC_POLICY),
    "KEEP_DET_FRAC": float(KEEP_DET_FRAC),
    "USE_RESTFRAME_TIME": bool(USE_RESTFRAME_TIME),
    "REBUILD_MODE": str(REBUILD_MODE),
    "RUN_DIR_USED": str(RUN_DIR),
    "ART_DIR_USED": str(ART_DIR),
    "LC_CLEAN_DIR_USED": str(LC_CLEAN_DIR),
    "manifest_csv": str(manifest_csv),
}
cfg_path = SEQ_DIR / "seq_config.json"
with open(cfg_path, "w", encoding="utf-8") as f:
    json.dump(cfg, f, indent=2)

print("\n[Stage 5] DONE")
print(f"- token_mode : {TOKEN_MODE}")
print(f"- features   : {FEATURE_NAMES}")
print(f"- Saved: {mtrain_path} (rows={len(df_m_train):,})")
print(f"- Saved: {mtest_path}  (rows={len(df_m_test):,})")
print(f"- Saved: {stats_path}")
print(f"- Saved: {cfg_path}")

# ----------------------------
# 13) Smoke test
# ----------------------------
def load_sequence(object_id: str, which: str):
    object_id = str(object_id).strip()
    m = df_m_train if which == "train" else df_m_test
    row = m[m["object_id"] == object_id]
    if row.empty:
        raise KeyError(f"object_id not found in seq manifest ({which}): {object_id}")
    r = row.iloc[0]
    data = np.load(r["shard"], allow_pickle=False)
    start = int(r["start"]); length = int(r["length"])
    X = data["x"][start:start+length]
    B = data["band"][start:start+length]
    return X, B

_smoke_oid = str(df_train_meta.index[0])
X_sm, B_sm = load_sequence(_smoke_oid, "train")
print(f"\n[Stage 5] Smoke test object_id={_smoke_oid}")
print(f"- seq_len={len(X_sm)} | X_shape={X_sm.shape} | bands_unique={sorted(set(B_sm.tolist()))}")

# ----------------------------
# 14) Export globals
# ----------------------------
globals().update({
    "RUN_DIR": RUN_DIR,
    "ART_DIR": ART_DIR,
    "LC_CLEAN_DIR": LC_CLEAN_DIR,
    "SEQ_DIR": SEQ_DIR,
    "seq_manifest_train": df_m_train,
    "seq_manifest_test": df_m_test,
    "SEQ_FEATURE_NAMES": FEATURE_NAMES,
    "SEQ_FEATURE_DIM": int(FEATURE_DIM),
    "SEQ_TOKEN_MODE": TOKEN_MODE,
    "get_clean_parts": get_clean_parts,
    "load_sequence": load_sequence,
})

gc.collect()


STAGE 5 ROUTING SYNC OK
- RUN_DIR      : /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76
- ART_DIR      : /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts
- LC_CLEAN_DIR : /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/lc_clean_mag
- manifest_csv : /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/lc_clean_mag/lc_clean_mag_manifest.csv
- meta_sync    : meta already consistent

[Stage 5] split_01/train | expected=155 | L_MAX=256 | TRUNC=smart
[Stage 5] OK: built=155 (missing_filled=0) | kept_rows=26,324 | len_mean 169.8->145.3 | p95 191.7->191.7 | trunc%=3.9% | time=9.83s | mode=mag

[Stage 5] split_01/test | expected=364 | L_MAX=256 | TRUNC=smart
[Stage 5] OK: built=364 (missing_filled=0) | kept_rows=59,235 | len_mean 162.7->148.3 | p95 193.8->193.8 | trunc%=2.2% | time=10.13s | mode=mag

[Stage 5] split_02/train | expected=170 | L_MAX=256 | TRUNC=smart
[Stage 5] OK: built=170 (missing_filled=0) | kept_rows=25,609 

55

# Sequence Length Policy (Padding, Truncation, Windowing)

In [12]:
# ============================================================
# STAGE 6 — Sequence Length Policy (Padding, Truncation, Windowing)
# ONE CELL, Kaggle CPU-SAFE — REVISI FULL v2.2 (MAG/ASINH COMPAT, HARDENED)
#
# Output:
# - artifacts/fixed_seq/{train|test}_{X|B|M}.dat  (memmap)
# - artifacts/fixed_seq/{train|test}_ids.npy
# - artifacts/fixed_seq/train_y.npy
# - artifacts/fixed_seq/{train|test}_origlen.npy, {train|test}_winstart.npy, {train|test}_winend.npy
# - artifacts/fixed_seq/length_policy_config.json
# ============================================================

import gc, json, time, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Require previous stages
# ----------------------------
for need in ["seq_manifest_train", "seq_manifest_test", "SEQ_FEATURE_NAMES",
             "df_train_meta", "df_test_meta", "ART_DIR"]:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan STAGE 0 -> 1 -> 2 -> 3 -> 4 -> 5 dulu.")

ART_DIR = Path(ART_DIR)

m_train = seq_manifest_train.copy()
m_test  = seq_manifest_test.copy()

SEQ_FEATURE_NAMES = list(SEQ_FEATURE_NAMES)
feat = {name: i for i, name in enumerate(SEQ_FEATURE_NAMES)}

# ----------------------------
# 0b) Detect token_mode (MAG vs ASINH)
# ----------------------------
SEQ_TOKEN_MODE = globals().get("SEQ_TOKEN_MODE", None)
if SEQ_TOKEN_MODE is None:
    if ("flux_asinh" in feat) and ("err_log1p" in feat):
        SEQ_TOKEN_MODE = "asinh"
    elif ("mag" in feat) and ("mag_err_log" in feat):
        SEQ_TOKEN_MODE = "mag"
    else:
        raise ValueError(
            "Cannot infer SEQ_TOKEN_MODE from SEQ_FEATURE_NAMES.\n"
            f"SEQ_FEATURE_NAMES={SEQ_FEATURE_NAMES}\n"
            "Expected either (flux_asinh, err_log1p) or (mag, mag_err_log)."
        )

REQ_COMMON = ["t_rel_log", "dt_log", "snr_tanh", "detected"]
for k in REQ_COMMON:
    if k not in feat:
        raise ValueError(f"SEQ_FEATURE_NAMES must include '{k}'. Found: {SEQ_FEATURE_NAMES}")

if SEQ_TOKEN_MODE == "asinh":
    if "flux_asinh" not in feat:
        raise ValueError("token_mode=asinh requires 'flux_asinh'.")
    SCORE_VALUE_FEAT = "flux_asinh"
elif SEQ_TOKEN_MODE == "mag":
    if "mag" not in feat:
        raise ValueError("token_mode=mag requires 'mag'.")
    SCORE_VALUE_FEAT = "mag"
else:
    raise ValueError(f"Unknown SEQ_TOKEN_MODE={SEQ_TOKEN_MODE}")

print(f"[Stage 6] token_mode={SEQ_TOKEN_MODE} | score_value_feat={SCORE_VALUE_FEAT} | F={len(SEQ_FEATURE_NAMES)}")

# ----------------------------
# 1) Settings
# ----------------------------
FORCE_MAX_LEN = None          # e.g. 256 (kalau mau paksa)
MAXLEN_CAPS = (256, 384, 512) # CPU-safe choices

# Score weights
W_SNR = 1.00
W_VAL = 0.35
W_DET = 0.25

# Padding policy
PAD_BAND_ID = 0

# AUTO: kalau shard punya band negatif (mis. -1 dari empty token), shift band ids otomatis
SHIFT_BAND_IDS = False
AUTO_SHIFT_IF_NEGATIVE_BANDS = True

# Build policy
REBUILD_MODE = "wipe_all"     # "wipe_all" atau "reuse_if_exists"
DTYPE_X = np.float32          # bisa fp16 kalau disk ketat

# ----------------------------
# 2) Inspect length distribution -> choose MAX_LEN
# ----------------------------
def describe_lengths(m: pd.DataFrame, name: str):
    L = pd.to_numeric(m["length"], errors="coerce").fillna(0).to_numpy(dtype=np.int32, copy=False)
    q = np.percentile(L, [0, 1, 5, 10, 25, 50, 75, 90, 95, 98, 99, 100])
    print(f"\n{name} length stats")
    print(f"- n_objects={len(L):,} | min={int(q[0])} | p50={int(q[5])} | p90={int(q[7])} | p95={int(q[8])} | p99={int(q[10])} | max={int(q[-1])}")
    return q

q_tr = describe_lengths(m_train, "TRAIN")
q_te = describe_lengths(m_test,  "TEST")

p95 = int(max(q_tr[8], q_te[8]))
if FORCE_MAX_LEN is not None:
    MAX_LEN = int(FORCE_MAX_LEN)
else:
    if p95 <= 256:
        MAX_LEN = 256
    elif p95 <= 384:
        MAX_LEN = 384
    else:
        MAX_LEN = 512

if MAX_LEN not in MAXLEN_CAPS and FORCE_MAX_LEN is None:
    MAX_LEN = int(min(MAXLEN_CAPS, key=lambda x: abs(x - MAX_LEN)))

print(f"\n[Stage 6] MAX_LEN={MAX_LEN} (based on p95={p95})")

# ----------------------------
# 3) Window scoring (adaptive)
# ----------------------------
def _brightness_proxy_from_mag(mag: np.ndarray) -> np.ndarray:
    mag = np.nan_to_num(mag, nan=np.float32(0.0), posinf=np.float32(0.0), neginf=np.float32(0.0)).astype(np.float32, copy=False)
    if mag.size == 0:
        return np.zeros_like(mag, dtype=np.float32)
    med = np.float32(np.median(mag))
    br = np.maximum(med - mag, np.float32(0.0))
    br = np.log1p(br).astype(np.float32, copy=False)
    return br

def _score_tokens(X: np.ndarray) -> np.ndarray:
    snr = np.abs(X[:, feat["snr_tanh"]]).astype(np.float32, copy=False)
    det = X[:, feat["detected"]].astype(np.float32, copy=False)

    if SEQ_TOKEN_MODE == "asinh":
        val = np.abs(X[:, feat["flux_asinh"]]).astype(np.float32, copy=False)
    else:
        mag = X[:, feat["mag"]].astype(np.float32, copy=False)
        val = _brightness_proxy_from_mag(mag)

    score = (np.float32(W_SNR) * snr) + (np.float32(W_VAL) * val) + (np.float32(W_DET) * det)
    score = np.nan_to_num(score, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
    return score

def select_best_window(score: np.ndarray, max_len: int) -> tuple[int, int]:
    L = int(score.shape[0])
    if L <= max_len:
        return 0, L

    cs = np.empty(L + 1, dtype=np.float32)
    cs[0] = 0.0
    np.cumsum(score.astype(np.float32, copy=False), out=cs[1:])

    ws = cs[max_len:] - cs[:-max_len]
    if not np.isfinite(ws).any():
        start = (L - max_len) // 2
    else:
        start = int(np.argmax(ws))
    end = start + max_len
    return start, end

def pad_to_fixed(X: np.ndarray, B: np.ndarray, max_len: int):
    L = int(X.shape[0])
    F = int(X.shape[1])

    Xp = np.zeros((max_len, F), dtype=DTYPE_X)
    Bp = np.full((max_len,), PAD_BAND_ID, dtype=np.int8)
    Mp = np.zeros((max_len,), dtype=np.int8)

    if L <= 0:
        return Xp, Bp, Mp, 0, 0, 0

    if L <= max_len:
        Xw = X
        Bw = B
        ws, we = 0, L
    else:
        sc = _score_tokens(X)
        ws, we = select_best_window(sc, max_len=max_len)
        Xw = X[ws:we]
        Bw = B[ws:we]

    lw = int(Xw.shape[0])
    Xp[:lw] = Xw.astype(DTYPE_X, copy=False)

    if SHIFT_BAND_IDS:
        # token band 1..K; pad=0; band=-1 -> 0
        Bw16 = Bw.astype(np.int16, copy=False)
        Bp[:lw] = (Bw16 + 1).astype(np.int8, copy=False)
    else:
        Bp[:lw] = Bw.astype(np.int8, copy=False)

    Mp[:lw] = 1
    return Xp, Bp, Mp, L, int(ws), int(we)

# ----------------------------
# 4) Fixed cache builder setup
# ----------------------------
FIX_DIR = Path(ART_DIR) / "fixed_seq"
FIX_DIR.mkdir(parents=True, exist_ok=True)

# robust ordering
train_ids = df_train_meta.index.astype("string").str.strip().astype(str).to_list()

# y column robust
_y_col = None
for cand in ["target", "y", "label", "class", "target_id"]:
    if cand in df_train_meta.columns:
        _y_col = cand
        break
if _y_col is None:
    raise RuntimeError(f"Cannot find target column in df_train_meta. cols={list(df_train_meta.columns)[:30]}")

y_train = pd.to_numeric(df_train_meta[_y_col], errors="coerce").fillna(0).astype(np.int16).to_numpy(copy=False)

def _try_load_sample_sub_ids():
    # 1) df_sub (kalau ada)
    if "df_sub" in globals() and isinstance(globals()["df_sub"], pd.DataFrame) and "object_id" in globals()["df_sub"].columns:
        return globals()["df_sub"]["object_id"].astype(str).str.strip().to_list()

    # 2) PATHS sample_submission
    if "PATHS" in globals() and isinstance(PATHS, dict):
        keys = ["SAMPLE_SUB", "SAMPLE_SUBMISSION", "sample_submission", "sample_sub", "SAMPLE"]
        for k in keys:
            p = PATHS.get(k, None)
            if p and Path(p).exists():
                df = pd.read_csv(p)
                if "object_id" in df.columns:
                    return df["object_id"].astype(str).str.strip().to_list()
    return None

test_ids = _try_load_sample_sub_ids()
if test_ids is None:
    test_ids = df_test_meta.index.astype("string").str.strip().astype(str).to_list()

# strict unique ids
if len(set(train_ids)) != len(train_ids):
    raise RuntimeError("train_ids contains duplicates. Check df_train_meta.index.")
if len(set(test_ids)) != len(test_ids):
    raise RuntimeError("test_ids contains duplicates. Check ordering source (df_sub/sample_sub/df_test_meta).")

train_row = {oid: i for i, oid in enumerate(train_ids)}
test_row  = {oid: i for i, oid in enumerate(test_ids)}

NTR = len(train_ids)
NTE = len(test_ids)
F = len(SEQ_FEATURE_NAMES)

def _gb(nbytes): return float(nbytes) / (1024**3)
size_tr = NTR * MAX_LEN * F * np.dtype(DTYPE_X).itemsize
size_te = NTE * MAX_LEN * F * np.dtype(DTYPE_X).itemsize
print(f"\n[Stage 6] Memmap X sizes approx: train={_gb(size_tr):.2f} GB | test={_gb(size_te):.2f} GB | dtype={DTYPE_X}")

# memmap paths
train_X_path = FIX_DIR / "train_X.dat"
train_B_path = FIX_DIR / "train_B.dat"
train_M_path = FIX_DIR / "train_M.dat"
test_X_path  = FIX_DIR / "test_X.dat"
test_B_path  = FIX_DIR / "test_B.dat"
test_M_path  = FIX_DIR / "test_M.dat"

train_len_path = FIX_DIR / "train_origlen.npy"
train_ws_path  = FIX_DIR / "train_winstart.npy"
train_we_path  = FIX_DIR / "train_winend.npy"
test_len_path  = FIX_DIR / "test_origlen.npy"
test_ws_path   = FIX_DIR / "test_winstart.npy"
test_we_path   = FIX_DIR / "test_winend.npy"

# ----------------------------
# 4b) Auto shift band ids if needed
# ----------------------------
if AUTO_SHIFT_IF_NEGATIVE_BANDS and (not SHIFT_BAND_IDS):
    try:
        # ambil 1 shard train untuk cek min band
        sp0 = str(m_train["shard"].astype(str).iloc[0])
        d0 = np.load(sp0, allow_pickle=False)
        b0 = d0["band"]
        bmin = int(np.min(b0)) if b0.size else 0
        del d0
        if bmin < 0:
            SHIFT_BAND_IDS = True
            print(f"[Stage 6] AUTO SHIFT_BAND_IDS=True (detected band min={bmin} in shard sample)")
    except Exception as e:
        print(f"[Stage 6] AUTO SHIFT check skipped ({type(e).__name__}: {e})")

# ----------------------------
# 4c) Rebuild handling
# ----------------------------
def _all_exist(paths):
    return all(Path(p).exists() for p in paths)

reuse_paths = [
    train_X_path, train_B_path, train_M_path,
    test_X_path, test_B_path, test_M_path,
    FIX_DIR / "train_ids.npy", FIX_DIR / "test_ids.npy", FIX_DIR / "train_y.npy",
    train_len_path, train_ws_path, train_we_path,
    test_len_path, test_ws_path, test_we_path,
    FIX_DIR / "length_policy_config.json"
]

if REBUILD_MODE == "reuse_if_exists" and _all_exist(reuse_paths):
    print("[Stage 6] REUSE (exists): fixed_seq cache already present.")
    globals().update({
        "FIX_DIR": FIX_DIR, "MAX_LEN": MAX_LEN,
        "FIX_TRAIN_X_PATH": train_X_path, "FIX_TRAIN_B_PATH": train_B_path, "FIX_TRAIN_M_PATH": train_M_path,
        "FIX_TEST_X_PATH": test_X_path,  "FIX_TEST_B_PATH": test_B_path,  "FIX_TEST_M_PATH": test_M_path,
        "FIX_TRAIN_Y_PATH": FIX_DIR / "train_y.npy",
        "FIX_TRAIN_IDS_PATH": FIX_DIR / "train_ids.npy",
        "FIX_TEST_IDS_PATH": FIX_DIR / "test_ids.npy",
        "FIX_POLICY_CFG_PATH": FIX_DIR / "length_policy_config.json",
        "SEQ_TOKEN_MODE": SEQ_TOKEN_MODE,
    })
    raise SystemExit

# ----------------------------
# 5) Create memmaps
# ----------------------------
Xtr = np.memmap(train_X_path, dtype=DTYPE_X, mode="w+", shape=(NTR, MAX_LEN, F))
Btr = np.memmap(train_B_path, dtype=np.int8,  mode="w+", shape=(NTR, MAX_LEN))
Mtr = np.memmap(train_M_path, dtype=np.int8,  mode="w+", shape=(NTR, MAX_LEN))

Xte = np.memmap(test_X_path, dtype=DTYPE_X, mode="w+", shape=(NTE, MAX_LEN, F))
Bte = np.memmap(test_B_path, dtype=np.int8,  mode="w+", shape=(NTE, MAX_LEN))
Mte = np.memmap(test_M_path, dtype=np.int8,  mode="w+", shape=(NTE, MAX_LEN))

origlen_tr = np.zeros((NTR,), dtype=np.int32)
winstart_tr = np.zeros((NTR,), dtype=np.int32)
winend_tr   = np.zeros((NTR,), dtype=np.int32)

origlen_te = np.zeros((NTE,), dtype=np.int32)
winstart_te = np.zeros((NTE,), dtype=np.int32)
winend_te   = np.zeros((NTE,), dtype=np.int32)

filled_tr = np.zeros((NTR,), dtype=np.uint8)
filled_te = np.zeros((NTE,), dtype=np.uint8)

# ----------------------------
# 6) Fill memmaps per shard (fast path)
# ----------------------------
def process_manifest_into_memmap(m: pd.DataFrame, which: str):
    if which == "train":
        row_map = train_row
        Xmm, Bmm, Mmm = Xtr, Btr, Mtr
        origlen, ws_arr, we_arr = origlen_tr, winstart_tr, winend_tr
        filled_mask = filled_tr
        expected_n = NTR
    else:
        row_map = test_row
        Xmm, Bmm, Mmm = Xte, Bte, Mte
        origlen, ws_arr, we_arr = origlen_te, winstart_te, winend_te
        filled_mask = filled_te
        expected_n = NTE

    for c in ["object_id", "shard", "start", "length"]:
        if c not in m.columns:
            raise RuntimeError(f"Manifest missing column '{c}'. cols={list(m.columns)}")

    shard_paths = m["shard"].astype(str).unique().tolist()
    miss_sh = [p for p in shard_paths if not Path(p).exists()]
    if miss_sh:
        raise RuntimeError(f"Missing shard files ({which}): count={len(miss_sh)} | ex={miss_sh[:5]}")

    filled = 0
    dup = 0
    empty = 0

    t0 = time.time()
    for shard_path in sorted(shard_paths):
        g = m[m["shard"].astype(str) == shard_path]
        if g.empty:
            continue

        data = np.load(shard_path, allow_pickle=False)
        x_all = data["x"]
        b_all = data["band"]

        # mapping object_id -> row index (robust int, na=-1)
        oids = g["object_id"].astype(str).to_numpy()
        idxs = pd.Series(oids).map(row_map).astype("Int64").to_numpy(dtype=np.int64, na_value=-1)

        starts = pd.to_numeric(g["start"], errors="coerce").fillna(-1).to_numpy(dtype=np.int64, copy=False)
        lens   = pd.to_numeric(g["length"], errors="coerce").fillna(0).to_numpy(dtype=np.int64, copy=False)

        valid = (idxs >= 0) & (starts >= 0) & (lens >= 0)
        if not valid.any():
            del data
            continue

        idxs_v = idxs[valid]
        starts_v = starts[valid]
        lens_v = lens[valid]
        oids_v = oids[valid]

        for oid, idx, st, ln in zip(oids_v, idxs_v, starts_v, lens_v):
            if ln <= 0:
                empty += 1
                continue
            if filled_mask[idx]:
                dup += 1
                continue

            end = int(st + ln)
            if st < 0 or end > x_all.shape[0] or end > b_all.shape[0]:
                raise RuntimeError(
                    f"[Stage 6] Out-of-range slice in shard={shard_path}\n"
                    f"- oid={oid} idx={idx} start={st} len={ln} end={end}\n"
                    f"- shard_x_len={x_all.shape[0]} shard_b_len={b_all.shape[0]}"
                )

            X = x_all[st:end]
            B = b_all[st:end]

            Xp, Bp, Mp, L0, ws, we = pad_to_fixed(X, B, max_len=MAX_LEN)

            Xmm[idx, :, :] = Xp
            Bmm[idx, :] = Bp
            Mmm[idx, :] = Mp
            origlen[idx] = int(L0)
            ws_arr[idx] = int(ws)
            we_arr[idx] = int(we)
            filled_mask[idx] = 1
            filled += 1

        del data
        if filled % 2000 == 0:
            gc.collect()

    elapsed = time.time() - t0
    return {"filled": int(filled), "dup_skipped": int(dup), "empty_len": int(empty), "time_s": float(elapsed), "expected": int(expected_n)}

print("\n[Stage 6] Building fixed cache (TRAIN)...")
st_tr = process_manifest_into_memmap(m_train, "train")
print(f"[Stage 6] TRAIN filled={st_tr['filled']:,}/{st_tr['expected']:,} | dup={st_tr['dup_skipped']:,} | empty={st_tr['empty_len']:,} | time={st_tr['time_s']:.2f}s")

print("\n[Stage 6] Building fixed cache (TEST)...")
st_te = process_manifest_into_memmap(m_test, "test")
print(f"[Stage 6] TEST  filled={st_te['filled']:,}/{st_te['expected']:,} | dup={st_te['dup_skipped']:,} | empty={st_te['empty_len']:,} | time={st_te['time_s']:.2f}s")

Xtr.flush(); Btr.flush(); Mtr.flush()
Xte.flush(); Bte.flush(); Mte.flush()

# ----------------------------
# 7) Hard sanity: must be 100% filled
# ----------------------------
miss_tr = np.where(filled_tr == 0)[0]
miss_te = np.where(filled_te == 0)[0]
if len(miss_tr) > 0:
    ex = [train_ids[i] for i in miss_tr[:10]]
    raise RuntimeError(f"[Stage 6] TRAIN missing filled rows: {len(miss_tr):,}/{NTR:,} | ex={ex}")
if len(miss_te) > 0:
    ex = [test_ids[i] for i in miss_te[:10]]
    raise RuntimeError(f"[Stage 6] TEST missing filled rows: {len(miss_te):,}/{NTE:,} | ex={ex}")

# ----------------------------
# 8) Save ids + y + meta arrays
# ----------------------------
np.save(FIX_DIR / "train_ids.npy", np.asarray(train_ids, dtype="S"))
np.save(FIX_DIR / "test_ids.npy",  np.asarray(test_ids, dtype="S"))
np.save(FIX_DIR / "train_y.npy",   y_train)

np.save(train_len_path, origlen_tr)
np.save(train_ws_path,  winstart_tr)
np.save(train_we_path,  winend_tr)

np.save(test_len_path,  origlen_te)
np.save(test_ws_path,   winstart_te)
np.save(test_we_path,   winend_te)

# ----------------------------
# 9) Quick sanity samples
# ----------------------------
def sanity_samples(which: str, n_show: int = 3, seed: int = 2025):
    rng = np.random.default_rng(seed)
    if which == "train":
        Xmm, Bmm, Mmm = Xtr, Btr, Mtr
        ids = train_ids
        ol = origlen_tr
    else:
        Xmm, Bmm, Mmm = Xte, Bte, Mte
        ids = test_ids
        ol = origlen_te

    idxs = rng.choice(len(ids), size=min(n_show, len(ids)), replace=False)
    print(f"\n[Stage 6] Sanity samples ({which}):")
    for i in idxs:
        kept = int(Mmm[i].sum())
        bands = sorted(set(Bmm[i, :kept].tolist())) if kept > 0 else []
        print(f"- idx={i} oid={ids[i]} orig_len={int(ol[i])} kept={kept} bands_unique={bands}")

sanity_samples("train", 3)
sanity_samples("test", 3)

# ----------------------------
# 10) Save config
# ----------------------------
policy_cfg = {
    "token_mode": SEQ_TOKEN_MODE,
    "max_len": int(MAX_LEN),
    "feature_names": list(SEQ_FEATURE_NAMES),
    "score_weights": {"W_SNR": float(W_SNR), "W_VAL": float(W_VAL), "W_DET": float(W_DET)},
    "score_value_feat": SCORE_VALUE_FEAT,
    "window_policy": "best_contiguous_window_by_max_sum(score)",
    "padding": {"PAD_BAND_ID": int(PAD_BAND_ID), "SHIFT_BAND_IDS": bool(SHIFT_BAND_IDS)},
    "dtype_X": str(DTYPE_X),
    "order": {
        "train": "df_train_meta.index",
        "test": ("df_sub.object_id" if ("df_sub" in globals() and isinstance(df_sub, pd.DataFrame) and "object_id" in df_sub.columns) else "df_test_meta.index / sample_submission fallback"),
        "y_col": str(_y_col),
    },
    "stats": {"train": st_tr, "test": st_te},
    "files": {
        "train_X": str(train_X_path), "train_B": str(train_B_path), "train_M": str(train_M_path),
        "test_X": str(test_X_path),   "test_B": str(test_B_path),   "test_M": str(test_M_path),
        "train_y": str(FIX_DIR / "train_y.npy"),
        "train_ids": str(FIX_DIR / "train_ids.npy"),
        "test_ids": str(FIX_DIR / "test_ids.npy"),
        "train_origlen": str(train_len_path), "train_winstart": str(train_ws_path), "train_winend": str(train_we_path),
        "test_origlen": str(test_len_path),   "test_winstart": str(test_ws_path),   "test_winend": str(test_we_path),
    }
}
cfg_path = FIX_DIR / "length_policy_config.json"
with open(cfg_path, "w", encoding="utf-8") as f:
    json.dump(policy_cfg, f, indent=2)

print("\n[Stage 6] DONE")
print(f"- FIX_DIR: {FIX_DIR}")
print(f"- Saved config: {cfg_path}")

globals().update({
    "FIX_DIR": FIX_DIR,
    "MAX_LEN": MAX_LEN,
    "FIX_TRAIN_X_PATH": train_X_path,
    "FIX_TRAIN_B_PATH": train_B_path,
    "FIX_TRAIN_M_PATH": train_M_path,
    "FIX_TEST_X_PATH": test_X_path,
    "FIX_TEST_B_PATH": test_B_path,
    "FIX_TEST_M_PATH": test_M_path,
    "FIX_TRAIN_Y_PATH": FIX_DIR / "train_y.npy",
    "FIX_TRAIN_IDS_PATH": FIX_DIR / "train_ids.npy",
    "FIX_TEST_IDS_PATH": FIX_DIR / "test_ids.npy",
    "FIX_POLICY_CFG_PATH": cfg_path,
    "SEQ_TOKEN_MODE": SEQ_TOKEN_MODE,
    "SHIFT_BAND_IDS": SHIFT_BAND_IDS,
})

gc.collect()


[Stage 6] token_mode=mag | score_value_feat=mag | F=6

TRAIN length stats
- n_objects=3,043 | min=17 | p50=150 | p90=183 | p95=194 | p99=256 | max=256

TEST length stats
- n_objects=7,135 | min=18 | p50=152 | p90=183 | p95=193 | p99=256 | max=256

[Stage 6] MAX_LEN=256 (based on p95=194)

[Stage 6] Memmap X sizes approx: train=0.02 GB | test=0.04 GB | dtype=<class 'numpy.float32'>

[Stage 6] Building fixed cache (TRAIN)...
[Stage 6] TRAIN filled=3,043/3,043 | dup=0 | empty=0 | time=0.15s

[Stage 6] Building fixed cache (TEST)...
[Stage 6] TEST  filled=7,135/7,135 | dup=0 | empty=0 | time=0.29s

[Stage 6] Sanity samples (train):
- idx=1360 oid=gwilwileth_adel_amloth orig_len=157 kept=157 bands_unique=[0, 1, 2, 3, 4, 5]
- idx=3020 oid=vin_araf_gwador orig_len=151 kept=151 bands_unique=[0, 1, 2, 3, 4, 5]
- idx=3025 oid=ylf_alph_mindon orig_len=167 kept=167 bands_unique=[0, 1, 2, 3, 4, 5]

[Stage 6] Sanity samples (test):
- idx=3191 oid=rom_bellas_lebdas orig_len=142 kept=142 bands_unique=

616

# CV Split (Object-Level, Stratified)

In [13]:
# ============================================================
# STAGE 7 — CV Split (Object-Level, Stratified) (ONE CELL, Kaggle CPU-SAFE)
# REVISI FULL v2.2 (HARDENED + HOLDOUT FALLBACK)
#
# Output:
# - artifacts/cv/cv_folds.csv
# - artifacts/cv/cv_folds.npz   (train_idx_f + val_idx_f)
# - artifacts/cv/cv_report.txt
# - artifacts/cv/cv_config.json
# - globals: fold_assign, folds, n_splits, CV_DIR
# ============================================================

import gc, json, time
from pathlib import Path
import numpy as np
import pandas as pd

# ----------------------------
# 0) Require minimal globals
# ----------------------------
for need in ["df_train_meta", "ART_DIR"]:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan minimal STAGE 2 dulu (df_train_meta & ART_DIR).")

SEED = int(globals().get("SEED", 2025))
ART_DIR = Path(ART_DIR)

# ----------------------------
# 1) CV Settings
# ----------------------------
DEFAULT_SPLITS = 5
FORCE_N_SPLITS = None              # set int kalau mau paksa (mis. 3), else None
MIN_POS_PER_FOLD = 3               # stabilitas; 3–10 umum
ENFORCE_MIN_POS_PER_FOLD = True    # kalau True: n_splits turun otomatis sampai min_pos>=MIN_POS_PER_FOLD (atau fallback holdout)
USE_GROUP_BY_SPLIT = False         # True => prefer StratifiedGroupKFold (groups=df_train_meta["split"])
AUTO_FALLBACK_GROUP = True         # True => kalau group-cv tidak bisa, fallback ke StratifiedKFold
HOLDOUT_FALLBACK = True            # True => kalau CV tidak mungkin, pakai 1 fold holdout (n_splits=1)
HOLDOUT_FRAC = 0.20                # target val fraction untuk holdout

print(f"[Stage 7] seed={SEED} | default_splits={DEFAULT_SPLITS} | MIN_POS_PER_FOLD={MIN_POS_PER_FOLD} "
      f"| enforce_minpos={ENFORCE_MIN_POS_PER_FOLD} | group_by_split={USE_GROUP_BY_SPLIT} | fallback_group={AUTO_FALLBACK_GROUP}")

# ----------------------------
# 2) Helpers
# ----------------------------
def _decode_ids(arr) -> list:
    out = []
    for x in arr.tolist():
        if isinstance(x, (bytes, bytearray)):
            s = x.decode("utf-8", errors="ignore")
        else:
            s = str(x)
        out.append(s.strip())
    return out

def _find_train_ids_npy(art_dir: Path) -> Path | None:
    # priority 1: FIX_DIR
    if "FIX_DIR" in globals():
        p = Path(globals()["FIX_DIR"]) / "train_ids.npy"
        if p.exists():
            return p

    # priority 2: ART_DIR/fixed_seq
    p = art_dir / "fixed_seq" / "train_ids.npy"
    if p.exists():
        return p

    # priority 3: scan mallorn_run runs (latest mtime)
    root = Path("/kaggle/working/mallorn_run")
    if root.exists():
        cands = list(root.glob("run_*/artifacts/fixed_seq/train_ids.npy"))
        if cands:
            cands = sorted(cands, key=lambda x: x.stat().st_mtime, reverse=True)
            return cands[0]
    return None

def _safe_str_index(idx: pd.Index) -> pd.Index:
    return pd.Index([str(x).strip() for x in idx], dtype="object")

# ----------------------------
# 3) Determine train_ids ordering (prefer fixed cache from STAGE 6)
# ----------------------------
train_ids = None
order_source = "df_train_meta.index"

p_ids = _find_train_ids_npy(ART_DIR)
if p_ids is not None:
    raw = np.load(p_ids, allow_pickle=False)
    train_ids = _decode_ids(raw)
    order_source = str(p_ids)
else:
    train_ids = [str(x).strip() for x in df_train_meta.index.astype(str).tolist()]
    order_source = "df_train_meta.index"

# uniqueness check train_ids
if len(train_ids) != len(set(train_ids)):
    s = pd.Series(train_ids)
    dup = s[s.duplicated()].iloc[:10].tolist()
    raise RuntimeError(f"[Stage 7] train_ids has duplicates (examples): {dup}")

N = len(train_ids)

# ----------------------------
# 4) Build robust mapping: train_id -> row position in df_train_meta
#    (tidak mengharuskan df_train_meta.index sudah string)
# ----------------------------
meta = df_train_meta

meta_idx_str = _safe_str_index(meta.index)
if meta_idx_str.has_duplicates:
    d = pd.Series(meta_idx_str).value_counts()
    dup = d[d > 1].index.tolist()[:10]
    raise RuntimeError(f"[Stage 7] df_train_meta index has duplicates after str/strip (examples): {dup}")

pos_map = pd.Series(np.arange(len(meta), dtype=np.int32), index=meta_idx_str)

pos_s = pos_map.reindex(train_ids)
missing = pos_s[pos_s.isna()].index.tolist()
if missing:
    ex = missing[:10]
    raise RuntimeError(
        "[Stage 7] Some train_ids not found in df_train_meta (after str/strip index).\n"
        f"Missing count={len(missing)} | ex={ex}\n"
        "Solusi: pastikan df_train_meta memang object-level meta dan index-nya object_id."
    )

pos_idx = pos_s.astype(np.int32).to_numpy()

# ----------------------------
# 5) Robust target column -> y (ordered by train_ids)
# ----------------------------
target_col = None
for cand in ["target", "y", "label", "class", "is_tde", "binary_target", "target_id"]:
    if cand in meta.columns:
        target_col = cand
        break
if target_col is None:
    raise RuntimeError(f"[Stage 7] Cannot find target column in df_train_meta. cols(sample)={list(meta.columns)[:40]}")

y_all = pd.to_numeric(meta[target_col], errors="coerce").fillna(0).astype(np.int16).to_numpy(copy=False)
y = y_all[pos_idx]
y = (y > 0).astype(np.int8)  # force 0/1

pos = int((y == 1).sum())
neg = int((y == 0).sum())
if pos == 0 or neg == 0:
    raise RuntimeError(f"[Stage 7] Invalid class distribution: pos={pos}, neg={neg}. Cannot do stratified split.")

# ----------------------------
# 6) Optional groups (by split)
# ----------------------------
groups = None
group_col = None

if USE_GROUP_BY_SPLIT:
    for cand in ["split", "split_id", "split_name", "split_idx"]:
        if cand in meta.columns:
            group_col = cand
            break

    if group_col is None:
        if not AUTO_FALLBACK_GROUP:
            raise RuntimeError("[Stage 7] USE_GROUP_BY_SPLIT=True but no split column found in df_train_meta.")
        print("[Stage 7] WARN: split column not found; fallback to StratifiedKFold.")
        USE_GROUP_BY_SPLIT = False
    else:
        g_all = meta[group_col].astype(str).to_numpy()
        groups = g_all[pos_idx]

# ----------------------------
# 7) Choose n_splits safely + auto-adjust
# ----------------------------
max_splits_by_pos = pos
max_splits_by_neg = neg
max_splits_by_minpos = max(1, pos // max(int(MIN_POS_PER_FOLD), 1))

n0 = min(DEFAULT_SPLITS, max_splits_by_pos, max_splits_by_neg, max_splits_by_minpos)
if FORCE_N_SPLITS is not None:
    n0 = int(FORCE_N_SPLITS)

print(f"[Stage 7] Candidate n_splits={n0} | N={N:,} pos={pos:,} neg={neg:,} pos%={pos/max(N,1)*100:.6f}% | order_source={order_source}")

# ----------------------------
# 8) Build folds (sklearn) with robust fallback
# ----------------------------
try:
    from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
    try:
        from sklearn.model_selection import StratifiedGroupKFold
    except Exception:
        StratifiedGroupKFold = None
except Exception as e:
    raise RuntimeError("scikit-learn is not available in this environment.") from e

def _try_split_kfold(k: int, use_group: bool):
    fold_assign = np.full(N, -1, dtype=np.int16)
    folds = []
    per = []

    if use_group:
        if StratifiedGroupKFold is None:
            return (False, "StratifiedGroupKFold(unavailable)", None, None, None)
        splitter = StratifiedGroupKFold(n_splits=k, shuffle=True, random_state=SEED)
        split_iter = splitter.split(np.zeros(N), y, groups=groups)
        cv_type = f"StratifiedGroupKFold({group_col})"
    else:
        splitter = StratifiedKFold(n_splits=k, shuffle=True, random_state=SEED)
        split_iter = splitter.split(np.zeros(N), y)
        cv_type = "StratifiedKFold"

    try:
        for fold, (tr_idx, val_idx) in enumerate(split_iter):
            fold_assign[val_idx] = fold
            yf = y[val_idx]
            pf = int((yf == 1).sum())
            nf = int((yf == 0).sum())
            per.append((len(val_idx), pf, nf))
            folds.append({
                "fold": int(fold),
                "train_idx": tr_idx.astype(np.int32),
                "val_idx": val_idx.astype(np.int32),
            })
    except Exception as e:
        return (False, f"{cv_type} (error: {type(e).__name__})", None, None, None)

    if (fold_assign < 0).any():
        return (False, f"{cv_type} (unassigned)", None, None, None)

    # hard check: each fold must have pos>=1 and neg>=1
    for (_, pf, nf) in per:
        if pf == 0 or nf == 0:
            return (False, f"{cv_type} (empty class in fold)", None, None, None)

    return (True, cv_type, fold_assign, folds, per)

def _make_holdout():
    # 1 split holdout with StratifiedShuffleSplit
    # choose val size that ensures at least 1 pos and 1 neg in val (and in train)
    n_pos = pos
    n_neg = neg
    n = N

    # start from HOLDOUT_FRAC
    val_n = int(round(n * float(HOLDOUT_FRAC)))
    val_n = max(val_n, 2)  # at least 2 samples
    val_n = min(val_n, n - 2)

    # ensure possible: need at least 1 pos and 1 neg in val
    val_n = max(val_n, 2)
    if n_pos == 1 or n_neg == 1:
        # still possible but fragile, keep val small
        val_n = 2

    # loop adjust if impossible
    def feasible(vn: int) -> bool:
        # need vn >=2 and vn <= n-2
        if vn < 2 or vn > n - 2:
            return False
        # can we place at least 1 pos and 1 neg into val and keep at least 1 pos/neg in train?
        return (n_pos >= 2 and n_neg >= 2) or ((n_pos >= 1 and n_neg >= 1) and (n_pos - 1 >= 1) and (n_neg - 1 >= 1))

    if not feasible(val_n):
        # minimal viable with both classes in train and val requires pos>=2 and neg>=2
        if n_pos < 2 or n_neg < 2:
            raise RuntimeError(
                f"[Stage 7] Cannot build even holdout split safely. Need pos>=2 and neg>=2. Got pos={n_pos}, neg={n_neg}."
            )
        val_n = 2

    test_size = val_n / n
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=test_size, random_state=SEED)
    tr_idx, val_idx = next(splitter.split(np.zeros(n), y))

    fold_assign = np.full(n, -1, dtype=np.int16)
    fold_assign[val_idx] = 0

    yf = y[val_idx]
    per = [(len(val_idx), int((yf == 1).sum()), int((yf == 0).sum()))]

    folds = [{
        "fold": 0,
        "train_idx": tr_idx.astype(np.int32),
        "val_idx": val_idx.astype(np.int32),
    }]
    return 1, "Holdout(StratifiedShuffleSplit)", fold_assign, folds, per

best = None
used_group = bool(USE_GROUP_BY_SPLIT)

if n0 >= 2:
    for k in range(n0, 1, -1):
        ok, cv_type, fa, folds, per = _try_split_kfold(k, use_group=used_group)

        if (not ok) and used_group and AUTO_FALLBACK_GROUP:
            ok2, cv_type2, fa2, folds2, per2 = _try_split_kfold(k, use_group=False)
            if ok2:
                ok, cv_type, fa, folds, per = ok2, cv_type2, fa2, folds2, per2
                used_group = False

        if not ok:
            continue

        min_pos_seen = min(pf for (_, pf, _) in per) if per else 0
        if ENFORCE_MIN_POS_PER_FOLD and (min_pos_seen < MIN_POS_PER_FOLD) and (FORCE_N_SPLITS is None):
            continue

        best = (k, cv_type, fa, folds, per, min_pos_seen)
        break

# if enforce failed completely, pick first valid (pos>=1 in each fold)
if best is None and n0 >= 2:
    for k in range(n0, 1, -1):
        ok, cv_type, fa, folds, per = _try_split_kfold(k, use_group=bool(USE_GROUP_BY_SPLIT))
        if (not ok) and USE_GROUP_BY_SPLIT and AUTO_FALLBACK_GROUP:
            ok, cv_type, fa, folds, per = _try_split_kfold(k, use_group=False)
        if ok:
            min_pos_seen = min(pf for (_, pf, _) in per) if per else 0
            best = (k, cv_type, fa, folds, per, min_pos_seen)
            print(f"[Stage 7] NOTE: Could not satisfy MIN_POS_PER_FOLD={MIN_POS_PER_FOLD}. Using k={k} with min_pos={min_pos_seen}.")
            break

# fallback to holdout if still none
if best is None:
    if HOLDOUT_FALLBACK:
        n_splits, cv_type, fold_assign, folds, per = _make_holdout()
        min_pos_seen = per[0][1] if per else 0
        best = (n_splits, cv_type, fold_assign, folds, per, min_pos_seen)
        print(f"[Stage 7] FALLBACK -> {cv_type} | val_pos={min_pos_seen}")
    else:
        raise RuntimeError("[Stage 7] Failed to build a valid CV split. Try smaller DEFAULT_SPLITS / FORCE_N_SPLITS, or enable HOLDOUT_FALLBACK.")

n_splits, cv_type, fold_assign, folds, per, min_pos_seen = best

# ----------------------------
# 9) Report + validation
# ----------------------------
lines = []
lines.append(f"CV={cv_type} n_splits={n_splits} seed={SEED}")
lines.append(f"Order source: {order_source}")
lines.append(f"Target column: {target_col}")
lines.append(f"Total: N={N} | pos={pos} | neg={neg} | pos%={pos/max(N,1)*100:.6f}%")
if USE_GROUP_BY_SPLIT:
    lines.append(f"Group col requested: {group_col} | used_group={('Group' in cv_type)}")
lines.append("Per-fold distribution (val):")

ok = True
for f in range(n_splits):
    idx = np.where(fold_assign == f)[0]
    yf = y[idx]
    pf = int((yf == 1).sum())
    nf = int((yf == 0).sum())
    lines.append(f"- fold {f}: n={len(idx):6d} | pos={pf:5d} | neg={nf:6d} | pos%={(pf/max(len(idx),1))*100:9.6f}%")
    if pf == 0 or nf == 0:
        ok = False

if not ok:
    raise RuntimeError("[Stage 7] A fold has pos=0 or neg=0 after selection (should not happen).")

if (fold_assign < 0).any():
    bad = np.where(fold_assign < 0)[0][:10]
    ex = [train_ids[i] for i in bad]
    raise RuntimeError(f"[Stage 7] Unassigned fold entries detected: count={(fold_assign<0).sum()} | ex={ex}")

if min_pos_seen < MIN_POS_PER_FOLD and n_splits >= 2:
    lines.append(f"NOTE: min positives in a fold = {min_pos_seen} (< MIN_POS_PER_FOLD={MIN_POS_PER_FOLD}). "
                 "Threshold/F1 tuning bisa noisy; pertimbangkan n_splits lebih kecil.")

print(f"[Stage 7] FINAL: n_splits={n_splits} | cv_type={cv_type} | min_pos_in_fold={min_pos_seen}")

# ----------------------------
# 10) Save artifacts
# ----------------------------
CV_DIR = ART_DIR / "cv"
CV_DIR.mkdir(parents=True, exist_ok=True)

df_folds = pd.DataFrame({"object_id": train_ids, "fold": fold_assign.astype(int)})
folds_csv = CV_DIR / "cv_folds.csv"
df_folds.to_csv(folds_csv, index=False)

npz_path = CV_DIR / "cv_folds.npz"
npz_kwargs = {}
for f in range(n_splits):
    npz_kwargs[f"train_idx_{f}"] = folds[f]["train_idx"]
    npz_kwargs[f"val_idx_{f}"]   = folds[f]["val_idx"]
np.savez(npz_path, **npz_kwargs)

report_path = CV_DIR / "cv_report.txt"
with open(report_path, "w", encoding="utf-8") as f:
    f.write("\n".join(lines) + "\n")

cfg_path = CV_DIR / "cv_config.json"
with open(cfg_path, "w", encoding="utf-8") as f:
    json.dump(
        {
            "seed": SEED,
            "n_splits": int(n_splits),
            "cv_type": cv_type,
            "min_pos_per_fold": int(MIN_POS_PER_FOLD),
            "enforce_min_pos_per_fold": bool(ENFORCE_MIN_POS_PER_FOLD),
            "use_group_by_split_requested": bool(USE_GROUP_BY_SPLIT),
            "auto_fallback_group": bool(AUTO_FALLBACK_GROUP),
            "holdout_fallback": bool(HOLDOUT_FALLBACK),
            "order_source": order_source,
            "target_col": target_col,
            "group_col": group_col,
            "artifacts": {
                "folds_csv": str(folds_csv),
                "folds_npz": str(npz_path),
                "report_txt": str(report_path),
            },
        },
        f,
        indent=2,
    )

print("\n[Stage 7] CV split OK")
print(f"- Saved: {folds_csv}")
print(f"- Saved: {npz_path}")
print(f"- Saved: {report_path}")
print(f"- Saved: {cfg_path}")

# print tail report lines
tail_n = min(len(lines), n_splits + 6)
print("\n".join(lines[-tail_n:]))

# ----------------------------
# 11) Export globals for next stage
# ----------------------------
globals().update({
    "CV_DIR": CV_DIR,
    "n_splits": int(n_splits),
    "train_ids_ordered": train_ids,
    "y_ordered": y,
    "fold_assign": fold_assign,
    "folds": folds,
    "CV_FOLDS_CSV": folds_csv,
    "CV_FOLDS_NPZ": npz_path,
    "CV_CFG_PATH": cfg_path,
    "CV_TYPE": cv_type,
    "CV_ORDER_SOURCE": order_source,
})

gc.collect()


[Stage 7] seed=2025 | default_splits=5 | MIN_POS_PER_FOLD=3 | enforce_minpos=True | group_by_split=False | fallback_group=True
[Stage 7] Candidate n_splits=5 | N=3,043 pos=148 neg=2,895 pos%=4.863621% | order_source=/kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/fixed_seq/train_ids.npy
[Stage 7] FINAL: n_splits=5 | cv_type=StratifiedKFold | min_pos_in_fold=29

[Stage 7] CV split OK
- Saved: /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/cv/cv_folds.csv
- Saved: /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/cv/cv_folds.npz
- Saved: /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/cv/cv_report.txt
- Saved: /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/cv/cv_config.json
CV=StratifiedKFold n_splits=5 seed=2025
Order source: /kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/artifacts/fixed_seq/train_ids.npy
Target column: target
Total: N=3043 | pos=148 | neg=2895 | pos%=4.8

33

# Train Model (CPU-Safe Configuration)

In [None]:
# ============================================================
# STAGE 8 — Train Multiband Event Transformer (CPU-Safe)
# REVISI FULL v3.1 (BOOST + NO-LEAK + OneCycle FIX + SHIFT_BAND_IDS OK)
#
# Output:
# - checkpoints/fold_*.pt
# - oof/oof_prob.npy + oof/oof_prob.csv
# - oof/fold_metrics.json
# - logs/train_cfg_stage8.json + global_feature_spec.json
# ============================================================

import os, gc, json, math, time, warnings
from pathlib import Path

import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)

# ----------------------------
# 0) Require minimal previous stages
# ----------------------------
need_min = ["FIX_DIR","MAX_LEN","SEQ_FEATURE_NAMES","df_train_meta","n_splits","folds"]
for k in need_min:
    if k not in globals():
        raise RuntimeError(f"Missing `{k}`. Jalankan STAGE 0..7 dulu dengan urutan benar.")

# ----------------------------
# 0a) Resolve train_ids ordering + labels (robust)
# ----------------------------
def _decode_ids(arr):
    out = []
    for x in arr.tolist():
        if isinstance(x, (bytes, bytearray)):
            s = x.decode("utf-8", errors="ignore")
        else:
            s = str(x)
        out.append(s.strip())
    return out

# ordering
if "train_ids_ordered" in globals() and globals()["train_ids_ordered"] is not None:
    train_ids = [str(x).strip() for x in list(globals()["train_ids_ordered"])]
else:
    p = Path(globals()["FIX_DIR"]) / "train_ids.npy"
    if p.exists():
        raw = np.load(p, allow_pickle=False)
        train_ids = _decode_ids(raw if raw.dtype.kind in ("S","O") else raw.astype(str))
    else:
        train_ids = [str(x).strip() for x in df_train_meta.index.astype(str).tolist()]

# target column robust
target_col = None
for cand in ["target","y","label","class","is_tde","binary_target","target_id"]:
    if cand in df_train_meta.columns:
        target_col = cand
        break
if target_col is None:
    raise RuntimeError(f"Cannot find target column in df_train_meta. cols(sample)={list(df_train_meta.columns)[:40]}")

# robust mapping by stringified index (avoid index dtype mismatch)
meta_idx_str = pd.Index([str(x).strip() for x in df_train_meta.index], dtype="object")
pos_map = pd.Series(np.arange(len(df_train_meta), dtype=np.int32), index=meta_idx_str)
pos_idx = pos_map.reindex(train_ids).to_numpy()
if np.isnan(pos_idx).any():
    miss = [train_ids[i] for i in np.where(np.isnan(pos_idx))[0][:10]]
    raise RuntimeError(f"Some train_ids not found in df_train_meta.index (string-mapped). ex={miss}")

pos_idx = pos_idx.astype(np.int32)

y_all = pd.to_numeric(df_train_meta[target_col], errors="coerce").fillna(0).astype(np.int16).to_numpy()
y = y_all[pos_idx]
y = (y > 0).astype(np.int8)

# ----------------------------
# 0b) Ensure output dirs exist
# ----------------------------
if "RUN_DIR" in globals() and globals()["RUN_DIR"] is not None:
    RUN_DIR = Path(globals()["RUN_DIR"])
else:
    if "ART_DIR" in globals() and globals()["ART_DIR"] is not None:
        RUN_DIR = Path(globals()["ART_DIR"]).parent
    else:
        RUN_DIR = Path("/kaggle/working/mallorn_run")

ART_DIR = Path(globals().get("ART_DIR", RUN_DIR / "artifacts"))
ART_DIR.mkdir(parents=True, exist_ok=True)

CKPT_DIR = Path(globals().get("CKPT_DIR", RUN_DIR / "checkpoints"))
OOF_DIR  = Path(globals().get("OOF_DIR",  RUN_DIR / "oof"))
LOG_DIR  = Path(globals().get("LOG_DIR",  RUN_DIR / "logs"))
CKPT_DIR.mkdir(parents=True, exist_ok=True)
OOF_DIR.mkdir(parents=True, exist_ok=True)
LOG_DIR.mkdir(parents=True, exist_ok=True)

globals().update({"RUN_DIR": RUN_DIR, "ART_DIR": ART_DIR, "CKPT_DIR": CKPT_DIR, "OOF_DIR": OOF_DIR, "LOG_DIR": LOG_DIR})

# ----------------------------
# 1) Torch imports + CPU safety
# ----------------------------
try:
    import torch
    import torch.nn as nn
except Exception as e:
    raise RuntimeError("PyTorch tidak tersedia di environment ini.") from e

SEED = int(globals().get("SEED", 2025))
torch.manual_seed(SEED)
np.random.seed(SEED)

device = torch.device("cpu")

# thread guard
try:
    torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", "2")))
    torch.set_num_interop_threads(1)
except Exception:
    pass

try:
    from sklearn.metrics import roc_auc_score
except Exception as e:
    raise RuntimeError("scikit-learn metrics tidak tersedia.") from e

# ----------------------------
# 2) Open memmaps (fixed seq) — NO RAM load
# ----------------------------
FIX_DIR = Path(globals()["FIX_DIR"])
N = len(train_ids)
L = int(globals()["MAX_LEN"])
SEQ_FEATURE_NAMES = list(globals()["SEQ_FEATURE_NAMES"])
Fdim = len(SEQ_FEATURE_NAMES)
feat = {n:i for i,n in enumerate(SEQ_FEATURE_NAMES)}

train_X_path = FIX_DIR / "train_X.dat"
train_B_path = FIX_DIR / "train_B.dat"
train_M_path = FIX_DIR / "train_M.dat"

for p in [train_X_path, train_B_path, train_M_path]:
    if not p.exists():
        raise FileNotFoundError(f"Missing fixed cache file: {p}. Pastikan STAGE 6 sukses.")

X_mm = np.memmap(train_X_path, dtype=np.float32, mode="r", shape=(N, L, Fdim))
B_mm = np.memmap(train_B_path, dtype=np.int8,   mode="r", shape=(N, L))
M_mm = np.memmap(train_M_path, dtype=np.int8,   mode="r", shape=(N, L))

# ----------------------------
# 2b) Read Stage6 policy to detect SHIFT_BAND_IDS (important!)
# ----------------------------
SHIFT_BAND_IDS = False
PAD_BAND_ID = 0
policy_path = FIX_DIR / "length_policy_config.json"
if policy_path.exists():
    try:
        with open(policy_path, "r", encoding="utf-8") as f:
            pol = json.load(f)
        SHIFT_BAND_IDS = bool(pol.get("padding", {}).get("SHIFT_BAND_IDS", False))
        PAD_BAND_ID = int(pol.get("padding", {}).get("PAD_BAND_ID", 0))
    except Exception:
        pass

# detect token mode
SEQ_TOKEN_MODE = globals().get("SEQ_TOKEN_MODE", None)
if SEQ_TOKEN_MODE is None:
    if ("mag" in feat) and ("mag_err_log" in feat):
        SEQ_TOKEN_MODE = "mag"
    elif ("flux_asinh" in feat) and ("err_log1p" in feat):
        SEQ_TOKEN_MODE = "asinh"
    else:
        raise RuntimeError(f"Cannot infer token_mode from features: {SEQ_FEATURE_NAMES}")

for k in ["snr_tanh","detected"]:
    if k not in feat:
        raise RuntimeError(f"Feature '{k}' not found in SEQ_FEATURE_NAMES.")

VAL_FEAT = "mag" if SEQ_TOKEN_MODE == "mag" else "flux_asinh"
if VAL_FEAT not in feat:
    raise RuntimeError(f"Feature '{VAL_FEAT}' missing for token_mode={SEQ_TOKEN_MODE}.")

# ----------------------------
# 3) Build RAW meta global features
# ----------------------------
BASE_G_COLS = ["Z","Z_err","EBV","Z_missing","Z_err_missing","EBV_missing","is_photoz"]
for c in BASE_G_COLS:
    if c not in df_train_meta.columns:
        df_train_meta[c] = 0.0

G_meta = df_train_meta.iloc[pos_idx][BASE_G_COLS].copy()
for c in BASE_G_COLS:
    G_meta[c] = pd.to_numeric(G_meta[c], errors="coerce").fillna(0.0).astype(np.float32)
G_meta_np = G_meta.to_numpy(dtype=np.float32, copy=False)

with open(Path(LOG_DIR)/"global_meta_cols.json", "w", encoding="utf-8") as f:
    json.dump({"cols": BASE_G_COLS}, f, indent=2)

# ----------------------------
# 3b) Sequence aggregate features (global + per-band) — BOOST
# ----------------------------
USE_AGG_SEQ_FEATURES = True
N_BANDS = 6

def _safe_div(a, b):
    return a / np.maximum(b, 1.0)

def build_agg_seq_features(X_mm, B_mm, M_mm, chunk=2048):
    snr_i = feat["snr_tanh"]
    det_i = feat["detected"]
    val_i = feat[VAL_FEAT]

    out_chunks = []
    for s in range(0, N, chunk):
        e = min(N, s + chunk)
        Xc = np.asarray(X_mm[s:e])  # (B,L,F)
        Bc = np.asarray(B_mm[s:e])  # (B,L)
        Mc = np.asarray(M_mm[s:e])  # (B,L)

        real = (Mc == 1)
        tok_count = real.sum(axis=1).astype(np.float32)

        snr = np.abs(Xc[:, :, snr_i]).astype(np.float32)
        det = (Xc[:, :, det_i] > 0.5).astype(np.float32)
        val = Xc[:, :, val_i].astype(np.float32)

        snr_r = snr * real
        det_r = det * real

        det_frac = _safe_div(det_r.sum(axis=1), tok_count)
        mean_abs_snr = _safe_div(snr_r.sum(axis=1), tok_count)
        max_abs_snr = np.where(tok_count > 0, (snr_r).max(axis=1), 0.0).astype(np.float32)

        if SEQ_TOKEN_MODE == "mag":
            val_r = np.where(real, val, np.nan)
            mean_val = np.nan_to_num(np.nanmean(val_r, axis=1).astype(np.float32), nan=0.0)
            std_val  = np.nan_to_num(np.nanstd(val_r, axis=1).astype(np.float32),  nan=0.0)
            min_val  = np.nan_to_num(np.nanmin(val_r, axis=1).astype(np.float32),  nan=0.0)
            global_val_feats = np.stack([mean_val, std_val, min_val], axis=1).astype(np.float32)
        else:
            aval = np.abs(val).astype(np.float32)
            aval_r = aval * real
            mean_aval = _safe_div(aval_r.sum(axis=1), tok_count).astype(np.float32)
            val_r = np.where(real, val, np.nan)
            std_val = np.nan_to_num(np.nanstd(val_r, axis=1).astype(np.float32), nan=0.0)
            max_aval = np.where(tok_count > 0, (aval_r).max(axis=1), 0.0).astype(np.float32)
            global_val_feats = np.stack([mean_aval, std_val, max_aval], axis=1).astype(np.float32)

        per_band = []
        for b in range(N_BANDS):
            bm = (Bc == b) & real
            cnt = bm.sum(axis=1).astype(np.float32)
            detb = (det * bm).sum(axis=1).astype(np.float32)
            snrb = (snr * bm).sum(axis=1).astype(np.float32)

            det_frac_b = _safe_div(detb, cnt)
            mean_abs_snr_b = _safe_div(snrb, cnt)

            if SEQ_TOKEN_MODE == "mag":
                vb = np.where(bm, val, np.nan)
                mean_val_b = np.nan_to_num(np.nanmean(vb, axis=1).astype(np.float32), nan=0.0)
            else:
                ab = (np.abs(val).astype(np.float32) * bm).sum(axis=1).astype(np.float32)
                mean_val_b = _safe_div(ab, cnt).astype(np.float32)

            per_band.append(np.stack([cnt, det_frac_b, mean_abs_snr_b, mean_val_b], axis=1))

        per_band = np.concatenate(per_band, axis=1).astype(np.float32)

        glob = np.stack([tok_count, det_frac, mean_abs_snr, max_abs_snr], axis=1).astype(np.float32)
        agg = np.concatenate([glob, global_val_feats, per_band], axis=1).astype(np.float32)

        out_chunks.append(agg)
        del Xc, Bc, Mc
        if (s // chunk) % 3 == 0:
            gc.collect()

    return np.concatenate(out_chunks, axis=0).astype(np.float32)

if USE_AGG_SEQ_FEATURES:
    print("[Stage 8] Building AGG sequence features (one-time)...")
    t0 = time.time()
    G_seq_np = build_agg_seq_features(X_mm, B_mm, M_mm, chunk=2048)
    print(f"[Stage 8] AGG built: shape={G_seq_np.shape} | time={time.time()-t0:.1f}s")
else:
    G_seq_np = np.zeros((N,0), dtype=np.float32)

G_raw_np = np.concatenate([G_meta_np, G_seq_np], axis=1).astype(np.float32)
g_dim = int(G_raw_np.shape[1])

with open(Path(LOG_DIR)/"global_feature_spec.json", "w", encoding="utf-8") as f:
    json.dump(
        {
            "meta_cols": BASE_G_COLS,
            "use_agg_seq": bool(USE_AGG_SEQ_FEATURES),
            "token_mode": SEQ_TOKEN_MODE,
            "val_feat": VAL_FEAT,
            "agg_dim": int(G_seq_np.shape[1]),
            "total_g_dim": int(g_dim),
            "shift_band_ids_from_stage6": bool(SHIFT_BAND_IDS),
            "pad_band_id_from_stage6": int(PAD_BAND_ID),
        },
        f,
        indent=2,
    )

# ----------------------------
# 4) Dataset / Loader (num_workers=0) + optional augmentation
# ----------------------------
AUG_TOKENDROP_P = 0.05     # 0.0 disable
AUG_VALUE_NOISE = 0.01     # 0.0 disable

class MemmapSeqDataset(torch.utils.data.Dataset):
    def __init__(self, idx, X_mm, B_mm, M_mm, G_scaled_np, y=None, train_mode=False):
        self.idx = np.asarray(idx, dtype=np.int32)
        self.X_mm = X_mm
        self.B_mm = B_mm
        self.M_mm = M_mm
        self.G = G_scaled_np
        self.y = None if y is None else np.asarray(y, dtype=np.int8)
        self.train_mode = bool(train_mode)
        self.rng = np.random.default_rng(SEED + (777 if train_mode else 0))

    def __len__(self):
        return len(self.idx)

    def __getitem__(self, i):
        j = int(self.idx[i])

        X = np.asarray(self.X_mm[j])  # (L,F) view-ish
        B = np.asarray(self.B_mm[j])  # (L,)
        M = np.asarray(self.M_mm[j])  # (L,)
        G = np.asarray(self.G[j])     # (g_dim,)

        # IMPORTANT: handle SHIFT_BAND_IDS from Stage 6
        if SHIFT_BAND_IDS:
            # real bands are 1..6, pad is 0. Convert real->0..5
            real = (M == 1)
            if real.any():
                B = B.astype(np.int16, copy=False)
                B2 = B.copy()
                B2[real] = np.clip(B2[real] - 1, 0, N_BANDS - 1)
                B2[~real] = 0
                B = B2.astype(np.int8, copy=False)

        if self.train_mode:
            # token dropout: drop a small fraction of REAL tokens, keep at least 1
            if AUG_TOKENDROP_P and AUG_TOKENDROP_P > 0:
                real = (M == 1)
                nreal = int(real.sum())
                if nreal > 1:
                    drop = (self.rng.random(M.shape[0]) < AUG_TOKENDROP_P) & real
                    if int(drop.sum()) >= nreal:  # would drop all real tokens
                        keep_pos = np.where(real)[0][int(self.rng.integers(0, nreal))]
                        drop[keep_pos] = False
                    if drop.any():
                        M = M.copy()
                        M[drop] = 0

            # small value noise on real tokens
            if AUG_VALUE_NOISE and AUG_VALUE_NOISE > 0:
                vi = feat[VAL_FEAT]
                real = (M == 1)
                if real.any():
                    X = X.copy()
                    noise = self.rng.normal(0.0, AUG_VALUE_NOISE, size=int(real.sum())).astype(np.float32)
                    X[real, vi] = (X[real, vi] + noise).astype(np.float32)

        Xt = torch.from_numpy(X.astype(np.float32, copy=False))
        Bt = torch.from_numpy(B.astype(np.int64, copy=False))
        Mt = torch.from_numpy(M.astype(np.int64, copy=False))
        Gt = torch.from_numpy(G.astype(np.float32, copy=False))

        if self.y is None:
            return Xt, Bt, Mt, Gt

        yy = float(self.y[j])
        return Xt, Bt, Mt, Gt, torch.tensor(yy, dtype=torch.float32)

def make_loader(ds, batch_size, shuffle, sampler=None):
    return torch.utils.data.DataLoader(
        ds,
        batch_size=batch_size,
        shuffle=(sampler is None and shuffle),
        sampler=sampler,
        num_workers=0,
        pin_memory=False,
        drop_last=False,
    )

# ----------------------------
# 5) Model — stronger pooling
# ----------------------------
class MultibandEventTransformer(nn.Module):
    def __init__(self, feat_dim, max_len, n_bands=6, d_model=160, n_heads=4, n_layers=3, ff_mult=2, dropout=0.12, g_dim=0):
        super().__init__()
        self.n_bands = n_bands
        self.d_model = d_model
        self.max_len = max_len

        self.x_proj = nn.Sequential(
            nn.Linear(feat_dim, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
        )
        self.band_emb = nn.Embedding(n_bands, d_model)

        self.pos_emb = nn.Parameter(torch.zeros(1, max_len, d_model))
        nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=int(d_model * ff_mult),
            dropout=dropout,
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=n_layers)

        self.attn = nn.Linear(d_model, 1)
        self.pool_ln = nn.LayerNorm(d_model)

        g_out = max(32, d_model // 2)
        self.g_proj = nn.Sequential(
            nn.Linear(g_dim, g_out),
            nn.GELU(),
            nn.Dropout(dropout),
        )

        self.head = nn.Sequential(
            nn.Linear(d_model + g_out, d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, 1),
        )

    def forward(self, X, band_id, mask, G):
        X = X.to(torch.float32)
        band_id = band_id.to(torch.long)
        mask = mask.to(torch.long)

        # clamp band to 0..n_bands-1
        band_id = band_id.clamp(0, self.n_bands - 1)

        pad_mask = (mask == 0)  # True=pad
        all_pad = pad_mask.all(dim=1)
        if all_pad.any():
            pad_mask = pad_mask.clone()
            pad_mask[all_pad, 0] = False

        h = self.x_proj(X) + self.band_emb(band_id) + self.pos_emb[:, :X.shape[1], :]
        h = self.encoder(h, src_key_padding_mask=pad_mask)

        a = self.attn(h).squeeze(-1)
        a = a.masked_fill(pad_mask, -1e9)
        w = torch.softmax(a, dim=1)
        pooled_attn = torch.sum(h * w.unsqueeze(-1), dim=1)

        valid = (~pad_mask).to(h.dtype).unsqueeze(-1)
        denom = valid.sum(dim=1).clamp_min(1.0)
        pooled_mean = (h * valid).sum(dim=1) / denom

        pooled = 0.6 * pooled_attn + 0.4 * pooled_mean
        pooled = self.pool_ln(pooled)

        g = self.g_proj(G.to(torch.float32))
        z = torch.cat([pooled, g], dim=1)
        return self.head(z).squeeze(-1)

# ----------------------------
# 6) Training config (CPU safe)
# ----------------------------
CFG = {
    "d_model": 160,
    "n_heads": 4,
    "n_layers": 3,
    "ff_mult": 2,
    "dropout": 0.12,

    "batch_size": 16,
    "grad_accum": 2,

    "epochs": 14,
    "lr": 5e-4,
    "weight_decay": 0.02,

    "patience": 4,            # early stop by AUC
    "max_grad_norm": 1.0,

    # imbalance strategy: "sampler" | "pos_weight" | "both" | "none"
    "balance_mode": "sampler",

    "label_smoothing": 0.03,
    "scheduler": "onecycle",
}

# auto soften for long seq
if L >= 512:
    CFG["d_model"] = 128
    CFG["n_heads"] = 4
    CFG["n_layers"] = 2
    CFG["batch_size"] = 12
    CFG["grad_accum"] = 2
    CFG["lr"] = 4e-4

cfg_path = Path(LOG_DIR) / "train_cfg_stage8.json"
with open(cfg_path, "w", encoding="utf-8") as f:
    json.dump(CFG, f, indent=2)

pos_all = int((y == 1).sum())
neg_all = int((y == 0).sum())
print("[Stage 8] TRAIN CONFIG (CPU)")
print(f"- N={N:,} | pos={pos_all:,} | neg={neg_all:,} | pos%={pos_all/max(N,1)*100:.6f}%")
print(f"- token_mode={SEQ_TOKEN_MODE} | val_feat={VAL_FEAT} | g_dim={g_dim} | use_agg_seq={USE_AGG_SEQ_FEATURES}")
print(f"- SHIFT_BAND_IDS(from stage6)={SHIFT_BAND_IDS} | PAD_BAND_ID(from stage6)={PAD_BAND_ID}")
print(f"- Model: d_model={CFG['d_model']} heads={CFG['n_heads']} layers={CFG['n_layers']} dropout={CFG['dropout']}")
print(f"- Batch={CFG['batch_size']} grad_accum={CFG['grad_accum']} epochs={CFG['epochs']} lr={CFG['lr']}")
print(f"- balance_mode={CFG['balance_mode']} | label_smoothing={CFG['label_smoothing']}")
print(f"- CKPT_DIR={CKPT_DIR}")
print(f"- OOF_DIR ={OOF_DIR}")
print(f"- LOG_DIR ={LOG_DIR}")

# ----------------------------
# 7) Helpers
# ----------------------------
def sigmoid_np(x):
    x = np.clip(x, -50, 50)
    return 1.0 / (1.0 + np.exp(-x))

def f1_binary(y_true, y_pred01):
    y_true = y_true.astype(np.int32)
    y_pred01 = y_pred01.astype(np.int32)
    tp = int(((y_true == 1) & (y_pred01 == 1)).sum())
    fp = int(((y_true == 0) & (y_pred01 == 1)).sum())
    fn = int(((y_true == 1) & (y_pred01 == 0)).sum())
    if tp == 0:
        return 0.0
    prec = tp / max(tp + fp, 1)
    rec  = tp / max(tp + fn, 1)
    if prec + rec == 0:
        return 0.0
    return float(2 * prec * rec / (prec + rec))

@torch.no_grad()
def eval_model(model, loader, criterion):
    model.eval()
    losses, logits_all, y_all = [], [], []
    for batch in loader:
        Xb, Bb, Mb, Gb, yb = batch
        Xb = Xb.to(device); Bb = Bb.to(device); Mb = Mb.to(device); Gb = Gb.to(device); yb = yb.to(device)
        logit = model(Xb, Bb, Mb, Gb)
        loss = criterion(logit, yb)
        losses.append(float(loss.item()))
        logits_all.append(logit.detach().cpu().numpy())
        y_all.append(yb.detach().cpu().numpy())
    logits_all = np.concatenate(logits_all, axis=0) if logits_all else np.zeros((0,), dtype=np.float32)
    y_all = np.concatenate(y_all, axis=0).astype(np.int8) if y_all else np.zeros((0,), dtype=np.int8)
    probs = sigmoid_np(logits_all)
    pred01 = (probs >= 0.5).astype(np.int8)
    f1 = f1_binary(y_all, pred01)
    auc = float(roc_auc_score(y_all, probs)) if (len(np.unique(y_all)) == 2) else float("nan")
    return float(np.mean(losses) if losses else np.nan), probs, y_all, f1, auc

# fold-wise scaler (NO leakage)
def fit_scaler_fold(G_raw_np, tr_idx):
    X = G_raw_np[tr_idx]
    mean = X.mean(axis=0).astype(np.float32)
    std  = X.std(axis=0).astype(np.float32)
    std  = np.where(std < 1e-6, 1.0, std).astype(np.float32)
    return mean, std

def apply_scaler(G_raw_np, mean, std):
    return ((G_raw_np - mean) / std).astype(np.float32)

# ----------------------------
# 8) CV Train
# ----------------------------
oof_prob = np.zeros((N,), dtype=np.float32)
fold_metrics = []

all_idx = np.arange(N, dtype=np.int32)
n_splits = int(globals()["n_splits"])

start_time = time.time()

for fold_info in globals()["folds"]:
    fold = int(fold_info["fold"])
    val_idx = np.asarray(fold_info["val_idx"], dtype=np.int32)

    val_mask = np.zeros(N, dtype=bool)
    val_mask[val_idx] = True
    tr_idx = all_idx[~val_mask]

    y_tr = y[tr_idx]
    pos = int((y_tr == 1).sum())
    neg = int((y_tr == 0).sum())
    if pos == 0:
        raise RuntimeError(f"[Stage 8] Fold {fold}: no positives in training split.")

    # imbalance knobs
    balance_mode = str(CFG.get("balance_mode", "sampler")).lower()
    use_sampler = balance_mode in ("sampler", "both")
    use_posw    = balance_mode in ("pos_weight", "both")

    pos_weight = float(neg / max(pos, 1))
    pos_weight_t = torch.tensor([pos_weight], dtype=torch.float32, device=device)

    # label smoothing
    ls = float(CFG["label_smoothing"])
    def smooth(yb):
        if ls <= 0:
            return yb
        return yb * (1.0 - ls) + 0.5 * ls

    # criterion
    if use_posw:
        criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_t)
    else:
        criterion = nn.BCEWithLogitsLoss()

    print(f"\n[Stage 8] FOLD {fold}/{n_splits-1} | train={len(tr_idx):,} val={len(val_idx):,} "
          f"| pos={pos:,} neg={neg:,} | pos_weight={pos_weight:.4f} | balance_mode={balance_mode}")

    # fold-wise scaler (NO leakage)
    g_mean, g_std = fit_scaler_fold(G_raw_np, tr_idx)
    G_fold_z = apply_scaler(G_raw_np, g_mean, g_std)

    # datasets
    ds_tr = MemmapSeqDataset(tr_idx, X_mm, B_mm, M_mm, G_fold_z, y=y, train_mode=True)
    ds_va = MemmapSeqDataset(val_idx, X_mm, B_mm, M_mm, G_fold_z, y=y, train_mode=False)

    # optional weighted sampler (train only)
    sampler = None
    if use_sampler:
        w = np.ones((len(tr_idx),), dtype=np.float32)
        ytr_local = y[tr_idx]
        w[ytr_local == 1] = float(neg / max(pos, 1))
        w_t = torch.from_numpy(w)
        sampler = torch.utils.data.WeightedRandomSampler(weights=w_t, num_samples=len(tr_idx), replacement=True)

    dl_tr = make_loader(ds_tr, batch_size=int(CFG["batch_size"]), shuffle=True, sampler=sampler)
    dl_va = make_loader(ds_va, batch_size=int(CFG["batch_size"]), shuffle=False)

    model = MultibandEventTransformer(
        feat_dim=Fdim,
        max_len=L,
        n_bands=6,
        d_model=int(CFG["d_model"]),
        n_heads=int(CFG["n_heads"]),
        n_layers=int(CFG["n_layers"]),
        ff_mult=int(CFG["ff_mult"]),
        dropout=float(CFG["dropout"]),
        g_dim=g_dim,
    ).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=float(CFG["lr"]), weight_decay=float(CFG["weight_decay"]))

    # scheduler (FIX: steps_per_epoch must match optimizer steps when using grad_accum)
    scheduler = None
    grad_accum = int(CFG["grad_accum"])
    if str(CFG.get("scheduler","")).lower() == "onecycle":
        steps_per_epoch_opt = int(math.ceil(len(dl_tr) / max(grad_accum, 1)))
        steps_per_epoch_opt = max(steps_per_epoch_opt, 1)
        scheduler = torch.optim.lr_scheduler.OneCycleLR(
            opt,
            max_lr=float(CFG["lr"]),
            epochs=int(CFG["epochs"]),
            steps_per_epoch=steps_per_epoch_opt,
            pct_start=0.1,
            anneal_strategy="cos",
            div_factor=10.0,
            final_div_factor=50.0,
        )

    best_val_auc = -1e9
    best_val_loss = float("inf")
    best_epoch = -1
    best_probs = None
    patience_left = int(CFG["patience"])

    for epoch in range(1, int(CFG["epochs"]) + 1):
        model.train()
        opt.zero_grad(set_to_none=True)

        total_loss = 0.0
        n_batches = 0
        accum = 0
        opt_steps = 0

        for batch in dl_tr:
            Xb, Bb, Mb, Gb, yb = batch
            Xb = Xb.to(device); Bb = Bb.to(device); Mb = Mb.to(device); Gb = Gb.to(device); yb = yb.to(device)

            yb_s = smooth(yb)

            logit = model(Xb, Bb, Mb, Gb)
            loss = criterion(logit, yb_s)

            total_loss += float(loss.item())
            n_batches += 1

            (loss / float(grad_accum)).backward()
            accum += 1

            if accum == grad_accum:
                if CFG["max_grad_norm"] is not None:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(CFG["max_grad_norm"]))
                opt.step()
                opt.zero_grad(set_to_none=True)
                opt_steps += 1
                accum = 0
                if scheduler is not None:
                    scheduler.step()

        # remainder step
        if accum > 0:
            if CFG["max_grad_norm"] is not None:
                torch.nn.utils.clip_grad_norm_(model.parameters(), float(CFG["max_grad_norm"]))
            opt.step()
            opt.zero_grad(set_to_none=True)
            opt_steps += 1
            if scheduler is not None:
                scheduler.step()

        train_loss = total_loss / max(n_batches, 1)

        # validate (use NON-smoothed y)
        val_loss, probs, y_val, f1_05, val_auc = eval_model(model, dl_va, criterion)

        improved = (val_auc > best_val_auc + 1e-6) or (math.isnan(best_val_auc) and not math.isnan(val_auc))
        if (not improved) and (abs(val_auc - best_val_auc) <= 1e-6) and (val_loss < best_val_loss - 1e-6):
            improved = True

        if improved:
            best_val_auc = float(val_auc)
            best_val_loss = float(val_loss)
            best_epoch = int(epoch)
            best_probs = probs.copy()

            ckpt_path = CKPT_DIR / f"fold_{fold}.pt"
            torch.save(
                {
                    "fold": fold,
                    "epoch": epoch,
                    "model_state": model.state_dict(),
                    "cfg": CFG,
                    "seq_feature_names": SEQ_FEATURE_NAMES,
                    "max_len": L,
                    "token_mode": SEQ_TOKEN_MODE,
                    "val_feat": VAL_FEAT,
                    "global_meta_cols": BASE_G_COLS,
                    "use_agg_seq_features": bool(USE_AGG_SEQ_FEATURES),
                    "global_scaler": {"mean": g_mean, "std": g_std},
                    "pos_weight_train": float(pos_weight),
                    "balance_mode": balance_mode,
                    "shift_band_ids_from_stage6": bool(SHIFT_BAND_IDS),
                    "pad_band_id_from_stage6": int(PAD_BAND_ID),
                },
                ckpt_path,
            )
            patience_left = int(CFG["patience"])
        else:
            patience_left -= 1

        lr_now = opt.param_groups[0]["lr"]
        print(f"  epoch {epoch:02d} | lr={lr_now:.2e} | opt_steps={opt_steps:4d} | "
              f"train_loss={train_loss:.5f} | val_loss={val_loss:.5f} | val_auc={val_auc:.5f} | f1@0.5={f1_05:.4f} | "
              f"best_ep={best_epoch} | pat={patience_left}")

        if patience_left <= 0:
            break

    if best_probs is None:
        raise RuntimeError(f"Fold {fold}: best_probs is None (unexpected).")

    # fill OOF
    oof_prob[val_idx] = best_probs.astype(np.float32)

    pred01 = (best_probs >= 0.5).astype(np.int8)
    best_f1_05 = f1_binary(y[val_idx], pred01)

    fold_metrics.append({
        "fold": fold,
        "val_size": int(len(val_idx)),
        "best_epoch": int(best_epoch),
        "best_val_auc": float(best_val_auc),
        "best_val_loss": float(best_val_loss),
        "f1_at_0p5": float(best_f1_05),
        "pos_weight_train": float(pos_weight),
        "g_dim": int(g_dim),
        "use_agg_seq": bool(USE_AGG_SEQ_FEATURES),
        "balance_mode": balance_mode,
        "shift_band_ids_from_stage6": bool(SHIFT_BAND_IDS),
    })

    del model, opt, ds_tr, ds_va, dl_tr, dl_va, G_fold_z
    gc.collect()

elapsed = time.time() - start_time

# ----------------------------
# 9) Save OOF artifacts + summary
# ----------------------------
oof_path_npy = OOF_DIR / "oof_prob.npy"
np.save(oof_path_npy, oof_prob)

df_oof = pd.DataFrame({"object_id": train_ids, "target": y.astype(int), "oof_prob": oof_prob.astype(np.float32)})
oof_path_csv = OOF_DIR / "oof_prob.csv"
df_oof.to_csv(oof_path_csv, index=False)

metrics_path = OOF_DIR / "fold_metrics.json"
with open(metrics_path, "w", encoding="utf-8") as f:
    json.dump({"fold_metrics": fold_metrics, "elapsed_sec": float(elapsed)}, f, indent=2)

oof_pred01 = (oof_prob >= 0.5).astype(np.int8)
oof_f1_05 = f1_binary(y, oof_pred01)
oof_auc = float(roc_auc_score(y, oof_prob)) if (len(np.unique(y)) == 2) else float("nan")

print("\n[Stage 8] CV TRAIN DONE")
print(f"- elapsed: {elapsed/60:.2f} min")
print(f"- OOF saved: {oof_path_npy}")
print(f"- OOF saved: {oof_path_csv}")
print(f"- fold metrics: {metrics_path}")
print(f"- OOF AUC (rough): {oof_auc:.5f}")
print(f"- OOF F1@0.5 (rough): {oof_f1_05:.4f}")

globals().update({
    "oof_prob": oof_prob,
    "OOF_PROB_PATH": oof_path_npy,
    "OOF_CSV_PATH": oof_path_csv,
    "FOLD_METRICS_PATH": metrics_path,
    "TRAIN_CFG_PATH": cfg_path,
})

gc.collect()


[Stage 8] Building AGG sequence features (one-time)...


  mean_val_b = np.nan_to_num(np.nanmean(vb, axis=1).astype(np.float32), nan=0.0)


[Stage 8] AGG built: shape=(3043, 31) | time=0.2s
[Stage 8] TRAIN CONFIG (CPU)
- N=3,043 | pos=148 | neg=2,895 | pos%=4.863621%
- token_mode=mag | val_feat=mag | g_dim=38 | use_agg_seq=True
- SHIFT_BAND_IDS(from stage6)=False | PAD_BAND_ID(from stage6)=0
- Model: d_model=160 heads=4 layers=3 dropout=0.12
- Batch=16 grad_accum=2 epochs=14 lr=0.0005
- balance_mode=sampler | label_smoothing=0.03
- CKPT_DIR=/kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/checkpoints
- OOF_DIR =/kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/oof
- LOG_DIR =/kaggle/working/mallorn_run/run_20260103_134051_143e1c6a76/logs

[Stage 8] FOLD 0/4 | train=2,434 val=609 | pos=118 neg=2,316 | pos_weight=19.6271 | balance_mode=sampler


  Xt = torch.from_numpy(X.astype(np.float32, copy=False))


  epoch 01 | lr=4.19e-04 | opt_steps=  77 | train_loss=0.68123 | val_loss=0.52205 | val_auc=0.68181 | f1@0.5=0.0615 | best_ep=1 | pat=4
  epoch 02 | lr=4.97e-04 | opt_steps=  77 | train_loss=0.61258 | val_loss=0.73800 | val_auc=0.70150 | f1@0.5=0.1620 | best_ep=2 | pat=4
  epoch 03 | lr=4.80e-04 | opt_steps=  77 | train_loss=0.54667 | val_loss=0.54998 | val_auc=0.75596 | f1@0.5=0.1756 | best_ep=3 | pat=4
  epoch 04 | lr=4.49e-04 | opt_steps=  77 | train_loss=0.48130 | val_loss=0.41873 | val_auc=0.81911 | f1@0.5=0.2460 | best_ep=4 | pat=4
  epoch 05 | lr=4.05e-04 | opt_steps=  77 | train_loss=0.43914 | val_loss=0.47571 | val_auc=0.83074 | f1@0.5=0.2212 | best_ep=5 | pat=4
  epoch 06 | lr=3.52e-04 | opt_steps=  77 | train_loss=0.40878 | val_loss=0.33423 | val_auc=0.84231 | f1@0.5=0.2714 | best_ep=6 | pat=4
  epoch 07 | lr=2.93e-04 | opt_steps=  77 | train_loss=0.39504 | val_loss=0.43668 | val_auc=0.84790 | f1@0.5=0.2434 | best_ep=7 | pat=4
  epoch 08 | lr=2.31e-04 | opt_steps=  77 | trai



  epoch 01 | lr=4.19e-04 | opt_steps=  77 | train_loss=0.67701 | val_loss=0.65374 | val_auc=0.69522 | f1@0.5=0.1548 | best_ep=1 | pat=4
  epoch 02 | lr=4.97e-04 | opt_steps=  77 | train_loss=0.59427 | val_loss=0.50821 | val_auc=0.69827 | f1@0.5=0.1545 | best_ep=2 | pat=4


# OOF Prediction + Threshold Tuning

In [None]:
# ============================================================
# STAGE 9 — OOF Prediction + Threshold Tuning (ONE CELL, Kaggle CPU-SAFE)
# REVISI FULL v3.1 (ALIGN SUPER ROBUST + MULTI-METRIC + FAST SWEEP)
#
# Upgrade v3.1 vs v3:
# - String-map df_train_meta.index -> row positions (anti dtype mismatch)
# - Prefer oof_prob.csv (object_id + oof_prob) for safest alignment
# - Clean oof_prob NaN/inf + clip [0,1]
# - FAST threshold sweep via sorting + cumulative counts (vectorized)
# - Best thresholds for: F1, Accuracy, Balanced Accuracy, MCC (+ Precision/Recall)
# - Exports multiple BEST thresholds
# ============================================================

import gc, json, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)

# ----------------------------
# 0) Require previous stages
# ----------------------------
need = ["OOF_DIR", "df_train_meta"]
for k in need:
    if k not in globals():
        raise RuntimeError(f"Missing `{k}`. Jalankan STAGE 0..8 dulu.")

OOF_DIR = Path(OOF_DIR)
OOF_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# Helper: robust stringify id
# ----------------------------
def _to_str_list(ids):
    out = []
    for x in ids:
        if isinstance(x, (bytes, np.bytes_)):
            out.append(x.decode("utf-8", errors="ignore").strip())
        else:
            out.append(str(x).strip())
    return out

def _as_1d_float32(arr):
    a = np.asarray(arr)
    if a.dtype == object and a.ndim == 0:
        try:
            a = np.asarray(a.item())
        except Exception:
            pass
    a = np.asarray(a, dtype=np.float32)
    if a.ndim == 0:
        return a
    if a.ndim > 1:
        a = a.reshape(-1)
    return a

# ----------------------------
# Detect target column in df_train_meta
# ----------------------------
def _detect_target_col(df):
    for cand in ["target","y","label","class","is_tde","binary_target","target_id"]:
        if cand in df.columns:
            return cand
    return None

TARGET_COL = _detect_target_col(df_train_meta)
if TARGET_COL is None:
    raise RuntimeError(
        "Cannot detect target column in df_train_meta. "
        f"Columns sample: {list(df_train_meta.columns)[:60]}"
    )

# ----------------------------
# Load OOF (prefer CSV)
# ----------------------------
def _load_oof():
    pcsv = OOF_DIR / "oof_prob.csv"
    if pcsv.exists():
        df = pd.read_csv(pcsv)
        if ("object_id" in df.columns) and ("oof_prob" in df.columns):
            ids = df["object_id"].astype(str).str.strip().tolist()
            prob = _as_1d_float32(df["oof_prob"].to_numpy())
            return ids, prob, "csv(oof_prob.csv)"

    if "oof_prob" in globals():
        prob = _as_1d_float32(globals()["oof_prob"])
        if isinstance(prob, np.ndarray) and prob.ndim != 0:
            if "train_ids_ordered" in globals():
                ids = _to_str_list(list(globals()["train_ids_ordered"]))
                return ids, prob, "globals(oof_prob + train_ids_ordered)"
            if len(prob) == len(df_train_meta):
                ids = _to_str_list(df_train_meta.index.tolist())
                return ids, prob, "globals(oof_prob + df_train_meta.index)"

    pnpy = OOF_DIR / "oof_prob.npy"
    if pnpy.exists():
        prob = _as_1d_float32(np.load(pnpy, allow_pickle=False))
        if "train_ids_ordered" in globals():
            ids = _to_str_list(list(globals()["train_ids_ordered"]))
            return ids, prob, "npy(oof_prob.npy + train_ids_ordered)"
        if len(prob) == len(df_train_meta):
            ids = _to_str_list(df_train_meta.index.tolist())
            return ids, prob, "npy(oof_prob.npy + df_train_meta.index)"

    raise FileNotFoundError("OOF prob not found (csv/globals/npy). Jalankan STAGE 8 dulu.")

train_ids, oof_prob, src = _load_oof()

if not isinstance(oof_prob, np.ndarray) or oof_prob.ndim == 0:
    raise TypeError(f"Invalid oof_prob. Type={type(oof_prob)} ndim={getattr(oof_prob,'ndim',None)}")

# sanitize prob
oof_prob = np.nan_to_num(oof_prob, nan=0.0, posinf=1.0, neginf=0.0).astype(np.float32)
oof_prob = np.clip(oof_prob, 0.0, 1.0).astype(np.float32)

# ----------------------------
# SUPER ROBUST alignment: string-map df_train_meta.index -> positions
# ----------------------------
meta_ids = _to_str_list(df_train_meta.index.tolist())
pos_map = pd.Series(np.arange(len(meta_ids), dtype=np.int32), index=pd.Index(meta_ids, dtype="object"))

pos_idx = pos_map.reindex(train_ids).to_numpy()
if np.isnan(pos_idx).any():
    bad = [train_ids[i] for i in np.where(np.isnan(pos_idx))[0][:10]]
    raise KeyError(f"OOF ids not found in df_train_meta.index (string-mapped). ex={bad} | missing_n={int(np.isnan(pos_idx).sum())}")

pos_idx = pos_idx.astype(np.int32)

# load y aligned
y_raw = pd.to_numeric(df_train_meta[TARGET_COL], errors="coerce").fillna(0).astype(np.int16).to_numpy()
y = y_raw[pos_idx]
y = (y > 0).astype(np.int8)

if len(oof_prob) != len(y):
    raise RuntimeError(f"Length mismatch: oof_prob={len(oof_prob)} vs y={len(y)}")

uy = set(np.unique(y).tolist())
if not uy.issubset({0, 1}):
    raise ValueError(f"y must be binary 0/1. Found: {sorted(list(uy))}")

N = int(len(y))
pos = int((y == 1).sum())
neg = int((y == 0).sum())

print(f"[Stage 9] Loaded OOF from: {src}")
print(f"[Stage 9] N={N:,} | pos={pos:,} | neg={neg:,} | pos%={pos/max(N,1)*100:.6f}% | target_col={TARGET_COL}")

# ----------------------------
# 1) Metrics helpers (vectorized-safe)
# ----------------------------
def _safe_div(a, b):
    return a / np.maximum(b, 1e-12)

def _metrics_from_counts(tp, fp, fn, tn):
    tp = tp.astype(np.float64)
    fp = fp.astype(np.float64)
    fn = fn.astype(np.float64)
    tn = tn.astype(np.float64)

    prec = _safe_div(tp, tp + fp)
    rec  = _safe_div(tp, tp + fn)
    f1   = _safe_div(2 * prec * rec, prec + rec)

    acc  = _safe_div(tp + tn, tp + fp + fn + tn)

    tpr  = _safe_div(tp, tp + fn)
    tnr  = _safe_div(tn, tn + fp)
    bacc = 0.5 * (tpr + tnr)

    num = tp * tn - fp * fn
    den = (tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)
    mcc = np.where(den > 0, num / np.sqrt(den), 0.0)

    return f1, prec, rec, acc, bacc, mcc

# ----------------------------
# 2) Threshold candidates (grid + quantiles + unique-prob sampling)
# ----------------------------
grid = np.concatenate([
    np.linspace(0.00, 0.10, 41),
    np.linspace(0.10, 0.90, 161),
    np.linspace(0.90, 1.00, 41),
]).astype(np.float32)

qs = np.linspace(0.001, 0.999, 999, dtype=np.float32)
quant_thr = np.quantile(oof_prob, qs).astype(np.float32)

uniq = np.unique(oof_prob)
if len(uniq) > 6000:
    take = np.linspace(0, len(uniq) - 1, 6000, dtype=int)
    uniq = uniq[take].astype(np.float32)

thr_candidates = np.unique(np.clip(np.concatenate([grid, quant_thr, uniq]), 0.0, 1.0)).astype(np.float32)

# ----------------------------
# 3) FAST sweep via sorting + cumulative counts
#    Predict positive if oof_prob >= thr
# ----------------------------
# sort probabilities descending
ord_desc = np.argsort(-oof_prob)
p_sorted = oof_prob[ord_desc]
y_sorted = y[ord_desc].astype(np.int8)

# cumulative pos/neg for prefix k (k predicted positive)
pos_prefix = np.cumsum(y_sorted == 1).astype(np.int64)
neg_prefix = np.cumsum(y_sorted == 0).astype(np.int64)

pos_total = int(pos_prefix[-1]) if N > 0 else 0
neg_total = int(neg_prefix[-1]) if N > 0 else 0

# for each thr, k = number of items with prob >= thr
# since p_sorted is descending, find leftmost index where p_sorted < thr
# k = searchsorted(-p_sorted, -thr, side="left")
k = np.searchsorted(-p_sorted, -thr_candidates.astype(np.float32), side="left").astype(np.int64)

tp = np.where(k > 0, pos_prefix[k - 1], 0).astype(np.int64)
fp = np.where(k > 0, neg_prefix[k - 1], 0).astype(np.int64)
fn = (pos_total - tp).astype(np.int64)
tn = (neg_total - fp).astype(np.int64)

f1, prec, rec, acc, bacc, mmc = _metrics_from_counts(tp, fp, fn, tn)
pos_pred = k.astype(np.int64)

thr_table = pd.DataFrame({
    "thr": thr_candidates.astype(np.float32),
    "f1": f1.astype(np.float32),
    "precision": prec.astype(np.float32),
    "recall": rec.astype(np.float32),
    "accuracy": acc.astype(np.float32),
    "balanced_accuracy": bacc.astype(np.float32),
    "mcc": mmc.astype(np.float32),
    "tp": tp.astype(np.int64),
    "fp": fp.astype(np.int64),
    "fn": fn.astype(np.int64),
    "tn": tn.astype(np.int64),
    "pos_pred": pos_pred.astype(np.int64),
})

# ----------------------------
# 4) Pick best thresholds with tie-breakers
# ----------------------------
def _pick_best(df, primary, tie_cols):
    sort_cols = [primary] + tie_cols
    asc = [False] * len(sort_cols)
    return df.sort_values(sort_cols, ascending=asc).iloc[0]

best_f1_row  = _pick_best(thr_table, "f1", ["recall", "precision", "balanced_accuracy"])
best_acc_row = _pick_best(thr_table, "accuracy", ["balanced_accuracy", "f1"])
best_bac_row = _pick_best(thr_table, "balanced_accuracy", ["accuracy", "f1"])
best_mcc_row = _pick_best(thr_table, "mcc", ["f1", "balanced_accuracy"])

def _row_to_dict(r):
    return {
        "thr": float(r["thr"]),
        "f1": float(r["f1"]),
        "precision": float(r["precision"]),
        "recall": float(r["recall"]),
        "accuracy": float(r["accuracy"]),
        "balanced_accuracy": float(r["balanced_accuracy"]),
        "mcc": float(r["mcc"]),
        "tp": int(r["tp"]), "fp": int(r["fp"]), "fn": int(r["fn"]), "tn": int(r["tn"]),
        "pos_pred": int(r["pos_pred"]),
    }

def _eval_at(thr):
    thr = float(thr)
    k0 = int(np.searchsorted(-p_sorted, -np.float32(thr), side="left"))
    tp0 = int(pos_prefix[k0 - 1]) if k0 > 0 else 0
    fp0 = int(neg_prefix[k0 - 1]) if k0 > 0 else 0
    fn0 = int(pos_total - tp0)
    tn0 = int(neg_total - fp0)

    prec0 = tp0 / max(tp0 + fp0, 1)
    rec0  = tp0 / max(tp0 + fn0, 1)
    f10 = 0.0 if (tp0 == 0 or (prec0 + rec0) == 0) else (2 * prec0 * rec0 / (prec0 + rec0))
    acc0 = (tp0 + tn0) / max(tp0 + fp0 + fn0 + tn0, 1)
    bacc0 = 0.5 * ((tp0 / max(tp0 + fn0, 1)) + (tn0 / max(tn0 + fp0, 1)))
    den0 = (tp0 + fp0) * (tp0 + fn0) * (tn0 + fp0) * (tn0 + fn0)
    mcc0 = 0.0 if den0 <= 0 else ((tp0 * tn0 - fp0 * fn0) / math.sqrt(den0))

    return {
        "thr": thr,
        "f1": float(f10),
        "precision": float(prec0),
        "recall": float(rec0),
        "accuracy": float(acc0),
        "balanced_accuracy": float(bacc0),
        "mcc": float(mcc0),
        "tp": tp0, "fp": fp0, "fn": fn0, "tn": tn0,
        "pos_pred": int(k0),
    }

base05 = _eval_at(0.5)

BEST_THR_F1   = float(best_f1_row["thr"])
BEST_THR_ACC  = float(best_acc_row["thr"])
BEST_THR_BACC = float(best_bac_row["thr"])
BEST_THR_MCC  = float(best_mcc_row["thr"])

best_f1_full  = _eval_at(BEST_THR_F1)
best_acc_full = _eval_at(BEST_THR_ACC)
best_bac_full = _eval_at(BEST_THR_BACC)
best_mcc_full = _eval_at(BEST_THR_MCC)

# ----------------------------
# 5) Save artifacts
# ----------------------------
out_json = OOF_DIR / "threshold_tuning.json"
out_txt  = OOF_DIR / "threshold_report.txt"
out_csv  = OOF_DIR / "threshold_table_top500.csv"

payload = {
    "source": src,
    "target_col": TARGET_COL,
    "n_objects": int(N),
    "pos": int(pos),
    "neg": int(neg),
    "baseline_thr_0p5": base05,
    "best_thr_f1": best_f1_full,
    "best_thr_accuracy": best_acc_full,
    "best_thr_balanced_accuracy": best_bac_full,
    "best_thr_mcc": best_mcc_full,
}

with open(out_json, "w", encoding="utf-8") as f:
    json.dump(payload, f, indent=2)

thr_table.sort_values(["f1","recall","precision"], ascending=[False, False, False]).head(500).to_csv(out_csv, index=False)

top_f1 = thr_table.sort_values(["f1","recall","precision"], ascending=[False, False, False]).head(10).reset_index(drop=True)

lines = []
lines.append("OOF Threshold Tuning Report (v3.1)")
lines.append(f"- source={src}")
lines.append(f"- target_col={TARGET_COL}")
lines.append(f"- N={N} | pos={pos} | neg={neg} | pos%={pos/max(N,1)*100:.6f}%")
lines.append("")
lines.append("Baseline @ thr=0.5")
lines.append(f"- F1={base05['f1']:.6f} | P={base05['precision']:.6f} | R={base05['recall']:.6f} | "
             f"ACC={base05['accuracy']:.6f} | BACC={base05['balanced_accuracy']:.6f} | MCC={base05['mcc']:.6f}")
lines.append(f"- tp={base05['tp']} fp={base05['fp']} fn={base05['fn']} tn={base05['tn']} | pos_pred={base05['pos_pred']}")
lines.append("")
lines.append(f"BEST-F1   @ thr={best_f1_full['thr']:.6f} | F1={best_f1_full['f1']:.6f} | P={best_f1_full['precision']:.6f} | R={best_f1_full['recall']:.6f} | pos_pred={best_f1_full['pos_pred']}")
lines.append(f"BEST-ACC  @ thr={best_acc_full['thr']:.6f} | ACC={best_acc_full['accuracy']:.6f} | BACC={best_acc_full['balanced_accuracy']:.6f} | F1={best_acc_full['f1']:.6f}")
lines.append(f"BEST-BACC @ thr={best_bac_full['thr']:.6f} | BACC={best_bac_full['balanced_accuracy']:.6f} | ACC={best_bac_full['accuracy']:.6f} | F1={best_bac_full['f1']:.6f}")
lines.append(f"BEST-MCC  @ thr={best_mcc_full['thr']:.6f} | MCC={best_mcc_full['mcc']:.6f} | F1={best_mcc_full['f1']:.6f} | BACC={best_mcc_full['balanced_accuracy']:.6f}")
lines.append("")
lines.append("Top 10 by F1:")
for i in range(len(top_f1)):
    r = top_f1.iloc[i]
    lines.append(f"{i+1:02d}. thr={float(r['thr']):.6f} | f1={float(r['f1']):.6f} | P={float(r['precision']):.6f} | "
                 f"R={float(r['recall']):.6f} | ACC={float(r['accuracy']):.6f} | BACC={float(r['balanced_accuracy']):.6f} | "
                 f"MCC={float(r['mcc']):.6f} | pos_pred={int(r['pos_pred'])}")

with open(out_txt, "w", encoding="utf-8") as f:
    f.write("\n".join(lines) + "\n")

print("[Stage 9] DONE")
print(f"- Saved: {out_json}")
print(f"- Saved: {out_txt}")
print(f"- Saved: {out_csv}")
print(f"- BEST_THR_F1  ={BEST_THR_F1:.6f} | F1={best_f1_full['f1']:.6f} (P={best_f1_full['precision']:.6f} R={best_f1_full['recall']:.6f})")
print(f"- BEST_THR_ACC ={BEST_THR_ACC:.6f} | ACC={best_acc_full['accuracy']:.6f} BACC={best_acc_full['balanced_accuracy']:.6f} F1={best_acc_full['f1']:.6f}")
print(f"- BEST_THR_BACC={BEST_THR_BACC:.6f} | BACC={best_bac_full['balanced_accuracy']:.6f} ACC={best_bac_full['accuracy']:.6f} F1={best_bac_full['f1']:.6f}")
print(f"- BEST_THR_MCC ={BEST_THR_MCC:.6f} | MCC={best_mcc_full['mcc']:.6f} F1={best_mcc_full['f1']:.6f} BACC={best_mcc_full['balanced_accuracy']:.6f}")

globals().update({
    "train_ids_oof": train_ids,
    "oof_prob": oof_prob,
    "BEST_THR": BEST_THR_F1,          # default tetap F1
    "BEST_THR_F1": BEST_THR_F1,
    "BEST_THR_ACC": BEST_THR_ACC,
    "BEST_THR_BACC": BEST_THR_BACC,
    "BEST_THR_MCC": BEST_THR_MCC,
    "thr_table": thr_table,
    "THR_JSON_PATH": out_json,
    "THR_REPORT_PATH": out_txt,
    "THR_TABLE_CSV_PATH": out_csv,
})

gc.collect()


# Test Inference (Fold Ensemble)

In [None]:
# ============================================================
# STAGE 10 — Test Inference (Fold Ensemble) (ONE CELL, Kaggle CPU-SAFE)
# REVISI FULL v3.2 (MATCH STAGE8 FORWARD + BUILD TEST AGG FEATS + LOGIT ENSEMBLE + ID ALIGN HARD)
#
# Fix utama vs v3.1 kamu:
# - Forward model sama dengan STAGE 8 (x_proj includes GELU, pooling mix attn+mean)
# - Build G_test raw = meta_cols + agg_seq_feats (dari memmap) persis STAGE 8
# - Align df_test_meta.index pakai string-map (anti dtype mismatch)
# - Apply fold scaler dari ckpt (NO leakage, consistent)
# ============================================================

import os, gc, json, re, math, warnings
from pathlib import Path

import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", message="enable_nested_tensor is True.*")

# ----------------------------
# 0) Require previous stages
# ----------------------------
need = ["ART_DIR","FIX_DIR","MAX_LEN","SEQ_FEATURE_NAMES","df_test_meta","CKPT_DIR","n_splits"]
for k in need:
    if k not in globals():
        raise RuntimeError(f"Missing `{k}`. Jalankan STAGE 0..9 dulu.")

# Torch
try:
    import torch
    import torch.nn as nn
except Exception as e:
    raise RuntimeError("PyTorch tidak tersedia di environment ini.") from e

device = torch.device("cpu")
SEED = int(globals().get("SEED", 2025))
torch.manual_seed(SEED)
np.random.seed(SEED)

# Thread guard (CPU)
try:
    torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", "2")))
    torch.set_num_interop_threads(1)
except Exception:
    pass

FIX_DIR = Path(FIX_DIR)
ART_DIR = Path(ART_DIR); ART_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR = Path(CKPT_DIR)

OUT_DIR = ART_DIR / "preds"
OUT_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# helper: normalize id robustly
# ----------------------------
def _norm_id(x):
    if isinstance(x, (bytes, np.bytes_)):
        try:
            x = x.decode("utf-8", errors="ignore")
        except Exception:
            x = str(x)
    s = str(x).strip()
    if (s.startswith("b'") and s.endswith("'")) or (s.startswith('b"') and s.endswith('"')):
        s = s[2:-1]
    return s.strip()

def _load_ids_npy(path: Path):
    arr = np.load(path, allow_pickle=False)
    xs = arr.tolist() if hasattr(arr, "tolist") else list(arr)
    return [_norm_id(z) for z in xs]

# ----------------------------
# 1) Load TEST ordering (must match STAGE 6)
# ----------------------------
test_ids_path = FIX_DIR / "test_ids.npy"
if not test_ids_path.exists():
    raise FileNotFoundError(f"Missing: {test_ids_path}. Pastikan STAGE 6 sukses.")

test_ids = _load_ids_npy(test_ids_path)
NTE = len(test_ids)
if NTE <= 0:
    raise RuntimeError("test_ids kosong (NTE=0). Pastikan STAGE 6 sukses membuat test_ids.npy.")

# Align df_test_meta.index via string-map (HARD)
df_test_meta = df_test_meta.copy(deep=False)
meta_ids = [_norm_id(z) for z in df_test_meta.index.tolist()]
df_test_meta.index = pd.Index(meta_ids, name=df_test_meta.index.name)

pos_map = pd.Series(np.arange(len(meta_ids), dtype=np.int32), index=pd.Index(meta_ids, dtype="object"))
pos_idx = pos_map.reindex(test_ids).to_numpy()
if np.isnan(pos_idx).any():
    bad = [test_ids[i] for i in np.where(np.isnan(pos_idx))[0][:10]]
    raise KeyError(f"Some test_ids not found in df_test_meta.index (string-mapped). ex={bad} | missing_n={int(np.isnan(pos_idx).sum())}")
pos_idx = pos_idx.astype(np.int32)

if len(set(test_ids)) != len(test_ids):
    s = pd.Series(test_ids)
    dup = s[s.duplicated()].head(10).tolist()
    raise ValueError(f"Duplicate object_id in test_ids ordering (examples): {dup}")

# ----------------------------
# 2) Open fixed-length TEST memmaps
# ----------------------------
SEQ_FEATURE_NAMES = list(SEQ_FEATURE_NAMES)
Fdim = len(SEQ_FEATURE_NAMES)
L = int(MAX_LEN)

test_X_path = FIX_DIR / "test_X.dat"
test_B_path = FIX_DIR / "test_B.dat"
test_M_path = FIX_DIR / "test_M.dat"
for p in [test_X_path, test_B_path, test_M_path]:
    if not p.exists():
        raise FileNotFoundError(f"Missing fixed cache file: {p}. Pastikan STAGE 6 sukses.")

Xte = np.memmap(test_X_path, dtype=np.float32, mode="r", shape=(NTE, L, Fdim))
Bte = np.memmap(test_B_path, dtype=np.int8,   mode="r", shape=(NTE, L))
Mte = np.memmap(test_M_path, dtype=np.int8,   mode="r", shape=(NTE, L))

feat = {n:i for i,n in enumerate(SEQ_FEATURE_NAMES)}

# detect token mode (same as STAGE 8)
if ("mag" in feat) and ("mag_err_log" in feat):
    SEQ_TOKEN_MODE = "mag"
    VAL_FEAT = "mag"
elif ("flux_asinh" in feat) and ("err_log1p" in feat):
    SEQ_TOKEN_MODE = "asinh"
    VAL_FEAT = "flux_asinh"
else:
    raise RuntimeError(f"Cannot infer token_mode from SEQ_FEATURE_NAMES. Found={SEQ_FEATURE_NAMES}")

for k in ["snr_tanh","detected"]:
    if k not in feat:
        raise RuntimeError(f"Feature '{k}' not found in SEQ_FEATURE_NAMES.")
if VAL_FEAT not in feat:
    raise RuntimeError(f"Feature '{VAL_FEAT}' missing for token_mode={SEQ_TOKEN_MODE}.")

# ----------------------------
# 3) Checkpoints (fold_*.pt)
# ----------------------------
ckpts = []
for f in range(int(n_splits)):
    p = CKPT_DIR / f"fold_{f}.pt"
    if not p.exists():
        raise FileNotFoundError(f"Missing checkpoint: {p}. Pastikan STAGE 8 menyimpan ckpt per fold.")
    ckpts.append(p)

# ----------------------------
# 4) Safe/compat checkpoint loader
# ----------------------------
def torch_load_compat(path: Path):
    try:
        obj = torch.load(path, map_location="cpu", weights_only=True)
        if isinstance(obj, dict) and ("model_state" in obj or "cfg" in obj or "global_scaler" in obj):
            return obj
        return torch.load(path, map_location="cpu", weights_only=False)
    except TypeError:
        return torch.load(path, map_location="cpu")
    except Exception:
        return torch.load(path, map_location="cpu", weights_only=False)

def extract_state_and_meta(ckpt_obj):
    if isinstance(ckpt_obj, dict) and "model_state" in ckpt_obj and isinstance(ckpt_obj["model_state"], dict):
        return ckpt_obj["model_state"], ckpt_obj
    if isinstance(ckpt_obj, dict):
        any_tensor = any(hasattr(v, "shape") for v in ckpt_obj.values())
        if any_tensor:
            return ckpt_obj, {}
        return ckpt_obj, ckpt_obj
    raise RuntimeError(f"Unsupported ckpt object type: {type(ckpt_obj)}")

# ----------------------------
# 5) Infer architecture from state_dict
# ----------------------------
def infer_from_state(sd: dict):
    keys = set(sd.keys())

    if "band_emb.weight" not in sd:
        raise RuntimeError("state_dict missing band_emb.weight.")
    n_bands = int(sd["band_emb.weight"].shape[0])
    d_model = int(sd["band_emb.weight"].shape[1])

    if "pos_emb" not in sd:
        raise RuntimeError("state_dict missing pos_emb.")
    max_len_ckpt = int(sd["pos_emb"].shape[1])

    xproj_is_seq = ("x_proj.0.weight" in keys)
    if xproj_is_seq:
        feat_dim = int(sd["x_proj.0.weight"].shape[1])
    else:
        if "x_proj.weight" not in sd:
            raise RuntimeError("state_dict missing x_proj.weight or x_proj.0.weight.")
        feat_dim = int(sd["x_proj.weight"].shape[1])

    # g_proj
    if "g_proj.0.weight" in sd:
        g_dim = int(sd["g_proj.0.weight"].shape[1])
        g_hidden = int(sd["g_proj.0.weight"].shape[0])
    else:
        g_dim = 0
        g_hidden = max(32, d_model // 2)

    # encoder layers
    layer_ids = set()
    for k in keys:
        m = re.match(r"encoder\.layers\.(\d+)\.", k)
        if m:
            layer_ids.add(int(m.group(1)))
    n_layers = (max(layer_ids) + 1) if layer_ids else 0
    if n_layers <= 0:
        raise RuntimeError("Cannot infer n_layers from state_dict (encoder.layers.* not found).")

    # dim_ff from linear1
    k_lin1 = "encoder.layers.0.linear1.weight"
    if k_lin1 in sd:
        dim_ff = int(sd[k_lin1].shape[0])
    else:
        lin1_keys = [k for k in keys if k.endswith("linear1.weight")]
        if not lin1_keys:
            raise RuntimeError("Cannot infer dim_feedforward (linear1.weight not found).")
        dim_ff = int(sd[sorted(lin1_keys)[0]].shape[0])

    has_pool_ln = ("pool_ln.weight" in keys and "pool_ln.bias" in keys)

    # head final idx
    head_w_idx = []
    for k in keys:
        m = re.match(r"head\.(\d+)\.weight", k)
        if m:
            head_w_idx.append(int(m.group(1)))
    if not head_w_idx:
        raise RuntimeError("Cannot infer head structure (head.*.weight not found).")
    head_final_idx = max(sorted(set(head_w_idx)))

    return {
        "n_bands": n_bands,
        "d_model": d_model,
        "max_len_ckpt": max_len_ckpt,
        "feat_dim": feat_dim,
        "g_dim": g_dim,
        "g_hidden": g_hidden,
        "n_layers": n_layers,
        "dim_ff": dim_ff,
        "has_pool_ln": has_pool_ln,
        "xproj_is_seq": xproj_is_seq,
        "head_final_idx": head_final_idx,
    }

# ----------------------------
# 6) Build model that matches STAGE 8 forward
# ----------------------------
class FlexMultibandEventTransformer(nn.Module):
    def __init__(self, feat_dim, max_len, n_bands, d_model, n_heads, n_layers, dim_ff, dropout,
                 g_dim, g_hidden, xproj_is_seq=True, has_pool_ln=True, head_final_idx=3):
        super().__init__()
        self.n_bands = int(n_bands)
        self.max_len = int(max_len)
        self.d_model = int(d_model)

        # IMPORTANT: match STAGE 8 (Linear + GELU + Dropout)
        if xproj_is_seq:
            self.x_proj = nn.Sequential(
                nn.Linear(int(feat_dim), int(d_model)),
                nn.GELU(),
                nn.Dropout(float(dropout)),
            )
        else:
            # still include GELU+Dropout for forward match
            self.x_proj = nn.Sequential(
                nn.Linear(int(feat_dim), int(d_model)),
                nn.GELU(),
                nn.Dropout(float(dropout)),
            )

        self.band_emb = nn.Embedding(int(n_bands), int(d_model))

        self.pos_emb = nn.Parameter(torch.zeros(1, int(max_len), int(d_model)))
        nn.init.normal_(self.pos_emb, mean=0.0, std=0.02)

        enc_layer = nn.TransformerEncoderLayer(
            d_model=int(d_model),
            nhead=int(n_heads),
            dim_feedforward=int(dim_ff),
            dropout=float(dropout),
            activation="gelu",
            batch_first=True,
            norm_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=int(n_layers))

        self.attn = nn.Linear(int(d_model), 1)

        self.has_pool_ln = bool(has_pool_ln)
        if self.has_pool_ln:
            self.pool_ln = nn.LayerNorm(int(d_model))

        self.g_proj = nn.Sequential(
            nn.Linear(int(g_dim), int(g_hidden)),
            nn.GELU(),
            nn.Dropout(float(dropout)),
        )

        in_head = int(d_model + g_hidden)
        if head_final_idx == 3:
            self.head = nn.Sequential(
                nn.Linear(in_head, int(d_model)),
                nn.GELU(),
                nn.Dropout(float(dropout)),
                nn.Linear(int(d_model), 1),
            )
        elif head_final_idx == 2:
            self.head = nn.Sequential(
                nn.Linear(in_head, int(d_model)),
                nn.GELU(),
                nn.Linear(int(d_model), 1),
            )
        else:
            self.head = nn.Sequential(
                nn.Linear(in_head, int(d_model)),
                nn.Linear(int(d_model), 1),
            )

    def forward(self, X, band_id, mask, G):
        X = X.to(torch.float32)
        band_id = band_id.clamp(0, self.n_bands - 1).to(torch.long)
        mask = mask.to(torch.long)

        pad_mask = (mask == 0)  # True=pad
        all_pad = pad_mask.all(dim=1)
        if all_pad.any():
            pad_mask = pad_mask.clone()
            pad_mask[all_pad, 0] = False

        h = self.x_proj(X) + self.band_emb(band_id) + self.pos_emb[:, :X.shape[1], :]
        h = self.encoder(h, src_key_padding_mask=pad_mask)

        # attn pooling
        a = self.attn(h).squeeze(-1)
        a = a.masked_fill(pad_mask, -1e9)
        w = torch.softmax(a, dim=1)
        pooled_attn = torch.sum(h * w.unsqueeze(-1), dim=1)

        # mean pooling on valid tokens
        valid = (~pad_mask).to(h.dtype).unsqueeze(-1)
        denom = valid.sum(dim=1).clamp_min(1.0)
        pooled_mean = (h * valid).sum(dim=1) / denom

        pooled = 0.6 * pooled_attn + 0.4 * pooled_mean  # match STAGE 8
        if self.has_pool_ln:
            pooled = self.pool_ln(pooled)

        g = self.g_proj(G.to(torch.float32))
        z = torch.cat([pooled, g], dim=1)
        return self.head(z).squeeze(-1)  # logit

def sigmoid_np(x):
    x = np.clip(x, -50, 50)
    return 1.0 / (1.0 + np.exp(-x))

# ----------------------------
# 7) Build TEST global features (meta + agg seq feats) — match STAGE 8
# ----------------------------
BASE_G_COLS_DEFAULT = ["Z","Z_err","EBV","Z_missing","Z_err_missing","EBV_missing","is_photoz"]
N_BANDS = 6

def _safe_div(a, b):
    return a / np.maximum(b, 1.0)

def build_agg_seq_features_memmap(Xmm, Bmm, Mmm, chunk=512):
    """
    Match STAGE 8 agg:
      glob: [tok_count, det_frac, mean_abs_snr, max_abs_snr] (4)
      val stats:
        mag  : [mean_mag, std_mag, min_mag] (3)
        asinh: [mean_abs_flux, std_flux, max_abs_flux] (3)
      per-band (b=0..5): [cnt, det_frac_b, mean_abs_snr_b, mean_val_b] (4*6=24)
    total agg_dim=31
    """
    snr_i = feat["snr_tanh"]
    det_i = feat["detected"]
    val_i = feat[VAL_FEAT]

    out = np.zeros((NTE, 31), dtype=np.float32)

    for start in range(0, NTE, int(chunk)):
        end = min(NTE, start + int(chunk))
        Xc = np.asarray(Xmm[start:end])  # (B,L,F)
        Bc = np.asarray(Bmm[start:end])  # (B,L)
        Mc = np.asarray(Mmm[start:end])  # (B,L)

        real = (Mc == 1)
        tok_count = real.sum(axis=1).astype(np.float32)

        snr = np.abs(Xc[:, :, snr_i]).astype(np.float32)
        det = (Xc[:, :, det_i] > 0.5).astype(np.float32)
        val = Xc[:, :, val_i].astype(np.float32)

        snr_r = snr * real
        det_r = det * real

        det_frac = _safe_div(det_r.sum(axis=1), tok_count)
        mean_abs_snr = _safe_div(snr_r.sum(axis=1), tok_count)
        max_abs_snr = np.where(tok_count > 0, (snr * real).max(axis=1), 0.0).astype(np.float32)

        if SEQ_TOKEN_MODE == "mag":
            val_r = np.where(real, val, np.nan)
            mean_val = np.nanmean(val_r, axis=1).astype(np.float32)
            std_val  = np.nanstd(val_r, axis=1).astype(np.float32)
            min_val  = np.nanmin(val_r, axis=1).astype(np.float32)
            mean_val = np.nan_to_num(mean_val, nan=0.0).astype(np.float32)
            std_val  = np.nan_to_num(std_val,  nan=0.0).astype(np.float32)
            min_val  = np.nan_to_num(min_val,  nan=0.0).astype(np.float32)
            global_val_feats = np.stack([mean_val, std_val, min_val], axis=1)
        else:
            aval = np.abs(val)
            aval_r = aval * real
            mean_aval = _safe_div(aval_r.sum(axis=1), tok_count)
            val_r = np.where(real, val, np.nan)
            std_val = np.nanstd(val_r, axis=1).astype(np.float32)
            max_aval = np.where(tok_count > 0, (aval * real).max(axis=1), 0.0).astype(np.float32)
            std_val = np.nan_to_num(std_val, nan=0.0).astype(np.float32)
            global_val_feats = np.stack([mean_aval.astype(np.float32), std_val, max_aval], axis=1)

        per_band = []
        for b in range(N_BANDS):
            bm = (Bc == b) & real
            cnt = bm.sum(axis=1).astype(np.float32)

            detb = (det * bm).sum(axis=1).astype(np.float32)
            snrb = (snr * bm).sum(axis=1).astype(np.float32)

            det_frac_b = _safe_div(detb, cnt)
            mean_abs_snr_b = _safe_div(snrb, cnt)

            if SEQ_TOKEN_MODE == "mag":
                val_b = np.where(bm, val, np.nan)
                mean_val_b = np.nanmean(val_b, axis=1).astype(np.float32)
                mean_val_b = np.nan_to_num(mean_val_b, nan=0.0).astype(np.float32)
            else:
                aval_b = np.abs(val) * bm
                mean_val_b = _safe_div(aval_b.sum(axis=1).astype(np.float32), cnt)

            per_band.append(np.stack([cnt, det_frac_b, mean_abs_snr_b, mean_val_b], axis=1))

        per_band = np.concatenate(per_band, axis=1).astype(np.float32)

        glob = np.stack([tok_count, det_frac, mean_abs_snr, max_abs_snr], axis=1).astype(np.float32)
        agg = np.concatenate([glob, global_val_feats.astype(np.float32), per_band], axis=1).astype(np.float32)

        out[start:end] = agg

        del Xc, Bc, Mc
        if (start // int(chunk)) % 4 == 0:
            gc.collect()

    return out

print(f"[Stage 10] Build TEST global features: token_mode={SEQ_TOKEN_MODE} | val_feat={VAL_FEAT}")

# meta part (always same order)
BASE_G_COLS = BASE_G_COLS_DEFAULT
for c in BASE_G_COLS:
    if c not in df_test_meta.columns:
        df_test_meta[c] = 0.0

G_meta = df_test_meta.iloc[pos_idx][BASE_G_COLS].copy()
for c in BASE_G_COLS:
    G_meta[c] = pd.to_numeric(G_meta[c], errors="coerce").fillna(0.0).astype(np.float32)
G_meta_np = G_meta.to_numpy(dtype=np.float32, copy=False)

# agg part (same as STAGE 8) — compute ONCE
print("[Stage 10] Building AGG seq features for TEST (one-time)...")
t0 = time.time()
G_seq_np = build_agg_seq_features_memmap(Xte, Bte, Mte, chunk=512)
print(f"[Stage 10] AGG built: shape={G_seq_np.shape} | time={time.time()-t0:.1f}s")

# raw global
G_raw_default = np.concatenate([G_meta_np, G_seq_np], axis=1).astype(np.float32)  # (NTE, 38)

# ----------------------------
# 8) Inference per fold (logit ensemble)
# ----------------------------
@torch.no_grad()
def predict_logits_batchwise(model, Xmm, Bmm, Mmm, G_raw, mean=None, std=None, batch_size=64):
    model.eval()
    out = np.zeros((Xmm.shape[0],), dtype=np.float32)
    N0 = int(Xmm.shape[0])
    for i in range(0, N0, int(batch_size)):
        j = min(N0, i + int(batch_size))
        Xb = torch.from_numpy(np.asarray(Xmm[i:j]).astype(np.float32, copy=False))
        Bb = torch.from_numpy(np.asarray(Bmm[i:j]).astype(np.int64, copy=False))
        Mb = torch.from_numpy(np.asarray(Mmm[i:j]).astype(np.int64, copy=False))

        Gb = G_raw[i:j]
        if mean is not None and std is not None:
            Gb = ((Gb - mean) / std).astype(np.float32, copy=False)
        Gb = torch.from_numpy(Gb.astype(np.float32, copy=False))

        logit = model(Xb.to(device), Bb.to(device), Mb.to(device), Gb.to(device))
        out[i:j] = logit.detach().cpu().numpy().astype(np.float32, copy=False)
    return out

BATCH_SIZE = 64
test_logit_folds = np.zeros((NTE, int(n_splits)), dtype=np.float32)

print(f"[Stage 10] Test inference: N_test={NTE:,} | folds={n_splits} | batch={BATCH_SIZE} | ensemble=mean_logits")

arch_used = None

for fold, ckpt_path in enumerate(ckpts):
    ckpt_obj = torch_load_compat(ckpt_path)
    sd, meta = extract_state_and_meta(ckpt_obj)

    arch = infer_from_state(sd)
    if arch_used is None:
        arch_used = dict(arch)

    cfg = meta.get("cfg", {}) if isinstance(meta, dict) else {}
    dropout = float(cfg.get("dropout", 0.0)) if isinstance(cfg, dict) else 0.0

    # n_heads: use cfg if valid else pick divisor
    n_heads = int(cfg.get("n_heads", 0)) if isinstance(cfg, dict) else 0
    if n_heads <= 0 or (arch["d_model"] % n_heads != 0):
        for h in [8, 4, 2, 1, 16, 32]:
            if arch["d_model"] % h == 0:
                n_heads = h
                break
        if n_heads <= 0:
            n_heads = 4

    # HARD schema checks
    if arch["feat_dim"] != Fdim:
        raise RuntimeError(
            f"Fold {fold}: feature_dim mismatch.\n"
            f"- ckpt expects feat_dim={arch['feat_dim']}\n"
            f"- memmap has Fdim={Fdim}\n"
            "Solusi: pastikan STAGE 6 feature list sama saat training ckpt dibuat."
        )
    if arch["max_len_ckpt"] != L:
        raise RuntimeError(
            f"Fold {fold}: max_len mismatch.\n"
            f"- ckpt max_len={arch['max_len_ckpt']}\n"
            f"- memmap MAX_LEN={L}\n"
        )

    # Decide G_raw to match ckpt g_dim
    g_dim = int(arch["g_dim"])
    if g_dim <= 0:
        G_raw = np.zeros((NTE, 0), dtype=np.float32)
        g_mean = None
        g_std = None
    else:
        # If ckpt expects 38 dims (7 meta + 31 agg), use default.
        # Otherwise crop/pad default to match g_dim (best-effort).
        if G_raw_default.shape[1] == g_dim:
            G_raw = G_raw_default
        elif G_raw_default.shape[1] > g_dim:
            G_raw = G_raw_default[:, :g_dim].copy()
        else:
            pad = np.zeros((NTE, g_dim - G_raw_default.shape[1]), dtype=np.float32)
            G_raw = np.concatenate([G_raw_default, pad], axis=1).astype(np.float32)

        scaler = meta.get("global_scaler", None) if isinstance(meta, dict) else None
        if scaler is not None and isinstance(scaler, dict) and ("mean" in scaler) and ("std" in scaler):
            g_mean = np.asarray(scaler["mean"], dtype=np.float32).reshape(-1)
            g_std  = np.asarray(scaler["std"],  dtype=np.float32).reshape(-1)
            if g_mean.shape[0] != g_dim or g_std.shape[0] != g_dim:
                g_mean = None; g_std = None
            else:
                g_std = np.where(g_std < 1e-6, 1.0, g_std).astype(np.float32)
        else:
            g_mean = None; g_std = None

    model = FlexMultibandEventTransformer(
        feat_dim=arch["feat_dim"],
        max_len=arch["max_len_ckpt"],
        n_bands=arch["n_bands"],
        d_model=arch["d_model"],
        n_heads=n_heads,
        n_layers=arch["n_layers"],
        dim_ff=arch["dim_ff"],
        dropout=dropout,
        g_dim=g_dim,
        g_hidden=arch["g_hidden"],
        xproj_is_seq=True,                 # match Stage 8 forward
        has_pool_ln=arch["has_pool_ln"],
        head_final_idx=arch["head_final_idx"],
    ).to(device)

    model.load_state_dict(sd, strict=True)

    logits = predict_logits_batchwise(
        model, Xte, Bte, Mte, G_raw, mean=g_mean, std=g_std, batch_size=BATCH_SIZE
    )
    test_logit_folds[:, fold] = logits

    probs_tmp = sigmoid_np(logits)
    print(f"  fold {fold}: d_model={arch['d_model']} n_heads={n_heads} g_dim={g_dim} | "
          f"logit_mean={float(logits.mean()):.6f} | prob_mean={float(probs_tmp.mean()):.6f} | prob_std={float(probs_tmp.std()):.6f}")

    del model, logits, probs_tmp
    gc.collect()

# ensemble on logits
test_logit_ens = test_logit_folds.mean(axis=1).astype(np.float32)
test_prob_folds = sigmoid_np(test_logit_folds).astype(np.float32)
test_prob_ens   = sigmoid_np(test_logit_ens).astype(np.float32)

# ----------------------------
# 9) Save artifacts
# ----------------------------
logit_fold_path = OUT_DIR / "test_logit_folds.npy"
logit_ens_path  = OUT_DIR / "test_logit_ens.npy"
prob_fold_path  = OUT_DIR / "test_prob_folds.npy"
prob_ens_path   = OUT_DIR / "test_prob_ens.npy"
csv_path        = OUT_DIR / "test_prob_ens.csv"
cfg_path        = OUT_DIR / "test_infer_config.json"

np.save(logit_fold_path, test_logit_folds)
np.save(logit_ens_path,  test_logit_ens)
np.save(prob_fold_path,  test_prob_folds)
np.save(prob_ens_path,   test_prob_ens)

pd.DataFrame({"object_id": test_ids, "prob": test_prob_ens}).to_csv(csv_path, index=False)

infer_cfg = {
    "seed": int(SEED),
    "n_splits": int(n_splits),
    "ensemble": "mean_logits_then_sigmoid",
    "batch_size": int(BATCH_SIZE),
    "max_len": int(L),
    "feature_dim": int(Fdim),
    "token_mode": SEQ_TOKEN_MODE,
    "val_feat": VAL_FEAT,
    "global_meta_cols": BASE_G_COLS_DEFAULT,
    "global_agg_dim": 31,
    "global_default_dim": int(G_raw_default.shape[1]),
    "ckpt_dir": str(CKPT_DIR),
    "ckpts": [str(p) for p in ckpts],
    "arch_inferred_from_first_fold": arch_used,
    "outputs": {
        "test_logit_folds": str(logit_fold_path),
        "test_logit_ens": str(logit_ens_path),
        "test_prob_folds": str(prob_fold_path),
        "test_prob_ens": str(prob_ens_path),
        "test_prob_ens_csv": str(csv_path),
    }
}
with open(cfg_path, "w", encoding="utf-8") as f:
    json.dump(infer_cfg, f, indent=2)

print("\n[Stage 10] DONE")
print(f"- Saved logits folds: {logit_fold_path}")
print(f"- Saved logits ens  : {logit_ens_path}")
print(f"- Saved probs folds : {prob_fold_path}")
print(f"- Saved probs ens   : {prob_ens_path}")
print(f"- Saved csv         : {csv_path}")
print(f"- Saved config      : {cfg_path}")
print(f"- ens prob mean={float(test_prob_ens.mean()):.6f} | std={float(test_prob_ens.std()):.6f} | "
      f"min={float(test_prob_ens.min()):.6f} | max={float(test_prob_ens.max()):.6f}")

# Export globals for submission stage
globals().update({
    "test_ids": test_ids,
    "test_logit_folds": test_logit_folds,
    "test_logit_ens": test_logit_ens,
    "test_prob_folds": test_prob_folds,
    "test_prob_ens": test_prob_ens,
    "TEST_LOGIT_FOLDS_PATH": logit_fold_path,
    "TEST_LOGIT_ENS_PATH": logit_ens_path,
    "TEST_PROB_FOLDS_PATH": prob_fold_path,
    "TEST_PROB_ENS_PATH": prob_ens_path,
    "TEST_PROB_CSV_PATH": csv_path,
    "TEST_INFER_CFG_PATH": cfg_path,
})

gc.collect()


# Evalution 

In [None]:
# ============================================================
# ONE CELL — EVALUATION (Precision / Recall / F1) + Threshold Sweep (OOF)
# REVISI FULL v3 — Robust target detect + robust align + handle dup ids + better thr candidates
#
# Input minimal:
# - df_train_meta (index: object_id, kolom target: target/y/label/class/is_tde/binary_target)
# - oof_prob (globals) ATAU file OOF_DIR/oof_prob.npy ATAU OOF_DIR/oof_prob.csv
#
# Output:
# - Print ringkasan metrik
# - Save: eval_report.txt + eval_threshold_table.csv + eval_summary.json
# - Export globals: BEST_THR_F1, BEST_THR_F05, BEST_THR_F2, thr_table_eval
# ============================================================

import gc, json, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Require minimal
# ----------------------------
if "df_train_meta" not in globals():
    raise RuntimeError("Missing df_train_meta. Jalankan stage meta dulu.")

ART_DIR = Path(globals().get("ART_DIR", "/kaggle/working"))
OOF_DIR = Path(globals().get("OOF_DIR", ART_DIR / "oof"))
OOF_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# Utils: id normalize + robust 1D float32
# ----------------------------
def _norm_id(x):
    if isinstance(x, (bytes, np.bytes_)):
        try:
            x = x.decode("utf-8", errors="ignore")
        except Exception:
            x = str(x)
    s = str(x).strip()
    if (s.startswith("b'") and s.endswith("'")) or (s.startswith('b"') and s.endswith('"')):
        s = s[2:-1]
    return s.strip()

def _as_1d_float32(arr):
    a = np.asarray(arr)
    if a.dtype == object and a.ndim == 0:
        try:
            a = np.asarray(a.item())
        except Exception:
            pass
    a = np.asarray(a, dtype=np.float32)
    if a.ndim == 0:
        return a
    if a.ndim > 1:
        a = a.reshape(-1)
    return a

def _sanitize_prob(p):
    p = np.asarray(p, dtype=np.float32)
    p = np.nan_to_num(p, nan=0.0, posinf=1.0, neginf=0.0)
    p = np.clip(p, 0.0, 1.0)
    return p.astype(np.float32)

# ensure meta index normalized
df_train_meta = df_train_meta.copy(deep=False)
df_train_meta.index = pd.Index([_norm_id(z) for z in df_train_meta.index], name=df_train_meta.index.name)

# ----------------------------
# 0b) Detect target column (robust)
# ----------------------------
def _detect_target_col(df):
    for cand in ["target","y","label","class","is_tde","binary_target"]:
        if cand in df.columns:
            return cand
    return None

TARGET_COL = _detect_target_col(df_train_meta)
if TARGET_COL is None:
    raise RuntimeError(
        "Cannot detect target column in df_train_meta. "
        f"Columns sample: {list(df_train_meta.columns)[:60]}"
    )

def _get_y_aligned(ids):
    yy = pd.to_numeric(df_train_meta.loc[ids, TARGET_COL], errors="coerce").fillna(0).to_numpy()
    yy = (yy.astype(np.float32) > 0).astype(np.int8)
    return yy

# ----------------------------
# 1) Load oof_prob (prefer csv for safest alignment)
# ----------------------------
def load_oof():
    pcsv = OOF_DIR / "oof_prob.csv"
    if pcsv.exists():
        df = pd.read_csv(pcsv)
        if ("object_id" in df.columns) and ("oof_prob" in df.columns):
            df["object_id"] = df["object_id"].apply(_norm_id)
            p = _sanitize_prob(_as_1d_float32(df["oof_prob"].to_numpy()))
            if len(p) != len(df):
                raise RuntimeError("oof_prob.csv: length mismatch after parsing.")
            df = df[["object_id"]].copy()
            df["oof_prob"] = p
            return p, df, "csv"

    if "oof_prob" in globals():
        p = _as_1d_float32(globals()["oof_prob"])
        if isinstance(p, np.ndarray) and p.ndim != 0:
            return _sanitize_prob(p), None, "globals"

    pnpy = OOF_DIR / "oof_prob.npy"
    if pnpy.exists():
        p = _sanitize_prob(_as_1d_float32(np.load(pnpy, allow_pickle=False)))
        return p, None, "npy"

    raise FileNotFoundError("OOF prob tidak ditemukan (oof_prob.csv / globals oof_prob / oof_prob.npy).")

oof_prob, df_oof_csv, oof_src = load_oof()
if not isinstance(oof_prob, np.ndarray) or oof_prob.ndim == 0:
    raise TypeError(f"Invalid oof_prob (scalar/unsized). type={type(oof_prob)} ndim={getattr(oof_prob,'ndim',None)}")

# ----------------------------
# 2) Align y (target) to oof order
# ----------------------------
train_ids = None
y = None

if df_oof_csv is not None:
    # Handle duplicate ids: mean per object_id but keep first-seen order
    ids_first = pd.unique(df_oof_csv["object_id"].to_numpy())
    if len(ids_first) != len(df_oof_csv):
        df_mean = df_oof_csv.groupby("object_id", as_index=True)["oof_prob"].mean()
        df_oof_csv = pd.DataFrame({"object_id": ids_first})
        df_oof_csv["oof_prob"] = df_mean.reindex(ids_first).to_numpy(dtype=np.float32)
        oof_prob = _sanitize_prob(df_oof_csv["oof_prob"].to_numpy())

    train_ids = df_oof_csv["object_id"].tolist()

    # drop ids not in meta
    ok = np.asarray([oid in df_train_meta.index for oid in train_ids], dtype=bool)
    if not ok.all():
        bad = [train_ids[i] for i in np.where(~ok)[0][:10]]
        print(f"[WARN] oof_prob.csv contains ids not in df_train_meta: missing_n={int((~ok).sum())} examples={bad}")
        df_oof_csv = df_oof_csv.loc[ok].reset_index(drop=True)
        train_ids = df_oof_csv["object_id"].tolist()
        oof_prob = _sanitize_prob(df_oof_csv["oof_prob"].to_numpy())

    y = _get_y_aligned(train_ids)

if y is None and ("train_ids_ordered" in globals()):
    ids = [_norm_id(z) for z in list(globals()["train_ids_ordered"])]
    if len(ids) == len(oof_prob):
        missing = [oid for oid in ids if oid not in df_train_meta.index]
        if missing:
            raise KeyError(f"train_ids_ordered contains ids not in df_train_meta. ex={missing[:10]} missing_n={len(missing)}")
        train_ids = ids
        y = _get_y_aligned(train_ids)

if y is None:
    if len(oof_prob) != len(df_train_meta):
        raise RuntimeError(
            f"Tidak bisa align y. len(oof_prob)={len(oof_prob)} != len(df_train_meta)={len(df_train_meta)} "
            "dan tidak ada oof_prob.csv (object_id) atau train_ids_ordered."
        )
    train_ids = df_train_meta.index.astype(str).tolist()
    y = _get_y_aligned(train_ids)

if len(y) != len(oof_prob):
    raise RuntimeError(f"Length mismatch: y={len(y)} vs oof_prob={len(oof_prob)}")

uy = set(np.unique(y).tolist())
if not uy.issubset({0, 1}):
    raise ValueError(f"y must be binary 0/1. Found: {sorted(list(uy))}")

N = int(len(y))
pos = int((y == 1).sum())
neg = int((y == 0).sum())

print(f"[Eval] OOF source={oof_src} | target_col={TARGET_COL}")
print(f"[Eval] N={N:,} | pos={pos:,} | neg={neg:,} | pos%={pos/max(N,1)*100:.6f}%")

# ----------------------------
# 3) Metrics: P/R/F1 + Fbeta + AUC optional
# ----------------------------
def prf_from_pred(y_true, y_pred01):
    y_true = np.asarray(y_true, dtype=np.int32)
    y_pred01 = np.asarray(y_pred01, dtype=np.int32)

    tp = int(((y_true == 1) & (y_pred01 == 1)).sum())
    fp = int(((y_true == 0) & (y_pred01 == 1)).sum())
    fn = int(((y_true == 1) & (y_pred01 == 0)).sum())
    tn = int(((y_true == 0) & (y_pred01 == 0)).sum())

    precision = tp / max(tp + fp, 1)
    recall    = tp / max(tp + fn, 1)
    f1 = 0.0 if (precision + recall) == 0 else (2.0 * precision * recall / (precision + recall))

    return {
        "tp": tp, "fp": fp, "fn": fn, "tn": tn,
        "precision": float(precision),
        "recall": float(recall),
        "f1": float(f1),
        "pos_pred": int(y_pred01.sum()),
        "acc": float((tp + tn) / max(len(y_true), 1)),
    }

def fbeta_from_pr(precision, recall, beta=1.0):
    b2 = beta * beta
    denom = (b2 * precision + recall)
    if denom <= 0:
        return 0.0
    return float((1 + b2) * precision * recall / denom)

def eval_at_threshold(prob, y_true, thr):
    pred = (prob >= float(thr)).astype(np.int8)
    met = prf_from_pred(y_true, pred)
    met["thr"] = float(thr)
    met["f0.5"] = fbeta_from_pr(met["precision"], met["recall"], beta=0.5)
    met["f2"]   = fbeta_from_pr(met["precision"], met["recall"], beta=2.0)
    return met

roc_auc = None
pr_auc = None
try:
    from sklearn.metrics import roc_auc_score, average_precision_score
    if (y.max() == 1) and (y.min() == 0):
        roc_auc = float(roc_auc_score(y, oof_prob))
        pr_auc  = float(average_precision_score(y, oof_prob))
except Exception:
    pass

base = eval_at_threshold(oof_prob, y, 0.5)

# ----------------------------
# 4) Threshold candidates (grid + quantiles + sampled uniques)
# ----------------------------
grid = np.concatenate([
    np.linspace(0.00, 0.10, 41),
    np.linspace(0.10, 0.90, 161),
    np.linspace(0.90, 1.00, 41),
]).astype(np.float32)

qs = np.linspace(0.001, 0.999, 999, dtype=np.float32)
quant_thr = np.quantile(oof_prob, qs).astype(np.float32)

uniq = np.unique(oof_prob)
if len(uniq) > 4000:
    take = np.linspace(0, len(uniq)-1, 4000, dtype=int)
    uniq = uniq[take].astype(np.float32)

thr_candidates = np.unique(np.clip(np.concatenate([grid, quant_thr, uniq]), 0.0, 1.0)).astype(np.float32)

rows = []
best_f1  = base.copy()
best_f05 = base.copy()
best_f2  = base.copy()

for thr in thr_candidates:
    met = eval_at_threshold(oof_prob, y, float(thr))
    rows.append([
        met["thr"], met["f1"], met["f0.5"], met["f2"],
        met["precision"], met["recall"], met["acc"],
        met["tp"], met["fp"], met["fn"], met["tn"], met["pos_pred"]
    ])

    # best F1 tie-break: recall higher, then fp lower
    if (met["f1"] > best_f1["f1"] + 1e-12) or (
        abs(met["f1"] - best_f1["f1"]) <= 1e-12 and (met["recall"] > best_f1["recall"] + 1e-12)
    ) or (
        abs(met["f1"] - best_f1["f1"]) <= 1e-12 and abs(met["recall"] - best_f1["recall"]) <= 1e-12 and (met["fp"] < best_f1["fp"])
    ):
        best_f1 = met.copy()

    # best F0.5 (precision-leaning)
    if (met["f0.5"] > best_f05.get("f0.5", -1.0) + 1e-12):
        best_f05 = met.copy()

    # best F2 (recall-leaning)
    if (met["f2"] > best_f2.get("f2", -1.0) + 1e-12):
        best_f2 = met.copy()

thr_table = pd.DataFrame(
    rows,
    columns=["thr","f1","f0.5","f2","precision","recall","acc","tp","fp","fn","tn","pos_pred"]
).sort_values(["f1","recall","precision"], ascending=[False, False, False]).reset_index(drop=True)

BEST_THR_F1  = float(best_f1["thr"])
BEST_THR_F05 = float(best_f05["thr"])
BEST_THR_F2  = float(best_f2["thr"])

# ----------------------------
# 5) Print report
# ----------------------------
print("\nEVALUATION (OOF) — Precision/Recall/F1")
if roc_auc is not None:
    print(f"- ROC-AUC={roc_auc:.6f} | PR-AUC={pr_auc:.6f}")
print("\nBaseline @ thr=0.5")
print(f"- F1={base['f1']:.6f} | P={base['precision']:.6f} | R={base['recall']:.6f} | ACC={base['acc']:.6f}")
print(f"  tp={base['tp']} fp={base['fp']} fn={base['fn']} tn={base['tn']} | pos_pred={base['pos_pred']}")

print(f"\nBEST-F1  @ thr={BEST_THR_F1:.6f}")
print(f"- F1={best_f1['f1']:.6f} | P={best_f1['precision']:.6f} | R={best_f1['recall']:.6f} | ACC={best_f1['acc']:.6f}")
print(f"  tp={best_f1['tp']} fp={best_f1['fp']} fn={best_f1['fn']} tn={best_f1['tn']} | pos_pred={best_f1['pos_pred']}")

print(f"\nBEST-F0.5 @ thr={BEST_THR_F05:.6f} (precision-leaning)")
print(f"- F0.5={best_f05['f0.5']:.6f} | P={best_f05['precision']:.6f} | R={best_f05['recall']:.6f} | F1={best_f05['f1']:.6f}")

print(f"\nBEST-F2   @ thr={BEST_THR_F2:.6f} (recall-leaning)")
print(f"- F2={best_f2['f2']:.6f} | P={best_f2['precision']:.6f} | R={best_f2['recall']:.6f} | F1={best_f2['f1']:.6f}")

print("\nTop 10 thresholds by F1:")
for i in range(min(10, len(thr_table))):
    r = thr_table.iloc[i]
    print(f"{i+1:02d}. thr={r['thr']:.6f} | f1={r['f1']:.6f} | f0.5={r['f0.5']:.6f} | f2={r['f2']:.6f} | "
          f"P={r['precision']:.6f} R={r['recall']:.6f} | tp={int(r['tp'])} fp={int(r['fp'])} fn={int(r['fn'])} | pos_pred={int(r['pos_pred'])}")

# ----------------------------
# 6) Save artifacts
# ----------------------------
out_txt  = OOF_DIR / "eval_report.txt"
out_csv  = OOF_DIR / "eval_threshold_table.csv"
out_json = OOF_DIR / "eval_summary.json"

lines = []
lines.append("OOF Evaluation Report (Precision/Recall/F1)")
lines.append(f"source={oof_src} | target_col={TARGET_COL}")
lines.append(f"N={N} | pos={pos} | neg={neg} | pos%={pos/max(N,1)*100:.8f}%")
if roc_auc is not None:
    lines.append(f"ROC-AUC={roc_auc:.10f} | PR-AUC={pr_auc:.10f}")
lines.append("")
lines.append("Baseline @ thr=0.5")
lines.append(f"F1={base['f1']:.10f} | P={base['precision']:.10f} | R={base['recall']:.10f} | ACC={base['acc']:.10f}")
lines.append(f"tp={base['tp']} fp={base['fp']} fn={base['fn']} tn={base['tn']} | pos_pred={base['pos_pred']}")
lines.append("")
lines.append(f"BEST-F1 @ thr={BEST_THR_F1:.10f}")
lines.append(f"F1={best_f1['f1']:.10f} | P={best_f1['precision']:.10f} | R={best_f1['recall']:.10f} | ACC={best_f1['acc']:.10f}")
lines.append(f"tp={best_f1['tp']} fp={best_f1['fp']} fn={best_f1['fn']} tn={best_f1['tn']} | pos_pred={best_f1['pos_pred']}")
lines.append("")
lines.append(f"BEST-F0.5 @ thr={BEST_THR_F05:.10f}")
lines.append(f"F0.5={best_f05['f0.5']:.10f} | P={best_f05['precision']:.10f} | R={best_f05['recall']:.10f} | F1={best_f05['f1']:.10f}")
lines.append("")
lines.append(f"BEST-F2 @ thr={BEST_THR_F2:.10f}")
lines.append(f"F2={best_f2['f2']:.10f} | P={best_f2['precision']:.10f} | R={best_f2['recall']:.10f} | F1={best_f2['f1']:.10f}")
lines.append("")
lines.append("Top 10 thresholds by F1:")
for i in range(min(10, len(thr_table))):
    r = thr_table.iloc[i]
    lines.append(f"{i+1:02d}. thr={r['thr']:.10f} | f1={r['f1']:.10f} | f0.5={r['f0.5']:.10f} | f2={r['f2']:.10f} | "
                 f"P={r['precision']:.10f} R={r['recall']:.10f} | tp={int(r['tp'])} fp={int(r['fp'])} fn={int(r['fn'])} | pos_pred={int(r['pos_pred'])}")

with open(out_txt, "w", encoding="utf-8") as f:
    f.write("\n".join(lines) + "\n")

thr_table.to_csv(out_csv, index=False)

payload = {
    "source": oof_src,
    "target_col": TARGET_COL,
    "N": N, "pos": pos, "neg": neg,
    "roc_auc": roc_auc, "pr_auc": pr_auc,
    "baseline_thr_0p5": base,
    "best_f1": best_f1,
    "best_f0.5": best_f05,
    "best_f2": best_f2,
    "paths": {"report": str(out_txt), "table": str(out_csv)}
}
with open(out_json, "w", encoding="utf-8") as f:
    json.dump(payload, f, indent=2)

print("\nSaved:")
print(f"- {out_txt}")
print(f"- {out_csv}")
print(f"- {out_json}")

# Export for next stages
globals().update({
    "BEST_THR_F1": BEST_THR_F1,
    "BEST_THR_F05": BEST_THR_F05,
    "BEST_THR_F2": BEST_THR_F2,
    "thr_table_eval": thr_table,
    "EVAL_REPORT_PATH": out_txt,
    "EVAL_TABLE_PATH": out_csv,
    "EVAL_SUMMARY_PATH": out_json,
})

gc.collect()


# Submission Build

In [None]:
# ============================================================
# STAGE 11 — Submission Build (ONE CELL, Kaggle CPU-SAFE) — REVISI FULL v3
#
# Fix utama v3:
# - Cari pred test sesuai STAGE 10 (ART_DIR/preds/test_prob_ens.csv) + baca dari TEST_INFER_CFG_PATH jika ada
# - Fallback lebih lengkap (globals / cfg json / csv / npy)
# - Strict align ke sample_submission order, prediction harus 0/1
#
# Output:
# - /kaggle/working/submission.csv
# - SUB_DIR/submission.csv (copy)
# ============================================================

import gc, json, warnings
from pathlib import Path
import numpy as np
import pandas as pd

warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=pd.errors.DtypeWarning)

# ----------------------------
# 0) Require STAGE 0 globals
# ----------------------------
for need in ["PATHS", "SUB_DIR"]:
    if need not in globals():
        raise RuntimeError(f"Missing `{need}`. Jalankan STAGE 0 dulu (setup).")

sample_path = Path(PATHS["SAMPLE_SUB"])
if not sample_path.exists():
    raise FileNotFoundError(f"Missing sample_submission.csv: {sample_path}")

df_sub = pd.read_csv(sample_path)
if not {"object_id", "prediction"}.issubset(df_sub.columns):
    raise ValueError(f"sample_submission must have columns object_id,prediction. Found: {list(df_sub.columns)}")

# ----------------------------
# Helpers
# ----------------------------
def _norm_id(x):
    if isinstance(x, (bytes, np.bytes_)):
        try:
            x = x.decode("utf-8", errors="ignore")
        except Exception:
            x = str(x)
    s = str(x).strip()
    if (s.startswith("b'") and s.endswith("'")) or (s.startswith('b"') and s.endswith('"')):
        s = s[2:-1]
    return s.strip()

def _as_1d_float32(arr):
    a = np.asarray(arr)
    if a.dtype == object and a.ndim == 0:
        try:
            a = np.asarray(a.item())
        except Exception:
            pass
    a = np.asarray(a, dtype=np.float32)
    if a.ndim == 0:
        return a
    if a.ndim > 1:
        a = a.reshape(-1)
    return a

def _sanitize_prob(p):
    p = np.asarray(p, dtype=np.float32)
    p = np.nan_to_num(p, nan=0.0, posinf=1.0, neginf=0.0)
    p = np.clip(p, 0.0, 1.0)
    return p.astype(np.float32)

def _load_ids_npy(path: Path):
    arr = np.load(path, allow_pickle=False)
    ids = arr.tolist() if hasattr(arr, "tolist") else list(arr)
    return [_norm_id(x) for x in ids]

def _try_load_stage10_cfg_csv():
    """
    If STAGE 10 wrote config json, use it to find the exact csv path.
    Returns Path or None.
    """
    p = globals().get("TEST_INFER_CFG_PATH", None)
    if p is None:
        return None
    p = Path(p)
    if not p.exists():
        return None
    try:
        cfg = json.load(open(p, "r", encoding="utf-8"))
        out = cfg.get("outputs", {})
        csvp = out.get("test_prob_ens_csv", None)
        if csvp:
            csvp = Path(csvp)
            if csvp.exists():
                return csvp
    except Exception:
        return None
    return None

def _load_pred_df():
    """
    Return df_pred with columns: object_id, prob
    Priority:
      A) globals: test_ids + test_prob_ens
      B) stage10 config json -> outputs.test_prob_ens_csv
      C) csv fallbacks (ART_DIR/preds/test_prob_ens.csv etc)
      D) npy fallback: FIX_DIR/test_ids.npy + test_prob_ens.npy
    """
    # ---- A) globals ----
    if ("test_prob_ens" in globals()) and (globals()["test_prob_ens"] is not None) and \
       ("test_ids" in globals()) and (globals()["test_ids"] is not None):
        ids = [_norm_id(x) for x in list(globals()["test_ids"])]
        prob = _sanitize_prob(_as_1d_float32(globals()["test_prob_ens"]))
        if isinstance(prob, np.ndarray) and prob.ndim != 0 and len(ids) == len(prob) and len(ids) > 0:
            return pd.DataFrame({"object_id": ids, "prob": prob})

    # ---- B) exact csv path from STAGE 10 config ----
    cfg_csv = _try_load_stage10_cfg_csv()
    if cfg_csv is not None and cfg_csv.exists():
        df = pd.read_csv(cfg_csv)
        if "object_id" in df.columns and ("prob" in df.columns or "prediction" in df.columns):
            df = df.copy()
            df["object_id"] = df["object_id"].apply(_norm_id)
            colp = "prob" if "prob" in df.columns else "prediction"
            prob = _sanitize_prob(_as_1d_float32(df[colp].to_numpy()))
            if len(prob) != len(df):
                raise RuntimeError(f"CSV prob length mismatch: {cfg_csv}")
            return pd.DataFrame({"object_id": df["object_id"].tolist(), "prob": prob})

    # ---- C) csv fallback (best if already aligned with object_id) ----
    art_dir = Path(globals().get("ART_DIR", "/kaggle/working"))
    preds_dir = art_dir / "preds"

    cand_csv = []
    if "TEST_PROB_CSV_PATH" in globals() and globals()["TEST_PROB_CSV_PATH"] is not None:
        cand_csv.append(Path(globals()["TEST_PROB_CSV_PATH"]))
    # STAGE 10 default
    cand_csv.append(preds_dir / "test_prob_ens.csv")
    # older / alternative
    cand_csv.append(art_dir / "test_prob_ens.csv")

    for p in cand_csv:
        if p.exists():
            df = pd.read_csv(p)
            if "object_id" in df.columns and ("prob" in df.columns or "prediction" in df.columns):
                df = df.copy()
                df["object_id"] = df["object_id"].apply(_norm_id)
                colp = "prob" if "prob" in df.columns else "prediction"
                prob = _sanitize_prob(_as_1d_float32(df[colp].to_numpy()))
                if len(prob) != len(df):
                    raise RuntimeError(f"CSV prob length mismatch: {p}")
                return pd.DataFrame({"object_id": df["object_id"].tolist(), "prob": prob})

    # ---- D) npy fallback ----
    fix_dir = Path(globals().get("FIX_DIR", "/kaggle/working/mallorn_run/artifacts/fixed_seq"))
    p_ids = fix_dir / "test_ids.npy"
    if not p_ids.exists():
        raise RuntimeError("Missing test_ids. Pastikan STAGE 6 membuat fixed_seq/test_ids.npy atau STAGE 10 export test_ids.")

    ids = _load_ids_npy(p_ids)
    if len(ids) == 0:
        raise RuntimeError("test_ids.npy kosong.")

    cand_npy = []
    if "TEST_PROB_ENS_PATH" in globals() and globals()["TEST_PROB_ENS_PATH"] is not None:
        cand_npy.append(Path(globals()["TEST_PROB_ENS_PATH"]))
    cand_npy.append(preds_dir / "test_prob_ens.npy")
    cand_npy.append(art_dir / "test_prob_ens.npy")

    prob = None
    for p in cand_npy:
        if p.exists():
            prob = _sanitize_prob(_as_1d_float32(np.load(p, allow_pickle=False)))
            break
    if prob is None:
        raise RuntimeError("Missing test_prob_ens. Jalankan STAGE 10 dulu (Test Inference).")

    if not isinstance(prob, np.ndarray) or prob.ndim == 0:
        raise TypeError(f"Invalid test_prob (scalar/unsized). type={type(prob)} ndim={getattr(prob,'ndim',None)}")

    if len(prob) != len(ids):
        raise RuntimeError(f"Length mismatch (NPY): test_prob={len(prob)} vs test_ids={len(ids)}")

    return pd.DataFrame({"object_id": ids, "prob": prob})

# ----------------------------
# 1) Load prediction df
# ----------------------------
df_pred = _load_pred_df()
if df_pred.empty:
    raise RuntimeError("df_pred empty (unexpected).")

df_pred["object_id"] = df_pred["object_id"].apply(_norm_id)
if df_pred["object_id"].duplicated().any():
    dup = df_pred.loc[df_pred["object_id"].duplicated(), "object_id"].iloc[:10].tolist()
    raise ValueError(f"Duplicated object_id in predictions (examples): {dup}")

p = df_pred["prob"].to_numpy(dtype=np.float32, copy=False)
if not np.isfinite(p).all():
    bad = int((~np.isfinite(p)).sum())
    raise ValueError(f"Found non-finite probabilities in test predictions: {bad} rows")
df_pred["prob"] = _sanitize_prob(p)

# ----------------------------
# 2) Threshold selection (priority)
# ----------------------------
FORCE_THR = None  # set manual if you want, e.g. 0.37
if FORCE_THR is not None:
    thr = float(FORCE_THR)
elif "BEST_THR_F1" in globals() and globals()["BEST_THR_F1"] is not None:
    thr = float(globals()["BEST_THR_F1"])
elif "BEST_THR" in globals() and globals()["BEST_THR"] is not None:
    thr = float(globals()["BEST_THR"])
else:
    thr = 0.5
thr = float(np.clip(thr, 0.0, 1.0))

# ----------------------------
# 3) Align to sample_submission order + build BINARY prediction (0/1)
# ----------------------------
df_sub = df_sub.copy()
df_sub["object_id"] = df_sub["object_id"].apply(_norm_id)

if df_sub["object_id"].duplicated().any():
    dup = df_sub.loc[df_sub["object_id"].duplicated(), "object_id"].iloc[:10].tolist()
    raise ValueError(f"sample_submission has duplicate object_id (unexpected). examples={dup}")

df_out = df_sub[["object_id"]].merge(df_pred, on="object_id", how="left")

if df_out["prob"].isna().any():
    missing_n = int(df_out["prob"].isna().sum())
    miss_ids = df_out.loc[df_out["prob"].isna(), "object_id"].iloc[:10].tolist()
    raise ValueError(
        f"Some sample_submission object_id have no prediction: missing_n={missing_n}. Examples: {miss_ids}\n"
        "Biasanya karena mismatch id normalization atau pred df tidak lengkap."
    )

df_out["prediction"] = (df_out["prob"].to_numpy(dtype=np.float32) >= np.float32(thr)).astype(np.int8)
df_out = df_out[["object_id", "prediction"]]

u = set(np.unique(df_out["prediction"].to_numpy()).tolist())
if not u.issubset({0, 1}):
    raise RuntimeError(f"submission prediction contains values outside {{0,1}}: {sorted(list(u))}")

if len(df_out) != len(df_sub):
    raise RuntimeError("submission row count mismatch with sample_submission.")

pos_pred = int(df_out["prediction"].sum())
print("[Stage 11] SUBMISSION READY (BINARY 0/1)")
print(f"- threshold_used={thr:.6f}")
print(f"- rows={len(df_out):,} | pos_pred={pos_pred:,} ({pos_pred/max(len(df_out),1)*100:.6f}%)")

# ----------------------------
# 4) Write files
# ----------------------------
SUB_DIR = Path(SUB_DIR)
SUB_DIR.mkdir(parents=True, exist_ok=True)

out_main = Path("/kaggle/working/submission.csv")
out_copy = SUB_DIR / "submission.csv"

df_out.to_csv(out_main, index=False)
df_out.to_csv(out_copy, index=False)

print(f"- wrote: {out_main}")
print(f"- copy : {out_copy}")
print("\nPreview:")
print(df_out.head(8).to_string(index=False))

globals().update({
    "SUBMISSION_PATH": out_main,
    "SUBMISSION_COPY_PATH": out_copy,
    "SUBMISSION_MODE": "binary",
    "SUBMISSION_THRESHOLD": thr,
})

gc.collect()
