# Dataset Loading & Manifest Building

In [1]:
# ============================================================
# STAGE 0 — Dataset Loading & Manifest Building (ONE CELL)
# - English stage naming (compatible with next stages in our plan)
# - Auto-detect Kaggle dataset root from folder structure shown
# - Builds:
#     df_train_all : uid, case_id, y, img_path, mask_path, is_supplemental, split
#     df_test      : uid, case_id, img_path
# - Saves:
#     /kaggle/working/recodai_luc_prof/paths.json
#     /kaggle/working/recodai_luc_prof/train_manifest.parquet
#     /kaggle/working/recodai_luc_prof/test_manifest.parquet
#     /kaggle/working/recodai_luc_prof/sample_submission.csv (copy)
# - Paper-ready figures (Stage 0):
#     /kaggle/working/recodai_luc_prof/figures/stage0/*.png
# ============================================================

import os, json, math, random, warnings
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")

# ----------------------------
# 0) Repro
# ----------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)

# ----------------------------
# 1) Auto-detect dataset root
# ----------------------------
def find_dataset_root(base="/kaggle/input"):
    base = Path(base)
    need_dirs = ["train_images", "train_masks", "test_images"]
    need_file = "sample_submission.csv"
    cands = []
    for d in base.iterdir():
        if not d.is_dir():
            continue
        ok = True
        for nd in need_dirs:
            if not (d / nd).exists():
                ok = False
                break
        if ok and (d / need_file).exists():
            cands.append(d)
    # prefer the largest (usually main competition dataset)
    if not cands:
        raise RuntimeError(
            "Dataset root not found under /kaggle/input. "
            "Expected folders: train_images/, train_masks/, test_images/ and sample_submission.csv"
        )
    cands = sorted(cands, key=lambda p: sum(1 for _ in p.rglob("*")), reverse=True)
    return cands[0]

DATA_ROOT = find_dataset_root()
print("DATA_ROOT:", DATA_ROOT)

TRAIN_IMG_DIR = DATA_ROOT / "train_images"
TRAIN_AUTH_DIR = TRAIN_IMG_DIR / "authentic"
TRAIN_FORG_DIR = TRAIN_IMG_DIR / "forged"
TRAIN_MASK_DIR = DATA_ROOT / "train_masks"

TEST_IMG_DIR = DATA_ROOT / "test_images"

SUP_IMG_DIR = DATA_ROOT / "supplemental_images"
SUP_MASK_DIR = DATA_ROOT / "supplemental_masks"

SAMPLE_SUB_PATH = DATA_ROOT / "sample_submission.csv"

for p in [TRAIN_AUTH_DIR, TRAIN_FORG_DIR, TRAIN_MASK_DIR, TEST_IMG_DIR, SAMPLE_SUB_PATH]:
    if not p.exists():
        raise RuntimeError(f"Missing required path: {p}")

# ----------------------------
# 2) Output dirs (keep consistent for next stages)
# ----------------------------
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")
PROF_DIR.mkdir(parents=True, exist_ok=True)

FIG_DIR = PROF_DIR / "figures" / "stage0"
FIG_DIR.mkdir(parents=True, exist_ok=True)

# Copy sample submission for reference
pd.read_csv(SAMPLE_SUB_PATH).to_csv(PROF_DIR / "sample_submission.csv", index=False)

# ----------------------------
# 3) Helpers
# ----------------------------
def list_pngs(d: Path):
    if (not d.exists()) or (not d.is_dir()):
        return []
    return sorted([p for p in d.glob("*.png")])

def read_mask_npy(mask_path: Path):
    m = np.load(mask_path)
    if m.ndim == 3:
        m = m.max(axis=0)
    m = (m > 0).astype(np.uint8)
    return m

def overlay_mask(img_rgb, mask01, alpha=0.45):
    img = img_rgb.astype(np.float32) / 255.0
    m = mask01.astype(bool)
    out = img.copy()
    out[m, 0] = (1 - alpha) * out[m, 0] + alpha * 1.0
    out[m, 1] = (1 - alpha) * out[m, 1] + alpha * 0.0
    out[m, 2] = (1 - alpha) * out[m, 2] + alpha * 0.0
    return (out * 255).clip(0, 255).astype(np.uint8)

def safe_case_id_from_uid(uid: str):
    # filenames look like "10.png" -> uid "10"
    # Keep as string for exact match with sample_submission
    return str(uid)

# ----------------------------
# 4) Build manifests
# ----------------------------
auth_files = list_pngs(TRAIN_AUTH_DIR)
forg_files = list_pngs(TRAIN_FORG_DIR)

rows = []
for p in auth_files:
    uid = p.stem
    rows.append(dict(
        uid=uid,
        case_id=safe_case_id_from_uid(uid),
        y=0,
        split="train",
        img_path=str(p),
        mask_path="",
        is_supplemental=0,
        source="train_authentic",
    ))

for p in forg_files:
    uid = p.stem
    mp = TRAIN_MASK_DIR / f"{uid}.npy"
    rows.append(dict(
        uid=uid,
        case_id=safe_case_id_from_uid(uid),
        y=1,
        split="train",
        img_path=str(p),
        mask_path=str(mp) if mp.exists() else "",
        is_supplemental=0,
        source="train_forged",
    ))

# Supplemental (if exists)
if SUP_IMG_DIR.exists() and SUP_MASK_DIR.exists():
    sup_files = list_pngs(SUP_IMG_DIR)
    for p in sup_files:
        uid = p.stem
        mp = SUP_MASK_DIR / f"{uid}.npy"
        rows.append(dict(
            uid=uid,
            case_id=safe_case_id_from_uid(uid),
            y=1,
            split="train",
            img_path=str(p),
            mask_path=str(mp) if mp.exists() else "",
            is_supplemental=1,
            source="supplemental_forged",
        ))

df_train_all = pd.DataFrame(rows)

# Resolve duplicate uids safely:
# If uid appears more than once (e.g., both authentic & forged), keep the row with an existing mask,
# otherwise prefer forged, else keep the first.
dup = df_train_all["uid"].duplicated(keep=False)
if dup.any():
    ddf = df_train_all[dup].copy()
    print(f"[WARN] Duplicate uid detected: {ddf['uid'].nunique()} uids (show up to 10):", ddf["uid"].unique()[:10])

    def pick_best(group: pd.DataFrame):
        # prefer: has mask_path existing file
        g = group.copy()
        g["mask_exists"] = g["mask_path"].apply(lambda x: Path(x).exists() if str(x) else False)
        # ranking: mask_exists desc, y desc (prefer forged), is_supplemental desc, keep stable by source
        g = g.sort_values(["mask_exists", "y", "is_supplemental"], ascending=[False, False, False])
        return g.iloc[0]

    df_train_all = (
        df_train_all.groupby("uid", as_index=False, sort=False)
        .apply(pick_best)
        .reset_index(drop=True)
    )

# Basic checks
df_train_all["mask_path"] = df_train_all["mask_path"].fillna("").astype(str)
missing_mask_forged = df_train_all[(df_train_all["y"] == 1) & (df_train_all["mask_path"] == "")]
if len(missing_mask_forged) > 0:
    print(f"[WARN] Forged samples missing mask_path: {len(missing_mask_forged)} (first 5 uids):",
          missing_mask_forged["uid"].head(5).tolist())

# Test manifest
test_files = list_pngs(TEST_IMG_DIR)
df_test = pd.DataFrame([dict(
    uid=p.stem,
    case_id=safe_case_id_from_uid(p.stem),
    split="test",
    img_path=str(p),
) for p in test_files])

# Sort by numeric id if possible (nice & stable)
def to_int_or_nan(x):
    try:
        return int(str(x))
    except:
        return np.nan

df_train_all["_cid_int"] = df_train_all["case_id"].map(to_int_or_nan)
df_test["_cid_int"] = df_test["case_id"].map(to_int_or_nan)
df_train_all = df_train_all.sort_values(["_cid_int", "case_id"]).drop(columns=["_cid_int"]).reset_index(drop=True)
df_test = df_test.sort_values(["_cid_int", "case_id"]).drop(columns=["_cid_int"]).reset_index(drop=True)

print("train_manifest:", df_train_all.shape, "| forged% =", round(df_train_all["y"].mean() * 100, 3))
print("test_manifest :", df_test.shape)

# ----------------------------
# 5) Save artifacts (parquet preferred)
# ----------------------------
train_out = PROF_DIR / "train_manifest.parquet"
test_out  = PROF_DIR / "test_manifest.parquet"
df_train_all.to_parquet(train_out, index=False)
df_test.to_parquet(test_out, index=False)

# Save paths.json for downstream stages
# (We keep these names stable so next stages can read them without errors.)
paths = dict(
    DATA_ROOT=str(DATA_ROOT),
    TRAIN_AUTH_DIR=str(TRAIN_AUTH_DIR),
    TRAIN_FORG_DIR=str(TRAIN_FORG_DIR),
    TRAIN_MASK_DIR=str(TRAIN_MASK_DIR),
    TEST_IMG_DIR=str(TEST_IMG_DIR),
    SUP_IMG_DIR=str(SUP_IMG_DIR) if SUP_IMG_DIR.exists() else "",
    SUP_MASK_DIR=str(SUP_MASK_DIR) if SUP_MASK_DIR.exists() else "",
    SAMPLE_SUBMISSION=str(SAMPLE_SUB_PATH),
    PROF_DIR=str(PROF_DIR),
    TRAIN_MANIFEST=str(train_out),
    TEST_MANIFEST=str(test_out),
)
(Path(PROF_DIR) / "paths.json").write_text(json.dumps(paths, indent=2))

# ----------------------------
# 6) Stage-0 paper figures (safe + lightweight)
# ----------------------------
import matplotlib.pyplot as plt

def savefig(path, dpi=300):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

# Fig0-2: class distribution
vc = df_train_all["y"].value_counts().sort_index()
plt.figure(figsize=(5,4))
plt.bar([str(k) for k in vc.index], vc.values)
plt.xlabel("Class (0=authentic, 1=forged)")
plt.ylabel("Count")
plt.title("Class distribution")
savefig(FIG_DIR / "Fig0-2_class_distribution.png")

# Fig0-3: image size scatter (sample to keep fast)
sample_n = min(800, len(df_train_all))
rows_s = df_train_all.sample(sample_n, random_state=SEED)
ws, hs = [], []
for p in rows_s["img_path"].tolist():
    im = Image.open(p)
    w, h = im.size
    ws.append(w); hs.append(h)

plt.figure(figsize=(6,4))
plt.scatter(ws, hs, s=8)
plt.xlabel("Width"); plt.ylabel("Height")
plt.title("Image size scatter (sample)")
savefig(FIG_DIR / "Fig0-3_image_size_scatter.png")

# Fig0-4: GT mask area histogram (forged with mask exists)
forged_with_mask = df_train_all[(df_train_all["y"]==1) & (df_train_all["mask_path"]!="")].copy()
areas = []
if len(forged_with_mask) > 0:
    rows_m = forged_with_mask.sample(min(1200, len(forged_with_mask)), random_state=SEED)
    for mp in rows_m["mask_path"].tolist():
        m = read_mask_npy(Path(mp))
        areas.append(float(m.mean()))
plt.figure(figsize=(6,4))
plt.hist(areas, bins=40)
plt.xlabel("Mask area (fraction of pixels)")
plt.ylabel("Count")
plt.title("GT mask area distribution (forged)")
savefig(FIG_DIR / "Fig0-4_gt_mask_area_hist.png")

# Fig0-1: sample grid (authentic + forged overlay)
def make_sample_grid(n_auth=6, n_forg=6):
    items, titles = [], []
    a = df_train_all[df_train_all["y"]==0].sample(min(n_auth, (df_train_all["y"]==0).sum()), random_state=SEED)
    f = df_train_all[(df_train_all["y"]==1) & (df_train_all["mask_path"]!="")].sample(
        min(n_forg, len(forged_with_mask)), random_state=SEED
    ) if len(forged_with_mask) else pd.DataFrame()

    # authentic
    for _, r in a.iterrows():
        img = np.array(Image.open(r["img_path"]).convert("RGB"))
        items.append(img); titles.append(f'{r["case_id"]} | authentic')

    # forged with overlay
    for _, r in f.iterrows():
        img = np.array(Image.open(r["img_path"]).convert("RGB"))
        m = read_mask_npy(Path(r["mask_path"]))
        ov = overlay_mask(img, m, alpha=0.45)
        items.append(ov); titles.append(f'{r["case_id"]} | forged (GT overlay)')

    # plot
    n = len(items)
    if n == 0:
        return
    ncols = 4
    nrows = int(math.ceil(n / ncols))
    plt.figure(figsize=(14, 10))
    for i in range(n):
        plt.subplot(nrows, ncols, i+1)
        plt.imshow(items[i])
        plt.axis("off")
        plt.title(titles[i], fontsize=9)
    savefig(FIG_DIR / "Fig0-1_samples_grid.png")

make_sample_grid()

print("\n[OK] Saved manifests & figures:")
print(" -", train_out)
print(" -", test_out)
print(" -", PROF_DIR / "paths.json")
print(" - figures:", FIG_DIR)

# Keep globals for next stages
globals().update(dict(
    DATA_ROOT=DATA_ROOT,
    TRAIN_AUTH_DIR=TRAIN_AUTH_DIR,
    TRAIN_FORG_DIR=TRAIN_FORG_DIR,
    TRAIN_MASK_DIR=TRAIN_MASK_DIR,
    TEST_IMG_DIR=TEST_IMG_DIR,
    SUP_IMG_DIR=SUP_IMG_DIR if SUP_IMG_DIR.exists() else None,
    SUP_MASK_DIR=SUP_MASK_DIR if SUP_MASK_DIR.exists() else None,
    SAMPLE_SUB_PATH=SAMPLE_SUB_PATH,
    PROF_DIR=PROF_DIR,
    FIG_DIR=FIG_DIR,
    df_train_all=df_train_all,
    df_test=df_test,
))
# ============================================================

DATA_ROOT: /kaggle/input/recodai-luc-scientific-image-forgery-detection
[WARN] Duplicate uid detected: 2377 uids (show up to 10): ['10' '10015' '10017' '10030' '10070' '1008' '10138' '10139' '10147'
 '10152']
train_manifest: (2795, 9) | forged% = 100.0
test_manifest : (1, 4)

[OK] Saved manifests & figures:
 - /kaggle/working/recodai_luc_prof/train_manifest.parquet
 - /kaggle/working/recodai_luc_prof/test_manifest.parquet
 - /kaggle/working/recodai_luc_prof/paths.json
 - figures: /kaggle/working/recodai_luc_prof/figures/stage0


# Data Profiling & Sanity Checks

In [2]:
# ============================================================
# STAGE 1 — Data Profiling & Sanity Checks (ONE CELL)
# - Continues from STAGE 0 (uses df_train_all, df_test, PROF_DIR, paths.json)
# - Outputs:
#   * /kaggle/working/recodai_luc_prof/artifacts/profiles/stage1_profile.json
#   * /kaggle/working/recodai_luc_prof/artifacts/profiles/bad_cases.csv
#   * /kaggle/working/recodai_luc_prof/artifacts/profiles/mask_stats.parquet
#   * Paper figures: /kaggle/working/recodai_luc_prof/figures/stage1/*.png
# - Sanity checks:
#   * image readable, size, mode
#   * mask loadable, ndim (2D/3D), shape match vs image (incl transpose match)
#   * forged missing masks
#   * tiny masks & ambiguous masks sampling for article figures
# ============================================================

import os, json, math, warnings
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")

# ----------------------------
# 0) Require Stage 0 outputs (fallback load)
# ----------------------------
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")
if "df_train_all" not in globals() or "df_test" not in globals():
    paths_path = PROF_DIR / "paths.json"
    if not paths_path.exists():
        raise RuntimeError("Missing Stage 0 artifacts. Run STAGE 0 first.")
    paths = json.loads(paths_path.read_text())
    df_train_all = pd.read_parquet(paths["TRAIN_MANIFEST"])
    df_test = pd.read_parquet(paths["TEST_MANIFEST"])
else:
    paths_path = PROF_DIR / "paths.json"
    paths = json.loads(paths_path.read_text()) if paths_path.exists() else {}

ART_DIR = PROF_DIR / "artifacts" / "profiles"
ART_DIR.mkdir(parents=True, exist_ok=True)

FIG_DIR = PROF_DIR / "figures" / "stage1"
FIG_DIR.mkdir(parents=True, exist_ok=True)

print("Loaded:", "train", df_train_all.shape, "| test", df_test.shape)
print("PROF_DIR:", PROF_DIR)
print("FIG_DIR :", FIG_DIR)
print("ART_DIR :", ART_DIR)

# ----------------------------
# 1) Helpers
# ----------------------------
def safe_open_image(path):
    try:
        im = Image.open(path)
        im.load()
        return im, None
    except Exception as e:
        return None, str(e)

def safe_load_mask(path):
    try:
        m = np.load(path)
        return m, None
    except Exception as e:
        return None, str(e)

def binarize_mask(m):
    if m.ndim == 3:
        m2 = m.max(axis=0)
    else:
        m2 = m
    m2 = (m2 > 0).astype(np.uint8)
    return m2

def overlay_mask(img_rgb, mask01, alpha=0.45):
    img = img_rgb.astype(np.float32) / 255.0
    m = mask01.astype(bool)
    out = img.copy()
    out[m, 0] = (1 - alpha) * out[m, 0] + alpha * 1.0
    out[m, 1] = (1 - alpha) * out[m, 1] + alpha * 0.0
    out[m, 2] = (1 - alpha) * out[m, 2] + alpha * 0.0
    return (out * 255).clip(0, 255).astype(np.uint8)

def savefig(path, dpi=300):
    import matplotlib.pyplot as plt
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

def show_grid(items, titles=None, ncols=4, figsize=(14,10)):
    import matplotlib.pyplot as plt
    n = len(items)
    if n == 0:
        return
    ncols = min(ncols, n)
    nrows = int(math.ceil(n / ncols))
    plt.figure(figsize=figsize)
    for i, x in enumerate(items):
        plt.subplot(nrows, ncols, i+1)
        if x.ndim == 2:
            plt.imshow(x, cmap="gray")
        else:
            plt.imshow(x)
        plt.axis("off")
        if titles is not None:
            plt.title(str(titles[i]), fontsize=9)
    plt.tight_layout()

# ----------------------------
# 2) Fast global checks
# ----------------------------
profile = {}
profile["n_train"] = int(len(df_train_all))
profile["n_test"] = int(len(df_test))
profile["train_forged_pct"] = float(df_train_all["y"].mean() * 100.0) if "y" in df_train_all else None
profile["n_supplemental"] = int(df_train_all.get("is_supplemental", pd.Series([0]*len(df_train_all))).sum())

# duplicates by uid / case_id
profile["dup_uid_count"] = int(df_train_all["uid"].duplicated().sum())
profile["dup_case_id_count"] = int(df_train_all["case_id"].duplicated().sum())

# forged missing masks
forged = df_train_all[df_train_all["y"] == 1].copy()
missing_mask = forged[forged["mask_path"].fillna("").astype(str) == ""]
profile["forged_missing_mask_count"] = int(len(missing_mask))
profile["forged_missing_mask_examples"] = missing_mask["uid"].head(10).tolist()

print("forged_missing_mask_count:", profile["forged_missing_mask_count"])

# ----------------------------
# 3) Sample-based integrity scan (images + masks)
# ----------------------------
SCAN_MAX = 2500  # keep it safe for Kaggle CPU
df_scan = df_train_all.sample(min(SCAN_MAX, len(df_train_all)), random_state=42).reset_index(drop=True)

bad_rows = []
mask_stats_rows = []

img_size_records = []

for i, r in df_scan.iterrows():
    uid = r["uid"]
    ip = Path(r["img_path"])
    mp = Path(r["mask_path"]) if str(r.get("mask_path","")) else None

    im, err_img = safe_open_image(ip)
    if err_img is not None:
        bad_rows.append(dict(uid=uid, case_id=r["case_id"], kind="image_open_fail", path=str(ip), error=err_img))
        continue

    w, h = im.size
    mode = im.mode
    img_size_records.append((w, h, mode, int(r["y"]), int(r.get("is_supplemental", 0))))

    # mask checks for forged with mask
    if int(r["y"]) == 1 and mp is not None and mp.exists():
        m_raw, err_m = safe_load_mask(mp)
        if err_m is not None:
            bad_rows.append(dict(uid=uid, case_id=r["case_id"], kind="mask_load_fail", path=str(mp), error=err_m))
            continue

        m_ndim = int(m_raw.ndim)
        m_shape = tuple(map(int, m_raw.shape))
        m_bin = binarize_mask(m_raw)
        mh, mw = m_bin.shape[:2]

        match = "match" if (mh == h and mw == w) else ("transpose_match" if (mh == w and mw == h) else "mismatch")
        area_px = int(m_bin.sum())
        area_frac = float(area_px / (mh * mw + 1e-9))

        # light intensity stats inside mask (for "ambiguous" examples)
        # (works even without original)
        gray = np.array(im.convert("L"), dtype=np.float32)
        if match == "transpose_match":
            mb = m_bin.T
        else:
            mb = m_bin

        if mb.shape != gray.shape:
            inside_mean = np.nan
            inside_std  = np.nan
        else:
            if mb.sum() == 0:
                inside_mean = 0.0
                inside_std  = 0.0
            else:
                vals = gray[mb.astype(bool)]
                inside_mean = float(vals.mean())
                inside_std  = float(vals.std())

        mask_stats_rows.append(dict(
            uid=uid,
            case_id=r["case_id"],
            is_supplemental=int(r.get("is_supplemental", 0)),
            img_w=int(w), img_h=int(h), img_mode=str(mode),
            mask_path=str(mp),
            mask_ndim=m_ndim,
            mask_shape=str(m_shape),
            grid_match=match,
            mask_w=int(mw), mask_h=int(mh),
            area_px=area_px,
            area_frac=area_frac,
            inside_mean=inside_mean,
            inside_std=inside_std,
        ))

        if match == "mismatch":
            bad_rows.append(dict(uid=uid, case_id=r["case_id"], kind="mask_image_shape_mismatch",
                                 path=str(mp), error=f"img(h,w)=({h},{w}) mask(h,w)=({mh},{mw}) raw_shape={m_shape}"))

# summarize image size
img_sizes = pd.DataFrame(img_size_records, columns=["w","h","mode","y","is_supplemental"])
profile["scan_n"] = int(len(df_scan))
profile["scan_image_fail_count"] = int(sum(1 for b in bad_rows if b["kind"]=="image_open_fail"))
profile["scan_mask_fail_count"] = int(sum(1 for b in bad_rows if b["kind"]=="mask_load_fail"))
profile["scan_shape_mismatch_count"] = int(sum(1 for b in bad_rows if b["kind"]=="mask_image_shape_mismatch"))

# save bad cases
bad_df = pd.DataFrame(bad_rows)
bad_path = ART_DIR / "bad_cases.csv"
bad_df.to_csv(bad_path, index=False)
print("Saved:", bad_path, "| rows:", len(bad_df))

# save mask stats
mask_stats_df = pd.DataFrame(mask_stats_rows)
mask_stats_path = ART_DIR / "mask_stats.parquet"
mask_stats_df.to_parquet(mask_stats_path, index=False)
print("Saved:", mask_stats_path, "| rows:", len(mask_stats_df))

# add more profile stats
if len(img_sizes) > 0:
    profile["img_mode_counts"] = img_sizes["mode"].value_counts().to_dict()
    profile["img_w_minmax"] = [int(img_sizes["w"].min()), int(img_sizes["w"].max())]
    profile["img_h_minmax"] = [int(img_sizes["h"].min()), int(img_sizes["h"].max())]

if len(mask_stats_df) > 0:
    profile["mask_ndim_counts"] = mask_stats_df["mask_ndim"].value_counts().to_dict()
    profile["mask_match_counts"] = mask_stats_df["grid_match"].value_counts().to_dict()
    profile["mask_area_frac_minmax"] = [float(mask_stats_df["area_frac"].min()), float(mask_stats_df["area_frac"].max())]

# save profile json
profile_path = ART_DIR / "stage1_profile.json"
profile_path.write_text(json.dumps(profile, indent=2))
print("Saved:", profile_path)

# ----------------------------
# 4) Paper figures (Stage 1)
# ----------------------------
import matplotlib.pyplot as plt

# Fig1-A: Image size histogram (w and h)
if len(img_sizes) > 0:
    plt.figure(figsize=(6,4))
    plt.hist(img_sizes["w"].values, bins=40, alpha=0.7, label="width")
    plt.hist(img_sizes["h"].values, bins=40, alpha=0.7, label="height")
    plt.xlabel("Pixels")
    plt.ylabel("Count")
    plt.title("Image size distribution (sample)")
    plt.legend()
    savefig(FIG_DIR / "Fig1-A_image_size_hist.png")

# Fig1-B: Mask area fraction histogram (log-friendly view)
if len(mask_stats_df) > 0:
    vals = mask_stats_df["area_frac"].values
    plt.figure(figsize=(6,4))
    plt.hist(vals, bins=50)
    plt.xlabel("Mask area (fraction of pixels)")
    plt.ylabel("Count")
    plt.title("GT mask area distribution (forged, sample)")
    savefig(FIG_DIR / "Fig1-B_mask_area_frac_hist.png")

# Fig1-1: Tiny masks examples (overlay) — good for paper
# define tiny = bottom 1% by area_px or area_frac
if len(mask_stats_df) > 0:
    ms = mask_stats_df.sort_values("area_px").reset_index(drop=True)
    take = min(12, len(ms))
    tiny = ms.head(take)

    items, titles = [], []
    for _, rr in tiny.iterrows():
        ip = df_train_all.loc[df_train_all["uid"]==rr["uid"], "img_path"].iloc[0]
        img = np.array(Image.open(ip).convert("RGB"))
        mp = Path(rr["mask_path"])
        m = binarize_mask(np.load(mp))
        # fix transpose match
        if rr["grid_match"] == "transpose_match":
            m = m.T
        ov = overlay_mask(img, m, alpha=0.45)
        items.append(ov)
        titles.append(f'{rr["case_id"]} | area_px={int(rr["area_px"])}')
    show_grid(items, titles=titles, ncols=4, figsize=(14,10))
    savefig(FIG_DIR / "Fig1-1_tiny_mask_examples.png")

# Fig1-2: Coverage examples (small/medium/large)
if len(mask_stats_df) > 0:
    ms = mask_stats_df.sort_values("area_frac").reset_index(drop=True)
    idxs = []
    for q in [0.05, 0.50, 0.95]:
        idxs.append(int(q * (len(ms)-1)))
    pick = ms.iloc[idxs].drop_duplicates("uid")

    items, titles = [], []
    for _, rr in pick.iterrows():
        ip = df_train_all.loc[df_train_all["uid"]==rr["uid"], "img_path"].iloc[0]
        img = np.array(Image.open(ip).convert("RGB"))
        mp = Path(rr["mask_path"])
        m = binarize_mask(np.load(mp))
        if rr["grid_match"] == "transpose_match":
            m = m.T
        ov = overlay_mask(img, m, alpha=0.45)
        items.append(ov)
        titles.append(f'{rr["case_id"]} | area_frac={rr["area_frac"]:.4f}')
    show_grid(items, titles=titles, ncols=3, figsize=(12,6))
    savefig(FIG_DIR / "Fig1-2_mask_coverage_examples.png")

print("\n[OK] Stage 1 complete.")
print("Profile  :", profile_path)
print("Bad cases:", bad_path)
print("Mask stats:", mask_stats_path)
print("Figures  :", FIG_DIR)

# Keep globals for next stages
globals().update(dict(
    STAGE1_PROFILE=profile,
    mask_stats_df=mask_stats_df,
    bad_cases_df=bad_df,
    STAGE1_ART_DIR=ART_DIR,
    STAGE1_FIG_DIR=FIG_DIR,
))
# ============================================================


Loaded: train (2795, 9) | test (1, 4)
PROF_DIR: /kaggle/working/recodai_luc_prof
FIG_DIR : /kaggle/working/recodai_luc_prof/figures/stage1
ART_DIR : /kaggle/working/recodai_luc_prof/artifacts/profiles
forged_missing_mask_count: 0
Saved: /kaggle/working/recodai_luc_prof/artifacts/profiles/bad_cases.csv | rows: 0
Saved: /kaggle/working/recodai_luc_prof/artifacts/profiles/mask_stats.parquet | rows: 2500
Saved: /kaggle/working/recodai_luc_prof/artifacts/profiles/stage1_profile.json

[OK] Stage 1 complete.
Profile  : /kaggle/working/recodai_luc_prof/artifacts/profiles/stage1_profile.json
Bad cases: /kaggle/working/recodai_luc_prof/artifacts/profiles/bad_cases.csv
Mask stats: /kaggle/working/recodai_luc_prof/artifacts/profiles/mask_stats.parquet
Figures  : /kaggle/working/recodai_luc_prof/figures/stage1


# DINOv2 Feature Extraction

In [None]:
# ============================================================
# STAGE 2 — DINOv2 Feature Extraction (Token-Grid Embedding) (ONE CELL)
# - Embedding backbone: DINOv2-Base from /kaggle/input/dinov2/pytorch/base/1
# - Produces per-image token-grid cache (.npz) for TRAIN + TEST
# - Keeps artifacts consistent for next stages (mask training & inference)
#
# Outputs:
#   /kaggle/working/recodai_luc/cache/dinov2_base_518_cfg_<hash>/
#       cfg.json
#       train/{uid}.npz   (tokens: float16 [Htok,Wtok,D])
#       test/{uid}.npz
#       tokens_manifest_train.parquet
#       tokens_manifest_test.parquet
#
# Paper figures:
#   /kaggle/working/recodai_luc_prof/figures/stage2/Fig2-2_patch_grid_overlay.png
#   /kaggle/working/recodai_luc_prof/figures/stage2/Fig2-3_token_norm_heatmap.png
# ============================================================

import os, gc, json, math, time, hashlib, warnings
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")

# ----------------------------
# 0) Require STAGE 0 artifacts (fallback load)
# ----------------------------
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")
paths_path = PROF_DIR / "paths.json"
if not paths_path.exists():
    raise RuntimeError("Missing paths.json. Run STAGE 0 first.")

paths = json.loads(paths_path.read_text())
df_train_all = globals().get("df_train_all", None)
df_test = globals().get("df_test", None)
if df_train_all is None:
    df_train_all = pd.read_parquet(paths["TRAIN_MANIFEST"])
if df_test is None:
    df_test = pd.read_parquet(paths["TEST_MANIFEST"])

FIG_DIR = PROF_DIR / "figures" / "stage2"
FIG_DIR.mkdir(parents=True, exist_ok=True)

print("Loaded manifests:", df_train_all.shape, df_test.shape)
print("FIG_DIR:", FIG_DIR)

# ----------------------------
# 1) Locate DINOv2 local model directory (prefer user path)
# ----------------------------
PREFERRED_MODEL_DIR = Path("/kaggle/input/dinov2/pytorch/base/1")

def find_local_dinov2_dir():
    # fallback: search /kaggle/input for a HF-style directory containing pytorch_model.bin + config.json
    base = Path("/kaggle/input")
    cands = []
    for d in base.rglob("pytorch_model.bin"):
        parent = d.parent
        if (parent / "config.json").exists():
            # heuristic: name contains "dino" somewhere
            if "dino" in str(parent).lower():
                cands.append(parent)
    # prefer shortest path / most direct
    cands = sorted(cands, key=lambda p: (len(str(p)), str(p)))
    return cands[0] if cands else None

MODEL_DIR = PREFERRED_MODEL_DIR if PREFERRED_MODEL_DIR.exists() else find_local_dinov2_dir()
if MODEL_DIR is None or (not MODEL_DIR.exists()):
    raise RuntimeError(
        "Could not find local DINOv2 model directory.\n"
        "Expected: /kaggle/input/dinov2/pytorch/base/1\n"
        "Or a directory containing pytorch_model.bin + config.json under /kaggle/input."
    )

print("MODEL_DIR:", MODEL_DIR)

# ----------------------------
# 2) Load model + processor
# ----------------------------
import torch

try:
    from transformers import AutoImageProcessor, AutoModel
except Exception as e:
    raise RuntimeError(
        "transformers is not available in this environment. "
        "Please ensure 'transformers' is installed in Kaggle."
    ) from e

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

processor = AutoImageProcessor.from_pretrained(str(MODEL_DIR))
model = AutoModel.from_pretrained(str(MODEL_DIR))
model.eval().to(device)

# enforce fixed input size (important for stable token grid)
INPUT_SIZE = 518  # recommended for patch=14 => 37x37 grid
# Many processors support these attributes; set defensively.
for attr, val in [
    ("do_resize", True),
    ("do_center_crop", True),
]:
    if hasattr(processor, attr):
        setattr(processor, attr, val)

# set resize/crop size if supported
# processor.size may be int or dict; we enforce dict
if hasattr(processor, "size"):
    try:
        processor.size = {"height": INPUT_SIZE, "width": INPUT_SIZE}
    except Exception:
        pass
if hasattr(processor, "crop_size"):
    try:
        processor.crop_size = {"height": INPUT_SIZE, "width": INPUT_SIZE}
    except Exception:
        pass

# Patch size (best effort)
patch_size = None
if hasattr(model, "config") and hasattr(model.config, "patch_size"):
    patch_size = int(model.config.patch_size)
elif hasattr(model, "config") and hasattr(model.config, "vit_patch_size"):
    patch_size = int(model.config.vit_patch_size)
else:
    patch_size = 14  # dinov2-base default

# ----------------------------
# 3) Cache directory (hash config for reproducibility)
# ----------------------------
CFG = {
    "model_dir": str(MODEL_DIR),
    "input_size": int(INPUT_SIZE),
    "patch_size": int(patch_size),
    "dtype": "float16",
    "tokens": "patch_tokens_only (exclude CLS)",
    "processor_class": processor.__class__.__name__,
    "model_class": model.__class__.__name__,
}
cfg_id = hashlib.md5(json.dumps(CFG, sort_keys=True).encode()).hexdigest()[:12]

CACHE_ROOT = Path("/kaggle/working/recodai_luc/cache") / f"dinov2_base_518_cfg_{cfg_id}"
TRAIN_OUT = CACHE_ROOT / "train"
TEST_OUT  = CACHE_ROOT / "test"
TRAIN_OUT.mkdir(parents=True, exist_ok=True)
TEST_OUT.mkdir(parents=True, exist_ok=True)
(CACHE_ROOT / "cfg.json").write_text(json.dumps(CFG, indent=2))

print("CACHE_ROOT:", CACHE_ROOT)

# ----------------------------
# 4) Dataset + DataLoader
# ----------------------------
from torch.utils.data import Dataset, DataLoader

class ImgDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
    def __len__(self):
        return len(self.df)
    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        uid = str(r["uid"])
        ip = str(r["img_path"])
        return uid, r.get("case_id", uid), ip

def collate_fn(batch):
    uids, case_ids, img_paths = zip(*batch)
    images = [Image.open(p).convert("RGB") for p in img_paths]
    inputs = processor(images=images, return_tensors="pt")
    return list(uids), list(case_ids), list(img_paths), inputs

BATCH_SIZE = 8 if device.type == "cuda" else 4
NUM_WORKERS = 2  # safe default on Kaggle
PIN_MEMORY = (device.type == "cuda")

# ----------------------------
# 5) Extract + Save tokens (resume-safe)
# ----------------------------
@torch.no_grad()
def extract_split(df, out_dir: Path, split_name: str):
    ds = ImgDataset(df)
    dl = DataLoader(
        ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        collate_fn=collate_fn,
        drop_last=False,
    )

    rows = []
    n_total = len(df)
    n_done = 0
    t0 = time.time()

    for uids, case_ids, img_paths, inputs in dl:
        # skip existing
        to_run = []
        run_idx = []
        for i, uid in enumerate(uids):
            op = out_dir / f"{uid}.npz"
            if op.exists():
                # still record manifest row
                rows.append(dict(
                    uid=uid, case_id=str(case_ids[i]), split=split_name,
                    img_path=str(img_paths[i]), tokens_path=str(op),
                ))
            else:
                to_run.append(i)
                run_idx.append(i)

        if len(to_run) > 0:
            # subset inputs
            pix = inputs["pixel_values"][to_run].to(device, non_blocking=True)

            out = model(pixel_values=pix)
            # transformers ViT-style: last_hidden_state [B, 1+N, D]
            hs = out.last_hidden_state
            patch = hs[:, 1:, :]  # exclude CLS
            B, N, D = patch.shape
            g = int(round(math.sqrt(N)))
            if g * g != N:
                # fallback: derive from input size / patch size
                g2 = INPUT_SIZE // patch_size
                if g2 * g2 == N:
                    g = g2
                else:
                    raise RuntimeError(
                        f"Token count N={N} is not a perfect square and cannot be resolved. "
                        f"Check preprocessing/INPUT_SIZE. (g~{g}, g2={g2}, patch={patch_size})"
                    )
            Htok = Wtok = g

            tok = patch.reshape(B, Htok, Wtok, D).detach().float().cpu().numpy().astype(np.float16)

            for j, bi in enumerate(to_run):
                uid = uids[bi]
                op = out_dir / f"{uid}.npz"
                np.savez_compressed(op, tokens=tok[j], Htok=Htok, Wtok=Wtok, D=D)
                rows.append(dict(
                    uid=str(uid), case_id=str(case_ids[bi]), split=split_name,
                    img_path=str(img_paths[bi]), tokens_path=str(op),
                ))

        n_done += len(uids)
        if n_done % max(64, BATCH_SIZE * 8) == 0:
            dt = time.time() - t0
            print(f"[{split_name}] {n_done}/{n_total} done | {dt/60:.1f} min")
        if device.type == "cuda":
            torch.cuda.empty_cache()
        gc.collect()

    dfm = pd.DataFrame(rows).drop_duplicates(subset=["uid"], keep="last").reset_index(drop=True)
    # attach token meta from one sample
    if len(dfm) > 0:
        # read one file
        s = np.load(dfm["tokens_path"].iloc[0])
        dfm["tok_h"] = int(s["Htok"])
        dfm["tok_w"] = int(s["Wtok"])
        dfm["tok_d"] = int(s["D"])
    return dfm

print("Extracting TRAIN tokens...")
man_train = extract_split(df_train_all, TRAIN_OUT, "train")

print("Extracting TEST tokens...")
man_test = extract_split(df_test, TEST_OUT, "test")

# Save manifests
man_train_path = CACHE_ROOT / "tokens_manifest_train.parquet"
man_test_path  = CACHE_ROOT / "tokens_manifest_test.parquet"
man_train.to_parquet(man_train_path, index=False)
man_test.to_parquet(man_test_path, index=False)

print("\n[OK] Saved token manifests:")
print(" -", man_train_path, "| rows:", len(man_train))
print(" -", man_test_path,  "| rows:", len(man_test))

# ----------------------------
# 6) Stage 2 paper figures (grid overlay + token norm heatmap)
# ----------------------------
import matplotlib.pyplot as plt

def savefig(path, dpi=300):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

# pick one sample (prefer forged with mask, else any)
pick_uid = None
if "y" in df_train_all.columns and (df_train_all["y"]==1).any():
    pick_uid = df_train_all[df_train_all["y"]==1]["uid"].iloc[0]
else:
    pick_uid = df_train_all["uid"].iloc[0]

pick_row = df_train_all[df_train_all["uid"]==pick_uid].iloc[0]
img = Image.open(pick_row["img_path"]).convert("RGB")

# Create a resized copy to INPUT_SIZE for visualization
img_resized = img.resize((INPUT_SIZE, INPUT_SIZE), Image.BICUBIC)
img_np = np.array(img_resized)

# Determine grid size from cached token
tok_npz = np.load(TRAIN_OUT / f"{pick_uid}.npz")
Htok = int(tok_npz["Htok"])
Wtok = int(tok_npz["Wtok"])
tokens = tok_npz["tokens"].astype(np.float32)  # [H,W,D]

# Fig2-2: patch grid overlay
plt.figure(figsize=(7,7))
plt.imshow(img_np)
for i in range(1, Wtok):
    x = i * (INPUT_SIZE / Wtok)
    plt.plot([x, x], [0, INPUT_SIZE], linewidth=0.7)
for j in range(1, Htok):
    y = j * (INPUT_SIZE / Htok)
    plt.plot([0, INPUT_SIZE], [y, y], linewidth=0.7)
plt.axis("off")
plt.title(f"Patch-token grid overlay ({Htok}×{Wtok}) | uid={pick_uid}")
savefig(FIG_DIR / "Fig2-2_patch_grid_overlay.png")

# Fig2-3: token norm heatmap
norm_map = np.linalg.norm(tokens, axis=-1)  # [H,W]
plt.figure(figsize=(7,6))
plt.imshow(norm_map, cmap="magma")
plt.colorbar(fraction=0.046, pad=0.04)
plt.title(f"Token L2-norm heatmap | uid={pick_uid}")
plt.axis("off")
savefig(FIG_DIR / "Fig2-3_token_norm_heatmap.png")

print("\n[OK] Stage 2 figures saved to:", FIG_DIR)

# ----------------------------
# 7) Keep globals for next stages
# ----------------------------
globals().update(dict(
    DINO_MODEL_DIR=MODEL_DIR,
    DINO_PROCESSOR=processor,
    DINO_MODEL=model,
    DINO_CFG=CFG,
    CFG_ID=cfg_id,
    CACHE_ROOT=CACHE_ROOT,
    TOKENS_MANIFEST_TRAIN=man_train_path,
    TOKENS_MANIFEST_TEST=man_test_path,
    man_train=man_train,
    man_test=man_test,
    TOK_H=int(man_train["tok_h"].iloc[0]) if len(man_train) else None,
    TOK_W=int(man_train["tok_w"].iloc[0]) if len(man_train) else None,
    TOK_D=int(man_train["tok_d"].iloc[0]) if len(man_train) else None,
))
# ============================================================

Loaded manifests: (2795, 9) (1, 4)
FIG_DIR: /kaggle/working/recodai_luc_prof/figures/stage2
MODEL_DIR: /kaggle/input/dinov2/pytorch/base/1


2026-01-15 13:06:10.383285: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1768482370.659493      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1768482370.741634      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1768482371.350164      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768482371.350253      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1768482371.350257      55 computation_placer.cc:177] computation placer alr

Device: cpu
CACHE_ROOT: /kaggle/working/recodai_luc/cache/dinov2_base_518_cfg_3be27f67a798
Extracting TRAIN tokens...
[train] 64/2795 done | 3.7 min
[train] 128/2795 done | 7.4 min
[train] 192/2795 done | 11.1 min
[train] 256/2795 done | 14.8 min
[train] 320/2795 done | 18.5 min


# Ground Truth Preparation on Token Grid

In [None]:
# ============================================================
# STAGE 3 — Ground Truth Preparation on Token Grid (ONE CELL)
# - Continues from STAGE 0 + STAGE 2 (uses train_manifest + tokens cache)
# - Creates token-grid GT masks (Htok x Wtok) aligned to each image/mask
#
# Outputs:
#   /kaggle/working/recodai_luc/cache/dinov2_base_518_cfg_<hash>/grid_masks_train/{uid}.npy
#   /kaggle/working/recodai_luc/cache/dinov2_base_518_cfg_<hash>/gridmask_manifest_train.parquet
#   /kaggle/working/recodai_luc_prof/train_manifest_with_gridmask.parquet
#   /kaggle/working/recodai_luc_prof/artifacts/profiles/stage3_gridmask_profile.json
#
# Paper figures:
#   /kaggle/working/recodai_luc_prof/figures/stage3/Fig3-1_full_vs_grid_examples.png
#   /kaggle/working/recodai_luc_prof/figures/stage3/Fig3-2_grid_area_hist.png
# ============================================================

import os, json, math, warnings, hashlib
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")

# ----------------------------
# 0) Require STAGE 0 artifacts (fallback load)
# ----------------------------
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")
paths_path = PROF_DIR / "paths.json"
if not paths_path.exists():
    raise RuntimeError("Missing paths.json. Run STAGE 0 first.")
paths = json.loads(paths_path.read_text())

df_train_all = globals().get("df_train_all", None)
if df_train_all is None:
    df_train_all = pd.read_parquet(paths["TRAIN_MANIFEST"])

FIG_DIR = PROF_DIR / "figures" / "stage3"
FIG_DIR.mkdir(parents=True, exist_ok=True)

ART_PROF_DIR = PROF_DIR / "artifacts" / "profiles"
ART_PROF_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# 1) Locate token cache root from STAGE 2 (robust auto-pick)
# ----------------------------
def pick_token_cache_root():
    # Prefer globals from Stage2 if present
    if "CACHE_ROOT" in globals():
        cr = Path(globals()["CACHE_ROOT"])
        if (cr / "tokens_manifest_train.parquet").exists():
            return cr
    # Else search /kaggle/working/recodai_luc/cache
    base = Path("/kaggle/working/recodai_luc/cache")
    if not base.exists():
        raise RuntimeError("Token cache base not found. Run STAGE 2 first.")
    cands = []
    for d in base.glob("dinov2_base_518_cfg_*"):
        if (d / "tokens_manifest_train.parquet").exists():
            # score by number of token files in train/
            nfiles = len(list((d / "train").glob("*.npz")))
            cands.append((nfiles, d))
    if not cands:
        raise RuntimeError("No valid token cache found. Run STAGE 2 first.")
    cands.sort(reverse=True, key=lambda x: x[0])
    return cands[0][1]

CACHE_ROOT = pick_token_cache_root()
TOK_MAN_TRAIN = CACHE_ROOT / "tokens_manifest_train.parquet"
man_train = pd.read_parquet(TOK_MAN_TRAIN)

print("CACHE_ROOT:", CACHE_ROOT)
print("Token manifest train:", man_train.shape)

# Read token grid shape from one sample
sample_npz = np.load(man_train["tokens_path"].iloc[0])
Htok = int(sample_npz["Htok"])
Wtok = int(sample_npz["Wtok"])
Dtok = int(sample_npz["D"])
print(f"Token grid: Htok={Htok}, Wtok={Wtok}, D={Dtok}")

# ----------------------------
# 2) Grid-mask output dir (resume-safe)
# ----------------------------
GRID_DIR = CACHE_ROOT / "grid_masks_train"
GRID_DIR.mkdir(parents=True, exist_ok=True)

GRID_CFG = dict(
    cache_root=str(CACHE_ROOT),
    grid_dir=str(GRID_DIR),
    Htok=Htok, Wtok=Wtok,
    method="full_mask -> align (transpose if needed) -> resize to (Wtok,Htok) with NEAREST -> binarize",
)
(CACHE_ROOT / "gridmask_cfg.json").write_text(json.dumps(GRID_CFG, indent=2))

# ----------------------------
# 3) Helpers
# ----------------------------
def load_mask_as_bin(mask_path: Path):
    m = np.load(mask_path)
    if m.ndim == 3:
        m = m.max(axis=0)
    return (m > 0).astype(np.uint8)

def align_mask_to_image(mask01: np.ndarray, img_w: int, img_h: int):
    """
    Returns aligned mask (H,W) to match image orientation as best as possible.
    """
    mh, mw = mask01.shape[:2]
    if (mh == img_h) and (mw == img_w):
        return mask01, "match"
    if (mh == img_w) and (mw == img_h):
        return mask01.T, "transpose_match"
    # fallback: resize mask to image size (nearest)
    m_img = Image.fromarray((mask01 * 255).astype(np.uint8), mode="L")
    m_res = m_img.resize((img_w, img_h), resample=Image.NEAREST)
    return (np.array(m_res) > 0).astype(np.uint8), "resized_to_match"

def downsample_to_grid(mask01_hw: np.ndarray, Wtok: int, Htok: int):
    m_img = Image.fromarray((mask01_hw * 255).astype(np.uint8), mode="L")
    m_small = m_img.resize((Wtok, Htok), resample=Image.NEAREST)
    return (np.array(m_small) > 0).astype(np.uint8)

def overlay_mask(img_rgb, mask01, alpha=0.45):
    img = img_rgb.astype(np.float32) / 255.0
    m = mask01.astype(bool)
    out = img.copy()
    out[m, 0] = (1 - alpha) * out[m, 0] + alpha * 1.0
    out[m, 1] = (1 - alpha) * out[m, 1] + alpha * 0.0
    out[m, 2] = (1 - alpha) * out[m, 2] + alpha * 0.0
    return (out * 255).clip(0, 255).astype(np.uint8)

# ----------------------------
# 4) Build grid masks for all train rows (resume-safe)
# ----------------------------
df = df_train_all.copy().reset_index(drop=True)
df["mask_path"] = df.get("mask_path", "").fillna("").astype(str)

rows = []
stats = {
    "n_train": int(len(df)),
    "n_forged": int((df["y"] == 1).sum()),
    "n_authentic": int((df["y"] == 0).sum()),
    "n_forged_missing_mask": int(((df["y"]==1) & (df["mask_path"]=="")).sum()),
    "align_counts": {"match":0, "transpose_match":0, "resized_to_match":0, "no_mask":0},
    "written": 0,
    "skipped_existing": 0,
}

t0 = time.time() if "time" in globals() else None

for i, r in df.iterrows():
    uid = str(r["uid"])
    case_id = str(r["case_id"])
    y = int(r["y"])
    ip = Path(r["img_path"])
    mp = Path(r["mask_path"]) if (y == 1 and str(r["mask_path"])) else None

    out_p = GRID_DIR / f"{uid}.npy"
    if out_p.exists():
        stats["skipped_existing"] += 1
        # still record
        m_grid = np.load(out_p)
        rows.append(dict(
            uid=uid, case_id=case_id, y=y,
            gridmask_path=str(out_p),
            grid_area_px=int(m_grid.sum()),
            grid_area_frac=float(m_grid.mean()),
            align="cached",
        ))
        continue

    # open image size (needed for alignment)
    try:
        im = Image.open(ip)
        w, h = im.size
    except Exception:
        # if image unreadable, fallback to zero mask
        m_grid = np.zeros((Htok, Wtok), dtype=np.uint8)
        np.save(out_p, m_grid)
        stats["written"] += 1
        rows.append(dict(
            uid=uid, case_id=case_id, y=y,
            gridmask_path=str(out_p),
            grid_area_px=0, grid_area_frac=0.0,
            align="image_open_fail->zero",
        ))
        continue

    if y == 0 or mp is None or (not mp.exists()):
        # authentic or missing mask => all-zero
        m_grid = np.zeros((Htok, Wtok), dtype=np.uint8)
        np.save(out_p, m_grid)
        stats["align_counts"]["no_mask"] += 1
        stats["written"] += 1
        rows.append(dict(
            uid=uid, case_id=case_id, y=y,
            gridmask_path=str(out_p),
            grid_area_px=0, grid_area_frac=0.0,
            align="no_mask",
        ))
        continue

    # forged with mask
    try:
        m = load_mask_as_bin(mp)
    except Exception:
        m_grid = np.zeros((Htok, Wtok), dtype=np.uint8)
        np.save(out_p, m_grid)
        stats["written"] += 1
        rows.append(dict(
            uid=uid, case_id=case_id, y=y,
            gridmask_path=str(out_p),
            grid_area_px=0, grid_area_frac=0.0,
            align="mask_load_fail->zero",
        ))
        continue

    m_aligned, how = align_mask_to_image(m, w, h)
    stats["align_counts"][how] = stats["align_counts"].get(how, 0) + 1

    m_grid = downsample_to_grid(m_aligned, Wtok=Wtok, Htok=Htok)

    np.save(out_p, m_grid.astype(np.uint8))
    stats["written"] += 1

    rows.append(dict(
        uid=uid, case_id=case_id, y=y,
        gridmask_path=str(out_p),
        grid_area_px=int(m_grid.sum()),
        grid_area_frac=float(m_grid.mean()),
        align=how,
    ))

    if (i+1) % 500 == 0:
        print(f"[gridmask] {i+1}/{len(df)}")

gridman = pd.DataFrame(rows)
gridman_path = CACHE_ROOT / "gridmask_manifest_train.parquet"
gridman.to_parquet(gridman_path, index=False)

# Merge into train manifest for downstream stages
df_train_with_grid = df.merge(gridman[["uid","gridmask_path","grid_area_px","grid_area_frac","align"]],
                              on="uid", how="left")

train_with_grid_path = PROF_DIR / "train_manifest_with_gridmask.parquet"
df_train_with_grid.to_parquet(train_with_grid_path, index=False)

# Save profile
stats["grid_area_frac_minmax"] = [float(gridman["grid_area_frac"].min()), float(gridman["grid_area_frac"].max())]
stats["grid_area_px_minmax"] = [int(gridman["grid_area_px"].min()), int(gridman["grid_area_px"].max())]
profile_path = ART_PROF_DIR / "stage3_gridmask_profile.json"
profile_path.write_text(json.dumps(stats, indent=2))

print("\n[OK] Saved:")
print(" -", gridman_path)
print(" -", train_with_grid_path)
print(" -", profile_path)

# ----------------------------
# 5) Paper figures (Stage 3)
# ----------------------------
import matplotlib.pyplot as plt

def savefig(path, dpi=300):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

def show_grid(items, titles=None, ncols=3, figsize=(14,10)):
    n = len(items)
    if n == 0:
        return
    ncols = min(ncols, n)
    nrows = int(math.ceil(n / ncols))
    plt.figure(figsize=figsize)
    for i, x in enumerate(items):
        plt.subplot(nrows, ncols, i+1)
        if x.ndim == 2:
            plt.imshow(x, cmap="gray")
        else:
            plt.imshow(x)
        plt.axis("off")
        if titles is not None:
            plt.title(str(titles[i]), fontsize=9)
    plt.tight_layout()

# Fig3-2: grid area histogram (for forged)
g_f = gridman[gridman["y"]==1].copy()
plt.figure(figsize=(6,4))
plt.hist(g_f["grid_area_frac"].values, bins=40)
plt.xlabel("Grid-mask area (fraction of cells)")
plt.ylabel("Count")
plt.title("Grid GT mask area distribution (forged)")
savefig(FIG_DIR / "Fig3-2_grid_area_hist.png")

# Fig3-1: full-res vs grid-res qualitative examples (pick tiny/medium/large by FULL area proxy)
# We use grid_area_frac for selection (fast & consistent)
cand = df_train_with_grid[(df_train_with_grid["y"]==1) & (df_train_with_grid["mask_path"]!="")].copy()
if len(cand) > 0:
    cand = cand.sort_values("grid_area_frac").reset_index(drop=True)
    picks = []
    for q in [0.02, 0.08, 0.25, 0.50, 0.75, 0.95]:
        picks.append(int(q * (len(cand)-1)))
    picks = sorted(set(picks))
    sel = cand.iloc[picks].copy()

    items, titles = [], []
    for _, r in sel.iterrows():
        uid = str(r["uid"])
        ip = Path(r["img_path"])
        mp = Path(r["mask_path"])
        img = np.array(Image.open(ip).convert("RGB"))
        w, h = Image.open(ip).size

        # full-res mask aligned
        m = load_mask_as_bin(mp)
        m_aligned, how = align_mask_to_image(m, w, h)

        # grid mask + upsample back to image size for visualization
        mg = np.load(Path(r["gridmask_path"]))
        mg_img = Image.fromarray((mg*255).astype(np.uint8), mode="L").resize((w, h), resample=Image.NEAREST)
        mg_up = (np.array(mg_img) > 0).astype(np.uint8)

        items.append(img)
        titles.append(f"{uid} | image")

        items.append(overlay_mask(img, m_aligned, alpha=0.45))
        titles.append(f"{uid} | full GT ({how})")

        items.append(overlay_mask(img, mg_up, alpha=0.45))
        titles.append(f"{uid} | grid GT upsampled")

    show_grid(items, titles=titles, ncols=3, figsize=(14, 18))
    savefig(FIG_DIR / "Fig3-1_full_vs_grid_examples.png")

print("[OK] Stage 3 figures saved to:", FIG_DIR)

# ----------------------------
# 6) Keep globals for next stages
# ----------------------------
globals().update(dict(
    CACHE_ROOT=CACHE_ROOT,
    GRID_DIR=GRID_DIR,
    GRIDMAN_TRAIN=gridman_path,
    TRAIN_MANIFEST_WITH_GRID=train_with_grid_path,
    Htok=Htok, Wtok=Wtok, Dtok=Dtok,
    df_train_with_grid=df_train_with_grid,
))
# ============================================================


# Leakage-Safe Cross-Validation Split

In [None]:
# ============================================================
# STAGE 4 — Leakage-Safe Cross-Validation Split (ONE CELL)
# - GroupKFold by case_id (leakage-safe)
# - Writes fold assignment back into train manifest (with gridmask paths)
#
# Outputs:
#   /kaggle/working/recodai_luc_prof/folds.parquet
#   /kaggle/working/recodai_luc_prof/train_manifest_with_gridmask_folds.parquet
#   /kaggle/working/recodai_luc_prof/artifacts/folds/fold_summary.json
#
# Paper figure:
#   /kaggle/working/recodai_luc_prof/figures/stage4/Fig4-1_fold_distribution.png
# ============================================================

import os, json, math, warnings
from pathlib import Path

import numpy as np
import pandas as pd

warnings.filterwarnings("ignore")

# ----------------------------
# 0) Require STAGE 3 outputs (fallback load)
# ----------------------------
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")
paths_path = PROF_DIR / "paths.json"
if not paths_path.exists():
    raise RuntimeError("Missing paths.json. Run STAGE 0 first.")
paths = json.loads(paths_path.read_text())

train_with_grid_path = Path(globals().get("TRAIN_MANIFEST_WITH_GRID", PROF_DIR / "train_manifest_with_gridmask.parquet"))
if not train_with_grid_path.exists():
    raise RuntimeError("Missing train_manifest_with_gridmask.parquet. Run STAGE 3 first.")
df_train = pd.read_parquet(train_with_grid_path)

FIG_DIR = PROF_DIR / "figures" / "stage4"
FIG_DIR.mkdir(parents=True, exist_ok=True)

ART_DIR = PROF_DIR / "artifacts" / "folds"
ART_DIR.mkdir(parents=True, exist_ok=True)

print("Loaded:", df_train.shape, "from", train_with_grid_path)

# ----------------------------
# 1) Ensure group key exists
# ----------------------------
if "case_id" not in df_train.columns:
    raise RuntimeError("case_id missing in train manifest. Ensure STAGE 0 created case_id.")

df_train["case_id"] = df_train["case_id"].astype(str)
df_train["uid"] = df_train["uid"].astype(str)

# If you want a stronger grouping rule, you can transform case_id here.
# Current dataset shown uses numeric filenames => case_id == uid => still safe.

# ----------------------------
# 2) Create GroupKFold split
# ----------------------------
N_FOLDS = 5  # adjust if needed

# If there are fewer groups than folds, reduce folds
n_groups = df_train["case_id"].nunique()
if n_groups < N_FOLDS:
    N_FOLDS = max(2, n_groups)
    print(f"[WARN] Reduced N_FOLDS to {N_FOLDS} because n_groups={n_groups}")

try:
    from sklearn.model_selection import GroupKFold
except Exception as e:
    raise RuntimeError("scikit-learn not available. Please add it to your Kaggle environment.") from e

gkf = GroupKFold(n_splits=N_FOLDS)

fold = np.full(len(df_train), -1, dtype=int)
X_dummy = np.zeros((len(df_train), 1), dtype=np.float32)
y_dummy = df_train["y"].values if "y" in df_train.columns else np.zeros(len(df_train))

for k, (_, va_idx) in enumerate(gkf.split(X_dummy, y_dummy, groups=df_train["case_id"].values)):
    fold[va_idx] = k

df_train["fold"] = fold
if (df_train["fold"] < 0).any():
    raise RuntimeError("Internal error: some rows did not receive a fold assignment.")

# ----------------------------
# 3) Fold summary + leakage checks
# ----------------------------
summary = []
for k in range(N_FOLDS):
    d = df_train[df_train["fold"] == k]
    summary.append(dict(
        fold=int(k),
        n=int(len(d)),
        forged_pct=float(d["y"].mean() * 100.0) if "y" in d.columns and len(d) else 0.0,
        n_groups=int(d["case_id"].nunique()),
    ))

# leakage check: no group appears in more than one fold
grp_to_fold = df_train.groupby("case_id")["fold"].nunique()
leak_groups = grp_to_fold[grp_to_fold > 1]
if len(leak_groups) > 0:
    # should never happen with GroupKFold
    raise RuntimeError(f"Leakage detected: {len(leak_groups)} groups appear in multiple folds!")

fold_summary = dict(
    n_folds=int(N_FOLDS),
    n_train=int(len(df_train)),
    n_groups=int(n_groups),
    folds=summary,
    forged_pct_overall=float(df_train["y"].mean() * 100.0) if "y" in df_train.columns else None,
)

(ART_DIR / "fold_summary.json").write_text(json.dumps(fold_summary, indent=2))

# Save folds file
df_folds = df_train[["uid", "case_id", "fold"]].copy()
folds_path = PROF_DIR / "folds.parquet"
df_folds.to_parquet(folds_path, index=False)

# Save updated train manifest
train_with_folds_path = PROF_DIR / "train_manifest_with_gridmask_folds.parquet"
df_train.to_parquet(train_with_folds_path, index=False)

print("\n[OK] Saved:")
print(" -", folds_path)
print(" -", train_with_folds_path)
print(" -", ART_DIR / "fold_summary.json")

# ----------------------------
# 4) Paper figure: fold distribution
# ----------------------------
import matplotlib.pyplot as plt

def savefig(path, dpi=300):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

fold_df = pd.DataFrame(summary)

plt.figure(figsize=(7,4))
plt.bar(fold_df["fold"].astype(str), fold_df["n"].values)
plt.xlabel("Fold")
plt.ylabel("Samples")
plt.title("Fold distribution (samples per fold)")
savefig(FIG_DIR / "Fig4-1_fold_counts.png")

plt.figure(figsize=(7,4))
plt.plot(fold_df["fold"].values, fold_df["forged_pct"].values, marker="o")
plt.xlabel("Fold")
plt.ylabel("Forged %")
plt.title("Fold distribution (forged ratio per fold)")
savefig(FIG_DIR / "Fig4-1_forged_ratio.png")

print("[OK] Figures saved to:", FIG_DIR)

# Keep globals for next stages
globals().update(dict(
    N_FOLDS=N_FOLDS,
    df_train_folds=df_train,
    FOLDS_PATH=folds_path,
    TRAIN_MANIFEST_WITH_FOLDS=train_with_folds_path,
    FOLD_SUMMARY=fold_summary,
))
# ============================================================


# Train Segmentation Decoder

In [None]:
# ============================================================
# STAGE 5 — Train Segmentation Decoder (Model A) (ONE CELL)
# - Uses cached DINOv2 token-grid embeddings from STAGE 2
# - Uses token-grid GT masks from STAGE 3
# - Uses leakage-safe folds from STAGE 4 (GroupKFold by case_id)
#
# Model A:
#   tokens [Htok,Wtok,D] -> (permute to [D,H,W]) -> small CNN decoder -> logits [Htok,Wtok]
# Loss:
#   BCEWithLogits + Dice (imbalance-friendly)
#
# Outputs:
#   /kaggle/working/recodai_luc_prof/artifacts/mask_model/
#       cfg.json
#       fold_{k}.pt
#       history_fold_{k}.csv
#       valid_preds_fold_{k}.npz   (uids + prob grid)
#
# Paper figures:
#   /kaggle/working/recodai_luc_prof/figures/stage5/
#       Fig5-1_curves_fold_{k}.png
#       Fig5-2_qualitative_fold_{k}.png
# ============================================================

import os, gc, json, math, time, warnings, hashlib
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image, ImageFile

ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")

# ----------------------------
# 0) Require prior stages (fallback load)
# ----------------------------
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")
paths_path = PROF_DIR / "paths.json"
if not paths_path.exists():
    raise RuntimeError("Missing paths.json. Run STAGE 0 first.")
paths = json.loads(paths_path.read_text())

train_with_folds_path = Path(globals().get("TRAIN_MANIFEST_WITH_FOLDS", PROF_DIR / "train_manifest_with_gridmask_folds.parquet"))
if not train_with_folds_path.exists():
    raise RuntimeError("Missing train_manifest_with_gridmask_folds.parquet. Run STAGE 4 first.")
df_train = pd.read_parquet(train_with_folds_path)

# pick token cache root (same helper logic as Stage 3)
def pick_token_cache_root():
    if "CACHE_ROOT" in globals():
        cr = Path(globals()["CACHE_ROOT"])
        if (cr / "tokens_manifest_train.parquet").exists():
            return cr
    base = Path("/kaggle/working/recodai_luc/cache")
    cands = []
    for d in base.glob("dinov2_base_518_cfg_*"):
        if (d / "tokens_manifest_train.parquet").exists():
            nfiles = len(list((d / "train").glob("*.npz")))
            cands.append((nfiles, d))
    if not cands:
        raise RuntimeError("No valid token cache found. Run STAGE 2 first.")
    cands.sort(reverse=True, key=lambda x: x[0])
    return cands[0][1]

CACHE_ROOT = pick_token_cache_root()
man_train = pd.read_parquet(CACHE_ROOT / "tokens_manifest_train.parquet")

# token grid meta
s = np.load(man_train["tokens_path"].iloc[0])
Htok = int(s["Htok"]); Wtok = int(s["Wtok"]); Dtok = int(s["D"])
print(f"CACHE_ROOT: {CACHE_ROOT}")
print(f"Token grid: {Htok}x{Wtok}x{Dtok}")
print("Train rows:", df_train.shape, " | folds:", sorted(df_train["fold"].unique().tolist()))

# output dirs
ART_DIR = PROF_DIR / "artifacts" / "mask_model"
ART_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR = PROF_DIR / "figures" / "stage5"
FIG_DIR.mkdir(parents=True, exist_ok=True)

# ----------------------------
# 1) Join paths (tokens + gridmask) robustly
# ----------------------------
df = df_train.copy()
df["uid"] = df["uid"].astype(str)

# tokens path from man_train
man_train2 = man_train[["uid","tokens_path"]].copy()
man_train2["uid"] = man_train2["uid"].astype(str)
df = df.merge(man_train2, on="uid", how="left")

# gridmask path should already exist in df from Stage 3
if "gridmask_path" not in df.columns:
    raise RuntimeError("gridmask_path missing. Run STAGE 3 first.")

# filter missing paths (fail-fast but keep report)
miss_tok = df["tokens_path"].isna() | (df["tokens_path"].astype(str)=="")
miss_gm  = df["gridmask_path"].isna() | (df["gridmask_path"].astype(str)=="")
if miss_tok.any():
    bad = df.loc[miss_tok, "uid"].head(10).tolist()
    raise RuntimeError(f"Missing tokens_path for {miss_tok.sum()} rows. Examples: {bad}")
if miss_gm.any():
    bad = df.loc[miss_gm, "uid"].head(10).tolist()
    raise RuntimeError(f"Missing gridmask_path for {miss_gm.sum()} rows. Examples: {bad}")

# Ensure files exist
def exists_all(series):
    return series.map(lambda p: Path(str(p)).exists()).values

tok_exist = exists_all(df["tokens_path"])
gm_exist  = exists_all(df["gridmask_path"])
if (~tok_exist).any():
    bad = df.loc[~tok_exist, ["uid","tokens_path"]].head(10).values.tolist()
    raise RuntimeError(f"Some token files not found. Examples: {bad}")
if (~gm_exist).any():
    bad = df.loc[~gm_exist, ["uid","gridmask_path"]].head(10).values.tolist()
    raise RuntimeError(f"Some gridmask files not found. Examples: {bad}")

# ----------------------------
# 2) Torch setup
# ----------------------------
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = (device.type == "cuda")
print("Device:", device, "| AMP:", USE_AMP)

# ----------------------------
# 3) Dataset
# ----------------------------
class TokenGridMaskDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        uid = str(r["uid"])
        tok_path = str(r["tokens_path"])
        gm_path  = str(r["gridmask_path"])

        npz = np.load(tok_path)
        tok = npz["tokens"].astype(np.float32)  # [H,W,D]
        gm  = np.load(gm_path).astype(np.float32)  # [H,W] 0/1

        # safety: enforce shapes
        if tok.shape[0] != Htok or tok.shape[1] != Wtok or tok.shape[2] != Dtok:
            raise RuntimeError(f"Token shape mismatch uid={uid}: {tok.shape} expected {(Htok,Wtok,Dtok)}")
        if gm.shape[0] != Htok or gm.shape[1] != Wtok:
            raise RuntimeError(f"Gridmask shape mismatch uid={uid}: {gm.shape} expected {(Htok,Wtok)}")

        # torch: tokens -> [D,H,W], mask -> [1,H,W]
        x = torch.from_numpy(tok).permute(2,0,1).contiguous()
        y = torch.from_numpy(gm)[None, ...].contiguous()
        return uid, x, y

# ----------------------------
# 4) Model A — small CNN decoder
# ----------------------------
class SmallSegDecoder(nn.Module):
    def __init__(self, in_ch, mid1=256, mid2=128, mid3=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, mid1, 3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid1, mid2, 3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid2, mid3, 3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid3, 1, 1, padding=0, bias=True),
        )

    def forward(self, x):
        return self.net(x)  # [B,1,H,W]

# ----------------------------
# 5) Losses + metrics
# ----------------------------
def dice_loss_with_logits(logits, targets, eps=1e-6):
    # logits/targets: [B,1,H,W]
    probs = torch.sigmoid(logits)
    num = 2.0 * (probs * targets).sum(dim=(1,2,3)) + eps
    den = (probs + targets).sum(dim=(1,2,3)) + eps
    dice = num / den
    return 1.0 - dice.mean()

@torch.no_grad()
def dice_score_from_logits(logits, targets, thr=0.5, eps=1e-6):
    probs = torch.sigmoid(logits)
    pred = (probs > thr).float()
    num = 2.0 * (pred * targets).sum(dim=(1,2,3)) + eps
    den = (pred + targets).sum(dim=(1,2,3)) + eps
    return (num / den).mean().item()

# ----------------------------
# 6) Training config (safe defaults)
# ----------------------------
N_FOLDS = int(df["fold"].nunique())
EPOCHS = 25 if device.type == "cpu" else 35
PATIENCE = 6
LR = 3e-4
WEIGHT_DECAY = 1e-4

BATCH_SIZE = 8 if device.type == "cpu" else 16
NUM_WORKERS = 2
PIN_MEMORY = (device.type == "cuda")

CFG = dict(
    stage="stage5_train_seg_decoder",
    cache_root=str(CACHE_ROOT),
    Htok=Htok, Wtok=Wtok, Dtok=Dtok,
    epochs=int(EPOCHS),
    patience=int(PATIENCE),
    lr=float(LR),
    weight_decay=float(WEIGHT_DECAY),
    batch_size=int(BATCH_SIZE),
    loss="BCEWithLogits(pos_weight fold-wise) + Dice",
    model="SmallSegDecoder(in=Dtok, 256-128-64-1)",
)
(ART_DIR / "cfg.json").write_text(json.dumps(CFG, indent=2))

# ----------------------------
# 7) Utilities for figures
# ----------------------------
import matplotlib.pyplot as plt

def savefig(path, dpi=300):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

def overlay_mask(img_rgb, mask01, alpha=0.45):
    img = img_rgb.astype(np.float32) / 255.0
    m = mask01.astype(bool)
    out = img.copy()
    out[m, 0] = (1 - alpha) * out[m, 0] + alpha * 1.0
    out[m, 1] = (1 - alpha) * out[m, 1] + alpha * 0.0
    out[m, 2] = (1 - alpha) * out[m, 2] + alpha * 0.0
    return (out * 255).clip(0, 255).astype(np.uint8)

def plot_curves(hist_df, out_path):
    plt.figure(figsize=(7,4))
    plt.plot(hist_df["epoch"], hist_df["train_loss"], label="train_loss")
    plt.plot(hist_df["epoch"], hist_df["valid_loss"], label="valid_loss")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Training curves (loss)")
    plt.legend()
    savefig(out_path)

    plt.figure(figsize=(7,4))
    plt.plot(hist_df["epoch"], hist_df["valid_dice"], label="valid_dice")
    plt.xlabel("Epoch"); plt.ylabel("Dice"); plt.title("Validation Dice")
    plt.legend()
    out2 = Path(out_path).with_name(Path(out_path).stem.replace("curves","dice") + ".png")
    savefig(out2)

# ----------------------------
# 8) Train per fold
# ----------------------------
all_fold_metrics = []

scaler = torch.cuda.amp.GradScaler(enabled=USE_AMP)

for fold_id in range(N_FOLDS):
    print("\n" + "="*60)
    print(f"Fold {fold_id}/{N_FOLDS-1}")

    trn = df[df["fold"] != fold_id].reset_index(drop=True)
    val = df[df["fold"] == fold_id].reset_index(drop=True)

    # fold-wise pos_weight from mean grid area fraction (already stored per sample)
    # pos_frac ~ E[mask cell==1]; clamp to avoid extreme
    if "grid_area_frac" in trn.columns:
        pos_frac = float(np.clip(trn["grid_area_frac"].mean(), 1e-6, 0.2))  # cap 0.2
    else:
        pos_frac = 1e-3
    pos_weight = float(np.clip((1.0 - pos_frac) / max(pos_frac, 1e-6), 1.0, 30.0))
    bce = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([pos_weight], device=device))

    print(f"Train size={len(trn)} | Valid size={len(val)} | pos_frac≈{pos_frac:.6f} | pos_weight={pos_weight:.2f}")

    ds_tr = TokenGridMaskDataset(trn)
    ds_va = TokenGridMaskDataset(val)

    dl_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True,
                       num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=False)
    dl_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False,
                       num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY, drop_last=False)

    model = SmallSegDecoder(in_ch=Dtok).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

    best_dice = -1.0
    best_epoch = -1
    best_path = ART_DIR / f"fold_{fold_id}.pt"
    patience_ctr = 0

    history = []

    for epoch in range(1, EPOCHS + 1):
        # ---- train
        model.train()
        tr_losses = []

        for _, x, y in dl_tr:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                logits = model(x)
                loss_bce = bce(logits, y)
                loss_dice = dice_loss_with_logits(logits, y)
                loss = loss_bce + loss_dice

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

            tr_losses.append(loss.item())

        # ---- valid
        model.eval()
        va_losses = []
        va_dices = []

        with torch.no_grad():
            for _, x, y in dl_va:
                x = x.to(device, non_blocking=True)
                y = y.to(device, non_blocking=True)
                with torch.cuda.amp.autocast(enabled=USE_AMP):
                    logits = model(x)
                    loss_bce = bce(logits, y)
                    loss_dice = dice_loss_with_logits(logits, y)
                    loss = loss_bce + loss_dice
                va_losses.append(loss.item())
                va_dices.append(dice_score_from_logits(logits, y, thr=0.5))

        train_loss = float(np.mean(tr_losses)) if tr_losses else float("nan")
        valid_loss = float(np.mean(va_losses)) if va_losses else float("nan")
        valid_dice = float(np.mean(va_dices)) if va_dices else 0.0

        history.append(dict(epoch=epoch, train_loss=train_loss, valid_loss=valid_loss, valid_dice=valid_dice))

        print(f"Epoch {epoch:03d} | train_loss={train_loss:.5f} | valid_loss={valid_loss:.5f} | valid_dice@0.5={valid_dice:.4f}")

        # early stop / best save
        if valid_dice > best_dice + 1e-5:
            best_dice = valid_dice
            best_epoch = epoch
            torch.save({
                "state_dict": model.state_dict(),
                "fold": fold_id,
                "Htok": Htok, "Wtok": Wtok, "Dtok": Dtok,
                "pos_weight": pos_weight,
                "cfg": CFG,
            }, best_path)
            patience_ctr = 0
        else:
            patience_ctr += 1
            if patience_ctr >= PATIENCE:
                print(f"Early stopping at epoch {epoch} (best_epoch={best_epoch}, best_dice={best_dice:.4f})")
                break

        if device.type == "cuda":
            torch.cuda.empty_cache()
        gc.collect()

    # save history
    hist_df = pd.DataFrame(history)
    hist_path = ART_DIR / f"history_fold_{fold_id}.csv"
    hist_df.to_csv(hist_path, index=False)
    print("Saved:", hist_path)
    print("Saved best:", best_path, f"(best_dice={best_dice:.4f} @ epoch {best_epoch})")

    # ---- save valid predictions (for postprocess tuning later)
    # predict probs for all valid rows and store grid probs
    ckpt = torch.load(best_path, map_location=device)
    model.load_state_dict(ckpt["state_dict"], strict=True)
    model.eval()

    uids_va = []
    probs_va = []

    with torch.no_grad():
        for uids, x, y in dl_va:
            x = x.to(device, non_blocking=True)
            with torch.cuda.amp.autocast(enabled=USE_AMP):
                logits = model(x)
                probs = torch.sigmoid(logits).detach().float().cpu().numpy()  # [B,1,H,W]
            uids_va.extend([str(u) for u in uids])
            probs_va.append(probs)

    probs_va = np.concatenate(probs_va, axis=0) if len(probs_va) else np.zeros((0,1,Htok,Wtok), np.float32)
    pred_path = ART_DIR / f"valid_preds_fold_{fold_id}.npz"
    np.savez_compressed(pred_path, uids=np.array(uids_va, dtype=object), probs=probs_va.astype(np.float16))
    print("Saved:", pred_path)

    # ---- paper figures (curves + qualitative)
    try:
        plot_curves(hist_df, FIG_DIR / f"Fig5-1_curves_fold_{fold_id}.png")
    except Exception as e:
        print("[WARN] Could not save curve figure:", e)

    # qualitative examples: image (resized 518), GT overlay (resized), Pred overlay (upsampled)
    try:
        # pick 4 forged + 4 authentic from valid (if available)
        val_df = val.copy()
        v_f = val_df[val_df["y"]==1].head(4)
        v_a = val_df[val_df["y"]==0].head(4)
        sel = pd.concat([v_f, v_a], axis=0).head(8)

        # map uid->prob
        uid2prob = {uids_va[i]: probs_va[i,0] for i in range(len(uids_va))}

        items = []
        titles = []
        for _, rr in sel.iterrows():
            uid = str(rr["uid"])
            ip = Path(rr["img_path"])
            img = np.array(Image.open(ip).convert("RGB"))
            # For visualization we show the same 518x518 resolution used for tokens
            img_518 = np.array(Image.fromarray(img).resize((518,518), Image.BICUBIC))

            # GT full mask if exists -> resize to 518
            gt_overlay = img_518
            if int(rr["y"]) == 1 and str(rr.get("mask_path","")):
                mp = Path(rr["mask_path"])
                if mp.exists():
                    m = np.load(mp)
                    if m.ndim == 3:
                        m = m.max(axis=0)
                    m = (m > 0).astype(np.uint8)
                    # align to image orientation if possible
                    w0, h0 = Image.open(ip).size
                    mh, mw = m.shape
                    if (mh == h0 and mw == w0):
                        m_al = m
                    elif (mh == w0 and mw == h0):
                        m_al = m.T
                    else:
                        m_al = np.array(Image.fromarray((m*255).astype(np.uint8)).resize((w0,h0), Image.NEAREST)) > 0
                        m_al = m_al.astype(np.uint8)
                    m_518 = np.array(Image.fromarray((m_al*255).astype(np.uint8)).resize((518,518), Image.NEAREST)) > 0
                    gt_overlay = overlay_mask(img_518, m_518.astype(np.uint8), alpha=0.45)

            # Pred grid prob -> upsample to 518
            prob = uid2prob.get(uid, np.zeros((Htok,Wtok), np.float32))
            prob_518 = np.array(Image.fromarray((prob*255).astype(np.uint8)).resize((518,518), Image.BILINEAR)).astype(np.float32)/255.0
            pred_bin = (prob_518 > 0.5).astype(np.uint8)
            pred_overlay = overlay_mask(img_518, pred_bin, alpha=0.45)

            items.extend([img_518, gt_overlay, pred_overlay])
            titles.extend([f"{uid} | image",
                           f"{uid} | GT (resized)",
                           f"{uid} | Pred@0.5"])

        # plot grid 3 columns
        n = len(items)
        ncols = 3
        nrows = int(math.ceil(n / ncols))
        plt.figure(figsize=(12, 3.8*nrows))
        for i in range(n):
            plt.subplot(nrows, ncols, i+1)
            plt.imshow(items[i])
            plt.axis("off")
            plt.title(titles[i], fontsize=9)
        savefig(FIG_DIR / f"Fig5-2_qualitative_fold_{fold_id}.png")
    except Exception as e:
        print("[WARN] Could not save qualitative figure:", e)

    all_fold_metrics.append(dict(fold=fold_id, best_epoch=int(best_epoch), best_dice=float(best_dice)))
    del model
    gc.collect()
    if device.type == "cuda":
        torch.cuda.empty_cache()

# Save fold metrics summary
metrics_df = pd.DataFrame(all_fold_metrics)
metrics_path = ART_DIR / "fold_metrics.csv"
metrics_df.to_csv(metrics_path, index=False)
print("\n[OK] Training complete.")
print("Saved:", metrics_path)
print("Models in:", ART_DIR)
print("Figures in:", FIG_DIR)

# Keep globals for next stages
globals().update(dict(
    MASK_MODEL_DIR=ART_DIR,
    MASK_MODELS=[str(ART_DIR / f"fold_{k}.pt") for k in range(N_FOLDS)],
    MASK_FOLD_METRICS=all_fold_metrics,
))
# ============================================================


# Full-Resolution Reconstruction & Post-Processing

In [None]:
# ============================================================
# STAGE 6 — Full-Resolution Reconstruction & Post-Processing (ONE CELL)
# - Uses STAGE 5 saved VALID grid probabilities (valid_preds_fold_k.npz)
# - Reconstructs full-res probability maps, applies post-processing, tunes params on CV-valid
#
# Outputs:
#   /kaggle/working/recodai_luc_prof/artifacts/postprocess/postprocess_cfg.json
#   /kaggle/working/recodai_luc_prof/artifacts/postprocess/threshold_sweep.csv
#   /kaggle/working/recodai_luc_prof/artifacts/postprocess/valid_postproc_summary.json
#   /kaggle/working/recodai_luc_prof/figures/stage6/Fig6-1_before_after_examples.png
#   /kaggle/working/recodai_luc_prof/figures/stage6/Fig6-2_fp_suppression_examples.png
#
# Keeps context for next stage:
#   POSTPROCESS_CFG, POSTPROCESS_CFG_PATH
# ============================================================

import os, gc, json, math, warnings
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image, ImageFile, ImageFilter

ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")

# ----------------------------
# 0) Require prior stages (fallback load)
# ----------------------------
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")
paths_path = PROF_DIR / "paths.json"
if not paths_path.exists():
    raise RuntimeError("Missing paths.json. Run STAGE 0 first.")
paths = json.loads(paths_path.read_text())

train_with_folds_path = Path(globals().get("TRAIN_MANIFEST_WITH_FOLDS", PROF_DIR / "train_manifest_with_gridmask_folds.parquet"))
if not train_with_folds_path.exists():
    raise RuntimeError("Missing train_manifest_with_gridmask_folds.parquet. Run STAGE 4 first.")
df_train = pd.read_parquet(train_with_folds_path).copy()

MASK_MODEL_DIR = Path(globals().get("MASK_MODEL_DIR", PROF_DIR / "artifacts" / "mask_model"))
if not MASK_MODEL_DIR.exists():
    raise RuntimeError("Missing mask_model dir. Run STAGE 5 first.")

ART_DIR = PROF_DIR / "artifacts" / "postprocess"
ART_DIR.mkdir(parents=True, exist_ok=True)

FIG_DIR = PROF_DIR / "figures" / "stage6"
FIG_DIR.mkdir(parents=True, exist_ok=True)

print("Loaded train:", df_train.shape, "| folds:", sorted(df_train["fold"].unique().tolist()))
print("MASK_MODEL_DIR:", MASK_MODEL_DIR)
print("ART_DIR:", ART_DIR)
print("FIG_DIR:", FIG_DIR)

# ----------------------------
# 1) Load all fold valid predictions (grid probs)
# ----------------------------
pred_files = sorted(MASK_MODEL_DIR.glob("valid_preds_fold_*.npz"))
if not pred_files:
    raise RuntimeError("No valid_preds_fold_*.npz found. Ensure STAGE 5 saved validation predictions.")

pred_rows = []
for pf in pred_files:
    z = np.load(pf, allow_pickle=True)
    uids = z["uids"].astype(object)
    probs = z["probs"]  # [N,1,H,W] float16
    # infer fold id from filename
    fold_id = int(pf.stem.split("_")[-1])
    if probs.ndim != 4 or probs.shape[1] != 1:
        raise RuntimeError(f"Bad probs shape in {pf}: {probs.shape}")
    for i, uid in enumerate(uids):
        pred_rows.append(dict(
            uid=str(uid),
            fold=int(fold_id),
            prob_grid=probs[i,0].astype(np.float32),  # store as float32 in-memory
        ))

df_pred = pd.DataFrame(pred_rows)
print("Loaded valid grid preds:", df_pred.shape)

# Join to train rows (must match each uid exactly once on its fold)
df_train["uid"] = df_train["uid"].astype(str)
df_pred["uid"] = df_pred["uid"].astype(str)

dfv = df_train.merge(df_pred, on=["uid", "fold"], how="left")
if dfv["prob_grid"].isna().any():
    miss = dfv[dfv["prob_grid"].isna()][["uid","fold"]].head(20).values.tolist()
    raise RuntimeError(f"Missing validation predictions after merge. Examples: {miss}")
print("Validation merged:", dfv.shape)

# infer grid shape
g0 = dfv["prob_grid"].iloc[0]
Htok, Wtok = int(g0.shape[0]), int(g0.shape[1])
print(f"Grid size: {Htok}x{Wtok}")

# ----------------------------
# 2) Optional SciPy for better morphology/filters (fallback to PIL-only)
# ----------------------------
try:
    import scipy.ndimage as ndi
    _HAS_SCIPY = True
except Exception:
    _HAS_SCIPY = False

print("SciPy available:", _HAS_SCIPY)

# ----------------------------
# 3) Helpers (alignment, reconstruction, postprocess, metrics, figures)
# ----------------------------
def load_mask_as_bin(mask_path: Path):
    m = np.load(mask_path)
    if m.ndim == 3:
        m = m.max(axis=0)
    return (m > 0).astype(np.uint8)

def align_mask_to_image(mask01: np.ndarray, img_w: int, img_h: int):
    mh, mw = mask01.shape[:2]
    if (mh == img_h) and (mw == img_w):
        return mask01, "match"
    if (mh == img_w) and (mw == img_h):
        return mask01.T, "transpose_match"
    # fallback: resize to image size
    m_img = Image.fromarray((mask01*255).astype(np.uint8), mode="L")
    m_res = m_img.resize((img_w, img_h), resample=Image.NEAREST)
    return (np.array(m_res) > 0).astype(np.uint8), "resized_to_match"

def upsample_grid_prob(prob_grid: np.ndarray, w: int, h: int):
    # prob_grid: [Htok,Wtok] float [0..1]
    img = Image.fromarray(np.clip(prob_grid*255.0, 0, 255).astype(np.uint8), mode="L")
    up = img.resize((w, h), resample=Image.BILINEAR)
    return (np.array(up).astype(np.float32) / 255.0).clip(0, 1)

def sobel_grad(x: np.ndarray):
    # x: float32 HxW
    if _HAS_SCIPY:
        gx = ndi.sobel(x, axis=1)
        gy = ndi.sobel(x, axis=0)
        g = np.sqrt(gx*gx + gy*gy)
        g = g / (g.max() + 1e-8)
        return g.astype(np.float32)
    # fallback: simple finite diff
    gx = np.zeros_like(x)
    gy = np.zeros_like(x)
    gx[:,1:] = np.abs(x[:,1:] - x[:,:-1])
    gy[1:,:] = np.abs(x[1:,:] - x[:-1,:])
    g = (gx + gy)
    g = g / (g.max() + 1e-8)
    return g.astype(np.float32)

def gaussian_blur(x: np.ndarray, sigma: float):
    if sigma <= 0:
        return x
    if _HAS_SCIPY:
        return ndi.gaussian_filter(x, sigma=sigma).astype(np.float32)
    # PIL fallback: radius approx sigma
    img = Image.fromarray(np.clip(x*255,0,255).astype(np.uint8), mode="L")
    img = img.filter(ImageFilter.GaussianBlur(radius=float(sigma)))
    return (np.array(img).astype(np.float32)/255.0).clip(0,1)

def morph_close_open(mask01: np.ndarray, k_close=5, k_open=3):
    if not _HAS_SCIPY:
        return mask01  # safe fallback (no morphology)
    st_close = np.ones((k_close, k_close), dtype=bool)
    st_open  = np.ones((k_open,  k_open),  dtype=bool)
    m = ndi.binary_closing(mask01.astype(bool), structure=st_close)
    m = ndi.binary_opening(m, structure=st_open)
    return m.astype(np.uint8)

def postprocess_prob_to_mask(prob_up: np.ndarray, img_gray: np.ndarray, *,
                             alpha_grad=0.35, blur_sigma=1.0, k_std=0.30,
                             thr_min=0.20, thr_max=0.90,
                             do_morph=True, k_close=5, k_open=3,
                             min_area=400, min_mean_inside=0.30):
    """
    Returns: mask01, info dict (thr, area, mean_inside)
    """
    grad = sobel_grad(prob_up)
    enh = (1.0 - alpha_grad) * prob_up + alpha_grad * grad
    enh = enh.clip(0, 1)
    enh = gaussian_blur(enh, sigma=blur_sigma)

    mu = float(enh.mean())
    sd = float(enh.std())
    thr = float(np.clip(mu + k_std * sd, thr_min, thr_max))

    mask = (enh > thr).astype(np.uint8)

    if do_morph:
        mask = morph_close_open(mask, k_close=k_close, k_open=k_open)

    area = int(mask.sum())
    if area == 0:
        return mask, {"thr": thr, "area": 0, "mean_inside": 0.0}

    # mean intensity inside mask (grayscale)
    if area > 0:
        vals = img_gray[mask.astype(bool)]
        mean_inside = float(vals.mean()) / 255.0
    else:
        mean_inside = 0.0

    # anti-FP filtering
    if (area < int(min_area)) or (mean_inside < float(min_mean_inside)):
        mask[:] = 0
        return mask, {"thr": thr, "area": 0, "mean_inside": mean_inside, "filtered": True}

    return mask, {"thr": thr, "area": area, "mean_inside": mean_inside, "filtered": False}

def f1_binary(pred01: np.ndarray, gt01: np.ndarray, eps=1e-9):
    # pixel-level F1
    pred = pred01.astype(bool)
    gt = gt01.astype(bool)
    tp = (pred & gt).sum()
    fp = (pred & (~gt)).sum()
    fn = ((~pred) & gt).sum()
    return float((2*tp + eps) / (2*tp + fp + fn + eps))

def overlay_mask(img_rgb, mask01, alpha=0.45):
    img = img_rgb.astype(np.float32)/255.0
    m = mask01.astype(bool)
    out = img.copy()
    out[m,0] = (1-alpha)*out[m,0] + alpha*1.0
    out[m,1] = (1-alpha)*out[m,1] + alpha*0.0
    out[m,2] = (1-alpha)*out[m,2] + alpha*0.0
    return (out*255).clip(0,255).astype(np.uint8)

import matplotlib.pyplot as plt
def savefig(path, dpi=300):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

def show_grid(items, titles=None, ncols=3, figsize=(14,10)):
    n = len(items)
    if n == 0:
        return
    ncols = min(ncols, n)
    nrows = int(math.ceil(n / ncols))
    plt.figure(figsize=figsize)
    for i, x in enumerate(items):
        plt.subplot(nrows, ncols, i+1)
        if x.ndim == 2:
            plt.imshow(x, cmap="gray")
        else:
            plt.imshow(x)
        plt.axis("off")
        if titles is not None:
            plt.title(str(titles[i]), fontsize=9)
    plt.tight_layout()

# ----------------------------
# 4) Prepare GT loader (full-res)
# ----------------------------
dfv["mask_path"] = dfv.get("mask_path","").fillna("").astype(str)
dfv["img_path"]  = dfv["img_path"].astype(str)

def get_gt_mask(uid, img_path, y, mask_path):
    im = Image.open(img_path).convert("RGB")
    w, h = im.size
    if int(y) == 1 and mask_path and Path(mask_path).exists():
        m = load_mask_as_bin(Path(mask_path))
        m_al, _ = align_mask_to_image(m, w, h)
        return np.array(im), w, h, m_al
    # authentic or missing mask -> empty GT
    return np.array(im), w, h, np.zeros((h, w), dtype=np.uint8)

# ----------------------------
# 5) Tune postprocess parameters on a subset of validation
#     (pixel-level F1 proxy; stable & fast; prevents crashes)
# ----------------------------
TUNE_MAX = 600  # safe compute
# balanced sampling: forged first then authentic
forged_idx = dfv[dfv["y"]==1].index.tolist()
auth_idx   = dfv[dfv["y"]==0].index.tolist()
np.random.RandomState(42).shuffle(forged_idx)
np.random.RandomState(42).shuffle(auth_idx)

take_f = min(len(forged_idx), TUNE_MAX//2)
take_a = min(len(auth_idx),   TUNE_MAX - take_f)
tune_idx = forged_idx[:take_f] + auth_idx[:take_a]
df_tune = dfv.loc[tune_idx].reset_index(drop=True)
print("Tuning subset:", df_tune.shape, "| forged:", int((df_tune["y"]==1).sum()), "| authentic:", int((df_tune["y"]==0).sum()))

# small grid (safe runtime)
GRID = []
for alpha_grad in [0.0, 0.35]:
    for k_std in [0.20, 0.30]:
        for blur_sigma in [0.0, 1.0]:
            for min_area in [0, 400]:
                for min_mean in [0.0, 0.30]:
                    GRID.append((alpha_grad, k_std, blur_sigma, min_area, min_mean))

def eval_params(params):
    alpha_grad, k_std, blur_sigma, min_area, min_mean = params
    scores = []
    n_filtered = 0
    n_pred_pos = 0
    for _, r in df_tune.iterrows():
        img_rgb, w, h, gt = get_gt_mask(r["uid"], r["img_path"], r["y"], r["mask_path"])
        gray = np.array(Image.fromarray(img_rgb).convert("L"), dtype=np.float32)
        prob_up = upsample_grid_prob(r["prob_grid"], w=w, h=h)
        mask, info = postprocess_prob_to_mask(
            prob_up, gray,
            alpha_grad=alpha_grad,
            blur_sigma=blur_sigma,
            k_std=k_std,
            do_morph=True,
            min_area=min_area,
            min_mean_inside=min_mean
        )
        if info.get("filtered", False):
            n_filtered += 1
        if mask.sum() > 0:
            n_pred_pos += 1
        scores.append(f1_binary(mask, gt))
    return float(np.mean(scores)), int(n_pred_pos), int(n_filtered)

results = []
best = None

print("Grid search candidates:", len(GRID))
for i, p in enumerate(GRID, 1):
    mean_f1, n_pos, n_filt = eval_params(p)
    results.append(dict(
        alpha_grad=p[0], k_std=p[1], blur_sigma=p[2], min_area=p[3], min_mean_inside=p[4],
        mean_f1=mean_f1, n_pred_pos=n_pos, n_filtered=n_filt
    ))
    if best is None or mean_f1 > best["mean_f1"]:
        best = results[-1]
    if i % 8 == 0:
        print(f"[tune] {i}/{len(GRID)} best_mean_f1={best['mean_f1']:.4f}")

sweep_df = pd.DataFrame(results).sort_values("mean_f1", ascending=False).reset_index(drop=True)
sweep_path = ART_DIR / "threshold_sweep.csv"
sweep_df.to_csv(sweep_path, index=False)
print("Saved sweep:", sweep_path)
print("Best params:", best)

# final postprocess config for next stages
POSTPROCESS_CFG = dict(
    stage="stage6_postprocess_tuned",
    grid_size=[int(Htok), int(Wtok)],
    # tuned params
    alpha_grad=float(best["alpha_grad"]),
    k_std=float(best["k_std"]),
    blur_sigma=float(best["blur_sigma"]),
    do_morph=True,
    k_close=5,
    k_open=3,
    thr_min=0.20,
    thr_max=0.90,
    min_area=int(best["min_area"]),
    min_mean_inside=float(best["min_mean_inside"]),
    # notes
    tuned_metric="pixel_F1_proxy_on_CV_valid",
    tuned_subset_size=int(len(df_tune)),
    mask_empty_label="authentic",
)

POSTPROCESS_CFG_PATH = ART_DIR / "postprocess_cfg.json"
POSTPROCESS_CFG_PATH.write_text(json.dumps(POSTPROCESS_CFG, indent=2))
print("Saved:", POSTPROCESS_CFG_PATH)

# ----------------------------
# 6) Apply tuned postprocess to ALL validation rows (for summary + figures)
# ----------------------------
all_scores = []
all_info = []
for _, r in dfv.iterrows():
    img_rgb, w, h, gt = get_gt_mask(r["uid"], r["img_path"], r["y"], r["mask_path"])
    gray = np.array(Image.fromarray(img_rgb).convert("L"), dtype=np.float32)
    prob_up = upsample_grid_prob(r["prob_grid"], w=w, h=h)
    mask, info = postprocess_prob_to_mask(
        prob_up, gray,
        alpha_grad=POSTPROCESS_CFG["alpha_grad"],
        blur_sigma=POSTPROCESS_CFG["blur_sigma"],
        k_std=POSTPROCESS_CFG["k_std"],
        thr_min=POSTPROCESS_CFG["thr_min"],
        thr_max=POSTPROCESS_CFG["thr_max"],
        do_morph=POSTPROCESS_CFG["do_morph"],
        k_close=POSTPROCESS_CFG["k_close"],
        k_open=POSTPROCESS_CFG["k_open"],
        min_area=POSTPROCESS_CFG["min_area"],
        min_mean_inside=POSTPROCESS_CFG["min_mean_inside"],
    )
    score = f1_binary(mask, gt)
    all_scores.append(score)
    all_info.append(dict(uid=r["uid"], fold=int(r["fold"]), y=int(r["y"]),
                         f1=score, area=int(mask.sum()), thr=float(info["thr"]),
                         filtered=bool(info.get("filtered", False))))

summary = {
    "n_valid_total": int(len(dfv)),
    "mean_pixel_f1_proxy": float(np.mean(all_scores)),
    "forged_mean_pixel_f1_proxy": float(np.mean([a["f1"] for a in all_info if a["y"]==1])) if any(a["y"]==1 for a in all_info) else 0.0,
    "auth_mean_pixel_f1_proxy": float(np.mean([a["f1"] for a in all_info if a["y"]==0])) if any(a["y"]==0 for a in all_info) else 0.0,
    "pred_positive_frac": float(np.mean([1.0 if a["area"]>0 else 0.0 for a in all_info])),
    "filtered_count": int(sum(1 for a in all_info if a["filtered"])),
}
summary_path = ART_DIR / "valid_postproc_summary.json"
summary_path.write_text(json.dumps(summary, indent=2))
print("Saved:", summary_path)
print("Summary:", summary)

# ----------------------------
# 7) Paper figures (before/after + FP suppression examples)
# ----------------------------
# pick examples
df_info = pd.DataFrame(all_info)
# "hard FP": authentic predicted positive before filtering would require raw; we approximate by selecting authentic with area=0 (filtered away)
# We'll generate 2 panels:
#  - mixed forged examples before/after
#  - authentic examples where final is empty (show how postprocess outputs empty)

# Helper to build before/after visuals for a uid
def build_before_after(uid):
    r = dfv[dfv["uid"]==uid].iloc[0]
    img_rgb, w, h, gt = get_gt_mask(r["uid"], r["img_path"], r["y"], r["mask_path"])
    gray = np.array(Image.fromarray(img_rgb).convert("L"), dtype=np.float32)
    prob_up = upsample_grid_prob(r["prob_grid"], w=w, h=h)

    # raw bin at fixed 0.5 (baseline)
    raw_bin = (prob_up > 0.5).astype(np.uint8)

    # tuned postprocess
    mask, info = postprocess_prob_to_mask(
        prob_up, gray,
        alpha_grad=POSTPROCESS_CFG["alpha_grad"],
        blur_sigma=POSTPROCESS_CFG["blur_sigma"],
        k_std=POSTPROCESS_CFG["k_std"],
        thr_min=POSTPROCESS_CFG["thr_min"],
        thr_max=POSTPROCESS_CFG["thr_max"],
        do_morph=POSTPROCESS_CFG["do_morph"],
        k_close=POSTPROCESS_CFG["k_close"],
        k_open=POSTPROCESS_CFG["k_open"],
        min_area=POSTPROCESS_CFG["min_area"],
        min_mean_inside=POSTPROCESS_CFG["min_mean_inside"],
    )

    # visuals (resize to max 640 for figure speed)
    max_side = 640
    scale = min(1.0, max_side / max(w, h))
    ww, hh = int(w*scale), int(h*scale)

    img_v = np.array(Image.fromarray(img_rgb).resize((ww, hh), Image.BICUBIC))
    prob_v = np.array(Image.fromarray((prob_up*255).astype(np.uint8)).resize((ww, hh), Image.BILINEAR))
    raw_v  = np.array(Image.fromarray((raw_bin*255).astype(np.uint8)).resize((ww, hh), Image.NEAREST)) > 0
    post_v = np.array(Image.fromarray((mask*255).astype(np.uint8)).resize((ww, hh), Image.NEAREST)) > 0
    gt_v   = np.array(Image.fromarray((gt*255).astype(np.uint8)).resize((ww, hh), Image.NEAREST)) > 0

    return {
        "img": img_v,
        "prob": prob_v,
        "raw_ov": overlay_mask(img_v, raw_v.astype(np.uint8), alpha=0.45),
        "post_ov": overlay_mask(img_v, post_v.astype(np.uint8), alpha=0.45),
        "gt_ov": overlay_mask(img_v, gt_v.astype(np.uint8), alpha=0.45),
        "title": f"{uid} | y={int(r['y'])} | thr≈{info['thr']:.3f} | area={int(mask.sum())}"
    }

# Fig6-1: mixed forged examples
forged_uids = df_info[df_info["y"]==1].sort_values("f1").head(4)["uid"].tolist()
# if too few forged
if len(forged_uids) < 4:
    forged_uids = df_info[df_info["y"]==1]["uid"].head(4).tolist()

items, titles = [], []
for uid in forged_uids[:4]:
    ex = build_before_after(uid)
    items += [ex["img"], ex["gt_ov"], ex["raw_ov"], ex["post_ov"]]
    titles += [f"{uid} image", f"{uid} GT overlay", f"{uid} raw@0.5", ex["title"] + " post"]

show_grid(items, titles=titles, ncols=4, figsize=(16, 10))
savefig(FIG_DIR / "Fig6-1_before_after_examples.png")

# Fig6-2: authentic examples (should end empty after filtering)
auth_uids = df_info[(df_info["y"]==0)].sample(min(4, int((df_info["y"]==0).sum())), random_state=42)["uid"].tolist()
items, titles = [], []
for uid in auth_uids[:4]:
    ex = build_before_after(uid)
    items += [ex["img"], ex["raw_ov"], ex["post_ov"]]
    titles += [f"{uid} image", f"{uid} raw@0.5", ex["title"] + " post"]
show_grid(items, titles=titles, ncols=3, figsize=(15, 8))
savefig(FIG_DIR / "Fig6-2_fp_suppression_examples.png")

print("[OK] Stage 6 figures saved to:", FIG_DIR)

# ----------------------------
# 8) Keep globals for next stage
# ----------------------------
globals().update(dict(
    POSTPROCESS_CFG=POSTPROCESS_CFG,
    POSTPROCESS_CFG_PATH=str(POSTPROCESS_CFG_PATH),
    POSTPROCESS_SUMMARY=summary,
))
# ============================================================


# Test Inference (Fold Ensemble) & Submission Generation

In [None]:
# ============================================================
# STAGE 7 — Test Inference (Fold Ensemble) & Submission Generation (ONE CELL)
# - Uses:
#   * STAGE 0: df_test + sample_submission
#   * STAGE 2: token caches for TEST (tokens_manifest_test.parquet)
#   * STAGE 5: fold_{k}.pt models
#   * STAGE 6: postprocess_cfg.json (tuned)
#
# Produces:
#   /kaggle/working/submission.csv
#   /kaggle/working/recodai_luc_prof/artifacts/submission/submission.csv
#   /kaggle/working/recodai_luc_prof/artifacts/submission/test_pred_stats.json
#   /kaggle/working/recodai_luc_prof/figures/stage7/Fig7-1_test_overlays.png
#   /kaggle/working/recodai_luc_prof/figures/stage7/Fig7-2_output_distributions.png
#
# Notes:
# - Ensemble: average probs across folds on token-grid
# - Reconstruction: upsample to original size
# - Postprocess: tuned Stage 6 (morphology uses SciPy if available)
# - Output format:
#     "authentic" OR JSON list string "[start, length, ...]"
# ============================================================

import os, gc, json, math, warnings
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image, ImageFile, ImageFilter

ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore")

# ----------------------------
# 0) Require prior stages (fallback load)
# ----------------------------
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")
paths_path = PROF_DIR / "paths.json"
if not paths_path.exists():
    raise RuntimeError("Missing paths.json. Run STAGE 0 first.")
paths = json.loads(paths_path.read_text())

# test manifest
df_test = globals().get("df_test", None)
if df_test is None:
    df_test = pd.read_parquet(paths["TEST_MANIFEST"]).copy()
df_test["uid"] = df_test["uid"].astype(str)
df_test["case_id"] = df_test["case_id"].astype(str)
df_test["img_path"] = df_test["img_path"].astype(str)

# sample submission (for order)
sample_sub_path = Path(paths.get("SAMPLE_SUBMISSION", PROF_DIR / "sample_submission.csv"))
if not sample_sub_path.exists():
    # fallback to copied one
    sample_sub_path = PROF_DIR / "sample_submission.csv"
sample_sub = pd.read_csv(sample_sub_path)
if "case_id" not in sample_sub.columns:
    # try common Kaggle naming
    if "id" in sample_sub.columns:
        sample_sub = sample_sub.rename(columns={"id":"case_id"})
    else:
        raise RuntimeError("sample_submission.csv does not contain 'case_id' or 'id' column.")

sample_sub["case_id"] = sample_sub["case_id"].astype(str)

# token cache root + manifest test
def pick_token_cache_root():
    if "CACHE_ROOT" in globals():
        cr = Path(globals()["CACHE_ROOT"])
        if (cr / "tokens_manifest_test.parquet").exists():
            return cr
    base = Path("/kaggle/working/recodai_luc/cache")
    cands = []
    for d in base.glob("dinov2_base_518_cfg_*"):
        if (d / "tokens_manifest_test.parquet").exists():
            nfiles = len(list((d / "test").glob("*.npz")))
            cands.append((nfiles, d))
    if not cands:
        raise RuntimeError("No valid token cache found. Run STAGE 2 first.")
    cands.sort(reverse=True, key=lambda x: x[0])
    return cands[0][1]

CACHE_ROOT = pick_token_cache_root()
man_test = pd.read_parquet(CACHE_ROOT / "tokens_manifest_test.parquet")
man_test["uid"] = man_test["uid"].astype(str)

# models
MASK_MODEL_DIR = Path(globals().get("MASK_MODEL_DIR", PROF_DIR / "artifacts" / "mask_model"))
if not MASK_MODEL_DIR.exists():
    raise RuntimeError("Missing mask_model dir. Run STAGE 5 first.")
model_paths = sorted(MASK_MODEL_DIR.glob("fold_*.pt"))
if not model_paths:
    raise RuntimeError("No fold_*.pt models found. Run STAGE 5 first.")

# postprocess cfg
pp_path = Path(globals().get("POSTPROCESS_CFG_PATH", PROF_DIR / "artifacts" / "postprocess" / "postprocess_cfg.json"))
if not pp_path.exists():
    raise RuntimeError("Missing postprocess_cfg.json. Run STAGE 6 first.")
POSTPROCESS_CFG = json.loads(pp_path.read_text())

# output dirs
ART_DIR = PROF_DIR / "artifacts" / "submission"
ART_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR = PROF_DIR / "figures" / "stage7"
FIG_DIR.mkdir(parents=True, exist_ok=True)

print("CACHE_ROOT:", CACHE_ROOT)
print("Models:", len(model_paths), "->", [p.name for p in model_paths])
print("Postprocess cfg:", pp_path)
print("Test rows:", df_test.shape, "| Tokens test rows:", man_test.shape)

# ----------------------------
# 1) Torch + Model definition (must match STAGE 5)
# ----------------------------
import torch
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
USE_AMP = (device.type == "cuda")
print("Device:", device, "| AMP:", USE_AMP)

class SmallSegDecoder(nn.Module):
    def __init__(self, in_ch, mid1=256, mid2=128, mid3=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, mid1, 3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid1, mid2, 3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid2, mid3, 3, padding=1, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid3, 1, 1, padding=0, bias=True),
        )
    def forward(self, x):
        return self.net(x)

# read grid meta from first model ckpt
ck0 = torch.load(model_paths[0], map_location="cpu")
Htok = int(ck0.get("Htok", POSTPROCESS_CFG["grid_size"][0]))
Wtok = int(ck0.get("Wtok", POSTPROCESS_CFG["grid_size"][1]))
Dtok = int(ck0.get("Dtok", 768))
print(f"Model grid: {Htok}x{Wtok} | Dtok={Dtok}")

# load all fold models
models = []
for mp in model_paths:
    ck = torch.load(mp, map_location="cpu")
    m = SmallSegDecoder(in_ch=Dtok)
    m.load_state_dict(ck["state_dict"], strict=True)
    m.eval().to(device)
    models.append(m)

# ----------------------------
# 2) SciPy optional for morphology/filters (same as Stage 6)
# ----------------------------
try:
    import scipy.ndimage as ndi
    _HAS_SCIPY = True
except Exception:
    _HAS_SCIPY = False
print("SciPy available:", _HAS_SCIPY)

# ----------------------------
# 3) Helpers: load tokens, reconstruct, postprocess, RLE
# ----------------------------
def load_tokens(uid: str):
    # map uid -> tokens_path
    # build dict once for speed
    return None

uid2tok = dict(zip(man_test["uid"].tolist(), man_test["tokens_path"].tolist()))
missing_u = [u for u in df_test["uid"].head(50).tolist() if u not in uid2tok]
if missing_u:
    raise RuntimeError(f"Some test uids missing in token manifest. Example: {missing_u[:5]}")

def load_token_grid(uid: str):
    p = Path(uid2tok[uid])
    npz = np.load(p)
    tok = npz["tokens"].astype(np.float32)  # [H,W,D]
    if tok.shape[0] != Htok or tok.shape[1] != Wtok or tok.shape[2] != Dtok:
        raise RuntimeError(f"Token shape mismatch uid={uid}: {tok.shape} expected {(Htok,Wtok,Dtok)}")
    # torch [D,H,W]
    x = torch.from_numpy(tok).permute(2,0,1).contiguous()
    return x

def sobel_grad(x: np.ndarray):
    if _HAS_SCIPY:
        gx = ndi.sobel(x, axis=1)
        gy = ndi.sobel(x, axis=0)
        g = np.sqrt(gx*gx + gy*gy)
        g = g / (g.max() + 1e-8)
        return g.astype(np.float32)
    gx = np.zeros_like(x)
    gy = np.zeros_like(x)
    gx[:,1:] = np.abs(x[:,1:] - x[:,:-1])
    gy[1:,:] = np.abs(x[1:,:] - x[:-1,:])
    g = gx + gy
    g = g / (g.max() + 1e-8)
    return g.astype(np.float32)

def gaussian_blur(x: np.ndarray, sigma: float):
    if sigma <= 0:
        return x
    if _HAS_SCIPY:
        return ndi.gaussian_filter(x, sigma=sigma).astype(np.float32)
    img = Image.fromarray(np.clip(x*255,0,255).astype(np.uint8), mode="L")
    img = img.filter(ImageFilter.GaussianBlur(radius=float(sigma)))
    return (np.array(img).astype(np.float32)/255.0).clip(0,1)

def morph_close_open(mask01: np.ndarray, k_close=5, k_open=3):
    if not _HAS_SCIPY:
        return mask01
    st_close = np.ones((k_close, k_close), dtype=bool)
    st_open  = np.ones((k_open,  k_open),  dtype=bool)
    m = ndi.binary_closing(mask01.astype(bool), structure=st_close)
    m = ndi.binary_opening(m, structure=st_open)
    return m.astype(np.uint8)

def upsample_grid_prob(prob_grid: np.ndarray, w: int, h: int):
    img = Image.fromarray(np.clip(prob_grid*255.0, 0, 255).astype(np.uint8), mode="L")
    up = img.resize((w, h), resample=Image.BILINEAR)
    return (np.array(up).astype(np.float32) / 255.0).clip(0, 1)

def postprocess_prob_to_mask(prob_up: np.ndarray, img_gray: np.ndarray, cfg: dict):
    grad = sobel_grad(prob_up)
    enh = (1.0 - float(cfg["alpha_grad"])) * prob_up + float(cfg["alpha_grad"]) * grad
    enh = enh.clip(0,1)
    enh = gaussian_blur(enh, sigma=float(cfg["blur_sigma"]))

    mu = float(enh.mean())
    sd = float(enh.std())
    thr = float(np.clip(mu + float(cfg["k_std"]) * sd, float(cfg["thr_min"]), float(cfg["thr_max"])))

    mask = (enh > thr).astype(np.uint8)

    if bool(cfg.get("do_morph", True)):
        mask = morph_close_open(mask, k_close=int(cfg.get("k_close",5)), k_open=int(cfg.get("k_open",3)))

    area = int(mask.sum())
    if area == 0:
        return mask, {"thr": thr, "area": 0, "mean_inside": 0.0, "filtered": False}

    vals = img_gray[mask.astype(bool)]
    mean_inside = float(vals.mean()) / 255.0 if vals.size else 0.0

    if (area < int(cfg.get("min_area",0))) or (mean_inside < float(cfg.get("min_mean_inside",0.0))):
        mask[:] = 0
        return mask, {"thr": thr, "area": 0, "mean_inside": mean_inside, "filtered": True}

    return mask, {"thr": thr, "area": area, "mean_inside": mean_inside, "filtered": False}

def rle_json_list(mask01: np.ndarray):
    """
    Encode binary mask to JSON list [start, length, ...]
    Uses Fortran-order flatten (transpose then C-flatten) which is common for Kaggle RLE.
    """
    m = mask01.astype(np.uint8)
    flat = m.T.flatten()  # column-major
    # pad
    padded = np.concatenate([[0], flat, [0]])
    changes = np.where(padded[1:] != padded[:-1])[0] + 1
    runs = []
    for i in range(0, len(changes), 2):
        start = int(changes[i])
        length = int(changes[i+1] - changes[i])
        runs.extend([start, length])
    return json.dumps(runs)

# ----------------------------
# 4) Batched inference on token grid + fold ensemble
# ----------------------------
@torch.no_grad()
def predict_prob_grid_batch(x_batch: torch.Tensor):
    # x_batch: [B, D, H, W]
    probs = None
    with torch.cuda.amp.autocast(enabled=USE_AMP):
        for m in models:
            logit = m(x_batch)  # [B,1,H,W]
            p = torch.sigmoid(logit)
            probs = p if probs is None else (probs + p)
    probs = probs / float(len(models))
    return probs  # torch [B,1,H,W]

BATCH_SIZE = 16 if device.type == "cuda" else 8

# ----------------------------
# 5) Run inference + build submission rows
# ----------------------------
sub_rows = []
stats = {
    "n_test": int(len(df_test)),
    "n_pred_mask": 0,
    "n_pred_authentic": 0,
    "mean_area_if_mask": None,
    "filtered_count": 0,
}

areas = []

# iterate in stable order matching df_test
uids = df_test["uid"].tolist()
case_ids = df_test["case_id"].tolist()
img_paths = df_test["img_path"].tolist()

for i0 in range(0, len(uids), BATCH_SIZE):
    batch_uids = uids[i0:i0+BATCH_SIZE]
    batch_case = case_ids[i0:i0+BATCH_SIZE]
    batch_imgs = img_paths[i0:i0+BATCH_SIZE]

    # load tokens
    xb = torch.stack([load_token_grid(u) for u in batch_uids], dim=0).to(device, non_blocking=True)  # [B,D,H,W]

    # predict grid probs (ensemble)
    pb = predict_prob_grid_batch(xb).detach().float().cpu().numpy()  # [B,1,H,W]

    # per sample: reconstruct + postprocess + encode
    for j in range(len(batch_uids)):
        uid = batch_uids[j]
        cid = batch_case[j]
        ip = batch_imgs[j]

        im = Image.open(ip).convert("RGB")
        w, h = im.size
        img_rgb = np.array(im)
        gray = np.array(im.convert("L"), dtype=np.float32)

        prob_grid = pb[j,0]  # [Htok,Wtok]
        prob_up = upsample_grid_prob(prob_grid, w=w, h=h)
        mask, info = postprocess_prob_to_mask(prob_up, gray, POSTPROCESS_CFG)

        if info.get("filtered", False):
            stats["filtered_count"] += 1

        area = int(mask.sum())
        if area == 0:
            ann = "authentic"
            stats["n_pred_authentic"] += 1
        else:
            ann = rle_json_list(mask)
            stats["n_pred_mask"] += 1
            areas.append(area)

        sub_rows.append(dict(case_id=str(cid), annotation=ann))

    if (i0 + BATCH_SIZE) % max(256, BATCH_SIZE*8) == 0:
        print(f"[test] {min(i0+BATCH_SIZE, len(uids))}/{len(uids)}")

    if device.type == "cuda":
        torch.cuda.empty_cache()
    gc.collect()

# finalize stats
if len(areas) > 0:
    stats["mean_area_if_mask"] = float(np.mean(areas))
    stats["median_area_if_mask"] = float(np.median(areas))
else:
    stats["mean_area_if_mask"] = 0.0
    stats["median_area_if_mask"] = 0.0

# ----------------------------
# 6) Build submission.csv aligned to sample_submission order
# ----------------------------
sub_df = pd.DataFrame(sub_rows)
sub_df["case_id"] = sub_df["case_id"].astype(str)

# ensure uniqueness
dup = sub_df["case_id"].duplicated()
if dup.any():
    # keep last (should not happen)
    sub_df = sub_df.drop_duplicates("case_id", keep="last")

# align to sample submission order
sub_df = sample_sub[["case_id"]].merge(sub_df, on="case_id", how="left")

if sub_df["annotation"].isna().any():
    miss = sub_df[sub_df["annotation"].isna()]["case_id"].head(20).tolist()
    raise RuntimeError(f"Submission missing annotations for {sub_df['annotation'].isna().sum()} case_id. Examples: {miss}")

# Save
OUT_SUB = Path("/kaggle/working/submission.csv")
sub_df.to_csv(OUT_SUB, index=False)

OUT_SUB2 = ART_DIR / "submission.csv"
sub_df.to_csv(OUT_SUB2, index=False)

stats_path = ART_DIR / "test_pred_stats.json"
stats_path.write_text(json.dumps(stats, indent=2))

print("\n[OK] Saved submission:")
print(" -", OUT_SUB)
print(" -", OUT_SUB2)
print("Stats:", stats)

# ----------------------------
# 7) Paper figures (Stage 7): overlays + output distributions
# ----------------------------
import matplotlib.pyplot as plt

def savefig(path, dpi=300):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

def overlay_mask(img_rgb, mask01, alpha=0.45):
    img = img_rgb.astype(np.float32)/255.0
    m = mask01.astype(bool)
    out = img.copy()
    out[m,0] = (1-alpha)*out[m,0] + alpha*1.0
    out[m,1] = (1-alpha)*out[m,1] + alpha*0.0
    out[m,2] = (1-alpha)*out[m,2] + alpha*0.0
    return (out*255).clip(0,255).astype(np.uint8)

def show_grid(items, titles=None, ncols=4, figsize=(14,10)):
    n = len(items)
    if n == 0:
        return
    ncols = min(ncols, n)
    nrows = int(math.ceil(n / ncols))
    plt.figure(figsize=figsize)
    for i, x in enumerate(items):
        plt.subplot(nrows, ncols, i+1)
        plt.imshow(x)
        plt.axis("off")
        if titles is not None:
            plt.title(str(titles[i]), fontsize=9)
    plt.tight_layout()

# Recompute a few overlays for figures (cheap sample)
# choose 8 samples: some predicted mask, some authentic
pred_mask_ids = sub_df[sub_df["annotation"]!="authentic"]["case_id"].head(4).tolist()
pred_auth_ids = sub_df[sub_df["annotation"]=="authentic"]["case_id"].head(4).tolist()
pick_ids = pred_mask_ids + pred_auth_ids
pick_df = df_test[df_test["case_id"].isin(pick_ids)].head(8)

# Need to regenerate masks for picked ones (safe)
items, titles = [], []
for _, r in pick_df.iterrows():
    uid = str(r["uid"])
    cid = str(r["case_id"])
    ip = str(r["img_path"])

    im = Image.open(ip).convert("RGB")
    w, h = im.size
    img_rgb = np.array(im)
    gray = np.array(im.convert("L"), dtype=np.float32)

    x = load_token_grid(uid)[None, ...].to(device)
    with torch.no_grad():
        p = predict_prob_grid_batch(x).detach().float().cpu().numpy()[0,0]
    prob_up = upsample_grid_prob(p, w=w, h=h)
    mask, info = postprocess_prob_to_mask(prob_up, gray, POSTPROCESS_CFG)

    max_side = 640
    scale = min(1.0, max_side / max(w, h))
    ww, hh = int(w*scale), int(h*scale)
    img_v = np.array(Image.fromarray(img_rgb).resize((ww, hh), Image.BICUBIC))
    mask_v = np.array(Image.fromarray((mask*255).astype(np.uint8)).resize((ww, hh), Image.NEAREST)) > 0
    ov = overlay_mask(img_v, mask_v.astype(np.uint8), alpha=0.45)

    items.append(ov)
    titles.append(f"{cid} | area={int(mask.sum())}")

show_grid(items, titles=titles, ncols=4, figsize=(14,8))
savefig(FIG_DIR / "Fig7-1_test_overlays.png")

# Output distribution plots
plt.figure(figsize=(6,4))
plt.bar(["authentic", "mask"], [stats["n_pred_authentic"], stats["n_pred_mask"]])
plt.ylabel("Count")
plt.title("Test output distribution")
savefig(FIG_DIR / "Fig7-2_output_distribution.png")

if len(areas) > 0:
    plt.figure(figsize=(6,4))
    plt.hist(areas, bins=40)
    plt.xlabel("Pred mask area (pixels)")
    plt.ylabel("Count")
    plt.title("Pred mask area histogram (non-empty)")
    savefig(FIG_DIR / "Fig7-2_mask_area_hist.png")

print("[OK] Figures saved to:", FIG_DIR)

# ----------------------------
# 8) Keep globals
# ----------------------------
globals().update(dict(
    submission_df=sub_df,
    SUBMISSION_PATH=str(OUT_SUB),
    SUB_STATS=stats,
))
# ============================================================


# Submission Quality Assurance (QA)

In [None]:
# ============================================================
# STAGE 8 — Submission Quality Assurance (QA) (ONE CELL)
# - Validates submission.csv format + coverage
# - Checks:
#   * all case_id present and unique
#   * annotation is either "authentic" OR valid JSON list [start,length,...]
#   * RLE runs valid (even length, positive lengths, monotonic starts)
#   * Optional: decode a small sample to ensure dimensions match image sizes
#
# Outputs:
#   /kaggle/working/recodai_luc_prof/artifacts/submission/qa_report.json
#   /kaggle/working/recodai_luc_prof/artifacts/submission/qa_bad_rows.csv
#   /kaggle/working/recodai_luc_prof/figures/stage8/Fig8-1_QA_summary.png
# ============================================================

import os, json, math, warnings
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image

warnings.filterwarnings("ignore")

# ----------------------------
# 0) Load required artifacts
# ----------------------------
PROF_DIR = Path("/kaggle/working/recodai_luc_prof")
paths_path = PROF_DIR / "paths.json"
if not paths_path.exists():
    raise RuntimeError("Missing paths.json. Run STAGE 0 first.")
paths = json.loads(paths_path.read_text())

ART_DIR = PROF_DIR / "artifacts" / "submission"
ART_DIR.mkdir(parents=True, exist_ok=True)

FIG_DIR = PROF_DIR / "figures" / "stage8"
FIG_DIR.mkdir(parents=True, exist_ok=True)

# submission
sub_path = Path(globals().get("SUBMISSION_PATH", "/kaggle/working/submission.csv"))
if not sub_path.exists():
    sub_path = ART_DIR / "submission.csv"
if not sub_path.exists():
    raise RuntimeError("submission.csv not found. Run STAGE 7 first.")
sub_df = pd.read_csv(sub_path)

# sample submission (truth for ordering/coverage)
sample_sub_path = Path(paths.get("SAMPLE_SUBMISSION", PROF_DIR / "sample_submission.csv"))
if not sample_sub_path.exists():
    sample_sub_path = PROF_DIR / "sample_submission.csv"
sample_sub = pd.read_csv(sample_sub_path)
if "case_id" not in sample_sub.columns:
    if "id" in sample_sub.columns:
        sample_sub = sample_sub.rename(columns={"id":"case_id"})
    else:
        raise RuntimeError("sample_submission.csv does not contain 'case_id' or 'id'.")

sample_sub["case_id"] = sample_sub["case_id"].astype(str)
sub_df["case_id"] = sub_df["case_id"].astype(str)

# test manifest (for optional decode checks)
df_test = globals().get("df_test", None)
if df_test is None:
    df_test = pd.read_parquet(paths["TEST_MANIFEST"]).copy()
df_test["case_id"] = df_test["case_id"].astype(str)
df_test["img_path"] = df_test["img_path"].astype(str)

print("Loaded submission:", sub_path, "| rows:", len(sub_df))
print("Loaded sample_submission rows:", len(sample_sub))

# ----------------------------
# 1) Basic schema checks
# ----------------------------
required_cols = ["case_id", "annotation"]
missing_cols = [c for c in required_cols if c not in sub_df.columns]
if missing_cols:
    raise RuntimeError(f"submission missing columns: {missing_cols}")

report = {
    "submission_path": str(sub_path),
    "n_rows": int(len(sub_df)),
    "n_expected": int(len(sample_sub)),
    "columns": list(sub_df.columns),
    "errors": [],
    "warnings": [],
}

# duplicates
dup = sub_df["case_id"].duplicated()
report["duplicate_case_id_count"] = int(dup.sum())

# coverage vs sample submission
sub_set = set(sub_df["case_id"].tolist())
sam_set = set(sample_sub["case_id"].tolist())
missing = sorted(list(sam_set - sub_set))
extra   = sorted(list(sub_set - sam_set))
report["missing_case_id_count"] = int(len(missing))
report["extra_case_id_count"] = int(len(extra))
report["missing_case_id_examples"] = missing[:20]
report["extra_case_id_examples"] = extra[:20]

if report["missing_case_id_count"] > 0:
    report["errors"].append("Missing case_id(s) compared to sample_submission.")
if report["duplicate_case_id_count"] > 0:
    report["errors"].append("Duplicate case_id(s) detected in submission.")
if report["extra_case_id_count"] > 0:
    report["warnings"].append("Extra case_id(s) present beyond sample_submission (usually ignored but not ideal).")

# reorder check (not error, just note)
is_same_order = (sub_df["case_id"].tolist() == sample_sub["case_id"].tolist()) if len(sub_df)==len(sample_sub) else False
report["matches_sample_order"] = bool(is_same_order)

# ----------------------------
# 2) Annotation format checks
# ----------------------------
def is_valid_rle_json_list(s):
    try:
        runs = json.loads(s)
        if not isinstance(runs, list):
            return False, "not_list"
        if len(runs) == 0:
            return True, "empty_list_ok"
        if len(runs) % 2 != 0:
            return False, "odd_length"
        # must be ints >= 1 for starts, >=1 for lengths (Kaggle style)
        prev_start = 0
        for i in range(0, len(runs), 2):
            st = runs[i]
            ln = runs[i+1]
            if not (isinstance(st, int) and isinstance(ln, int)):
                return False, "non_int"
            if st < 1 or ln < 1:
                return False, "non_positive"
            if st < prev_start:
                return False, "non_monotonic"
            prev_start = st
        return True, "ok"
    except Exception:
        return False, "json_parse_fail"

bad_rows = []
n_auth = 0
n_rle = 0
n_empty_list = 0

for idx, r in sub_df.iterrows():
    ann = r["annotation"]
    if isinstance(ann, float) and np.isnan(ann):
        bad_rows.append(dict(case_id=r["case_id"], issue="annotation_nan", annotation=""))
        continue
    ann = str(ann)
    if ann == "authentic":
        n_auth += 1
        continue
    ok, reason = is_valid_rle_json_list(ann)
    if not ok:
        bad_rows.append(dict(case_id=r["case_id"], issue=f"bad_rle_{reason}", annotation=ann[:120]))
    else:
        n_rle += 1
        if reason == "empty_list_ok":
            n_empty_list += 1

report["n_authentic"] = int(n_auth)
report["n_rle"] = int(n_rle)
report["n_empty_rle_list"] = int(n_empty_list)
report["bad_annotation_count"] = int(len(bad_rows))

if report["bad_annotation_count"] > 0:
    report["errors"].append("Some annotations are not valid ('authentic' or valid JSON RLE list).")

bad_path = ART_DIR / "qa_bad_rows.csv"
pd.DataFrame(bad_rows).to_csv(bad_path, index=False)
print("Saved bad rows:", bad_path, "| rows:", len(bad_rows))

# ----------------------------
# 3) Optional decode sanity check (small sample)
# ----------------------------
# Decode a handful of RLEs and verify pixel count fits image dimensions.
# This does not guarantee correctness, but catches obvious indexing errors.
def decode_rle_json_list(rle_json, h, w):
    runs = json.loads(rle_json)
    mask = np.zeros(h*w, dtype=np.uint8)
    for i in range(0, len(runs), 2):
        st = runs[i] - 1  # 1-indexed to 0-indexed
        ln = runs[i+1]
        mask[st:st+ln] = 1
    # reshape back from column-major encoding (we used transpose flatten)
    mask = mask.reshape((w, h)).T  # invert flatten m.T.flatten
    return mask

decode_checks = []
# take up to 20 masks
mask_cases = sub_df[sub_df["annotation"]!="authentic"].head(20)
for _, rr in mask_cases.iterrows():
    cid = rr["case_id"]
    ann = rr["annotation"]
    # find image
    m = df_test[df_test["case_id"]==cid]
    if len(m)==0:
        decode_checks.append(dict(case_id=cid, ok=False, reason="case_id_not_in_test_manifest"))
        continue
    ip = m["img_path"].iloc[0]
    try:
        im = Image.open(ip)
        w, h = im.size
        mk = decode_rle_json_list(ann, h=h, w=w)
        ok = (mk.shape[0]==h and mk.shape[1]==w)
        decode_checks.append(dict(case_id=cid, ok=bool(ok), h=int(h), w=int(w), area=int(mk.sum())))
    except Exception as e:
        decode_checks.append(dict(case_id=cid, ok=False, reason=f"decode_fail: {str(e)[:120]}"))

report["decode_check_n"] = int(len(decode_checks))
report["decode_check_fail_count"] = int(sum(1 for x in decode_checks if not x.get("ok", False)))
report["decode_check_examples"] = [x for x in decode_checks if not x.get("ok", False)][:10]

if report["decode_check_fail_count"] > 0:
    report["warnings"].append("Some RLE decode checks failed (possible encoding/index issue).")

# ----------------------------
# 4) Save QA report
# ----------------------------
qa_path = ART_DIR / "qa_report.json"
qa_path.write_text(json.dumps(report, indent=2))
print("Saved QA report:", qa_path)
print("QA errors:", report["errors"])
print("QA warnings:", report["warnings"])

# ----------------------------
# 5) Paper figure: QA summary
# ----------------------------
import matplotlib.pyplot as plt

def savefig(path, dpi=300):
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

plt.figure(figsize=(7,4))
plt.bar(["authentic", "rle_mask"], [report["n_authentic"], report["n_rle"]])
plt.ylabel("Count")
plt.title("Submission QA: output type counts")
savefig(FIG_DIR / "Fig8-1_QA_summary.png")

print("[OK] QA figure saved:", FIG_DIR / "Fig8-1_QA_summary.png")

# ----------------------------
# 6) Final status print
# ----------------------------
if len(report["errors"]) == 0:
    print("\n[OK] Submission QA PASSED.")
else:
    print("\n[FAIL] Submission QA found errors. Check qa_report.json and qa_bad_rows.csv.")

# Keep globals
globals().update(dict(
    QA_REPORT=report,
    QA_REPORT_PATH=str(qa_path),
    QA_BAD_ROWS_PATH=str(bad_path),
))
# ============================================================
