# Set Paths & Select Config (CFG)

In [None]:
# ============================================================
# STAGE — Set Paths & Select Config (CFG) (ONE CELL)
# - Auto-load competition + prep artifacts
# - Auto-pick latest TOKEN cache / MATCH cache / PRED_ENS dir (unless already in globals)
# - Build and freeze feature_cols for Gate (LightGBM) from pred_features_train.csv
# - Save:
#     /kaggle/working/recodai_luc_gate_artifacts/train_cfg.json
#     /kaggle/working/recodai_luc_gate_artifacts/feature_cols.json
# ============================================================

import os, json, hashlib, time
from pathlib import Path
import numpy as np
import pandas as pd

# ----------------------------
# 0) Locations (fixed)
# ----------------------------
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")
CACHE_DIR = Path("/kaggle/working/recodai_luc/cache")
ART_DIR = Path("/kaggle/working/recodai_luc_gate_artifacts")
ART_DIR.mkdir(parents=True, exist_ok=True)

for req in ["paths.json", "train_manifest.parquet", "test_manifest.parquet", "folds.parquet"]:
    p = PROF_DIR / req
    if not p.exists():
        raise FileNotFoundError(f"Missing {p}. Run PREP stages first.")

PATHS = json.loads((PROF_DIR / "paths.json").read_text())
df_train = pd.read_parquet(PROF_DIR / "train_manifest.parquet")
df_test  = pd.read_parquet(PROF_DIR / "test_manifest.parquet")
df_folds = pd.read_parquet(PROF_DIR / "folds.parquet")

# ----------------------------
# Helpers: auto-pick latest cache roots
# ----------------------------
def _pick_latest_dir(glob_pat: str, must_have: str):
    cands = sorted(CACHE_DIR.glob(glob_pat))
    cands = [c for c in cands if c.is_dir() and (c / must_have).exists()]
    if not cands:
        return None
    cands = sorted(cands, key=lambda p: (p / must_have).stat().st_mtime, reverse=True)
    return cands[0]

def pick_token_root():
    if "TOKEN_CACHE_ROOT" in globals():
        r = Path(str(globals()["TOKEN_CACHE_ROOT"]))
        if r.exists() and (r / "cfg.json").exists():
            return r
    return _pick_latest_dir("dinov2_*cfg_*", "cfg.json")

def pick_match_root():
    if "MATCH_CACHE_ROOT" in globals():
        r = Path(str(globals()["MATCH_CACHE_ROOT"]))
        if r.exists() and (r / "cfg.json").exists():
            return r
    return _pick_latest_dir("match_cfg_*", "cfg.json")

def pick_pred_ens_dir():
    # preferred: from globals
    if "PRED_ENS_DIR" in globals():
        r = Path(str(globals()["PRED_ENS_DIR"]))
        if r.exists():
            return r
    # preferred canonical
    r = CACHE_DIR / "pred_ens"
    if r.exists():
        return r
    # fallback search
    cands = [d for d in CACHE_DIR.glob("*") if d.is_dir() and (d / "pred_features_train.csv").exists()]
    if cands:
        cands = sorted(cands, key=lambda p: (p / "pred_features_train.csv").stat().st_mtime, reverse=True)
        return cands[0]
    return None

TOKEN_ROOT = pick_token_root()
MATCH_ROOT = pick_match_root()
PRED_ENS_DIR = pick_pred_ens_dir()

if TOKEN_ROOT is None:
    raise FileNotFoundError("TOKEN cache not found. Run 'DINOv2 Feature Cache' stage first.")
if MATCH_ROOT is None:
    raise FileNotFoundError("MATCH cache not found. Run 'Robust Matching' stage first.")
if PRED_ENS_DIR is None:
    raise FileNotFoundError("PRED_ENS_DIR not found. Run 'Verification, Mask Reconstruction & Postprocess' stage first.")

PRED_FEATURES_TRAIN = PRED_ENS_DIR / "pred_features_train.csv"
PRED_FEATURES_TEST  = PRED_ENS_DIR / "pred_features_test.csv"
if not PRED_FEATURES_TRAIN.exists():
    raise FileNotFoundError(f"Missing {PRED_FEATURES_TRAIN}. Run Verification stage first.")
if not PRED_FEATURES_TEST.exists():
    raise FileNotFoundError(f"Missing {PRED_FEATURES_TEST}. Run Verification stage first.")

# ----------------------------
# 1) Feature columns (freeze)
# ----------------------------
df_feat_train = pd.read_csv(PRED_FEATURES_TRAIN)
# must have case_id
if "case_id" not in df_feat_train.columns:
    raise ValueError("pred_features_train.csv must contain 'case_id'")

# prioritize these (core features)
core = [
    "best_peak_score", "has_match", "n_inst",
    "area_frac", "area_frac_tok", "mean_prob_tok",
    "has_prob"
]
# keep only numeric columns (exclude ids/paths)
exclude = set(["case_id", "split", "img_path", "npz_path"])
numeric_cols = []
for c in df_feat_train.columns:
    if c in exclude:
        continue
    # accept core even if dtype object (will coerce later), else require numeric-ish
    if c in core:
        numeric_cols.append(c)
        continue
    if pd.api.types.is_numeric_dtype(df_feat_train[c]):
        numeric_cols.append(c)

# ensure core first if present
feature_cols = [c for c in core if c in numeric_cols] + [c for c in numeric_cols if c not in core]

if len(feature_cols) == 0:
    raise RuntimeError("No usable feature columns found in pred_features_train.csv")

(ART_DIR / "feature_cols.json").write_text(json.dumps(feature_cols, indent=2))

# ----------------------------
# 2) CFG (Gate + calibration + postprocess knobs reference)
# ----------------------------
CFG = {
    "version": "gate_v1",
    "created_utc": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
    "paths": {
        "COMP_ROOT": PATHS.get("COMP_ROOT"),
        "SAMPLE_SUB": PATHS.get("SAMPLE_SUB"),
        "PROF_DIR": str(PROF_DIR),
        "CACHE_DIR": str(CACHE_DIR),
        "ART_DIR": str(ART_DIR),
        "TOKEN_ROOT": str(TOKEN_ROOT),
        "MATCH_ROOT": str(MATCH_ROOT),
        "PRED_ENS_DIR": str(PRED_ENS_DIR),
        "PRED_FEATURES_TRAIN": str(PRED_FEATURES_TRAIN),
        "PRED_FEATURES_TEST": str(PRED_FEATURES_TEST),
    },
    "cv": {
        "n_folds": int(df_folds["fold"].nunique()),
        "seed": 42,
        "split_key": "case_id",
        "fold_col": "fold",
        "label_col": "y",
    },
    "features": {
        "feature_cols_path": str(ART_DIR / "feature_cols.json"),
        "feature_cols": feature_cols,
        "missing_numeric_fill": 0.0,
    },
    "gate_model": {
        "name": "LightGBM",
        "params": {
            "objective": "binary",
            "metric": "binary_logloss",
            "learning_rate": 0.05,
            "num_leaves": 63,
            "min_data_in_leaf": 50,
            "feature_fraction": 0.9,
            "bagging_fraction": 0.8,
            "bagging_freq": 1,
            "lambda_l2": 1.0,
            "max_depth": -1,
            "n_estimators": 2000,
        },
        "early_stopping_rounds": 200,
    },
    "calibration": {
        "enabled": True,
        "method": "isotonic",   # "sigmoid" or "isotonic"
    },
    # Reference knobs used in Verification/Postprocess (kept here for reproducibility)
    "postprocess_ref": {
        "min_peak_score_keep": 6,
        "min_area_frac_keep": 0.0005,
        "max_inst_keep": 8,
    },
}

cfg_id = hashlib.sha1(json.dumps(CFG, sort_keys=True).encode()).hexdigest()[:12]
CFG["cfg_id"] = cfg_id
(ART_DIR / "train_cfg.json").write_text(json.dumps(CFG, indent=2))

# ----------------------------
# 3) Print summary (tight)
# ----------------------------
print("ART_DIR:", ART_DIR)
print("CFG_ID :", cfg_id)
print("TOKEN_ROOT:", TOKEN_ROOT.name)
print("MATCH_ROOT:", MATCH_ROOT.name)
print("PRED_ENS_DIR:", str(PRED_ENS_DIR))
print("Features:", len(feature_cols), "| First 10:", feature_cols[:10])
print("Saved:")
print(" -", ART_DIR / "train_cfg.json")
print(" -", ART_DIR / "feature_cols.json")

# Export globals for next stages
GATE_ART_DIR = ART_DIR
GATE_CFG = CFG
FEATURE_COLS = feature_cols


# Build Training Table (X, y, folds)

In [None]:
# ============================================================
# STAGE — Build Training Table (X, y, folds) (ONE CELL)
# Uses:
# - /kaggle/working/recodai_luc_gate_artifacts/train_cfg.json
# - /kaggle/working/recodai_luc_gate_artifacts/feature_cols.json
# - /kaggle/working/recodai_luc_prof/train_manifest.parquet
# - /kaggle/working/recodai_luc_prof/folds.parquet
# - pred_features_train.csv (from Verification stage)
#
# Produces:
# - /kaggle/working/recodai_luc_gate_artifacts/train_table.parquet
# - /kaggle/working/recodai_luc_gate_artifacts/train_X.npy
# - /kaggle/working/recodai_luc_gate_artifacts/train_y.npy
# - /kaggle/working/recodai_luc_gate_artifacts/train_folds.npy
# - /kaggle/working/recodai_luc_gate_artifacts/feature_stats.json
#
# Exports globals:
# - X_train, y_train, folds_train, df_train_tabular
# ============================================================

import json
from pathlib import Path
import numpy as np
import pandas as pd

ART_DIR = Path("/kaggle/working/recodai_luc_gate_artifacts")
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")

cfg_path = ART_DIR / "train_cfg.json"
cols_path = ART_DIR / "feature_cols.json"
train_manifest_pq = PROF_DIR / "train_manifest.parquet"
folds_pq = PROF_DIR / "folds.parquet"

for p in [cfg_path, cols_path, train_manifest_pq, folds_pq]:
    if not p.exists():
        raise FileNotFoundError(f"Missing {p}. Run previous stages first.")

CFG = json.loads(cfg_path.read_text())
FEATURE_COLS = json.loads(cols_path.read_text())

pred_feat_train = Path(CFG["paths"]["PRED_FEATURES_TRAIN"])
if not pred_feat_train.exists():
    raise FileNotFoundError(f"Missing pred_features_train.csv: {pred_feat_train}")

fill_value = float(CFG["features"].get("missing_numeric_fill", 0.0))
label_col = CFG["cv"]["label_col"]
fold_col  = CFG["cv"]["fold_col"]

# ----------------------------
# 1) Load base labels + folds
# ----------------------------
df_label = pd.read_parquet(train_manifest_pq)[["case_id", label_col]].copy()
df_folds = pd.read_parquet(folds_pq)[["case_id", fold_col]].copy()

df_base = df_label.merge(df_folds, on="case_id", how="left")
if df_base[fold_col].isna().any():
    miss = df_base[df_base[fold_col].isna()]["case_id"].head(10).tolist()
    raise RuntimeError(f"Missing fold for some case_id (first 10): {miss}")

# ----------------------------
# 2) Load features (from verification) and dedup per case_id
# ----------------------------
df_feat = pd.read_csv(pred_feat_train)
if "case_id" not in df_feat.columns:
    raise ValueError("pred_features_train.csv must contain 'case_id'")

# keep only needed columns + case_id
keep = ["case_id"] + [c for c in FEATURE_COLS if c in df_feat.columns]
df_feat = df_feat[keep].copy()

# Deduplicate: if multiple rows per case_id, keep the one with max best_peak_score, then max area_frac
sort_keys = []
if "best_peak_score" in df_feat.columns:
    sort_keys.append("best_peak_score")
if "area_frac" in df_feat.columns:
    sort_keys.append("area_frac")
if sort_keys:
    df_feat = df_feat.sort_values(sort_keys, ascending=False).drop_duplicates("case_id", keep="first")
else:
    df_feat = df_feat.drop_duplicates("case_id", keep="first")

# ----------------------------
# 3) Join to build train table
# ----------------------------
df_train_tabular = df_base.merge(df_feat, on="case_id", how="left")

# ensure all feature cols exist
for c in FEATURE_COLS:
    if c not in df_train_tabular.columns:
        df_train_tabular[c] = np.nan

# coerce numeric + fill
for c in FEATURE_COLS:
    df_train_tabular[c] = pd.to_numeric(df_train_tabular[c], errors="coerce")

n_missing_before = int(df_train_tabular[FEATURE_COLS].isna().sum().sum())
df_train_tabular[FEATURE_COLS] = df_train_tabular[FEATURE_COLS].fillna(fill_value)
n_missing_after = int(df_train_tabular[FEATURE_COLS].isna().sum().sum())

# ----------------------------
# 4) Add a few safe derived features (does NOT touch prep)
# ----------------------------
def safe_log1p(x): return np.log1p(np.maximum(x, 0))
def safe_sqrt(x):  return np.sqrt(np.maximum(x, 0))

if "best_peak_score" in df_train_tabular.columns:
    df_train_tabular["log1p_best_peak_score"] = safe_log1p(df_train_tabular["best_peak_score"].values.astype(np.float32))
if "area_frac_tok" in df_train_tabular.columns:
    df_train_tabular["sqrt_area_frac_tok"] = safe_sqrt(df_train_tabular["area_frac_tok"].values.astype(np.float32))
if "n_inst" in df_train_tabular.columns:
    df_train_tabular["has_inst"] = (df_train_tabular["n_inst"].values.astype(np.float32) > 0).astype(np.float32)

# extend FEATURE_COLS with derived (and save back for consistency)
derived = [c for c in ["log1p_best_peak_score","sqrt_area_frac_tok","has_inst"] if c in df_train_tabular.columns]
FEATURE_COLS_FINAL = FEATURE_COLS + [c for c in derived if c not in FEATURE_COLS]

(ART_DIR / "feature_cols.json").write_text(json.dumps(FEATURE_COLS_FINAL, indent=2))

# ----------------------------
# 5) Build arrays
# ----------------------------
X_train = df_train_tabular[FEATURE_COLS_FINAL].to_numpy(dtype=np.float32, copy=True)
y_train = df_train_tabular[label_col].to_numpy(dtype=np.int64, copy=True)
folds_train = df_train_tabular[fold_col].to_numpy(dtype=np.int64, copy=True)

# ----------------------------
# 6) Save
# ----------------------------
df_train_tabular.to_parquet(ART_DIR / "train_table.parquet", index=False)
np.save(ART_DIR / "train_X.npy", X_train)
np.save(ART_DIR / "train_y.npy", y_train)
np.save(ART_DIR / "train_folds.npy", folds_train)

# feature stats (for sanity / later debugging)
stats = {
    "n_rows": int(len(df_train_tabular)),
    "pos_rate": float(y_train.mean()),
    "n_folds": int(np.unique(folds_train).size),
    "missing_before_fill": n_missing_before,
    "missing_after_fill": n_missing_after,
    "feature_cols": FEATURE_COLS_FINAL,
    "feature_min": np.nanmin(X_train, axis=0).tolist(),
    "feature_max": np.nanmax(X_train, axis=0).tolist(),
    "feature_mean": np.nanmean(X_train, axis=0).tolist(),
}
(ART_DIR / "feature_stats.json").write_text(json.dumps(stats, indent=2))

print("Saved:")
print(" -", ART_DIR / "train_table.parquet")
print(" -", ART_DIR / "train_X.npy")
print(" -", ART_DIR / "train_y.npy")
print(" -", ART_DIR / "train_folds.npy")
print(" -", ART_DIR / "feature_stats.json")
print("Updated:")
print(" -", ART_DIR / "feature_cols.json (added derived if any)")
print("-"*60)
print("Train rows:", len(df_train_tabular), "| pos_rate:", float(y_train.mean()))
print("Features:", len(FEATURE_COLS_FINAL), "| missing_before:", n_missing_before, "| missing_after:", n_missing_after)

# Export globals for next stages
FEATURE_COLS = FEATURE_COLS_FINAL


# Build & Export Test Feature Table

In [None]:
# ============================================================
# STAGE — Build & Export Test Feature Table (ONE CELL)
# Uses:
# - /kaggle/working/recodai_luc_gate_artifacts/train_cfg.json
# - /kaggle/working/recodai_luc_gate_artifacts/feature_cols.json
# - /kaggle/working/recodai_luc_prof/test_manifest.parquet
# - pred_features_test.csv (from Verification stage)
# - sample_submission.csv (to enforce ordering)
#
# Produces:
# - /kaggle/working/recodai_luc_gate_artifacts/test_table.parquet
# - /kaggle/working/recodai_luc_gate_artifacts/test_X.npy
# - /kaggle/working/recodai_luc_gate_artifacts/test_case_id.npy
# ============================================================

import json
from pathlib import Path
import numpy as np
import pandas as pd

ART_DIR = Path("/kaggle/working/recodai_luc_gate_artifacts")
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")

cfg_path = ART_DIR / "train_cfg.json"
cols_path = ART_DIR / "feature_cols.json"
test_manifest_pq = PROF_DIR / "test_manifest.parquet"
paths_json = PROF_DIR / "paths.json"

for p in [cfg_path, cols_path, test_manifest_pq, paths_json]:
    if not p.exists():
        raise FileNotFoundError(f"Missing {p}. Run previous stages first.")

CFG = json.loads(cfg_path.read_text())
FEATURE_COLS = json.loads(cols_path.read_text())
PATHS = json.loads(paths_json.read_text())

pred_feat_test = Path(CFG["paths"]["PRED_FEATURES_TEST"])
if not pred_feat_test.exists():
    raise FileNotFoundError(f"Missing pred_features_test.csv: {pred_feat_test}")

sample_sub = Path(PATHS["SAMPLE_SUB"])
if not sample_sub.exists():
    raise FileNotFoundError(f"Missing sample_submission.csv: {sample_sub}")

fill_value = float(CFG["features"].get("missing_numeric_fill", 0.0))

# ----------------------------
# 1) Load base test ids (order must follow sample_submission)
# ----------------------------
df_sub = pd.read_csv(sample_sub)
if "case_id" not in df_sub.columns:
    raise ValueError("sample_submission.csv must contain 'case_id'")

sub_ids = df_sub["case_id"].astype(int).tolist()

df_test_meta = pd.read_parquet(test_manifest_pq)[["case_id","img_path","H","W"]].copy()
df_test_meta["case_id"] = df_test_meta["case_id"].astype(int)

# Align to sample_submission (keep all)
df_test_meta = df_test_meta.set_index("case_id").reindex(sub_ids).reset_index()

# ----------------------------
# 2) Load features and dedup per case_id
# ----------------------------
df_feat = pd.read_csv(pred_feat_test)
if "case_id" not in df_feat.columns:
    raise ValueError("pred_features_test.csv must contain 'case_id'")

# keep only needed columns we might use
keep_cols = ["case_id"] + [c for c in FEATURE_COLS if c in df_feat.columns]
df_feat = df_feat[keep_cols].copy()
df_feat["case_id"] = df_feat["case_id"].astype(int)

# Dedup if necessary (max best_peak_score then max area_frac)
sort_keys = []
if "best_peak_score" in df_feat.columns:
    sort_keys.append("best_peak_score")
if "area_frac" in df_feat.columns:
    sort_keys.append("area_frac")
if sort_keys:
    df_feat = df_feat.sort_values(sort_keys, ascending=False).drop_duplicates("case_id", keep="first")
else:
    df_feat = df_feat.drop_duplicates("case_id", keep="first")

# ----------------------------
# 3) Join to build test table
# ----------------------------
df_test_tabular = df_test_meta.merge(df_feat, on="case_id", how="left")

# ensure all features exist
for c in FEATURE_COLS:
    if c not in df_test_tabular.columns:
        df_test_tabular[c] = np.nan

# coerce numeric + fill
for c in FEATURE_COLS:
    df_test_tabular[c] = pd.to_numeric(df_test_tabular[c], errors="coerce")

n_missing_before = int(df_test_tabular[FEATURE_COLS].isna().sum().sum())
df_test_tabular[FEATURE_COLS] = df_test_tabular[FEATURE_COLS].fillna(fill_value)
n_missing_after = int(df_test_tabular[FEATURE_COLS].isna().sum().sum())

# ----------------------------
# 4) Derived features must match TRAIN stage logic
# (These were appended to feature_cols.json by the training-table stage.)
# If they exist in FEATURE_COLS but not in df, compute here.
# ----------------------------
def safe_log1p(x): return np.log1p(np.maximum(x, 0))
def safe_sqrt(x):  return np.sqrt(np.maximum(x, 0))

if "log1p_best_peak_score" in FEATURE_COLS and "log1p_best_peak_score" not in df_test_tabular.columns:
    if "best_peak_score" in df_test_tabular.columns:
        df_test_tabular["log1p_best_peak_score"] = safe_log1p(df_test_tabular["best_peak_score"].values.astype(np.float32))
    else:
        df_test_tabular["log1p_best_peak_score"] = 0.0

if "sqrt_area_frac_tok" in FEATURE_COLS and "sqrt_area_frac_tok" not in df_test_tabular.columns:
    if "area_frac_tok" in df_test_tabular.columns:
        df_test_tabular["sqrt_area_frac_tok"] = safe_sqrt(df_test_tabular["area_frac_tok"].values.astype(np.float32))
    else:
        df_test_tabular["sqrt_area_frac_tok"] = 0.0

if "has_inst" in FEATURE_COLS and "has_inst" not in df_test_tabular.columns:
    if "n_inst" in df_test_tabular.columns:
        df_test_tabular["has_inst"] = (df_test_tabular["n_inst"].values.astype(np.float32) > 0).astype(np.float32)
    else:
        df_test_tabular["has_inst"] = 0.0

# Final safety: ensure all FEATURE_COLS present after derived
for c in FEATURE_COLS:
    if c not in df_test_tabular.columns:
        df_test_tabular[c] = fill_value

# ----------------------------
# 5) Build arrays
# ----------------------------
X_test = df_test_tabular[FEATURE_COLS].to_numpy(dtype=np.float32, copy=True)
case_id_test = df_test_tabular["case_id"].to_numpy(dtype=np.int64, copy=True)

# ----------------------------
# 6) Save
# ----------------------------
df_test_tabular.to_parquet(ART_DIR / "test_table.parquet", index=False)
np.save(ART_DIR / "test_X.npy", X_test)
np.save(ART_DIR / "test_case_id.npy", case_id_test)

print("Saved:")
print(" -", ART_DIR / "test_table.parquet")
print(" -", ART_DIR / "test_X.npy")
print(" -", ART_DIR / "test_case_id.npy")
print("-"*60)
print("Test rows:", len(df_test_tabular), "| missing_before:", n_missing_before, "| missing_after:", n_missing_after)

# Export globals for next stages
X_test = X_test
case_id_test = case_id_test
df_test_tabular = df_test_tabular


# Train Baseline Model (Leakage-Safe CV)

In [None]:
# ============================================================
# STAGE — Train Baseline Model (Leakage-Safe CV) (ONE CELL) — REVISI FULL v8.1 (STACK-CONSISTENT)
# Gate model: LightGBM (fallback: HistGradientBoosting if LGBM missing)
# + calibration (isotonic/sigmoid)
# + threshold sweep using Dice-proxy computed from:
#     - if gate predicts "authentic" => empty mask
#     - else => use FULL-RES mask from pred_ens/train/{case_id}.npz (key: "mask")
#     - GT mask union loaded from: train_masks / supplemental_masks (npy union OR png union)
# Dice-proxy uses SAME postprocess spirit as submission (instance split & filtering),
# but runs on FULL-RES masks (robust for any tok/grid mismatch).
#
# Requires:
# - /kaggle/working/recodai_luc_gate_artifacts/train_cfg.json
# - /kaggle/working/recodai_luc_gate_artifacts/feature_cols.json
# - /kaggle/working/recodai_luc_gate_artifacts/train_X.npy, train_y.npy, train_folds.npy
# - /kaggle/working/recodai_luc_prof/paths.json
# - /kaggle/working/recodai_luc_gate_artifacts/train_table.parquet (case_id order)
# - pred_ens train cache: /kaggle/working/recodai_luc/cache/pred_ens/train/{case_id}.npz
#
# Produces:
# - models/: fold_*.txt (or .joblib fallback), calibration*.joblib, calibration.json
# - oof/: oof_prob_raw.npy, oof_prob.npy, oof_prob.csv
# - eval/: threshold_table.csv + best_threshold.json + fold_metrics.json
# - final_gate_model.pt  (portable torch-save dict)
# ============================================================

import os, json, time, warnings
from pathlib import Path
import numpy as np
import pandas as pd
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")

ART_DIR  = Path("/kaggle/working/recodai_luc_gate_artifacts")
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")
CACHE_DIR = Path("/kaggle/working/recodai_luc/cache")

cfg_path  = ART_DIR / "train_cfg.json"
cols_path = ART_DIR / "feature_cols.json"
X_path    = ART_DIR / "train_X.npy"
y_path    = ART_DIR / "train_y.npy"
f_path    = ART_DIR / "train_folds.npy"
paths_json = PROF_DIR / "paths.json"
train_table_pq = ART_DIR / "train_table.parquet"

for p in [cfg_path, cols_path, X_path, y_path, f_path, paths_json, train_table_pq]:
    if not p.exists():
        raise FileNotFoundError(f"Missing {p}. Run previous stages first.")

CFG = json.loads(cfg_path.read_text())
FEATURE_COLS = json.loads(cols_path.read_text())
PATHS = json.loads(paths_json.read_text())

X = np.load(X_path).astype(np.float32)
y = np.load(y_path).astype(np.int64)
folds = np.load(f_path).astype(np.int64)
N = len(y)

if X.shape[0] != N or folds.shape[0] != N:
    raise RuntimeError("Shape mismatch among train_X/train_y/train_folds")

# dirs
MODEL_DIR = ART_DIR / "models"
OOF_DIR   = ART_DIR / "oof"
EVAL_DIR  = ART_DIR / "eval"
MODEL_DIR.mkdir(parents=True, exist_ok=True)
OOF_DIR.mkdir(parents=True, exist_ok=True)
EVAL_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# Model import (LGBM preferred)
# ----------------------------
use_lgbm = True
try:
    import lightgbm as lgb
except Exception:
    use_lgbm = False

from sklearn.metrics import roc_auc_score, log_loss, f1_score, precision_score, recall_score
from sklearn.isotonic import IsotonicRegression
from sklearn.linear_model import LogisticRegression
import joblib

# SciPy optional (CC)
try:
    import scipy.ndimage as ndi
    _HAS_SCIPY = True
except Exception:
    _HAS_SCIPY = False

seed = int(CFG.get("cv", {}).get("seed", 42))
fold_ids = sorted(np.unique(folds).tolist())
n_folds = len(fold_ids)

# ----------------------------
# Instance-split / filtering (FULL-RES) for Dice-proxy
# Keep in sync with Verification intent
# ----------------------------
PP = CFG.get("postprocess_ref", {})
MIN_INST_PIX = int(PP.get("min_inst_pix", 32))            # drop tiny CC at full-res
MAX_AREA_FRAC = float(PP.get("max_area_frac", 0.90))      # drop huge CC
MAX_INST_KEEP = int(PP.get("max_inst_keep", 8))           # keep top-K by area

# ----------------------------
# Helpers: GT union loader (npy union OR PNG union)
# ----------------------------
TRAIN_MASK_DIR = Path(PATHS.get("TRAIN_MASK_DIR","")) if PATHS.get("TRAIN_MASK_DIR") else None
SUP_MASK_DIR   = Path(PATHS.get("SUP_MASK_DIR","")) if PATHS.get("SUP_MASK_DIR") else None

def _find_mask_files(mask_dir: Path, case_id: int):
    if mask_dir is None or (not mask_dir.exists()):
        return []
    cid = str(int(case_id))
    exts = (".png",".jpg",".jpeg",".tif",".tiff",".bmp")
    pats = [
        f"{cid}*.png", f"{cid}*.jpg", f"{cid}*.jpeg", f"{cid}*.tif", f"{cid}*.tiff", f"{cid}*.bmp",
        f"{cid}__*.png", f"{cid}_*.png"
    ]
    out, seen = [], set()
    for pat in pats:
        for p in mask_dir.glob(pat):
            if p.suffix.lower() in exts:
                s = str(p)
                if s not in seen:
                    out.append(p); seen.add(s)
    return sorted(out)

def load_gt_union(case_id: int):
    # 1) union cache npy if exists
    for d in [TRAIN_MASK_DIR, SUP_MASK_DIR]:
        if d is None or (not d.exists()):
            continue
        npy = d / f"{int(case_id)}.npy"
        if npy.exists():
            a = np.load(npy, mmap_mode="r")
            if a.ndim == 2:
                return (np.asarray(a) > 0)
            if a.ndim == 3:
                return (np.asarray(a) > 0).any(axis=0)

    # 2) union PNGs (multi instances)
    files = []
    if TRAIN_MASK_DIR is not None: files += _find_mask_files(TRAIN_MASK_DIR, case_id)
    if SUP_MASK_DIR is not None:   files += _find_mask_files(SUP_MASK_DIR, case_id)
    if not files:
        return None

    m = None
    for p in files:
        try:
            im = Image.open(p).convert("L")
            a = (np.asarray(im) > 0)
            m = a if m is None else (m | a)
        except Exception:
            continue
    return m

# ----------------------------
# Pred mask loader (FULL-RES) from pred_ens
# ----------------------------
PRED_ENS_DIR = Path(CFG["paths"]["PRED_ENS_DIR"]) if ("paths" in CFG and "PRED_ENS_DIR" in CFG["paths"]) else (CACHE_DIR / "pred_ens")
PRED_TRAIN_DIR = PRED_ENS_DIR / "train"

def load_pred_union(case_id: int):
    p = PRED_TRAIN_DIR / f"{int(case_id)}.npz"
    if not p.exists():
        return None
    z = np.load(p)
    if "mask" not in z.files:
        return None
    m = z["mask"]
    return (m > 0)

# ----------------------------
# CC instance filtering (FULL-RES)
# ----------------------------
def cc_union_filtered(mask_bool: np.ndarray):
    if mask_bool is None:
        return None, {"n_inst": 0, "area": 0}
    m = mask_bool.astype(bool)
    H, W = m.shape
    area = int(m.sum())
    if area == 0:
        return m, {"n_inst": 0, "area": 0}

    if (area / float(H*W)) > MAX_AREA_FRAC:
        return np.zeros_like(m, dtype=bool), {"n_inst": 0, "area": 0}

    if not _HAS_SCIPY:
        if area < MIN_INST_PIX:
            return np.zeros_like(m, dtype=bool), {"n_inst": 0, "area": 0}
        return m, {"n_inst": 1, "area": area}

    lab, n = ndi.label(m, structure=np.ones((3,3), dtype=np.uint8))
    if n <= 0:
        return np.zeros_like(m, dtype=bool), {"n_inst": 0, "area": 0}

    areas = ndi.sum(m.astype(np.uint8), lab, index=np.arange(1, n+1)).astype(np.int64)
    keep = np.where(areas >= MIN_INST_PIX)[0]
    if keep.size == 0:
        return np.zeros_like(m, dtype=bool), {"n_inst": 0, "area": 0}

    keep = keep[np.argsort(areas[keep])[::-1]]
    keep = keep[:MAX_INST_KEEP]

    out = np.zeros_like(m, dtype=bool)
    for k in keep:
        out |= (lab == (k + 1))
    return out, {"n_inst": int(len(keep)), "area": int(out.sum())}

def dice_score(pr: np.ndarray, gt: np.ndarray):
    pr = pr.astype(bool)
    gt = gt.astype(bool)
    a = int(pr.sum()); b = int(gt.sum())
    if a == 0 and b == 0:
        return 1.0
    if a == 0 or b == 0:
        return 0.0
    inter = int((pr & gt).sum())
    return float((2.0 * inter) / (a + b))

# ----------------------------
# 1) Train fold models -> OOF prob (raw)
# ----------------------------
oof_prob_raw = np.zeros(N, dtype=np.float32)
fold_models = []
fold_metrics = []

params = CFG.get("gate_model", {}).get("params", {})
early_rounds = int(CFG.get("gate_model", {}).get("early_stopping_rounds", 200))

t0 = time.time()
print("Model:", "LightGBM" if use_lgbm else "HistGradientBoosting (fallback)")
print("Folds:", n_folds, "| N:", N, "| pos_rate:", float(y.mean()))
print("PRED_TRAIN_DIR:", PRED_TRAIN_DIR)
print("InstanceSplit:", {"MIN_INST_PIX": MIN_INST_PIX, "MAX_AREA_FRAC": MAX_AREA_FRAC, "MAX_INST_KEEP": MAX_INST_KEEP})
print("-"*60)

if use_lgbm:
    for f in fold_ids:
        tr_idx = np.where(folds != f)[0]
        va_idx = np.where(folds == f)[0]

        dtr = lgb.Dataset(X[tr_idx], label=y[tr_idx], feature_name=FEATURE_COLS, free_raw_data=True)
        dva = lgb.Dataset(X[va_idx], label=y[va_idx], feature_name=FEATURE_COLS, free_raw_data=True)

        booster = lgb.train(
            params=params,
            train_set=dtr,
            valid_sets=[dva],
            valid_names=["val"],
            num_boost_round=int(params.get("n_estimators", 2000)),
            callbacks=[lgb.early_stopping(stopping_rounds=early_rounds, verbose=False)]
        )

        p = booster.predict(X[va_idx], num_iteration=booster.best_iteration)
        p = np.clip(p, 1e-6, 1-1e-6).astype(np.float32)
        oof_prob_raw[va_idx] = p

        model_path = MODEL_DIR / f"fold_{f}.txt"
        booster.save_model(str(model_path))

        auc = float(roc_auc_score(y[va_idx], p)) if len(np.unique(y[va_idx])) > 1 else float("nan")
        ll  = float(log_loss(y[va_idx], p, labels=[0,1]))
        pred05 = (p >= 0.5).astype(int)
        f1v = float(f1_score(y[va_idx], pred05))
        fold_metrics.append({"fold": int(f), "auc": auc, "logloss": ll, "f1@0.5": f1v, "best_iter": int(booster.best_iteration)})
        fold_models.append({"fold": int(f), "path": str(model_path), "best_iter": int(booster.best_iteration)})

        print(f"[fold {f}] auc={auc:.4f} logloss={ll:.4f} f1@0.5={f1v:.4f} iter={int(booster.best_iteration)}")

else:
    from sklearn.ensemble import HistGradientBoostingClassifier
    for f in fold_ids:
        tr_idx = np.where(folds != f)[0]
        va_idx = np.where(folds == f)[0]

        clf = HistGradientBoostingClassifier(
            learning_rate=0.05,
            max_depth=None,
            max_leaf_nodes=63,
            min_samples_leaf=50,
            l2_regularization=1.0,
            max_iter=500,
            random_state=seed + int(f),
        )
        clf.fit(X[tr_idx], y[tr_idx])
        p = clf.predict_proba(X[va_idx])[:, 1]
        p = np.clip(p, 1e-6, 1-1e-6).astype(np.float32)
        oof_prob_raw[va_idx] = p

        model_path = MODEL_DIR / f"fold_{f}.joblib"
        joblib.dump(clf, model_path)

        auc = float(roc_auc_score(y[va_idx], p)) if len(np.unique(y[va_idx])) > 1 else float("nan")
        ll  = float(log_loss(y[va_idx], p, labels=[0,1]))
        pred05 = (p >= 0.5).astype(int)
        f1v = float(f1_score(y[va_idx], pred05))
        fold_metrics.append({"fold": int(f), "auc": auc, "logloss": ll, "f1@0.5": f1v})
        fold_models.append({"fold": int(f), "path": str(model_path)})

        print(f"[fold {f}] auc={auc:.4f} logloss={ll:.4f} f1@0.5={f1v:.4f}")

auc_all = float(roc_auc_score(y, oof_prob_raw)) if len(np.unique(y)) > 1 else float("nan")
ll_all  = float(log_loss(y, oof_prob_raw, labels=[0,1]))
pred05_all = (oof_prob_raw >= 0.5).astype(int)
f1_all = float(f1_score(y, pred05_all))
prec_all = float(precision_score(y, pred05_all, zero_division=0))
rec_all  = float(recall_score(y, pred05_all, zero_division=0))

print("-"*60)
print(f"OOF raw: auc={auc_all:.4f} logloss={ll_all:.4f} f1@0.5={f1_all:.4f} prec={prec_all:.4f} rec={rec_all:.4f}")
print("Train CV finished in", f"{time.time()-t0:.1f}s")

# ----------------------------
# 2) Calibration (fit on OOF)
# ----------------------------
cal_cfg = CFG.get("calibration", {"enabled": True, "method": "isotonic"})
cal_enabled = bool(cal_cfg.get("enabled", True))
cal_method = str(cal_cfg.get("method", "isotonic")).lower()

oof_prob = oof_prob_raw.copy()
cal_pack = {"enabled": False}

if cal_enabled:
    if cal_method == "isotonic":
        iso = IsotonicRegression(out_of_bounds="clip")
        iso.fit(oof_prob_raw, y)
        oof_prob = np.clip(iso.predict(oof_prob_raw).astype(np.float32), 1e-6, 1-1e-6)
        cal_pack = {"enabled": True, "method": "isotonic"}
        joblib.dump(iso, MODEL_DIR / "calibration_isotonic.joblib")
    elif cal_method in ["sigmoid", "platt"]:
        p = np.clip(oof_prob_raw, 1e-6, 1-1e-6)
        logit = np.log(p/(1-p)).reshape(-1,1)
        lr = LogisticRegression(solver="lbfgs", max_iter=200)
        lr.fit(logit, y)
        oof_prob = np.clip(lr.predict_proba(logit)[:,1].astype(np.float32), 1e-6, 1-1e-6)
        cal_pack = {"enabled": True, "method": "sigmoid"}
        joblib.dump(lr, MODEL_DIR / "calibration_sigmoid.joblib")
    else:
        print("WARNING: unknown calibration method:", cal_method, "-> skip calibration")
        cal_pack = {"enabled": False}

auc_cal = float(roc_auc_score(y, oof_prob)) if len(np.unique(y)) > 1 else float("nan")
ll_cal  = float(log_loss(y, oof_prob, labels=[0,1]))
pred05_cal = (oof_prob >= 0.5).astype(int)
f1_cal = float(f1_score(y, pred05_cal))

(MODEL_DIR / "calibration.json").write_text(json.dumps(cal_pack, indent=2))
print("-"*60)
print(f"OOF calibrated: auc={auc_cal:.4f} logloss={ll_cal:.4f} f1@0.5={f1_cal:.4f} | calibration={cal_pack}")

# save OOF arrays + csv
np.save(OOF_DIR / "oof_prob_raw.npy", oof_prob_raw)
np.save(OOF_DIR / "oof_prob.npy", oof_prob)

df_train_tab = pd.read_parquet(train_table_pq)[["case_id"]].copy()
df_train_tab["case_id"] = df_train_tab["case_id"].astype(int)
if len(df_train_tab) != N:
    raise RuntimeError("train_table.parquet rows != train_y length (order mismatch). Rebuild train table.")
df_train_tab["y"] = y
df_train_tab["oof_prob_raw"] = oof_prob_raw
df_train_tab["oof_prob"] = oof_prob
df_train_tab.to_csv(OOF_DIR / "oof_prob.csv", index=False)

# ----------------------------
# 3) Dice-proxy arrays (FULL-RES, instance-split consistent)
# ----------------------------
case_ids = df_train_tab["case_id"].tolist()

dice_use  = np.zeros(N, dtype=np.float32)  # if we "use" predicted mask
dice_empty= np.zeros(N, dtype=np.float32)  # if gate outputs "authentic" (empty)

t1 = time.time()
miss_gt = miss_pr = bad_shape = 0

for i, cid in enumerate(case_ids):
    gt = load_gt_union(cid)
    pr = load_pred_union(cid)

    # GT missing => empty (safe)
    if gt is None:
        miss_gt += 1
        gt_mask = None
        gt_empty = True
    else:
        gt_mask = gt.astype(bool)
        gt_empty = (gt_mask.sum() == 0)

    dice_empty[i] = 1.0 if gt_empty else 0.0

    if pr is None:
        miss_pr += 1
        dice_use[i] = dice_empty[i]
        continue

    pr_mask = pr.astype(bool)

    # shape align if needed
    if gt_mask is not None and pr_mask.shape != gt_mask.shape:
        bad_shape += 1
        # resize pred to gt shape
        im = Image.fromarray((pr_mask.astype(np.uint8)*255))
        im = im.resize((gt_mask.shape[1], gt_mask.shape[0]), resample=Image.NEAREST)
        pr_mask = (np.asarray(im) > 127)

    # instance filtering
    pr_f, _ = cc_union_filtered(pr_mask)

    if gt_mask is None:
        dice_use[i] = 1.0 if (pr_f.sum() == 0) else 0.0
    else:
        gt_f, _ = cc_union_filtered(gt_mask)
        dice_use[i] = dice_score(pr_f, gt_f)

    if (i + 1) % 500 == 0:
        print(f"[dice-proxy] {i+1}/{N} | miss_gt={miss_gt} miss_pr={miss_pr} bad_shape={bad_shape} | {time.time()-t1:.1f}s")

print("-"*60)
print(f"Dice-proxy ready | miss_gt={miss_gt} miss_pr={miss_pr} bad_shape={bad_shape} | {time.time()-t1:.1f}s")

# ----------------------------
# 4) Threshold sweep (optimize proxy score)
# ----------------------------
thr_grid = np.linspace(0.0, 1.0, 201, dtype=np.float32)
rows = []
for thr in thr_grid:
    use = (oof_prob >= thr)
    score = np.where(use, dice_use, dice_empty).mean()

    pred = use.astype(int)
    f1v = f1_score(y, pred)
    prec = precision_score(y, pred, zero_division=0)
    rec  = recall_score(y, pred, zero_division=0)
    fp_rate = float(((pred==1) & (y==0)).mean())
    fn_rate = float(((pred==0) & (y==1)).mean())

    rows.append({
        "thr": float(thr),
        "score_dice_proxy": float(score),
        "f1_gate": float(f1v),
        "precision_gate": float(prec),
        "recall_gate": float(rec),
        "fp_rate": fp_rate,
        "fn_rate": fn_rate,
    })

df_thr = pd.DataFrame(rows)
best_i = int(df_thr["score_dice_proxy"].values.argmax())
best_thr = float(df_thr.loc[best_i, "thr"])
best_score = float(df_thr.loc[best_i, "score_dice_proxy"])

df_thr.to_csv(EVAL_DIR / "threshold_table.csv", index=False)
(EVAL_DIR / "best_threshold.json").write_text(json.dumps({
    "recommended_thr": best_thr,
    "best_score_dice_proxy": best_score,
    "best_row": df_thr.loc[best_i].to_dict(),
    "instance_split_fullres": {"MIN_INST_PIX": MIN_INST_PIX, "MAX_AREA_FRAC": MAX_AREA_FRAC, "MAX_INST_KEEP": MAX_INST_KEEP},
}, indent=2))

print("-"*60)
print("BEST (dice-proxy): thr =", best_thr, "| score =", best_score)
print("Saved:", EVAL_DIR / "threshold_table.csv")
print("Saved:", EVAL_DIR / "best_threshold.json")

# ----------------------------
# 5) Save fold metrics + portable bundle
# ----------------------------
summary = {
    "oof_raw": {"auc": auc_all, "logloss": ll_all, "f1@0.5": f1_all},
    "oof_cal": {"auc": auc_cal, "logloss": ll_cal, "f1@0.5": f1_cal},
    "recommended_thr": best_thr,
    "best_score_dice_proxy": best_score,
    "fold_metrics": fold_metrics,
    "model_paths": fold_models,
    "calibration": cal_pack,
    "feature_cols": FEATURE_COLS,
    "dice_proxy": {
        "uses_fullres_pred_ens_mask": True,
        "pred_train_dir": str(PRED_TRAIN_DIR),
        "gt_dirs": {"train": str(TRAIN_MASK_DIR) if TRAIN_MASK_DIR else None,
                    "supp":  str(SUP_MASK_DIR) if SUP_MASK_DIR else None},
        "instance_split_fullres": {"MIN_INST_PIX": MIN_INST_PIX, "MAX_AREA_FRAC": MAX_AREA_FRAC, "MAX_INST_KEEP": MAX_INST_KEEP},
    }
}
(EVAL_DIR / "fold_metrics.json").write_text(json.dumps(summary, indent=2))

import torch
bundle = {
    "cfg": CFG,
    "feature_cols": FEATURE_COLS,
    "fold_models": fold_models,       # file paths
    "calibration": cal_pack,
    "recommended_thr": best_thr,
    "instance_split_fullres": {"MIN_INST_PIX": MIN_INST_PIX, "MAX_AREA_FRAC": MAX_AREA_FRAC, "MAX_INST_KEEP": MAX_INST_KEEP},
    "notes": {
        "oof_prob_raw": str(OOF_DIR / "oof_prob_raw.npy"),
        "oof_prob": str(OOF_DIR / "oof_prob.npy"),
        "threshold_table": str(EVAL_DIR / "threshold_table.csv"),
        "fold_metrics": str(EVAL_DIR / "fold_metrics.json"),
        "dice_proxy_note": "mean Dice proxy using pred_ens full-res mask when gate predicts forged; empty when authentic; CC-filter applied to mimic submission instance filtering",
    }
}
torch.save(bundle, ART_DIR / "final_gate_model.pt")

print("-"*60)
print("Saved:", EVAL_DIR / "fold_metrics.json")
print("Saved:", ART_DIR / "final_gate_model.pt")
print("Saved:", OOF_DIR / "oof_prob.csv")


# Optimize Model & Hyperparameters (Iterative)

In [None]:
# ============================================================
# STAGE — Optimize Model & Hyperparameters (Iterative) (ONE CELL)
# Hybrid Model (OPSI-1): UNet(+ASPP) token-decoder + Gate Head (classification) in ONE NETWORK
# Input  : DINOv2 token-grid (Ht,Wt,D) + seed (Ht,Wt,1) from Robust Matching
# Output : prob_tok (seg) + p_forged (gate)
# Score  : Dice-proxy on VAL (if p_forged<thr => empty else fuse(seg,seed)->inst-split->dice vs GT tok)
#
# REQUIRE (expected from earlier stages):
# - train_table.parquet with columns (at minimum): case_id, fold, y, and a token path column:
#     tok_path / token_path / dino_path / feat_path / emb_path  (auto-detect)
# - Robust Matching cache: match_cfg_* with match_manifest_train.parquet
# - paths.json for TRAIN_MASK_DIR / SUP_MASK_DIR (GT)
#
# OUTPUT:
# - /kaggle/working/recodai_luc_hybrid_opt/trials.csv
# - /kaggle/working/recodai_luc_hybrid_opt/best_config.json
# - /kaggle/working/recodai_luc_hybrid_opt/best_model.pt
# ============================================================

import os, json, time, math, random, warnings
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# Config (safe defaults; tweak via env vars if needed)
# ----------------------------
OUT_DIR = Path("/kaggle/working/recodai_luc_hybrid_opt")
OUT_DIR.mkdir(parents=True, exist_ok=True)

PROF_DIR  = Path("/kaggle/working/recodai_luc_prof")
CACHE_DIR = Path("/kaggle/working/recodai_luc/cache")

SEED = int(os.environ.get("SEED", "42"))
MAX_TRIALS = int(os.environ.get("MAX_TRIALS", "12"))          # keep small; increase if time
TRIAL_EPOCHS = int(os.environ.get("TRIAL_EPOCHS", "6"))        # quick trials
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", "32"))           # token-grid is small
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", "2"))
ACCUM_STEPS = int(os.environ.get("ACCUM_STEPS", "1"))
USE_AMP = bool(int(os.environ.get("USE_AMP", "1")))
EARLYSTOP_PATIENCE = int(os.environ.get("EARLYSTOP_PATIENCE", "2"))

# tune-space (reasonable)
LR_RANGE = (2e-4, 2e-3)
WD_RANGE = (0.0, 0.05)
DROPOUT_RANGE = (0.0, 0.25)
BASE_CH_CHOICES = [64, 96, 128]
LAMBDA_SEG_RANGE = (0.6, 1.2)
LAMBDA_CLS_RANGE = (0.3, 0.9)
FOCAL_GAMMA_CHOICES = [0.0, 1.0, 2.0]  # 0 -> BCE only

# postprocess / fuse params
T1_RANGE = (0.50, 0.65)
T0_RANGE = (0.25, 0.45)
SEED_DILATE_CHOICES = [0, 1, 2]
THR_GATE_RANGE = (0.20, 0.80)

# instance split token-space
MIN_TOK_AREA_CHOICES = [1, 2, 3]
MAX_TOK_AREA_FRAC_CHOICES = [0.70, 0.80, 0.90]
MAX_INST_KEEP_CHOICES = [4, 8, 12]

# guard
MIN_PEAK_SCORE_KEEP_CHOICES = [5, 6, 7]
MIN_AREA_FRAC_KEEP_CHOICES = [0.0003, 0.0005, 0.0010]

# folds usage for fast opt (use 1 fold val each trial; rotate)
VAL_FOLD_ROTATE = True

# ----------------------------
# Repro
# ----------------------------
def seed_everything(s=42):
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
seed_everything(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("DEVICE:", device, "| AMP:", USE_AMP and device.type == "cuda")

# ----------------------------
# Load PATHS + train_table
# ----------------------------
paths_json = PROF_DIR / "paths.json"
if not paths_json.exists():
    raise FileNotFoundError(f"Missing {paths_json}")

PATHS = json.loads(paths_json.read_text())
TRAIN_MASK_DIR = Path(PATHS.get("TRAIN_MASK_DIR","")) if PATHS.get("TRAIN_MASK_DIR") else None
SUP_MASK_DIR   = Path(PATHS.get("SUP_MASK_DIR","")) if PATHS.get("SUP_MASK_DIR") else None

table_candidates = [
    Path("/kaggle/working/recodai_luc_gate_artifacts/train_table.parquet"),
    Path("/kaggle/working/recodai_luc_hybrid_artifacts/train_table.parquet"),
    PROF_DIR / "train_table.parquet",
]
TRAIN_TABLE = None
for p in table_candidates:
    if p.exists():
        TRAIN_TABLE = p
        break
if TRAIN_TABLE is None:
    raise FileNotFoundError("Cannot find train_table.parquet in known locations. Run Build Training Table stage first.")

df = pd.read_parquet(TRAIN_TABLE).copy()
for need in ["case_id","fold","y"]:
    if need not in df.columns:
        raise ValueError(f"train_table missing required col: {need}")
df["case_id"] = df["case_id"].astype(int)
df["fold"] = df["fold"].astype(int)
df["y"] = df["y"].astype(int)

# token path column auto-detect
tok_col = None
for c in ["tok_path","token_path","dino_path","feat_path","emb_path","token_npz","npz_path"]:
    if c in df.columns:
        tok_col = c
        break
if tok_col is None:
    raise ValueError("train_table.parquet must contain a token path column (tok_path/token_path/dino_path/feat_path/emb_path).")

df[tok_col] = df[tok_col].astype(str)
ok_tok = df[tok_col].map(lambda p: Path(p).exists())
if ok_tok.mean() < 0.90:
    print("WARNING: token path existence rate is low:", float(ok_tok.mean()))
df = df[ok_tok].reset_index(drop=True)

fold_ids = sorted(df["fold"].unique().tolist())
print("TRAIN_TABLE:", TRAIN_TABLE, "| rows:", len(df), "| folds:", fold_ids, "| pos_rate:", float(df["y"].mean()))
print("TOK_COL:", tok_col)

# ----------------------------
# Latest MATCH_ROOT + map case_id -> (match_npz, best_peak_score)
# ----------------------------
def pick_latest_match_root():
    cands = sorted(CACHE_DIR.glob("match_cfg_*"))
    cands = [c for c in cands if (c/"cfg.json").exists() and (c/"match_manifest_train.parquet").exists()]
    if not cands:
        raise FileNotFoundError("Cannot find match_cfg_* under /kaggle/working/recodai_luc/cache. Run Robust Matching stage first.")
    cands = sorted(cands, key=lambda p: (p/"cfg.json").stat().st_mtime, reverse=True)
    return cands[0]

MATCH_ROOT = pick_latest_match_root()
MATCH_CFG = json.loads((MATCH_ROOT/"cfg.json").read_text())
PATCH = int(MATCH_CFG.get("patch", MATCH_CFG.get("patch_size", 14)))
HTOK  = int(MATCH_CFG.get("Ht", MATCH_CFG.get("htok", 37)))
WTOK  = int(MATCH_CFG.get("Wt", MATCH_CFG.get("wtok", 37)))

mtrain_pq = MATCH_ROOT / "match_manifest_train.parquet"
df_m = pd.read_parquet(mtrain_pq).copy()
df_m["case_id"] = df_m["case_id"].astype(int)
if "match_npz" not in df_m.columns:
    raise ValueError("match_manifest_train.parquet missing match_npz column")

# pick best (prefer a score col if exists else latest mtime)
score_cols = [c for c in ["best_peak_score","peak_score_max","max_peak_score","score_max","best_score"] if c in df_m.columns]
if score_cols:
    sc = score_cols[0]
    df_m[sc] = pd.to_numeric(df_m[sc], errors="coerce").fillna(-1)
    df_m = df_m.sort_values(["case_id", sc], ascending=[True, False]).drop_duplicates("case_id", keep="first")
else:
    def _mtime(p):
        try: return Path(p).stat().st_mtime
        except Exception: return -1
    df_m["_mtime"] = df_m["match_npz"].map(_mtime)
    df_m = df_m.sort_values(["case_id","_mtime"], ascending=[True, False]).drop_duplicates("case_id", keep="first")

match_map = df_m.set_index("case_id")["match_npz"].to_dict()
print("MATCH_ROOT:", MATCH_ROOT, "| PATCH/HTOK/WTOK:", PATCH, HTOK, WTOK)

# ----------------------------
# GT union loader + downsample to token grid
# ----------------------------
def _find_mask_files(mask_dir: Path, case_id: int):
    if mask_dir is None or (not mask_dir.exists()):
        return []
    cid = str(int(case_id))
    exts = (".png",".jpg",".jpeg",".tif",".tiff",".bmp")
    pats = [f"{cid}*.png", f"{cid}*.jpg", f"{cid}*.jpeg", f"{cid}*.tif", f"{cid}*.tiff", f"{cid}*.bmp",
            f"{cid}__*.png", f"{cid}_*.png"]
    out, seen = [], set()
    for pat in pats:
        for p in mask_dir.glob(pat):
            if p.suffix.lower() in exts:
                s = str(p)
                if s not in seen:
                    out.append(p); seen.add(s)
    return sorted(out)

def load_gt_union_full(case_id: int):
    # fast npy union if exists
    for d in [TRAIN_MASK_DIR, SUP_MASK_DIR]:
        if d is None or (not d.exists()):
            continue
        npy = d / f"{int(case_id)}.npy"
        if npy.exists():
            a = np.load(npy, mmap_mode="r")
            if a.ndim == 2:
                return (np.asarray(a) > 0)
            if a.ndim == 3:
                return (np.asarray(a) > 0).any(axis=0)
    # png union
    files = []
    if TRAIN_MASK_DIR is not None: files += _find_mask_files(TRAIN_MASK_DIR, case_id)
    if SUP_MASK_DIR is not None:   files += _find_mask_files(SUP_MASK_DIR, case_id)
    if not files:
        return None
    m = None
    for p in files:
        try:
            im = Image.open(p).convert("L")
            a = (np.asarray(im) > 0)
            m = a if m is None else (m | a)
        except Exception:
            continue
    return m

def downsample_bool_to_tok(mask_bool: np.ndarray):
    if mask_bool is None:
        return np.zeros((HTOK,WTOK), dtype=np.float32)
    im = Image.fromarray((mask_bool.astype(np.uint8)*255))
    im = im.resize((WTOK, HTOK), resample=Image.NEAREST)
    return (np.asarray(im) > 127).astype(np.float32)

# ----------------------------
# Seed builder from match_npz (token union)
# ----------------------------
def load_seed_tok(case_id: int, topk=None):
    p = match_map.get(int(case_id), None)
    if p is None or (not Path(p).exists()):
        return np.zeros((HTOK,WTOK), dtype=np.float32), 0
    z = np.load(p)
    scores = z["peak_score"] if "peak_score" in z.files else np.zeros((0,), np.int32)
    src = z["src_masks"] if "src_masks" in z.files else np.zeros((0,HTOK,WTOK), np.uint8)
    tgt = z["tgt_masks"] if "tgt_masks" in z.files else np.zeros((0,HTOK,WTOK), np.uint8)
    if src.ndim != 3 or tgt.ndim != 3 or src.shape[0] == 0:
        return np.zeros((HTOK,WTOK), dtype=np.float32), int(scores.max()) if len(scores) else 0
    best_score = int(scores.max()) if len(scores) else 0
    if topk is not None and src.shape[0] > topk:
        # keep topk by score (desc)
        idx = np.argsort(scores)[::-1][:topk]
        src = src[idx]; tgt = tgt[idx]
    seed = ((src>0) | (tgt>0)).any(axis=0).astype(np.float32)
    return seed, best_score

# ----------------------------
# Morphology / CC (token space)
# ----------------------------
try:
    import scipy.ndimage as ndi
    _HAS_SCIPY = True
except Exception:
    _HAS_SCIPY = False

def dilate_tok(x_bool, it=1):
    if it <= 0: return x_bool
    x = x_bool.astype(bool)
    if _HAS_SCIPY:
        return ndi.binary_dilation(x, iterations=it)
    for _ in range(it):
        xp = np.pad(x, 1, mode="constant", constant_values=False)
        y = np.zeros_like(x, dtype=bool)
        for dy in (-1,0,1):
            for dx in (-1,0,1):
                y |= xp[1+dy:1+dy+x.shape[0], 1+dx:1+dx+x.shape[1]]
        x = y
    return x

def label_cc(x_bool):
    x = x_bool.astype(bool)
    if _HAS_SCIPY:
        lab, n = ndi.label(x, structure=np.ones((3,3), dtype=np.uint8))
        return lab, int(n)
    H,W = x.shape
    lab = np.zeros((H,W), dtype=np.int32); cur=0
    for y in range(H):
        for x0 in range(W):
            if (not x[y,x0]) or lab[y,x0]!=0: continue
            cur += 1
            st=[(y,x0)]; lab[y,x0]=cur
            while st:
                yy,xx=st.pop()
                for dy in (-1,0,1):
                    for dx in (-1,0,1):
                        if dy==0 and dx==0: continue
                        ny,nx=yy+dy,xx+dx
                        if 0<=ny<H and 0<=nx<W and x[ny,nx] and lab[ny,nx]==0:
                            lab[ny,nx]=cur; st.append((ny,nx))
    return lab, int(cur)

def inst_split_union_tok(mask_bool, min_area=2, max_area_frac=0.8, max_keep=8):
    H,W = mask_bool.shape
    lab,n = label_cc(mask_bool)
    if n<=0:
        return np.zeros((H,W), dtype=bool), 0
    insts=[]
    areas=[]
    for k in range(1,n+1):
        m = (lab==k)
        a = int(m.sum())
        if a < min_area: continue
        if a / float(H*W) > max_area_frac: continue
        insts.append(m); areas.append(a)
    if not insts:
        return np.zeros((H,W), dtype=bool), 0
    order = np.argsort(np.asarray(areas))[::-1][:max_keep]
    uni = np.zeros((H,W), dtype=bool)
    for i in order:
        uni |= insts[i]
    return uni, int(len(order))

def dice(pr_bool, gt_bool):
    a=int(pr_bool.sum()); b=int(gt_bool.sum())
    if a==0 and b==0: return 1.0
    if a==0 or b==0: return 0.0
    inter=int((pr_bool & gt_bool).sum())
    return float((2.0*inter)/(a+b))

# ----------------------------
# Dataset
# ----------------------------
def load_tok_npz(path_str: str):
    z = np.load(path_str)
    # try common keys
    for k in ["tok","tokens","grid","token_grid","x","feat","emb","f"]:
        if k in z.files:
            a = z[k]
            break
    else:
        # fallback: first key
        keys = list(z.files)
        if not keys: raise ValueError("empty npz")
        a = z[keys[0]]
    a = np.asarray(a)
    # ensure (Ht,Wt,D)
    if a.ndim == 3 and a.shape[0] == HTOK and a.shape[1] == WTOK:
        return a.astype(np.float32)
    if a.ndim == 3 and a.shape[-2] == HTOK and a.shape[-1] == WTOK:
        # (D,Ht,Wt) -> (Ht,Wt,D)
        return np.transpose(a, (1,2,0)).astype(np.float32)
    # last resort: resize spatial to tok
    if a.ndim == 3:
        # assume (H,W,D)
        H,W,D = a.shape
        # resize each channel (slow but rare)
        out = np.zeros((HTOK,WTOK,D), np.float32)
        for d in range(D):
            im = Image.fromarray(a[:,:,d].astype(np.float32))
            im = im.resize((WTOK,HTOK), resample=Image.BILINEAR)
            out[:,:,d] = np.asarray(im).astype(np.float32)
        return out
    raise ValueError(f"Unknown token array shape: {a.shape}")

class HybridTokDS(Dataset):
    def __init__(self, df_in: pd.DataFrame, tok_col: str):
        self.df = df_in.reset_index(drop=True)
        self.tok_col = tok_col
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        cid = int(r["case_id"])
        y = float(r["y"])
        tok = load_tok_npz(r[self.tok_col])  # (Ht,Wt,D)
        seed, best_score = load_seed_tok(cid, topk=8)
        gt_full = load_gt_union_full(cid)
        gt_tok = downsample_bool_to_tok(gt_full)  # (Ht,Wt) float 0/1
        # build input: (C,Ht,Wt)
        tok_ch = np.transpose(tok, (2,0,1))  # (D,Ht,Wt)
        x = np.concatenate([tok_ch, seed[None,:,:]], axis=0).astype(np.float32)
        return {
            "x": torch.from_numpy(x),
            "gt": torch.from_numpy(gt_tok[None,:,:].astype(np.float32)),  # (1,Ht,Wt)
            "y": torch.tensor([y], dtype=torch.float32),
            "best_score": torch.tensor([float(best_score)], dtype=torch.float32),
            "case_id": torch.tensor([cid], dtype=torch.int64),
        }

# infer input dim
tmp = load_tok_npz(df.iloc[0][tok_col])
D_IN = int(tmp.shape[-1]) + 1
print("Token dim:", int(tmp.shape[-1]), "| Input channels:", D_IN)

# ----------------------------
# Model: UNet-ish + ASPP + dual heads
# ----------------------------
class ConvBNAct(nn.Module):
    def __init__(self, c_in, c_out, k=3, s=1, p=1, drop=0.0):
        super().__init__()
        self.conv = nn.Conv2d(c_in, c_out, k, s, p, bias=False)
        self.bn = nn.BatchNorm2d(c_out)
        self.act = nn.SiLU(inplace=True)
        self.drop = nn.Dropout2d(drop) if drop > 0 else nn.Identity()
    def forward(self, x):
        return self.drop(self.act(self.bn(self.conv(x))))

class ASPP(nn.Module):
    def __init__(self, c_in, c_out, rates=(1,2,4)):
        super().__init__()
        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c_in, c_out, 3, padding=r, dilation=r, bias=False),
                nn.BatchNorm2d(c_out),
                nn.SiLU(inplace=True),
            ) for r in rates
        ])
        self.proj = nn.Sequential(
            nn.Conv2d(len(rates)*c_out, c_out, 1, bias=False),
            nn.BatchNorm2d(c_out),
            nn.SiLU(inplace=True),
        )
    def forward(self, x):
        ys = [b(x) for b in self.blocks]
        y = torch.cat(ys, dim=1)
        return self.proj(y)

class HybridUNet(nn.Module):
    def __init__(self, c_in, base_ch=96, drop=0.1):
        super().__init__()
        c1, c2, c3 = base_ch, base_ch*2, base_ch*3

        self.in_norm = nn.GroupNorm(8, c_in)

        self.e1 = nn.Sequential(ConvBNAct(c_in, c1, drop=drop), ConvBNAct(c1, c1, drop=drop))
        self.p1 = nn.MaxPool2d(2)

        self.e2 = nn.Sequential(ConvBNAct(c1, c2, drop=drop), ConvBNAct(c2, c2, drop=drop))
        self.p2 = nn.MaxPool2d(2)

        self.e3 = nn.Sequential(ConvBNAct(c2, c3, drop=drop), ConvBNAct(c3, c3, drop=drop))

        self.aspp = ASPP(c3, c3, rates=(1,2,4))

        # decoder
        self.u2 = nn.ConvTranspose2d(c3, c2, 2, stride=2)
        self.d2 = nn.Sequential(ConvBNAct(c2+c2, c2, drop=drop), ConvBNAct(c2, c2, drop=drop))

        self.u1 = nn.ConvTranspose2d(c2, c1, 2, stride=2)
        self.d1 = nn.Sequential(ConvBNAct(c1+c1, c1, drop=drop), ConvBNAct(c1, c1, drop=drop))

        # heads
        self.seg_head = nn.Conv2d(c1, 1, 1)
        self.cls_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(c3, c3//2),
            nn.SiLU(inplace=True),
            nn.Dropout(drop if drop>0 else 0.0),
            nn.Linear(c3//2, 1),
        )

    def forward(self, x):
        # x: (B,C,Ht,Wt)
        x = self.in_norm(x)
        e1 = self.e1(x)
        e2 = self.e2(self.p1(e1))
        e3 = self.e3(self.p2(e2))
        b  = self.aspp(e3)

        # cls from bottleneck
        cls_logit = self.cls_head(b)

        d2 = self.u2(b)
        d2 = self.d2(torch.cat([d2, e2], dim=1))

        d1 = self.u1(d2)
        d1 = self.d1(torch.cat([d1, e1], dim=1))

        seg_logit = self.seg_head(d1)
        return seg_logit, cls_logit

# ----------------------------
# Loss: (BCE/Focal + Dice) + BCE cls
# ----------------------------
def dice_loss_from_logits(logits, target, eps=1e-6):
    # logits/target: (B,1,H,W)
    p = torch.sigmoid(logits)
    inter = (p * target).sum(dim=(2,3))
    den = (p + target).sum(dim=(2,3)) + eps
    d = (2.0 * inter) / den
    return 1.0 - d.mean()

def bce_focal_from_logits(logits, target, gamma=2.0):
    bce = F.binary_cross_entropy_with_logits(logits, target, reduction="none")
    if gamma <= 0:
        return bce.mean()
    p = torch.sigmoid(logits)
    pt = target * p + (1-target) * (1-p)
    w = (1-pt).pow(gamma)
    return (w * bce).mean()

# ----------------------------
# Fuse + score (token space)
# ----------------------------
@torch.no_grad()
def eval_model(model, dl, cfg_pp):
    model.eval()
    scores = []
    # unpack pp
    T1 = cfg_pp["T1"]; T0 = cfg_pp["T0"]; dil_it = cfg_pp["seed_dilate_it"]
    thr_gate = cfg_pp["thr_gate"]
    min_tok_area = cfg_pp["min_tok_area"]
    max_tok_area_frac = cfg_pp["max_tok_area_frac"]
    max_inst_keep = cfg_pp["max_inst_keep"]
    min_peak_keep = cfg_pp["min_peak_score_keep"]
    min_area_frac_keep = cfg_pp["min_area_frac_keep"]

    for batch in dl:
        x = batch["x"].to(device, non_blocking=True)
        gt = batch["gt"].to(device, non_blocking=True)
        best_score = batch["best_score"].cpu().numpy().reshape(-1)
        seg_logit, cls_logit = model(x)
        p_gate = torch.sigmoid(cls_logit).detach().cpu().numpy().reshape(-1)
        p_tok = torch.sigmoid(seg_logit).detach().cpu().numpy()[:,0]  # (B,Ht,Wt)
        seed = x.detach().cpu().numpy()[:,-1]  # last channel (B,Ht,Wt)
        gt_np = gt.detach().cpu().numpy()[:,0] > 0.5

        for i in range(x.shape[0]):
            if p_gate[i] < thr_gate:
                pr = np.zeros((HTOK,WTOK), dtype=bool)
            else:
                prob = p_tok[i]
                hard = prob >= T1
                soft = prob >= T0
                sd = dilate_tok(seed[i] > 0.5, dil_it)
                fused = hard | (sd & soft)

                # instance split + filter
                uni, n_inst = inst_split_union_tok(
                    fused, min_area=min_tok_area, max_area_frac=max_tok_area_frac, max_keep=max_inst_keep
                )
                area_frac = float(uni.mean())

                # guard: weak match + tiny area -> drop
                if (best_score[i] < min_peak_keep) and (area_frac < min_area_frac_keep):
                    uni = np.zeros((HTOK,WTOK), dtype=bool)

                pr = uni

            scores.append(dice(pr, gt_np[i]))
    return float(np.mean(scores)) if scores else 0.0

# ----------------------------
# Train one trial (single val fold)
# ----------------------------
def sample_trial_cfg(trial_id: int, val_fold: int):
    def log_uniform(a,b):
        return float(np.exp(np.random.uniform(np.log(a), np.log(b))))
    lr = log_uniform(*LR_RANGE)
    wd = float(np.random.uniform(*WD_RANGE))
    drop = float(np.random.uniform(*DROPOUT_RANGE))
    base_ch = int(np.random.choice(BASE_CH_CHOICES))
    lam_seg = float(np.random.uniform(*LAMBDA_SEG_RANGE))
    lam_cls = float(np.random.uniform(*LAMBDA_CLS_RANGE))
    gamma = float(np.random.choice(FOCAL_GAMMA_CHOICES))

    T1 = float(np.random.uniform(*T1_RANGE))
    T0 = float(np.random.uniform(*T0_RANGE))
    if T0 > T1:
        T0, T1 = T1-0.05, T1  # enforce T0 <= T1 (soft <= hard)
        T0 = max(0.05, T0)
    dil_it = int(np.random.choice(SEED_DILATE_CHOICES))
    thr_gate = float(np.random.uniform(*THR_GATE_RANGE))

    min_tok_area = int(np.random.choice(MIN_TOK_AREA_CHOICES))
    max_tok_area_frac = float(np.random.choice(MAX_TOK_AREA_FRAC_CHOICES))
    max_inst_keep = int(np.random.choice(MAX_INST_KEEP_CHOICES))

    min_peak_keep = int(np.random.choice(MIN_PEAK_SCORE_KEEP_CHOICES))
    min_area_frac_keep = float(np.random.choice(MIN_AREA_FRAC_KEEP_CHOICES))

    return {
        "trial_id": int(trial_id),
        "val_fold": int(val_fold),
        "lr": lr, "weight_decay": wd, "dropout": drop, "base_ch": base_ch,
        "lambda_seg": lam_seg, "lambda_cls": lam_cls, "focal_gamma": gamma,
        "T1": T1, "T0": T0, "seed_dilate_it": dil_it, "thr_gate": thr_gate,
        "min_tok_area": min_tok_area, "max_tok_area_frac": max_tok_area_frac, "max_inst_keep": max_inst_keep,
        "min_peak_score_keep": min_peak_keep, "min_area_frac_keep": min_area_frac_keep,
    }

def train_trial(cfg_trial):
    val_fold = cfg_trial["val_fold"]
    df_tr = df[df["fold"] != val_fold].reset_index(drop=True)
    df_va = df[df["fold"] == val_fold].reset_index(drop=True)

    # small balancing: oversample positives in train
    pos = df_tr[df_tr["y"]==1]
    neg = df_tr[df_tr["y"]==0]
    if len(pos) > 0 and len(neg) > 0:
        # target roughly 1:1 for stability
        take = min(len(neg), len(pos)*3)  # allow more neg but not too skewed
        neg_s = neg.sample(n=take, replace=False, random_state=SEED)
        pos_s = pos.sample(n=take, replace=True, random_state=SEED)
        df_tr_use = pd.concat([neg_s, pos_s], axis=0).sample(frac=1.0, random_state=SEED).reset_index(drop=True)
    else:
        df_tr_use = df_tr

    ds_tr = HybridTokDS(df_tr_use, tok_col=tok_col)
    ds_va = HybridTokDS(df_va, tok_col=tok_col)

    dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS,
                       pin_memory=(device.type=="cuda"), drop_last=True)
    dl_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS,
                       pin_memory=(device.type=="cuda"), drop_last=False)

    model = HybridUNet(c_in=D_IN, base_ch=cfg_trial["base_ch"], drop=cfg_trial["dropout"]).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=cfg_trial["lr"], weight_decay=cfg_trial["weight_decay"])
    scaler = torch.cuda.amp.GradScaler(enabled=(USE_AMP and device.type=="cuda"))

    best_score = -1.0
    best_state = None
    bad = 0

    for ep in range(1, TRIAL_EPOCHS+1):
        model.train()
        t0 = time.time()
        loss_meter = 0.0
        nsteps = 0

        for step, batch in enumerate(dl_tr, start=1):
            x = batch["x"].to(device, non_blocking=True)
            gt = batch["gt"].to(device, non_blocking=True)
            yb = batch["y"].to(device, non_blocking=True)

            with torch.cuda.amp.autocast(enabled=(USE_AMP and device.type=="cuda")):
                seg_logit, cls_logit = model(x)
                l_seg = bce_focal_from_logits(seg_logit, gt, gamma=cfg_trial["focal_gamma"]) + dice_loss_from_logits(seg_logit, gt)
                l_cls = F.binary_cross_entropy_with_logits(cls_logit, yb)
                loss = cfg_trial["lambda_seg"] * l_seg + cfg_trial["lambda_cls"] * l_cls
                loss = loss / ACCUM_STEPS

            scaler.scale(loss).backward()
            if step % ACCUM_STEPS == 0:
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)

            loss_meter += float(loss.item()) * ACCUM_STEPS
            nsteps += 1

        # eval
        score = eval_model(model, dl_va, cfg_trial)
        dt = time.time() - t0
        print(f"[trial {cfg_trial['trial_id']:02d} | fold {val_fold}] ep {ep}/{TRIAL_EPOCHS} "
              f"loss={loss_meter/max(1,nsteps):.4f} val_dice_proxy={score:.5f} time={dt:.1f}s")

        if score > best_score + 1e-5:
            best_score = score
            best_state = {k: v.detach().cpu() for k,v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            if bad >= EARLYSTOP_PATIENCE:
                break

    return best_score, best_state

# ----------------------------
# Main loop
# ----------------------------
trials = []
global_best = {"score": -1.0, "cfg": None, "state": None}

t_all = time.time()

for t in range(1, MAX_TRIALS+1):
    val_fold = fold_ids[(t-1) % len(fold_ids)] if VAL_FOLD_ROTATE else fold_ids[0]
    cfg_trial = sample_trial_cfg(t, val_fold)

    # run
    try:
        score, state = train_trial(cfg_trial)
    except Exception as e:
        print(f"[trial {t}] FAILED:", repr(e))
        cfg_trial["score"] = float("nan")
        cfg_trial["status"] = "fail"
        trials.append(cfg_trial)
        continue

    cfg_trial["score"] = float(score)
    cfg_trial["status"] = "ok"
    trials.append(cfg_trial)

    if score > global_best["score"]:
        global_best["score"] = float(score)
        global_best["cfg"] = cfg_trial
        global_best["state"] = state

    # persist trials each iteration
    pd.DataFrame(trials).to_csv(OUT_DIR / "trials.csv", index=False)
    (OUT_DIR / "best_config.json").write_text(json.dumps(global_best["cfg"], indent=2) if global_best["cfg"] else "{}")

    print("-"*60)
    print("CURRENT BEST:", global_best["score"], "| trial:", global_best["cfg"]["trial_id"], "| val_fold:", global_best["cfg"]["val_fold"])
    print("-"*60)

# save best model
if global_best["state"] is not None:
    pack = {
        "model_type": "HybridUNet",
        "input_channels": D_IN,
        "HTOK": HTOK, "WTOK": WTOK, "PATCH": PATCH,
        "tok_col": tok_col,
        "match_root": str(MATCH_ROOT),
        "train_table": str(TRAIN_TABLE),
        "paths": {"TRAIN_MASK_DIR": str(TRAIN_MASK_DIR) if TRAIN_MASK_DIR else None,
                  "SUP_MASK_DIR": str(SUP_MASK_DIR) if SUP_MASK_DIR else None},
        "best_cfg": global_best["cfg"],
        "state_dict": global_best["state"],
    }
    torch.save(pack, OUT_DIR / "best_model.pt")

print("DONE in", f"{time.time()-t_all:.1f}s")
print("Saved:", OUT_DIR / "trials.csv")
print("Saved:", OUT_DIR / "best_config.json")
print("Saved:", OUT_DIR / "best_model.pt" if (OUT_DIR / "best_model.pt").exists() else "(no best_model.pt)")


# Final Training (Train on Full Data)

In [None]:
# ============================================================
# STAGE — Final Training (Train on Full Data) (ONE CELL) — HYBRID (OPSI-1)
# Train ONE NETWORK: UNet(+ASPP) token-decoder + Gate head
# Then:
# - Export mask-prob cache  : /kaggle/working/recodai_luc/cache/mask_prob_hybrid_<hash>/{case_id}.npz  (key: prob_tok, p_gate)
# - Sweep best gate threshold on TRAIN using Dice-proxy (token-space)
# - Save final model bundle : /kaggle/working/recodai_luc_hybrid_artifacts/final_hybrid_model.pt
#
# Needs:
# - /kaggle/working/recodai_luc_hybrid_opt/best_config.json (from Optimize stage)
# - train_table.parquet with: case_id, y, tok_path/token_path/dino_path/... (auto-detect)
# - (optional) test_table.parquet with: case_id, tok_path/... (auto-detect)
# - Robust Matching outputs: /kaggle/working/recodai_luc/cache/match_cfg_*/match_manifest_{train,test}.parquet
# - GT masks via /kaggle/working/recodai_luc_prof/paths.json (TRAIN_MASK_DIR / SUP_MASK_DIR)
# ============================================================

import os, json, time, math, random, hashlib, warnings
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# ----------------------------
# Paths
# ----------------------------
OPT_DIR  = Path("/kaggle/working/recodai_luc_hybrid_opt")
OUT_DIR  = Path("/kaggle/working/recodai_luc_hybrid_artifacts")
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")
CACHE_DIR= Path("/kaggle/working/recodai_luc/cache")
OUT_DIR.mkdir(parents=True, exist_ok=True)

best_cfg_path = OPT_DIR / "best_config.json"
if not best_cfg_path.exists():
    raise FileNotFoundError(f"Missing {best_cfg_path}. Run Optimize stage first.")
BEST = json.loads(best_cfg_path.read_text())

paths_json = PROF_DIR / "paths.json"
if not paths_json.exists():
    raise FileNotFoundError(f"Missing {paths_json}")
PATHS = json.loads(paths_json.read_text())
TRAIN_MASK_DIR = Path(PATHS.get("TRAIN_MASK_DIR","")) if PATHS.get("TRAIN_MASK_DIR") else None
SUP_MASK_DIR   = Path(PATHS.get("SUP_MASK_DIR","")) if PATHS.get("SUP_MASK_DIR") else None

# ----------------------------
# Repro / device
# ----------------------------
SEED = int(os.environ.get("SEED", "42"))
def seed_everything(s=42):
    random.seed(s); np.random.seed(s)
    torch.manual_seed(s); torch.cuda.manual_seed_all(s)
seed_everything(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = bool(int(os.environ.get("USE_AMP", "1"))) and (device.type == "cuda")
print("DEVICE:", device, "| AMP:", USE_AMP)

# ----------------------------
# Find train/test tables
# ----------------------------
train_table_cands = [
    Path("/kaggle/working/recodai_luc_gate_artifacts/train_table.parquet"),
    Path("/kaggle/working/recodai_luc_hybrid_artifacts/train_table.parquet"),
    PROF_DIR / "train_table.parquet",
]
test_table_cands = [
    Path("/kaggle/working/recodai_luc_gate_artifacts/test_table.parquet"),
    Path("/kaggle/working/recodai_luc_hybrid_artifacts/test_table.parquet"),
    PROF_DIR / "test_table.parquet",
]
TRAIN_TABLE = next((p for p in train_table_cands if p.exists()), None)
TEST_TABLE  = next((p for p in test_table_cands if p.exists()), None)
if TRAIN_TABLE is None:
    raise FileNotFoundError("Cannot find train_table.parquet. Run Build Training Table stage first.")

df_tr = pd.read_parquet(TRAIN_TABLE).copy()
for need in ["case_id","y"]:
    if need not in df_tr.columns:
        raise ValueError(f"train_table missing required col: {need}")
df_tr["case_id"] = df_tr["case_id"].astype(int)
df_tr["y"] = df_tr["y"].astype(int)

tok_col = None
for c in ["tok_path","token_path","dino_path","feat_path","emb_path","token_npz","npz_path"]:
    if c in df_tr.columns:
        tok_col = c
        break
if tok_col is None:
    raise ValueError("train_table must contain a token path column (tok_path/token_path/dino_path/feat_path/emb_path).")

df_tr[tok_col] = df_tr[tok_col].astype(str)
df_tr = df_tr[df_tr[tok_col].map(lambda p: Path(p).exists())].reset_index(drop=True)

df_te = None
if TEST_TABLE is not None:
    df_te = pd.read_parquet(TEST_TABLE).copy()
    if "case_id" not in df_te.columns:
        raise ValueError("test_table missing case_id")
    df_te["case_id"] = df_te["case_id"].astype(int)
    if tok_col not in df_te.columns:
        # try detect token col on test too
        for c in ["tok_path","token_path","dino_path","feat_path","emb_path","token_npz","npz_path"]:
            if c in df_te.columns:
                tok_col_te = c
                break
        else:
            raise ValueError("test_table missing token path column.")
    else:
        tok_col_te = tok_col
    df_te[tok_col_te] = df_te[tok_col_te].astype(str)
    df_te = df_te[df_te[tok_col_te].map(lambda p: Path(p).exists())].reset_index(drop=True)
else:
    tok_col_te = tok_col

print("TRAIN_TABLE:", TRAIN_TABLE, "| rows:", len(df_tr), "| pos_rate:", float(df_tr["y"].mean()))
print("TEST_TABLE :", TEST_TABLE if TEST_TABLE else "(none)")
print("TOK_COL    :", tok_col, "| TOK_COL_TEST:", tok_col_te)

# ----------------------------
# Latest MATCH_ROOT + Htok/Wtok/PATCH + match maps (train/test)
# ----------------------------
def pick_latest_match_root():
    cands = sorted(CACHE_DIR.glob("match_cfg_*"))
    cands = [c for c in cands if (c/"cfg.json").exists() and (c/"match_manifest_train.parquet").exists()]
    if not cands:
        raise FileNotFoundError("Cannot find match_cfg_* under /kaggle/working/recodai_luc/cache. Run Robust Matching stage first.")
    cands = sorted(cands, key=lambda p: (p/"cfg.json").stat().st_mtime, reverse=True)
    return cands[0]

MATCH_ROOT = pick_latest_match_root()
MATCH_CFG = json.loads((MATCH_ROOT/"cfg.json").read_text())
PATCH = int(MATCH_CFG.get("patch", MATCH_CFG.get("patch_size", 14)))
HTOK  = int(MATCH_CFG.get("Ht", MATCH_CFG.get("htok", 37)))
WTOK  = int(MATCH_CFG.get("Wt", MATCH_CFG.get("wtok", 37)))

def build_match_map(pq: Path):
    if not pq.exists():
        return {}
    dfm = pd.read_parquet(pq).copy()
    dfm["case_id"] = dfm["case_id"].astype(int)
    if "match_npz" not in dfm.columns:
        return {}
    dfm = dfm[dfm["match_npz"].notna()].copy()
    if len(dfm) == 0:
        return {}
    score_cols = [c for c in ["best_peak_score","peak_score_max","max_peak_score","score_max","best_score"] if c in dfm.columns]
    if score_cols:
        sc = score_cols[0]
        dfm[sc] = pd.to_numeric(dfm[sc], errors="coerce").fillna(-1)
        dfm = dfm.sort_values(["case_id", sc], ascending=[True, False]).drop_duplicates("case_id", keep="first")
    else:
        def _mtime(p):
            try: return Path(p).stat().st_mtime
            except Exception: return -1
        dfm["_mtime"] = dfm["match_npz"].map(_mtime)
        dfm = dfm.sort_values(["case_id","_mtime"], ascending=[True, False]).drop_duplicates("case_id", keep="first")
    return dfm.set_index("case_id")["match_npz"].to_dict()

match_map_tr = build_match_map(MATCH_ROOT / "match_manifest_train.parquet")
match_map_te = build_match_map(MATCH_ROOT / "match_manifest_test.parquet")

print("MATCH_ROOT:", MATCH_ROOT)
print("TOK GRID  :", (HTOK, WTOK), "| PATCH:", PATCH)
print("match_map_tr:", len(match_map_tr), "| match_map_te:", len(match_map_te))

# ----------------------------
# Optional SciPy
# ----------------------------
try:
    import scipy.ndimage as ndi
    _HAS_SCIPY = True
except Exception:
    _HAS_SCIPY = False

def dilate_tok(x_bool, it=1):
    if it <= 0: return x_bool
    x = x_bool.astype(bool)
    if _HAS_SCIPY:
        return ndi.binary_dilation(x, iterations=it)
    for _ in range(it):
        xp = np.pad(x, 1, mode="constant", constant_values=False)
        y = np.zeros_like(x, dtype=bool)
        for dy in (-1,0,1):
            for dx in (-1,0,1):
                y |= xp[1+dy:1+dy+x.shape[0], 1+dx:1+dx+x.shape[1]]
        x = y
    return x

def label_cc(x_bool):
    x = x_bool.astype(bool)
    if _HAS_SCIPY:
        lab, n = ndi.label(x, structure=np.ones((3,3), dtype=np.uint8))
        return lab, int(n)
    H,W = x.shape
    lab = np.zeros((H,W), dtype=np.int32); cur=0
    for y in range(H):
        for x0 in range(W):
            if (not x[y,x0]) or lab[y,x0]!=0: continue
            cur += 1
            st=[(y,x0)]; lab[y,x0]=cur
            while st:
                yy,xx=st.pop()
                for dy in (-1,0,1):
                    for dx in (-1,0,1):
                        if dy==0 and dx==0: continue
                        ny,nx=yy+dy,xx+dx
                        if 0<=ny<H and 0<=nx<W and x[ny,nx] and lab[ny,nx]==0:
                            lab[ny,nx]=cur; st.append((ny,nx))
    return lab, int(cur)

def inst_split_union_tok(mask_bool, min_area=2, max_area_frac=0.8, max_keep=8):
    H,W = mask_bool.shape
    lab,n = label_cc(mask_bool)
    if n<=0:
        return np.zeros((H,W), dtype=bool), 0
    insts=[]; areas=[]
    for k in range(1,n+1):
        m = (lab==k)
        a = int(m.sum())
        if a < min_area: continue
        if a / float(H*W) > max_area_frac: continue
        insts.append(m); areas.append(a)
    if not insts:
        return np.zeros((H,W), dtype=bool), 0
    order = np.argsort(np.asarray(areas))[::-1][:max_keep]
    uni = np.zeros((H,W), dtype=bool)
    for i in order:
        uni |= insts[i]
    return uni, int(len(order))

def dice(pr_bool, gt_bool):
    a=int(pr_bool.sum()); b=int(gt_bool.sum())
    if a==0 and b==0: return 1.0
    if a==0 or b==0: return 0.0
    inter=int((pr_bool & gt_bool).sum())
    return float((2.0*inter)/(a+b))

# ----------------------------
# GT union loader -> token GT
# ----------------------------
def _find_mask_files(mask_dir: Path, case_id: int):
    if mask_dir is None or (not mask_dir.exists()):
        return []
    cid = str(int(case_id))
    exts = (".png",".jpg",".jpeg",".tif",".tiff",".bmp")
    pats = [f"{cid}*.png", f"{cid}*.jpg", f"{cid}*.jpeg", f"{cid}*.tif", f"{cid}*.tiff", f"{cid}*.bmp",
            f"{cid}__*.png", f"{cid}_*.png"]
    out, seen = [], set()
    for pat in pats:
        for p in mask_dir.glob(pat):
            if p.suffix.lower() in exts:
                s = str(p)
                if s not in seen:
                    out.append(p); seen.add(s)
    return sorted(out)

def load_gt_union_full(case_id: int):
    # npy union if exists
    for d in [TRAIN_MASK_DIR, SUP_MASK_DIR]:
        if d is None or (not d.exists()):
            continue
        npy = d / f"{int(case_id)}.npy"
        if npy.exists():
            a = np.load(npy, mmap_mode="r")
            if a.ndim == 2:
                return (np.asarray(a) > 0)
            if a.ndim == 3:
                return (np.asarray(a) > 0).any(axis=0)
    # png union
    files = []
    if TRAIN_MASK_DIR is not None: files += _find_mask_files(TRAIN_MASK_DIR, case_id)
    if SUP_MASK_DIR is not None:   files += _find_mask_files(SUP_MASK_DIR, case_id)
    if not files:
        return None
    m = None
    for p in files:
        try:
            im = Image.open(p).convert("L")
            a = (np.asarray(im) > 0)
            m = a if m is None else (m | a)
        except Exception:
            continue
    return m

def downsample_bool_to_tok(mask_bool: np.ndarray):
    if mask_bool is None:
        return np.zeros((HTOK,WTOK), dtype=np.float32)
    im = Image.fromarray((mask_bool.astype(np.uint8)*255))
    im = im.resize((WTOK, HTOK), resample=Image.NEAREST)
    return (np.asarray(im) > 127).astype(np.float32)

# ----------------------------
# Seed from match_npz (token union) + best_score
# ----------------------------
def load_seed_tok(case_id: int, is_test=False, topk=8):
    mm = match_map_te if is_test else match_map_tr
    p = mm.get(int(case_id), None)
    if p is None or (not Path(p).exists()):
        return np.zeros((HTOK,WTOK), dtype=np.float32), 0
    z = np.load(p)
    scores = z["peak_score"] if "peak_score" in z.files else np.zeros((0,), np.int32)
    src = z["src_masks"] if "src_masks" in z.files else np.zeros((0,HTOK,WTOK), np.uint8)
    tgt = z["tgt_masks"] if "tgt_masks" in z.files else np.zeros((0,HTOK,WTOK), np.uint8)
    if src.ndim != 3 or tgt.ndim != 3 or src.shape[0] == 0:
        return np.zeros((HTOK,WTOK), dtype=np.float32), int(scores.max()) if len(scores) else 0
    best_score = int(scores.max()) if len(scores) else 0
    if topk is not None and src.shape[0] > topk:
        idx = np.argsort(scores)[::-1][:topk]
        src = src[idx]; tgt = tgt[idx]
    seed = ((src>0) | (tgt>0)).any(axis=0).astype(np.float32)
    return seed, best_score

# ----------------------------
# Token loader (npz/npy)
# ----------------------------
def load_tok_any(path_str: str):
    p = Path(path_str)
    if p.suffix.lower() == ".npy":
        a = np.load(p, mmap_mode="r")
    else:
        z = np.load(p)
        for k in ["tok","tokens","grid","token_grid","x","feat","emb","f"]:
            if k in z.files:
                a = z[k]; break
        else:
            keys = list(z.files)
            if not keys: raise ValueError("empty npz")
            a = z[keys[0]]
    a = np.asarray(a)
    # normalize to (Ht,Wt,D)
    if a.ndim == 3 and a.shape[0] == HTOK and a.shape[1] == WTOK:
        return a.astype(np.float32)
    if a.ndim == 3 and a.shape[-2] == HTOK and a.shape[-1] == WTOK:
        return np.transpose(a, (1,2,0)).astype(np.float32)  # (D,Ht,Wt)->(Ht,Wt,D)
    if a.ndim == 3:
        # resize spatial to tok (slow fallback)
        H,W,D = a.shape
        out = np.zeros((HTOK,WTOK,D), np.float32)
        for d in range(D):
            im = Image.fromarray(a[:,:,d].astype(np.float32))
            im = im.resize((WTOK,HTOK), resample=Image.BILINEAR)
            out[:,:,d] = np.asarray(im).astype(np.float32)
        return out
    raise ValueError(f"Unknown token array shape: {a.shape}")

# infer input dim
tmp = load_tok_any(df_tr.iloc[0][tok_col])
DIN = int(tmp.shape[-1])
CIN = DIN + 1
print("Token D:", DIN, "| Input C:", CIN)

# ----------------------------
# Dataset / loaders
# ----------------------------
class TrainDS(Dataset):
    def __init__(self, df_in: pd.DataFrame, tok_col: str):
        self.df = df_in.reset_index(drop=True)
        self.tok_col = tok_col
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        cid = int(r["case_id"])
        y = float(r["y"])
        tok = load_tok_any(r[self.tok_col])          # (Ht,Wt,D)
        seed, best_score = load_seed_tok(cid, is_test=False, topk=8)
        gt_full = load_gt_union_full(cid)
        gt_tok = downsample_bool_to_tok(gt_full)     # (Ht,Wt) float
        tok_ch = np.transpose(tok, (2,0,1))          # (D,Ht,Wt)
        x = np.concatenate([tok_ch, seed[None,:,:]], axis=0).astype(np.float32)
        return {
            "x": torch.from_numpy(x),
            "gt": torch.from_numpy(gt_tok[None,:,:].astype(np.float32)),
            "y": torch.tensor([y], dtype=torch.float32),
            "best_score": torch.tensor([float(best_score)], dtype=torch.float32),
            "case_id": torch.tensor([cid], dtype=torch.int64),
        }

class InferDS(Dataset):
    def __init__(self, df_in: pd.DataFrame, tok_col: str, is_test: bool):
        self.df = df_in.reset_index(drop=True)
        self.tok_col = tok_col
        self.is_test = is_test
    def __len__(self): return len(self.df)
    def __getitem__(self, i):
        r = self.df.iloc[i]
        cid = int(r["case_id"])
        tok = load_tok_any(r[self.tok_col])
        seed, best_score = load_seed_tok(cid, is_test=self.is_test, topk=8)
        tok_ch = np.transpose(tok, (2,0,1))
        x = np.concatenate([tok_ch, seed[None,:,:]], axis=0).astype(np.float32)
        return {
            "x": torch.from_numpy(x),
            "best_score": torch.tensor([float(best_score)], dtype=torch.float32),
            "case_id": torch.tensor([cid], dtype=torch.int64),
        }

# ----------------------------
# Model (odd-size safe: pool ceil_mode + interpolate up)
# ----------------------------
class ConvBNAct(nn.Module):
    def __init__(self, c_in, c_out, k=3, s=1, p=1, drop=0.0):
        super().__init__()
        self.conv = nn.Conv2d(c_in, c_out, k, s, p, bias=False)
        self.bn = nn.BatchNorm2d(c_out)
        self.act = nn.SiLU(inplace=True)
        self.drop = nn.Dropout2d(drop) if drop > 0 else nn.Identity()
    def forward(self, x):
        return self.drop(self.act(self.bn(self.conv(x))))

class ASPP(nn.Module):
    def __init__(self, c_in, c_out, rates=(1,2,4)):
        super().__init__()
        self.blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(c_in, c_out, 3, padding=r, dilation=r, bias=False),
                nn.BatchNorm2d(c_out),
                nn.SiLU(inplace=True),
            ) for r in rates
        ])
        self.proj = nn.Sequential(
            nn.Conv2d(len(rates)*c_out, c_out, 1, bias=False),
            nn.BatchNorm2d(c_out),
            nn.SiLU(inplace=True),
        )
    def forward(self, x):
        y = torch.cat([b(x) for b in self.blocks], dim=1)
        return self.proj(y)

class HybridUNet(nn.Module):
    def __init__(self, c_in, base_ch=96, drop=0.1):
        super().__init__()
        c1, c2, c3 = base_ch, base_ch*2, base_ch*3
        self.e1 = nn.Sequential(ConvBNAct(c_in, c1, drop=drop), ConvBNAct(c1, c1, drop=drop))
        self.pool = nn.MaxPool2d(2, 2, ceil_mode=True)
        self.e2 = nn.Sequential(ConvBNAct(c1, c2, drop=drop), ConvBNAct(c2, c2, drop=drop))
        self.e3 = nn.Sequential(ConvBNAct(c2, c3, drop=drop), ConvBNAct(c3, c3, drop=drop))
        self.aspp = ASPP(c3, c3, rates=(1,2,4))

        self.d2 = nn.Sequential(ConvBNAct(c3+c2, c2, drop=drop), ConvBNAct(c2, c2, drop=drop))
        self.d1 = nn.Sequential(ConvBNAct(c2+c1, c1, drop=drop), ConvBNAct(c1, c1, drop=drop))

        self.seg_head = nn.Conv2d(c1, 1, 1)
        self.cls_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(c3, c3//2),
            nn.SiLU(inplace=True),
            nn.Dropout(drop if drop>0 else 0.0),
            nn.Linear(c3//2, 1),
        )

    def forward(self, x):
        e1 = self.e1(x)                       # (B,c1,H,W)
        p1 = self.pool(e1)                    # ceil
        e2 = self.e2(p1)                      # (B,c2,~H/2,~W/2)
        p2 = self.pool(e2)
        e3 = self.e3(p2)
        b  = self.aspp(e3)

        cls_logit = self.cls_head(b)

        u2 = F.interpolate(b, size=e2.shape[-2:], mode="bilinear", align_corners=False)
        d2 = self.d2(torch.cat([u2, e2], dim=1))

        u1 = F.interpolate(d2, size=e1.shape[-2:], mode="bilinear", align_corners=False)
        d1 = self.d1(torch.cat([u1, e1], dim=1))

        seg_logit = self.seg_head(d1)
        return seg_logit, cls_logit

# ----------------------------
# Loss
# ----------------------------
def dice_loss_from_logits(logits, target, eps=1e-6):
    p = torch.sigmoid(logits)
    inter = (p * target).sum(dim=(2,3))
    den = (p + target).sum(dim=(2,3)) + eps
    d = (2.0 * inter) / den
    return 1.0 - d.mean()

def bce_focal_from_logits(logits, target, gamma=2.0):
    bce = F.binary_cross_entropy_with_logits(logits, target, reduction="none")
    if gamma <= 0:
        return bce.mean()
    p = torch.sigmoid(logits)
    pt = target * p + (1-target) * (1-p)
    w = (1-pt).pow(gamma)
    return (w * bce).mean()

# ----------------------------
# Train
# ----------------------------
EPOCHS = int(os.environ.get("EPOCHS_FINAL", str(max(12, int(BEST.get("trial_epochs", 0) or 0) + 10))))
BATCH  = int(os.environ.get("BATCH_SIZE", "32"))
NUM_WORKERS = int(os.environ.get("NUM_WORKERS", "2"))
ACCUM = int(os.environ.get("ACCUM_STEPS", "1"))

base_ch = int(BEST["base_ch"])
dropout = float(BEST["dropout"])
lr = float(BEST["lr"])
wd = float(BEST["weight_decay"])
lam_seg = float(BEST["lambda_seg"])
lam_cls = float(BEST["lambda_cls"])
focal_gamma = float(BEST["focal_gamma"])

pos = int(df_tr["y"].sum())
neg = int(len(df_tr) - pos)
pos_weight = float(neg / max(1, pos))
bce_cls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], dtype=torch.float32, device=device))

ds = TrainDS(df_tr, tok_col=tok_col)
dl = DataLoader(
    ds, batch_size=BATCH, shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=(device.type=="cuda"),
    drop_last=True
)

model = HybridUNet(c_in=CIN, base_ch=base_ch, drop=dropout).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max(1, EPOCHS))
scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

print("-"*60)
print("FINAL TRAIN:", {"epochs": EPOCHS, "batch": BATCH, "accum": ACCUM, "lr": lr, "wd": wd,
                      "base_ch": base_ch, "dropout": dropout, "pos_weight": pos_weight,
                      "lam_seg": lam_seg, "lam_cls": lam_cls, "focal_gamma": focal_gamma})
print("-"*60)

t0 = time.time()
model.train()
opt.zero_grad(set_to_none=True)

for ep in range(1, EPOCHS+1):
    loss_meter = 0.0
    nsteps = 0
    for step, batch in enumerate(dl, start=1):
        x  = batch["x"].to(device, non_blocking=True)
        gt = batch["gt"].to(device, non_blocking=True)
        yb = batch["y"].to(device, non_blocking=True)

        with torch.cuda.amp.autocast(enabled=USE_AMP):
            seg_logit, cls_logit = model(x)
            l_seg = bce_focal_from_logits(seg_logit, gt, gamma=focal_gamma) + dice_loss_from_logits(seg_logit, gt)
            l_cls = bce_cls(cls_logit, yb)
            loss = lam_seg * l_seg + lam_cls * l_cls
            loss = loss / ACCUM

        scaler.scale(loss).backward()
        if step % ACCUM == 0:
            scaler.step(opt); scaler.update()
            opt.zero_grad(set_to_none=True)

        loss_meter += float(loss.item()) * ACCUM
        nsteps += 1

    sched.step()
    if ep == 1 or ep % 2 == 0 or ep == EPOCHS:
        print(f"[ep {ep:02d}/{EPOCHS}] loss={loss_meter/max(1,nsteps):.4f} lr={sched.get_last_lr()[0]:.6g} | {time.time()-t0:.1f}s")

print("TRAIN DONE |", f"{time.time()-t0:.1f}s")

# ----------------------------
# Export mask-prob cache (TRAIN+TEST) + p_gate
# ----------------------------
def cfg_hash(d):
    s = json.dumps(d, sort_keys=True).encode()
    return hashlib.md5(s).hexdigest()[:10]

CFG_EXPORT = {
    "hybrid": True,
    "best_config": BEST,
    "match_root": str(MATCH_ROOT),
    "tok_grid": {"HTOK": HTOK, "WTOK": WTOK, "PATCH": PATCH},
    "train_table": str(TRAIN_TABLE),
    "test_table": str(TEST_TABLE) if TEST_TABLE else None,
}
CFG_ID = cfg_hash(CFG_EXPORT)
MASKPROB_DIR = CACHE_DIR / f"mask_prob_hybrid_{CFG_ID}"
MASKPROB_DIR.mkdir(parents=True, exist_ok=True)

@torch.no_grad()
def export_probs(df_in: pd.DataFrame, tok_col_use: str, is_test: bool, tag: str):
    ds_inf = InferDS(df_in, tok_col=tok_col_use, is_test=is_test)
    dl_inf = DataLoader(ds_inf, batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS,
                        pin_memory=(device.type=="cuda"), drop_last=False)
    model.eval()
    t1 = time.time()
    wrote = 0
    for j, batch in enumerate(dl_inf, start=1):
        x = batch["x"].to(device, non_blocking=True)
        seg_logit, cls_logit = model(x)
        prob_tok = torch.sigmoid(seg_logit).detach().cpu().numpy()[:,0]  # (B,Ht,Wt)
        p_gate   = torch.sigmoid(cls_logit).detach().cpu().numpy().reshape(-1)  # (B,)
        cids = batch["case_id"].cpu().numpy().reshape(-1)

        for i in range(len(cids)):
            cid = int(cids[i])
            np.savez_compressed(
                MASKPROB_DIR / f"{cid}.npz",
                prob_tok=prob_tok[i].astype(np.float16),
                p_gate=np.float16(p_gate[i]),
            )
            wrote += 1

        if j % 100 == 0:
            print(f"[export {tag}] {wrote}/{len(ds_inf)} | {time.time()-t1:.1f}s")

    print(f"[export {tag}] done | wrote={wrote} | {time.time()-t1:.1f}s")

export_probs(df_tr[["case_id", tok_col]], tok_col, is_test=False, tag="train")
if df_te is not None:
    export_probs(df_te[["case_id", tok_col_te]], tok_col_te, is_test=True, tag="test")

# ----------------------------
# Threshold sweep on TRAIN using Dice-proxy (token-space)
# ----------------------------
T1 = float(BEST["T1"]); T0 = float(BEST["T0"]); dil_it = int(BEST["seed_dilate_it"])
min_tok_area = int(BEST["min_tok_area"])
max_tok_area_frac = float(BEST["max_tok_area_frac"])
max_inst_keep = int(BEST["max_inst_keep"])
min_peak_keep = int(BEST["min_peak_score_keep"])
min_area_frac_keep = float(BEST["min_area_frac_keep"])

@torch.no_grad()
def build_dice_arrays_train():
    ds_inf = TrainDS(df_tr[["case_id","y",tok_col]].copy(), tok_col=tok_col)
    dl_inf = DataLoader(ds_inf, batch_size=BATCH, shuffle=False, num_workers=NUM_WORKERS,
                        pin_memory=(device.type=="cuda"), drop_last=False)
    model.eval()

    p_gate_all = np.zeros(len(ds_inf), np.float32)
    dice_use   = np.zeros(len(ds_inf), np.float32)
    dice_empty = np.zeros(len(ds_inf), np.float32)

    k0 = 0
    t2 = time.time()
    for batch in dl_inf:
        x = batch["x"].to(device, non_blocking=True)
        gt = batch["gt"].cpu().numpy()[:,0] > 0.5
        best_score = batch["best_score"].cpu().numpy().reshape(-1)
        seg_logit, cls_logit = model(x)
        prob = torch.sigmoid(seg_logit).cpu().numpy()[:,0]
        pg   = torch.sigmoid(cls_logit).cpu().numpy().reshape(-1)
        seed = batch["x"].cpu().numpy()[:,-1]  # last channel (B,Ht,Wt)

        B = prob.shape[0]
        for i in range(B):
            idx = k0 + i
            p_gate_all[idx] = pg[i]

            gt_empty = (gt[i].sum() == 0)
            dice_empty[idx] = 1.0 if gt_empty else 0.0

            hard = prob[i] >= T1
            soft = prob[i] >= T0
            sd = dilate_tok(seed[i] > 0.5, dil_it)
            fused = hard | (sd & soft)

            uni, _ = inst_split_union_tok(fused, min_area=min_tok_area, max_area_frac=max_tok_area_frac, max_keep=max_inst_keep)
            area_frac = float(uni.mean())
            if (best_score[i] < min_peak_keep) and (area_frac < min_area_frac_keep):
                uni = np.zeros((HTOK,WTOK), dtype=bool)

            dice_use[idx] = dice(uni, gt[i])

        k0 += B
        if k0 % 800 == 0:
            print(f"[dice-proxy] {k0}/{len(ds_inf)} | {time.time()-t2:.1f}s")

    return p_gate_all, dice_use, dice_empty

p_gate_all, dice_use, dice_empty = build_dice_arrays_train()

thr_grid = np.linspace(0.0, 1.0, 201, dtype=np.float32)
rows = []
best_i = 0
best_score = -1.0

for i, thr in enumerate(thr_grid):
    use = (p_gate_all >= thr)
    score = float(np.where(use, dice_use, dice_empty).mean())
    rows.append({"thr": float(thr), "score_dice_proxy": score})
    if score > best_score:
        best_score = score
        best_i = i

df_thr = pd.DataFrame(rows)
best_thr = float(df_thr.loc[best_i, "thr"])

(df_thr).to_csv(OUT_DIR / "threshold_table.csv", index=False)
(OUT_DIR / "best_threshold.json").write_text(json.dumps({
    "recommended_thr": best_thr,
    "best_score_dice_proxy": float(best_score),
    "row": df_thr.loc[best_i].to_dict(),
    "pp": {
        "T1": T1, "T0": T0, "seed_dilate_it": dil_it,
        "min_tok_area": min_tok_area, "max_tok_area_frac": max_tok_area_frac, "max_inst_keep": max_inst_keep,
        "min_peak_score_keep": min_peak_keep, "min_area_frac_keep": min_area_frac_keep,
    }
}, indent=2))

print("BEST_THR:", best_thr, "| best_score:", best_score)

# ----------------------------
# Save final model bundle
# ----------------------------
pack = {
    "model_type": "HybridUNet",
    "state_dict": {k: v.detach().cpu() for k,v in model.state_dict().items()},
    "best_config": BEST,
    "recommended_thr": best_thr,
    "cfg_export": CFG_EXPORT,
    "cfg_id": CFG_ID,
    "maskprob_dir": str(MASKPROB_DIR),
    "tok_grid": {"HTOK": HTOK, "WTOK": WTOK, "PATCH": PATCH},
    "input_channels": CIN,
    "token_dim": DIN,
}
torch.save(pack, OUT_DIR / "final_hybrid_model.pt")

print("-"*60)
print("SAVED:")
print(" -", OUT_DIR / "final_hybrid_model.pt")
print(" -", OUT_DIR / "threshold_table.csv")
print(" -", OUT_DIR / "best_threshold.json")
print("MASKPROB_DIR:", MASKPROB_DIR)


# Finalize & Save Model Bundle (Reproducible)

In [None]:
# ============================================================
# STAGE — Finalize & Save Model Bundle (Reproducible) (ONE CELL) — HYBRID (OPSI-1)
# Bundle (portable):
# - final_hybrid_model.pt
# - best_threshold.json + threshold_table.csv
# - best_config.json (if exists)
# - paths.json + match cfg.json (if resolvable)
# - manifest.json (sha256 + env + metadata)
# - ZIP: hybrid_model_bundle_<cfg_id>.zip
# ============================================================

import os, json, time, hashlib, shutil, platform
from pathlib import Path

import numpy as np
import pandas as pd

import torch

# ----------------------------
# Locate inputs
# ----------------------------
OPT_DIR  = Path("/kaggle/working/recodai_luc_hybrid_opt")
OUT_DIR  = Path("/kaggle/working/recodai_luc_hybrid_artifacts")
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")

FINAL_PT = OUT_DIR / "final_hybrid_model.pt"
BEST_THR = OUT_DIR / "best_threshold.json"
THR_TAB  = OUT_DIR / "threshold_table.csv"
BEST_CFG = OPT_DIR / "best_config.json"
PATHS_JS = PROF_DIR / "paths.json"

for p in [FINAL_PT, BEST_THR]:
    if not p.exists():
        raise FileNotFoundError(f"Missing {p}. Run Final Training stage first.")

# ----------------------------
# Read model pack for metadata
# ----------------------------
pack = torch.load(FINAL_PT, map_location="cpu")
cfg_id = str(pack.get("cfg_id", "unknown"))
maskprob_dir = str(pack.get("maskprob_dir", ""))

# resolve match cfg.json if possible
match_cfg_json = None
try:
    match_root = pack.get("cfg_export", {}).get("match_root", None)
    if match_root and Path(match_root).exists() and (Path(match_root) / "cfg.json").exists():
        match_cfg_json = Path(match_root) / "cfg.json"
except Exception:
    match_cfg_json = None

stamp = time.strftime("%Y%m%d_%H%M%S", time.gmtime())
BUNDLE_DIR = Path(f"/kaggle/working/recodai_luc_hybrid_bundle_{cfg_id}_{stamp}")
BUNDLE_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# Helpers
# ----------------------------
def sha256_file(p: Path, chunk=1<<20):
    h = hashlib.sha256()
    with open(p, "rb") as f:
        while True:
            b = f.read(chunk)
            if not b:
                break
            h.update(b)
    return h.hexdigest()

def safe_copy(src: Path, dst_dir: Path, new_name: str = None):
    if src is None or (not Path(src).exists()):
        return None
    src = Path(src)
    dst = dst_dir / (new_name if new_name else src.name)
    shutil.copy2(src, dst)
    return dst

# ----------------------------
# Copy artifacts
# ----------------------------
copied = {}

copied["final_hybrid_model.pt"] = str(safe_copy(FINAL_PT, BUNDLE_DIR, "final_hybrid_model.pt"))
copied["best_threshold.json"]   = str(safe_copy(BEST_THR, BUNDLE_DIR, "best_threshold.json"))

if THR_TAB.exists():
    copied["threshold_table.csv"] = str(safe_copy(THR_TAB, BUNDLE_DIR, "threshold_table.csv"))

if BEST_CFG.exists():
    copied["best_config.json"] = str(safe_copy(BEST_CFG, BUNDLE_DIR, "best_config.json"))

if PATHS_JS.exists():
    copied["paths.json"] = str(safe_copy(PATHS_JS, BUNDLE_DIR, "paths.json"))

if match_cfg_json is not None and Path(match_cfg_json).exists():
    copied["match_cfg.json"] = str(safe_copy(match_cfg_json, BUNDLE_DIR, "match_cfg.json"))

# lightweight loader hint
readme = BUNDLE_DIR / "README.txt"
readme.write_text(
    "Hybrid (OPSI-1) bundle contents:\n"
    "- final_hybrid_model.pt : Torch state_dict + cfg + recommended_thr + maskprob_dir\n"
    "- best_threshold.json   : recommended_thr + postprocess params\n"
    "- threshold_table.csv   : sweep table (optional)\n"
    "- best_config.json      : HPO best trial (optional)\n"
    "- paths.json            : dataset paths used in build (optional)\n"
    "- match_cfg.json        : robust matching cfg used (optional)\n\n"
    "Load example:\n"
    "  import torch, json\n"
    "  pack = torch.load('final_hybrid_model.pt', map_location='cpu')\n"
    "  thr  = json.load(open('best_threshold.json'))['recommended_thr']\n"
)

# ----------------------------
# Manifest (sha256 + env + metadata)
# ----------------------------
files = sorted([p for p in BUNDLE_DIR.glob("*") if p.is_file() and p.name != "manifest.json"])
hashes = {p.name: sha256_file(p) for p in files}

meta = {
    "bundle_dir": str(BUNDLE_DIR),
    "created_utc": time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime()),
    "cfg_id": cfg_id,
    "recommended_thr": float(pack.get("recommended_thr", json.loads(BEST_THR.read_text()).get("recommended_thr", 0.5))),
    "tok_grid": pack.get("tok_grid", {}),
    "input_channels": int(pack.get("input_channels", -1)),
    "token_dim": int(pack.get("token_dim", -1)),
    "maskprob_dir": maskprob_dir,
    "env": {
        "python": platform.python_version(),
        "platform": platform.platform(),
        "torch": torch.__version__,
        "cuda_available": bool(torch.cuda.is_available()),
        "cuda_version": torch.version.cuda,
        "numpy": np.__version__,
        "pandas": pd.__version__,
    },
    "files_copied": copied,
    "sha256": hashes,
}

manifest_path = BUNDLE_DIR / "manifest.json"
manifest_path.write_text(json.dumps(meta, indent=2))

# ----------------------------
# ZIP bundle
# ----------------------------
import zipfile
ZIP_PATH = Path(f"/kaggle/working/hybrid_model_bundle_{cfg_id}_{stamp}.zip")
with zipfile.ZipFile(ZIP_PATH, "w", compression=zipfile.ZIP_DEFLATED) as z:
    for p in sorted(BUNDLE_DIR.glob("*")):
        if p.is_file():
            z.write(p, arcname=p.name)

print("BUNDLE_DIR:", BUNDLE_DIR)
print("ZIP_PATH  :", ZIP_PATH)
print("Files:", [p.name for p in sorted(BUNDLE_DIR.glob('*')) if p.is_file()])
