In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Train & Save: Gate Model + Calibration + Thresholds

In [None]:
# ============================================================
# STAGE 5 — DINOv2 Multi-Task (Classification + Segmentation) (ONE CELL, REVISI FULL v5)
# - Replace tabular gate with real DINOv2 training (multi-task)
# - Patch-grid segmentation (stable, fast) + classification head
# - CV OOF + calibration + threshold tuning (thr_forged, thr_mask, min_pred_patches)
# - Save ONLY trainable weights (heads + last N blocks) to keep checkpoint small
#
# Outputs:
# - /kaggle/working/recodai_luc/models/dinov2_mt_v5_*/checkpoints/fold_*.pt
# - model_config.json, thresholds.json, calibrator.joblib, oof_predictions.csv, report.json
# ============================================================

import os, json, time, math, ast, hashlib, random, gc
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from transformers import AutoModel
import joblib
from sklearn.metrics import roc_auc_score, average_precision_score
from sklearn.isotonic import IsotonicRegression
from sklearn.linear_model import LogisticRegression

# ----------------------------
# REQUIRE
# ----------------------------
if "df_train_all" not in globals():
    raise RuntimeError("Missing df_train_all. Jalankan STAGE 1 dulu.")
df_train_all = df_train_all.copy()

need_cols = {"sample_id","case_id","variant","fold","y_forged","mask_paths","image_path"}
miss = need_cols - set(df_train_all.columns)
if miss:
    raise RuntimeError(f"df_train_all missing columns: {miss}")

# ----------------------------
# ID normalization (same idea as STAGE 4)
# ----------------------------
def _norm_one_id(x):
    if x is None or (isinstance(x, float) and np.isnan(x)):
        return ""
    if isinstance(x, (np.integer, int)):
        return str(int(x))
    if isinstance(x, (np.floating, float)):
        if np.isfinite(x) and abs(x - round(x)) < 1e-9:
            return str(int(round(x)))
        return str(float(x))
    s = str(x)
    if s.endswith(".0"):
        head = s[:-2]
        if head.isdigit():
            return head
    return s

def norm_id_series(s: pd.Series) -> pd.Series:
    return s.map(_norm_one_id)

for c in ["sample_id","case_id"]:
    df_train_all[c] = norm_id_series(df_train_all[c])

df_train_all["variant"]  = df_train_all["variant"].astype(str)
df_train_all["fold"]     = df_train_all["fold"].astype(int)
df_train_all["y_forged"] = df_train_all["y_forged"].astype(int)
df_train_all["image_path"] = df_train_all["image_path"].astype(str)

# keep only fold >=0 (train CV)
df_train_all = df_train_all[df_train_all["fold"] >= 0].reset_index(drop=True)
if len(df_train_all) == 0:
    raise RuntimeError("Tidak ada data train dengan fold>=0.")

# ----------------------------
# CONFIG (tuning-ready)
# ----------------------------
CFG = {
    # DINO dir (from STAGE 1/2)
    "dino_dir": str(globals().get("DINO_BASE_DIR", "/kaggle/input/dinov2/pytorch/base/1")),

    # image / patch grid
    "img_size": 560,           # MUST be multiple of patch_size
    "patch_size": 14,
    "use_center_crop_val": True,

    # training
    "seed": 2025,
    "epochs": 5,
    "batch_size": 4,
    "num_workers": 2,
    "amp": True,
    "grad_accum": 1,
    "clip_grad": 1.0,

    # finetune policy
    "unfreeze_last_n_blocks": 2,   # 0 => linear probe (heads only)
    "use_grad_ckpt": False,

    # optimizer
    "lr_backbone": 3e-5,
    "lr_heads": 3e-4,
    "weight_decay": 0.05,

    # losses
    "w_cls": 1.0,
    "w_seg_bce": 1.0,
    "w_seg_dice": 1.0,
    "dice_eps": 1e-6,

    # augmentation (light, scientific-friendly)
    "aug_hflip_p": 0.5,
    "aug_vflip_p": 0.1,
    "aug_rot90_p": 0.15,         # rotate by 0/90/180/270
    "aug_brightness": 0.10,
    "aug_contrast": 0.10,

    # early stopping
    "patience": 2,

    # output
    "out_root": "/kaggle/working/recodai_luc/models",
    "gt_cache_root": "/kaggle/working/recodai_luc/cache/gt_grid_union",
    "print_every": 200,
}

IMG_SIZE = int(CFG["img_size"])
PATCH    = int(CFG["patch_size"])
assert IMG_SIZE % PATCH == 0, f"img_size harus kelipatan patch_size. Got {IMG_SIZE} vs {PATCH}"
GH = IMG_SIZE // PATCH
GW = IMG_SIZE // PATCH

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# threads (CPU side)
try:
    torch.set_num_threads(max(1, (os.cpu_count() or 2)//2))
except Exception:
    pass

# seed
def seed_everything(seed=2025):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
seed_everything(CFG["seed"])

# imagenet norm (DINOv2 typical)
IMNET_MEAN = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(3,1,1)
IMNET_STD  = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(3,1,1)

# popcount LUT (for packed bits)
POPCNT = np.array([bin(i).count("1") for i in range(256)], dtype=np.uint8)

# ----------------------------
# Mask path parsing + loader (robust)
# ----------------------------
def parse_mask_paths(val):
    if val is None or (isinstance(val, float) and np.isnan(val)):
        return []
    if isinstance(val, (list, tuple)):
        return [str(x) for x in val if str(x)]
    if isinstance(val, np.ndarray):
        try:
            return [str(x) for x in val.reshape(-1).tolist() if str(x)]
        except Exception:
            return []
    if isinstance(val, str):
        s = val.strip()
        if s == "" or s.lower() in ("nan","none","null"):
            return []
        if (s.startswith("[") and s.endswith("]")) or (s.startswith("(") and s.endswith(")")):
            try:
                out = json.loads(s)
                if isinstance(out, (list, tuple)):
                    return [str(x) for x in out if str(x)]
            except Exception:
                pass
            try:
                out = ast.literal_eval(s)
                if isinstance(out, (list, tuple)):
                    return [str(x) for x in out if str(x)]
            except Exception:
                pass
        return [s]
    return []

def load_mask_any_as_bool(path: str, target_h: int, target_w: int):
    p = Path(str(path))
    if not p.exists():
        return None
    suf = p.suffix.lower()

    if suf in (".png",".jpg",".jpeg",".bmp",".tif",".tiff",".webp"):
        im = Image.open(p).convert("L")
        if im.size != (target_w, target_h):
            im = im.resize((target_w, target_h), resample=Image.NEAREST)
        return (np.array(im) > 0)

    if suf == ".npz":
        z = np.load(p, allow_pickle=False)
        if ("mask_pack" in z.files) and ("mask_h" in z.files) and ("mask_w" in z.files):
            mh = int(z["mask_h"]); mw = int(z["mask_w"])
            pack = z["mask_pack"].astype(np.uint8).reshape(-1)
            bits = np.unpackbits(pack, axis=None)[: mh*mw].reshape(mh, mw).astype(bool)
            if (mh, mw) != (target_h, target_w):
                im = Image.fromarray((bits.astype(np.uint8)*255)).resize((target_w, target_h), resample=Image.NEAREST)
                return (np.array(im) > 0)
            return bits
        return None

    # npy / others
    try:
        arr = np.load(p, allow_pickle=False)
    except Exception:
        arr = np.load(p, allow_pickle=True)
        if np.ndim(arr) == 0 and hasattr(arr, "item"):
            arr = arr.item()

    a = np.asarray(arr)
    a = np.squeeze(a)
    if a.ndim == 2:
        m = (a > 0)
        if m.shape != (target_h, target_w):
            im = Image.fromarray((m.astype(np.uint8)*255)).resize((target_w, target_h), resample=Image.NEAREST)
            m = (np.array(im) > 0)
        return m.astype(bool)

    if a.ndim == 3:
        # take max across channels
        m = (a.max(axis=-1) > 0)
        if m.shape != (target_h, target_w):
            im = Image.fromarray((m.astype(np.uint8)*255)).resize((target_w, target_h), resample=Image.NEAREST)
            m = (np.array(im) > 0)
        return m.astype(bool)

    return None

# ----------------------------
# GT cache: store union mask on PATCH GRID (GH x GW) packed bits
# - super small files, super fast training
# ----------------------------
GT_CACHE_ROOT = Path(CFG["gt_cache_root"]) / f"sz{IMG_SIZE}_p{PATCH}"
GT_CACHE_ROOT.mkdir(parents=True, exist_ok=True)

def mask_to_grid(mask_bool_hw: np.ndarray, gh: int, gw: int, patch: int) -> np.ndarray:
    # mask_bool_hw: (H,W) where H=W=IMG_SIZE
    H, W = mask_bool_hw.shape
    assert H == gh*patch and W == gw*patch
    x = mask_bool_hw.reshape(gh, patch, gw, patch)
    g = x.max(axis=(1,3))  # (gh,gw) bool
    return g

def pack_grid_bool(grid_bool: np.ndarray) -> np.ndarray:
    flat = grid_bool.astype(np.uint8).reshape(-1)
    return np.packbits(flat, axis=None).astype(np.uint8)

def unpack_grid_pack(pack_u8: np.ndarray, gh: int, gw: int) -> np.ndarray:
    bits = np.unpackbits(pack_u8.reshape(-1).astype(np.uint8), axis=None)[: gh*gw]
    return bits.reshape(gh, gw).astype(np.uint8)

def build_gt_cache_for_row(sample_id: str, y_forged: int, mask_paths_val):
    outp = GT_CACHE_ROOT / f"{sample_id}.npz"
    if outp.exists():
        return True

    # authentic => empty
    if int(y_forged) == 0:
        pack = pack_grid_bool(np.zeros((GH, GW), dtype=bool))
        np.savez_compressed(str(outp),
                            grid_pack=pack,
                            gh=np.int16(GH), gw=np.int16(GW),
                            pos=np.int32(0),
                            empty=np.int8(1))
        return True

    paths = parse_mask_paths(mask_paths_val)
    if len(paths) == 0:
        pack = pack_grid_bool(np.zeros((GH, GW), dtype=bool))
        np.savez_compressed(str(outp),
                            grid_pack=pack,
                            gh=np.int16(GH), gw=np.int16(GW),
                            pos=np.int32(0),
                            empty=np.int8(1))
        return True

    union = np.zeros((IMG_SIZE, IMG_SIZE), dtype=bool)
    for p in paths:
        m = load_mask_any_as_bool(p, IMG_SIZE, IMG_SIZE)
        if m is None:
            continue
        union |= m

    grid = mask_to_grid(union, GH, GW, PATCH)
    pos = int(grid.sum())
    pack = pack_grid_bool(grid)

    np.savez_compressed(str(outp),
                        grid_pack=pack,
                        gh=np.int16(GH), gw=np.int16(GW),
                        pos=np.int32(pos),
                        empty=np.int8(1 if pos == 0 else 0))
    return True

print("\n[GT CACHE] Building missing GT grid cache files (only if absent) ...")
t0 = time.time()
built = 0
fail = 0
for i, r in enumerate(df_train_all.itertuples(index=False), start=1):
    sid = getattr(r, "sample_id")
    yfg = getattr(r, "y_forged")
    mpv = getattr(r, "mask_paths")
    try:
        ok = build_gt_cache_for_row(str(sid), int(yfg), mpv)
        built += 1 if ok else 0
    except Exception:
        fail += 1
    if (i % CFG["print_every"]) == 0:
        print(f"[GT CACHE] {i:,}/{len(df_train_all):,} processed | fail={fail:,}")
dt = time.time() - t0
print(f"[GT CACHE] Done. processed={built:,} fail={fail:,} elapsed_s={dt:.1f} | dir={GT_CACHE_ROOT}")

# ----------------------------
# Dataset
# ----------------------------
def pil_load_rgb(path: str):
    p = Path(path)
    if not p.exists():
        return None
    try:
        return Image.open(p).convert("RGB")
    except Exception:
        return None

def pil_resize_sq(img: Image.Image, size: int):
    # keep it simple: resize to square (scientific images often ok)
    return img.resize((size, size), resample=Image.BICUBIC)

def img_to_tensor_norm(img: Image.Image):
    arr = np.array(img, dtype=np.float32) / 255.0  # HWC
    t = torch.from_numpy(arr).permute(2,0,1).contiguous()  # CHW
    t = (t - IMNET_MEAN) / IMNET_STD
    return t

def apply_aug(image_t: torch.Tensor, grid_u8: torch.Tensor, cfg: dict):
    # image_t: [3,H,W], grid_u8: [GH,GW]
    # flips/rot90 on both; color jitter only on image
    if random.random() < float(cfg["aug_hflip_p"]):
        image_t = torch.flip(image_t, dims=[2])
        grid_u8 = torch.flip(grid_u8, dims=[1])
    if random.random() < float(cfg["aug_vflip_p"]):
        image_t = torch.flip(image_t, dims=[1])
        grid_u8 = torch.flip(grid_u8, dims=[0])

    if random.random() < float(cfg["aug_rot90_p"]):
        k = random.randint(0, 3)
        if k > 0:
            image_t = torch.rot90(image_t, k=k, dims=[1,2])
            grid_u8 = torch.rot90(grid_u8, k=k, dims=[0,1])

    # light brightness/contrast
    b = float(cfg["aug_brightness"])
    c = float(cfg["aug_contrast"])
    if (b > 0) or (c > 0):
        # brightness: multiply
        if b > 0 and random.random() < 0.5:
            factor = 1.0 + random.uniform(-b, b)
            image_t = torch.clamp(image_t * factor, -5.0, 5.0)
        # contrast: scale deviation from mean
        if c > 0 and random.random() < 0.5:
            mean = image_t.mean(dim=(1,2), keepdim=True)
            factor = 1.0 + random.uniform(-c, c)
            image_t = torch.clamp((image_t - mean) * factor + mean, -5.0, 5.0)

    return image_t, grid_u8

class DinoMTDataset(Dataset):
    def __init__(self, df: pd.DataFrame, train: bool):
        self.df = df.reset_index(drop=True)
        self.train = bool(train)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        sid = str(r["sample_id"])
        yfg = int(r["y_forged"])
        ip  = str(r["image_path"])

        img = pil_load_rgb(ip)
        if img is None:
            # fallback blank
            img = Image.fromarray(np.zeros((IMG_SIZE, IMG_SIZE, 3), dtype=np.uint8))
        img = pil_resize_sq(img, IMG_SIZE)
        x = img_to_tensor_norm(img)  # [3,IMG,IMG]

        gt_path = GT_CACHE_ROOT / f"{sid}.npz"
        if gt_path.exists():
            z = np.load(gt_path, allow_pickle=False)
            pack = z["grid_pack"].astype(np.uint8).reshape(-1)
            grid = unpack_grid_pack(pack, GH, GW)  # uint8 0/1
        else:
            grid = np.zeros((GH, GW), dtype=np.uint8)

        g = torch.from_numpy(grid).to(torch.uint8)  # [GH,GW]

        if self.train:
            x, g = apply_aug(x, g, CFG)

        # label float
        y = torch.tensor([float(yfg)], dtype=torch.float32)

        # seg target float (0/1)
        g = g.to(torch.float32)

        return x, g, y, sid

# ----------------------------
# Build DINOv2 multitask model
# ----------------------------
def find_encoder_layers(model):
    # try common structures
    if hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
        return model.encoder.layer
    if hasattr(model, "vit") and hasattr(model.vit, "encoder") and hasattr(model.vit.encoder, "layer"):
        return model.vit.encoder.layer
    if hasattr(model, "backbone") and hasattr(model.backbone, "encoder") and hasattr(model.backbone.encoder, "layer"):
        return model.backbone.encoder.layer
    return None

def infer_num_layers_from_names(model):
    # fallback by scanning parameter names like "...encoder.layer.11...."
    ids = set()
    for n, _ in model.named_parameters():
        if ".layer." in n:
            try:
                t = n.split(".layer.")[1]
                i = int(t.split(".")[0])
                ids.add(i)
            except Exception:
                pass
    return (max(ids)+1) if ids else 0

def freeze_all(model):
    for p in model.parameters():
        p.requires_grad = False

def unfreeze_last_n_blocks(model, n_last: int):
    if n_last <= 0:
        return 0

    layers = find_encoder_layers(model)
    if layers is not None:
        L = len(layers)
        n_last = min(int(n_last), L)
        for i in range(L - n_last, L):
            for p in layers[i].parameters():
                p.requires_grad = True
        return n_last

    # fallback: infer number of layers from names
    L = infer_num_layers_from_names(model)
    if L <= 0:
        return 0
    n_last = min(int(n_last), L)
    allow = set(range(L - n_last, L))
    for name, p in model.named_parameters():
        if ".layer." in name:
            try:
                i = int(name.split(".layer.")[1].split(".")[0])
                if i in allow:
                    p.requires_grad = True
            except Exception:
                pass
    return n_last

class DinoMultiTask(nn.Module):
    def __init__(self, backbone: nn.Module, gh: int, gw: int):
        super().__init__()
        self.backbone = backbone
        self.gh = int(gh)
        self.gw = int(gw)

        # infer embed dim
        with torch.no_grad():
            dummy = torch.zeros((1,3,IMG_SIZE,IMG_SIZE), dtype=torch.float32)
            out = self.backbone(pixel_values=dummy)
            D = int(out.last_hidden_state.shape[-1])

        self.embed_dim = D

        # heads
        self.cls_head = nn.Sequential(
            nn.LayerNorm(D),
            nn.Linear(D, 1)
        )
        self.seg_head = nn.Sequential(
            nn.LayerNorm(D),
            nn.Linear(D, 1)  # per patch
        )

    def forward(self, x):
        out = self.backbone(pixel_values=x)
        h = out.last_hidden_state  # [B, 1+N, D]
        cls_tok = h[:, 0, :]       # [B,D]
        ptok    = h[:, 1:, :]      # [B,N,D]

        cls_logit = self.cls_head(cls_tok).squeeze(-1)  # [B]

        # patch logits -> grid
        seg_patch = self.seg_head(ptok).squeeze(-1)     # [B,N]
        B, N = seg_patch.shape
        # expect N == gh*gw, but be robust
        if N != self.gh * self.gw:
            # try to infer grid close to square
            side = int(round(math.sqrt(N)))
            gh = max(1, side)
            gw = max(1, N // gh)
            gh = max(1, N // gw)
            seg_grid = seg_patch[:, :gh*gw].reshape(B, gh, gw)
        else:
            seg_grid = seg_patch.reshape(B, self.gh, self.gw)

        return cls_logit, seg_grid

def bce_with_pos_weight(logits, targets, pos_weight: float):
    pw = torch.tensor([pos_weight], device=logits.device, dtype=logits.dtype)
    return F.binary_cross_entropy_with_logits(logits, targets, pos_weight=pw)

def dice_loss_from_logits(seg_logits, seg_targets, eps=1e-6):
    # seg_logits: [B,GH,GW], seg_targets: [B,GH,GW] (0/1)
    p = torch.sigmoid(seg_logits)
    t = seg_targets
    # flatten
    p = p.reshape(p.size(0), -1)
    t = t.reshape(t.size(0), -1)
    inter = (p * t).sum(dim=1)
    den = p.sum(dim=1) + t.sum(dim=1)
    dice = (2.0 * inter + eps) / (den + eps)
    return (1.0 - dice).mean()

def get_trainable_state_dict(model: nn.Module):
    sd = model.state_dict()
    trainable = {}
    # save params that are trainable OR belong to heads
    trainable_prefix = ("cls_head.", "seg_head.")
    for n, p in model.named_parameters():
        if p.requires_grad or n.startswith(trainable_prefix):
            trainable[n] = sd[n].detach().half().cpu()
    # also include LN running buffers if any inside heads (usually none)
    for n, b in model.named_buffers():
        if n.startswith(trainable_prefix):
            trainable[n] = b.detach().half().cpu() if torch.is_floating_point(b) else b.detach().cpu()
    return trainable

# ----------------------------
# Prepare data splits + weights
# ----------------------------
folds_all = df_train_all["fold"].values.astype(int)
unique_folds = sorted(np.unique(folds_all).tolist())

y_all = df_train_all["y_forged"].values.astype(int)
pos = float((y_all == 1).sum()); neg = float((y_all == 0).sum())
pos_weight_cls = (neg / max(pos, 1.0)) if (pos > 0 and neg > 0) else 1.0

# seg pos_weight: estimate from cached GT grids
# (neg_pixels / pos_pixels) on grid space
print("\n[SEG POS WEIGHT] Estimating from cached GT grids ...")
pos_pix = 0
tot_pix = 0
for sid in df_train_all["sample_id"].tolist():
    p = GT_CACHE_ROOT / f"{sid}.npz"
    if not p.exists():
        continue
    z = np.load(p, allow_pickle=False)
    pack = z["grid_pack"].astype(np.uint8).reshape(-1)
    cnt = int(POPCNT[pack].sum())
    pos_pix += cnt
    tot_pix += (GH * GW)
neg_pix = max(0, tot_pix - pos_pix)
pos_weight_seg = (neg_pix / max(pos_pix, 1)) if pos_pix > 0 else 1.0
pos_weight_seg = float(np.clip(pos_weight_seg, 1.0, 50.0))  # clamp to avoid extreme
print(f"[SEG POS WEIGHT] pos_pix={pos_pix:,} tot_pix={tot_pix:,} => pos_weight_seg={pos_weight_seg:.3f}")

# sampler weights (balance class)
w_pos = (pos + neg) / (2.0 * max(pos, 1.0))
w_neg = (pos + neg) / (2.0 * max(neg, 1.0))
sample_w = np.where(y_all == 1, w_pos, w_neg).astype(np.float32)

# ----------------------------
# Training / Eval helpers
# ----------------------------
@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    all_sid = []
    all_y = []
    all_p = []
    all_seg = []
    all_gt = []

    for x, g, y, sid in loader:
        x = x.to(device, non_blocking=True)
        g = g.to(device, non_blocking=True)  # [B,GH,GW]
        y = y.to(device, non_blocking=True).view(-1)  # [B]

        cls_logit, seg_logit = model(x)  # seg_logit [B,GH,GW] (or inferred)
        p = torch.sigmoid(cls_logit).detach().cpu().numpy().astype(np.float32)

        seg_p = torch.sigmoid(seg_logit).detach().cpu().numpy().astype(np.float16)

        all_sid.extend(list(sid))
        all_y.append(y.detach().cpu().numpy().astype(np.int8))
        all_p.append(p)
        all_seg.append(seg_p)
        all_gt.append(g.detach().cpu().numpy().astype(np.uint8))

    all_y = np.concatenate(all_y, axis=0)
    all_p = np.concatenate(all_p, axis=0)
    all_seg = np.concatenate(all_seg, axis=0)  # [N,gh,gw]
    all_gt = np.concatenate(all_gt, axis=0)    # [N,gh,gw] uint8

    # metrics
    if np.unique(all_y).size >= 2:
        auc = float(roc_auc_score(all_y, all_p))
        ap  = float(average_precision_score(all_y, all_p))
    else:
        auc = float("nan")
        ap  = float("nan")

    # dice at default thr=0.5 (for monitoring)
    thr = 0.5
    pred = (all_seg >= thr).astype(np.uint8)
    gt = (all_gt > 0).astype(np.uint8)

    ps = pred.reshape(len(pred), -1).sum(axis=1).astype(np.float32)
    gs = gt.reshape(len(gt), -1).sum(axis=1).astype(np.float32)
    inter = (pred & gt).reshape(len(pred), -1).sum(axis=1).astype(np.float32)

    dice = np.zeros(len(pred), dtype=np.float32)
    both0 = (ps == 0) & (gs == 0)
    dice[both0] = 1.0
    m = ~both0
    dice[m] = (2.0*inter[m]) / (ps[m] + gs[m] + 1e-6)

    # focus dice on forged (more meaningful)
    forged = (all_y == 1)
    dice_forg = float(dice[forged].mean()) if forged.any() else 0.0

    # combined score for checkpointing
    # (favor classification, but keep seg quality)
    score = float((0.65 * (auc if np.isfinite(auc) else 0.0)) + (0.35 * dice_forg))

    return {
        "sid": all_sid,
        "y": all_y,
        "p_cls": all_p,
        "seg_prob": all_seg,   # float16
        "gt_grid": all_gt,     # uint8
        "auc": auc,
        "ap": ap,
        "dice_forg@0.5": dice_forg,
        "score": score,
    }

def build_optimizer(model: nn.Module):
    # param groups: heads (higher lr), trainable backbone (lower lr)
    head_params = []
    bb_params = []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if n.startswith("cls_head.") or n.startswith("seg_head."):
            head_params.append(p)
        else:
            bb_params.append(p)

    groups = []
    if bb_params:
        groups.append({"params": bb_params, "lr": float(CFG["lr_backbone"]), "weight_decay": float(CFG["weight_decay"])})
    if head_params:
        groups.append({"params": head_params, "lr": float(CFG["lr_heads"]), "weight_decay": float(CFG["weight_decay"])})

    opt = torch.optim.AdamW(groups)
    return opt

def cosine_lr(step, total, lr_max):
    # simple cosine from lr_max -> 0
    if total <= 1:
        return lr_max
    t = min(max(step / total, 0.0), 1.0)
    return lr_max * 0.5 * (1.0 + math.cos(math.pi * t))

# ----------------------------
# Output directory (versioned)
# ----------------------------
RUN_TAG = hashlib.md5(json.dumps(CFG, sort_keys=True).encode()).hexdigest()[:10]
OUT_DIR = Path(CFG["out_root"]) / f"dinov2_mt_v5_{RUN_TAG}"
CKPT_DIR = OUT_DIR / "checkpoints"
OUT_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR.mkdir(parents=True, exist_ok=True)

print("\nOUT_DIR:", OUT_DIR)
print("device:", device, "| folds:", unique_folds)

# ----------------------------
# CV TRAIN
# ----------------------------
n_all = len(df_train_all)
oof_p_cls = np.zeros(n_all, dtype=np.float32)
oof_seg   = np.zeros((n_all, GH, GW), dtype=np.float16)
oof_gt    = np.zeros((n_all, GH, GW), dtype=np.uint8)
oof_y     = df_train_all["y_forged"].values.astype(np.int8)
oof_fold  = df_train_all["fold"].values.astype(int)
oof_sid   = df_train_all["sample_id"].astype(str).tolist()

fold_reports = {}
best_fold_paths = {}

for f in unique_folds:
    print(f"\n====================\nFOLD {f}\n====================")
    tr_idx = np.where(oof_fold != f)[0]
    va_idx = np.where(oof_fold == f)[0]
    if len(va_idx) == 0:
        print("Skip fold (no val).")
        continue

    df_tr = df_train_all.iloc[tr_idx].reset_index(drop=True)
    df_va = df_train_all.iloc[va_idx].reset_index(drop=True)

    ds_tr = DinoMTDataset(df_tr, train=True)
    ds_va = DinoMTDataset(df_va, train=False)

    # sampler for train
    sw = sample_w[tr_idx]
    sampler = WeightedRandomSampler(weights=torch.from_numpy(sw), num_samples=len(sw), replacement=True)

    dl_tr = DataLoader(ds_tr, batch_size=int(CFG["batch_size"]), sampler=sampler,
                       num_workers=int(CFG["num_workers"]), pin_memory=(device.type=="cuda"))
    dl_va = DataLoader(ds_va, batch_size=int(CFG["batch_size"]), shuffle=False,
                       num_workers=int(CFG["num_workers"]), pin_memory=(device.type=="cuda"))

    # backbone
    print("Loading DINO from:", CFG["dino_dir"])
    backbone = AutoModel.from_pretrained(str(CFG["dino_dir"]), local_files_only=True)
    backbone.eval()

    if CFG.get("use_grad_ckpt", False) and hasattr(backbone, "gradient_checkpointing_enable"):
        try:
            backbone.gradient_checkpointing_enable()
            print("gradient checkpointing: ON")
        except Exception:
            pass

    freeze_all(backbone)
    used_unfreeze = unfreeze_last_n_blocks(backbone, int(CFG["unfreeze_last_n_blocks"]))
    print("unfreeze_last_n_blocks used:", used_unfreeze)

    model = DinoMultiTask(backbone, gh=GH, gw=GW).to(device)
    # heads always trainable
    for p in model.cls_head.parameters():
        p.requires_grad = True
    for p in model.seg_head.parameters():
        p.requires_grad = True

    opt = build_optimizer(model)

    # AMP scaler
    use_amp = bool(CFG["amp"]) and (device.type == "cuda")
    scaler = torch.cuda.amp.GradScaler(enabled=use_amp)

    # training steps for cosine schedule
    steps_per_epoch = max(1, len(dl_tr))
    total_steps = int(CFG["epochs"]) * steps_per_epoch

    best_score = -1e9
    best_state = None
    best_epoch = -1
    bad_epochs = 0

    global_step = 0
    t0 = time.time()

    for epoch in range(int(CFG["epochs"])):
        model.train()
        running = 0.0
        n_seen = 0

        for it, (x, g, y, sid) in enumerate(dl_tr, start=1):
            x = x.to(device, non_blocking=True)
            g = g.to(device, non_blocking=True)         # [B,GH,GW]
            y = y.to(device, non_blocking=True).view(-1)  # [B]

            # lr schedule (two groups): scale by cosine on each group's base lr
            lr_scale = cosine_lr(global_step, total_steps, 1.0)
            for pg in opt.param_groups:
                base = float(pg.get("lr", 1e-4))
                # base already set; rescale relative to initial? keep simple:
                pg["lr"] = base * lr_scale

            with torch.cuda.amp.autocast(enabled=use_amp):
                cls_logit, seg_logit = model(x)

                # cls loss
                loss_cls = bce_with_pos_weight(cls_logit, y, pos_weight=pos_weight_cls)

                # seg loss on grid
                loss_seg_bce  = bce_with_pos_weight(seg_logit, g, pos_weight=pos_weight_seg)
                loss_seg_dice = dice_loss_from_logits(seg_logit, g, eps=float(CFG["dice_eps"]))

                loss = (float(CFG["w_cls"]) * loss_cls
                        + float(CFG["w_seg_bce"]) * loss_seg_bce
                        + float(CFG["w_seg_dice"]) * loss_seg_dice)

            loss = loss / float(CFG["grad_accum"])

            scaler.scale(loss).backward()

            if (it % int(CFG["grad_accum"])) == 0:
                # grad clip
                if float(CFG["clip_grad"]) > 0:
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), float(CFG["clip_grad"]))

                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)

            running += float(loss.item()) * float(CFG["grad_accum"])
            n_seen += x.size(0)
            global_step += 1

            if (it % int(CFG["print_every"])) == 0:
                dt = time.time() - t0
                print(f"epoch={epoch+1}/{CFG['epochs']} it={it}/{len(dl_tr)} "
                      f"loss={running/max(1,it):.4f} seen={n_seen} step={global_step} elapsed_s={dt:.1f}")

        # eval
        ev = evaluate(model, dl_va)
        print(f"[VAL] epoch={epoch+1} score={ev['score']:.5f} AUC={ev['auc']:.5f} AP={ev['ap']:.5f} dice_forg@0.5={ev['dice_forg@0.5']:.5f}")

        if ev["score"] > best_score:
            best_score = ev["score"]
            best_epoch = epoch + 1
            bad_epochs = 0
            # save trainable state
            best_state = get_trainable_state_dict(model)
        else:
            bad_epochs += 1
            if bad_epochs >= int(CFG["patience"]):
                print(f"Early stop on fold {f} at epoch {epoch+1}. Best epoch={best_epoch} best_score={best_score:.5f}")
                break

        gc.collect()
        if device.type == "cuda":
            torch.cuda.empty_cache()

    # load best state into model (for OOF extraction)
    if best_state is not None:
        # apply into current model
        cur = model.state_dict()
        for k, v in best_state.items():
            if k in cur:
                cur[k] = v.to(cur[k].dtype)
        model.load_state_dict(cur, strict=False)

    # final val for OOF storing
    ev = evaluate(model, dl_va)

    # write OOF to global arrays
    # map val local indices -> global indices
    for j, sid in enumerate(ev["sid"]):
        # sid belongs to df_va order, map by position j
        gi = va_idx[j]
        oof_p_cls[gi] = ev["p_cls"][j]
        # ensure shapes align
        sg = ev["seg_prob"][j]
        gt = ev["gt_grid"][j]
        if sg.shape != (GH, GW):
            # resize fallback via nearest on numpy (rare)
            sg = np.array(Image.fromarray(sg.astype(np.float32)).resize((GW, GH), resample=Image.NEAREST)).astype(np.float16)
        if gt.shape != (GH, GW):
            gt = np.array(Image.fromarray(gt.astype(np.uint8)*255).resize((GW, GH), resample=Image.NEAREST) > 0).astype(np.uint8)
        oof_seg[gi] = sg
        oof_gt[gi]  = gt

    # save fold checkpoint (trainable only)
    ckpt_path = CKPT_DIR / f"fold_{int(f)}.pt"
    torch.save({
        "fold": int(f),
        "best_epoch": int(best_epoch),
        "best_score": float(best_score),
        "trainable_state": best_state,
        "unfreeze_last_n_blocks": int(CFG["unfreeze_last_n_blocks"]),
        "img_size": int(IMG_SIZE),
        "patch_size": int(PATCH),
        "gh": int(GH),
        "gw": int(GW),
        "dino_dir": str(CFG["dino_dir"]),
    }, ckpt_path)

    best_fold_paths[int(f)] = str(ckpt_path)

    fold_reports[int(f)] = {
        "best_epoch": int(best_epoch),
        "best_score": float(best_score),
        "val_auc": float(ev["auc"]) if np.isfinite(ev["auc"]) else None,
        "val_ap": float(ev["ap"]) if np.isfinite(ev["ap"]) else None,
        "val_dice_forg@0.5": float(ev["dice_forg@0.5"]),
        "ckpt_path": str(ckpt_path),
    }

    print(f"[FOLD {f}] saved: {ckpt_path}")

# ----------------------------
# Calibration on OOF (cls only)
# ----------------------------
y = oof_y.astype(int)
p_raw = oof_p_cls.astype(np.float32)

calibrator = None
calib_kind = "none"
try:
    if np.unique(y).size >= 2 and len(y) >= 200 and np.unique(p_raw).size >= 50:
        iso = IsotonicRegression(out_of_bounds="clip")
        iso.fit(p_raw, y)
        calibrator = iso
        calib_kind = "isotonic"
    else:
        raise RuntimeError("Not enough unique probs for isotonic.")
except Exception:
    try:
        platt = LogisticRegression(solver="lbfgs", max_iter=4000)
        platt.fit(p_raw.reshape(-1,1), y)
        calibrator = platt
        calib_kind = "platt"
    except Exception:
        calibrator = None
        calib_kind = "none"

def apply_calibrator(p):
    p = np.asarray(p, dtype=np.float32)
    if calibrator is None or calib_kind == "none":
        return p
    if calib_kind == "isotonic":
        return calibrator.transform(p).astype(np.float32)
    return calibrator.predict_proba(p.reshape(-1,1))[:,1].astype(np.float32)

p_cal = apply_calibrator(p_raw)

# OOF cls metrics
if np.unique(y).size >= 2:
    auc_raw = float(roc_auc_score(y, p_raw))
    ap_raw  = float(average_precision_score(y, p_raw))
    auc_cal = float(roc_auc_score(y, p_cal))
    ap_cal  = float(average_precision_score(y, p_cal))
else:
    auc_raw = ap_raw = auc_cal = ap_cal = float("nan")

print("\n[OOF CLS] raw : AUC=%.5f AP=%.5f" % (auc_raw, ap_raw))
print("[OOF CLS] cal(%s): AUC=%.5f AP=%.5f" % (calib_kind, auc_cal, ap_cal))

# save calibrator
joblib.dump({"kind": calib_kind, "calibrator": calibrator}, OUT_DIR / "calibrator.joblib")

# ----------------------------
# Threshold tuning (vectorized, FAST)
# - tune: thr_forged, thr_mask, min_pred_patches
# Objective (weighted per-fold):
#   if gt_empty: score=1 - pf*(pred_nonempty)
#   if gt_nonempty: score=pf * dice
# where pf = (p_cal >= thr_forged)
# dice computed on grid with threshold thr_mask
# pred_nonempty uses min_pred_patches
# ----------------------------
gt = (oof_gt > 0).astype(np.uint8)                 # [N,GH,GW]
gt_sum = gt.reshape(n_all, -1).sum(axis=1).astype(np.int32)
gt_empty = (gt_sum == 0)

# balanced weights for objective
w_pos2 = (y.sum() + (y==0).sum()) / (2.0 * max(y.sum(), 1))
w_neg2 = (y.sum() + (y==0).sum()) / (2.0 * max((y==0).sum(), 1))
w_obj = np.where(y == 1, w_pos2, w_neg2).astype(np.float32)

thr_p_list = np.linspace(0.05, 0.95, 37).astype(np.float32)
thr_m_list = np.linspace(0.20, 0.80, 31).astype(np.float32)
min_patches_list = np.array([0, 1, 2, 4, 8, 16], dtype=np.int32)

# precompute per-fold masks
fold_ids = sorted(np.unique(oof_fold).tolist())
fold_masks = [(f, (oof_fold == f)) for f in fold_ids]

# precompute dice and pred_sum for each thr_mask
seg = oof_seg.astype(np.float32)  # [N,GH,GW] (float16 -> float32)

dice_by_tm = np.zeros((len(thr_m_list), n_all), dtype=np.float32)
psum_by_tm = np.zeros((len(thr_m_list), n_all), dtype=np.int32)

print("\n[THR] Precomputing dice / pred_sum across thr_mask ...")
t0 = time.time()
for i_tm, tm in enumerate(thr_m_list):
    pred = (seg >= float(tm)).astype(np.uint8)
    ps = pred.reshape(n_all, -1).sum(axis=1).astype(np.int32)
    inter = (pred & gt).reshape(n_all, -1).sum(axis=1).astype(np.int32)
    gs = gt_sum.astype(np.int32)

    dice = np.zeros(n_all, dtype=np.float32)
    both0 = (ps == 0) & (gs == 0)
    dice[both0] = 1.0
    m = ~both0
    dice[m] = (2.0 * inter[m].astype(np.float32)) / (ps[m].astype(np.float32) + gs[m].astype(np.float32) + 1e-6)

    dice_by_tm[i_tm] = dice
    psum_by_tm[i_tm] = ps

dt = time.time() - t0
print(f"[THR] Precompute done in {dt:.1f}s")

best = {
    "score": -1e9,
    "thr_forged": 0.5,
    "thr_mask": 0.5,
    "min_pred_patches": 0,
}

print("\n[THR] Grid search thr_forged x thr_mask x min_pred_patches (weighted fold-mean) ...")
t0 = time.time()

for tp in thr_p_list:
    pf = (p_cal >= float(tp))  # [N] bool
    for i_tm, tm in enumerate(thr_m_list):
        dice = dice_by_tm[i_tm]   # [N]
        ps   = psum_by_tm[i_tm]   # [N] int

        for mp in min_patches_list:
            pred_nonempty = (ps >= int(mp))  # [N] bool

            # score per sample:
            # gt_empty: 1 - pf*pred_nonempty
            # gt_nonempty: pf*dice
            s = np.zeros(n_all, dtype=np.float32)
            s[gt_empty] = 1.0 - (pf[gt_empty] & pred_nonempty[gt_empty]).astype(np.float32)
            ne = ~gt_empty
            s[ne] = (pf[ne].astype(np.float32) * dice[ne].astype(np.float32))

            # fold mean weighted
            fold_scores = []
            for f, m in fold_masks:
                if m.sum() == 0:
                    continue
                ww = w_obj[m]
                fold_scores.append(float(np.sum(s[m] * ww) / (np.sum(ww) + 1e-12)))
            mean_score = float(np.mean(fold_scores)) if fold_scores else -1e9

            if mean_score > best["score"]:
                best = {
                    "score": mean_score,
                    "thr_forged": float(tp),
                    "thr_mask": float(tm),
                    "min_pred_patches": int(mp),
                }

print(f"[THR] Search done in {time.time()-t0:.1f}s")
print("Best thresholds:")
print(json.dumps(best, indent=2))

# ----------------------------
# Save artifacts
# ----------------------------
thresholds = {
    "thr_forged": best["thr_forged"],
    "thr_p": best["thr_forged"],          # alias for compatibility
    "thr_mask": best["thr_mask"],
    "min_pred_patches": best["min_pred_patches"],
    "grid_hw": [int(GH), int(GW)],
    "img_size": int(IMG_SIZE),
    "patch_size": int(PATCH),
    "calibration": calib_kind,
}
with open(OUT_DIR / "thresholds.json", "w") as f:
    json.dump(thresholds, f, indent=2)

model_config = {
    "dino_dir": str(CFG["dino_dir"]),
    "img_size": int(IMG_SIZE),
    "patch_size": int(PATCH),
    "grid_hw": [int(GH), int(GW)],
    "unfreeze_last_n_blocks": int(CFG["unfreeze_last_n_blocks"]),
    "note": "Checkpoints store trainable_state only (heads + unfrozen blocks). Load base DINO from dino_dir, then apply trainable_state.",
    "fold_checkpoints": best_fold_paths,
}
with open(OUT_DIR / "model_config.json", "w") as f:
    json.dump(model_config, f, indent=2)

# OOF audit file
df_oof = pd.DataFrame({
    "sample_id": oof_sid,
    "fold": oof_fold,
    "y_forged": oof_y,
    "p_raw": p_raw.astype(np.float32),
    "p_cal": p_cal.astype(np.float32),
    "gt_sum": gt_sum.astype(np.int32),
})
df_oof.to_csv(OUT_DIR / "oof_predictions.csv", index=False)

# final report
report = {
    "cfg": CFG,
    "out_dir": str(OUT_DIR),
    "device": str(device),
    "n_train": int(n_all),
    "forged_rate": float(oof_y.mean()),
    "oof_auc_raw": auc_raw if np.isfinite(auc_raw) else None,
    "oof_ap_raw": ap_raw if np.isfinite(ap_raw) else None,
    "oof_auc_cal": auc_cal if np.isfinite(auc_cal) else None,
    "oof_ap_cal": ap_cal if np.isfinite(ap_cal) else None,
    "thresholds": thresholds,
    "fold_reports": fold_reports,
}
with open(OUT_DIR / "report.json", "w") as f:
    json.dump(report, f, indent=2)

# print checkpoint sizes
sizes = {}
for k, p in best_fold_paths.items():
    pp = Path(p)
    if pp.exists():
        sizes[str(k)] = float(pp.stat().st_size / (1024**2))
print("\nCheckpoint sizes (MB) per fold:", sizes)

print("\nSAVED ->", OUT_DIR)
print("Files:", sorted([p.name for p in OUT_DIR.iterdir()]))

DINO_MT_MODEL_DIR = str(OUT_DIR)
print("\nDONE. Exported: DINO_MT_MODEL_DIR =", DINO_MT_MODEL_DIR)


# Inference Strategy: Two-Pass + Smart Ensemble + Export RLE (Strict Guard)

In [None]:
# ============================================================
# STAGE 6 — DINOv2 Multi-Task Inference: Two-Pass + Smart Ensemble + Export RLE (Strict Guard)
# ONE CELL, REVISI FULL v3 (FAST + ROBUST + STRICT ORDER)
#
# Requirements:
# - Output STAGE 5 (DINO Multi-task) exists:
#   * DINO_MT_MODEL_DIR (or auto-find /kaggle/working/recodai_luc/models/dinov2_mt_v5_*)
#   * model_config.json, thresholds.json, calibrator.joblib, checkpoints/fold_*.pt
#
# Key upgrades:
# - Mean-weights ensemble across folds (FAST). Optional exact fold ensemble if needed.
# - PASS-1 small (448) + PASS-2 large (700) only for borderline.
# - Strict guard uses: thr_forged + thr_mask + min_pred_patches + anti-huge/fragment rules.
# - Cache PASS-1/PASS-2 masks as NPZ (mask_pack) => rerun is fast.
# - Export RLE strictly in sample_submission order.
# ============================================================

import os, gc, json, time, math, re
from pathlib import Path

import numpy as np
import pandas as pd
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel
import joblib

# ----------------------------
# Optional SciPy for morphology / components
# ----------------------------
try:
    import scipy.ndimage as ndi
    _HAS_SCIPY = True
except Exception:
    _HAS_SCIPY = False

# ----------------------------
# PATHS (fixed)
# ----------------------------
DATA_ROOT = Path("/kaggle/input/recodai-luc-scientific-image-forgery-detection")
TEST_IMAGES_DIR = DATA_ROOT / "test_images"
SAMPLE_SUB_PATH = DATA_ROOT / "sample_submission.csv"

# ----------------------------
# CACHE DIRS
# ----------------------------
P1_DIR = Path("/kaggle/working/recodai_luc/cache/dino_mt_p1")
P2_DIR = Path("/kaggle/working/recodai_luc/cache/dino_mt_p2")
P1_DIR.mkdir(parents=True, exist_ok=True)
P2_DIR.mkdir(parents=True, exist_ok=True)

OUT_SUB_PATH  = Path("/kaggle/working/submission.csv")
OUT_COPY_PATH = Path("/kaggle/working/recodai_luc/outputs/submission.csv")
OUT_COPY_PATH.parent.mkdir(parents=True, exist_ok=True)

# ----------------------------
# USER TUNABLE (safe defaults)
# ----------------------------
PASS1_SIZE = 448          # must be multiple of patch_size=14 (OK)
PASS2_SIZE = 700          # must be multiple of 14 (OK)
PASS1_BS   = 8            # auto-reduce if OOM
PASS2_BS   = 4
USE_PASS2  = True

BORDER_MARGIN = 0.08      # borderline if |p-thr| <= margin
MAX_BORDERLINE = None     # set int to cap PASS-2 (e.g., 1500)

# FAST ensemble: average weights across folds (recommended)
EXACT_FOLD_ENSEMBLE = False  # if True: run all folds and average outputs (slower)

# RLE order (if you already set global RLE_ORDER, it will use it)
RLE_ORDER = globals().get("RLE_ORDER", "F")
if RLE_ORDER not in ("F","C"):
    RLE_ORDER = "F"

# ----------------------------
# Require sample submission
# ----------------------------
if not SAMPLE_SUB_PATH.exists():
    raise FileNotFoundError(f"sample_submission.csv not found: {SAMPLE_SUB_PATH}")

df_sample = pd.read_csv(SAMPLE_SUB_PATH)
if not {"case_id","annotation"}.issubset({c.lower() for c in df_sample.columns}):
    raise ValueError(f"sample_submission must contain case_id, annotation. Found: {list(df_sample.columns)}")

col_case = [c for c in df_sample.columns if c.lower()=="case_id"][0]
col_ann  = [c for c in df_sample.columns if c.lower()=="annotation"][0]
df_sample = df_sample.rename(columns={col_case:"case_id", col_ann:"annotation"}).copy()
df_sample["case_id"] = df_sample["case_id"].astype(str)

# ----------------------------
# Resolve test images in STRICT sample order
# ----------------------------
IMG_EXTS = {".png",".jpg",".jpeg",".tif",".tiff",".bmp",".webp"}

def build_caseid_map(folder: Path) -> dict:
    mp = {}
    files = [p for p in folder.rglob("*") if p.is_file() and p.suffix.lower() in IMG_EXTS]
    files.sort()
    for p in files:
        cid = p.stem
        if cid not in mp:
            mp[cid] = p
    return mp

test_img_map = build_caseid_map(TEST_IMAGES_DIR)
df_test = pd.DataFrame({"case_id": df_sample["case_id"].astype(str).tolist()})
df_test["image_path"] = df_test["case_id"].map(lambda x: str(test_img_map.get(str(x), "")))

n_ok = int(df_test["image_path"].map(lambda p: Path(p).exists()).sum())
print(f"Test images resolved: {n_ok:,}/{len(df_test):,}")

# ----------------------------
# RLE encode
# ----------------------------
def rle_encode(mask: np.ndarray, order: str="F") -> str:
    m = (mask > 0).astype(np.uint8)
    if m.sum() == 0:
        return ""
    if order.upper() == "F":
        pixels = m.T.reshape(-1)
    else:
        pixels = m.reshape(-1)
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[0::2]
    return " ".join(map(str, runs))

# ----------------------------
# Auto-find DINO_MT_MODEL_DIR
# ----------------------------
def _auto_find_latest_dir(root: Path, pattern: str):
    root = Path(root)
    if not root.exists():
        return None
    cands = [p for p in root.glob(pattern) if p.is_dir()]
    if not cands:
        return None
    cands.sort(key=lambda p: p.stat().st_mtime, reverse=True)
    return cands[0]

if "DINO_MT_MODEL_DIR" in globals():
    DINO_MT_MODEL_DIR = Path(str(globals()["DINO_MT_MODEL_DIR"]))
else:
    DINO_MT_MODEL_DIR = _auto_find_latest_dir(Path("/kaggle/working/recodai_luc/models"), "dinov2_mt_v5_*")

if DINO_MT_MODEL_DIR is None or (not DINO_MT_MODEL_DIR.exists()):
    raise FileNotFoundError("DINO_MT_MODEL_DIR not found. Jalankan STAGE 5 DINO Multi-task dulu.")

cfg_path = DINO_MT_MODEL_DIR / "model_config.json"
thr_path = DINO_MT_MODEL_DIR / "thresholds.json"
cal_path = DINO_MT_MODEL_DIR / "calibrator.joblib"
ckpt_dir = DINO_MT_MODEL_DIR / "checkpoints"

for p in [cfg_path, thr_path, cal_path, ckpt_dir]:
    if not p.exists():
        raise FileNotFoundError(f"Missing artifact: {p}")

model_cfg = json.loads(cfg_path.read_text())
thr_cfg   = json.loads(thr_path.read_text())
cal_pack  = joblib.load(cal_path)

calib_kind = cal_pack.get("kind", "none")
calibrator = cal_pack.get("calibrator", None)

# thresholds from stage 5
thr_forged = float(thr_cfg.get("thr_forged", thr_cfg.get("thr_p", 0.5)))
thr_mask   = float(thr_cfg.get("thr_mask", 0.5))
min_pred_patches = int(thr_cfg.get("min_pred_patches", 0))

patch_size = int(thr_cfg.get("patch_size", model_cfg.get("patch_size", 14)))
assert PASS1_SIZE % patch_size == 0 and PASS2_SIZE % patch_size == 0, "PASS sizes must be multiple of patch_size."

print("\nLoaded DINO multi-task artifacts:")
print(f"  model_dir          : {DINO_MT_MODEL_DIR}")
print(f"  dino_dir           : {model_cfg.get('dino_dir')}")
print(f"  patch_size         : {patch_size}")
print(f"  thr_forged         : {thr_forged}")
print(f"  thr_mask           : {thr_mask}")
print(f"  min_pred_patches   : {min_pred_patches}")
print(f"  calib_kind         : {calib_kind}")
print(f"  EXACT_FOLD_ENSEMBLE: {EXACT_FOLD_ENSEMBLE}")
print(f"  RLE_ORDER          : {RLE_ORDER}")

def apply_calibrator(p):
    p = np.asarray(p, dtype=np.float32)
    if calibrator is None or calib_kind == "none":
        return p
    if calib_kind == "isotonic":
        return calibrator.transform(p).astype(np.float32)
    return calibrator.predict_proba(p.reshape(-1,1))[:,1].astype(np.float32)

# ----------------------------
# Build DINO multitask model (same heads as STAGE 5)
# ----------------------------
IMNET_MEAN = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32).view(3,1,1)
IMNET_STD  = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32).view(3,1,1)

def freeze_all(model):
    for p in model.parameters():
        p.requires_grad = False

class DinoMultiTask(nn.Module):
    def __init__(self, backbone: nn.Module, patch: int):
        super().__init__()
        self.backbone = backbone
        self.patch = int(patch)

        # infer embed dim
        with torch.no_grad():
            dummy = torch.zeros((1,3,PASS1_SIZE,PASS1_SIZE), dtype=torch.float32)
            out = self.backbone(pixel_values=dummy)
            D = int(out.last_hidden_state.shape[-1])
        self.embed_dim = D

        self.cls_head = nn.Sequential(nn.LayerNorm(D), nn.Linear(D, 1))
        self.seg_head = nn.Sequential(nn.LayerNorm(D), nn.Linear(D, 1))

    def forward(self, x):
        out = self.backbone(pixel_values=x)
        h = out.last_hidden_state  # [B,1+N,D]
        cls_tok = h[:, 0, :]
        ptok = h[:, 1:, :]
        cls_logit = self.cls_head(cls_tok).squeeze(-1)      # [B]
        seg_patch = self.seg_head(ptok).squeeze(-1)         # [B,N]

        # infer grid from input size (square)
        B, _, H, W = x.shape
        gh = H // self.patch
        gw = W // self.patch
        Nexp = gh * gw
        if seg_patch.shape[1] < Nexp:
            # safety: pad
            pad = Nexp - seg_patch.shape[1]
            seg_patch = torch.cat([seg_patch, seg_patch.new_zeros((B, pad))], dim=1)
        seg_grid = seg_patch[:, :Nexp].reshape(B, gh, gw)
        return cls_logit, seg_grid

# ----------------------------
# Load fold checkpoints and build ensemble state
# ----------------------------
def list_fold_ckpts(ckpt_dir: Path):
    cands = sorted(ckpt_dir.glob("fold_*.pt"))
    if not cands:
        raise FileNotFoundError(f"No fold checkpoints found in {ckpt_dir}")
    return cands

fold_ckpts = list_fold_ckpts(ckpt_dir)
print(f"\nFound fold checkpoints: {len(fold_ckpts)}")

def load_trainable_state(pt_path: Path):
    ck = torch.load(pt_path, map_location="cpu")
    sd = ck.get("trainable_state", None)
    if sd is None:
        raise RuntimeError(f"Checkpoint missing trainable_state: {pt_path}")
    # ensure float32 for averaging
    out = {}
    for k, v in sd.items():
        if torch.is_tensor(v) and v.is_floating_point():
            out[k] = v.float()
        else:
            out[k] = v
    return out

def average_states(states: list):
    keys = set(states[0].keys())
    for st in states[1:]:
        keys &= set(st.keys())
    keys = sorted(list(keys))
    avg = {}
    for k in keys:
        vs = [st[k] for st in states]
        if torch.is_tensor(vs[0]) and vs[0].is_floating_point():
            s = vs[0].clone()
            for t in vs[1:]:
                s += t
            avg[k] = (s / float(len(vs)))
        else:
            avg[k] = vs[0]
    return avg

def apply_trainable_state(model: nn.Module, trainable_state: dict):
    cur = model.state_dict()
    for k, v in trainable_state.items():
        if k in cur:
            if torch.is_tensor(v) and torch.is_tensor(cur[k]):
                cur[k] = v.to(cur[k].dtype)
            else:
                cur[k] = v
    model.load_state_dict(cur, strict=False)

# ----------------------------
# Image preprocessing (square, consistent with STAGE 5)
# ----------------------------
def load_image_square(path: str, size: int):
    p = Path(path)
    if not p.exists():
        return None, None
    im = Image.open(p).convert("RGB")
    orig_w, orig_h = im.size
    if im.size != (size, size):
        im = im.resize((size, size), resample=Image.BICUBIC)
    arr = (np.asarray(im, dtype=np.float32) / 255.0)  # HWC
    x = torch.from_numpy(arr).permute(2,0,1).contiguous()  # CHW
    x = (x - IMNET_MEAN) / IMNET_STD
    meta = {"orig_h": int(orig_h), "orig_w": int(orig_w), "side": int(size)}
    return x, meta

def grid_to_mask_orig(grid_u8: np.ndarray, meta: dict, patch: int):
    # grid_u8: [gh,gw] 0/1 on square side meta["side"]
    side = int(meta["side"])
    orig_h = int(meta["orig_h"])
    orig_w = int(meta["orig_w"])
    mask_sq = np.kron(grid_u8.astype(np.uint8), np.ones((patch, patch), dtype=np.uint8))
    mask_sq = mask_sq[:side, :side]
    # resize square->orig
    if (orig_w, orig_h) != (side, side):
        im = Image.fromarray((mask_sq * 255).astype(np.uint8))
        im = im.resize((orig_w, orig_h), resample=Image.NEAREST)
        mask = (np.array(im) > 0).astype(np.uint8)
    else:
        mask = mask_sq.astype(np.uint8)
    return mask

def pack_mask(mask_u8: np.ndarray):
    mh, mw = mask_u8.shape
    pack = np.packbits((mask_u8 > 0).astype(np.uint8), axis=None).astype(np.uint8)
    return pack, mh, mw

def unpack_mask(pack: np.ndarray, h: int, w: int):
    if h <= 0 or w <= 0 or pack is None or pack.size == 0:
        return np.zeros((max(h,1), max(w,1)), dtype=np.uint8)
    bits = np.unpackbits(pack.astype(np.uint8), axis=None)[: h*w]
    return bits.reshape(h, w).astype(np.uint8)

# ----------------------------
# Post-filter mask (anti-fragment, anti-noise)
# ----------------------------
def filter_mask(mask_u8: np.ndarray, min_area_frac=0.0, keep_topk=10, close_ks=3, open_ks=0):
    if mask_u8.sum() == 0:
        return mask_u8, {"n_comp": 0, "largest": 0}

    H, W = mask_u8.shape
    area = int(mask_u8.sum())
    denom = float(H*W) + 1e-9
    if float(min_area_frac) > 0 and (area/denom) < float(min_area_frac):
        return np.zeros_like(mask_u8, dtype=np.uint8), {"n_comp": 0, "largest": 0}

    m = mask_u8.astype(bool)
    if _HAS_SCIPY:
        if int(close_ks) and int(close_ks) > 1:
            st = np.ones((int(close_ks), int(close_ks)), dtype=bool)
            m = ndi.binary_closing(m, structure=st)
        if int(open_ks) and int(open_ks) > 1:
            st = np.ones((int(open_ks), int(open_ks)), dtype=bool)
            m = ndi.binary_opening(m, structure=st)

        lab, n = ndi.label(m)
        if n <= 0:
            return np.zeros_like(mask_u8, dtype=np.uint8), {"n_comp": 0, "largest": 0}

        areas = np.bincount(lab.ravel())
        areas[0] = 0
        comps = np.where(areas > 0)[0]
        comps = comps[np.argsort(areas[comps])[::-1]]
        if keep_topk and int(keep_topk) > 0:
            comps = comps[:int(keep_topk)]
        out = np.isin(lab, comps).astype(np.uint8)
        largest = int(areas[comps[0]]) if comps.size else 0
        return out, {"n_comp": int(comps.size), "largest": largest}

    # no scipy: minimal
    return mask_u8.astype(np.uint8), {"n_comp": 1, "largest": int(mask_u8.sum())}

# ----------------------------
# Strict guard + quality score (DINO multi-task)
# ----------------------------
def strict_guard(p_cal: float, pred_patches: int, area_frac: float):
    if float(p_cal) < thr_forged:
        return False
    if int(pred_patches) < int(min_pred_patches):
        return False
    # anti-huge mask unless very confident
    if float(area_frac) > 0.65 and float(p_cal) < (thr_forged + 0.12):
        return False
    # tiny mask with low confidence -> reject
    if float(area_frac) < 0.00015 and float(p_cal) < (thr_forged + 0.10):
        return False
    return True

def quality_score(p_cal: float, area_frac: float, n_comp: float, mean_conf_in_mask: float):
    # higher is better
    frag_pen = min(float(n_comp) / 35.0, 1.0) * 0.35
    huge_pen = 0.0
    if float(area_frac) > 0.35:
        huge_pen = min((float(area_frac) - 0.35) / 0.35, 1.0) * 0.45
    q = float(p_cal) * (0.55 + 0.45*float(mean_conf_in_mask)) * (1.0 - frag_pen) * (1.0 - huge_pen)
    return float(q)

def mean_conf_inside(seg_prob_grid: np.ndarray, grid_bin: np.ndarray):
    # seg_prob_grid float32 [gh,gw], grid_bin uint8
    if grid_bin.sum() == 0:
        return 0.0
    v = seg_prob_grid[grid_bin > 0]
    return float(v.mean()) if v.size else 0.0

# ----------------------------
# Cache IO (store final thresholded mask + stats)
# ----------------------------
def cache_path(cache_dir: Path, case_id: str, tag: str):
    return cache_dir / f"{case_id}_{tag}.npz"

def load_cached(cache_dir: Path, case_id: str, tag: str):
    p = cache_path(cache_dir, case_id, tag)
    if not p.exists():
        return None
    z = np.load(p, allow_pickle=False)
    pack = z["mask_pack"].astype(np.uint8).reshape(-1)
    mh = int(z["mask_h"]); mw = int(z["mask_w"])
    info = {k: float(z[k]) for k in z.files if k not in ("mask_pack","mask_h","mask_w")}
    return pack, mh, mw, info

def save_cached(cache_dir: Path, case_id: str, tag: str, mask_u8: np.ndarray, info: dict):
    pack, mh, mw = pack_mask(mask_u8)
    payload = {
        "mask_pack": pack,
        "mask_h": np.int32(mh),
        "mask_w": np.int32(mw),
        "p_raw": np.float32(float(info.get("p_raw", 0.0))),
        "p_cal": np.float32(float(info.get("p_cal", 0.0))),
        "pred_patches": np.float32(float(info.get("pred_patches", 0.0))),
        "area_frac": np.float32(float(info.get("area_frac", 0.0))),
        "n_comp": np.float32(float(info.get("n_comp", 0.0))),
        "mean_conf": np.float32(float(info.get("mean_conf", 0.0))),
        "side": np.float32(float(info.get("side", 0.0))),
    }
    np.savez_compressed(cache_path(cache_dir, case_id, tag), **payload)

# ----------------------------
# Inference core (supports mean-state or exact fold ensemble)
# ----------------------------
def build_model_on_device(dino_dir: str, device: torch.device):
    backbone = AutoModel.from_pretrained(str(dino_dir), local_files_only=True).eval()
    freeze_all(backbone)
    model = DinoMultiTask(backbone, patch=patch_size).eval().to(device)
    return model

def infer_batch(model, x_batch, use_amp=True):
    device = next(model.parameters()).device
    x_batch = x_batch.to(device, non_blocking=True)
    with torch.inference_mode():
        if use_amp and device.type == "cuda":
            with torch.cuda.amp.autocast(True):
                cls_logit, seg_logit = model(x_batch)
        else:
            cls_logit, seg_logit = model(x_batch)
    p_raw = torch.sigmoid(cls_logit).detach().float().cpu().numpy().astype(np.float32)
    seg_prob = torch.sigmoid(seg_logit).detach().float().cpu().numpy().astype(np.float32)  # [B,gh,gw]
    return p_raw, seg_prob

def infer_with_mean_state(image_paths, metas, side: int, batch_size: int, cache_dir: Path, tag: str):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = build_model_on_device(model_cfg["dino_dir"], device)

    # build mean state once
    states = [load_trainable_state(p) for p in fold_ckpts]
    mean_state = average_states(states)
    apply_trainable_state(model, mean_state)

    # optional: half on cuda
    if device.type == "cuda":
        model = model.half()

    n = len(image_paths)
    p_raw_all = np.zeros(n, dtype=np.float32)
    seg_prob_all = None  # store per item while saving cache

    # process in batches
    i = 0
    while i < n:
        j = min(i + batch_size, n)
        xb = []
        mb = metas[i:j]
        for k in range(i, j):
            x, _ = load_image_square(image_paths[k], size=side)
            if x is None:
                x = torch.zeros((3, side, side), dtype=torch.float32)
            xb.append(x)
        xb = torch.stack(xb, dim=0)
        if device.type == "cuda":
            xb = xb.half()
        p_raw, seg_prob = infer_batch(model, xb, use_amp=True)
        p_raw_all[i:j] = p_raw

        # per item: threshold seg -> mask -> post-filter -> cache
        p_cal = apply_calibrator(p_raw)

        gh = side // patch_size
        gw = side // patch_size
        for t in range(j - i):
            meta = mb[t]
            cid = str(meta["case_id"])
            orig_h = int(meta["orig_h"]); orig_w = int(meta["orig_w"])
            seg = seg_prob[t]  # [gh,gw]
            grid = (seg >= thr_mask).astype(np.uint8)
            pred_patches = int(grid.sum())

            mask = grid_to_mask_orig(grid, meta={"orig_h": orig_h, "orig_w": orig_w, "side": side}, patch=patch_size)

            # post-filter: keep stable; set minimal tiny filter
            mask2, comp = filter_mask(mask, min_area_frac=0.0, keep_topk=10, close_ks=3, open_ks=0)

            area = float(mask2.sum())
            area_frac = float(area / (float(orig_h*orig_w) + 1e-9))
            mean_conf = mean_conf_inside(seg, grid)

            info = {
                "p_raw": float(p_raw[t]),
                "p_cal": float(p_cal[t]),
                "pred_patches": float(pred_patches),
                "area_frac": float(area_frac),
                "n_comp": float(comp.get("n_comp", 0)),
                "mean_conf": float(mean_conf),
                "side": float(side),
            }
            save_cached(cache_dir, cid, tag, mask2.astype(np.uint8), info)

        i = j
        if (i % max(100, batch_size*25)) == 0:
            print(f"[{tag}] infer+cache {i:,}/{n:,}")

    # cleanup
    del model
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

def infer_with_exact_fold_ensemble(image_paths, metas, side: int, batch_size: int, cache_dir: Path, tag: str):
    # exact: average outputs across folds (slower)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = build_model_on_device(model_cfg["dino_dir"], device)
    if device.type == "cuda":
        model = model.half()

    n = len(image_paths)
    # accumulator
    p_raw_sum = np.zeros(n, dtype=np.float32)
    seg_sum = np.zeros((n, side//patch_size, side//patch_size), dtype=np.float32)

    for fi, ckpt_path in enumerate(fold_ckpts, 1):
        st = load_trainable_state(ckpt_path)
        apply_trainable_state(model, st)

        i = 0
        while i < n:
            j = min(i + batch_size, n)
            xb = []
            mb = metas[i:j]
            for k in range(i, j):
                x, _ = load_image_square(image_paths[k], size=side)
                if x is None:
                    x = torch.zeros((3, side, side), dtype=torch.float32)
                xb.append(x)
            xb = torch.stack(xb, dim=0)
            if device.type == "cuda":
                xb = xb.half()
            p_raw, seg_prob = infer_batch(model, xb, use_amp=True)
            p_raw_sum[i:j] += p_raw
            seg_sum[i:j] += seg_prob
            i = j

        print(f"[{tag}] fold {fi}/{len(fold_ckpts)} done")

    p_raw_avg = p_raw_sum / float(len(fold_ckpts))
    seg_avg = seg_sum / float(len(fold_ckpts))
    p_cal = apply_calibrator(p_raw_avg)

    # write cache
    gh = side // patch_size
    gw = side // patch_size
    for idx in range(n):
        meta = metas[idx]
        cid = str(meta["case_id"])
        orig_h = int(meta["orig_h"]); orig_w = int(meta["orig_w"])
        seg = seg_avg[idx]
        grid = (seg >= thr_mask).astype(np.uint8)
        pred_patches = int(grid.sum())

        mask = grid_to_mask_orig(grid, meta={"orig_h": orig_h, "orig_w": orig_w, "side": side}, patch=patch_size)
        mask2, comp = filter_mask(mask, min_area_frac=0.0, keep_topk=10, close_ks=3, open_ks=0)

        area = float(mask2.sum())
        area_frac = float(area / (float(orig_h*orig_w) + 1e-9))
        mean_conf = mean_conf_inside(seg, grid)

        info = {
            "p_raw": float(p_raw_avg[idx]),
            "p_cal": float(p_cal[idx]),
            "pred_patches": float(pred_patches),
            "area_frac": float(area_frac),
            "n_comp": float(comp.get("n_comp", 0)),
            "mean_conf": float(mean_conf),
            "side": float(side),
        }
        save_cached(cache_dir, cid, tag, mask2.astype(np.uint8), info)

    del model
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# ----------------------------
# Build meta list + load from cache (PASS-1)
# ----------------------------
t0 = time.time()

metas = []
need_p1 = []
for i, row in df_test.iterrows():
    cid = str(row["case_id"])
    ip = str(row["image_path"])
    if not Path(ip).exists():
        metas.append({"case_id": cid, "image_path": ip, "orig_h": 1, "orig_w": 1})
        continue
    im = Image.open(ip)
    ow, oh = im.size
    im.close()
    metas.append({"case_id": cid, "image_path": ip, "orig_h": int(oh), "orig_w": int(ow)})

    c = load_cached(P1_DIR, cid, tag=f"s{PASS1_SIZE}")
    if c is None:
        need_p1.append(i)

print(f"\nPASS-1 cache check: need_compute={len(need_p1):,}/{len(df_test):,}")

# Compute missing PASS-1 caches
if len(need_p1) > 0:
    idxs = need_p1
    img_paths = [metas[i]["image_path"] for i in idxs]
    meta_sub  = [dict(metas[i], case_id=str(metas[i]["case_id"])) for i in idxs]
    # batch size auto reduce on OOM
    bs = int(PASS1_BS)
    while True:
        try:
            if EXACT_FOLD_ENSEMBLE:
                infer_with_exact_fold_ensemble(img_paths, meta_sub, side=PASS1_SIZE, batch_size=bs, cache_dir=P1_DIR, tag=f"s{PASS1_SIZE}")
            else:
                infer_with_mean_state(img_paths, meta_sub, side=PASS1_SIZE, batch_size=bs, cache_dir=P1_DIR, tag=f"s{PASS1_SIZE}")
            break
        except RuntimeError as e:
            if "out of memory" in str(e).lower() and bs > 1:
                bs = max(1, bs // 2)
                print(f"[OOM] Reduce PASS1_BS -> {bs}")
                gc.collect()
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                continue
            raise

# Load PASS-1 info for all
p1_pack = [None]*len(df_test)
p1_hw   = [None]*len(df_test)
p1_info = [None]*len(df_test)

miss_p1 = 0
for i, row in df_test.iterrows():
    cid = str(row["case_id"])
    c = load_cached(P1_DIR, cid, tag=f"s{PASS1_SIZE}")
    if c is None:
        miss_p1 += 1
        p1_pack[i] = np.zeros((0,), dtype=np.uint8)
        p1_hw[i]   = (1,1)
        p1_info[i] = {"p_raw":0.0,"p_cal":0.0,"pred_patches":0.0,"area_frac":0.0,"n_comp":0.0,"mean_conf":0.0,"side":float(PASS1_SIZE)}
    else:
        pack, mh, mw, info = c
        p1_pack[i] = pack
        p1_hw[i]   = (mh, mw)
        p1_info[i] = info

print(f"PASS-1 loaded. missing_after_compute={miss_p1:,}")

# ----------------------------
# Borderline selection for PASS-2
# ----------------------------
borderline = []
if USE_PASS2:
    for i, row in df_test.iterrows():
        cid = str(row["case_id"])
        ip  = str(row["image_path"])
        if not Path(ip).exists():
            continue

        p = float(p1_info[i].get("p_cal", 0.0))
        area = float(p1_info[i].get("area_frac", 0.0))
        pp = int(round(float(p1_info[i].get("pred_patches", 0.0))))
        # borderline conditions (smart)
        near_thr = (abs(p - thr_forged) <= float(BORDER_MARGIN))
        repair_small = (p >= (thr_forged - 0.06)) and (pp < max(min_pred_patches, 2)) and (area < 0.0012)
        recover_fn = (p >= (thr_forged - 0.10)) and (pp >= max(min_pred_patches-1, 1)) and (area >= 0.00025)
        if (near_thr or repair_small or recover_fn):
            borderline.append((cid, abs(p - thr_forged)))

    borderline.sort(key=lambda x: x[1])
    if MAX_BORDERLINE is not None and len(borderline) > int(MAX_BORDERLINE):
        borderline = borderline[:int(MAX_BORDERLINE)]
    borderline_ids = [x[0] for x in borderline]
else:
    borderline_ids = []

print(f"Borderline for PASS-2: {len(borderline_ids):,}/{len(df_test):,}")

# Compute missing PASS-2 caches for borderline only
if USE_PASS2 and len(borderline_ids) > 0:
    need_p2 = []
    id_to_index = {str(df_test.iloc[i]["case_id"]): i for i in range(len(df_test))}
    for cid in borderline_ids:
        if load_cached(P2_DIR, cid, tag=f"s{PASS2_SIZE}") is None:
            need_p2.append(cid)

    print(f"PASS-2 cache check: need_compute={len(need_p2):,}/{len(borderline_ids):,}")

    if len(need_p2) > 0:
        idxs = [id_to_index[cid] for cid in need_p2]
        img_paths = [metas[i]["image_path"] for i in idxs]
        meta_sub  = [dict(metas[i], case_id=str(metas[i]["case_id"])) for i in idxs]
        bs = int(PASS2_BS)
        while True:
            try:
                if EXACT_FOLD_ENSEMBLE:
                    infer_with_exact_fold_ensemble(img_paths, meta_sub, side=PASS2_SIZE, batch_size=bs, cache_dir=P2_DIR, tag=f"s{PASS2_SIZE}")
                else:
                    infer_with_mean_state(img_paths, meta_sub, side=PASS2_SIZE, batch_size=bs, cache_dir=P2_DIR, tag=f"s{PASS2_SIZE}")
                break
            except RuntimeError as e:
                if "out of memory" in str(e).lower() and bs > 1:
                    bs = max(1, bs // 2)
                    print(f"[OOM] Reduce PASS2_BS -> {bs}")
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    continue
                raise

# ----------------------------
# Smart ensemble P1 vs P2 + Strict guard + Export
# ----------------------------
def iou_masks(m1: np.ndarray, m2: np.ndarray):
    inter = float(np.logical_and(m1 > 0, m2 > 0).sum())
    uni = float(np.logical_or(m1 > 0, m2 > 0).sum()) + 1e-9
    return inter/uni

results = []
n_pred_forg = 0
n_used_p2 = 0
n_union = 0
n_inter = 0

for i, row in df_test.iterrows():
    cid = str(row["case_id"])
    ip  = str(row["image_path"])

    # PASS-1 info
    info1 = p1_info[i]
    mh1, mw1 = p1_hw[i]
    pcal1 = float(info1.get("p_cal", 0.0))
    area1 = float(info1.get("area_frac", 0.0))
    pp1   = int(round(float(info1.get("pred_patches", 0.0))))
    nc1   = float(info1.get("n_comp", 0.0))
    mc1   = float(info1.get("mean_conf", 0.0))

    g1 = strict_guard(pcal1, pp1, area1)
    q1 = quality_score(pcal1, area1, nc1, mc1)

    final_source = "p1"
    final_guard = g1
    final_mask = None
    final_pcal = pcal1
    final_q = q1

    # PASS-2 available?
    use_p2 = False
    if USE_PASS2 and (cid in set(borderline_ids)) and Path(ip).exists():
        c2 = load_cached(P2_DIR, cid, tag=f"s{PASS2_SIZE}")
        if c2 is not None:
            use_p2 = True
            pack2, mh2, mw2, info2 = c2
            pcal2 = float(info2.get("p_cal", 0.0))
            area2 = float(info2.get("area_frac", 0.0))
            pp2   = int(round(float(info2.get("pred_patches", 0.0))))
            nc2   = float(info2.get("n_comp", 0.0))
            mc2   = float(info2.get("mean_conf", 0.0))

            g2 = strict_guard(pcal2, pp2, area2)
            q2 = quality_score(pcal2, area2, nc2, mc2)

            # decision rules (stronger + stable)
            if g2 and (not g1):
                final_source = "p2"; final_guard = True; final_pcal = pcal2; final_q = q2
                final_mask = unpack_mask(pack2, mh2, mw2)
                n_used_p2 += 1
            elif g1 and (not g2):
                final_source = "p1"
            else:
                if g1 and g2:
                    # compare masks
                    m1 = unpack_mask(p1_pack[i], mh1, mw1)
                    m2 = unpack_mask(pack2, mh2, mw2)
                    iou = iou_masks(m1, m2)

                    strong1 = (pcal1 >= thr_forged + 0.12) and (mc1 >= 0.60)
                    strong2 = (pcal2 >= thr_forged + 0.12) and (mc2 >= 0.60)

                    if iou >= 0.18 and strong1 and strong2:
                        mu = np.logical_or(m1 > 0, m2 > 0).astype(np.uint8)
                        final_source = "union"; final_guard = True; final_pcal = max(pcal1, pcal2); final_q = max(q1, q2)
                        final_mask = mu
                        n_union += 1
                        n_used_p2 += 1
                    elif iou <= 0.06 and strong1 and strong2:
                        mi = np.logical_and(m1 > 0, m2 > 0).astype(np.uint8)
                        # re-check guard after intersection
                        area_i = float(mi.sum()) / (float(mi.size) + 1e-9)
                        pp_i = int(round((mi.sum() / (patch_size*patch_size)) / max(1.0, (PASS1_SIZE*PASS1_SIZE)/(patch_size*patch_size))))  # coarse proxy
                        if strict_guard(max(pcal1, pcal2), max(pp1, pp2), area_i):
                            final_source = "inter"; final_guard = True; final_pcal = max(pcal1, pcal2); final_q = max(q1, q2)
                            final_mask = mi
                            n_inter += 1
                            n_used_p2 += 1
                        else:
                            # fallback best quality
                            if q2 > q1 * 1.03:
                                final_source = "p2"; final_guard = g2; final_pcal = pcal2; final_q = q2
                                final_mask = m2
                                n_used_p2 += 1
                            else:
                                final_source = "p1"
                    else:
                        if q2 > q1 * 1.05:
                            final_source = "p2"; final_guard = g2; final_pcal = pcal2; final_q = q2
                            final_mask = m2
                            n_used_p2 += 1
                        else:
                            final_source = "p1"
                else:
                    # both not guarded: pick best quality if clearly better
                    if q2 > q1 * 1.12:
                        final_source = "p2"; final_guard = g2; final_pcal = pcal2; final_q = q2
                        final_mask = unpack_mask(pack2, mh2, mw2)
                        n_used_p2 += 1
                    else:
                        final_source = "p1"

    # final annotation
    if final_guard:
        if final_mask is None:
            final_mask = unpack_mask(p1_pack[i], mh1, mw1)
        rle = rle_encode(final_mask.astype(np.uint8), order=RLE_ORDER)
        ann = rle if rle != "" else "authentic"
        if ann != "authentic":
            n_pred_forg += 1
    else:
        ann = "authentic"

    results.append({"case_id": cid, "annotation": ann})

df_sub = pd.DataFrame(results)
df_sub = df_sample[["case_id"]].merge(df_sub, on="case_id", how="left")
df_sub["annotation"] = df_sub["annotation"].fillna("authentic").astype(str)

df_sub.to_csv(OUT_SUB_PATH, index=False)
df_sub.to_csv(OUT_COPY_PATH, index=False)

dt = time.time() - t0
print("\nDONE.")
print(f"submission.csv -> {OUT_SUB_PATH}")
print(f"copy          -> {OUT_COPY_PATH}")
print(f"elapsed_s     -> {dt:.1f}")
print(f"pred_forged   -> {n_pred_forg:,}/{len(df_sub):,}")
print(f"used_pass2    -> {n_used_p2:,} | union={n_union:,} | inter={n_inter:,}")
print(df_sub.head())
