# Set Paths & Select Config (CFG)

In [1]:
# ============================================================
# STAGE 1 — Set Paths & Select Config (FULL PIPELINE) (ONE CELL)
# Target pipeline:
#   DINOv2-(Giant/Large/Base) dense descriptors
# + Copy-Move Matching (self-similarity / PatchMatch-style features)
# + Segmentation decoder (UNet++ or DeepLabV3+ASPP) on token-grid
# + Gate model (image-level) + Fold Ensemble
#
# REVISI FULL (SUPER ROBUST + FUTURE-PROOF):
# - Resolve COMP_ROOT automatically
# - Ensure PROF artifacts in /kaggle/working/recodai_luc_prof:
#     * copy from /kaggle/input if present, else rebuild minimal manifests
# - Auto-pick TOKEN_ROOT (prefer giant -> large -> base) from /working or /input
# - Auto-pick MATCH_ROOT (support match_cfg_* / patchmatch_cfg_* / ssim_cfg_*) from /working or /input
# - Create RUN_DIR with standard subfolders:
#     seg/, gate/, oof/, preds/, features/, bundle/, logs/
# - Save config: run_cfg.json
# Exports globals:
#   COMP_ROOT, PROF_DIR, TOKEN_ROOT, MATCH_ROOT, RUN_DIR, CFG, PATHS
# ============================================================

import os, json, time, re, shutil, hashlib, warnings
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore", category=FutureWarning)

WORK = Path("/kaggle/working")
INP  = Path("/kaggle/input")

# ----------------------------
# Helpers
# ----------------------------
def read_json_safe(p: Path, default=None):
    try:
        return json.loads(Path(p).read_text())
    except Exception:
        return default

def write_json(p: Path, obj):
    p.write_text(json.dumps(obj, indent=2, ensure_ascii=False))

def safe_hw(p: Path):
    try:
        im = Image.open(p)
        w, h = im.size
        return int(h), int(w)
    except Exception:
        return np.nan, np.nan

def parse_case_id(stem: str):
    """
    Robust case_id parsing:
    - handle: '12345', '12345__forg', '12345__auth', 'case_12345_x', etc.
    Preference: first long-ish digit group.
    """
    s = str(stem)
    # common split tokens
    for tok in ["__", "_", "-", " "]:
        s = s.replace(tok, " ")
    m = re.search(r"\b(\d{3,})\b", s)
    if m:
        return int(m.group(1))
    m = re.search(r"(\d+)", str(stem))
    return int(m.group(1)) if m else None

def find_comp_root():
    # Prefer known folder name
    cand = INP / "recodai-luc-scientific-image-forgery-detection"
    if cand.exists() and (cand / "sample_submission.csv").exists():
        return cand

    # Fallback: any dataset containing sample_submission.csv + train_images + test_images
    for d in sorted(INP.glob("*")):
        if not d.is_dir():
            continue
        if (d / "sample_submission.csv").exists() and (d / "train_images").exists() and (d / "test_images").exists():
            return d

    raise FileNotFoundError(
        "Cannot find competition dataset under /kaggle/input "
        "(need sample_submission.csv + train_images + test_images)."
    )

def _pick_train_subdir(root: Path, keywords):
    best, best_n = None, -1
    if not root.exists():
        return None
    for d in [p for p in root.rglob("*") if p.is_dir()]:
        name = d.name.lower()
        if any(k in name for k in keywords):
            n = sum(1 for _ in d.glob("*.png"))
            if n > best_n:
                best, best_n = d, n
    return best

def find_mask_path(case_id: int, img_stem: str, mask_dirs):
    """
    Try locate mask file for a given case_id / stem.
    Supports patterns: {case_id}.png, {img_stem}.png, {case_id}__*.png
    """
    if case_id is None:
        return None
    stems = [str(case_id), str(img_stem)]
    pats = [
        f"{case_id}.png",
        f"{img_stem}.png",
        f"{case_id}__*.png",
        f"*{case_id}*.png",
    ]
    for md in mask_dirs:
        if md is None or (not md.exists()):
            continue
        # exact first
        for s in stems:
            p = md / f"{s}.png"
            if p.exists():
                return str(p)
        # patterns
        for pat in pats:
            hits = list(md.glob(pat))
            if hits:
                hits = sorted(hits, key=lambda x: x.stat().st_mtime, reverse=True)
                return str(hits[0])
    return None

# ----------------------------
# 0) Resolve competition root
# ----------------------------
COMP_ROOT = find_comp_root()
SAMPLE_SUB     = COMP_ROOT / "sample_submission.csv"
TRAIN_IMG_DIR  = COMP_ROOT / "train_images"
TEST_IMG_DIR   = COMP_ROOT / "test_images"
TRAIN_MASK_DIR = COMP_ROOT / "train_masks"
SUP_MASK_DIR   = COMP_ROOT / "supplemental_masks"

# ----------------------------
# 1) Ensure PROF_DIR in /kaggle/working
# ----------------------------
PROF_DIR = WORK / "recodai_luc_prof"
PROF_DIR.mkdir(parents=True, exist_ok=True)

# Copy as much as possible if exists in input (fast path)
PROF_WANT = [
    "paths.json",
    "train_manifest.parquet",
    "test_manifest.parquet",
    "folds.parquet",
    # optional but nice to have if you already built them before
    "image_profile.parquet",
    "mask_index.parquet",
    "mask_profile.parquet",
    "dup_case_images.csv",
    "sanity_report.json",
]

def find_input_prof_dir():
    cands = []
    for p in INP.glob("*/recodai_luc_prof"):
        if p.is_dir():
            cands.append(p)
    for p in INP.glob("*/*/recodai_luc_prof"):
        if p.is_dir():
            cands.append(p)

    good = []
    for p in cands:
        if (p/"train_manifest.parquet").exists() and (p/"test_manifest.parquet").exists():
            good.append(p)

    if good:
        good = sorted(good, key=lambda x: (x/"train_manifest.parquet").stat().st_mtime, reverse=True)
        return good[0]

    for pj in INP.rglob("recodai_luc_prof/paths.json"):
        p = pj.parent
        if (p/"train_manifest.parquet").exists() and (p/"test_manifest.parquet").exists():
            return p
    return None

def copy_prof_if_missing(src_dir: Path, dst_dir: Path):
    if src_dir is None:
        return
    for fn in PROF_WANT:
        sp = src_dir / fn
        dp = dst_dir / fn
        if sp.exists() and (not dp.exists()):
            try:
                shutil.copy2(sp, dp)
            except Exception:
                pass

copy_prof_if_missing(find_input_prof_dir(), PROF_DIR)

# If still missing core manifests -> rebuild minimal
paths_json = PROF_DIR / "paths.json"
train_pq   = PROF_DIR / "train_manifest.parquet"
test_pq    = PROF_DIR / "test_manifest.parquet"
folds_pq   = PROF_DIR / "folds.parquet"

def build_train_manifest():
    # expected structure: train_images/authentic and train_images/forged
    auth_dir = TRAIN_IMG_DIR / "authentic"
    forg_dir = TRAIN_IMG_DIR / "forged"

    if not auth_dir.exists():
        auth_dir = _pick_train_subdir(TRAIN_IMG_DIR, ["auth", "real", "clean"]) or auth_dir
    if not forg_dir.exists():
        forg_dir = _pick_train_subdir(TRAIN_IMG_DIR, ["forg", "fake", "tamper", "manip"]) or forg_dir

    mask_dirs = [TRAIN_MASK_DIR, SUP_MASK_DIR]

    rows = []
    if auth_dir.exists():
        for p in sorted(auth_dir.glob("*.png")):
            cid = parse_case_id(p.stem)
            if cid is None:
                continue
            H, W = safe_hw(p)
            rows.append({
                "uid": p.stem,
                "case_id": int(cid),
                "y": 0,
                "img_path": str(p),
                "mask_path": None,
                "H": H, "W": W
            })

    if forg_dir.exists():
        for p in sorted(forg_dir.glob("*.png")):
            cid = parse_case_id(p.stem)
            if cid is None:
                continue
            H, W = safe_hw(p)
            mpath = find_mask_path(cid, p.stem, mask_dirs)
            rows.append({
                "uid": p.stem,
                "case_id": int(cid),
                "y": 1,
                "img_path": str(p),
                "mask_path": mpath,
                "H": H, "W": W
            })

    df = pd.DataFrame(rows, columns=["uid","case_id","y","img_path","mask_path","H","W"])
    return df, str(auth_dir), str(forg_dir)

def build_test_manifest():
    rows = []
    for p in sorted(TEST_IMG_DIR.glob("*.png")):
        cid = parse_case_id(p.stem)
        if cid is None:
            continue
        H, W = safe_hw(p)
        rows.append({"uid": p.stem, "case_id": int(cid), "img_path": str(p), "H": H, "W": W})
    df = pd.DataFrame(rows, columns=["uid","case_id","img_path","H","W"])
    return df

def build_folds(df_train, n_folds=5, seed=42):
    """
    Leakage-safe by case_id.
    If multiple rows per case_id exist, we group by case_id.
    Otherwise, plain stratified works.
    """
    df = df_train[["uid","case_id","y"]].copy()
    df["fold"] = -1
    if len(df) == 0:
        return pd.DataFrame([], columns=["uid","case_id","fold"])

    # detect grouping need
    multi = (df.groupby("case_id").size().max() > 1)

    try:
        if multi:
            from sklearn.model_selection import StratifiedGroupKFold
            sgkf = StratifiedGroupKFold(n_splits=n_folds, shuffle=True, random_state=seed)
            X = np.zeros(len(df))
            for f, (_, va) in enumerate(sgkf.split(X, df["y"].values, groups=df["case_id"].values)):
                df.loc[df.index[va], "fold"] = f
        else:
            from sklearn.model_selection import StratifiedKFold
            skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed)
            X = np.zeros(len(df))
            for f, (_, va) in enumerate(skf.split(X, df["y"].values)):
                df.loc[df.index[va], "fold"] = f
    except Exception:
        df["fold"] = (df["case_id"].astype(int) % int(n_folds)).astype(int)

    return df[["uid","case_id","fold"]].sort_values(["fold","case_id"]).reset_index(drop=True)

if (not paths_json.exists()) or (not train_pq.exists()) or (not test_pq.exists()) or (not folds_pq.exists()):
    df_tr, auth_dir_used, forg_dir_used = build_train_manifest()
    df_te = build_test_manifest()

    if len(df_tr) == 0:
        raise RuntimeError("train_manifest gagal dibuat (train_images kosong / struktur tidak ketemu). Cek folder train_images.")
    df_fd = build_folds(df_tr, n_folds=5, seed=42)

    PATHS = {
        "COMP_ROOT": str(COMP_ROOT),
        "SAMPLE_SUB": str(SAMPLE_SUB),
        "TRAIN_IMG_DIR": str(TRAIN_IMG_DIR),
        "TEST_IMG_DIR": str(TEST_IMG_DIR),
        "TRAIN_AUTH_DIR": str(auth_dir_used),
        "TRAIN_FORG_DIR": str(forg_dir_used),
        "TRAIN_MASK_DIR": str(TRAIN_MASK_DIR),
        "SUP_MASK_DIR": str(SUP_MASK_DIR),
    }
    write_json(paths_json, PATHS)
    df_tr.to_parquet(train_pq, index=False)
    df_te.to_parquet(test_pq, index=False)
    df_fd.to_parquet(folds_pq, index=False)

PATHS = read_json_safe(paths_json, default={})

# ----------------------------
# 2) Auto-pick TOKEN_ROOT and MATCH_ROOT (working OR input)
# ----------------------------
WORK_CACHE = WORK / "recodai_luc" / "cache"
WORK_CACHE.mkdir(parents=True, exist_ok=True)

def _pick_latest_dir(cands, must_have):
    cands = [Path(c) for c in cands if Path(c).is_dir() and (Path(c)/must_have).exists()]
    if not cands:
        return None
    cands = sorted(cands, key=lambda p: (p/must_have).stat().st_mtime, reverse=True)
    return cands[0]

def _collect_token_candidates(base_dir: Path):
    # support many names:
    # dinov2_base_518_cfg_xxx / dinov2_large_518_cfg_xxx / dinov2_giant_518_cfg_xxx
    # also allow dinov2_*cfg_* variants.
    if base_dir is None or (not base_dir.exists()):
        return []
    cands = []
    cands += list(base_dir.glob("dinov2_*cfg_*"))
    return cands

def pick_token_root(prefer_variants=("giant","large","base")):
    """
    Prefer DINOv2 giant -> large -> base.
    If not found, fallback to any token cache found.
    """
    # gather candidates from working + input
    w_cands = _collect_token_candidates(WORK_CACHE)

    i_cands = []
    for p in INP.glob("*/recodai_luc/cache"):
        i_cands += _collect_token_candidates(p)
    for p in INP.glob("*/*/recodai_luc/cache"):
        i_cands += _collect_token_candidates(p)

    all_cands = w_cands + i_cands

    def variant_of(p: Path):
        n = p.name.lower()
        if "giant" in n: return "giant"
        if "large" in n: return "large"
        if "base"  in n: return "base"
        return "unknown"

    # 1) preferred order
    for v in prefer_variants:
        subset = [p for p in all_cands if variant_of(p) == v]
        r = _pick_latest_dir(subset, "tokens_manifest_train.parquet")
        if r is not None:
            return r, v

    # 2) any
    r = _pick_latest_dir(all_cands, "tokens_manifest_train.parquet")
    if r is not None:
        return r, variant_of(r)

    # 3) last resort: locate tokens_manifest_train.parquet anywhere
    for tm in INP.rglob("tokens_manifest_train.parquet"):
        root = tm.parent
        return root, variant_of(root)

    return None, None

def _collect_match_candidates(base_dir: Path):
    if base_dir is None or (not base_dir.exists()):
        return []
    cands = []
    # support multiple naming conventions
    cands += list(base_dir.glob("match_cfg_*"))
    cands += list(base_dir.glob("patchmatch_cfg_*"))
    cands += list(base_dir.glob("ssim_cfg_*"))
    return cands

def pick_match_root():
    w_cands = _collect_match_candidates(WORK_CACHE)

    i_cands = []
    for p in INP.glob("*/recodai_luc/cache"):
        i_cands += _collect_match_candidates(p)
    for p in INP.glob("*/*/recodai_luc/cache"):
        i_cands += _collect_match_candidates(p)

    all_cands = w_cands + i_cands

    r = _pick_latest_dir(all_cands, "match_manifest_train.parquet")
    if r is not None:
        return r

    # fallback find any manifest
    for mm in INP.rglob("match_manifest_train.parquet"):
        return mm.parent

    return None

TOKEN_ROOT, DINO_VARIANT_USED = pick_token_root(prefer_variants=("giant","large","base"))
MATCH_ROOT = pick_match_root()

if TOKEN_ROOT is None:
    raise FileNotFoundError(
        "TOKEN_ROOT tidak ditemukan (butuh tokens_manifest_train.parquet). "
        "Pasang dataset token-cache atau jalankan stage token-cache."
    )
if MATCH_ROOT is None:
    raise FileNotFoundError(
        "MATCH_ROOT tidak ditemukan (butuh match_manifest_train.parquet). "
        "Pasang dataset match-cache atau jalankan stage matching-cache."
    )

# ----------------------------
# 3) Build RUN_DIR + CFG (full pipeline)
# ----------------------------
RUNS_DIR = WORK / "recodai_luc_runs"
RUNS_DIR.mkdir(parents=True, exist_ok=True)

# Defaults: target LB1-style architecture (you can tweak later)
CFG = {
    "version": "luc_fullpipe_v1",
    "created_utc": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
    "paths": {
        "COMP_ROOT": str(COMP_ROOT),
        "PROF_DIR": str(PROF_DIR),
        "SAMPLE_SUB": str(SAMPLE_SUB),
        "TRAIN_IMG_DIR": str(TRAIN_IMG_DIR),
        "TEST_IMG_DIR": str(TEST_IMG_DIR),
        "TRAIN_MASK_DIR": str(TRAIN_MASK_DIR),
        "SUP_MASK_DIR": str(SUP_MASK_DIR),
        "TOKEN_ROOT": str(TOKEN_ROOT),
        "MATCH_ROOT": str(MATCH_ROOT),
    },

    # ---------- Token encoder (DINOv2) ----------
    "token": {
        "dino_variant_prefer": ["giant", "large", "base"],
        "dino_variant_used": DINO_VARIANT_USED,  # auto selected by cache
        "img_size": 518,               # must match token cache build
        "patch": 14,
        "tok_norm": "l2",
        "reduce_dim": {
            "enabled": True,
            "method": "linear",        # "linear" (1x1 conv) / "pca" (offline) - set later
            "out_dim": 256
        }
    },

    # ---------- Copy-Move Matching (self-similarity / PatchMatch-style) ----------
    "matching": {
        "enabled": True,
        "method": "selfsim_topk",      # computed as cache in MATCH_ROOT
        "seed_channels": ["src", "tgt", "union", "conf"],  # compatible with your current seeds
        "use_best_peak_only": True,
        # if you rebuild match-cache later, these become the canonical params:
        "params": {
            "topk": 8,
            "exclude_radius": 2,       # token units
            "bidir": True,
            "use_vote_consistency": True,
            "vote_bins": 9,            # small hough bins for displacement consistency
        },
        # what feature maps are expected for seg fusion (later stage will enforce)
        "feat_maps": [
            "max_sim", "sim_margin", "dx", "dy", "vote_strength", "bidir_score"
        ]
    },

    # ---------- Segmentation model (token-grid -> mask) ----------
    "seg": {
        "enabled": True,
        "decoder": "unetpp_aspp",      # "unetpp_aspp" or "deeplabv3p_aspp"
        "base_ch": 128,
        "dropout": 0.1,
        "use_matching_seeds": True,
        "use_matching_feats": True,
        "loss": {
            "bce": 1.0,
            "dice": 1.0,               # or "tversky" in later stage
            "pos_weight": 2.0
        },
        "train": {
            "seed": 42,
            "n_folds": 5,
            "epochs": 25,
            "batch_size": 16,
            "lr": 3e-4,
            "weight_decay": 1e-4,
            "amp": True,
            "ema": True,
            "accum_steps": 1,
            "num_workers": 2,
            "early_stop_patience": 6,
        }
    },

    # ---------- Gate model (image-level forged/auth to suppress FP) ----------
    "gate": {
        "enabled": True,
        "model": "lgbm",               # could be "logreg"/"mlp" later
        "train_from_oof": True,
        "features_policy": "auto",     # stage later will define exact feature_cols
        "thr": None                    # tuned via OOF sweep later
    },

    # ---------- Ensemble + inference policy ----------
    "ensemble": {
        "enabled": True,
        "use_folds": True,             # 5-fold ensemble
        "tta": ["none", "hflip", "vflip"],  # minimal; extend later
        "calibration": {
            "enabled": True,
            "method": "thr_sweep",      # keep simple & robust
        }
    },

    # ---------- Postprocess defaults (tuned later on OOF) ----------
    "post": {
        "prob_thr": 0.5,
        "min_cc_full_pix": 64,
        "max_area_frac": 0.90,
    }
}

cfg_id = hashlib.sha1(json.dumps(CFG, sort_keys=True).encode()).hexdigest()[:12]
CFG["cfg_id"] = cfg_id

RUN_DIR = RUNS_DIR / f"run_{cfg_id}"
RUN_DIR.mkdir(parents=True, exist_ok=True)

# standard subdirs
for sd in ["seg", "gate", "oof", "preds", "features", "bundle", "logs"]:
    (RUN_DIR / sd).mkdir(parents=True, exist_ok=True)

# save config
write_json(RUN_DIR / "run_cfg.json", CFG)

# ----------------------------
# 4) Summary
# ----------------------------
print("COMP_ROOT      :", COMP_ROOT)
print("PROF_DIR       :", PROF_DIR)
print("TOKEN_ROOT     :", TOKEN_ROOT, f"(variant_used={DINO_VARIANT_USED})")
print("MATCH_ROOT     :", MATCH_ROOT)
print("RUN_DIR        :", RUN_DIR)
print("CFG_ID         :", cfg_id)
print("Train manifest :", train_pq, "| exists:", train_pq.exists())
print("Test manifest  :", test_pq,  "| exists:", test_pq.exists())
print("Folds          :", folds_pq, "| exists:", folds_pq.exists())
print("Saved cfg      :", RUN_DIR / "run_cfg.json")

# Export globals (explicit)
PATHS = PATHS


COMP_ROOT      : /kaggle/input/recodai-luc-scientific-image-forgery-detection
PROF_DIR       : /kaggle/working/recodai_luc_prof
TOKEN_ROOT     : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/dinov2_base_518_cfg_bind_9894bfdb484a (variant_used=base)
MATCH_ROOT     : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/match_cfg_2ed747746f9c
RUN_DIR        : /kaggle/working/recodai_luc_runs/run_4747afaf927b
CFG_ID         : 4747afaf927b
Train manifest : /kaggle/working/recodai_luc_prof/train_manifest.parquet | exists: True
Test manifest  : /kaggle/working/recodai_luc_prof/test_manifest.parquet | exists: True
Folds          : /kaggle/working/recodai_luc_prof/folds.parquet | exists: True
Saved cfg      : /kaggle/working/recodai_luc_runs/run_4747afaf927b/run_cfg.json


# Build Training Table (X, y, folds)

In [2]:
# ============================================================
# STAGE 2 — Build Train/Test Tables + Precompute GT Token Masks (NO FOLD) (ONE CELL)
# REVISI FULL:
# - Reads latest RUN_DIR/run_cfg.json (or uses global RUN_DIR if exists)
# - Loads train/test manifests (no fold concept at all)
# - Robust join with TOKEN_ROOT + MATCH_ROOT manifests (prefer uid, fallback case_id)
# - KEEPS ALL ROWS (adds tok_exists/match_exists flags; does NOT drop)
# - Infers HTOK/WTOK robustly (cfg.json -> manifest cols -> read one npz)
# - Precompute GT token masks (union) into RUN_DIR/seg/gt_tok_masks_<cfg_id>/
# - Saves tables to RUN_DIR/features/train_table.parquet & test_table.parquet
#
# Exports globals:
#   DF_TRAIN_ALL, DF_TEST_ALL, HTOK, WTOK, GT_TOK_DIR
# ============================================================

import os, json, re, time, warnings
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore", category=FutureWarning)

WORK = Path("/kaggle/working")
INP  = Path("/kaggle/input")

# ----------------------------
# 0) Locate RUN_DIR + load CFG
# ----------------------------
def _latest_run_dir():
    runs = sorted((WORK / "recodai_luc_runs").glob("run_*/run_cfg.json"),
                  key=lambda p: p.stat().st_mtime, reverse=True)
    return runs[0].parent if runs else None

if "RUN_DIR" in globals():
    _rd = Path(RUN_DIR)
    cfg_path = _rd / "run_cfg.json"
    if not cfg_path.exists():
        _rd = _latest_run_dir()
        if _rd is None:
            raise FileNotFoundError("Cannot find RUN_DIR/run_cfg.json. Run STAGE 1 first.")
        cfg_path = _rd / "run_cfg.json"
    RUN_DIR = _rd
else:
    RUN_DIR = _latest_run_dir()
    if RUN_DIR is None:
        raise FileNotFoundError("Cannot find any run_*/run_cfg.json under /kaggle/working/recodai_luc_runs. Run STAGE 1 first.")
    cfg_path = RUN_DIR / "run_cfg.json"

CFG = json.loads(cfg_path.read_text())
CFG_ID = CFG.get("cfg_id", "no_cfgid")

for sd in ["seg", "gate", "oof", "preds", "features", "bundle", "logs"]:
    (RUN_DIR / sd).mkdir(parents=True, exist_ok=True)

PROF_DIR       = Path(CFG["paths"]["PROF_DIR"])
TOKEN_ROOT     = Path(CFG["paths"]["TOKEN_ROOT"])
MATCH_ROOT     = Path(CFG["paths"]["MATCH_ROOT"])
TRAIN_MASK_DIR = Path(CFG["paths"]["TRAIN_MASK_DIR"]) if CFG["paths"].get("TRAIN_MASK_DIR") else None
SUP_MASK_DIR   = Path(CFG["paths"]["SUP_MASK_DIR"]) if CFG["paths"].get("SUP_MASK_DIR") else None

# ----------------------------
# Helpers
# ----------------------------
def _infer_case_id_from_series(s: pd.Series) -> pd.Series:
    s = s.astype(str)
    out = s.str.extract(r"(\d+)")[0]
    return pd.to_numeric(out, errors="coerce")

def ensure_uid_case(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    if "uid" not in df.columns:
        if "img_path" in df.columns:
            df["uid"] = df["img_path"].astype(str).apply(lambda x: Path(x).stem)
        elif "npz_path" in df.columns:
            df["uid"] = df["npz_path"].astype(str).apply(lambda x: Path(x).stem)
        elif "tok_npz" in df.columns:
            df["uid"] = df["tok_npz"].astype(str).apply(lambda x: Path(x).stem)
        elif "match_npz" in df.columns:
            df["uid"] = df["match_npz"].astype(str).apply(lambda x: Path(x).stem)
        else:
            df["uid"] = None

    if "case_id" not in df.columns:
        for c in ["uid", "uid_safe", "id", "img_path", "npz_path", "tok_npz", "match_npz"]:
            if c in df.columns:
                df["case_id"] = _infer_case_id_from_series(df[c])
                break
    df["case_id"] = pd.to_numeric(df["case_id"], errors="coerce")
    df = df[df["case_id"].notna()].copy()
    df["case_id"] = df["case_id"].astype(int)
    return df

def resolve_any_path(p, root: Path, split_hint: str):
    if p is None or (isinstance(p, float) and np.isnan(p)):
        return None
    s = str(p)
    if not s:
        return None
    pp = Path(s)
    if pp.exists():
        return str(pp)

    if not pp.is_absolute():
        cand = root / pp
        if cand.exists():
            return str(cand)

    name = pp.name
    cand1 = root / split_hint / name
    if cand1.exists():
        return str(cand1)
    cand2 = root / name
    if cand2.exists():
        return str(cand2)

    s2 = s.replace("\\", "/")
    for seg in ["train", "test"]:
        if f"/{seg}/" in s2:
            tail = s2.split(f"/{seg}/", 1)[1]
            cand3 = root / seg / tail
            if cand3.exists():
                return str(cand3)
    return None

def dedup_best(df, key_cols, path_col):
    if len(df) == 0:
        return df
    df = df.copy()
    df["_ok"] = df[path_col].apply(lambda x: isinstance(x, str) and Path(x).exists())
    df = df.sort_values(["_ok"], ascending=False).drop_duplicates(key_cols, keep="first").drop(columns=["_ok"])
    return df

def _pick_path_col(df, candidates):
    for c in candidates:
        if c in df.columns:
            return c
    return None

def find_mask_path(case_id: int, uid: str):
    mask_dirs = [TRAIN_MASK_DIR, SUP_MASK_DIR]
    if case_id is None:
        return None
    stems = [str(case_id)]
    if uid:
        stems.append(str(uid))
    pats = [f"{case_id}.png", f"{case_id}__*.png", f"*{case_id}*.png"]
    for md in mask_dirs:
        if md is None or (not md.exists()):
            continue
        for st in stems:
            p = md / f"{st}.png"
            if p.exists():
                return str(p)
        for pat in pats:
            hits = list(md.glob(pat))
            if hits:
                hits = sorted(hits, key=lambda x: x.stat().st_mtime, reverse=True)
                return str(hits[0])
    return None

# ----------------------------
# 1) Load manifests (NO FOLD)
# ----------------------------
train_pq = PROF_DIR / "train_manifest.parquet"
test_pq  = PROF_DIR / "test_manifest.parquet"

if not train_pq.exists():
    raise FileNotFoundError(f"Missing {train_pq}. Run STAGE 1 first.")
if not test_pq.exists():
    raise FileNotFoundError(f"Missing {test_pq}. Run STAGE 1 first.")

df_train = ensure_uid_case(pd.read_parquet(train_pq).copy())
df_test  = ensure_uid_case(pd.read_parquet(test_pq).copy())

if "y" not in df_train.columns:
    raise ValueError("train_manifest must contain 'y'.")
if "img_path" not in df_train.columns or "img_path" not in df_test.columns:
    raise ValueError("manifests must contain 'img_path'.")

df_train["y"] = pd.to_numeric(df_train["y"], errors="coerce").fillna(0).astype(int)
df_train["uid"] = df_train["uid"].astype(str)
df_test["uid"]  = df_test["uid"].astype(str)

# ensure mask_path exists
if "mask_path" not in df_train.columns:
    df_train["mask_path"] = None
need_mask = (df_train["y"] == 1) & (df_train["mask_path"].isna() | (df_train["mask_path"].astype(str) == "None"))
if need_mask.any():
    df_train.loc[need_mask, "mask_path"] = df_train.loc[need_mask].apply(
        lambda r: find_mask_path(int(r["case_id"]), r["uid"]), axis=1
    )

# ----------------------------
# 2) Load token + match manifests
# ----------------------------
tok_train_pq = TOKEN_ROOT / "tokens_manifest_train.parquet"
tok_test_pq  = TOKEN_ROOT / "tokens_manifest_test.parquet"
mat_train_pq = MATCH_ROOT / "match_manifest_train.parquet"
mat_test_pq  = MATCH_ROOT / "match_manifest_test.parquet"

if not tok_train_pq.exists():
    raise FileNotFoundError(f"Missing {tok_train_pq} (TOKEN_ROOT wrong?).")
if not mat_train_pq.exists():
    raise FileNotFoundError(f"Missing {mat_train_pq} (MATCH_ROOT wrong?).")

df_tok_tr = ensure_uid_case(pd.read_parquet(tok_train_pq).copy())
df_tok_te = ensure_uid_case(pd.read_parquet(tok_test_pq).copy()) if tok_test_pq.exists() else pd.DataFrame(columns=["uid","case_id"])
df_mat_tr = ensure_uid_case(pd.read_parquet(mat_train_pq).copy())
df_mat_te = ensure_uid_case(pd.read_parquet(mat_test_pq).copy()) if mat_test_pq.exists() else pd.DataFrame(columns=["uid","case_id"])

tok_path_col = _pick_path_col(df_tok_tr, ["npz_path","tok_npz","path","file","npz","token_npz"])
if tok_path_col is None:
    raise ValueError("tokens_manifest_train.parquet has no recognizable npz path column.")
df_tok_tr["tok_npz"] = df_tok_tr[tok_path_col]
if len(df_tok_te):
    tok_path_col_te = _pick_path_col(df_tok_te, ["npz_path","tok_npz","path","file","npz","token_npz"])
    df_tok_te["tok_npz"] = df_tok_te[tok_path_col_te] if tok_path_col_te else None

mat_path_col = _pick_path_col(df_mat_tr, ["match_npz","npz_path","path","file","npz"])
if mat_path_col is None:
    raise ValueError("match_manifest_train.parquet has no recognizable npz path column.")
df_mat_tr["match_npz"] = df_mat_tr[mat_path_col]
if len(df_mat_te):
    mat_path_col_te = _pick_path_col(df_mat_te, ["match_npz","npz_path","path","file","npz"])
    df_mat_te["match_npz"] = df_mat_te[mat_path_col_te] if mat_path_col_te else None

df_tok_tr["tok_npz"] = df_tok_tr["tok_npz"].apply(lambda x: resolve_any_path(x, TOKEN_ROOT, "train"))
if len(df_tok_te):
    df_tok_te["tok_npz"] = df_tok_te["tok_npz"].apply(lambda x: resolve_any_path(x, TOKEN_ROOT, "test"))

df_mat_tr["match_npz"] = df_mat_tr["match_npz"].apply(lambda x: resolve_any_path(x, MATCH_ROOT, "train"))
if len(df_mat_te):
    df_mat_te["match_npz"] = df_mat_te["match_npz"].apply(lambda x: resolve_any_path(x, MATCH_ROOT, "test"))

key_tok = ["uid"] if df_tok_tr["uid"].notna().any() else ["case_id"]
key_mat = ["uid"] if df_mat_tr["uid"].notna().any() else ["case_id"]
df_tok_tr = dedup_best(df_tok_tr, key_tok, "tok_npz")
df_mat_tr = dedup_best(df_mat_tr, key_mat, "match_npz")
if len(df_tok_te): df_tok_te = dedup_best(df_tok_te, key_tok, "tok_npz")
if len(df_mat_te): df_mat_te = dedup_best(df_mat_te, key_mat, "match_npz")

tok_tr_small = df_tok_tr[["uid","case_id","tok_npz"]].copy()
mat_tr_small = df_mat_tr[["uid","case_id","match_npz"]].copy()
tok_te_small = df_tok_te[["uid","case_id","tok_npz"]].copy() if len(df_tok_te) else pd.DataFrame(columns=["uid","case_id","tok_npz"])
mat_te_small = df_mat_te[["uid","case_id","match_npz"]].copy() if len(df_mat_te) else pd.DataFrame(columns=["uid","case_id","match_npz"])

# ----------------------------
# 3) Build joined train/test tables (KEEP ALL)
# ----------------------------
use_uid_join = (df_train["uid"].notna().all() and tok_tr_small["uid"].notna().all() and mat_tr_small["uid"].notna().all())

if use_uid_join:
    DF_TRAIN_ALL = df_train.merge(tok_tr_small[["uid","tok_npz"]], on="uid", how="left") \
                           .merge(mat_tr_small[["uid","match_npz"]], on="uid", how="left")
    DF_TEST_ALL  = df_test.merge(tok_te_small[["uid","tok_npz"]], on="uid", how="left") \
                          .merge(mat_te_small[["uid","match_npz"]], on="uid", how="left")
else:
    DF_TRAIN_ALL = df_train.merge(tok_tr_small[["case_id","tok_npz"]], on="case_id", how="left") \
                           .merge(mat_tr_small[["case_id","match_npz"]], on="case_id", how="left")
    DF_TEST_ALL  = df_test.merge(tok_te_small[["case_id","tok_npz"]], on="case_id", how="left") \
                          .merge(mat_te_small[["case_id","match_npz"]], on="case_id", how="left")

DF_TRAIN_ALL["tok_exists"]   = DF_TRAIN_ALL["tok_npz"].apply(lambda x: isinstance(x, str) and Path(x).exists())
DF_TRAIN_ALL["match_exists"] = DF_TRAIN_ALL["match_npz"].apply(lambda x: isinstance(x, str) and Path(x).exists())
DF_TEST_ALL["tok_exists"]    = DF_TEST_ALL["tok_npz"].apply(lambda x: isinstance(x, str) and Path(x).exists())
DF_TEST_ALL["match_exists"]  = DF_TEST_ALL["match_npz"].apply(lambda x: isinstance(x, str) and Path(x).exists())

print("RUN_DIR     :", RUN_DIR)
print("CFG_ID      :", CFG_ID)
print("TOKEN_ROOT  :", TOKEN_ROOT)
print("MATCH_ROOT  :", MATCH_ROOT)
print("-"*60)
print("Train rows  :", len(DF_TRAIN_ALL),
      "| tok_exists:", int(DF_TRAIN_ALL["tok_exists"].sum()),
      "| match_exists:", int(DF_TRAIN_ALL["match_exists"].sum()),
      "| forged%:", float(DF_TRAIN_ALL["y"].mean())*100)
print("Test rows   :", len(DF_TEST_ALL),
      "| tok_exists:", int(DF_TEST_ALL["tok_exists"].sum()),
      "| match_exists:", int(DF_TEST_ALL["match_exists"].sum()))

# ----------------------------
# 4) Infer token grid size (HTOK, WTOK)
# ----------------------------
HTOK = WTOK = None
cfg_tok_json = TOKEN_ROOT / "cfg.json"

def _infer_hw_from_npz(npz_path: Path):
    z = np.load(npz_path)
    keys = list(z.keys())
    pref = ["tok","tokens","x","feat","grid","emb","embedding"]
    for k in pref:
        if k in keys:
            a = z[k]
            if a.ndim == 3:
                return int(a.shape[0]), int(a.shape[1])
            if a.ndim == 2:
                n = int(a.shape[0])
                s = int(np.sqrt(n))
                if s*s == n:
                    return s, s
    for k in keys:
        a = z[k]
        if getattr(a, "ndim", 0) == 3:
            return int(a.shape[0]), int(a.shape[1])
    for k in keys:
        a = z[k]
        if getattr(a, "ndim", 0) == 2:
            n = int(a.shape[0])
            s = int(np.sqrt(n))
            if s*s == n:
                return s, s
    return None, None

if cfg_tok_json.exists():
    try:
        tok_cfg = json.loads(cfg_tok_json.read_text())
        for kH, kW in [("HTOK","WTOK"), ("htok","wtok"), ("Htok","Wtok")]:
            if kH in tok_cfg and kW in tok_cfg:
                HTOK = int(tok_cfg[kH]); WTOK = int(tok_cfg[kW])
                break
    except Exception:
        pass

if HTOK is None or WTOK is None:
    for a,b in [("htok","wtok"),("Htok","Wtok"),("HTOK","WTOK")]:
        if a in df_tok_tr.columns and b in df_tok_tr.columns and df_tok_tr[a].notna().any():
            HTOK = int(df_tok_tr[a].dropna().iloc[0])
            WTOK = int(df_tok_tr[b].dropna().iloc[0])
            break

if HTOK is None or WTOK is None:
    sample = DF_TRAIN_ALL.loc[DF_TRAIN_ALL["tok_exists"], "tok_npz"]
    if len(sample) == 0:
        files = sorted((TOKEN_ROOT / "train").glob("*.npz"))
        if not files:
            raise RuntimeError("Cannot infer HTOK/WTOK: no token npz found.")
        sp = files[0]
    else:
        sp = Path(sample.iloc[0])
    HTOK, WTOK = _infer_hw_from_npz(sp)

if HTOK is None or WTOK is None:
    raise RuntimeError("Failed to infer HTOK/WTOK from cfg/manifest/npz.")

print("Token grid  :", (HTOK, WTOK))

# ----------------------------
# 5) Precompute GT token masks (TRAIN only)
# NOTE: file name uses uid to avoid collision if case_id duplicates exist
# ----------------------------
GT_TOK_DIR = RUN_DIR / "seg" / f"gt_tok_masks_{CFG_ID}"
GT_TOK_TRAIN = GT_TOK_DIR / "train"
GT_TOK_TRAIN.mkdir(parents=True, exist_ok=True)

def load_gt_union_fullres(case_id: int, mask_path: str = None):
    if mask_path and isinstance(mask_path, str):
        p = Path(mask_path)
        if p.exists():
            try:
                im = Image.open(p).convert("L")
                return (np.asarray(im) > 0)
            except Exception:
                pass

    for d in [TRAIN_MASK_DIR, SUP_MASK_DIR]:
        if d is None or (not d.exists()):
            continue
        npy = d / f"{int(case_id)}.npy"
        if npy.exists():
            a = np.asarray(np.load(npy, mmap_mode="r"))
            if a.ndim == 2:
                return (a > 0)
            if a.ndim == 3:
                return (a > 0).any(axis=0)

    exts = (".png",".jpg",".jpeg",".tif",".tiff",".bmp")
    pats = [f"{int(case_id)}*.png", f"{int(case_id)}*.jpg", f"{int(case_id)}*.jpeg",
            f"{int(case_id)}*.tif", f"{int(case_id)}*.tiff", f"{int(case_id)}*.bmp"]
    files = []
    for d in [TRAIN_MASK_DIR, SUP_MASK_DIR]:
        if d is None or (not d.exists()):
            continue
        for pat in pats:
            files += list(d.glob(pat))
    files = [p for p in files if p.suffix.lower() in exts]
    if not files:
        return None

    m = None
    for p in files:
        try:
            im = Image.open(p).convert("L")
            a = (np.asarray(im) > 0)
            m = a if m is None else (m | a)
        except Exception:
            continue
    return m

def downsample_to_tok(mask_bool, htok, wtok):
    if mask_bool is None:
        return np.zeros((htok, wtok), dtype=np.uint8)
    im = Image.fromarray((mask_bool.astype(np.uint8) * 255))
    im = im.resize((wtok, htok), resample=Image.NEAREST)
    return (np.asarray(im) > 127).astype(np.uint8)

gt_tok_paths = []
gt_area_fracs = []
rebuilt = 0
t0 = time.time()

for i, r in DF_TRAIN_ALL.iterrows():
    uid = str(r["uid"])
    cid = int(r["case_id"])
    y   = int(r["y"])
    outp = GT_TOK_TRAIN / f"{uid}.npz"

    if outp.exists():
        gt_tok_paths.append(str(outp))
        try:
            a = np.load(outp)["m"]
            gt_area_fracs.append(float(a.mean()))
        except Exception:
            gt_area_fracs.append(np.nan)
        continue

    if y == 0:
        m_tok = np.zeros((HTOK, WTOK), dtype=np.uint8)
    else:
        gt_full = load_gt_union_fullres(cid, r.get("mask_path", None))
        m_tok = downsample_to_tok(gt_full, HTOK, WTOK)

    np.savez_compressed(outp, m=m_tok)
    gt_tok_paths.append(str(outp))
    gt_area_fracs.append(float(m_tok.mean()))
    rebuilt += 1

    if (i + 1) % 500 == 0:
        print(f"[gt_tok] {i+1}/{len(DF_TRAIN_ALL)} | rebuilt={rebuilt} | {time.time()-t0:.1f}s")

DF_TRAIN_ALL["gt_tok_npz"] = gt_tok_paths
DF_TRAIN_ALL["gt_area_frac_tok"] = gt_area_fracs

# ----------------------------
# 6) Save tables
# ----------------------------
train_out = RUN_DIR / "features" / "train_table.parquet"
test_out  = RUN_DIR / "features" / "test_table.parquet"
DF_TRAIN_ALL.to_parquet(train_out, index=False)
DF_TEST_ALL.to_parquet(test_out, index=False)

print("-"*60)
print("Saved:", train_out)
print("Saved:", test_out)
print("GT_TOK_DIR:", GT_TOK_DIR)
print("GT rebuilt:", rebuilt, "/", len(DF_TRAIN_ALL), "| time:", f"{time.time()-t0:.1f}s")

# Export globals
DF_TRAIN_ALL = DF_TRAIN_ALL
DF_TEST_ALL  = DF_TEST_ALL
GT_TOK_DIR   = GT_TOK_DIR
HTOK, WTOK   = HTOK, WTOK


RUN_DIR     : /kaggle/working/recodai_luc_runs/run_4747afaf927b
CFG_ID      : 4747afaf927b
TOKEN_ROOT  : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/dinov2_base_518_cfg_bind_9894bfdb484a
MATCH_ROOT  : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/match_cfg_2ed747746f9c
------------------------------------------------------------
Train rows  : 2795 | tok_exists: 0 | match_exists: 2795 | forged%: 100.0
Test rows   : 1 | tok_exists: 0 | match_exists: 1
Token grid  : (37, 37)
[gt_tok] 500/2795 | rebuilt=500 | 12.5s
[gt_tok] 1000/2795 | rebuilt=1000 | 23.1s
[gt_tok] 1500/2795 | rebuilt=1500 | 35.0s
[gt_tok] 2000/2795 | rebuilt=2000 | 47.6s
[gt_tok] 2500/2795 | rebuilt=2500 | 59.8s
------------------------------------------------------------
Saved: /kaggle/working/recodai_luc_runs/run_4747afaf927b/features/train_table.parquet
Saved: /kaggle/working/recodai_luc_runs/run_4747afaf927b/features/test_table.parquet
GT_TOK_DIR: /kaggle/working/recodai_luc_runs/run_4747afaf927b

# Build & Export Test Feature Table

In [3]:
# ============================================================
# STAGE — Build & Export Test Feature Table (ONE CELL) — FIX: AUTO feature_cols.json
# - If feature_cols.json missing, infer from pred_features_{train,test}.csv numeric columns
# - Auto-pick PROF_DIR + pred_features_test.csv from /kaggle/working or /kaggle/input
# - Align to sample_submission order
# Outputs:
#   /kaggle/working/recodai_luc_gate_artifacts/test_table.parquet
#   /kaggle/working/recodai_luc_gate_artifacts/test_X.npy
#   /kaggle/working/recodai_luc_gate_artifacts/test_case_id.npy
#   /kaggle/working/recodai_luc_gate_artifacts/feature_cols.json   (auto if missing)
# ============================================================

import os, json
from pathlib import Path
import numpy as np
import pandas as pd

WORK = Path("/kaggle/working")
INP  = Path("/kaggle/input")

ART_DIR = WORK / "recodai_luc_gate_artifacts"
ART_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# 0) Find competition root + sample_submission
# ----------------------------
def find_comp_root():
    cand = INP / "recodai-luc-scientific-image-forgery-detection"
    if cand.exists() and (cand / "sample_submission.csv").exists():
        return cand
    for d in sorted(INP.glob("*")):
        if not d.is_dir():
            continue
        if (d / "sample_submission.csv").exists() and (d / "train_images").exists() and (d / "test_images").exists():
            return d
    raise FileNotFoundError("Cannot find competition dataset under /kaggle/input (need sample_submission.csv + train_images + test_images).")

COMP_ROOT  = find_comp_root()
SAMPLE_SUB = COMP_ROOT / "sample_submission.csv"

# ----------------------------
# 1) Find PROF_DIR (working first, else input)
# ----------------------------
def pick_prof_dir():
    w = WORK / "recodai_luc_prof"
    if (w / "test_manifest.parquet").exists():
        return w
    # try common input layouts
    cands = []
    for p in INP.glob("*/recodai_luc_prof"):
        if (p / "test_manifest.parquet").exists():
            cands.append(p)
    for p in INP.glob("*/*/recodai_luc_prof"):
        if (p / "test_manifest.parquet").exists():
            cands.append(p)
    if cands:
        cands = sorted(cands, key=lambda x: (x / "test_manifest.parquet").stat().st_mtime, reverse=True)
        return cands[0]
    # last resort
    for pj in INP.rglob("recodai_luc_prof/test_manifest.parquet"):
        return pj.parent
    raise FileNotFoundError("Cannot find recodai_luc_prof/test_manifest.parquet in working or input.")

PROF_DIR = pick_prof_dir()
test_manifest_pq = PROF_DIR / "test_manifest.parquet"
if not test_manifest_pq.exists():
    raise FileNotFoundError(f"Missing {test_manifest_pq}")

# ----------------------------
# 2) Find pred_features_{test,train}.csv (auto)
# ----------------------------
def newest_glob(base: Path, pattern: str):
    hits = list(base.glob(pattern))
    if not hits:
        return None
    hits = sorted(hits, key=lambda p: p.stat().st_mtime, reverse=True)
    return hits[0]

def find_pred_features(name="test"):
    # prefer working cache first
    w_cands = [
        newest_glob(WORK, f"recodai_luc/cache/pred_ens/pred_features_{name}.csv"),
        newest_glob(WORK, f"recodai_luc/cache/**/pred_features_{name}.csv"),
        newest_glob(WORK, f"**/pred_features_{name}.csv"),
    ]
    for c in w_cands:
        if c is not None and c.exists():
            return c

    # input datasets (your screenshot: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/pred_ens/...)
    i_hits = list(INP.glob(f"**/pred_features_{name}.csv"))
    if i_hits:
        i_hits = sorted(i_hits, key=lambda p: p.stat().st_mtime, reverse=True)
        return i_hits[0]
    return None

pred_feat_test = find_pred_features("test")
pred_feat_train = find_pred_features("train")

if pred_feat_test is None or (not pred_feat_test.exists()):
    raise FileNotFoundError("pred_features_test.csv not found in working or input.")

print("PROF_DIR        :", PROF_DIR)
print("pred_feat_test  :", pred_feat_test)
print("pred_feat_train :", pred_feat_train if pred_feat_train else "None (ok)")

# ----------------------------
# 3) Load/Build FEATURE_COLS
# ----------------------------
cols_path = ART_DIR / "feature_cols.json"

def ensure_case_id_col(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    if "case_id" in df.columns:
        df["case_id"] = pd.to_numeric(df["case_id"], errors="coerce")
    else:
        for c in ["uid", "id", "uid_safe", "npz_path"]:
            if c in df.columns:
                src = df[c].astype(str)
                df["case_id"] = pd.to_numeric(src.str.extract(r"(\d+)")[0], errors="coerce")
                break
    if "case_id" not in df.columns:
        raise ValueError("Cannot infer case_id from pred_features csv.")
    df = df[df["case_id"].notna()].copy()
    df["case_id"] = df["case_id"].astype(int)
    return df

DROP_KEYS = {
    "case_id","uid","uid_safe","id","y","fold","split","img_path","mask_path",
    "tok_npz","match_npz","npz_path","path","file"
}

def infer_feature_cols(df: pd.DataFrame):
    df = df.copy()
    df = ensure_case_id_col(df)
    # candidate cols excluding ids/meta
    cand = [c for c in df.columns if c not in DROP_KEYS]
    if not cand:
        return []
    # coerce numeric & keep numeric-ish
    num_cols = []
    for c in cand:
        s = pd.to_numeric(df[c], errors="coerce")
        if s.notna().any():
            num_cols.append(c)
    return sorted(num_cols)

if cols_path.exists():
    FEATURE_COLS = json.loads(cols_path.read_text())
    print("Loaded FEATURE_COLS:", len(FEATURE_COLS), "from", cols_path)
else:
    # infer from train if available (preferred), else from test
    if pred_feat_train and pred_feat_train.exists():
        df_tmp = pd.read_csv(pred_feat_train)
        FEATURE_COLS = infer_feature_cols(df_tmp)
        src_used = str(pred_feat_train)
    else:
        df_tmp = pd.read_csv(pred_feat_test)
        FEATURE_COLS = infer_feature_cols(df_tmp)
        src_used = str(pred_feat_test)

    if len(FEATURE_COLS) == 0:
        raise RuntimeError("Failed to infer FEATURE_COLS (no usable numeric columns).")

    cols_path.write_text(json.dumps(FEATURE_COLS, indent=2))
    print("Created FEATURE_COLS:", len(FEATURE_COLS), "->", cols_path)
    print("Inferred from:", src_used)

# ----------------------------
# 4) Build ordered test_meta (sample_submission order)
# ----------------------------
df_sub = pd.read_csv(SAMPLE_SUB)
id_candidates = [c for c in ["case_id", "id", "uid"] if c in df_sub.columns]
id_col = id_candidates[0] if id_candidates else df_sub.columns[0]
sub_ids = pd.to_numeric(df_sub[id_col], errors="coerce").dropna().astype(int).tolist()
if len(sub_ids) == 0:
    raise RuntimeError("No numeric ids found in sample_submission.")

df_test_meta = pd.read_parquet(test_manifest_pq).copy()
if "case_id" not in df_test_meta.columns:
    raise ValueError("test_manifest.parquet must have case_id column.")
df_test_meta["case_id"] = pd.to_numeric(df_test_meta["case_id"], errors="coerce")
df_test_meta = df_test_meta[df_test_meta["case_id"].notna()].copy()
df_test_meta["case_id"] = df_test_meta["case_id"].astype(int)

base = pd.DataFrame({"case_id": sub_ids})
df_test_meta = base.merge(df_test_meta, on="case_id", how="left")

# ----------------------------
# 5) Load pred features + dedup
# ----------------------------
df_feat = pd.read_csv(pred_feat_test)
df_feat = ensure_case_id_col(df_feat)

keep_cols = ["case_id"] + [c for c in FEATURE_COLS if c in df_feat.columns]
df_feat = df_feat[keep_cols].copy()

sort_keys = []
for k in ["best_peak_score", "area_frac", "area_frac_tok"]:
    if k in df_feat.columns:
        sort_keys.append(k)

if sort_keys:
    df_feat = df_feat.sort_values(sort_keys, ascending=False).drop_duplicates("case_id", keep="first")
else:
    df_feat = df_feat.drop_duplicates("case_id", keep="first")

# ----------------------------
# 6) Join + fill
# ----------------------------
df_test_tabular = df_test_meta.merge(df_feat, on="case_id", how="left")

for c in FEATURE_COLS:
    if c not in df_test_tabular.columns:
        df_test_tabular[c] = np.nan
    df_test_tabular[c] = pd.to_numeric(df_test_tabular[c], errors="coerce")

fill_value = 0.0
n_missing_before = int(df_test_tabular[FEATURE_COLS].isna().sum().sum())
df_test_tabular[FEATURE_COLS] = df_test_tabular[FEATURE_COLS].fillna(fill_value)
n_missing_after = int(df_test_tabular[FEATURE_COLS].isna().sum().sum())

# arrays
X_test = df_test_tabular[FEATURE_COLS].to_numpy(dtype=np.float32, copy=True)
case_id_test = df_test_tabular["case_id"].to_numpy(dtype=np.int64, copy=True)

# save
df_test_tabular.to_parquet(ART_DIR / "test_table.parquet", index=False)
np.save(ART_DIR / "test_X.npy", X_test)
np.save(ART_DIR / "test_case_id.npy", case_id_test)

print("Saved:")
print(" -", ART_DIR / "feature_cols.json")
print(" -", ART_DIR / "test_table.parquet")
print(" -", ART_DIR / "test_X.npy")
print(" -", ART_DIR / "test_case_id.npy")
print("-"*60)
print("Test rows:", len(df_test_tabular), "| missing_before:", n_missing_before, "| missing_after:", n_missing_after)

# Export globals
FEATURE_COLS = FEATURE_COLS
X_test = X_test
case_id_test = case_id_test
df_test_tabular = df_test_tabular


PROF_DIR        : /kaggle/working/recodai_luc_prof
pred_feat_test  : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/pred_ens/pred_features_test.csv
pred_feat_train : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/pred_ens/pred_features_train.csv
Created FEATURE_COLS: 8 -> /kaggle/working/recodai_luc_gate_artifacts/feature_cols.json
Inferred from: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/pred_ens/pred_features_train.csv
Saved:
 - /kaggle/working/recodai_luc_gate_artifacts/feature_cols.json
 - /kaggle/working/recodai_luc_gate_artifacts/test_table.parquet
 - /kaggle/working/recodai_luc_gate_artifacts/test_X.npy
 - /kaggle/working/recodai_luc_gate_artifacts/test_case_id.npy
------------------------------------------------------------
Test rows: 1 | missing_before: 0 | missing_after: 0


# Train Baseline Model (Leakage-Safe CV)

In [4]:
# ============================================================
# STAGE — Train Gate Model (NO-FOLD, Ensemble Multi-Seed, CV OOF, Calibrate, Dice-Proxy Thr)
# REVISI FULL (sesuai pipeline: DINO dense + matching + seg decoder + gate + ensemble)
#
# Input minimal:
# - /kaggle/input/.../recodai_luc/cache/pred_ens/pred_features_train.csv  (atau di /kaggle/working)
# - pred masks: pred_ens/train/{case_id}.npz  (key: mask/m/pred/...)
# - GT masks: train_masks + supplemental_masks
# - train_manifest.parquet (recodai_luc_prof) untuk y & case_id
#
# Output:
# - /kaggle/working/recodai_luc_gate_artifacts/feature_cols.json
# - models/ (final per-seed models + calibrator)
# - oof/oof_prob*.npy + oof_prob.csv
# - eval/threshold_table.csv + best_threshold.json + cv_metrics.json
# - final_gate_model.pt  (portable)
# ============================================================

import os, json, time, warnings
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")

# ----------------------------
# Paths
# ----------------------------
WORK = Path("/kaggle/working")
INP  = Path("/kaggle/input")

ART_DIR   = WORK / "recodai_luc_gate_artifacts"
MODEL_DIR = ART_DIR / "models"
OOF_DIR   = ART_DIR / "oof"
EVAL_DIR  = ART_DIR / "eval"
for d in [ART_DIR, MODEL_DIR, OOF_DIR, EVAL_DIR]:
    d.mkdir(parents=True, exist_ok=True)

cols_path = ART_DIR / "feature_cols.json"
cfg_path  = ART_DIR / "train_cfg.json"  # optional (kalau ada dipakai), kalau tidak -> auto default

# ----------------------------
# Helpers: find competition root + prof dir + pred_features
# ----------------------------
def find_comp_root():
    cand = INP / "recodai-luc-scientific-image-forgery-detection"
    if cand.exists() and (cand / "sample_submission.csv").exists():
        return cand
    for d in sorted(INP.glob("*")):
        if d.is_dir() and (d / "sample_submission.csv").exists() and (d / "train_images").exists() and (d / "test_images").exists():
            return d
    raise FileNotFoundError("Cannot find competition dataset under /kaggle/input.")

def pick_prof_dir():
    w = WORK / "recodai_luc_prof"
    if (w / "train_manifest.parquet").exists():
        return w
    cands = []
    for p in INP.glob("*/recodai_luc_prof"):
        if (p / "train_manifest.parquet").exists():
            cands.append(p)
    for p in INP.glob("*/*/recodai_luc_prof"):
        if (p / "train_manifest.parquet").exists():
            cands.append(p)
    if cands:
        cands = sorted(cands, key=lambda x: (x / "train_manifest.parquet").stat().st_mtime, reverse=True)
        return cands[0]
    for pj in INP.rglob("recodai_luc_prof/train_manifest.parquet"):
        return pj.parent
    raise FileNotFoundError("Cannot find recodai_luc_prof/train_manifest.parquet in working or input.")

def newest_glob(base: Path, pattern: str):
    hits = list(base.glob(pattern))
    if not hits:
        return None
    hits = sorted(hits, key=lambda p: p.stat().st_mtime, reverse=True)
    return hits[0]

def find_pred_features_train():
    # working first
    cands = [
        newest_glob(WORK, "recodai_luc/cache/pred_ens/pred_features_train.csv"),
        newest_glob(WORK, "recodai_luc/cache/**/pred_features_train.csv"),
        newest_glob(WORK, "**/pred_features_train.csv"),
    ]
    for c in cands:
        if c is not None and c.exists():
            return c

    hits = list(INP.glob("**/pred_features_train.csv"))
    if hits:
        hits = sorted(hits, key=lambda p: p.stat().st_mtime, reverse=True)
        return hits[0]
    return None

def find_pred_mask_train_dir():
    # try common
    cands = [
        WORK / "recodai_luc/cache/pred_ens/train",
    ]
    for c in cands:
        if c.exists():
            return c
    # input
    hits = list(INP.glob("**/pred_ens/train"))
    hits = [h for h in hits if h.is_dir()]
    if hits:
        hits = sorted(hits, key=lambda p: p.stat().st_mtime, reverse=True)
        return hits[0]
    return None

def ensure_case_id_col(df: pd.DataFrame) -> pd.DataFrame:
    df = df.copy()
    if "case_id" in df.columns:
        df["case_id"] = pd.to_numeric(df["case_id"], errors="coerce")
    else:
        for c in ["uid", "id", "uid_safe", "npz_path", "path", "file"]:
            if c in df.columns:
                src = df[c].astype(str)
                df["case_id"] = pd.to_numeric(src.str.extract(r"(\d+)")[0], errors="coerce")
                break
    if "case_id" not in df.columns:
        raise ValueError("Cannot infer case_id.")
    df = df[df["case_id"].notna()].copy()
    df["case_id"] = df["case_id"].astype(int)
    return df

# ----------------------------
# Load CFG (optional)
# ----------------------------
CFG = {}
if cfg_path.exists():
    try:
        CFG = json.loads(cfg_path.read_text())
    except Exception:
        CFG = {}

COMP_ROOT = Path(CFG.get("paths", {}).get("COMP_ROOT", str(find_comp_root())))
PROF_DIR  = Path(CFG.get("paths", {}).get("PROF_DIR", str(pick_prof_dir())))

TRAIN_MASK_DIR = Path(CFG.get("paths", {}).get("TRAIN_MASK_DIR", str(COMP_ROOT / "train_masks")))
SUP_MASK_DIR   = Path(CFG.get("paths", {}).get("SUP_MASK_DIR",   str(COMP_ROOT / "supplemental_masks")))

pred_feat_train = Path(CFG.get("paths", {}).get("PRED_FEATURES_TRAIN", "")) if CFG.get("paths", {}).get("PRED_FEATURES_TRAIN") else None
if pred_feat_train is None or (not pred_feat_train.exists()):
    pred_feat_train = find_pred_features_train()
if pred_feat_train is None or (not Path(pred_feat_train).exists()):
    raise FileNotFoundError("pred_features_train.csv not found (working/input).")

PRED_TRAIN_DIR = Path(CFG.get("paths", {}).get("PRED_TRAIN_DIR", "")) if CFG.get("paths", {}).get("PRED_TRAIN_DIR") else None
if PRED_TRAIN_DIR is None or (not PRED_TRAIN_DIR.exists()):
    PRED_TRAIN_DIR = find_pred_mask_train_dir()
if PRED_TRAIN_DIR is None or (not PRED_TRAIN_DIR.exists()):
    raise FileNotFoundError("pred_ens/train directory not found (needs {case_id}.npz masks).")

print("COMP_ROOT     :", COMP_ROOT)
print("PROF_DIR      :", PROF_DIR)
print("pred_features :", pred_feat_train)
print("PRED_TRAIN_DIR:", PRED_TRAIN_DIR)
print("GT mask dirs  :", TRAIN_MASK_DIR, "|", SUP_MASK_DIR)
print("-"*80)

# ----------------------------
# Build training table for gate: train_manifest + pred_features_train
# ----------------------------
train_manifest_pq = PROF_DIR / "train_manifest.parquet"
if not train_manifest_pq.exists():
    raise FileNotFoundError(f"Missing {train_manifest_pq}")

df_man = pd.read_parquet(train_manifest_pq).copy()
df_man = ensure_case_id_col(df_man)
if "y" not in df_man.columns:
    raise ValueError("train_manifest.parquet must contain y.")
df_man["y"] = pd.to_numeric(df_man["y"], errors="coerce").fillna(0).astype(int)
df_man = df_man.drop_duplicates("case_id", keep="first")[["case_id", "y"]].copy()

df_feat = pd.read_csv(pred_feat_train)
df_feat = ensure_case_id_col(df_feat)

# ----------------------------
# FEATURE_COLS: load or infer from df_feat numeric columns
# ----------------------------
DROP_KEYS = {
    "case_id","uid","uid_safe","id","y","fold","split","img_path","mask_path",
    "tok_npz","match_npz","npz_path","path","file"
}

def infer_feature_cols_from_feat(df: pd.DataFrame):
    cand = [c for c in df.columns if c not in DROP_KEYS]
    num_cols = []
    for c in cand:
        s = pd.to_numeric(df[c], errors="coerce")
        if s.notna().any():
            num_cols.append(c)
    return sorted(num_cols)

if cols_path.exists():
    FEATURE_COLS = json.loads(cols_path.read_text())
else:
    FEATURE_COLS = infer_feature_cols_from_feat(df_feat)
    if len(FEATURE_COLS) == 0:
        raise RuntimeError("Failed to infer FEATURE_COLS from pred_features_train.csv.")
    cols_path.write_text(json.dumps(FEATURE_COLS, indent=2))

print("FEATURE_COLS:", len(FEATURE_COLS), "| saved:", cols_path)

# keep only needed cols
keep_cols = ["case_id"] + [c for c in FEATURE_COLS if c in df_feat.columns]
df_feat = df_feat[keep_cols].copy()

# dedup per case_id (prefer largest signal)
sort_keys = [k for k in ["best_peak_score","area_frac","area_frac_tok"] if k in df_feat.columns]
if sort_keys:
    df_feat = df_feat.sort_values(sort_keys, ascending=False).drop_duplicates("case_id", keep="first")
else:
    df_feat = df_feat.drop_duplicates("case_id", keep="first")

df_train_tab = df_man.merge(df_feat, on="case_id", how="left")

# ensure all feature cols exist
for c in FEATURE_COLS:
    if c not in df_train_tab.columns:
        df_train_tab[c] = np.nan
    df_train_tab[c] = pd.to_numeric(df_train_tab[c], errors="coerce")

fill_value = float(CFG.get("features", {}).get("missing_numeric_fill", 0.0))
df_train_tab[FEATURE_COLS] = df_train_tab[FEATURE_COLS].fillna(fill_value)

# save table for debugging/consistency
train_table_pq = ART_DIR / "train_table.parquet"
df_train_tab.to_parquet(train_table_pq, index=False)

X = df_train_tab[FEATURE_COLS].to_numpy(dtype=np.float32, copy=True)
y = df_train_tab["y"].to_numpy(dtype=np.int64, copy=True)
case_ids = df_train_tab["case_id"].to_numpy(dtype=np.int64, copy=True)

N = len(y)
print("Train rows:", N, "| pos_rate:", float(y.mean()))
print("-"*80)

# ----------------------------
# Model: LightGBM preferred, fallback sklearn
# ----------------------------
use_lgbm = True
try:
    import lightgbm as lgb
except Exception:
    use_lgbm = False

from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_auc_score, log_loss, f1_score, precision_score, recall_score
from sklearn.isotonic import IsotonicRegression
from sklearn.linear_model import LogisticRegression
import joblib

# SciPy optional (CC filtering)
try:
    import scipy.ndimage as ndi
    _HAS_SCIPY = True
except Exception:
    _HAS_SCIPY = False

# ----------------------------
# Gate ensemble config (multi-seed + CV OOF)
# ----------------------------
gate_cfg = CFG.get("gate_model", {})
seeds = gate_cfg.get("seeds", [42, 202, 777])
n_splits = int(gate_cfg.get("n_splits", 5))
early_rounds = int(gate_cfg.get("early_stopping_rounds", 200))

# sensible default LGB params (kalau cfg tidak isi)
default_lgb_params = {
    "objective": "binary",
    "metric": "auc",
    "learning_rate": 0.03,
    "num_leaves": 63,
    "min_data_in_leaf": 40,
    "feature_fraction": 0.9,
    "bagging_fraction": 0.8,
    "bagging_freq": 1,
    "lambda_l2": 1.0,
    "verbosity": -1,
}
params = gate_cfg.get("params", default_lgb_params)
# keep n_estimators as upper bound
num_boost_round = int(params.get("n_estimators", 4000))

print("Model:", "LightGBM" if use_lgbm else "HistGradientBoosting (fallback)")
print("Seeds:", seeds, "| n_splits:", n_splits)
print("-"*80)

# ----------------------------
# 1) CV OOF across seeds (accumulate sum/count -> averaged oof)
# ----------------------------
oof_sum = np.zeros(N, dtype=np.float32)
oof_cnt = np.zeros(N, dtype=np.int32)
cv_logs = []
best_iters = []

if use_lgbm:
    for s in seeds:
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=int(s))
        for k, (tr_idx, va_idx) in enumerate(skf.split(np.zeros(N), y)):
            dtr = lgb.Dataset(X[tr_idx], label=y[tr_idx], feature_name=FEATURE_COLS, free_raw_data=True)
            dva = lgb.Dataset(X[va_idx], label=y[va_idx], feature_name=FEATURE_COLS, free_raw_data=True)

            # seed per run
            run_params = dict(params)
            run_params["seed"] = int(s)
            run_params["feature_fraction_seed"] = int(s) + 1
            run_params["bagging_seed"] = int(s) + 2

            booster = lgb.train(
                params=run_params,
                train_set=dtr,
                valid_sets=[dva],
                valid_names=["val"],
                num_boost_round=num_boost_round,
                callbacks=[lgb.early_stopping(stopping_rounds=early_rounds, verbose=False)]
            )

            p = booster.predict(X[va_idx], num_iteration=booster.best_iteration)
            p = np.clip(p, 1e-6, 1-1e-6).astype(np.float32)
            oof_sum[va_idx] += p
            oof_cnt[va_idx] += 1
            best_iters.append(int(booster.best_iteration))

            auc = float(roc_auc_score(y[va_idx], p)) if len(np.unique(y[va_idx])) > 1 else float("nan")
            ll  = float(log_loss(y[va_idx], p, labels=[0,1]))
            pred05 = (p >= 0.5).astype(int)
            f1v = float(f1_score(y[va_idx], pred05))

            cv_logs.append({
                "seed": int(s), "split": int(k),
                "auc": auc, "logloss": ll, "f1@0.5": f1v,
                "best_iter": int(booster.best_iteration),
            })
            print(f"[seed {s} split {k}] auc={auc:.4f} logloss={ll:.4f} f1@0.5={f1v:.4f} iter={int(booster.best_iteration)}")
else:
    # fallback (no LGBM): use HistGradientBoosting
    from sklearn.ensemble import HistGradientBoostingClassifier
    for s in seeds:
        skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=int(s))
        for k, (tr_idx, va_idx) in enumerate(skf.split(np.zeros(N), y)):
            clf = HistGradientBoostingClassifier(
                learning_rate=0.05,
                max_leaf_nodes=63,
                min_samples_leaf=40,
                max_iter=600,
                random_state=int(s) + 1000 + int(k),
            )
            clf.fit(X[tr_idx], y[tr_idx])
            p = clf.predict_proba(X[va_idx])[:, 1]
            p = np.clip(p, 1e-6, 1-1e-6).astype(np.float32)
            oof_sum[va_idx] += p
            oof_cnt[va_idx] += 1

            auc = float(roc_auc_score(y[va_idx], p)) if len(np.unique(y[va_idx])) > 1 else float("nan")
            ll  = float(log_loss(y[va_idx], p, labels=[0,1]))
            pred05 = (p >= 0.5).astype(int)
            f1v = float(f1_score(y[va_idx], pred05))

            cv_logs.append({"seed": int(s), "split": int(k), "auc": auc, "logloss": ll, "f1@0.5": f1v})
            print(f"[seed {s} split {k}] auc={auc:.4f} logloss={ll:.4f} f1@0.5={f1v:.4f}")

oof_prob_raw = oof_sum / np.maximum(oof_cnt, 1).astype(np.float32)

auc_all = float(roc_auc_score(y, oof_prob_raw)) if len(np.unique(y)) > 1 else float("nan")
ll_all  = float(log_loss(y, oof_prob_raw, labels=[0,1]))
pred05_all = (oof_prob_raw >= 0.5).astype(int)
f1_all = float(f1_score(y, pred05_all))
prec_all = float(precision_score(y, pred05_all, zero_division=0))
rec_all  = float(recall_score(y, pred05_all, zero_division=0))

print("-"*80)
print(f"OOF raw (ensemble): auc={auc_all:.4f} logloss={ll_all:.4f} f1@0.5={f1_all:.4f} prec={prec_all:.4f} rec={rec_all:.4f}")

best_iter_mean = int(np.mean(best_iters)) if best_iters else None
best_iter_med  = int(np.median(best_iters)) if best_iters else None
print("best_iter_mean:", best_iter_mean, "| best_iter_median:", best_iter_med)
print("-"*80)

# ----------------------------
# 2) Calibration on OOF
# ----------------------------
cal_cfg = CFG.get("calibration", {"enabled": True, "method": "isotonic"})
cal_enabled = bool(cal_cfg.get("enabled", True))
cal_method = str(cal_cfg.get("method", "isotonic")).lower()

oof_prob = oof_prob_raw.copy()
cal_pack = {"enabled": False}

if cal_enabled:
    if cal_method == "isotonic":
        iso = IsotonicRegression(out_of_bounds="clip")
        iso.fit(oof_prob_raw, y)
        oof_prob = np.clip(iso.predict(oof_prob_raw).astype(np.float32), 1e-6, 1-1e-6)
        cal_pack = {"enabled": True, "method": "isotonic", "path": str(MODEL_DIR / "calibration_isotonic.joblib")}
        joblib.dump(iso, MODEL_DIR / "calibration_isotonic.joblib")
    elif cal_method in ["sigmoid", "platt"]:
        p = np.clip(oof_prob_raw, 1e-6, 1-1e-6)
        logit = np.log(p/(1-p)).reshape(-1,1)
        lr = LogisticRegression(solver="lbfgs", max_iter=200)
        lr.fit(logit, y)
        oof_prob = np.clip(lr.predict_proba(logit)[:,1].astype(np.float32), 1e-6, 1-1e-6)
        cal_pack = {"enabled": True, "method": "sigmoid", "path": str(MODEL_DIR / "calibration_sigmoid.joblib")}
        joblib.dump(lr, MODEL_DIR / "calibration_sigmoid.joblib")

auc_cal = float(roc_auc_score(y, oof_prob)) if len(np.unique(y)) > 1 else float("nan")
ll_cal  = float(log_loss(y, oof_prob, labels=[0,1]))
pred05_cal = (oof_prob >= 0.5).astype(int)
f1_cal = float(f1_score(y, pred05_cal))

(MODEL_DIR / "calibration.json").write_text(json.dumps(cal_pack, indent=2))
print(f"OOF calibrated: auc={auc_cal:.4f} logloss={ll_cal:.4f} f1@0.5={f1_cal:.4f} | calibration={cal_pack}")
print("-"*80)

# save OOF
np.save(OOF_DIR / "oof_prob_raw.npy", oof_prob_raw)
np.save(OOF_DIR / "oof_prob.npy", oof_prob)
pd.DataFrame({
    "case_id": case_ids.astype(int),
    "y": y.astype(int),
    "oof_prob_raw": oof_prob_raw.astype(float),
    "oof_prob": oof_prob.astype(float),
}).to_csv(OOF_DIR / "oof_prob.csv", index=False)

# ----------------------------
# 3) Dice-proxy (GT vs predicted masks) for threshold selection
# ----------------------------
PP = CFG.get("postprocess_ref", {})
MIN_INST_PIX  = int(PP.get("min_inst_pix", 32))
MAX_AREA_FRAC = float(PP.get("max_area_frac", 0.90))
MAX_INST_KEEP = int(PP.get("max_inst_keep", 8))

def _find_mask_files(mask_dir: Path, case_id: int):
    if mask_dir is None or (not mask_dir.exists()):
        return []
    cid = str(int(case_id))
    exts = (".png",".jpg",".jpeg",".tif",".tiff",".bmp")
    pats = [f"{cid}*.png", f"{cid}*.jpg", f"{cid}*.jpeg", f"{cid}*.tif", f"{cid}*.tiff", f"{cid}*.bmp"]
    out, seen = [], set()
    for pat in pats:
        for p in mask_dir.glob(pat):
            if p.suffix.lower() in exts:
                s = str(p)
                if s not in seen:
                    out.append(p); seen.add(s)
    return sorted(out)

def load_gt_union(case_id: int):
    # npy shortcut
    for d in [TRAIN_MASK_DIR, SUP_MASK_DIR]:
        if d is None or (not d.exists()):
            continue
        npy = d / f"{int(case_id)}.npy"
        if npy.exists():
            a = np.load(npy, mmap_mode="r")
            a = np.asarray(a)
            if a.ndim == 2:
                return (a > 0)
            if a.ndim == 3:
                return (a > 0).any(axis=0)
    # union images
    files = []
    files += _find_mask_files(TRAIN_MASK_DIR, case_id)
    files += _find_mask_files(SUP_MASK_DIR, case_id)
    if not files:
        return None
    m = None
    for p in files:
        try:
            im = Image.open(p).convert("L")
            a = (np.asarray(im) > 0)
            m = a if m is None else (m | a)
        except Exception:
            continue
    return m

def load_pred_union(case_id: int):
    p = PRED_TRAIN_DIR / f"{int(case_id)}.npz"
    if not p.exists():
        return None
    z = np.load(p)
    # support multiple keys
    for k in ["mask", "m", "pred", "pred_mask", "mask_full", "mask_up"]:
        if k in z.files:
            a = z[k]
            return (np.asarray(a) > 0)
    # fallback: first array
    for k in z.files:
        a = z[k]
        if getattr(a, "ndim", 0) == 2:
            return (np.asarray(a) > 0)
    return None

def cc_union_filtered(mask_bool: np.ndarray):
    if mask_bool is None:
        return None, {"n_inst": 0, "area": 0}
    m = mask_bool.astype(bool)
    H, W = m.shape
    area = int(m.sum())
    if area == 0:
        return m, {"n_inst": 0, "area": 0}
    if (area / float(H*W)) > MAX_AREA_FRAC:
        return np.zeros_like(m, dtype=bool), {"n_inst": 0, "area": 0}
    if not _HAS_SCIPY:
        if area < MIN_INST_PIX:
            return np.zeros_like(m, dtype=bool), {"n_inst": 0, "area": 0}
        return m, {"n_inst": 1, "area": area}
    lab, n = ndi.label(m, structure=np.ones((3,3), dtype=np.uint8))
    if n <= 0:
        return np.zeros_like(m, dtype=bool), {"n_inst": 0, "area": 0}
    areas = ndi.sum(m.astype(np.uint8), lab, index=np.arange(1, n+1)).astype(np.int64)
    keep = np.where(areas >= MIN_INST_PIX)[0]
    if keep.size == 0:
        return np.zeros_like(m, dtype=bool), {"n_inst": 0, "area": 0}
    keep = keep[np.argsort(areas[keep])[::-1]][:MAX_INST_KEEP]
    out = np.zeros_like(m, dtype=bool)
    for k in keep:
        out |= (lab == (k + 1))
    return out, {"n_inst": int(len(keep)), "area": int(out.sum())}

def dice_score(pr: np.ndarray, gt: np.ndarray):
    pr = pr.astype(bool); gt = gt.astype(bool)
    a = int(pr.sum()); b = int(gt.sum())
    if a == 0 and b == 0: return 1.0
    if a == 0 or b == 0: return 0.0
    inter = int((pr & gt).sum())
    return float((2.0 * inter) / (a + b))

dice_use  = np.zeros(N, dtype=np.float32)
dice_empty= np.zeros(N, dtype=np.float32)

miss_gt = miss_pr = bad_shape = 0
t1 = time.time()

for i, cid in enumerate(case_ids.tolist()):
    gt = load_gt_union(cid)
    pr = load_pred_union(cid)

    if gt is None:
        miss_gt += 1
        gt_mask = None
        gt_empty = True
    else:
        gt_mask = gt.astype(bool)
        gt_empty = (gt_mask.sum() == 0)

    dice_empty[i] = 1.0 if gt_empty else 0.0

    if pr is None:
        miss_pr += 1
        dice_use[i] = dice_empty[i]
        continue

    pr_mask = pr.astype(bool)
    if gt_mask is not None and pr_mask.shape != gt_mask.shape:
        bad_shape += 1
        im = Image.fromarray((pr_mask.astype(np.uint8)*255))
        im = im.resize((gt_mask.shape[1], gt_mask.shape[0]), resample=Image.NEAREST)
        pr_mask = (np.asarray(im) > 127)

    pr_f, _ = cc_union_filtered(pr_mask)
    if gt_mask is None:
        dice_use[i] = 1.0 if (pr_f.sum() == 0) else 0.0
    else:
        gt_f, _ = cc_union_filtered(gt_mask)
        dice_use[i] = dice_score(pr_f, gt_f)

    if (i + 1) % 500 == 0:
        print(f"[dice-proxy] {i+1}/{N} | miss_gt={miss_gt} miss_pr={miss_pr} bad_shape={bad_shape} | {time.time()-t1:.1f}s")

print("-"*80)
print(f"Dice-proxy ready | miss_gt={miss_gt} miss_pr={miss_pr} bad_shape={bad_shape} | {time.time()-t1:.1f}s")
print("-"*80)

# ----------------------------
# 4) Threshold sweep (optimize dice-proxy, tie-break by F1, then FP-rate)
# ----------------------------
thr_grid = np.linspace(0.0, 1.0, 201, dtype=np.float32)
rows = []
for thr in thr_grid:
    use = (oof_prob >= thr)
    score = np.where(use, dice_use, dice_empty).mean()

    pred = use.astype(int)
    f1v = f1_score(y, pred)
    prec = precision_score(y, pred, zero_division=0)
    rec  = recall_score(y, pred, zero_division=0)
    fp_rate = float(((pred==1) & (y==0)).mean())
    fn_rate = float(((pred==0) & (y==1)).mean())

    rows.append({
        "thr": float(thr),
        "score_dice_proxy": float(score),
        "f1_gate": float(f1v),
        "precision_gate": float(prec),
        "recall_gate": float(rec),
        "fp_rate": fp_rate,
        "fn_rate": fn_rate,
    })

df_thr = pd.DataFrame(rows)

# tie-break: maximize dice_proxy, then maximize f1, then minimize fp_rate
df_thr["_neg_fp"] = -df_thr["fp_rate"]
df_thr = df_thr.sort_values(["score_dice_proxy","f1_gate","_neg_fp"], ascending=[False,False,False]).reset_index(drop=True)
best_thr = float(df_thr.loc[0, "thr"])
best_score = float(df_thr.loc[0, "score_dice_proxy"])
df_thr = df_thr.drop(columns=["_neg_fp"])

df_thr.to_csv(EVAL_DIR / "threshold_table.csv", index=False)
(EVAL_DIR / "best_threshold.json").write_text(json.dumps({
    "recommended_thr": best_thr,
    "best_score_dice_proxy": best_score,
    "best_row": df_thr.loc[0].to_dict(),
    "instance_split_fullres": {"MIN_INST_PIX": MIN_INST_PIX, "MAX_AREA_FRAC": MAX_AREA_FRAC, "MAX_INST_KEEP": MAX_INST_KEEP},
}, indent=2))

print("BEST (dice-proxy): thr =", best_thr, "| score =", best_score)
print("-"*80)

# ----------------------------
# 5) Train FINAL models on full data (per-seed) for inference
# ----------------------------
final_models = []
if use_lgbm:
    # choose num_boost_round for final training
    final_rounds = best_iter_med if best_iter_med is not None else int(min(2000, num_boost_round))
    dfull = lgb.Dataset(X, label=y, feature_name=FEATURE_COLS, free_raw_data=True)
    for s in seeds:
        run_params = dict(params)
        run_params["seed"] = int(s)
        run_params["feature_fraction_seed"] = int(s) + 1
        run_params["bagging_seed"] = int(s) + 2
        booster = lgb.train(run_params, dfull, num_boost_round=int(final_rounds))
        mp = MODEL_DIR / f"final_seed{s}.txt"
        booster.save_model(str(mp))
        final_models.append({"type": "lgbm", "seed": int(s), "path": str(mp), "num_boost_round": int(final_rounds)})
else:
    from sklearn.ensemble import HistGradientBoostingClassifier
    for s in seeds:
        clf = HistGradientBoostingClassifier(
            learning_rate=0.05,
            max_leaf_nodes=63,
            min_samples_leaf=40,
            max_iter=600,
            random_state=int(s) + 9999,
        )
        clf.fit(X, y)
        mp = MODEL_DIR / f"final_seed{s}.joblib"
        joblib.dump(clf, mp)
        final_models.append({"type": "hgb", "seed": int(s), "path": str(mp)})

# ----------------------------
# 6) Save metrics + portable bundle
# ----------------------------
cv_metrics = {
    "oof_raw": {"auc": auc_all, "logloss": ll_all, "f1@0.5": f1_all, "prec": prec_all, "rec": rec_all},
    "oof_cal": {"auc": auc_cal, "logloss": ll_cal, "f1@0.5": f1_cal},
    "recommended_thr": best_thr,
    "best_score_dice_proxy": best_score,
    "cv_logs": cv_logs[:],  # can be large; ok
    "best_iter_mean": best_iter_mean,
    "best_iter_median": best_iter_med,
    "seeds": seeds,
    "n_splits": n_splits,
    "calibration": cal_pack,
    "feature_cols": FEATURE_COLS,
    "paths": {
        "pred_features_train": str(pred_feat_train),
        "pred_mask_train_dir": str(PRED_TRAIN_DIR),
        "train_manifest": str(train_manifest_pq),
    }
}
(EVAL_DIR / "cv_metrics.json").write_text(json.dumps(cv_metrics, indent=2))

import torch
torch.save({
    "cfg": CFG,
    "feature_cols": FEATURE_COLS,
    "final_models": final_models,
    "calibration": cal_pack,
    "recommended_thr": best_thr,
    "dice_proxy_cfg": {"MIN_INST_PIX": MIN_INST_PIX, "MAX_AREA_FRAC": MAX_AREA_FRAC, "MAX_INST_KEEP": MAX_INST_KEEP},
}, ART_DIR / "final_gate_model.pt")

print("Saved:")
print(" -", train_table_pq)
print(" -", cols_path)
print(" -", OOF_DIR / "oof_prob.csv")
print(" -", EVAL_DIR / "threshold_table.csv")
print(" -", EVAL_DIR / "best_threshold.json")
print(" -", EVAL_DIR / "cv_metrics.json")
print(" -", ART_DIR / "final_gate_model.pt")
print("Final models:", len(final_models), "| dir:", MODEL_DIR)


COMP_ROOT     : /kaggle/input/recodai-luc-scientific-image-forgery-detection
PROF_DIR      : /kaggle/working/recodai_luc_prof
pred_features : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/pred_ens/pred_features_train.csv
PRED_TRAIN_DIR: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/pred_ens/train
GT mask dirs  : /kaggle/input/recodai-luc-scientific-image-forgery-detection/train_masks | /kaggle/input/recodai-luc-scientific-image-forgery-detection/supplemental_masks
--------------------------------------------------------------------------------
FEATURE_COLS: 8 | saved: /kaggle/working/recodai_luc_gate_artifacts/feature_cols.json
Train rows: 2795 | pos_rate: 1.0
--------------------------------------------------------------------------------
Model: LightGBM
Seeds: [42, 202, 777] | n_splits: 5
--------------------------------------------------------------------------------
[seed 42 split 0] auc=nan logloss=0.0000 f1@0.5=1.0000 iter=1
[seed 42 split 1] auc=nan logloss=0

# Optimize Model & Hyperparameters (Iterative)

In [5]:
# ============================================================
# FIX v2 — PATCH train_table.parquet with valid tok_npz
# - Auto-detect TOKEN_MANIFEST_ROOT (where tokens_manifest_train.parquet exists)
# - Auto-detect TOKEN_DATA_ROOT (a sibling dir that actually contains token .npz)
# - Build fast basename->absolute_path index from TOKEN_DATA_ROOT/{train,test,train_all,test_all}
# - Join into train_table by case_id (fallback uid if needed)
# Output:
#   /kaggle/working/recodai_luc_gate_artifacts/train_table.patched_tokens.parquet
# ============================================================

import json, re
from pathlib import Path
import numpy as np
import pandas as pd

ART_DIR = Path("/kaggle/working/recodai_luc_gate_artifacts")
TRAIN_TABLE_IN = ART_DIR / "train_table.parquet"
TRAIN_TABLE_OUT = ART_DIR / "train_table.patched_tokens.parquet"

if not TRAIN_TABLE_IN.exists():
    raise FileNotFoundError(f"Missing {TRAIN_TABLE_IN}")

# ------------------------------------------------------------
# 0) Locate base cache dir (from cfg if possible, else fallback)
# ------------------------------------------------------------
base_cache = None
cfg_try = [ART_DIR / "train_cfg.json", Path("/kaggle/working/recodai_luc_mask_artifacts/mask_cfg.json")]
for cp in cfg_try:
    if cp.exists():
        try:
            j = json.loads(cp.read_text())
            if "paths" in j and "TOKEN_ROOT" in j["paths"]:
                base_cache = Path(j["paths"]["TOKEN_ROOT"]).parent  # parent of token root
                break
        except Exception:
            pass

if base_cache is None:
    # fallback to your dataset structure
    base_cache = Path("/kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache")

if not base_cache.exists():
    raise FileNotFoundError(f"Base cache dir not found: {base_cache}")

print("BASE_CACHE:", base_cache)

# ------------------------------------------------------------
# 1) Find TOKEN_MANIFEST_ROOT: where tokens_manifest_train.parquet exists
# ------------------------------------------------------------
man_candidates = sorted(base_cache.glob("**/tokens_manifest_train.parquet"))
if not man_candidates:
    raise FileNotFoundError(f"Cannot find tokens_manifest_train.parquet under {base_cache}")

# prefer the one that sits next to a train/ folder OR most recently modified
def man_score(p: Path):
    root = p.parent
    has_train = int((root/"train").exists() or (root/"train_all").exists())
    return (has_train, p.stat().st_mtime)

man_candidates = sorted(man_candidates, key=man_score, reverse=True)
TOK_MAN_TRAIN = man_candidates[0]
TOKEN_MANIFEST_ROOT = TOK_MAN_TRAIN.parent

print("TOKEN_MANIFEST_ROOT:", TOKEN_MANIFEST_ROOT)
print("TOK_MAN_TRAIN      :", TOK_MAN_TRAIN)

df_tok = pd.read_parquet(TOK_MAN_TRAIN).copy()

# ------------------------------------------------------------
# 2) Pick TOKEN_DATA_ROOT: a dir under base_cache with most *.npz in {train,test,train_all,test_all}
#    (This fixes your 'bind' root that has no npz)
# ------------------------------------------------------------
def count_npz_under(root: Path):
    total = 0
    for sub in ["train", "test", "train_all", "test_all"]:
        d = root / sub
        if d.exists():
            total += sum(1 for _ in d.glob("*.npz"))
            total += sum(1 for _ in d.glob("*.npy"))
    return total

# candidate roots = direct children of base_cache that look like dinov2_*
root_candidates = [p for p in base_cache.iterdir() if p.is_dir()]
root_scored = [(count_npz_under(p), p) for p in root_candidates]
root_scored.sort(reverse=True, key=lambda x: x[0])

TOKEN_DATA_ROOT = root_scored[0][1]
best_cnt = root_scored[0][0]

print("TOKEN_DATA_ROOT    :", TOKEN_DATA_ROOT)
print("TOKEN_DATA_ROOT npz:", best_cnt)

if best_cnt == 0:
    # last resort: maybe tokens are deeper; try a broader scan but still limited
    deep = []
    for p in base_cache.glob("**/train"):
        if p.is_dir():
            deep.append(p.parent)
    deep = list(dict.fromkeys(deep))  # unique preserve order
    deep_scored = [(count_npz_under(p), p) for p in deep]
    deep_scored.sort(reverse=True, key=lambda x: x[0])
    if deep_scored and deep_scored[0][0] > 0:
        TOKEN_DATA_ROOT = deep_scored[0][1]
        best_cnt = deep_scored[0][0]
        print("DEEP PICK TOKEN_DATA_ROOT:", TOKEN_DATA_ROOT, "| npz:", best_cnt)

if best_cnt == 0:
    raise RuntimeError(
        "Still no .npz/.npy found in any candidate TOKEN_DATA_ROOT.\n"
        f"Check dataset mount under: {base_cache}"
    )

# ------------------------------------------------------------
# 3) Build basename -> abs path index (FAST, no rglob)
# ------------------------------------------------------------
name2path = {}
for sub in ["train", "test", "train_all", "test_all"]:
    d = TOKEN_DATA_ROOT / sub
    if not d.exists():
        continue
    for p in d.glob("*.npz"):
        name2path[p.name] = str(p)
    for p in d.glob("*.npy"):
        name2path[p.name] = str(p)

print("Indexed token files (unique basenames):", len(name2path))

# ------------------------------------------------------------
# 4) Normalize df_tok: ensure case_id & pick path column
# ------------------------------------------------------------
def infer_case_id_from_series(s: pd.Series) -> pd.Series:
    s = s.astype(str)
    out = s.str.extract(r"(\d+)")[0]
    return pd.to_numeric(out, errors="coerce")

if "case_id" not in df_tok.columns:
    for c in ["uid", "uid_safe", "id", "npz_path", "tok_npz", "path", "file"]:
        if c in df_tok.columns:
            df_tok["case_id"] = infer_case_id_from_series(df_tok[c])
            break

df_tok["case_id"] = pd.to_numeric(df_tok["case_id"], errors="coerce")
df_tok = df_tok[df_tok["case_id"].notna()].copy()
df_tok["case_id"] = df_tok["case_id"].astype(int)

# pick a token path column
def pick_token_path_col(df):
    priority = ["npz_path","tok_npz","token_npz","path","file"]
    for c in priority:
        if c in df.columns and df[c].astype(str).str.contains(r"\.npz$|\.npy$", regex=True, na=False).mean() > 0.2:
            return c
    # fallback: any column with many .npz/.npy strings
    best, best_rate = None, 0.0
    for c in df.columns:
        rate = df[c].astype(str).str.contains(r"\.npz$|\.npy$", regex=True, na=False).mean()
        if rate > best_rate:
            best, best_rate = c, rate
    return best

tok_path_col = pick_token_path_col(df_tok)
if tok_path_col is None:
    raise ValueError("Cannot find token path column in tokens_manifest_train.parquet")

df_tok["tok_raw"] = df_tok[tok_path_col].astype(str)

def resolve_tok(p):
    if p is None:
        return None
    s = str(p)
    if not s or s.lower() == "nan":
        return None
    pp = Path(s)
    if pp.exists():
        return str(pp)

    bn = pp.name
    if bn in name2path:
        return name2path[bn]

    # try tail after manifest root name (handles absolute from other root)
    s2 = s.replace("\\", "/")
    root_name = TOKEN_DATA_ROOT.name
    if root_name in s2:
        tail = s2.split(root_name, 1)[1].lstrip("/")
        cand = TOKEN_DATA_ROOT / tail
        if cand.exists():
            return str(cand)

    # try common folders
    for sub in ["train", "test", "train_all", "test_all"]:
        cand = TOKEN_DATA_ROOT / sub / bn
        if cand.exists():
            return str(cand)

    return None

df_tok["tok_npz"] = df_tok["tok_raw"].map(resolve_tok)

# dedup per case_id prefer existing
df_tok["_ok"] = df_tok["tok_npz"].map(lambda x: isinstance(x, str) and Path(x).exists())
df_tok = df_tok.sort_values(["case_id","_ok"], ascending=[True, False]).drop_duplicates("case_id", keep="first").drop(columns=["_ok"])

# ------------------------------------------------------------
# 5) Patch train_table.parquet
# ------------------------------------------------------------
df_table = pd.read_parquet(TRAIN_TABLE_IN).copy()
if "case_id" not in df_table.columns:
    raise ValueError("train_table.parquet must have case_id")
df_table["case_id"] = pd.to_numeric(df_table["case_id"], errors="coerce").fillna(-1).astype(int)

df_table = df_table.merge(df_tok[["case_id","tok_npz"]], on="case_id", how="left")
df_table["tok_exists"] = df_table["tok_npz"].map(lambda x: isinstance(x,str) and Path(x).exists())

join_rate = float(df_table["tok_npz"].notna().mean())
exists_rate = float(df_table["tok_exists"].mean())

print("-"*70)
print("train_table rows:", len(df_table))
print("tok_npz not-null rate:", join_rate)
print("tok_npz exists rate   :", exists_rate)

if exists_rate < 0.90:
    bad = df_table[~df_table["tok_exists"]]
    print("Bad examples (case_id, tok_npz):")
    print(bad[["case_id","tok_npz"]].head(12).to_string(index=False))

df_table.to_parquet(TRAIN_TABLE_OUT, index=False)
print("-"*70)
print("Saved patched:", TRAIN_TABLE_OUT)


BASE_CACHE: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache
TOKEN_MANIFEST_ROOT: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/dinov2_base_518_cfg_543289469500
TOK_MAN_TRAIN      : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/dinov2_base_518_cfg_543289469500/tokens_manifest_train.parquet
TOKEN_DATA_ROOT    : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/dinov2_base_518_cfg_543289469500
TOKEN_DATA_ROOT npz: 2796
Indexed token files (unique basenames): 2796
----------------------------------------------------------------------
train_table rows: 2795
tok_npz not-null rate: 1.0
tok_npz exists rate   : 1.0
----------------------------------------------------------------------
Saved patched: /kaggle/working/recodai_luc_gate_artifacts/train_table.patched_tokens.parquet


In [6]:
# ============================================================
# STAGE — Optimize Hybrid UNet(+ASPP) Token-Decoder + Gate Head (ONE CELL)
# REVISI FULL (NO-FOLD, TOKEN/MATCH PATH AUTO-RESOLVE via FILE INDEX, FAIL-SAFE SEED)
# ============================================================

import os, json, time, random, warnings, re
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import StratifiedKFold

# ----------------------------
# Config (env override)
# ----------------------------
OUT_DIR = Path("/kaggle/working/recodai_luc_hybrid_opt")
OUT_DIR.mkdir(parents=True, exist_ok=True)

WORK = Path("/kaggle/working")
INP  = Path("/kaggle/input")
CACHE_DIR = Path("/kaggle/working/recodai_luc/cache")

SEED = int(os.environ.get("SEED", "42"))
MAX_TRIALS = int(os.environ.get("MAX_TRIALS", "12"))
TRIAL_EPOCHS = int(os.environ.get("TRIAL_EPOCHS", "6"))
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "32"))
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", "2"))
ACCUM_STEPS = int(os.environ.get("ACCUM_STEPS", "1"))
USE_AMP = bool(int(os.environ.get("USE_AMP", "0")))  # CPU -> off
EARLYSTOP_PATIENCE = int(os.environ.get("EARLYSTOP_PATIENCE", "2"))

N_SPLITS = int(os.environ.get("N_SPLITS", "5"))
VAL_SPLIT_ROTATE = bool(int(os.environ.get("VAL_SPLIT_ROTATE", "1")))

LR_RANGE = (2e-4, 2e-3)
WD_RANGE = (0.0, 0.05)
DROPOUT_RANGE = (0.0, 0.25)
BASE_CH_CHOICES = [64, 96, 128]
LAMBDA_SEG_RANGE = (0.6, 1.2)
LAMBDA_CLS_RANGE = (0.3, 0.9)
FOCAL_GAMMA_CHOICES = [0.0, 1.0, 2.0]

T1_RANGE = (0.50, 0.65)
T0_RANGE = (0.25, 0.45)
SEED_DILATE_CHOICES = [0, 1, 2]
THR_GATE_RANGE = (0.20, 0.80)

MIN_TOK_AREA_CHOICES = [1, 2, 3]
MAX_TOK_AREA_FRAC_CHOICES = [0.70, 0.80, 0.90]
MAX_INST_KEEP_CHOICES = [4, 8, 12]

MIN_PEAK_SCORE_KEEP_CHOICES = [5, 6, 7]
MIN_AREA_FRAC_KEEP_CHOICES = [0.0003, 0.0005, 0.0010]

# ----------------------------
# Repro + device
# ----------------------------
def seed_everything(s=42):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)

seed_everything(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
amp_ok = (USE_AMP and device.type == "cuda")
print("DEVICE:", device, "| AMP:", amp_ok)

# ----------------------------
# Find competition root
# ----------------------------
def find_comp_root():
    cand = INP / "recodai-luc-scientific-image-forgery-detection"
    if cand.exists() and (cand / "train_images").exists() and (cand / "sample_submission.csv").exists():
        return cand
    for d in sorted(INP.glob("*")):
        if d.is_dir() and (d / "train_images").exists() and (d / "sample_submission.csv").exists():
            return d
    raise FileNotFoundError("Cannot find competition dataset under /kaggle/input.")

COMP_ROOT = find_comp_root()
TRAIN_MASK_DIR = COMP_ROOT / "train_masks"
SUP_MASK_DIR   = COMP_ROOT / "supplemental_masks"
print("COMP_ROOT     :", COMP_ROOT)
print("TRAIN_MASK_DIR:", TRAIN_MASK_DIR, "| exists:", TRAIN_MASK_DIR.exists())
print("SUP_MASK_DIR  :", SUP_MASK_DIR,   "| exists:", SUP_MASK_DIR.exists())

# ----------------------------
# Helpers
# ----------------------------
def infer_case_id_any(x):
    if x is None:
        return np.nan
    s = str(x)
    if s.lower() in ["nan","none",""]:
        return np.nan
    m = re.search(r"\b(\d{3,})\b", s)
    if m:
        return float(m.group(1))
    m = re.search(r"(\d+)", s)
    return float(m.group(1)) if m else np.nan

def build_npz_index(root: Path):
    """
    Build filename -> fullpath mapping from existing npz files under root.
    Prefer shallower paths; keep latest mtime if duplicates.
    """
    if root is None or (not root.exists()):
        return {}, []
    files = list(root.glob("**/*.npz"))
    mp = {}
    for p in files:
        bn = p.name
        if bn not in mp:
            mp[bn] = p
        else:
            try:
                if p.stat().st_mtime > mp[bn].stat().st_mtime:
                    mp[bn] = p
            except Exception:
                pass
    return {k: str(v) for k,v in mp.items()}, files

def resolve_by_index(p, root: Path, idx_map: dict):
    if p is None:
        return None
    s = str(p)
    if s.lower() in ["nan","none",""]:
        return None
    pp = Path(s)
    if pp.exists():
        return str(pp)

    # relative under root
    if root is not None and root.exists() and (not pp.is_absolute()):
        cand = root / pp
        if cand.exists():
            return str(cand)

    # fallback basename lookup
    bn = pp.name
    if bn in idx_map:
        return idx_map[bn]

    # last: try common subdirs
    if root is not None and root.exists():
        for seg in ["train", "test", "train_all", "test_all", ""]:
            cand = (root / seg / bn) if seg else (root / bn)
            if cand.exists():
                return str(cand)

    return None

def pick_best_path_col(df, candidates, root: Path, idx_map: dict):
    best = (None, -1.0)
    for c in candidates:
        if c not in df.columns:
            continue
        rr = df[c].map(lambda x: resolve_by_index(x, root, idx_map))
        ex = rr.map(lambda x: isinstance(x,str) and Path(x).exists()).mean() if len(rr) else 0.0
        if ex > best[1]:
            best = (c, float(ex))
    return best[0], best[1]

# ----------------------------
# Locate tokens_manifest_train.parquet (BEST by real token coverage)
# + Resolve TOKEN_DATA_ROOT even if manifest is under *_bind_* (no npz)
# ----------------------------
def _count_tok_files(root: Path):
    if root is None or (not root.exists()):
        return 0
    total = 0
    for sub in ["train","test","train_all","test_all"]:
        d = root / sub
        if d.exists():
            total += sum(1 for _ in d.glob("*.npz"))
            total += sum(1 for _ in d.glob("*.npy"))
    return total

def pick_best_tokens_manifest_train():
    hits = []
    hits += list(WORK.glob("**/tokens_manifest_train.parquet"))
    hits += list(INP.glob("**/tokens_manifest_train.parquet"))
    hits = [p for p in hits if p.exists()]
    if not hits:
        return None

    # score: prefer parent that actually has token files in train/test dirs
    def score(p: Path):
        r = p.parent
        n_tok = _count_tok_files(r)
        has_train_dir = int((r/"train").exists() or (r/"train_all").exists())
        # (has_tokens, n_tokens, has_train_dir, mtime)
        return (int(n_tok > 0), int(n_tok), has_train_dir, float(p.stat().st_mtime))

    hits = sorted(hits, key=score, reverse=True)
    return hits[0]

TOK_MAN_TRAIN = pick_best_tokens_manifest_train()
if TOK_MAN_TRAIN is None:
    raise FileNotFoundError("tokens_manifest_train.parquet not found. Token cache must exist.")

TOKEN_MANIFEST_ROOT = TOK_MAN_TRAIN.parent

# base cache is one level above token cfg dir (the folder that contains dinov2_base_... siblings)
BASE_CACHE = TOKEN_MANIFEST_ROOT.parent
if not BASE_CACHE.exists():
    # fallback for odd layouts
    BASE_CACHE = Path("/kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache")

# pick best TOKEN_DATA_ROOT among siblings under BASE_CACHE
sib_dirs = [p for p in BASE_CACHE.iterdir() if p.is_dir()]
sib_scored = sorted([(_count_tok_files(d), d) for d in sib_dirs], key=lambda x: x[0], reverse=True)
best_cnt, best_root = sib_scored[0] if sib_scored else (0, None)

# If manifest root has 0 tokens, switch data root to sibling with tokens
manifest_cnt = _count_tok_files(TOKEN_MANIFEST_ROOT)
TOKEN_DATA_ROOT = TOKEN_MANIFEST_ROOT if manifest_cnt > 0 else best_root

print("TOK_MAN_TRAIN       :", TOK_MAN_TRAIN)
print("TOKEN_MANIFEST_ROOT :", TOKEN_MANIFEST_ROOT, "| token_files:", manifest_cnt)
print("BASE_CACHE          :", BASE_CACHE)
print("TOKEN_DATA_ROOT     :", TOKEN_DATA_ROOT, "| token_files:", best_cnt)

if TOKEN_DATA_ROOT is None or (not Path(TOKEN_DATA_ROOT).exists()) or _count_tok_files(Path(TOKEN_DATA_ROOT)) == 0:
    raise RuntimeError(
        "No token .npz/.npy found in any TOKEN_DATA_ROOT candidate.\n"
        f"Checked siblings under: {BASE_CACHE}"
    )

# Build fast basename->fullpath index (DO NOT use **/*.npz on big trees)
def build_tok_index(data_root: Path):
    mp = {}
    files = []
    for sub in ["train","test","train_all","test_all"]:
        d = data_root / sub
        if not d.exists():
            continue
        for p in d.glob("*.npz"):
            files.append(p); mp[p.name] = str(p)
        for p in d.glob("*.npy"):
            files.append(p); mp[p.name] = str(p)
    return mp, files

tok_idx_map, tok_files = build_tok_index(Path(TOKEN_DATA_ROOT))
print("TOKEN indexed files :", len(tok_files))

def resolve_tok_path(p):
    if p is None:
        return None
    s = str(p)
    if s.lower() in ["nan","none",""]:
        return None
    pp = Path(s)
    if pp.exists():
        return str(pp)

    # try relative under TOKEN_DATA_ROOT
    if not pp.is_absolute():
        cand = Path(TOKEN_DATA_ROOT) / pp
        if cand.exists():
            return str(cand)

    # basename lookup
    bn = pp.name
    if bn in tok_idx_map:
        return tok_idx_map[bn]

    # last: try common subdirs
    for sub in ["train","test","train_all","test_all"]:
        cand = Path(TOKEN_DATA_ROOT) / sub / bn
        if cand.exists():
            return str(cand)

    return None


# ----------------------------
# Load train_table.parquet
# ----------------------------
table_candidates = [
    WORK / "recodai_luc_gate_artifacts" / "train_table.parquet",
    WORK / "recodai_luc_hybrid_artifacts" / "train_table.parquet",
    WORK / "recodai_luc_prof" / "train_table.parquet",
]
TRAIN_TABLE = next((p for p in table_candidates if p.exists()), None)
if TRAIN_TABLE is None:
    raise FileNotFoundError("Cannot find train_table.parquet in known locations.")

df = pd.read_parquet(TRAIN_TABLE).copy()
for need in ["case_id","y"]:
    if need not in df.columns:
        raise ValueError(f"train_table missing required col: {need}")

df["case_id"] = pd.to_numeric(df["case_id"], errors="coerce")
df["y"] = pd.to_numeric(df["y"], errors="coerce").fillna(0).astype(int)
df = df[df["case_id"].notna()].copy()
df["case_id"] = df["case_id"].astype(int)

# ----------------------------
# Ensure token path column exists:
#  - if missing in train_table -> join from tokens_manifest_train.parquet
#  - auto-select correct path column in tokens manifest by existence-rate
# ----------------------------
tok_col = None
for c in ["tok_path","token_path","dino_path","feat_path","emb_path","token_npz","tok_npz","npz_path","path","file","npz"]:
    if c in df.columns:
        tok_col = c
        break

df_tok = pd.read_parquet(TOK_MAN_TRAIN).copy()

# ensure case_id in tokens manifest
if "case_id" not in df_tok.columns:
    if "uid" in df_tok.columns:
        df_tok["case_id"] = df_tok["uid"].apply(infer_case_id_any)
    else:
        # try any column with digits
        any_col = df_tok.columns[0]
        df_tok["case_id"] = df_tok[any_col].apply(infer_case_id_any)

df_tok["case_id"] = pd.to_numeric(df_tok["case_id"], errors="coerce")
df_tok = df_tok[df_tok["case_id"].notna()].copy()
df_tok["case_id"] = df_tok["case_id"].astype(int)

# pick best path column in tokens manifest
tok_path_candidates = ["tok_npz","npz_path","path","file","npz","token_npz"]
best_col, best_rate = pick_best_path_col(df_tok, tok_path_candidates, TOKEN_ROOT, tok_idx_map)
if best_col is None:
    raise ValueError("Cannot detect token path column in tokens_manifest_train.parquet.")

df_tok["tok_npz"] = df_tok[best_col].map(lambda x: resolve_by_index(x, TOKEN_ROOT, tok_idx_map))
df_tok["_ok"] = df_tok["tok_npz"].map(lambda x: isinstance(x,str) and Path(x).exists())

# keep best per case_id
def _mt(p):
    try:
        return float(Path(p).stat().st_mtime)
    except Exception:
        return -1.0

df_tok["_mt"] = df_tok["tok_npz"].map(_mt)
df_tok = df_tok.sort_values(["case_id","_ok","_mt"], ascending=[True, False, False]) \
               .drop_duplicates("case_id", keep="first") \
               .drop(columns=["_ok","_mt"])

# if train_table doesn't have token col -> merge
if tok_col is None:
    df = df.merge(df_tok[["case_id","tok_npz"]], on="case_id", how="left")
    tok_col = "tok_npz"
    print(f"INFO: token path column missing; joined from tokens manifest using '{best_col}' (exist-rate≈{best_rate:.3f})")
else:
    # normalize existing tok_col via index too
    df[tok_col] = df[tok_col].map(lambda x: resolve_by_index(x, TOKEN_ROOT, tok_idx_map))

ok_tok = df[tok_col].map(lambda x: isinstance(x,str) and Path(x).exists())
print("Token exists rate:", float(ok_tok.mean()), "| rows:", len(df))
if ok_tok.mean() < 0.50:
    # Hard fail with diagnostics (lebih jelas daripada rows_used=0)
    print("DIAG: tokens manifest best_col =", best_col, "| best_rate≈", best_rate)
    print("DIAG: df_tok columns:", df_tok.columns.tolist()[:50])
    print("DIAG: sample df_tok tok_npz:", df_tok["tok_npz"].dropna().head(5).tolist())
    raise RuntimeError(
        "Token exists rate is too low. This indicates path resolution mismatch.\n"
        "Check TOKEN_ROOT contents vs manifest paths."
    )

df = df[ok_tok].copy()
df = df.drop_duplicates("case_id", keep="first").reset_index(drop=True)
print("TRAIN_TABLE:", TRAIN_TABLE, "| rows_used:", len(df), "| pos_rate:", float(df["y"].mean()))
print("TOK_COL:", tok_col)

# ----------------------------
# MATCH ROOT (optional; fail-safe)
# ----------------------------
def pick_latest_match_root():
    cands = []
    cands += list(CACHE_DIR.glob("match_cfg_*"))
    cands += list(INP.glob("**/recodai_luc/cache/match_cfg_*"))
    cands = [c for c in cands if (c/"cfg.json").exists() and (c/"match_manifest_train.parquet").exists()]
    if not cands:
        return None
    cands = sorted(cands, key=lambda p: (p/"cfg.json").stat().st_mtime, reverse=True)
    return cands[0]

MATCH_ROOT = pick_latest_match_root()
match_map = {}
score_map = None
PATCH = 14
HTOK_MATCH = WTOK_MATCH = None

if MATCH_ROOT is None:
    print("WARNING: MATCH_ROOT not found. Seed will be zeros (still runs).")
else:
    try:
        MATCH_CFG = json.loads((MATCH_ROOT/"cfg.json").read_text())
        PATCH = int(MATCH_CFG.get("patch", MATCH_CFG.get("patch_size", 14)))
        HTOK_MATCH = int(MATCH_CFG.get("Ht", MATCH_CFG.get("htok", MATCH_CFG.get("HTOK", 37))))
        WTOK_MATCH = int(MATCH_CFG.get("Wt", MATCH_CFG.get("wtok", MATCH_CFG.get("WTOK", 37))))

        mtrain_pq = MATCH_ROOT / "match_manifest_train.parquet"
        df_m = pd.read_parquet(mtrain_pq).copy()

        # ensure case_id
        if "case_id" not in df_m.columns:
            if "uid" in df_m.columns:
                df_m["case_id"] = df_m["uid"].apply(infer_case_id_any)
            else:
                # try from any path col later
                df_m["case_id"] = np.nan
        df_m["case_id"] = pd.to_numeric(df_m["case_id"], errors="coerce")
        df_m = df_m[df_m["case_id"].notna()].copy()
        df_m["case_id"] = df_m["case_id"].astype(int)

        # build index for match npz
        match_idx_map, match_files = build_npz_index(MATCH_ROOT)
        # pick best path column
        match_path_candidates = ["match_npz","npz_path","path","file","npz"]
        best_mcol, best_mrate = pick_best_path_col(df_m, match_path_candidates, MATCH_ROOT, match_idx_map)
        if best_mcol is None:
            raise ValueError("Cannot detect match path column in match_manifest_train.parquet.")

        df_m["match_npz"] = df_m[best_mcol].map(lambda x: resolve_by_index(x, MATCH_ROOT, match_idx_map))
        df_m["_ok"] = df_m["match_npz"].map(lambda x: isinstance(x,str) and Path(x).exists())
        df_m = df_m[df_m["_ok"]].copy()

        # pick best per case_id (score or mtime)
        score_cols = [c for c in ["best_peak_score","peak_score_max","max_peak_score","score_max","best_score","peak_max"] if c in df_m.columns]
        if score_cols:
            sc = score_cols[0]
            df_m[sc] = pd.to_numeric(df_m[sc], errors="coerce").fillna(-1)
            df_m = df_m.sort_values(["case_id", sc], ascending=[True, False]).drop_duplicates("case_id", keep="first")
            score_map = df_m.set_index("case_id")[sc].to_dict()
        else:
            df_m["_mt"] = df_m["match_npz"].map(lambda p: Path(p).stat().st_mtime if Path(p).exists() else -1.0)
            df_m = df_m.sort_values(["case_id","_mt"], ascending=[True, False]).drop_duplicates("case_id", keep="first")

        match_map = df_m.set_index("case_id")["match_npz"].to_dict()

        print("MATCH_ROOT:", MATCH_ROOT)
        print("MATCH npz coverage:", len(match_map), "/", df["case_id"].nunique())
        print("PATCH/HTOK_MATCH/WTOK_MATCH:", PATCH, HTOK_MATCH, WTOK_MATCH)
    except Exception as e:
        print("WARNING: match manifest failed -> seed zeros. Reason:", repr(e))
        MATCH_ROOT = None
        match_map = {}
        score_map = None
        PATCH = 14
        HTOK_MATCH = WTOK_MATCH = None

# ----------------------------
# Infer token grid dims + embedding dim
# ----------------------------
def load_tok_npz_any(path_str: str):
    z = np.load(path_str)
    keys = list(z.files)
    if not keys:
        raise ValueError("empty npz")
    pref = ["tok","tokens","grid","token_grid","x","feat","emb","f"]
    a = None
    for k in pref:
        if k in z.files:
            a = z[k]; break
    if a is None:
        a = z[keys[0]]
    a = np.asarray(a)
    if a.ndim != 3:
        raise ValueError(f"Unknown token array shape: {a.shape}")
    return a.astype(np.float32)

tok0 = load_tok_npz_any(df.iloc[0][tok_col])
# normalize dim inference
if tok0.ndim == 3 and tok0.shape[0] < 200 and tok0.shape[1] < 200 and tok0.shape[2] > 32:
    # likely (Ht,Wt,D)
    HTOK, WTOK, D_TOK = int(tok0.shape[0]), int(tok0.shape[1]), int(tok0.shape[2])
elif tok0.ndim == 3 and tok0.shape[0] > 32 and tok0.shape[1] < 200 and tok0.shape[2] < 200:
    # maybe (D,H,W)
    D_TOK, HTOK, WTOK = int(tok0.shape[0]), int(tok0.shape[1]), int(tok0.shape[2])
else:
    # fallback assume (Ht,Wt,D)
    HTOK, WTOK, D_TOK = int(tok0.shape[0]), int(tok0.shape[1]), int(tok0.shape[2])

D_IN = D_TOK + 1
print("TOKEN GRID:", (HTOK, WTOK), "| D:", D_TOK, "| Input channels:", D_IN)

# ----------------------------
# SciPy optional
# ----------------------------
try:
    import scipy.ndimage as ndi
    _HAS_SCIPY = True
except Exception:
    _HAS_SCIPY = False

# ----------------------------
# GT/Seed caches
# ----------------------------
_gt_tok_cache = {}
_seed_tok_cache = {}

def _find_mask_files(mask_dir: Path, case_id: int):
    if mask_dir is None or (not mask_dir.exists()):
        return []
    cid = str(int(case_id))
    exts = (".png",".jpg",".jpeg",".tif",".tiff",".bmp")
    pats = [f"{cid}*.png", f"{cid}*.jpg", f"{cid}*.jpeg", f"{cid}*.tif", f"{cid}*.tiff", f"{cid}*.bmp",
            f"{cid}__*.png", f"{cid}_*.png"]
    out, seen = [], set()
    for pat in pats:
        for p in mask_dir.glob(pat):
            if p.suffix.lower() in exts:
                s = str(p)
                if s not in seen:
                    out.append(p); seen.add(s)
    return sorted(out)

def load_gt_union_full(case_id: int):
    for d in [TRAIN_MASK_DIR, SUP_MASK_DIR]:
        if d is None or (not d.exists()):
            continue
        npy = d / f"{int(case_id)}.npy"
        if npy.exists():
            a = np.load(npy, mmap_mode="r")
            a = np.asarray(a)
            if a.ndim == 2:
                return (a > 0)
            if a.ndim == 3:
                return (a > 0).any(axis=0)

    files = []
    files += _find_mask_files(TRAIN_MASK_DIR, case_id)
    files += _find_mask_files(SUP_MASK_DIR, case_id)
    if not files:
        return None

    m = None
    for p in files:
        try:
            im = Image.open(p).convert("L")
            a = (np.asarray(im) > 0)
            m = a if m is None else (m | a)
        except Exception:
            continue
    return m

def downsample_bool_to_tok(mask_bool: np.ndarray, h=HTOK, w=WTOK):
    if mask_bool is None:
        return np.zeros((h,w), dtype=np.float32)
    im = Image.fromarray((mask_bool.astype(np.uint8)*255))
    im = im.resize((w, h), resample=Image.NEAREST)
    return (np.asarray(im) > 127).astype(np.float32)

def get_gt_tok(case_id: int):
    if case_id in _gt_tok_cache:
        return _gt_tok_cache[case_id]
    gt_full = load_gt_union_full(case_id)
    gt_tok = downsample_bool_to_tok(gt_full, HTOK, WTOK)
    _gt_tok_cache[case_id] = gt_tok
    return gt_tok

def resize_bool_grid(x_bool: np.ndarray, h, w):
    if x_bool is None:
        return np.zeros((h,w), dtype=bool)
    if x_bool.shape == (h,w):
        return x_bool.astype(bool)
    im = Image.fromarray((x_bool.astype(np.uint8)*255))
    im = im.resize((w, h), resample=Image.NEAREST)
    return (np.asarray(im) > 127)

def load_seed_tok(case_id: int, topk=8):
    if case_id in _seed_tok_cache:
        return _seed_tok_cache[case_id]

    p = match_map.get(int(case_id), None)
    if p is None or (not Path(p).exists()):
        out = (np.zeros((HTOK,WTOK), dtype=np.float32), 0)
        _seed_tok_cache[case_id] = out
        return out

    try:
        z = np.load(p)
    except Exception:
        out = (np.zeros((HTOK,WTOK), dtype=np.float32), 0)
        _seed_tok_cache[case_id] = out
        return out

    scores = None
    for k in ["peak_score","scores","score","peak_scores"]:
        if k in z.files:
            scores = np.asarray(z[k]).reshape(-1)
            break

    src = tgt = None
    for a,b in [("src_masks","tgt_masks"), ("src_m","tgt_m"), ("src","tgt")]:
        if a in z.files and b in z.files:
            src = np.asarray(z[a]); tgt = np.asarray(z[b])
            break

    best = int(np.max(scores)) if scores is not None and len(scores) else 0

    if src is None or tgt is None or src.ndim != 3 or tgt.ndim != 3 or src.shape[0] == 0:
        out = (np.zeros((HTOK,WTOK), dtype=np.float32), best)
        _seed_tok_cache[case_id] = out
        return out

    if scores is not None and len(scores) == src.shape[0]:
        idx = np.argsort(scores)[::-1][:min(topk, len(scores))]
        src = src[idx]; tgt = tgt[idx]
    else:
        if src.shape[0] > topk:
            src = src[:topk]; tgt = tgt[:topk]

    seed_match = ((src>0) | (tgt>0)).any(axis=0).astype(bool)

    # if match grid differs, resize to token grid
    if seed_match.shape != (HTOK,WTOK):
        seed_tok = resize_bool_grid(seed_match, HTOK, WTOK).astype(np.float32)
    else:
        seed_tok = seed_match.astype(np.float32)

    out = (seed_tok, best)
    _seed_tok_cache[case_id] = out
    return out

# ----------------------------
# Token loader normalized to (HTOK,WTOK,D)
# ----------------------------
def load_tok_npz(path_str: str):
    a = load_tok_npz_any(path_str)
    a = np.asarray(a)
    if a.ndim != 3:
        raise ValueError(f"Bad token array ndim: {a.ndim}")

    # (Ht,Wt,D)
    if a.shape[0] == HTOK and a.shape[1] == WTOK:
        return a.astype(np.float32)

    # (D,Ht,Wt)
    if a.shape[1] == HTOK and a.shape[2] == WTOK:
        return np.transpose(a, (1,2,0)).astype(np.float32)

    # resize each channel if (H,W,D) mismatched
    H0,W0,D0 = a.shape
    out = np.zeros((HTOK,WTOK,D0), np.float32)
    for d in range(D0):
        im = Image.fromarray(a[:,:,d].astype(np.float32))
        im = im.resize((WTOK,HTOK), resample=Image.BILINEAR)
        out[:,:,d] = np.asarray(im).astype(np.float32)
    return out

# ----------------------------
# Dataset
# ----------------------------
class HybridTokDS(Dataset):
    def __init__(self, df_in: pd.DataFrame, tok_col: str):
        self.df = df_in.reset_index(drop=True)
        self.tok_col = tok_col
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        cid = int(r["case_id"])
        y = float(r["y"])
        tok = load_tok_npz(r[self.tok_col])  # (Ht,Wt,D)
        seed, best_score = load_seed_tok(cid, topk=8)
        gt_tok = get_gt_tok(cid)

        tok_ch = np.transpose(tok, (2,0,1))  # (D,H,W)
        x = np.concatenate([tok_ch, seed[None,:,:]], axis=0).astype(np.float32)

        return {
            "x": torch.from_numpy(x),
            "gt": torch.from_numpy(gt_tok[None,:,:].astype(np.float32)),
            "y": torch.tensor([y], dtype=torch.float32),
            "best_score": torch.tensor([float(best_score)], dtype=torch.float32),
            "case_id": torch.tensor([cid], dtype=torch.int64),
        }

# ----------------------------
# Internal splits (NO-FOLD)
# ----------------------------
base = df[["case_id","y"]].copy().reset_index(drop=True)
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=SEED)
split_id = np.full(len(base), -1, dtype=np.int32)
X_dummy = np.zeros(len(base), dtype=np.uint8)
for k, (_, va) in enumerate(skf.split(X_dummy, base["y"].values)):
    split_id[va] = k
df = df.merge(pd.DataFrame({"case_id": base["case_id"].values, "split": split_id}), on="case_id", how="left")
df["split"] = df["split"].astype(int)
split_ids = sorted(df["split"].unique().tolist())
print("Internal splits:", split_ids, "| N_SPLITS:", N_SPLITS)

# ----------------------------
# Morphology/CC + Dice
# ----------------------------
def dilate_tok(x_bool, it=1):
    if it <= 0: return x_bool.astype(bool)
    x = x_bool.astype(bool)
    if _HAS_SCIPY:
        return ndi.binary_dilation(x, iterations=it)
    for _ in range(it):
        xp = np.pad(x, 1, mode="constant", constant_values=False)
        y = np.zeros_like(x, dtype=bool)
        for dy in (-1,0,1):
            for dx in (-1,0,1):
                y |= xp[1+dy:1+dy+x.shape[0], 1+dx:1+dx+x.shape[1]]
        x = y
    return x

def label_cc(x_bool):
    x = x_bool.astype(bool)
    if _HAS_SCIPY:
        lab, n = ndi.label(x, structure=np.ones((3,3), dtype=np.uint8))
        return lab, int(n)
    H,W = x.shape
    lab = np.zeros((H,W), dtype=np.int32); cur=0
    for y0 in range(H):
        for x0 in range(W):
            if (not x[y0,x0]) or lab[y0,x0]!=0: continue
            cur += 1
            st=[(y0,x0)]; lab[y0,x0]=cur
            while st:
                yy,xx=st.pop()
                for dy in (-1,0,1):
                    for dx in (-1,0,1):
                        if dy==0 and dx==0: continue
                        ny,nx=yy+dy,xx+dx
                        if 0<=ny<H and 0<=nx<W and x[ny,nx] and lab[ny,nx]==0:
                            lab[ny,nx]=cur; st.append((ny,nx))
    return lab, int(cur)

def inst_split_union_tok(mask_bool, min_area=2, max_area_frac=0.8, max_keep=8):
    H,W = mask_bool.shape
    lab,n = label_cc(mask_bool)
    if n<=0:
        return np.zeros((H,W), dtype=bool), 0
    insts=[]; areas=[]
    for k in range(1,n+1):
        m = (lab==k)
        a = int(m.sum())
        if a < min_area: continue
        if a / float(H*W) > max_area_frac: continue
        insts.append(m); areas.append(a)
    if not insts:
        return np.zeros((H,W), dtype=bool), 0
    order = np.argsort(np.asarray(areas))[::-1][:max_keep]
    uni = np.zeros((H,W), dtype=bool)
    for i in order:
        uni |= insts[i]
    return uni, int(len(order))

def dice(pr_bool, gt_bool):
    a=int(pr_bool.sum()); b=int(gt_bool.sum())
    if a==0 and b==0: return 1.0
    if a==0 or b==0: return 0.0
    inter=int((pr_bool & gt_bool).sum())
    return float((2.0*inter)/(a+b))

# ----------------------------
# Model (odd-grid safe)
# ----------------------------
def _best_gn_groups(c, max_groups=32):
    for g in [32,16,8,4,2,1]:
        if g <= max_groups and c % g == 0:
            return g
    return 1

class ConvBNAct(nn.Module):
    def __init__(self, c_in, c_out, k=3, s=1, p=1, drop=0.0):
        super().__init__()
        self.conv = nn.Conv2d(c_in, c_out, k, s, p, bias=False)
        self.bn = nn.BatchNorm2d(c_out)
        self.act = nn.SiLU(inplace=True)
        self.drop = nn.Dropout2d(drop) if drop > 0 else nn.Identity()
    def forward(self, x):
        return self.drop(self.act(self.bn(self.conv(x))))

class ASPP(nn.Module):
    def __init__(self, c_in, c_out, rates=(1,2,4)):
        super().__init__()
        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c_in, c_out, 3, padding=r, dilation=r, bias=False),
                nn.BatchNorm2d(c_out),
                nn.SiLU(inplace=True),
            ) for r in rates
        ])
        self.proj = nn.Sequential(
            nn.Conv2d(len(rates)*c_out, c_out, 1, bias=False),
            nn.BatchNorm2d(c_out),
            nn.SiLU(inplace=True),
        )
    def forward(self, x):
        ys = [b(x) for b in self.blocks]
        y = torch.cat(ys, dim=1)
        return self.proj(y)

class HybridUNet(nn.Module):
    def __init__(self, c_in, base_ch=96, drop=0.1):
        super().__init__()
        c1, c2, c3 = base_ch, base_ch*2, base_ch*3
        self.in_norm = nn.GroupNorm(_best_gn_groups(c_in), c_in)

        self.e1 = nn.Sequential(ConvBNAct(c_in, c1, drop=drop), ConvBNAct(c1, c1, drop=drop))
        self.p1 = nn.MaxPool2d(2)
        self.e2 = nn.Sequential(ConvBNAct(c1, c2, drop=drop), ConvBNAct(c2, c2, drop=drop))
        self.p2 = nn.MaxPool2d(2)

        self.e3 = nn.Sequential(ConvBNAct(c2, c3, drop=drop), ConvBNAct(c3, c3, drop=drop))
        self.aspp = ASPP(c3, c3, rates=(1,2,4))

        self.u2 = nn.ConvTranspose2d(c3, c2, 2, stride=2)
        self.d2 = nn.Sequential(ConvBNAct(c2+c2, c2, drop=drop), ConvBNAct(c2, c2, drop=drop))

        self.u1 = nn.ConvTranspose2d(c2, c1, 2, stride=2)
        self.d1 = nn.Sequential(ConvBNAct(c1+c1, c1, drop=drop), ConvBNAct(c1, c1, drop=drop))

        self.seg_head = nn.Conv2d(c1, 1, 1)
        self.cls_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(c3, c3//2),
            nn.SiLU(inplace=True),
            nn.Dropout(drop if drop>0 else 0.0),
            nn.Linear(c3//2, 1),
        )

    def forward(self, x):
        x = self.in_norm(x)
        e1 = self.e1(x)
        e2 = self.e2(self.p1(e1))
        e3 = self.e3(self.p2(e2))
        b  = self.aspp(e3)

        cls_logit = self.cls_head(b)

        d2 = self.u2(b)
        if d2.shape[-2:] != e2.shape[-2:]:
            d2 = F.interpolate(d2, size=e2.shape[-2:], mode="bilinear", align_corners=False)
        d2 = self.d2(torch.cat([d2, e2], dim=1))

        d1 = self.u1(d2)
        if d1.shape[-2:] != e1.shape[-2:]:
            d1 = F.interpolate(d1, size=e1.shape[-2:], mode="bilinear", align_corners=False)
        d1 = self.d1(torch.cat([d1, e1], dim=1))

        seg_logit = self.seg_head(d1)
        return seg_logit, cls_logit

# ----------------------------
# Loss
# ----------------------------
def dice_loss_from_logits(logits, target, eps=1e-6):
    p = torch.sigmoid(logits)
    inter = (p * target).sum(dim=(2,3))
    den = (p + target).sum(dim=(2,3)) + eps
    d = (2.0 * inter) / den
    return 1.0 - d.mean()

def bce_focal_from_logits(logits, target, gamma=2.0):
    bce = F.binary_cross_entropy_with_logits(logits, target, reduction="none")
    if gamma <= 0:
        return bce.mean()
    p = torch.sigmoid(logits)
    pt = target * p + (1-target) * (1-p)
    w = (1-pt).pow(gamma)
    return (w * bce).mean()

@torch.no_grad()
def eval_model(model, dl, cfg_pp):
    model.eval()
    scores = []

    T1 = cfg_pp["T1"]; T0 = cfg_pp["T0"]; dil_it = cfg_pp["seed_dilate_it"]
    thr_gate = cfg_pp["thr_gate"]
    min_tok_area = cfg_pp["min_tok_area"]
    max_tok_area_frac = cfg_pp["max_tok_area_frac"]
    max_inst_keep = cfg_pp["max_inst_keep"]
    min_peak_keep = cfg_pp["min_peak_score_keep"]
    min_area_frac_keep = cfg_pp["min_area_frac_keep"]

    for batch in dl:
        x = batch["x"].to(device, non_blocking=True)
        gt = batch["gt"].to(device, non_blocking=True)
        best_score = batch["best_score"].cpu().numpy().reshape(-1)

        seg_logit, cls_logit = model(x)
        p_gate = torch.sigmoid(cls_logit).detach().cpu().numpy().reshape(-1)
        p_tok = torch.sigmoid(seg_logit).detach().cpu().numpy()[:,0]
        seed = x.detach().cpu().numpy()[:,-1]
        gt_np = gt.detach().cpu().numpy()[:,0] > 0.5

        for i in range(x.shape[0]):
            if p_gate[i] < thr_gate:
                pr = np.zeros((HTOK,WTOK), dtype=bool)
            else:
                prob = p_tok[i]
                hard = prob >= T1
                soft = prob >= T0
                sd = dilate_tok(seed[i] > 0.5, dil_it)
                fused = hard | (sd & soft)

                uni, _ = inst_split_union_tok(
                    fused, min_area=min_tok_area, max_area_frac=max_tok_area_frac, max_keep=max_inst_keep
                )
                area_frac = float(uni.mean())
                if (best_score[i] < min_peak_keep) and (area_frac < min_area_frac_keep):
                    uni = np.zeros((HTOK,WTOK), dtype=bool)
                pr = uni

            scores.append(dice(pr, gt_np[i]))

    return float(np.mean(scores)) if scores else 0.0

# ----------------------------
# Trial sampling + train
# ----------------------------
def sample_trial_cfg(trial_id: int, val_split: int):
    def log_uniform(a,b):
        return float(np.exp(np.random.uniform(np.log(a), np.log(b))))
    lr = log_uniform(*LR_RANGE)
    wd = float(np.random.uniform(*WD_RANGE))
    drop = float(np.random.uniform(*DROPOUT_RANGE))
    base_ch = int(np.random.choice(BASE_CH_CHOICES))
    lam_seg = float(np.random.uniform(*LAMBDA_SEG_RANGE))
    lam_cls = float(np.random.uniform(*LAMBDA_CLS_RANGE))
    gamma = float(np.random.choice(FOCAL_GAMMA_CHOICES))

    T1 = float(np.random.uniform(*T1_RANGE))
    T0 = float(np.random.uniform(*T0_RANGE))
    if T0 > T1:
        T0, T1 = max(0.05, T1-0.05), T1

    dil_it = int(np.random.choice(SEED_DILATE_CHOICES))
    thr_gate = float(np.random.uniform(*THR_GATE_RANGE))

    min_tok_area = int(np.random.choice(MIN_TOK_AREA_CHOICES))
    max_tok_area_frac = float(np.random.choice(MAX_TOK_AREA_FRAC_CHOICES))
    max_inst_keep = int(np.random.choice(MAX_INST_KEEP_CHOICES))

    min_peak_keep = int(np.random.choice(MIN_PEAK_SCORE_KEEP_CHOICES))
    min_area_frac_keep = float(np.random.choice(MIN_AREA_FRAC_KEEP_CHOICES))

    return {
        "trial_id": int(trial_id),
        "val_split": int(val_split),
        "lr": lr, "weight_decay": wd, "dropout": drop, "base_ch": base_ch,
        "lambda_seg": lam_seg, "lambda_cls": lam_cls, "focal_gamma": gamma,
        "T1": T1, "T0": T0, "seed_dilate_it": dil_it, "thr_gate": thr_gate,
        "min_tok_area": min_tok_area, "max_tok_area_frac": max_tok_area_frac, "max_inst_keep": max_inst_keep,
        "min_peak_score_keep": min_peak_keep, "min_area_frac_keep": min_area_frac_keep,
    }

def train_trial(cfg_trial):
    val_split = cfg_trial["val_split"]
    df_tr = df[df["split"] != val_split].reset_index(drop=True)
    df_va = df[df["split"] == val_split].reset_index(drop=True)

    # oversample positives
    pos = df_tr[df_tr["y"]==1]
    neg = df_tr[df_tr["y"]==0]
    if len(pos) > 0 and len(neg) > 0:
        take = min(len(neg), len(pos)*3)
        neg_s = neg.sample(n=take, replace=False, random_state=SEED)
        pos_s = pos.sample(n=take, replace=True,  random_state=SEED)
        df_tr_use = pd.concat([neg_s, pos_s], axis=0).sample(frac=1.0, random_state=SEED).reset_index(drop=True)
    else:
        df_tr_use = df_tr

    ds_tr = HybridTokDS(df_tr_use, tok_col=tok_col)
    ds_va = HybridTokDS(df_va,     tok_col=tok_col)

    dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True,  num_workers=NUM_WORKERS,
                       pin_memory=(device.type=="cuda"), drop_last=True)
    dl_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS,
                       pin_memory=(device.type=="cuda"), drop_last=False)

    model = HybridUNet(c_in=D_IN, base_ch=cfg_trial["base_ch"], drop=cfg_trial["dropout"]).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=cfg_trial["lr"], weight_decay=cfg_trial["weight_decay"])
    scaler = torch.cuda.amp.GradScaler(enabled=amp_ok)

    best_score = -1.0
    best_state = None
    bad = 0

    for ep in range(1, TRIAL_EPOCHS+1):
        model.train()
        t0 = time.time()
        loss_meter = 0.0
        nsteps = 0
        opt.zero_grad(set_to_none=True)

        for step, batch in enumerate(dl_tr, start=1):
            x  = batch["x"].to(device, non_blocking=True)
            gt = batch["gt"].to(device, non_blocking=True)
            yb = batch["y"].to(device, non_blocking=True)

            with torch.cuda.amp.autocast(enabled=amp_ok):
                seg_logit, cls_logit = model(x)
                l_seg = bce_focal_from_logits(seg_logit, gt, gamma=cfg_trial["focal_gamma"]) + dice_loss_from_logits(seg_logit, gt)
                l_cls = F.binary_cross_entropy_with_logits(cls_logit, yb)
                loss = cfg_trial["lambda_seg"] * l_seg + cfg_trial["lambda_cls"] * l_cls
                loss = loss / ACCUM_STEPS

            scaler.scale(loss).backward()
            if step % ACCUM_STEPS == 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)

            loss_meter += float(loss.item()) * ACCUM_STEPS
            nsteps += 1

        score = eval_model(model, dl_va, cfg_trial)
        dt = time.time() - t0
        print(f"[trial {cfg_trial['trial_id']:02d} | split {val_split}] ep {ep}/{TRIAL_EPOCHS} "
              f"loss={loss_meter/max(1,nsteps):.4f} val_dice_proxy={score:.5f} time={dt:.1f}s")

        if score > best_score + 1e-5:
            best_score = score
            best_state = {k: v.detach().cpu() for k,v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= EARLYSTOP_PATIENCE:
                break

    return best_score, best_state

# ----------------------------
# Run trials
# ----------------------------
trials = []
global_best = {"score": -1.0, "cfg": None, "state": None}

t_all = time.time()

for t in range(1, MAX_TRIALS+1):
    val_split = split_ids[(t-1) % len(split_ids)] if VAL_SPLIT_ROTATE else split_ids[0]
    cfg_trial = sample_trial_cfg(t, val_split)

    try:
        score, state = train_trial(cfg_trial)
        cfg_trial["score"] = float(score)
        cfg_trial["status"] = "ok"
    except Exception as e:
        print(f"[trial {t}] FAILED:", repr(e))
        cfg_trial["score"] = float("nan")
        cfg_trial["status"] = "fail"
        trials.append(cfg_trial)
        pd.DataFrame(trials).to_csv(OUT_DIR / "trials.csv", index=False)
        continue

    trials.append(cfg_trial)

    if score > global_best["score"]:
        global_best["score"] = float(score)
        global_best["cfg"] = cfg_trial
        global_best["state"] = state

    pd.DataFrame(trials).to_csv(OUT_DIR / "trials.csv", index=False)
    (OUT_DIR / "best_config.json").write_text(json.dumps(global_best["cfg"], indent=2) if global_best["cfg"] else "{}")

    print("-"*60)
    print("CURRENT BEST:", global_best["score"], "| trial:", global_best["cfg"]["trial_id"], "| val_split:", global_best["cfg"]["val_split"])
    print("-"*60)

if global_best["state"] is not None:
    pack = {
        "model_type": "HybridUNet",
        "input_channels": int(D_IN),
        "HTOK": int(HTOK), "WTOK": int(WTOK), "PATCH": int(PATCH),
        "tok_col": str(tok_col),
        "token_manifest_train": str(TOK_MAN_TRAIN),
        "token_root": str(TOKEN_ROOT),
        "match_root": str(MATCH_ROOT) if MATCH_ROOT else None,
        "train_table": str(TRAIN_TABLE),
        "mask_dirs": {"TRAIN_MASK_DIR": str(TRAIN_MASK_DIR), "SUP_MASK_DIR": str(SUP_MASK_DIR)},
        "best_cfg": global_best["cfg"],
        "state_dict": global_best["state"],
        "meta": {"seed": int(SEED), "n_splits": int(N_SPLITS), "val_split_rotate": bool(VAL_SPLIT_ROTATE)},
    }
    torch.save(pack, OUT_DIR / "best_model.pt")

print("DONE in", f"{time.time()-t_all:.1f}s")
print("Saved:", OUT_DIR / "trials.csv")
print("Saved:", OUT_DIR / "best_config.json")
print("Saved:", OUT_DIR / "best_model.pt" if (OUT_DIR / "best_model.pt").exists() else "(no best_model.pt)")

DEVICE: cpu | AMP: False
COMP_ROOT     : /kaggle/input/recodai-luc-scientific-image-forgery-detection
TRAIN_MASK_DIR: /kaggle/input/recodai-luc-scientific-image-forgery-detection/train_masks | exists: True
SUP_MASK_DIR  : /kaggle/input/recodai-luc-scientific-image-forgery-detection/supplemental_masks | exists: True
TOK_MAN_TRAIN       : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/dinov2_base_518_cfg_543289469500/tokens_manifest_train.parquet
TOKEN_MANIFEST_ROOT : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/dinov2_base_518_cfg_543289469500 | token_files: 2796
BASE_CACHE          : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache
TOKEN_DATA_ROOT     : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/dinov2_base_518_cfg_543289469500 | token_files: 2796
TOKEN indexed files : 2796
INFO: token path column missing; joined from tokens manifest using 'npz_path' (exist-rate≈1.000)
Token exists rate: 1.0 | rows: 2795
TRAIN_TABLE: /kaggle/working/recodai_luc_g

# Final Training (Train on Full Data)

In [7]:
# ============================================================
# STAGE — Final Training (Train on Full Data) (ONE CELL) — HYBRID (OPSI-1) — REVISI FULL
# - Token & Match path auto-resolve via FILE INDEX (works with *_bind_* manifests)
# - Fail-safe seed (if match missing/broken -> zeros, still runs)
# - Export mask-prob cache: /kaggle/working/recodai_luc/cache/mask_prob_hybrid_<cfgid>/{case_id}.npz
#   keys: prob_tok (float16), p_gate (float16)
# - Sweep best gate threshold on TRAIN using Dice-proxy (token-space)
# - Save final model: /kaggle/working/recodai_luc_hybrid_artifacts/final_hybrid_model.pt
# ============================================================

import os, json, time, math, random, hashlib, warnings, re
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# Paths
# ----------------------------
WORK     = Path("/kaggle/working")
INP      = Path("/kaggle/input")

OPT_DIR  = WORK / "recodai_luc_hybrid_opt"
OUT_DIR  = WORK / "recodai_luc_hybrid_artifacts"
PROF_DIR = WORK / "recodai_luc_prof"
CACHE_W  = WORK / "recodai_luc/cache"  # working cache (if exists)

OUT_DIR.mkdir(parents=True, exist_ok=True)
CACHE_W.mkdir(parents=True, exist_ok=True)

best_cfg_path = OPT_DIR / "best_config.json"
if not best_cfg_path.exists():
    raise FileNotFoundError(f"Missing {best_cfg_path}. Run Optimize stage first.")
BEST = json.loads(best_cfg_path.read_text())

paths_json = PROF_DIR / "paths.json"
if not paths_json.exists():
    raise FileNotFoundError(f"Missing {paths_json}")
PATHS = json.loads(paths_json.read_text())

TRAIN_MASK_DIR = Path(PATHS.get("TRAIN_MASK_DIR","")) if PATHS.get("TRAIN_MASK_DIR") else None
SUP_MASK_DIR   = Path(PATHS.get("SUP_MASK_DIR","")) if PATHS.get("SUP_MASK_DIR") else None

# ----------------------------
# Repro / device
# ----------------------------
SEED = int(os.environ.get("SEED", "42"))
def seed_everything(s=42):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)

seed_everything(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = bool(int(os.environ.get("USE_AMP", "1"))) and (device.type == "cuda")
print("DEVICE:", device, "| AMP:", USE_AMP)

# ----------------------------
# Helpers: ids
# ----------------------------
def infer_case_id_any(x):
    if x is None:
        return np.nan
    s = str(x)
    if s.lower() in ["nan","none",""]:
        return np.nan
    m = re.search(r"\b(\d{3,})\b", s)
    if m:
        return float(m.group(1))
    m = re.search(r"(\d+)", s)
    return float(m.group(1)) if m else np.nan

# ----------------------------
# Helpers: token manifest + token data root resolver (bind-safe)
# ----------------------------
def _count_tok_files(root: Path):
    if root is None or (not root.exists()):
        return 0
    total = 0
    for sub in ["train","test","train_all","test_all"]:
        d = root / sub
        if d.exists():
            total += sum(1 for _ in d.glob("*.npz"))
            total += sum(1 for _ in d.glob("*.npy"))
    return total

def pick_best_tokens_manifest_train():
    hits = []
    hits += list(WORK.glob("**/tokens_manifest_train.parquet"))
    hits += list(INP.glob("**/tokens_manifest_train.parquet"))
    hits = [p for p in hits if p.exists()]
    if not hits:
        return None

    def score(p: Path):
        r = p.parent
        n_tok = _count_tok_files(r)
        has_train_dir = int((r/"train").exists() or (r/"train_all").exists())
        return (int(n_tok > 0), int(n_tok), has_train_dir, float(p.stat().st_mtime))

    hits = sorted(hits, key=score, reverse=True)
    return hits[0]

TOK_MAN_TRAIN = pick_best_tokens_manifest_train()
if TOK_MAN_TRAIN is None:
    raise FileNotFoundError("tokens_manifest_train.parquet not found. Token cache must exist.")

TOKEN_MANIFEST_ROOT = TOK_MAN_TRAIN.parent
BASE_CACHE = TOKEN_MANIFEST_ROOT.parent if TOKEN_MANIFEST_ROOT.parent.exists() else TOKEN_MANIFEST_ROOT

# choose best sibling that has real token files
sib_dirs = [p for p in BASE_CACHE.iterdir() if p.is_dir()]
sib_scored = sorted([(_count_tok_files(d), d) for d in sib_dirs], key=lambda x: x[0], reverse=True)
best_cnt, best_root = sib_scored[0] if sib_scored else (0, None)

manifest_cnt = _count_tok_files(TOKEN_MANIFEST_ROOT)
TOKEN_DATA_ROOT = TOKEN_MANIFEST_ROOT if manifest_cnt > 0 else best_root

print("TOK_MAN_TRAIN       :", TOK_MAN_TRAIN)
print("TOKEN_MANIFEST_ROOT :", TOKEN_MANIFEST_ROOT, "| token_files:", manifest_cnt)
print("BASE_CACHE          :", BASE_CACHE)
print("TOKEN_DATA_ROOT     :", TOKEN_DATA_ROOT, "| token_files:", best_cnt)

if TOKEN_DATA_ROOT is None or (not Path(TOKEN_DATA_ROOT).exists()) or _count_tok_files(Path(TOKEN_DATA_ROOT)) == 0:
    raise RuntimeError(
        "No token .npz/.npy found for TOKEN_DATA_ROOT.\n"
        f"Checked siblings under: {BASE_CACHE}"
    )

def build_tok_index(data_root: Path):
    mp = {}
    files = []
    for sub in ["train","test","train_all","test_all"]:
        d = data_root / sub
        if not d.exists():
            continue
        for p in d.glob("*.npz"):
            files.append(p); mp[p.name] = str(p)
        for p in d.glob("*.npy"):
            files.append(p); mp[p.name] = str(p)
    return mp, files

tok_idx_map, tok_files = build_tok_index(Path(TOKEN_DATA_ROOT))
print("TOKEN indexed files :", len(tok_files))

def resolve_tok_path(p):
    if p is None:
        return None
    s = str(p)
    if s.lower() in ["nan","none",""]:
        return None
    pp = Path(s)
    if pp.exists():
        return str(pp)

    # relative under TOKEN_DATA_ROOT
    if not pp.is_absolute():
        cand = Path(TOKEN_DATA_ROOT) / pp
        if cand.exists():
            return str(cand)

    # basename lookup
    bn = pp.name
    if bn in tok_idx_map:
        return tok_idx_map[bn]

    # try add extension
    if "." not in bn:
        if (bn + ".npz") in tok_idx_map:
            return tok_idx_map[bn + ".npz"]
        if (bn + ".npy") in tok_idx_map:
            return tok_idx_map[bn + ".npy"]

    # last: common subdirs
    for sub in ["train","test","train_all","test_all"]:
        cand = Path(TOKEN_DATA_ROOT) / sub / bn
        if cand.exists():
            return str(cand)
        if "." not in bn:
            cand2 = Path(TOKEN_DATA_ROOT) / sub / (bn + ".npz")
            if cand2.exists():
                return str(cand2)
            cand3 = Path(TOKEN_DATA_ROOT) / sub / (bn + ".npy")
            if cand3.exists():
                return str(cand3)

    return None

def pick_best_path_col(df, candidates):
    best = (None, -1.0)
    for c in candidates:
        if c not in df.columns:
            continue
        rr = df[c].map(resolve_tok_path)
        ex = rr.map(lambda x: isinstance(x,str) and Path(x).exists()).mean() if len(rr) else 0.0
        if ex > best[1]:
            best = (c, float(ex))
    return best[0], best[1]

# ----------------------------
# Find train/test tables
# ----------------------------
train_table_cands = [
    WORK / "recodai_luc_gate_artifacts" / "train_table.patched_tokens.parquet",
    WORK / "recodai_luc_gate_artifacts" / "train_table.parquet",
    WORK / "recodai_luc_hybrid_artifacts" / "train_table.parquet",
    PROF_DIR / "train_table.parquet",
]
test_table_cands = [
    WORK / "recodai_luc_gate_artifacts" / "test_table.patched_tokens.parquet",
    WORK / "recodai_luc_gate_artifacts" / "test_table.parquet",
    WORK / "recodai_luc_hybrid_artifacts" / "test_table.parquet",
    PROF_DIR / "test_table.parquet",
]

TRAIN_TABLE = next((p for p in train_table_cands if p.exists()), None)
TEST_TABLE  = next((p for p in test_table_cands if p.exists()), None)
if TRAIN_TABLE is None:
    raise FileNotFoundError("Cannot find train_table.parquet. Run Build Training Table stage first.")

df_tr = pd.read_parquet(TRAIN_TABLE).copy()
for need in ["case_id","y"]:
    if need not in df_tr.columns:
        raise ValueError(f"train_table missing required col: {need}")
df_tr["case_id"] = pd.to_numeric(df_tr["case_id"], errors="coerce")
df_tr = df_tr[df_tr["case_id"].notna()].copy()
df_tr["case_id"] = df_tr["case_id"].astype(int)
df_tr["y"] = pd.to_numeric(df_tr["y"], errors="coerce").fillna(0).astype(int)

df_te = None
if TEST_TABLE is not None and Path(TEST_TABLE).exists():
    df_te = pd.read_parquet(TEST_TABLE).copy()
    if "case_id" not in df_te.columns:
        raise ValueError("test_table missing case_id")
    df_te["case_id"] = pd.to_numeric(df_te["case_id"], errors="coerce")
    df_te = df_te[df_te["case_id"].notna()].copy()
    df_te["case_id"] = df_te["case_id"].astype(int)

# ----------------------------
# Ensure token path columns exist in train/test (join from tokens manifests if needed)
# ----------------------------
tok_col_tr = None
for c in ["tok_path","token_path","dino_path","feat_path","emb_path","token_npz","tok_npz","npz_path","path","file","npz"]:
    if c in df_tr.columns:
        tok_col_tr = c
        break

df_tok_tr = pd.read_parquet(TOK_MAN_TRAIN).copy()
if "case_id" not in df_tok_tr.columns:
    if "uid" in df_tok_tr.columns:
        df_tok_tr["case_id"] = df_tok_tr["uid"].apply(infer_case_id_any)
    else:
        df_tok_tr["case_id"] = df_tok_tr[df_tok_tr.columns[0]].apply(infer_case_id_any)

df_tok_tr["case_id"] = pd.to_numeric(df_tok_tr["case_id"], errors="coerce")
df_tok_tr = df_tok_tr[df_tok_tr["case_id"].notna()].copy()
df_tok_tr["case_id"] = df_tok_tr["case_id"].astype(int)

tok_path_candidates = ["tok_npz","npz_path","path","file","npz","token_npz"]
best_col_tr, best_rate_tr = pick_best_path_col(df_tok_tr, tok_path_candidates)
if best_col_tr is None:
    raise RuntimeError("Cannot detect token path column in tokens_manifest_train.parquet.")

df_tok_tr["tok_npz"] = df_tok_tr[best_col_tr].map(resolve_tok_path)
df_tok_tr["_ok"] = df_tok_tr["tok_npz"].map(lambda x: isinstance(x,str) and Path(x).exists())

def _mt(p):
    try:
        return float(Path(p).stat().st_mtime)
    except Exception:
        return -1.0

df_tok_tr["_mt"] = df_tok_tr["tok_npz"].map(_mt)
df_tok_tr = df_tok_tr.sort_values(["case_id","_ok","_mt"], ascending=[True, False, False]) \
                     .drop_duplicates("case_id", keep="first") \
                     .drop(columns=["_ok","_mt"])

if tok_col_tr is None:
    df_tr = df_tr.merge(df_tok_tr[["case_id","tok_npz"]], on="case_id", how="left")
    tok_col_tr = "tok_npz"
    print(f"INFO: train token col missing -> joined from TOK_MAN_TRAIN using '{best_col_tr}' (exist-rate≈{best_rate_tr:.3f})")
else:
    df_tr[tok_col_tr] = df_tr[tok_col_tr].map(resolve_tok_path)

ok_tr = df_tr[tok_col_tr].map(lambda x: isinstance(x,str) and Path(x).exists())
print("TRAIN token exists rate:", float(ok_tr.mean()), "| rows:", len(df_tr))
if ok_tr.mean() < 0.50:
    print("DIAG train best_col:", best_col_tr, "| best_rate≈", best_rate_tr)
    print("DIAG sample tok_npz:", df_tok_tr["tok_npz"].dropna().head(5).tolist())
    raise RuntimeError("TRAIN token exists rate too low -> path mismatch. Fix TOKEN_DATA_ROOT / manifest mapping.")
df_tr = df_tr[ok_tr].drop_duplicates("case_id", keep="first").reset_index(drop=True)

# test tokens: try best manifest test if needed
tok_col_te = tok_col_tr
if df_te is not None:
    tok_col_te = None
    for c in ["tok_path","token_path","dino_path","feat_path","emb_path","token_npz","tok_npz","npz_path","path","file","npz"]:
        if c in df_te.columns:
            tok_col_te = c
            break

    # locate tokens_manifest_test.parquet (prefer sibling of train manifest)
    TOK_MAN_TEST = None
    cand = TOKEN_MANIFEST_ROOT / "tokens_manifest_test.parquet"
    if cand.exists():
        TOK_MAN_TEST = cand
    else:
        hits = list(WORK.glob("**/tokens_manifest_test.parquet")) + list(INP.glob("**/tokens_manifest_test.parquet"))
        hits = [p for p in hits if p.exists()]
        TOK_MAN_TEST = sorted(hits, key=lambda p: p.stat().st_mtime, reverse=True)[0] if hits else None

    if tok_col_te is None:
        if TOK_MAN_TEST is None:
            print("WARNING: test token col missing and tokens_manifest_test.parquet not found -> skip test export.")
            df_te = None
            tok_col_te = tok_col_tr
        else:
            df_tok_te = pd.read_parquet(TOK_MAN_TEST).copy()
            if "case_id" not in df_tok_te.columns:
                if "uid" in df_tok_te.columns:
                    df_tok_te["case_id"] = df_tok_te["uid"].apply(infer_case_id_any)
                else:
                    df_tok_te["case_id"] = df_tok_te[df_tok_te.columns[0]].apply(infer_case_id_any)
            df_tok_te["case_id"] = pd.to_numeric(df_tok_te["case_id"], errors="coerce")
            df_tok_te = df_tok_te[df_tok_te["case_id"].notna()].copy()
            df_tok_te["case_id"] = df_tok_te["case_id"].astype(int)

            best_col_te, best_rate_te = pick_best_path_col(df_tok_te, tok_path_candidates)
            if best_col_te is None:
                print("WARNING: cannot detect token path column in tokens_manifest_test.parquet -> skip test export.")
                df_te = None
                tok_col_te = tok_col_tr
            else:
                df_tok_te["tok_npz"] = df_tok_te[best_col_te].map(resolve_tok_path)
                df_tok_te["_ok"] = df_tok_te["tok_npz"].map(lambda x: isinstance(x,str) and Path(x).exists())
                df_tok_te["_mt"] = df_tok_te["tok_npz"].map(_mt)
                df_tok_te = df_tok_te.sort_values(["case_id","_ok","_mt"], ascending=[True, False, False]) \
                                     .drop_duplicates("case_id", keep="first") \
                                     .drop(columns=["_ok","_mt"])
                df_te = df_te.merge(df_tok_te[["case_id","tok_npz"]], on="case_id", how="left")
                tok_col_te = "tok_npz"
                df_te[tok_col_te] = df_te[tok_col_te].map(resolve_tok_path)
                ok_te = df_te[tok_col_te].map(lambda x: isinstance(x,str) and Path(x).exists())
                print("TEST token exists rate:", float(ok_te.mean()), "| rows:", len(df_te))
                df_te = df_te[ok_te].drop_duplicates("case_id", keep="first").reset_index(drop=True)
    else:
        df_te[tok_col_te] = df_te[tok_col_te].map(resolve_tok_path)
        ok_te = df_te[tok_col_te].map(lambda x: isinstance(x,str) and Path(x).exists())
        print("TEST token exists rate:", float(ok_te.mean()), "| rows:", len(df_te))
        df_te = df_te[ok_te].drop_duplicates("case_id", keep="first").reset_index(drop=True)

print("TRAIN_TABLE:", TRAIN_TABLE, "| rows_used:", len(df_tr), "| pos_rate:", float(df_tr["y"].mean()))
print("TEST_TABLE :", TEST_TABLE if df_te is not None else "(none/skip)")
print("TOK_COL_TR :", tok_col_tr)
print("TOK_COL_TE :", tok_col_te)

# ----------------------------
# Infer token grid dims + token dim from first train token
# ----------------------------
def load_tok_any(path_str: str):
    p = Path(path_str)
    if p.suffix.lower() == ".npy":
        a = np.load(p, mmap_mode="r")
    else:
        z = np.load(p)
        for k in ["tok","tokens","grid","token_grid","x","feat","emb","f"]:
            if k in z.files:
                a = z[k]; break
        else:
            keys = list(z.files)
            if not keys:
                raise ValueError("empty npz")
            a = z[keys[0]]
    return np.asarray(a)

tok0 = load_tok_any(df_tr.iloc[0][tok_col_tr])
if tok0.ndim != 3:
    raise RuntimeError(f"Token array ndim must be 3, got {tok0.shape}")

# decide layout
if tok0.shape[0] < 200 and tok0.shape[1] < 200 and tok0.shape[2] > 32:
    # (Ht,Wt,D)
    HTOK, WTOK, DIN = int(tok0.shape[0]), int(tok0.shape[1]), int(tok0.shape[2])
elif tok0.shape[0] > 32 and tok0.shape[1] < 200 and tok0.shape[2] < 200:
    # (D,H,W)
    DIN, HTOK, WTOK = int(tok0.shape[0]), int(tok0.shape[1]), int(tok0.shape[2])
else:
    HTOK, WTOK, DIN = int(tok0.shape[0]), int(tok0.shape[1]), int(tok0.shape[2])

CIN = DIN + 1
print("TOKEN GRID:", (HTOK, WTOK), "| DIN:", DIN, "| CIN:", CIN)

def load_tok_norm(path_str: str):
    a = load_tok_any(path_str).astype(np.float32)
    if a.ndim != 3:
        raise ValueError(f"Unknown token array shape: {a.shape}")

    # (Ht,Wt,D)
    if a.shape[0] == HTOK and a.shape[1] == WTOK:
        return a

    # (D,Ht,Wt)
    if a.shape[1] == HTOK and a.shape[2] == WTOK:
        return np.transpose(a, (1,2,0))

    # resize fallback (H,W,D)
    H0,W0,D0 = a.shape
    out = np.zeros((HTOK,WTOK,D0), np.float32)
    for d in range(D0):
        im = Image.fromarray(a[:,:,d].astype(np.float32))
        im = im.resize((WTOK,HTOK), resample=Image.BILINEAR)
        out[:,:,d] = np.asarray(im).astype(np.float32)
    return out

# ----------------------------
# Match root (robust; optional)
# ----------------------------
def build_npz_index(root: Path):
    if root is None or (not root.exists()):
        return {}, []
    files = list(root.glob("**/*.npz"))
    mp = {}
    for p in files:
        bn = p.name
        if bn not in mp:
            mp[bn] = p
        else:
            try:
                if p.stat().st_mtime > mp[bn].stat().st_mtime:
                    mp[bn] = p
            except Exception:
                pass
    return {k: str(v) for k,v in mp.items()}, files

def resolve_by_index(p, root: Path, idx_map: dict):
    if p is None:
        return None
    s = str(p)
    if s.lower() in ["nan","none",""]:
        return None
    pp = Path(s)
    if pp.exists():
        return str(pp)

    # relative under root
    if root is not None and root.exists() and (not pp.is_absolute()):
        cand = root / pp
        if cand.exists():
            return str(cand)

    bn = pp.name
    if bn in idx_map:
        return idx_map[bn]

    # try add extension
    if "." not in bn:
        if (bn + ".npz") in idx_map:
            return idx_map[bn + ".npz"]

    # common subdirs
    if root is not None and root.exists():
        for seg in ["train","test","train_all","test_all",""]:
            cand = (root/seg/bn) if seg else (root/bn)
            if cand.exists():
                return str(cand)
            if "." not in bn:
                cand2 = (root/seg/(bn+".npz")) if seg else (root/(bn+".npz"))
                if cand2.exists():
                    return str(cand2)
    return None

def pick_latest_match_root():
    cands = []
    cands += list(CACHE_W.glob("match_cfg_*"))
    cands += list(INP.glob("**/recodai_luc/cache/match_cfg_*"))
    cands = [c for c in cands if (c/"cfg.json").exists()]
    if not cands:
        return None
    cands = sorted(cands, key=lambda p: (p/"cfg.json").stat().st_mtime, reverse=True)
    return cands[0]

MATCH_ROOT = pick_latest_match_root()
PATCH = 14
match_map_tr = {}
match_map_te = {}

if MATCH_ROOT is None:
    print("WARNING: MATCH_ROOT not found. Seed will be zeros.")
else:
    try:
        MATCH_CFG = json.loads((MATCH_ROOT/"cfg.json").read_text())
        PATCH = int(MATCH_CFG.get("patch", MATCH_CFG.get("patch_size", 14)))

        def build_match_map(pq: Path):
            if pq is None or (not pq.exists()):
                return {}
            dfm = pd.read_parquet(pq).copy()

            # ensure case_id
            if "case_id" not in dfm.columns:
                if "uid" in dfm.columns:
                    dfm["case_id"] = dfm["uid"].apply(infer_case_id_any)
                else:
                    dfm["case_id"] = np.nan
            dfm["case_id"] = pd.to_numeric(dfm["case_id"], errors="coerce")
            dfm = dfm[dfm["case_id"].notna()].copy()
            dfm["case_id"] = dfm["case_id"].astype(int)

            idx_map, _ = build_npz_index(MATCH_ROOT)

            # pick best path col
            path_cands = ["match_npz","npz_path","path","file","npz"]
            best = (None, -1.0)
            for c in path_cands:
                if c not in dfm.columns:
                    continue
                rr = dfm[c].map(lambda x: resolve_by_index(x, MATCH_ROOT, idx_map))
                ex = rr.map(lambda x: isinstance(x,str) and Path(x).exists()).mean() if len(rr) else 0.0
                if ex > best[1]:
                    best = (c, float(ex))
            best_col = best[0]
            if best_col is None:
                return {}

            dfm["match_npz"] = dfm[best_col].map(lambda x: resolve_by_index(x, MATCH_ROOT, idx_map))
            dfm["_ok"] = dfm["match_npz"].map(lambda x: isinstance(x,str) and Path(x).exists())
            dfm = dfm[dfm["_ok"]].copy()
            if len(dfm) == 0:
                return {}

            score_cols = [c for c in ["best_peak_score","peak_score_max","max_peak_score","score_max","best_score","peak_max"] if c in dfm.columns]
            if score_cols:
                sc = score_cols[0]
                dfm[sc] = pd.to_numeric(dfm[sc], errors="coerce").fillna(-1)
                dfm = dfm.sort_values(["case_id", sc], ascending=[True, False]).drop_duplicates("case_id", keep="first")
            else:
                dfm["_mt"] = dfm["match_npz"].map(lambda p: Path(p).stat().st_mtime if Path(p).exists() else -1.0)
                dfm = dfm.sort_values(["case_id","_mt"], ascending=[True, False]).drop_duplicates("case_id", keep="first")

            return dfm.set_index("case_id")["match_npz"].to_dict()

        match_map_tr = build_match_map(MATCH_ROOT / "match_manifest_train.parquet")
        match_map_te = build_match_map(MATCH_ROOT / "match_manifest_test.parquet")

        print("MATCH_ROOT:", MATCH_ROOT)
        print("PATCH:", PATCH, "| match_map_tr:", len(match_map_tr), "| match_map_te:", len(match_map_te))
    except Exception as e:
        print("WARNING: match manifest parse failed -> seed zeros. Reason:", repr(e))
        MATCH_ROOT = None
        match_map_tr = {}
        match_map_te = {}
        PATCH = 14

# ----------------------------
# SciPy optional
# ----------------------------
try:
    import scipy.ndimage as ndi
    _HAS_SCIPY = True
except Exception:
    _HAS_SCIPY = False

def dilate_tok(x_bool, it=1):
    if it <= 0:
        return x_bool.astype(bool)
    x = x_bool.astype(bool)
    if _HAS_SCIPY:
        return ndi.binary_dilation(x, iterations=it)
    for _ in range(it):
        xp = np.pad(x, 1, mode="constant", constant_values=False)
        y = np.zeros_like(x, dtype=bool)
        for dy in (-1,0,1):
            for dx in (-1,0,1):
                y |= xp[1+dy:1+dy+x.shape[0], 1+dx:1+dx+x.shape[1]]
        x = y
    return x

def label_cc(x_bool):
    x = x_bool.astype(bool)
    if _HAS_SCIPY:
        lab, n = ndi.label(x, structure=np.ones((3,3), dtype=np.uint8))
        return lab, int(n)
    H,W = x.shape
    lab = np.zeros((H,W), dtype=np.int32); cur=0
    for y in range(H):
        for x0 in range(W):
            if (not x[y,x0]) or lab[y,x0]!=0:
                continue
            cur += 1
            st=[(y,x0)]; lab[y,x0]=cur
            while st:
                yy,xx=st.pop()
                for dy in (-1,0,1):
                    for dx in (-1,0,1):
                        if dy==0 and dx==0:
                            continue
                        ny,nx=yy+dy,xx+dx
                        if 0<=ny<H and 0<=nx<W and x[ny,nx] and lab[ny,nx]==0:
                            lab[ny,nx]=cur; st.append((ny,nx))
    return lab, int(cur)

def inst_split_union_tok(mask_bool, min_area=2, max_area_frac=0.8, max_keep=8):
    H,W = mask_bool.shape
    lab,n = label_cc(mask_bool)
    if n<=0:
        return np.zeros((H,W), dtype=bool), 0
    insts=[]; areas=[]
    for k in range(1,n+1):
        m = (lab==k)
        a = int(m.sum())
        if a < min_area:
            continue
        if a / float(H*W) > max_area_frac:
            continue
        insts.append(m); areas.append(a)
    if not insts:
        return np.zeros((H,W), dtype=bool), 0
    order = np.argsort(np.asarray(areas))[::-1][:max_keep]
    uni = np.zeros((H,W), dtype=bool)
    for i in order:
        uni |= insts[i]
    return uni, int(len(order))

def dice(pr_bool, gt_bool):
    a=int(pr_bool.sum()); b=int(gt_bool.sum())
    if a==0 and b==0:
        return 1.0
    if a==0 or b==0:
        return 0.0
    inter=int((pr_bool & gt_bool).sum())
    return float((2.0*inter)/(a+b))

# ----------------------------
# GT union loader -> token GT
# ----------------------------
def _find_mask_files(mask_dir: Path, case_id: int):
    if mask_dir is None or (not mask_dir.exists()):
        return []
    cid = str(int(case_id))
    exts = (".png",".jpg",".jpeg",".tif",".tiff",".bmp")
    pats = [f"{cid}*.png", f"{cid}*.jpg", f"{cid}*.jpeg", f"{cid}*.tif", f"{cid}*.tiff", f"{cid}*.bmp",
            f"{cid}__*.png", f"{cid}_*.png"]
    out, seen = [], set()
    for pat in pats:
        for p in mask_dir.glob(pat):
            if p.suffix.lower() in exts:
                s = str(p)
                if s not in seen:
                    out.append(p); seen.add(s)
    return sorted(out)

def load_gt_union_full(case_id: int):
    # fast npy union if exists
    for d in [TRAIN_MASK_DIR, SUP_MASK_DIR]:
        if d is None or (not d.exists()):
            continue
        npy = d / f"{int(case_id)}.npy"
        if npy.exists():
            a = np.load(npy, mmap_mode="r")
            a = np.asarray(a)
            if a.ndim == 2:
                return (a > 0)
            if a.ndim == 3:
                return (a > 0).any(axis=0)

    files = []
    if TRAIN_MASK_DIR is not None:
        files += _find_mask_files(TRAIN_MASK_DIR, case_id)
    if SUP_MASK_DIR is not None:
        files += _find_mask_files(SUP_MASK_DIR, case_id)
    if not files:
        return None

    m = None
    for p in files:
        try:
            im = Image.open(p).convert("L")
            a = (np.asarray(im) > 0)
            m = a if m is None else (m | a)
        except Exception:
            continue
    return m

def downsample_bool_to_tok(mask_bool: np.ndarray):
    if mask_bool is None:
        return np.zeros((HTOK,WTOK), dtype=np.float32)
    im = Image.fromarray((mask_bool.astype(np.uint8)*255))
    im = im.resize((WTOK, HTOK), resample=Image.NEAREST)
    return (np.asarray(im) > 127).astype(np.float32)

def resize_bool_grid(x_bool: np.ndarray, h, w):
    if x_bool is None:
        return np.zeros((h,w), dtype=bool)
    if x_bool.shape == (h,w):
        return x_bool.astype(bool)
    im = Image.fromarray((x_bool.astype(np.uint8)*255))
    im = im.resize((w, h), resample=Image.NEAREST)
    return (np.asarray(im) > 127)

# ----------------------------
# Seed from match_npz (token union) + best_score (fail-safe)
# ----------------------------
def load_seed_tok(case_id: int, is_test=False, topk=8):
    mm = match_map_te if is_test else match_map_tr
    p = mm.get(int(case_id), None)
    if p is None or (not Path(p).exists()):
        return np.zeros((HTOK,WTOK), dtype=np.float32), 0

    try:
        z = np.load(p)
    except Exception:
        return np.zeros((HTOK,WTOK), dtype=np.float32), 0

    scores = None
    for k in ["peak_score","scores","score","peak_scores"]:
        if k in z.files:
            scores = np.asarray(z[k]).reshape(-1)
            break

    src = tgt = None
    for a,b in [("src_masks","tgt_masks"), ("src_m","tgt_m"), ("src","tgt")]:
        if a in z.files and b in z.files:
            src = np.asarray(z[a]); tgt = np.asarray(z[b])
            break

    best = int(np.max(scores)) if scores is not None and len(scores) else 0

    if src is None or tgt is None or src.ndim != 3 or tgt.ndim != 3 or src.shape[0] == 0:
        return np.zeros((HTOK,WTOK), dtype=np.float32), best

    if scores is not None and len(scores) == src.shape[0]:
        idx = np.argsort(scores)[::-1][:min(topk, len(scores))]
        src = src[idx]; tgt = tgt[idx]
    else:
        if src.shape[0] > topk:
            src = src[:topk]; tgt = tgt[:topk]

    seed_match = ((src>0) | (tgt>0)).any(axis=0).astype(bool)
    if seed_match.shape != (HTOK,WTOK):
        seed_tok = resize_bool_grid(seed_match, HTOK, WTOK).astype(np.float32)
    else:
        seed_tok = seed_match.astype(np.float32)

    return seed_tok, best

# ----------------------------
# Dataset / loaders
# ----------------------------
class TrainDS(Dataset):
    def __init__(self, df_in: pd.DataFrame, tok_col: str):
        self.df = df_in.reset_index(drop=True)
        self.tok_col = tok_col
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        cid = int(r["case_id"])
        y = float(r["y"])
        tok = load_tok_norm(r[self.tok_col])          # (Ht,Wt,D)
        seed, best_score = load_seed_tok(cid, is_test=False, topk=8)
        gt_full = load_gt_union_full(cid)
        gt_tok = downsample_bool_to_tok(gt_full)      # (Ht,Wt) float
        tok_ch = np.transpose(tok, (2,0,1))           # (D,Ht,Wt)
        x = np.concatenate([tok_ch, seed[None,:,:]], axis=0).astype(np.float32)
        return {
            "x": torch.from_numpy(x),
            "gt": torch.from_numpy(gt_tok[None,:,:].astype(np.float32)),
            "y": torch.tensor([y], dtype=torch.float32),
            "best_score": torch.tensor([float(best_score)], dtype=torch.float32),
            "case_id": torch.tensor([cid], dtype=torch.int64),
        }

class InferDS(Dataset):
    def __init__(self, df_in: pd.DataFrame, tok_col: str, is_test: bool):
        self.df = df_in.reset_index(drop=True)
        self.tok_col = tok_col
        self.is_test = is_test
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        cid = int(r["case_id"])
        tok = load_tok_norm(r[self.tok_col])
        seed, best_score = load_seed_tok(cid, is_test=self.is_test, topk=8)
        tok_ch = np.transpose(tok, (2,0,1))
        x = np.concatenate([tok_ch, seed[None,:,:]], axis=0).astype(np.float32)
        return {
            "x": torch.from_numpy(x),
            "best_score": torch.tensor([float(best_score)], dtype=torch.float32),
            "case_id": torch.tensor([cid], dtype=torch.int64),
        }

# ----------------------------
# Model (odd-size safe)
# ----------------------------
class ConvBNAct(nn.Module):
    def __init__(self, c_in, c_out, k=3, s=1, p=1, drop=0.0):
        super().__init__()
        self.conv = nn.Conv2d(c_in, c_out, k, s, p, bias=False)
        self.bn = nn.BatchNorm2d(c_out)
        self.act = nn.SiLU(inplace=True)
        self.drop = nn.Dropout2d(drop) if drop > 0 else nn.Identity()
    def forward(self, x):
        return self.drop(self.act(self.bn(self.conv(x))))

class ASPP(nn.Module):
    def __init__(self, c_in, c_out, rates=(1,2,4)):
        super().__init__()
        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c_in, c_out, 3, padding=r, dilation=r, bias=False),
                nn.BatchNorm2d(c_out),
                nn.SiLU(inplace=True),
            ) for r in rates
        ])
        self.proj = nn.Sequential(
            nn.Conv2d(len(rates)*c_out, c_out, 1, bias=False),
            nn.BatchNorm2d(c_out),
            nn.SiLU(inplace=True),
        )
    def forward(self, x):
        y = torch.cat([b(x) for b in self.blocks], dim=1)
        return self.proj(y)

class HybridUNet(nn.Module):
    def __init__(self, c_in, base_ch=96, drop=0.1):
        super().__init__()
        c1, c2, c3 = base_ch, base_ch*2, base_ch*3
        self.e1 = nn.Sequential(ConvBNAct(c_in, c1, drop=drop), ConvBNAct(c1, c1, drop=drop))
        self.pool = nn.MaxPool2d(2, 2, ceil_mode=True)
        self.e2 = nn.Sequential(ConvBNAct(c1, c2, drop=drop), ConvBNAct(c2, c2, drop=drop))
        self.e3 = nn.Sequential(ConvBNAct(c2, c3, drop=drop), ConvBNAct(c3, c3, drop=drop))
        self.aspp = ASPP(c3, c3, rates=(1,2,4))

        self.d2 = nn.Sequential(ConvBNAct(c3+c2, c2, drop=drop), ConvBNAct(c2, c2, drop=drop))
        self.d1 = nn.Sequential(ConvBNAct(c2+c1, c1, drop=drop), ConvBNAct(c1, c1, drop=drop))

        self.seg_head = nn.Conv2d(c1, 1, 1)
        self.cls_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(c3, c3//2),
            nn.SiLU(inplace=True),
            nn.Dropout(drop if drop>0 else 0.0),
            nn.Linear(c3//2, 1),
        )

    def forward(self, x):
        e1 = self.e1(x)
        e2 = self.e2(self.pool(e1))
        e3 = self.e3(self.pool(e2))
        b  = self.aspp(e3)

        cls_logit = self.cls_head(b)

        u2 = F.interpolate(b, size=e2.shape[-2:], mode="bilinear", align_corners=False)
        d2 = self.d2(torch.cat([u2, e2], dim=1))

        u1 = F.interpolate(d2, size=e1.shape[-2:], mode="bilinear", align_corners=False)
        d1 = self.d1(torch.cat([u1, e1], dim=1))

        seg_logit = self.seg_head(d1)
        return seg_logit, cls_logit

# ----------------------------
# Loss
# ----------------------------
def dice_loss_from_logits(logits, target, eps=1e-6):
    p = torch.sigmoid(logits)
    inter = (p * target).sum(dim=(2,3))
    den = (p + target).sum(dim=(2,3)) + eps
    d = (2.0 * inter) / den
    return 1.0 - d.mean()

def bce_focal_from_logits(logits, target, gamma=2.0):
    bce = F.binary_cross_entropy_with_logits(logits, target, reduction="none")
    if gamma <= 0:
        return bce.mean()
    p = torch.sigmoid(logits)
    pt = target * p + (1-target) * (1-p)
    w = (1-pt).pow(gamma)
    return (w * bce).mean()

# ----------------------------
# Final training config (from BEST + env overrides)
# ----------------------------
EPOCHS = int(os.environ.get("EPOCHS_FINAL", "18" if device.type=="cuda" else "8"))
BATCH  = int(os.environ.get("BATCH_SIZE", "32" if device.type=="cuda" else "16"))
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", "2"))
ACCUM = int(os.environ.get("ACCUM_STEPS", "1"))

base_ch = int(BEST["base_ch"])
dropout = float(BEST["dropout"])
lr = float(BEST["lr"])
wd = float(BEST["weight_decay"])
lam_seg = float(BEST["lambda_seg"])
lam_cls = float(BEST["lambda_cls"])
focal_gamma = float(BEST["focal_gamma"])

pos = int(df_tr["y"].sum())
neg = int(len(df_tr) - pos)
pos_weight = float(neg / max(1, pos))
bce_cls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], dtype=torch.float32, device=device))

print("-"*60)
print("FINAL TRAIN CFG:", {
    "epochs": EPOCHS, "batch": BATCH, "accum": ACCUM,
    "lr": lr, "wd": wd, "base_ch": base_ch, "dropout": dropout,
    "pos_weight": pos_weight, "lam_seg": lam_seg, "lam_cls": lam_cls, "focal_gamma": focal_gamma
})
print("-"*60)

ds = TrainDS(df_tr, tok_col=tok_col_tr)
dl = DataLoader(
    ds, batch_size=BATCH, shuffle=True,
    num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"),
    drop_last=True
)

model = HybridUNet(c_in=CIN, base_ch=base_ch, drop=dropout).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1, EPOCHS))
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

t0 = time.time()
model.train()
opt.zero_grad(set_to_none=True)

for ep in range(1, EPOCHS+1):
    loss_meter = 0.0
    nsteps = 0
    for step, batch in enumerate(dl, start=1):
        x  = batch["x"].to(device, non_blocking=True)
        gt = batch["gt"].to(device, non_blocking=True)
        yb = batch["y"].to(device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=USE_AMP):
            seg_logit, cls_logit = model(x)
            l_seg = bce_focal_from_logits(seg_logit, gt, gamma=focal_gamma) + dice_loss_from_logits(seg_logit, gt)
            l_cls = bce_cls(cls_logit, yb)
            loss = lam_seg * l_seg + lam_cls * l_cls
            loss = loss / ACCUM

        scaler.scale(loss).backward()
        if step % ACCUM == 0:
            scaler.step(opt); scaler.update()
            opt.zero_grad(set_to_none=True)

        loss_meter += float(loss.item()) * ACCUM
        nsteps += 1

    sched.step()
    if ep == 1 or ep % 2 == 0 or ep == EPOCHS:
        print(f"[ep {ep:02d}/{EPOCHS}] loss={loss_meter/max(1,nsteps):.4f} lr={sched.get_last_lr()[0]:.6g} | {time.time()-t0:.1f}s")

print("TRAIN DONE |", f"{time.time()-t0:.1f}s")

# ----------------------------
# Export mask-prob cache (TRAIN+TEST) + p_gate
# ----------------------------
def cfg_hash(d):
    s = json.dumps(d, sort_keys=True).encode()
    return hashlib.md5(s).hexdigest()[:10]

CFG_EXPORT = {
    "hybrid": True,
    "best_config": BEST,
    "token_manifest_train": str(TOK_MAN_TRAIN),
    "token_manifest_root": str(TOKEN_MANIFEST_ROOT),
    "token_data_root": str(TOKEN_DATA_ROOT),
    "match_root": str(MATCH_ROOT) if MATCH_ROOT else None,
    "tok_grid": {"HTOK": HTOK, "WTOK": WTOK, "PATCH": PATCH},
    "train_table": str(TRAIN_TABLE),
    "test_table": str(TEST_TABLE) if df_te is not None else None,
    "tok_col_tr": tok_col_tr,
    "tok_col_te": tok_col_te,
}
CFG_ID = cfg_hash(CFG_EXPORT)

MASKPROB_DIR = CACHE_W / f"mask_prob_hybrid_{CFG_ID}"
MASKPROB_DIR.mkdir(parents=True, exist_ok=True)

@torch.no_grad()
def export_probs(df_in: pd.DataFrame, tok_col_use: str, is_test: bool, tag: str):
    ds_inf = InferDS(df_in, tok_col=tok_col_use, is_test=is_test)
    dl_inf = DataLoader(
        ds_inf, batch_size=BATCH, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=(device.type=="cuda"),
        drop_last=False
    )
    model.eval()
    t1 = time.time()
    wrote = 0
    for j, batch in enumerate(dl_inf, start=1):
        x = batch["x"].to(device, non_blocking=True)
        seg_logit, cls_logit = model(x)
        prob_tok = torch.sigmoid(seg_logit).detach().cpu().numpy()[:,0]        # (B,Ht,Wt)
        p_gate   = torch.sigmoid(cls_logit).detach().cpu().numpy().reshape(-1) # (B,)
        cids = batch["case_id"].cpu().numpy().reshape(-1)

        for i in range(len(cids)):
            cid = int(cids[i])
            np.savez_compressed(
                MASKPROB_DIR / f"{cid}.npz",
                prob_tok=prob_tok[i].astype(np.float16),
                p_gate=np.float16(p_gate[i]),
            )
            wrote += 1

        if j % 100 == 0:
            print(f"[export {tag}] {wrote}/{len(ds_inf)} | {time.time()-t1:.1f}s")

    print(f"[export {tag}] done | wrote={wrote} | {time.time()-t1:.1f}s")

export_probs(df_tr[["case_id", tok_col_tr]].copy(), tok_col_tr, is_test=False, tag="train")
if df_te is not None and len(df_te) > 0:
    export_probs(df_te[["case_id", tok_col_te]].copy(), tok_col_te, is_test=True, tag="test")

# ----------------------------
# Threshold sweep on TRAIN using Dice-proxy (token-space)
# score(thr) = mean( p_gate>=thr ? dice(postproc(seg,seed),gt) : dice(empty,gt) )
# ----------------------------
T1 = float(BEST["T1"]); T0 = float(BEST["T0"]); dil_it = int(BEST["seed_dilate_it"])
min_tok_area = int(BEST["min_tok_area"])
max_tok_area_frac = float(BEST["max_tok_area_frac"])
max_inst_keep = int(BEST["max_inst_keep"])
min_peak_keep = int(BEST["min_peak_score_keep"])
min_area_frac_keep = float(BEST["min_area_frac_keep"])

@torch.no_grad()
def build_dice_arrays_train():
    ds_inf = TrainDS(df_tr[["case_id","y",tok_col_tr]].copy(), tok_col=tok_col_tr)
    dl_inf = DataLoader(ds_inf, batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS,
                        pin_memory=(device.type=="cuda"), drop_last=False)
    model.eval()

    p_gate_all = np.zeros(len(ds_inf), np.float32)
    dice_use   = np.zeros(len(ds_inf), np.float32)
    dice_empty = np.zeros(len(ds_inf), np.float32)

    k0 = 0
    t2 = time.time()
    for batch in dl_inf:
        x = batch["x"].to(device, non_blocking=True)
        gt = batch["gt"].cpu().numpy()[:,0] > 0.5
        best_score = batch["best_score"].cpu().numpy().reshape(-1)

        seg_logit, cls_logit = model(x)
        prob = torch.sigmoid(seg_logit).cpu().numpy()[:,0]
        pg   = torch.sigmoid(cls_logit).cpu().numpy().reshape(-1)
        seed = batch["x"].cpu().numpy()[:,-1]  # last channel (B,Ht,Wt)

        B = prob.shape[0]
        for i in range(B):
            idx = k0 + i
            p_gate_all[idx] = pg[i]

            gt_empty = (gt[i].sum() == 0)
            dice_empty[idx] = 1.0 if gt_empty else 0.0

            hard = prob[i] >= T1
            soft = prob[i] >= T0
            sd = dilate_tok(seed[i] > 0.5, dil_it)
            fused = hard | (sd & soft)

            uni, _ = inst_split_union_tok(fused, min_area=min_tok_area, max_area_frac=max_tok_area_frac, max_keep=max_inst_keep)
            area_frac = float(uni.mean())
            if (best_score[i] < min_peak_keep) and (area_frac < min_area_frac_keep):
                uni = np.zeros((HTOK,WTOK), dtype=bool)

            dice_use[idx] = dice(uni, gt[i])

        k0 += B
        if k0 % 800 == 0:
            print(f"[dice-proxy] {k0}/{len(ds_inf)} | {time.time()-t2:.1f}s")

    return p_gate_all, dice_use, dice_empty

p_gate_all, dice_use, dice_empty = build_dice_arrays_train()

thr_grid = np.linspace(0.0, 1.0, 201, dtype=np.float32)
rows = []
best_thr = 0.5
best_score = -1.0

for thr in thr_grid:
    use = (p_gate_all >= thr)
    score = float(np.where(use, dice_use, dice_empty).mean())
    rows.append({"thr": float(thr), "score_dice_proxy": score})
    if score > best_score:
        best_score = score
        best_thr = float(thr)

df_thr = pd.DataFrame(rows)
df_thr.to_csv(OUT_DIR / "threshold_table.csv", index=False)

(OUT_DIR / "best_threshold.json").write_text(json.dumps({
    "recommended_thr": float(best_thr),
    "best_score_dice_proxy": float(best_score),
    "pp": {
        "T1": float(T1), "T0": float(T0), "seed_dilate_it": int(dil_it),
        "min_tok_area": int(min_tok_area),
        "max_tok_area_frac": float(max_tok_area_frac),
        "max_inst_keep": int(max_inst_keep),
        "min_peak_score_keep": int(min_peak_keep),
        "min_area_frac_keep": float(min_area_frac_keep),
    },
    "cfg_id": CFG_ID
}, indent=2))

print("BEST_THR:", best_thr, "| best_score_dice_proxy:", best_score)

# ----------------------------
# Save final model bundle (.pt)
# ----------------------------
pack = {
    "model_type": "HybridUNet",
    "state_dict": {k: v.detach().cpu() for k,v in model.state_dict().items()},
    "best_config": BEST,
    "recommended_thr": float(best_thr),
    "best_score_dice_proxy": float(best_score),
    "cfg_export": CFG_EXPORT,
    "cfg_id": CFG_ID,
    "maskprob_dir": str(MASKPROB_DIR),
    "tok_grid": {"HTOK": int(HTOK), "WTOK": int(WTOK), "PATCH": int(PATCH)},
    "input_channels": int(CIN),
    "token_dim": int(DIN),
    "meta": {
        "seed": int(SEED),
        "train_rows": int(len(df_tr)),
        "test_rows": int(len(df_te)) if df_te is not None else 0,
        "train_table": str(TRAIN_TABLE),
        "test_table": str(TEST_TABLE) if df_te is not None else None,
        "tok_col_tr": str(tok_col_tr),
        "tok_col_te": str(tok_col_te),
    }
}
torch.save(pack, OUT_DIR / "final_hybrid_model.pt")

print("-"*60)
print("SAVED:")
print(" -", OUT_DIR / "final_hybrid_model.pt")
print(" -", OUT_DIR / "threshold_table.csv")
print(" -", OUT_DIR / "best_threshold.json")
print("MASKPROB_DIR:", MASKPROB_DIR)
print("CFG_ID:", CFG_ID)


DEVICE: cpu | AMP: False
TOK_MAN_TRAIN       : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/dinov2_base_518_cfg_543289469500/tokens_manifest_train.parquet
TOKEN_MANIFEST_ROOT : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/dinov2_base_518_cfg_543289469500 | token_files: 2796
BASE_CACHE          : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache
TOKEN_DATA_ROOT     : /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/dinov2_base_518_cfg_543289469500 | token_files: 2796
TOKEN indexed files : 2796
TRAIN token exists rate: 1.0 | rows: 2795
TEST token exists rate: 1.0 | rows: 1
TRAIN_TABLE: /kaggle/working/recodai_luc_gate_artifacts/train_table.patched_tokens.parquet | rows_used: 2795 | pos_rate: 1.0
TEST_TABLE : /kaggle/working/recodai_luc_gate_artifacts/test_table.parquet
TOK_COL_TR : tok_npz
TOK_COL_TE : tok_npz
TOKEN GRID: (37, 37) | DIN: 768 | CIN: 769
MATCH_ROOT: /kaggle/input/recod-ailuc-dinov2-base/recodai_luc/cache/match_cfg_2ed747746f9c
PATCH: 14 |

# Finalize & Save Model Bundle (Reproducible)

In [8]:
# ============================================================
# STAGE — Finalize & Save Model Bundle (Reproducible) (ONE CELL) — HYBRID (OPSI-1) — REVISI FULL
# Bundle (portable):
# - final_hybrid_model.pt
# - best_threshold.json + threshold_table.csv
# - best_config.json (if exists)
# - paths.json + match cfg.json (if resolvable)
# - manifest.json (sha256 + env + metadata)
# - ZIP: hybrid_model_bundle_<cfg_id>_<stamp>.zip
# ============================================================

import os, json, time, hashlib, shutil, platform, zipfile
from pathlib import Path

import numpy as np
import pandas as pd
import torch

OPT_DIR  = Path("/kaggle/working/recodai_luc_hybrid_opt")
OUT_DIR  = Path("/kaggle/working/recodai_luc_hybrid_artifacts")
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")

FINAL_PT = OUT_DIR / "final_hybrid_model.pt"
BEST_THR = OUT_DIR / "best_threshold.json"
THR_TAB  = OUT_DIR / "threshold_table.csv"
BEST_CFG = OPT_DIR / "best_config.json"
PATHS_JS = PROF_DIR / "paths.json"

for p in [FINAL_PT, BEST_THR]:
    if not p.exists():
        raise FileNotFoundError(f"Missing {p}. Run Final Training stage first.")

pack = torch.load(FINAL_PT, map_location="cpu")
cfg_id = str(pack.get("cfg_id", "unknown"))

stamp = time.strftime("%Y%m%d_%H%M%S", time.gmtime())
BUNDLE_DIR = Path(f"/kaggle/working/recodai_luc_hybrid_bundle_{cfg_id}_{stamp}")
BUNDLE_DIR.mkdir(parents=True, exist_ok=True)

def sha256_file(p: Path, chunk=1<<20):
    h = hashlib.sha256()
    with open(p, "rb") as f:
        while True:
            b = f.read(chunk)
            if not b:
                break
            h.update(b)
    return h.hexdigest()

def safe_copy(src: Path, dst_dir: Path, new_name: str = None):
    if src is None:
        return None
    src = Path(src)
    if not src.exists():
        return None
    dst = dst_dir / (new_name if new_name else src.name)
    shutil.copy2(src, dst)
    return dst

# resolve match cfg.json if possible
match_cfg_json = None
try:
    match_root = pack.get("cfg_export", {}).get("match_root", None)
    if match_root and Path(match_root).exists():
        p = Path(match_root) / "cfg.json"
        if p.exists():
            match_cfg_json = p
except Exception:
    match_cfg_json = None

copied = {}
copied["final_hybrid_model.pt"] = str(safe_copy(FINAL_PT, BUNDLE_DIR, "final_hybrid_model.pt"))
copied["best_threshold.json"]   = str(safe_copy(BEST_THR, BUNDLE_DIR, "best_threshold.json"))

if THR_TAB.exists():
    copied["threshold_table.csv"] = str(safe_copy(THR_TAB, BUNDLE_DIR, "threshold_table.csv"))

if BEST_CFG.exists():
    copied["best_config.json"] = str(safe_copy(BEST_CFG, BUNDLE_DIR, "best_config.json"))

if PATHS_JS.exists():
    copied["paths.json"] = str(safe_copy(PATHS_JS, BUNDLE_DIR, "paths.json"))

if match_cfg_json is not None and Path(match_cfg_json).exists():
    copied["match_cfg.json"] = str(safe_copy(match_cfg_json, BUNDLE_DIR, "match_cfg.json"))

# README
readme = BUNDLE_DIR / "README.txt"
readme.write_text(
    "Hybrid (OPSI-1) bundle contents:\n"
    "- final_hybrid_model.pt : Torch state_dict + cfg + recommended_thr + maskprob_dir + metadata\n"
    "- best_threshold.json   : recommended_thr + postprocess params + cfg_id\n"
    "- threshold_table.csv   : sweep table (optional)\n"
    "- best_config.json      : HPO best trial (optional)\n"
    "- paths.json            : dataset paths used (optional)\n"
    "- match_cfg.json        : robust matching cfg used (optional)\n\n"
    "Load example:\n"
    "  import torch, json\n"
    "  pack = torch.load('final_hybrid_model.pt', map_location='cpu')\n"
    "  thr  = json.load(open('best_threshold.json'))['recommended_thr']\n"
)

# Manifest (sha256 + env + metadata)
files = sorted([p for p in BUNDLE_DIR.glob("*") if p.is_file() and p.name != "manifest.json"])
hashes = {p.name: sha256_file(p) for p in files}

try:
    thr_js = json.loads(BEST_THR.read_text())
    rec_thr = float(pack.get("recommended_thr", thr_js.get("recommended_thr", 0.5)))
except Exception:
    rec_thr = float(pack.get("recommended_thr", 0.5))

meta = {
    "bundle_dir": str(BUNDLE_DIR),
    "created_utc": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
    "cfg_id": cfg_id,
    "recommended_thr": rec_thr,
    "best_score_dice_proxy": float(pack.get("best_score_dice_proxy", np.nan)),
    "tok_grid": pack.get("tok_grid", {}),
    "input_channels": int(pack.get("input_channels", -1)),
    "token_dim": int(pack.get("token_dim", -1)),
    "maskprob_dir": str(pack.get("maskprob_dir", "")),
    "cfg_export": pack.get("cfg_export", {}),
    "env": {
        "python": platform.python_version(),
        "platform": platform.platform(),
        "torch": torch.__version__,
        "cuda_available": bool(torch.cuda.is_available()),
        "cuda_version": torch.version.cuda,
        "numpy": np.__version__,
        "pandas": pd.__version__,
    },
    "files_copied": copied,
    "sha256": hashes,
}

manifest_path = BUNDLE_DIR / "manifest.json"
manifest_path.write_text(json.dumps(meta, indent=2))

# ZIP
ZIP_PATH = Path(f"/kaggle/working/hybrid_model_bundle_{cfg_id}_{stamp}.zip")
with zipfile.ZipFile(ZIP_PATH, "w", compression=zipfile.ZIP_DEFLATED) as z:
    for p in sorted(BUNDLE_DIR.glob("*")):
        if p.is_file():
            z.write(p, arcname=p.name)

print("BUNDLE_DIR:", BUNDLE_DIR)
print("ZIP_PATH  :", ZIP_PATH)
print("Files:", [p.name for p in sorted(BUNDLE_DIR.glob('*')) if p.is_file()])


BUNDLE_DIR: /kaggle/working/recodai_luc_hybrid_bundle_8d3a1c7f72_20260113_164518
ZIP_PATH  : /kaggle/working/hybrid_model_bundle_8d3a1c7f72_20260113_164518.zip
Files: ['README.txt', 'best_config.json', 'best_threshold.json', 'final_hybrid_model.pt', 'manifest.json', 'match_cfg.json', 'paths.json', 'threshold_table.csv']
