In [10]:
# ==========================================
# Cell 1) Detection datasets discovery & inspection (FINAL++)
#   - How many datasets are found
#   - train/val image count per dataset
#   - Whether label cases (original/scale/side) exist
#   - Output estimated class count/names (multiclass-based)
#
# [UPDATED++]
#   ✅ Reflect split structure rules per dataset name
#   ✅ For Cell 2 acceleration
#       - Store confirmed train/val labels dir
#       - Calculate train/val label file count, box (line) count
#       - Store n_train_groups_est (=box count)
# ========================================== 

from __future__ import annotations

import os, sys, random
from pathlib import Path
from typing import List, Tuple, Optional, Dict

# -------------------------------------------------------------------------
# 0) Register PROJECT_MODULE_DIR
# -------------------------------------------------------------------------
PROJECT_MODULE_DIR = Path("/home/ISW/project/Project_Module")
if str(PROJECT_MODULE_DIR) not in sys.path:
    sys.path.insert(0, str(PROJECT_MODULE_DIR))

# -------------------------------------------------------------------------
# 1) ultra_det_loader
# -------------------------------------------------------------------------
from ultra_det_loader import discover_det_datasets

# -------------------------------------------------------------------------
# 2) noisy_insection (use only scale/boundary jitter case list)
# -------------------------------------------------------------------------
try:
    from noisy_insection import UNIFORM_SCALING_FACTORS, JITTER_PATTERNS
except Exception:
    UNIFORM_SCALING_FACTORS = [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.3, 1.4]
    JITTER_PATTERNS = [1, 3, 5, 7, 9]

# -------------------------------------------------------------------------
# User config
# -------------------------------------------------------------------------
LOAD_DIR = "/home/ISW/project/datasets"
SEED = 42

# Image extensions
_IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}

def set_seed(seed: int = 42):
    random.seed(seed)

def list_images(dir_path: Optional[Path]) -> List[Path]:
    if dir_path is None or not Path(dir_path).exists():
        return []
    dir_path = Path(dir_path)
    imgs = []
    for p in dir_path.rglob("*"):
        if p.is_file() and p.suffix.lower() in _IMG_EXTS:
            imgs.append(p)
    return sorted(imgs)

def normalize_name(name: str) -> str:
    n = name.strip().lower()
    n = n.replace("_", "-")
    n = n.replace(" ", "-")
    return n

# -------------------------------------------------------------------------
# Legacy heuristic (fallback)
# -------------------------------------------------------------------------
def _fallback_train_dir(images_root: Path) -> Path:
    if (images_root / "train").is_dir():
        return images_root / "train"
    return images_root

def _fallback_val_dir(images_root: Path) -> Optional[Path]:
    if (images_root / "val").is_dir():
        return images_root / "val"
    if (images_root / "valid").is_dir():
        return images_root / "valid"
    return None

# -------------------------------------------------------------------------
# ✅ Dataset-specific split rules
# -------------------------------------------------------------------------
_SIMPLE_TRAIN_VAL = {
    "bccd",
    "brain-tumor",
    "custom-blood",
    "homeobjects-3k",
    "kitti",
    "medical-pills",
    "signature",
}

_TRAIN_TEST_VAL = {
    "construction-ppe",
    "african-wildlife",
}

def detect_split_dirs(ds_root: Path) -> Dict[str, Optional[Path]]:
    """
    Interpret images/labels split structure based on ds_root.
    Returns:
        {
          "train_img_dir": Path|None,
          "val_img_dir": Path|None,
          "test_img_dir": Path|None,
          "split_mode": str,  # "explicit" | "sku_virtual_8_2" | "fallback"
          "train_tag": str,
          "val_tag": str,
        }
    """
    ds_name = normalize_name(ds_root.name)
    images_root = ds_root / "images"

    # 1) VOC rule: use train2012/val2012 only
    if ds_name == "voc":
        return dict(
            train_img_dir=images_root / "train2012",
            val_img_dir=images_root / "val2012",
            test_img_dir=None,
            split_mode="explicit",
            train_tag="train2012",
            val_tag="val2012",
        )

    # 2) COCO/LVIS rule
    if ds_name == "coco" or "coco" in ds_name:
        return dict(
            train_img_dir=images_root / "train2017",
            val_img_dir=images_root / "val2017",
            test_img_dir=images_root / "test2017",
            split_mode="explicit",
            train_tag="train2017",
            val_tag="val2017",
        )

    if ds_name == "lvis" or "lvis" in ds_name:
        return dict(
            train_img_dir=images_root / "train2017",
            val_img_dir=images_root / "val2017",
            test_img_dir=images_root / "test2017",
            split_mode="explicit",
            train_tag="train2017",
            val_tag="val2017",
        )

    # 3) Explicit train/val structure
    if ds_name in _SIMPLE_TRAIN_VAL:
        return dict(
            train_img_dir=images_root / "train",
            val_img_dir=images_root / "val",
            test_img_dir=None,
            split_mode="explicit",
            train_tag="train",
            val_tag="val",
        )

    # 4) train/test/val structure
    if ds_name in _TRAIN_TEST_VAL:
        return dict(
            train_img_dir=images_root / "train",
            val_img_dir=images_root / "val",
            test_img_dir=images_root / "test",
            split_mode="explicit",
            train_tag="train",
            val_tag="val",
        )

    # 5) SKU-110K: no subfolders -> virtual split
    if ds_name in {"sku-110k", "sku110k", "sku_110k"} or ("sku" in ds_name and "110k" in ds_name):
        return dict(
            train_img_dir=images_root,
            val_img_dir=images_root,
            test_img_dir=None,
            split_mode="sku_virtual_8_2",
            train_tag="virtual_8_2",
            val_tag="virtual_8_2",
        )

    # 6) fallback
    tr = _fallback_train_dir(images_root)
    va = _fallback_val_dir(images_root)
    return dict(
        train_img_dir=tr,
        val_img_dir=va,
        test_img_dir=None,
        split_mode="fallback",
        train_tag=tr.name if tr else "unknown",
        val_tag=va.name if va else "missing",
    )

# -------------------------------------------------------------------------
# Class name estimation
# -------------------------------------------------------------------------
def infer_class_names_from_labels(label_root: Path, max_files: int = 2000) -> List[str]:
    if label_root is None or not label_root.exists():
        return ["class_0"]

    txts = list(label_root.rglob("*.txt"))
    if not txts:
        return ["class_0"]

    txts = txts[:max_files]
    cls_ids = set()

    for t in txts:
        try:
            with open(t, "r", encoding="utf-8") as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) < 1:
                        continue
                    cid = int(float(parts[0]))
                    cls_ids.add(cid)
        except Exception:
            continue

    if not cls_ids:
        return ["class_0"]

    max_id = max(cls_ids)
    return [f"class_{i}" for i in range(max_id + 1)]

# -------------------------------------------------------------------------
# Label case detection
# -------------------------------------------------------------------------
def list_label_cases_for_dataset(ds_root: Path) -> List[Tuple[str, str]]:
    cases: List[Tuple[str, str]] = []

    if (ds_root / "labels").is_dir():
        cases.append(("original", "labels"))

    for s in UNIFORM_SCALING_FACTORS:
        d = f"labels_uniform_scaling_{s}"
        if (ds_root / d).is_dir():
            cases.append((f"scale_{s}", d))

    for k in JITTER_PATTERNS:
        d = f"labels_boundary_jitter_{k}"
        if (ds_root / d).is_dir():
            cases.append((f"side_{k}", d))

    return cases

# -------------------------------------------------------------------------
# ✅ Determine train/val labels dir (used directly by Cell 2)
# -------------------------------------------------------------------------
def resolve_split_label_dirs(ds_root: Path, train_tag: str, val_tag: str) -> Tuple[Path, Path]:
    labels_root = ds_root / "labels"

    cand_train = labels_root / train_tag
    cand_val   = labels_root / val_tag

    train_labels_dir = cand_train if cand_train.is_dir() else labels_root
    val_labels_dir   = cand_val   if cand_val.is_dir()   else labels_root

    return train_labels_dir, val_labels_dir

# -------------------------------------------------------------------------
# ✅ Label statistics (quick group count estimation)
# -------------------------------------------------------------------------
def count_label_files_and_boxes(label_dir: Optional[Path]) -> Tuple[int, int]:
    if label_dir is None or not Path(label_dir).exists():
        return 0, 0
    label_dir = Path(label_dir)
    txts = sorted(label_dir.rglob("*.txt"))
    n_files = len(txts)
    n_boxes = 0
    for t in txts:
        try:
            with open(t, "r", encoding="utf-8") as f:
                for line in f:
                    if line.strip():
                        n_boxes += 1
        except Exception:
            continue
    return n_files, n_boxes

# -------------------------------------------------------------------------
# SKU-110K virtual split count
# -------------------------------------------------------------------------
def compute_sku_virtual_counts(images_root: Path, seed: int = 42, ratio: float = 0.8) -> Tuple[int, int]:
    imgs = list_images(images_root)
    n = len(imgs)
    if n == 0:
        return 0, 0
    rnd = random.Random(seed)
    idxs = list(range(n))
    rnd.shuffle(idxs)
    cut = int(n * ratio)
    n_train = cut
    n_val = n - cut
    return n_train, n_val

# -------------------------------------------------------------------------
# Discover dataset roots
# -------------------------------------------------------------------------
set_seed(SEED)

specs = discover_det_datasets(LOAD_DIR)
roots: List[Path] = []
for s in specs:
    r = Path(s.root)
    if r not in roots:
        roots.append(r)

print("=" * 80)
print(f"[DISCOVERY] Found {len(roots)} unique dataset roots under: {Path(LOAD_DIR).resolve()}")
print("=" * 80)

# -------------------------------------------------------------------------
# Per-dataset summary
# -------------------------------------------------------------------------
dataset_summaries: List[Dict] = []

for ds_root in roots:
    ds_root = Path(ds_root)
    images_root = ds_root / "images"
    labels_root = ds_root / "labels"

    if not images_root.is_dir() or not labels_root.is_dir():
        print(f"⏭️  Skip (missing images/labels): {ds_root}")
        continue

    split_info = detect_split_dirs(ds_root)
    train_dir = split_info["train_img_dir"]
    val_dir   = split_info["val_img_dir"]
    split_mode = split_info["split_mode"]
    train_tag  = split_info.get("train_tag", "train")
    val_tag    = split_info.get("val_tag", "val")

    # --- Calculate image count ---
    if split_mode == "sku_virtual_8_2":
        n_train, n_val = compute_sku_virtual_counts(images_root, seed=SEED, ratio=0.8)
    else:
        n_train = len(list_images(train_dir))
        n_val   = len(list_images(val_dir)) if val_dir else 0

    # --- Confirm labels split dir ---
    train_labels_dir, val_labels_dir = resolve_split_label_dirs(ds_root, train_tag, val_tag)

    # --- Label statistics ---
    n_train_label_files, n_train_boxes = count_label_files_and_boxes(train_labels_dir)
    n_val_label_files,   n_val_boxes   = count_label_files_and_boxes(val_labels_dir)

    # group count estimate (for practical absn conversion)
    n_train_groups_est = n_train_boxes

    cases = list_label_cases_for_dataset(ds_root)
    class_names = infer_class_names_from_labels(labels_root)
    nc = len(class_names)

    info = {
        "dataset": ds_root.name,
        "root": str(ds_root),
        "images_root": str(images_root),
        "labels_root": str(labels_root),

        "train_dir": str(train_dir) if train_dir else None,
        "val_dir": str(val_dir) if val_dir else None,

        # ✅ Used directly by Cell 2
        "train_labels_dir": str(train_labels_dir) if train_labels_dir else None,
        "val_labels_dir": str(val_labels_dir) if val_labels_dir else None,

        "n_train": n_train,
        "n_val": n_val,

        # ✅ Label statistics
        "n_train_label_files": n_train_label_files,
        "n_val_label_files": n_val_label_files,
        "n_train_boxes": n_train_boxes,
        "n_val_boxes": n_val_boxes,
        "n_train_groups_est": n_train_groups_est,

        "split_mode": split_mode,
        "train_tag": train_tag,
        "val_tag": val_tag,
        "label_cases": [c[0] for c in cases],
        "nc_inferred": nc,
        "class_names_inferred": class_names,
    }
    dataset_summaries.append(info)

    print("\n" + "-" * 80)
    print(f"[Dataset] {ds_root.name}")
    print(f" - root        : {ds_root}")
    print(f" - split_mode  : {split_mode}")
    print(f" - train_dir   : {train_dir if train_dir else '(missing)'} | tag={train_tag} | n_train={n_train}")
    print(f" - val_dir     : {val_dir if val_dir else '(missing)'} | tag={val_tag} | n_val={n_val}")

    test_dir = split_info.get("test_img_dir", None)
    if test_dir and test_dir.is_dir():
        n_test = len(list_images(test_dir))
        print(f" - test_dir    : {test_dir} | n_test={n_test}")

    print(f" - train_labels_dir : {train_labels_dir}")
    print(f" - val_labels_dir   : {val_labels_dir}")
    print(f" - train label files/boxes/groups_est : {n_train_label_files} / {n_train_boxes} / {n_train_groups_est}")
    print(f" - val   label files/boxes           : {n_val_label_files} / {n_val_boxes}")

    print(f" - label_cases : {[c[0] for c in cases] if cases else '(none)'}")
    print(f" - inferred classes (multiclass-based): nc={nc}, names={class_names}")
    print("-" * 80)

print("\n✅ Cell 1 done.")
print(f"   -> dataset_summaries length = {len(dataset_summaries)}")
print("   -> roots variable is ready for Cell 2.")


[DISCOVERY] Found 13 unique dataset roots under: /home/ISW/project/datasets

--------------------------------------------------------------------------------
[Dataset] SKU-110K
 - root        : /home/ISW/project/datasets/SKU-110K
 - split_mode  : sku_virtual_8_2
 - train_dir   : /home/ISW/project/datasets/SKU-110K/images | tag=virtual_8_2 | n_train=9394
 - val_dir     : /home/ISW/project/datasets/SKU-110K/images | tag=virtual_8_2 | n_val=2349
 - train_labels_dir : /home/ISW/project/datasets/SKU-110K/labels
 - val_labels_dir   : /home/ISW/project/datasets/SKU-110K/labels
 - train label files/boxes/groups_est : 11743 / 1730996 / 1730996
 - val   label files/boxes           : 11743 / 1730996
 - label_cases : ['original', 'scale_0.6', 'scale_0.7', 'scale_0.8', 'scale_0.9', 'scale_1.1', 'scale_1.2', 'scale_1.3', 'scale_1.4', 'side_1', 'side_3', 'side_5', 'side_7', 'side_9']
 - inferred classes (multiclass-based): nc=1, names=['class_0']
------------------------------------------------------

In [11]:
# ==========================================
# Cell 2) ReBox Core — FINAL (PATCHED)
#   - Fix: Remove logcumsumexp_backward unimplemented error in AMP(fp16)
#          => Compute ListMLE/Mono/MSE loss in FP32 (autocast disabled)
#   - Save: Include model_state_dict in checkpoint + save exp/config/meta
#   - Path: OUT_ROOT/{dataset}/{exp}/{absn_####_ctx#}/{ablation_tag}/log.csv
#           OUT_ROOT/weights/{dataset}/{exp}/{absn_####_ctx#}/{ablation_tag}/best.pt
#   - Generate candidates from noisy-anchor labels + compute IoU score with clean-target labels
#   - Mixed anchor support: __case_map.json
# ==========================================

from __future__ import annotations
import os, json, math, random, hashlib, re, time
from dataclasses import dataclass, asdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision.transforms as T
import torchvision.models as tv_models
from PIL import Image

# -----------------------------
# Globals
# -----------------------------
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]

VAL_IMAGE_EXTS = {".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp"}

# -----------------------------
# Stable hash
# -----------------------------
def stable_hash32(s: str) -> int:
    h = hashlib.md5(s.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

# -----------------------------
# YOLO label IO
# -----------------------------
def read_yolo_labels(txt_path: Path) -> List[Tuple[int, float, float, float, float]]:
    out = []
    if not txt_path.exists():
        return out
    try:
        for line in txt_path.read_text(encoding="utf-8", errors="ignore").splitlines():
            parts = line.strip().split()
            if len(parts) < 5:
                continue
            cls = int(float(parts[0]))
            cx, cy, w, h = map(float, parts[1:5])
            out.append((cls, cx, cy, w, h))
    except Exception:
        return []
    return out

def _clamp01(x: float) -> float:
    return max(0.0, min(1.0, x))

def apply_side_pattern_norm(cx, cy, w, h, dl, dr, dt, db, eps: float = 1e-6):
    new_w = w * (1.0 + dl + dr)
    new_h = h * (1.0 + dt + db)
    new_w = max(eps, new_w); new_h = max(eps, new_h)

    new_cx = cx + (dr - dl) * w * 0.5
    new_cy = cy + (db - dt) * h * 0.5

    x1 = _clamp01(new_cx - new_w/2); x2 = _clamp01(new_cx + new_w/2)
    y1 = _clamp01(new_cy - new_h/2); y2 = _clamp01(new_cy + new_h/2)

    if x2 <= x1:
        mid = _clamp01(new_cx)
        x1 = _clamp01(mid - eps/2); x2 = _clamp01(mid + eps/2)
    if y2 <= y1:
        mid = _clamp01(new_cy)
        y1 = _clamp01(mid - eps/2); y2 = _clamp01(mid + eps/2)

    out_cx = (x1 + x2)/2; out_cy = (y1 + y2)/2
    out_w  = max(eps, x2 - x1); out_h = max(eps, y2 - y1)
    return out_cx, out_cy, out_w, out_h

def apply_scale_as_side(cx, cy, w, h, scale: float, eps: float = 1e-6):
    d = (float(scale) - 1.0) / 2.0
    return apply_side_pattern_norm(cx, cy, w, h, d, d, d, d, eps=eps)

# -----------------------------
# IoU (normalized cxcywh)
# -----------------------------
def _to_xyxy(cx, cy, w, h):
    return cx - w/2, cy - h/2, cx + w/2, cy + h/2

def box_iou(a: Tuple[float,float,float,float], b: Tuple[float,float,float,float], eps=1e-9) -> float:
    ax1, ay1, ax2, ay2 = _to_xyxy(*a)
    bx1, by1, bx2, by2 = _to_xyxy(*b)
    inter_x1 = max(ax1, bx1); inter_y1 = max(ay1, by1)
    inter_x2 = min(ax2, bx2); inter_y2 = min(ay2, by2)
    iw = max(0.0, inter_x2 - inter_x1); ih = max(0.0, inter_y2 - inter_y1)
    inter = iw * ih
    area_a = max(0.0, ax2-ax1) * max(0.0, ay2-ay1)
    area_b = max(0.0, bx2-bx1) * max(0.0, by2-by1)
    union = area_a + area_b - inter
    return float(inter / (union + eps))

# -----------------------------
# start-case map loader (mixed anchor support)
# -----------------------------
def load_case_map_if_exists(anchor_labels_dir: str) -> Optional[Dict[str, str]]:
    p = Path(anchor_labels_dir) / "__case_map.json"
    if p.exists():
        try:
            m = json.loads(p.read_text(encoding="utf-8"))
            if isinstance(m, dict):
                return {str(k): str(v) for k, v in m.items()}
        except Exception:
            return None
    return None

def _parse_case_from_dirname(name: str) -> Optional[str]:
    m = re.match(r"^labels_uniform_scaling_(.+)$", name)
    if m:
        raw = m.group(1).replace("p", ".")
        try:
            s = float(raw)
            return f"scale_{s:g}"
        except Exception:
            return None
    m = re.match(r"^labels_boundary_jitter_(\d+)$", name)
    if m:
        return f"side_{int(m.group(1))}"
    return None

def infer_case_tag_for_file(rel_txt: Path, anchor_labels_dir: str, case_map: Optional[Dict[str, str]]) -> str:
    if case_map is not None:
        key = rel_txt.as_posix()
        if key in case_map:
            v = case_map[key]
            tag = _parse_case_from_dirname(Path(v).name) or _parse_case_from_dirname(v) or v
            if tag:
                return tag

    parts = [p.name for p in Path(anchor_labels_dir).parts]
    for dn in parts[::-1]:
        tag = _parse_case_from_dirname(dn)
        if tag:
            return tag

    return "clean"

# -----------------------------
# Side-noise sampler (same concept as noisy_insection)
# -----------------------------
SIDE_LEVELS = [-0.40, -0.30, -0.20, -0.10, 0.10, 0.20, 0.30, 0.40]
SIDE_NSIDES_RANGE = (2, 4)

def _severity_rank_of_K(K: int, jitter_patterns: List[int]) -> float:
    pcs = sorted(set(int(x) for x in jitter_patterns if int(x) > 0))
    if not pcs:
        return 0.0
    if K <= pcs[0]: return 0.0
    if K >= pcs[-1]: return 1.0
    idx = pcs.index(K) if K in pcs else max(0, sum(1 for x in pcs if x < K) - 1)
    return float(idx) / float(max(1, len(pcs) - 1))

def _levels_and_weights_for_K(K: int, jitter_patterns: List[int], side_levels: List[float]):
    abs_levels = sorted({abs(float(x)) for x in side_levels if abs(float(x)) > 1e-12})
    if not abs_levels:
        allowed = [float(x) for x in side_levels]
        return allowed, [1.0]*len(allowed)
    r = _severity_rank_of_K(K, jitter_patterns)  # 0..1
    idx = int(round(r * (len(abs_levels) - 1)))
    max_abs = abs_levels[idx]
    allowed = [float(x) for x in side_levels if abs(float(x)) <= max_abs + 1e-12]
    if not allowed:
        allowed = [float(x) for x in side_levels]
    gamma = 1.0 + 5.0 * r  # 1..6
    weights = [(abs(x) + 1e-9) ** gamma for x in allowed]
    return allowed, weights

def sample_side_pattern_for_box(
    rng: random.Random,
    K: int,
    jitter_patterns: List[int],
    require_mixed_signs: bool = True,
) -> Tuple[float,float,float,float]:
    m = rng.randint(SIDE_NSIDES_RANGE[0], SIDE_NSIDES_RANGE[1])
    chosen = rng.sample(["L","R","T","B"], k=m)

    allowed, weights = _levels_and_weights_for_K(K, jitter_patterns, SIDE_LEVELS)

    def pick_once():
        d = {"L":0.0,"R":0.0,"T":0.0,"B":0.0}
        vals = []
        for s in chosen:
            v = rng.choices(allowed, weights=weights, k=1)[0]
            d[s] = v
            vals.append(v)
        return d, vals

    for _ in range(30):
        d, vals = pick_once()
        if (not require_mixed_signs) or (any(v>0 for v in vals) and any(v<0 for v in vals)):
            return d["L"], d["R"], d["T"], d["B"]

    d, _ = pick_once()
    return d["L"], d["R"], d["T"], d["B"]

# -----------------------------
# Inverse delta (analytical inverse of boundary-jitter-noise)
# -----------------------------
def inverse_deltas_of_side(dl, dr, dt, db, eps: float = 1e-9):
    s_lr = max(eps, 1.0 + dl + dr)
    sum_inv  = (1.0 / s_lr) - 1.0
    diff_inv = -(dr - dl) / s_lr
    dr_inv = 0.5*(sum_inv + diff_inv)
    dl_inv = 0.5*(sum_inv - diff_inv)

    s_tb = max(eps, 1.0 + dt + db)
    sum_inv2  = (1.0 / s_tb) - 1.0
    diff_inv2 = -(db - dt) / s_tb
    db_inv = 0.5*(sum_inv2 + diff_inv2)
    dt_inv = 0.5*(sum_inv2 - diff_inv2)

    return dl_inv, dr_inv, dt_inv, db_inv

# -----------------------------
# Candidate generation config
# -----------------------------
@dataclass
class CandidateCfg:
    cand_uniform_scaling_factors: List[float]
    cand_side_ks: List[int]
    num_border_perturb: int = 10
    include_anchor: bool = True
    include_inverse: bool = True
    inverse_jitter: float = 0.03
    require_mixed_signs: bool = True
    max_candidates_per_group: int = 60

def auto_expand_uniform_scaling_factors_for_refine(anchor_uniform_scaling_factors: List[float], cap: float = 2.2) -> List[float]:
    xs = []
    for s in anchor_uniform_scaling_factors:
        s = float(s)
        if s <= 0:
            continue
        xs.append(s)
        xs.append(1.0 / s)
    xs.append(1.0)
    xs = [min(cap, max(0.2, float(x))) for x in xs]
    xs = sorted(set(round(x, 4) for x in xs))
    return xs

def _dedup_keep_first(boxes: List[Tuple[float,float,float,float]],
                      dists: List[float],
                      tags: List[str],
                      tol: float = 1e-6):
    mp = {}
    order = []
    for b, d, t in zip(boxes, dists, tags):
        key = tuple(int(round(x / tol)) for x in b)
        if key in mp:
            continue
        mp[key] = (b, d, t)
        order.append(key)
    out_b, out_d, out_t = [], [], []
    for k in order:
        b, d, t = mp[k]
        out_b.append(b); out_d.append(float(d)); out_t.append(str(t))
    return out_b, out_d, out_t

def build_candidates_for_one_box(
    anchor_box: Tuple[float,float,float,float],
    case_tag: str,
    clean_lbl_path: Path,   # ✅ Hash basis for noise generator reproduction
    box_index: int,
    seed: int,
    cand_cfg: CandidateCfg,
    jitter_patterns: List[int],
) -> Tuple[List[Tuple[float,float,float,float]], List[float], List[str]]:
    cx, cy, w, h = anchor_box
    candidates: List[Tuple[float,float,float,float]] = []
    dists: List[float] = []
    tags: List[str] = []

    base_h = stable_hash32(str(clean_lbl_path))
    base_rng = random.Random((base_h ^ int(seed) ^ (box_index * 97531)) & 0xFFFFFFFF)

    # 0) anchor
    if cand_cfg.include_anchor:
        candidates.append((cx,cy,w,h)); dists.append(0.0); tags.append("anchor")

    # 1) inverse
    if cand_cfg.include_inverse and case_tag.startswith("scale_"):
        try:
            S = float(case_tag.split("_",1)[1])
            inv = min(max(1.0 / max(1e-9, S), 0.2), 2.2)
            bc = apply_scale_as_side(cx,cy,w,h, inv)
            candidates.append(bc); dists.append(abs(inv-1.0)); tags.append(f"inv_scale({inv:.3f})")

            jit = float(cand_cfg.inverse_jitter)
            for mul in [1.0-jit, 1.0+jit]:
                inv2 = min(max(inv*mul, 0.2), 2.2)
                bc2 = apply_scale_as_side(cx,cy,w,h, inv2)
                candidates.append(bc2); dists.append(abs(inv2-1.0)); tags.append(f"inv_scale_jit({inv2:.3f})")
        except Exception:
            pass

    if cand_cfg.include_inverse and case_tag.startswith("side_"):
        try:
            K = int(case_tag.split("_",1)[1])
            file_seed = (base_h ^ int(seed) ^ (int(K)*2654435761)) & 0xFFFFFFFF
            frnd = random.Random(file_seed)

            dl=dr=dt=db=0.0
            for _i in range(box_index+1):
                dl, dr, dt, db = sample_side_pattern_for_box(
                    frnd, K=K, jitter_patterns=jitter_patterns,
                    require_mixed_signs=True
                )

            dl_inv, dr_inv, dt_inv, db_inv = inverse_deltas_of_side(dl,dr,dt,db)
            bc = apply_side_pattern_norm(cx,cy,w,h, dl_inv,dr_inv,dt_inv,db_inv)
            inv_dist = abs(dl_inv)+abs(dr_inv)+abs(dt_inv)+abs(db_inv)
            candidates.append(bc); dists.append(inv_dist); tags.append(f"inv_side(K={K})")

            jit = float(cand_cfg.inverse_jitter)
            for _ in range(3):
                jdl = dl_inv + base_rng.uniform(-jit, +jit)
                jdr = dr_inv + base_rng.uniform(-jit, +jit)
                jdt = dt_inv + base_rng.uniform(-jit, +jit)
                jdb = db_inv + base_rng.uniform(-jit, +jit)
                bc2 = apply_side_pattern_norm(cx,cy,w,h, jdl,jdr,jdt,jdb)
                candidates.append(bc2)
                dists.append(abs(jdl)+abs(jdr)+abs(jdt)+abs(jdb))
                tags.append("inv_side_jit")
        except Exception:
            pass

    # 2) expanded isotropic resizing candidates (use all; cap handled by max_candidates)
    for s in cand_cfg.cand_uniform_scaling_factors:
        try:
            s = float(s)
            bc = apply_scale_as_side(cx,cy,w,h, s)
            candidates.append(bc)
            dists.append(abs(s-1.0))
            tags.append(f"scale({s:g})")
        except Exception:
            pass

    # 3) random border-wise perturbation candidates (severity-aware)
    ks_pool = [int(k) for k in cand_cfg.cand_side_ks if int(k) > 0] or [1,3,5,7,9]
    caseK = None
    if case_tag.startswith("side_"):
        try: caseK = int(case_tag.split("_",1)[1])
        except Exception: caseK = None

    for j in range(int(cand_cfg.num_border_perturb)):
        rr = random.Random((base_h ^ (box_index*19260817) ^ (j*104729) ^ int(seed)) & 0xFFFFFFFF)
        if caseK is not None and len(ks_pool) >= 2:
            weights = [1.0 / (1.0 + abs(k - caseK)) for k in ks_pool]
            K = rr.choices(ks_pool, weights=weights, k=1)[0]
        else:
            K = rr.choice(ks_pool)

        dl, dr, dt, db = sample_side_pattern_for_box(
            rr, K=K, jitter_patterns=jitter_patterns,
            require_mixed_signs=cand_cfg.require_mixed_signs
        )
        bc = apply_side_pattern_norm(cx,cy,w,h, dl,dr,dt,db)
        candidates.append(bc)
        dists.append(abs(dl)+abs(dr)+abs(dt)+abs(db))
        tags.append(f"border_perturb(K={K})")

    # dedup + sort by dist (anchor first)
    candidates, dists, tags = _dedup_keep_first(candidates, dists, tags, tol=1e-6)
    order = np.argsort(np.array(dists, dtype=np.float64))
    candidates = [candidates[i] for i in order.tolist()]
    dists      = [float(dists[i]) for i in order.tolist()]
    tags       = [tags[i] for i in order.tolist()]

    # hard cap
    if len(candidates) > int(cand_cfg.max_candidates_per_group):
        candidates = candidates[:int(cand_cfg.max_candidates_per_group)]
        dists      = dists[:int(cand_cfg.max_candidates_per_group)]
        tags       = tags[:int(cand_cfg.max_candidates_per_group)]

    return candidates, dists, tags

# -----------------------------
# TrainConfig
# -----------------------------
@dataclass
class TrainConfig:
    data_root: str
    dataset_name: str
    exp_name: str
    seed: int
    device: str
    out_root: str

    backbone_key: str = "densenet"
    d_model: int = 512
    nhead: int = 8
    nlayers: int = 2

    budget_mode: str = "absn"
    absn: int = 100

    use_context: bool = True
    disable_dist_mlp: bool = False
    share_backbone: bool = False

    w_listmle: float = 1.0
    w_mono: float = 0.0
    w_mse: float = 0.25

    # runtime
    num_workers: int = 4
    prefetch_factor: int = 2
    persistent_workers: bool = True
    pin_memory: bool = True
    use_amp: bool = True

    epochs: int = 10
    eval_every: int = 5
    batch_size_set: int = 2
    crop_size: int = 224
    lr: float = 2e-4
    weight_decay: float = 0.05

    # explicit dirs
    train_images_dir: str = ""
    valid_images_dir: str = ""
    train_labels_dir: str = ""          # anchor
    valid_labels_dir: str = ""          # anchor
    target_train_labels_dir: str = ""   # clean target
    target_valid_labels_dir: str = ""   # clean target

    # candidate controls
    cand_uniform_scaling_factors: Optional[List[float]] = None
    cand_side_ks: Optional[List[int]] = None
    jitter_patterns: Optional[List[int]] = None
    num_border_perturb: int = 10
    include_inverse: bool = True
    inverse_jitter: float = 0.03
    max_candidates_per_group: int = 60
    require_mixed_signs: bool = True

# -----------------------------
# Meta build
# -----------------------------
def iter_images_with_labels(images_dir: Path, labels_dir: Path) -> List[Path]:
    all_imgs = sorted([p for p in images_dir.rglob("*") if p.is_file() and p.suffix.lower() in VAL_IMAGE_EXTS])
    out = []
    for img in all_imgs:
        rel = img.relative_to(images_dir)
        if (labels_dir / rel.with_suffix(".txt")).exists():
            out.append(img)
    return out

def build_meta_from_anchor_and_clean(
    images_dir: str,
    anchor_labels_dir: str,
    clean_labels_dir: str,
    seed: int,
    cand_cfg: CandidateCfg,
    jitter_patterns: List[int],
) -> List[Dict[str, Any]]:
    images_dir = Path(images_dir)
    anchor_labels_dir = Path(anchor_labels_dir)
    clean_labels_dir = Path(clean_labels_dir)

    case_map = load_case_map_if_exists(str(anchor_labels_dir))
    img_list = iter_images_with_labels(images_dir, anchor_labels_dir)

    meta: List[Dict[str, Any]] = []
    for img_path in img_list:
        rel_img = img_path.relative_to(images_dir)
        rel_txt = rel_img.with_suffix(".txt")

        a_txt = anchor_labels_dir / rel_txt
        c_txt = clean_labels_dir / rel_txt
        if (not a_txt.exists()) or (not c_txt.exists()):
            continue

        a_lbls = read_yolo_labels(a_txt)
        c_lbls = read_yolo_labels(c_txt)
        if (not a_lbls) or (not c_lbls):
            continue

        # alignment fallback: if lengths differ, use class-wise greedy IoU match
        if len(a_lbls) != len(c_lbls):
            used = set()
            pairs = []
            for (acls, acx, acy, aw, ah) in a_lbls:
                best_j = None
                best_iou = -1.0
                for j,(ccls, ccx, ccy, cw, ch) in enumerate(c_lbls):
                    if j in used: continue
                    if ccls != acls: continue
                    iou = box_iou((acx,acy,aw,ah),(ccx,ccy,cw,ch))
                    if iou > best_iou:
                        best_iou = iou; best_j = j
                if best_j is None:
                    continue
                used.add(best_j)
                pairs.append(((acls, acx, acy, aw, ah), c_lbls[best_j]))
        else:
            pairs = list(zip(a_lbls, c_lbls))

        case_tag = infer_case_tag_for_file(rel_txt, str(anchor_labels_dir), case_map)

        for j, (a, c) in enumerate(pairs):
            acls, acx, acy, aw, ah = a
            ccls, ccx, ccy, cw, ch = c
            if acls != ccls:
                continue

            cand_boxes, cand_dists, cand_tags = build_candidates_for_one_box(
                anchor_box=(acx,acy,aw,ah),
                case_tag=case_tag,
                clean_lbl_path=c_txt,    # ✅ Important: noise generator hash basis (original labels path)
                box_index=j,
                seed=seed,
                cand_cfg=cand_cfg,
                jitter_patterns=jitter_patterns,
            )
            if not cand_boxes:
                continue

            target = (ccx,ccy,cw,ch)
            scores = [box_iou(b, target) for b in cand_boxes]

            meta.append({
                "img_rel": rel_img.as_posix(),
                "obj_idx": j,
                "cls": int(acls),
                "anchor_box": (acx,acy,aw,ah),
                "target_box": (ccx,ccy,cw,ch),
                "cand_boxes": cand_boxes,
                "cand_scores": scores,
                "cand_dists": cand_dists,
                "case_tag": case_tag,
            })

    return meta

# -----------------------------
# Dataset
# -----------------------------
def load_rgb(path: Path) -> Image.Image:
    img = Image.open(path)
    if img.mode != "RGB":
        img = img.convert("RGB")
    return img

def _make_crop_tf(size: int):
    return T.Compose([
        T.Resize((size,size)),
        T.ToTensor(),
        T.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])

def crop_by_norm_bbox(img: Image.Image, box: Tuple[float,float,float,float], size: int, tf_crop) -> torch.Tensor:
    W, H = img.size
    cx, cy, w, h = box
    x1 = int(round((cx - w/2) * W)); x2 = int(round((cx + w/2) * W))
    y1 = int(round((cy - h/2) * H)); y2 = int(round((cy + h/2) * H))
    x1 = max(0, min(W-1, x1)); x2 = max(1, min(W, x2))
    y1 = max(0, min(H-1, y1)); y2 = max(1, min(H, y2))
    crop = img.crop((x1,y1,x2,y2))
    return tf_crop(crop)

class GroupedCandidateDataset(Dataset):
    def __init__(self, images_dir: str, meta: List[Dict[str,Any]], crop_size: int, use_context: bool):
        self.images_dir = Path(images_dir)
        self.meta = meta
        self.crop_size = int(crop_size)
        self.use_context = bool(use_context)
        self._tf_crop = _make_crop_tf(self.crop_size)
        self._tf_ctx  = _make_crop_tf(self.crop_size)

    def __len__(self): return len(self.meta)

    def __getitem__(self, i: int):
        m = self.meta[i]
        img = load_rgb(self.images_dir / m["img_rel"])
        cand_boxes = m["cand_boxes"]
        Xc = torch.stack([crop_by_norm_bbox(img, b, self.crop_size, self._tf_crop) for b in cand_boxes], dim=0)

        if self.use_context:
            Xctx = self._tf_ctx(img)
        else:
            Xctx = torch.zeros((3,self.crop_size,self.crop_size), dtype=Xc.dtype)

        Y = torch.tensor(m["cand_scores"], dtype=torch.float32)[:,None]  # (N,1)
        D = torch.tensor(m["cand_dists"], dtype=torch.float32)          # (N,)
        gid = f"{m['img_rel']}|{m['obj_idx']}"
        case_tag = m["case_tag"]
        return Xc, Xctx, Y, D, gid, case_tag

def collate_pad(batch):
    # returns mask_valid=True for valid entries
    B = len(batch)
    Ns = [b[0].shape[0] for b in batch]
    maxN = max(Ns)

    Xc0 = batch[0][0]
    _, C, H, W = Xc0.shape

    Xc = torch.zeros((B, maxN, C, H, W), dtype=Xc0.dtype)
    Y  = torch.zeros((B, maxN, 1), dtype=torch.float32)
    D  = torch.zeros((B, maxN), dtype=torch.float32)
    mask_valid = torch.zeros((B, maxN), dtype=torch.bool)

    Xctx = torch.stack([b[1] for b in batch], dim=0)
    gids, case_tags = [], []

    for i,(xci,xctxi,yi,di,gid,ctag) in enumerate(batch):
        n = xci.shape[0]
        Xc[i,:n] = xci
        Y[i,:n]  = yi
        D[i,:n]  = di
        mask_valid[i,:n] = True
        gids.append(gid)
        case_tags.append(ctag)

    return Xc, Xctx, Y, D, mask_valid, gids, case_tags

# -----------------------------
# Backbone
# -----------------------------
def build_backbone(name: str):
    name = str(name).lower()
    if name == "resnet":
        m = tv_models.resnet50(weights=tv_models.ResNet50_Weights.IMAGENET1K_V2)
        feat_dim = m.fc.in_features
        m.fc = nn.Identity()
        return m, feat_dim
    if name == "densenet":
        m = tv_models.densenet121(weights=tv_models.DenseNet121_Weights.IMAGENET1K_V1)
        feat_dim = m.classifier.in_features
        m.classifier = nn.Identity()
        return m, feat_dim
    if name == "efficient":
        m = tv_models.efficientnet_v2_s(weights=tv_models.EfficientNet_V2_S_Weights.IMAGENET1K_V1)
        feat_dim = m.classifier[1].in_features
        m.classifier = nn.Identity()
        return m, feat_dim
    if name == "vit":
        m = tv_models.vit_b_16(weights=tv_models.ViT_B_16_Weights.IMAGENET1K_V1)
        feat_dim = m.heads.head.in_features
        m.heads = nn.Identity()
        return m, feat_dim
    raise ValueError(name)

class ReBox(nn.Module):
    def __init__(self, backbone_key="densenet", d_model=512, nhead=8, nlayers=2,
                 use_context=True, disable_dist_mlp=False, share_backbone=False):
        super().__init__()
        self.use_context = bool(use_context)
        self.disable_dist_mlp = bool(disable_dist_mlp)
        self.share_backbone = bool(share_backbone)

        self.backbone_c, feat_dim = build_backbone(backbone_key)
        if self.use_context:
            self.backbone_ctx = self.backbone_c if self.share_backbone else build_backbone(backbone_key)[0]
        else:
            self.backbone_ctx = None

        self.proj_c = nn.Linear(feat_dim, d_model)
        self.proj_ctx = nn.Linear(feat_dim, d_model) if self.use_context else None

        self.dist_mlp = None
        if not self.disable_dist_mlp:
            self.dist_mlp = nn.Sequential(
                nn.Linear(1, d_model),
                nn.ReLU(inplace=True),
                nn.Linear(d_model, d_model),
            )

        enc_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.enc = nn.TransformerEncoder(enc_layer, num_layers=nlayers)
        self.head = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(inplace=True),
            nn.Linear(d_model, 1),
        )

    def forward(self, Xc, Xctx, D, mask_valid):
        # mask_valid: True=valid, False=pad
        B,N,_,_,_ = Xc.shape
        x = Xc.reshape(B*N, 3, Xc.shape[-2], Xc.shape[-1])
        feat_c = self.backbone_c(x).view(B, N, -1)
        h = self.proj_c(feat_c)

        if self.use_context:
            feat_ctx = self.backbone_ctx(Xctx)  # (B,feat)
            feat_ctx = self.proj_ctx(feat_ctx)  # (B,d)
            h = h + feat_ctx[:,None,:]

        if self.dist_mlp is not None:
            h = h + self.dist_mlp(D[...,None])

        src_key_padding_mask = ~mask_valid  # True for padding
        z = self.enc(h, src_key_padding_mask=src_key_padding_mask)
        out = self.head(z).squeeze(-1)      # (B,N)

        neg = torch.finfo(out.dtype).min
        out = out.masked_fill(~mask_valid, neg)
        return out

# -----------------------------
# Loss (FP32 safe)
# -----------------------------
def listmle_loss_fp32(scores_pred_f32, scores_true_f32, mask_valid):
    # scores_pred_f32: (B,N) float32, scores_true_f32: (B,N,1) float32
    y = scores_true_f32.squeeze(-1)
    B,N = scores_pred_f32.shape
    loss = scores_pred_f32.new_zeros(())
    cnt = 0
    for b in range(B):
        idx = torch.nonzero(mask_valid[b], as_tuple=False).squeeze(-1)
        if idx.numel() < 2:
            continue
        yb = y[b, idx]
        pb = scores_pred_f32[b, idx]
        order = torch.argsort(yb, descending=True)
        pb = pb[order]
        lse = torch.logcumsumexp(pb.flip(0), dim=0).flip(0)
        loss = loss + (lse - pb).mean()
        cnt += 1
    if cnt == 0:
        return scores_pred_f32.sum() * 0.0
    return loss / float(cnt)

def monotone_hinge_by_dist_fp32(scores_pred_f32, D_f32, mask_valid, margin=0.0):
    # enforce: if dist_i > dist_j then pred_i <= pred_j
    B,N = scores_pred_f32.shape
    loss = scores_pred_f32.new_zeros(())
    cnt = 0
    for b in range(B):
        idx = torch.nonzero(mask_valid[b], as_tuple=False).squeeze(-1)
        if idx.numel() < 2:
            continue
        pb = scores_pred_f32[b, idx]  # (n,)
        db = D_f32[b, idx]            # (n,)

        di = db[:,None]
        dj = db[None,:]
        pi = pb[:,None]
        pj = pb[None,:]

        M = (di > dj).float()
        hinge = F.relu(pi - pj + margin)
        loss_b = (hinge * M).sum() / (M.sum() + 1e-9)
        loss = loss + loss_b
        cnt += 1
    if cnt == 0:
        return scores_pred_f32.sum() * 0.0
    return loss / float(cnt)

# -----------------------------
# Eval (overall + per-case)
# -----------------------------
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    all_pred_iou, all_best_iou, all_anchor_iou = [], [], []
    by_case = {}

    for Xc, Xctx, Y, D, mask_valid, gids, case_tags in loader:
        Xc = Xc.to(device, non_blocking=True)
        Xctx = Xctx.to(device, non_blocking=True)
        Y = Y.to(device, non_blocking=True)
        D = D.to(device, non_blocking=True)
        mask_valid = mask_valid.to(device, non_blocking=True)

        pred = model(Xc, Xctx, D, mask_valid)  # (B,N)
        idx = torch.argmax(pred, dim=1)        # (B,)
        B = pred.shape[0]

        for b in range(B):
            ct = case_tags[b]
            yb = Y[b].squeeze(-1)  # (N,)
            mb = mask_valid[b]
            yv = yb[mb]
            if yv.numel() == 0:
                continue

            best_iou  = float(yv.max().item())
            pred_iou  = float(yb[idx[b]].item())
            # anchor is always first candidate (index 0) due to dist sorting
            anchor_iou = float(yb[0].item()) if mb[0].item() else float(yv[0].item())

            all_best_iou.append(best_iou)
            all_pred_iou.append(pred_iou)
            all_anchor_iou.append(anchor_iou)

            d = by_case.setdefault(ct, {"n":0, "best":[], "pred":[], "anchor":[]})
            d["n"] += 1
            d["best"].append(best_iou)
            d["pred"].append(pred_iou)
            d["anchor"].append(anchor_iou)

    def _mean(x): return float(np.mean(x)) if len(x) else 0.0
    metrics = {
        "mean_best_iou": _mean(all_best_iou),
        "mean_pred_iou": _mean(all_pred_iou),
        "mean_anchor_iou": _mean(all_anchor_iou),
        "mean_gain_pred_minus_anchor": _mean([p-a for p,a in zip(all_pred_iou, all_anchor_iou)]),
    }
    metrics_by_case = {}
    for k,v in by_case.items():
        metrics_by_case[k] = {
            "n": int(v["n"]),
            "mean_best_iou": _mean(v["best"]),
            "mean_pred_iou": _mean(v["pred"]),
            "mean_anchor_iou": _mean(v["anchor"]),
            "mean_gain": _mean([p-a for p,a in zip(v["pred"], v["anchor"])]),
        }
    metrics["by_case"] = metrics_by_case
    return metrics

# -----------------------------
# Tag/path helpers
# -----------------------------
def _fmt_w(x: float) -> str:
    # 0.25 -> 0p25
    s = f"{float(x):g}".replace(".", "p")
    return s

def build_budget_tag(cfg: TrainConfig) -> str:
    return f"absn_{int(cfg.absn):04d}"

def build_ctx_tag(cfg: TrainConfig) -> str:
    return f"ctx{int(bool(cfg.use_context))}"

def build_ablation_tag(cfg: TrainConfig) -> str:
    dist1 = int(not bool(cfg.disable_dist_mlp))
    share = int(bool(cfg.share_backbone))
    inv   = int(bool(cfg.include_inverse))
    tag = (
        f"d{int(cfg.d_model)}_h{int(cfg.nhead)}_l{int(cfg.nlayers)}"
        f"_dist{dist1}_share{share}"
        f"_wL{_fmt_w(cfg.w_listmle)}_wM{_fmt_w(cfg.w_mono)}_wE{_fmt_w(cfg.w_mse)}"
        f"_rnd{int(cfg.num_border_perturb)}"
        f"_inv{inv}_jit{_fmt_w(cfg.inverse_jitter)}"
        f"_cap{int(cfg.max_candidates_per_group)}"
    )
    return tag

def compute_run_paths(cfg: TrainConfig) -> Tuple[Path, Path]:
    out_root = Path(cfg.out_root)
    budget_tag = build_budget_tag(cfg)
    ctx_tag = build_ctx_tag(cfg)
    ablation_tag = build_ablation_tag(cfg)

    out_dir = out_root / cfg.dataset_name / cfg.exp_name / f"{budget_tag}_{ctx_tag}" / ablation_tag
    ckpt_dir = out_root / "weights" / cfg.dataset_name / cfg.exp_name / f"{budget_tag}_{ctx_tag}" / ablation_tag
    ckpt_path = ckpt_dir / "best.pt"
    return out_dir, ckpt_path

# -----------------------------
# Train loop (FP16-safe)
# -----------------------------
def run_train_case(cfg: TrainConfig) -> Dict[str, Any]:
    device = torch.device(cfg.device)
    out_dir, ckpt_path = compute_run_paths(cfg)
    out_dir.mkdir(parents=True, exist_ok=True)
    ckpt_path.parent.mkdir(parents=True, exist_ok=True)

    # candidate cfg
    anchor_scales = cfg.cand_uniform_scaling_factors or [0.6,0.7,0.8,0.9,1.0,1.1,1.2,1.3,1.4]
    cand_scale = list(anchor_scales)
    cand_side_ks = cfg.cand_side_ks or [1,3,5,7,9]
    jitter_patterns = cfg.jitter_patterns or list(cand_side_ks)

    cand_cfg = CandidateCfg(
        cand_uniform_scaling_factors=list(cand_scale),
        cand_side_ks=list(cand_side_ks),
        num_border_perturb=int(cfg.num_border_perturb),
        include_anchor=True,
        include_inverse=bool(cfg.include_inverse),
        inverse_jitter=float(cfg.inverse_jitter),
        require_mixed_signs=bool(cfg.require_mixed_signs),
        max_candidates_per_group=int(cfg.max_candidates_per_group),
    )

    # meta
    t0 = time.time()
    train_meta = build_meta_from_anchor_and_clean(
        images_dir=cfg.train_images_dir,
        anchor_labels_dir=cfg.train_labels_dir,
        clean_labels_dir=cfg.target_train_labels_dir,
        seed=cfg.seed,
        cand_cfg=cand_cfg,
        jitter_patterns=list(jitter_patterns),
    )
    val_meta = build_meta_from_anchor_and_clean(
        images_dir=cfg.valid_images_dir,
        anchor_labels_dir=cfg.valid_labels_dir,
        clean_labels_dir=cfg.target_valid_labels_dir,
        seed=cfg.seed,
        cand_cfg=cand_cfg,
        jitter_patterns=list(jitter_patterns),
    )
    if len(train_meta) == 0 or len(val_meta) == 0:
        payload = {"status":"failed_no_meta", "n_train":len(train_meta), "n_val":len(val_meta)}
        (out_dir / "failed.json").write_text(json.dumps(payload, indent=2), encoding="utf-8")
        return payload

    # dataset/loader
    train_ds = GroupedCandidateDataset(cfg.train_images_dir, train_meta, cfg.crop_size, cfg.use_context)
    val_ds   = GroupedCandidateDataset(cfg.valid_images_dir,  val_meta,  cfg.crop_size, cfg.use_context)

    dl_kwargs = dict(
        num_workers=int(cfg.num_workers),
        pin_memory=bool(cfg.pin_memory) and str(cfg.device).startswith("cuda"),
        persistent_workers=bool(cfg.persistent_workers) and int(cfg.num_workers)>0,
        prefetch_factor=int(cfg.prefetch_factor) if int(cfg.num_workers)>0 else None,
    )
    if dl_kwargs["prefetch_factor"] is None:
        dl_kwargs.pop("prefetch_factor")

    train_loader = DataLoader(train_ds, batch_size=cfg.batch_size_set, shuffle=True, collate_fn=collate_pad, **dl_kwargs)
    val_loader   = DataLoader(val_ds,   batch_size=cfg.batch_size_set, shuffle=False, collate_fn=collate_pad, **dl_kwargs)

    # model
    model = ReBox(
        backbone_key=str(cfg.backbone_key),
        d_model=int(cfg.d_model),
        nhead=int(cfg.nhead),
        nlayers=int(cfg.nlayers),
        use_context=bool(cfg.use_context),
        disable_dist_mlp=bool(cfg.disable_dist_mlp),
        share_backbone=bool(cfg.share_backbone),
    ).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=float(cfg.lr), weight_decay=float(cfg.weight_decay))
    scaler = torch.cuda.amp.GradScaler(enabled=bool(cfg.use_amp) and device.type == "cuda")

    # logging
    log_path = out_dir / "log.csv"
    if not log_path.exists():
        log_path.write_text("epoch,train_loss,val_pred_iou,val_anchor_iou,val_gain\n", encoding="utf-8")

    best_metric = -1e9
    best_state = None
    best_epoch = 0

    for epoch in range(1, int(cfg.epochs)+1):
        model.train()
        losses = []

        for Xc, Xctx, Y, D, mask_valid, gids, case_tags in train_loader:
            Xc = Xc.to(device, non_blocking=True)
            Xctx = Xctx.to(device, non_blocking=True)
            Y = Y.to(device, non_blocking=True)
            D = D.to(device, non_blocking=True)
            mask_valid = mask_valid.to(device, non_blocking=True)

            opt.zero_grad(set_to_none=True)

            amp_ok = bool(cfg.use_amp) and device.type == "cuda"
            with torch.cuda.amp.autocast(enabled=amp_ok):
                pred = model(Xc, Xctx, D, mask_valid)  # (B,N) possibly fp16

            # ✅ Key: loss computed in FP32 (autocast OFF)
            with torch.cuda.amp.autocast(enabled=False):
                pred_f32 = pred.float()
                Y_f32 = Y.float()
                D_f32 = D.float()
                loss = pred_f32.new_zeros(())

                if float(cfg.w_listmle) > 0:
                    loss = loss + float(cfg.w_listmle) * listmle_loss_fp32(pred_f32, Y_f32, mask_valid)
                if float(cfg.w_mono) > 0:
                    loss = loss + float(cfg.w_mono) * monotone_hinge_by_dist_fp32(pred_f32, D_f32, mask_valid, margin=0.0)
                if float(cfg.w_mse) > 0:
                    y = Y_f32.squeeze(-1)
                    p = torch.sigmoid(pred_f32)
                    mse = ((p - y)**2)[mask_valid].mean()
                    loss = loss + float(cfg.w_mse) * mse

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

            losses.append(float(loss.detach().cpu().item()))

        train_loss = float(np.mean(losses)) if losses else 0.0

        if (epoch % int(cfg.eval_every)) == 0 or epoch == int(cfg.epochs):
            metrics = evaluate(model, val_loader, device=device)
            key = float(metrics["mean_pred_iou"])
            if key > best_metric:
                best_metric = key
                best_epoch = epoch
                best_state = {k: v.detach().cpu() for k,v in model.state_dict().items()}

            # append log
            line = f"{epoch},{train_loss:.6f},{metrics['mean_pred_iou']:.6f},{metrics['mean_anchor_iou']:.6f},{metrics['mean_gain_pred_minus_anchor']:.6f}\n"
            with open(log_path, "a", encoding="utf-8") as f:
                f.write(line)

    # load best and final eval
    if best_state is not None:
        model.load_state_dict(best_state, strict=True)

    final_metrics = evaluate(model, val_loader, device=device)

    # save ckpt
    exp_meta = {
        "backbone": str(cfg.backbone_key),
        "backbone_key": str(cfg.backbone_key),
        "d_model": int(cfg.d_model),
        "nhead": int(cfg.nhead),
        "nlayers": int(cfg.nlayers),
        "use_context": bool(cfg.use_context),
        "disable_dist_mlp": bool(cfg.disable_dist_mlp),
        "share_backbone": bool(cfg.share_backbone),
        "crop_size": int(cfg.crop_size),
        "use_amp": bool(cfg.use_amp) and device.type == "cuda",
        "seed": int(cfg.seed),
        "absn": int(cfg.absn),
        "budget_tag": build_budget_tag(cfg),
        "ctx_tag": build_ctx_tag(cfg),
        "ablation_tag": build_ablation_tag(cfg),
        "cand_cfg": asdict(cand_cfg),
        "jitter_patterns": list(jitter_patterns),
        "build_seconds": float(time.time() - t0),
        "n_train_groups": int(len(train_meta)),
        "n_val_groups": int(len(val_meta)),
        "out_dir": str(out_dir),
        "log_csv": str(log_path),
    }

    ckpt = {
        "model_state_dict": model.state_dict(),   # ✅ weights included
        "best_epoch": int(best_epoch),
        "best_metrics": final_metrics,
        "exp": exp_meta,
        "config": asdict(cfg),
    }
    torch.save(ckpt, ckpt_path)

    # also save config.json for quick view
    (out_dir / "config.json").write_text(json.dumps({"exp": exp_meta, "config": asdict(cfg)}, indent=2, ensure_ascii=False), encoding="utf-8")

    return {
        "status":"ok",
        "dataset": cfg.dataset_name,
        "exp_name": cfg.exp_name,
        "best_epoch": int(best_epoch),
        "best_val_pred_iou": float(final_metrics["mean_pred_iou"]),
        "best_val_anchor_iou": float(final_metrics["mean_anchor_iou"]),
        "best_val_gain": float(final_metrics["mean_gain_pred_minus_anchor"]),
        "best_ckpt": str(ckpt_path),
        "out_dir": str(out_dir),
        "log_csv": str(log_path),
        "metrics_by_case_json": json.dumps(final_metrics.get("by_case", {}), ensure_ascii=False),
    }


In [12]:
# ==========================================
# Cell 2-A) CaseSpec with Candidate Control — FINAL
#   - CaseSpec class: Model settings + candidate generation settings combined
#   - resolve_candidate_config(): CaseSpec -> CandidateCfg conversion
#   - CASE_SPECS_DEFAULT: Define experiment cases
#
# [Experiment Design]
#   - Baseline: both, max=60 (currently generates 31)
#   - Exp1: both, max=15 (includes scale+side, limited to 15) <- half
#   - Exp2: isotropic_only, 15 (pure isotropic resizing candidates only)
#   - Exp3: borderwise_only, 15 (pure border-wise perturbation candidates only)
# ==========================================

from __future__ import annotations
from dataclasses import dataclass, asdict
from typing import List, Optional

# -----------------------------
# CaseSpec (Extended: includes Candidate Control)
# -----------------------------
@dataclass(frozen=True)
class CaseSpec:
    """
    Experiment case definition: model structure + training settings + candidate generation settings
    
    cand_mode:
        - "both": generate both scale + border-wise perturbation candidates
        - "isotropic_only": generate only isotropic resizing candidates (no side)
        - "borderwise_only": generate only border-wise perturbation candidates (minimize scale)
    """
    # Existing fields: model structure
    case_group: str
    case_name: str
    backbone_key: str
    absn: int
    use_context: bool
    disable_dist_mlp: bool
    share_backbone: bool
    loss_name: str
    w_listmle: float
    w_mono: float
    w_mse: float
    d_model: int = 512
    nhead: int = 8
    nlayers: int = 2
    
    # ★ Candidate Control ★
    cand_mode: str = "both"           # "both" | "isotropic_only" | "borderwise_only"
    max_candidates: int = 60          # Maximum candidate count
    num_border_perturb: int = 10           # Border-wise perturbation candidate count
    include_inverse: bool = True      # Whether to include inverse candidates


# -----------------------------
# Scale factors generation function
# -----------------------------
def generate_uniform_scaling_factors_for_count(count: int, min_s: float = 0.6, max_s: float = 1.4) -> List[float]:
    """
    Generate specified number of scale factors (uniform distribution)
    
    e.g.: count=11, 0.6~1.4 -> [0.6, 0.68, 0.76, ..., 1.4]
    """
    if count <= 1:
        return [1.0]
    step = (max_s - min_s) / (count - 1)
    return [round(min_s + i * step, 2) for i in range(count)]


# -----------------------------
# resolve_candidate_config: CaseSpec → CandidateCfg
# -----------------------------
def resolve_candidate_config(
    case: CaseSpec,
    anchor_uniform_scaling_factors: List[float],
    anchor_side_ks: List[int] = None,
) -> CandidateCfg:
    """
    Create CandidateCfg based on CaseSpec's cand_mode
    
    ★ Baseline (31) vs Experiments (15 = half) ★
    - both: anchor(1) + inv(3) + scale(17) + border_perturb(10) = 31 -> limit by max
    - isotropic_only: anchor(1) + inv(3) + scale(11) = 15
    - borderwise_only: anchor(1) + inv(4) + scale(1) + border_perturb(9) = 15
    """
    if anchor_side_ks is None:
        anchor_side_ks = [1, 3, 5, 7, 9]
    
    if case.cand_mode == "isotropic_only":
        # ============================================================
        # Scale only: no side
        # 15 = anchor(1) + inv(3) + scale(11)
        # ============================================================
        isotropic_scales = generate_uniform_scaling_factors_for_count(11, 0.6, 1.4)
        n_border_perturb = 0
        
    elif case.cand_mode == "borderwise_only":
        # ============================================================
        # Side only: scale is only 1.0
        # 15 = anchor(1) + inv(4) + scale(1) + border_perturb(9)
        # ============================================================
        isotropic_scales = [1.0]  # identity only
        n_border_perturb = 9      # 15 - 1 - 4 - 1 = 9
        
    else:  # "both" (default)
        # ============================================================
        # Scale + Side both: existing method
        # 31 generated -> limit by max_candidates
        # ============================================================
        isotropic_scales = auto_expand_uniform_scaling_factors_for_refine(anchor_uniform_scaling_factors)
        n_border_perturb = case.num_border_perturb
    
    return CandidateCfg(
        cand_uniform_scaling_factors=isotropic_scales,
        cand_side_ks=anchor_side_ks,
        num_border_perturb=n_border_perturb,
        include_anchor=True,
        include_inverse=case.include_inverse,
        inverse_jitter=0.03,
        require_mixed_signs=True,
        max_candidates_per_group=case.max_candidates,
    )


# -----------------------------
# CASE_SPECS_DEFAULT: Define experiment cases
# -----------------------------
_n_data = globals().get("n_data", 10)

CASE_SPECS_DEFAULT = [
    # ============================================================
    # Baseline: both, max=60 (current setting, generates 31)
    # ============================================================
    CaseSpec(
        case_group="CAND_ABLATION",
        case_name=f"baseline_both_31_absn{_n_data}",
        backbone_key="densenet",
        absn=_n_data,
        use_context=True,
        disable_dist_mlp=False,
        share_backbone=False,
        loss_name="L+E",
        w_listmle=1.0,
        w_mono=0.0,
        w_mse=1.0,
        cand_mode="both",
        max_candidates=60,
        num_border_perturb=10,
        include_inverse=True,
    ),
    
    # ============================================================
    # Exp 1: both, max=15 (includes scale+side, limited to 15 = half)
    # ============================================================
    CaseSpec(
        case_group="CAND_ABLATION",
        case_name=f"exp1_both_15_absn{_n_data}",
        backbone_key="densenet",
        absn=_n_data,
        use_context=True,
        disable_dist_mlp=False,
        share_backbone=False,
        loss_name="L+E",
        w_listmle=1.0,
        w_mono=0.0,
        w_mse=1.0,
        cand_mode="both",
        max_candidates=15,  # 31 -> 15 (half)
        num_border_perturb=10,
        include_inverse=True,
    ),
    
    # ============================================================
    # Exp 2: isotropic_only, 15 (pure isotropic resizing candidates only)
    # anchor(1) + inv(3) + scale(11) = 15
    # ============================================================
    CaseSpec(
        case_group="CAND_ABLATION",
        case_name=f"exp2_isotropic_only_15_absn{_n_data}",
        backbone_key="densenet",
        absn=_n_data,
        use_context=True,
        disable_dist_mlp=False,
        share_backbone=False,
        loss_name="L+E",
        w_listmle=1.0,
        w_mono=0.0,
        w_mse=1.0,
        cand_mode="isotropic_only",
        max_candidates=15,
        num_border_perturb=0,  # no side
        include_inverse=True,
    ),
    
    # ============================================================
    # Exp 3: borderwise_only, 15 (pure border-wise perturbation candidates only)
    # anchor(1) + inv(4) + scale(1) + border_perturb(9) = 15
    # ============================================================
    CaseSpec(
        case_group="CAND_ABLATION",
        case_name=f"exp3_borderwise_only_15_absn{_n_data}",
        backbone_key="densenet",
        absn=_n_data,
        use_context=True,
        disable_dist_mlp=False,
        share_backbone=False,
        loss_name="L+E",
        w_listmle=1.0,
        w_mono=0.0,
        w_mse=1.0,
        cand_mode="borderwise_only",
        max_candidates=15,
        num_border_perturb=9,  # 15 - 1(anchor) - 4(inv) - 1(scale) = 9
        include_inverse=True,
    ),
]


# -----------------------------
# Calculate and print expected candidate counts
# -----------------------------
def calc_expected_candidates(case: CaseSpec) -> dict:
    """Calculate expected candidate count for each CaseSpec"""
    if case.cand_mode == "isotropic_only":
        scale_cnt = 11  # generate_uniform_scaling_factors_for_count(11)
        inv_cnt = 3     # scale inverse + 2 jitter
        border_perturb = 0
    elif case.cand_mode == "borderwise_only":
        scale_cnt = 1
        inv_cnt = 4     # side inverse + 3 jitter
        border_perturb = 9
    else:  # both
        scale_cnt = 17  # auto_expand result
        inv_cnt = 3     # based on uniform scaling case
        border_perturb = case.num_border_perturb
    
    anchor = 1
    raw_total = anchor + inv_cnt + scale_cnt + border_perturb
    final = min(raw_total, case.max_candidates)
    
    return {
        "anchor": anchor,
        "inverse": inv_cnt,
        "scale": scale_cnt,
        "border_perturb": border_perturb,
        "raw_total": raw_total,
        "max_candidates": case.max_candidates,
        "final": final,
    }

print("=" * 80)
print("CASE_SPECS_DEFAULT - Candidate settings and expected counts (half of Baseline)")
print("=" * 80)
print(f"{'Case Name':<30} │ Mode       │ Anc│ Inv│ Scl│ Sid│ Raw│ Max│ Final")
print("-" * 80)

for case in CASE_SPECS_DEFAULT:
    c = calc_expected_candidates(case)
    print(f"{case.case_name:<30} │ {case.cand_mode:<10} │ {c['anchor']:2d} │ {c['inverse']:2d} │ {c['scale']:2d} │ {c['border_perturb']:2d} │ {c['raw_total']:2d} │ {c['max_candidates']:2d} │  {c['final']:2d}")

print("=" * 80)
print()
print("★ Verification:")
print(f"  - Baseline (both): 31 candidates")
print(f"  - Exp1 (both, max=15): 31 -> limited to 15 (half)")
print(f"  - Exp2 (isotropic_only): 1+3+11+0 = 15 (half)")
print(f"  - Exp3 (borderwise_only): 1+4+1+9 = 15 (half)")
print("=" * 80)

CASE_SPECS_DEFAULT - Candidate settings and expected counts (half of Baseline)
Case Name                      │ Mode       │ Anc│ Inv│ Scl│ Sid│ Raw│ Max│ Final
--------------------------------------------------------------------------------
baseline_both_31_absn10        │ both       │  1 │  3 │ 17 │ 10 │ 31 │ 60 │  31
exp1_both_15_absn10            │ both       │  1 │  3 │ 17 │ 10 │ 31 │ 15 │  15
exp2_isotropic_only_15_absn10      │ isotropic_only │  1 │  3 │ 11 │  0 │ 15 │ 15 │  15
exp3_borderwise_only_15_absn10       │ borderwise_only  │  1 │  4 │  1 │  9 │ 15 │ 15 │  15

★ Verification:
  - Baseline (both): 31 candidates
  - Exp1 (both, max=15): 31 -> limited to 15 (half)
  - Exp2 (isotropic_only): 1+3+11+0 = 15 (half)
  - Exp3 (borderwise_only): 1+4+1+9 = 15 (half)


In [13]:
# ==========================================
# Cell 3) Orchestrator — ONE MODEL per dataset (13 start-cases MIXED) — FINAL (CSV ablation-friendly)
#  - ✅ Explicitly separate ablation comparison columns in CSV
#  - ✅ Remove best_pt(best_ckpt), out_dir columns (keep only log_csv in CSV)
#  - ✅ If best_val_gain == 0.0(±tol), delete outputs and retry once (fresh retry)
#  - (Option) Filter AMP deprecation / nested-tensor transformer warnings (fix by replacing with torch.amp in Cell2)
# ==========================================

from __future__ import annotations
import os, shutil, random, hashlib, re, json, gc, warnings
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple

import pandas as pd
import torch

assert "dataset_summaries" in globals()
assert "TrainConfig" in globals() and "run_train_case" in globals()
assert "auto_expand_uniform_scaling_factors_for_refine" in globals()

# -----------------------------
# (Optional) warning filters (fix by replacing torch.cuda.amp -> torch.amp in Cell2)
# -----------------------------
warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    message=r".*torch\.cuda\.amp\.GradScaler.*deprecated.*",
)
warnings.filterwarnings(
    "ignore",
    category=FutureWarning,
    message=r".*torch\.cuda\.amp\.autocast.*deprecated.*",
)
warnings.filterwarnings(
    "ignore",
    category=UserWarning,
    message=r".*nested tensors is in prototype stage.*",
)

# -----------------------------
# retry policy (gain==0 → purge & retry)
# -----------------------------
ZERO_GAIN_TOL = 1e-12          # tolerance for gain==0 check
MAX_RETRY_ON_ZERO_GAIN = 1     # retry up to 1 time if gain is 0

# -----------------------------
# small utils
# -----------------------------
def _stable_hash32(s: str) -> int:
    h = hashlib.md5(s.encode("utf-8")).hexdigest()
    return int(h[:8], 16)

def _safe_symlink_or_copy(src: Path, dst: Path):
    dst.parent.mkdir(parents=True, exist_ok=True)
    if dst.exists():
        return
    try:
        os.symlink(src, dst)
    except OSError:
        shutil.copy2(src, dst)

def _safe_name(s: str) -> str:
    s = str(s).replace(os.sep, "_").replace("..","_")
    s = re.sub(r"[^a-zA-Z0-9_\-\.]+", "_", s)
    s = s.replace(".", "p")
    return s[:180]

def resolve_device(prefer="cuda:0"):
    if torch.cuda.is_available():
        n = torch.cuda.device_count()
        if prefer.startswith("cuda:"):
            try:
                i = int(prefer.split(":")[1])
                if 0 <= i < n:
                    return prefer
            except Exception:
                pass
        return "cuda:0"
    return "cpu"

def _to_float(x, default=None):
    try:
        return float(x)
    except Exception:
        return default

def _infer_weights_dir_from_log_csv(out_root: Path, log_csv: str) -> Optional[Path]:
    """
    Cell2 convention:
      log_csv parent:  OUT_ROOT/<dataset>/<exp>/<absn_tag>/<ablation_tag>/
      weights dir:     OUT_ROOT/weights/<dataset>/<exp>/<absn_tag>/<ablation_tag>/
    """
    if not log_csv:
        return None
    try:
        run_dir = Path(log_csv).resolve().parent
        out_root_r = Path(out_root).resolve()
        rel = run_dir.relative_to(out_root_r)
        return out_root_r / "weights" / rel
    except Exception:
        return None

def _purge_run_artifacts(out_root: Path, out: Dict[str, Any]):
    """
    If gain==0, remove previous outputs (out_dir + weights) for fresh retry.
    - Delete run_dir based on out['log_csv']
    - Infer and delete weights/<...> based on log_csv
    - Use best_ckpt/best_pt from out if available
    """
    # 1) Remove run_dir (based on log_csv)
    log_csv = out.get("log_csv", "") if isinstance(out, dict) else ""
    if isinstance(log_csv, str) and log_csv:
        try:
            run_dir = Path(log_csv).resolve().parent
            if run_dir.exists():
                shutil.rmtree(run_dir, ignore_errors=True)
        except Exception:
            pass

        # 2) Infer and remove weights dir
        try:
            wdir = _infer_weights_dir_from_log_csv(Path(out_root), log_csv)
            if wdir and wdir.exists():
                shutil.rmtree(wdir, ignore_errors=True)
        except Exception:
            pass

    # 3) Also remove paths from best_ckpt/best_pt in out if present
    best_ckpt = ""
    if isinstance(out, dict):
        best_ckpt = out.get("best_ckpt", "") or out.get("best_pt", "") or out.get("best", "")
    if isinstance(best_ckpt, str) and best_ckpt:
        try:
            p = Path(best_ckpt).resolve()
            if p.exists():
                # delete only best.pt
                try:
                    p.unlink()
                except Exception:
                    pass
            # usually run-level folder, try cleaning parent too
            if p.parent.exists() and any(p.parent.iterdir()) is False:
                shutil.rmtree(p.parent, ignore_errors=True)
        except Exception:
            pass

    # 4) Memory cleanup (prevent OOM on retry)
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


# -----------------------------
# Subset helpers (train/val)
# -----------------------------
VAL_IMAGE_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}

def _select_subset_rels(ds_info: Dict[str, Any], split_kind: str, absn: int, seed: int) -> List[Path]:
    assert split_kind in ("train","val")
    images_dir = ds_info.get("train_dir") if split_kind=="train" else ds_info.get("val_dir")
    labels_dir = ds_info.get("train_labels_dir") if split_kind=="train" else ds_info.get("val_labels_dir")
    if not images_dir or not labels_dir:
        return []
    images_path = Path(images_dir)
    labels_path = Path(labels_dir)

    all_imgs = sorted([p for p in images_path.rglob("*") if p.is_file() and p.suffix.lower() in VAL_IMAGE_EXTS])
    imgs_with_label = []
    for img_path in all_imgs:
        rel = img_path.relative_to(images_path)
        if (labels_path / rel.with_suffix(".txt")).exists():
            imgs_with_label.append(img_path)
    if not imgs_with_label:
        return []

    salt = 0xA17C if split_kind=="train" else 0xB55D
    rng_seed = (salt ^ _stable_hash32(ds_info["dataset"]) ^ int(absn) ^ int(seed)) & 0xFFFFFFFF
    rng = random.Random(rng_seed)
    rng.shuffle(imgs_with_label)
    subset = imgs_with_label[:min(absn, len(imgs_with_label))]
    return [p.relative_to(images_path) for p in subset]

def prepare_subset_from_rels(out_root: Path, ds_info: Dict[str, Any], split_kind: str, rels: List[Path], absn: int, seed: int):
    assert split_kind in ("train","val")
    images_dir = ds_info.get("train_dir") if split_kind=="train" else ds_info.get("val_dir")
    labels_dir = ds_info.get("train_labels_dir") if split_kind=="train" else ds_info.get("val_labels_dir")
    if not images_dir or not labels_dir:
        return images_dir, labels_dir

    images_path = Path(images_dir)
    labels_path = Path(labels_dir)

    subset_root = out_root / f"_{split_kind}_subsets" / ds_info["dataset"] / f"seed{int(seed)}" / f"absn_{absn:04d}"
    img_out = subset_root / "images"
    lbl_out = subset_root / "labels"

    if img_out.exists() and any(img_out.rglob("*")):
        return str(img_out), str(lbl_out)

    img_out.mkdir(parents=True, exist_ok=True)
    lbl_out.mkdir(parents=True, exist_ok=True)

    for rel in rels:
        img_src = images_path / rel
        lbl_src = labels_path / rel.with_suffix(".txt")
        if not img_src.exists() or not lbl_src.exists():
            continue
        _safe_symlink_or_copy(img_src, img_out / rel)
        _safe_symlink_or_copy(lbl_src, lbl_out / rel.with_suffix(".txt"))

    return str(img_out), str(lbl_out)


# -----------------------------
# Start anchor labels (mixed) + save __case_map.json
# -----------------------------
def list_start_noise_cases_for_dataset(ds_info: Dict[str, Any], split_kind: str,
                                       uniform_scaling_factors: List[float], side_ks: List[int], exclude_scale_1=True) -> List[str]:
    clean_labels_dir = ds_info.get("train_labels_dir") if split_kind=="train" else ds_info.get("val_labels_dir")
    if not clean_labels_dir:
        return []
    split_name = Path(clean_labels_dir).name
    ds_root = Path(ds_info["root"])

    out = []
    for s in uniform_scaling_factors:
        if exclude_scale_1 and abs(float(s)-1.0) < 1e-9:
            continue
        cand = [f"labels_uniform_scaling_{s:g}", f"labels_uniform_scaling_{str(s).replace('.','p')}"]
        for dn in cand:
            if (ds_root / dn / split_name).exists():
                out.append(dn); break
    for k in side_ks:
        dn = f"labels_boundary_jitter_{int(k)}"
        if (ds_root / dn / split_name).exists():
            out.append(dn)

    seen=set(); uniq=[]
    for x in out:
        if x in seen: continue
        seen.add(x); uniq.append(x)
    return uniq

def prepare_anchor_labels_dir(
    out_root: Path,
    ds_info: Dict[str, Any],
    split_kind: str,
    rels_from_images: Optional[List[Path]],
    seed: int,
    tag: str,
    start_anchor_mode: str,
    uniform_scaling_factors: List[float],
    side_ks: List[int],
    fixed_case: Optional[str] = None,
) -> str:
    assert start_anchor_mode in ("noisy_mix_balanced","noisy_mix","noisy_fixed","clean")

    clean_labels_dir = ds_info.get("train_labels_dir") if split_kind=="train" else ds_info.get("val_labels_dir")
    if not clean_labels_dir:
        return clean_labels_dir
    if start_anchor_mode == "clean":
        return str(clean_labels_dir)

    clean_labels_path = Path(clean_labels_dir)
    split_name = clean_labels_path.name
    ds_root = Path(ds_info["root"])

    base = out_root / "_start_anchor_labels" / ds_info["dataset"] / start_anchor_mode / f"seed{int(seed)}" / tag / split_name
    if start_anchor_mode == "noisy_fixed" and fixed_case:
        base = out_root / "_start_anchor_labels" / ds_info["dataset"] / start_anchor_mode / _safe_name(fixed_case) / f"seed{int(seed)}" / tag / split_name
    start_root = base

    if start_root.exists() and any(start_root.rglob("*.txt")) and (start_root / "__case_map.json").exists():
        return str(start_root)
    start_root.mkdir(parents=True, exist_ok=True)

    if rels_from_images is None:
        all_lbls = sorted([p for p in clean_labels_path.rglob("*.txt") if p.is_file()])
        rel_lbls = [p.relative_to(clean_labels_path) for p in all_lbls]
    else:
        rel_lbls = [rel.with_suffix(".txt") for rel in rels_from_images]

    if start_anchor_mode == "noisy_fixed":
        if (not fixed_case) or (not (ds_root / fixed_case / split_name).exists()):
            return str(clean_labels_path)
        case_list = [fixed_case]
    else:
        case_list = list_start_noise_cases_for_dataset(ds_info, split_kind, uniform_scaling_factors, side_ks, exclude_scale_1=True)
        if not case_list:
            return str(clean_labels_path)

    rel_to_case = {}

    if start_anchor_mode == "noisy_mix_balanced":
        rel_sorted = list(rel_lbls)
        rng = random.Random((_stable_hash32(ds_info["dataset"]) ^ int(seed) ^ (0xC0FFEE if split_kind=="train" else 0xBEEF)) & 0xFFFFFFFF)
        rng.shuffle(rel_sorted)
        for i, rel_lbl in enumerate(rel_sorted):
            rel_to_case[rel_lbl.as_posix()] = case_list[i % len(case_list)]
    elif start_anchor_mode == "noisy_mix":
        for rel_lbl in rel_lbls:
            h = (_stable_hash32(f"{ds_info['dataset']}::{rel_lbl.as_posix()}") ^ int(seed)) & 0xFFFFFFFF
            rel_to_case[rel_lbl.as_posix()] = case_list[h % len(case_list)]
    else:
        for rel_lbl in rel_lbls:
            rel_to_case[rel_lbl.as_posix()] = case_list[0]

    for rel_lbl in rel_lbls:
        clean_src = clean_labels_path / rel_lbl
        if not clean_src.exists():
            continue
        chosen = rel_to_case.get(rel_lbl.as_posix(), case_list[0])
        noisy_src = ds_root / chosen / split_name / rel_lbl
        if not noisy_src.exists():
            noisy_src = clean_src
        _safe_symlink_or_copy(noisy_src, start_root / rel_lbl)

    try:
        (start_root / "__case_map.json").write_text(json.dumps(rel_to_case, ensure_ascii=False, indent=2), encoding="utf-8")
        (start_root / "__case_list.json").write_text(json.dumps(case_list, ensure_ascii=False, indent=2), encoding="utf-8")
    except Exception:
        pass

    return str(start_root)


# -----------------------------
# CaseSpec (Extended: includes Candidate Control)
# -----------------------------
@dataclass(frozen=True)
class CaseSpec:
    """
    Experiment case definition: model structure + training settings + candidate generation settings
    
    cand_mode:
        - "both": generate both scale + border-wise perturbation candidates
        - "isotropic_only": generate only isotropic resizing candidates (no side)
        - "borderwise_only": generate only border-wise perturbation candidates (minimize scale)
    """
    case_group: str
    case_name: str
    backbone_key: str
    absn: int
    use_context: bool
    disable_dist_mlp: bool
    share_backbone: bool
    loss_name: str
    w_listmle: float
    w_mono: float
    w_mse: float
    d_model: int = 512
    nhead: int = 8
    nlayers: int = 2
    # ★ Candidate Control ★
    cand_mode: str = "both"           # "both" | "isotropic_only" | "borderwise_only"
    max_candidates: int = 60          # Maximum candidate count
    num_border_perturb: int = 10           # Border-wise perturbation candidate count
    include_inverse: bool = True      # Whether to include inverse candidates


def generate_uniform_scaling_factors_for_count(count: int, min_s: float = 0.6, max_s: float = 1.4) -> List[float]:
    """Generate specified number of scale factors (uniform distribution)"""
    if count <= 1:
        return [1.0]
    step = (max_s - min_s) / (count - 1)
    return [round(min_s + i * step, 2) for i in range(count)]

BACKBONE_KEY_TO_STEM = {
    "resnet": "resnet",
    "densenet": "densenet",
    "efficient": "efficientv2s",
    "vit": "vitb16",
}
def resolve_exp_name(backbone_key: str) -> str:
    return f"set_{BACKBONE_KEY_TO_STEM[backbone_key]}_noisy2clean"


# -----------------------------
# CSV helpers (ablation-friendly)
# -----------------------------
def _fmt_list(xs: Any, max_items: int = 18) -> str:
    if xs is None:
        return ""
    try:
        xs = list(xs)
    except Exception:
        return str(xs)
    if len(xs) <= max_items:
        return str(xs)
    head = xs[:max_items]
    return f"{head} ...(+{len(xs)-max_items})"

CSV_COLUMNS = [
    # identity
    "dataset", "seed", "start_anchor_mode",
    "case_group", "case_name",
    # model config
    "exp_name", "backbone_key",
    "train_absn", "val_absn", "model_absn",
    "ctx_tag", "dist_tag", "share_tag",
    "use_context", "disable_dist_mlp", "share_backbone",
    "loss_name", "w_listmle", "w_mono", "w_mse",
    "d_model", "nhead", "nlayers",
    # candidate config
    "cand_uniform_scaling_factors", "cand_side_ks",
    "num_border_perturb", "include_inverse", "inverse_jitter", "max_candidates_per_group",
    # results
    "status", "best_epoch",
    "best_val_pred_iou", "best_val_anchor_iou", "best_val_gain",
    "log_csv",
]

def build_ablation_row(ds_info: Dict[str, Any], case: CaseSpec, cfg: Any, out: Dict[str, Any], start_anchor_mode: str, train_absn: int, val_absn: int) -> Dict[str, Any]:
    dist_on = int(not bool(cfg.disable_dist_mlp))   # disable_dist_mlp=False -> dist1
    ctx_on  = int(bool(cfg.use_context))
    share_on = int(bool(cfg.share_backbone))

    row = {
        "dataset": ds_info["dataset"],
        "seed": int(cfg.seed),
        "start_anchor_mode": str(start_anchor_mode),

        "case_group": str(case.case_group),
        "case_name": str(case.case_name),

        "exp_name": str(cfg.exp_name),
        "backbone_key": str(cfg.backbone_key),

        "train_absn": int(train_absn),
        "val_absn": int(val_absn),
        "model_absn": int(cfg.absn),

        "ctx_tag": f"ctx{ctx_on}",
        "dist_tag": f"dist{dist_on}",
        "share_tag": f"share{share_on}",

        "use_context": bool(cfg.use_context),
        "disable_dist_mlp": bool(cfg.disable_dist_mlp),
        "share_backbone": bool(cfg.share_backbone),

        "loss_name": str(case.loss_name),
        "w_listmle": float(cfg.w_listmle),
        "w_mono": float(cfg.w_mono),
        "w_mse": float(cfg.w_mse),

        "d_model": int(cfg.d_model),
        "nhead": int(cfg.nhead),
        "nlayers": int(cfg.nlayers),

        "cand_uniform_scaling_factors": _fmt_list(getattr(cfg, "cand_uniform_scaling_factors", None)),
        "cand_side_ks": _fmt_list(getattr(cfg, "cand_side_ks", None)),
        "num_border_perturb": int(getattr(cfg, "num_border_perturb", -1)),
        "include_inverse": bool(getattr(cfg, "include_inverse", False)),
        "inverse_jitter": float(getattr(cfg, "inverse_jitter", 0.0)),
        "max_candidates_per_group": int(getattr(cfg, "max_candidates_per_group", -1)),

        "status": str(out.get("status", "unknown")),
        "best_epoch": int(out.get("best_epoch", 0)) if out.get("best_epoch") is not None else 0,

        "best_val_pred_iou": float(out.get("best_val_pred_iou", 0.0)) if out.get("best_val_pred_iou") is not None else 0.0,
        "best_val_anchor_iou": float(out.get("best_val_anchor_iou", 0.0)) if out.get("best_val_anchor_iou") is not None else 0.0,
        "best_val_gain": float(out.get("best_val_gain", 0.0)) if out.get("best_val_gain") is not None else 0.0,

        "log_csv": str(out.get("log_csv", "")),
    }
    return row



# -----------------------------
# Candidate config helper (based on CaseSpec cand_mode)
# -----------------------------
def _resolve_uniform_scaling_factors(case: CaseSpec, default_scales: List[float]) -> List[float]:
    """Determine scale factors based on CaseSpec's cand_mode"""
    if case.cand_mode == "isotropic_only":
        # 15 = anchor(1) + inv(3) + scale(11) -> scale 11
        return generate_uniform_scaling_factors_for_count(11, 0.6, 1.4)
    elif case.cand_mode == "borderwise_only":
        # scale is only 1.0 (identity)
        return [1.0]
    else:  # "both"
        return list(default_scales)

def _resolve_num_border_perturb(case: CaseSpec) -> int:
    """Determine num_border_perturb based on CaseSpec's cand_mode"""
    if case.cand_mode == "isotropic_only":
        return 0  # no side
    elif case.cand_mode == "borderwise_only":
        return 9  # 15 - 1(anchor) - 4(inv) - 1(scale) = 9
    else:  # "both"
        return int(case.num_border_perturb)


# -----------------------------
# build_cfg
# -----------------------------
def build_cfg(
    out_root: str,
    device: str,
    ds_info: Dict[str, Any],
    seed: int,
    case: CaseSpec,
    start_anchor_mode: str,
    train_absn: int,
    val_absn: int,
    anchor_uniform_scaling_factors: List[float],
    anchor_side_ks: List[int],
    num_workers: int,
    prefetch_factor: int,
    persistent_workers: bool,
    use_amp: bool,
    pin_memory: bool,
    eval_every: int,
    epochs: int,
    batch_size_set: int,
    crop_size: int,
    lr: float,
    weight_decay: float,
) -> Any:
    out_root_p = Path(out_root)

    rels_tr = _select_subset_rels(ds_info, "train", int(train_absn), seed=int(seed))
    tr_img, tr_lbl_clean = prepare_subset_from_rels(out_root_p, ds_info, "train", rels_tr, int(train_absn), seed=int(seed))

    rels_va = _select_subset_rels(ds_info, "val", int(val_absn), seed=int(seed))
    va_img, va_lbl_clean = prepare_subset_from_rels(out_root_p, ds_info, "val", rels_va, int(val_absn), seed=int(seed))

    tr_lbl_anchor = prepare_anchor_labels_dir(
        out_root_p, ds_info, "train",
        rels_from_images=rels_tr,
        seed=int(seed),
        tag=f"train_absn{int(train_absn):04d}",
        start_anchor_mode=start_anchor_mode,
        uniform_scaling_factors=anchor_uniform_scaling_factors,
        side_ks=anchor_side_ks,
        fixed_case=None,
    )
    va_lbl_anchor = prepare_anchor_labels_dir(
        out_root_p, ds_info, "val",
        rels_from_images=rels_va,
        seed=int(seed),
        tag=f"val_absn{int(val_absn):04d}",
        start_anchor_mode=start_anchor_mode,
        uniform_scaling_factors=anchor_uniform_scaling_factors,
        side_ks=anchor_side_ks,
        fixed_case=None,
    )

    cand_scale = auto_expand_uniform_scaling_factors_for_refine(anchor_uniform_scaling_factors, cap=2.2)

    cfg = TrainConfig(
        data_root=str(Path(ds_info["root"]).parent),
        dataset_name=ds_info["dataset"],
        exp_name=resolve_exp_name(case.backbone_key),
        seed=int(seed),
        device=str(device),
        out_root=str(out_root),

        backbone_key=str(case.backbone_key),
        d_model=int(case.d_model),
        nhead=int(case.nhead),
        nlayers=int(case.nlayers),

        budget_mode="absn",
        absn=int(case.absn),

        use_context=bool(case.use_context),
        disable_dist_mlp=bool(case.disable_dist_mlp),
        share_backbone=bool(case.share_backbone),

        w_listmle=float(case.w_listmle),
        w_mono=float(case.w_mono),
        w_mse=float(case.w_mse),

        num_workers=int(num_workers),
        prefetch_factor=int(prefetch_factor),
        persistent_workers=bool(persistent_workers) and int(num_workers)>0,
        use_amp=bool(use_amp) and str(device).startswith("cuda"),
        pin_memory=bool(pin_memory) and str(device).startswith("cuda"),

        eval_every=int(eval_every),
        epochs=int(epochs),
        batch_size_set=int(batch_size_set),
        crop_size=int(crop_size),
        lr=float(lr),
        weight_decay=float(weight_decay),

        train_images_dir=str(tr_img),
        valid_images_dir=str(va_img),

        train_labels_dir=str(tr_lbl_anchor),          # anchor
        valid_labels_dir=str(va_lbl_anchor),          # anchor
        target_train_labels_dir=str(tr_lbl_clean),    # clean target
        target_valid_labels_dir=str(va_lbl_clean),    # clean target

        # ★ Candidate config based on CaseSpec's cand_mode ★
        cand_uniform_scaling_factors=_resolve_uniform_scaling_factors(case, cand_scale),
        cand_side_ks=list(anchor_side_ks),
        jitter_patterns=list(anchor_side_ks),
        num_border_perturb=_resolve_num_border_perturb(case),
        include_inverse=bool(case.include_inverse),
        inverse_jitter=0.03,
        max_candidates_per_group=int(case.max_candidates),
        require_mixed_signs=True,
    )
    return cfg


def main(
    target_datasets: List[str],
    seeds: List[int] = (42,),
    case_specs: Optional[List[CaseSpec]] = None,
    start_anchor_mode: str = "noisy_mix_balanced",
    out_root_base: str = "./experiments_refine_noisy_to_clean",
    anchor_uniform_scaling_factors: Optional[List[float]] = None,
    anchor_side_ks: Optional[List[int]] = None,
    train_absn: int = 100,
    val_absn: int = 100,
    prefer_device: str = "cuda:0",
    num_workers: int = 8,
    prefetch_factor: int = 2,
    persistent_workers: bool = True,
    use_amp: bool = True,
    pin_memory: bool = True,
    eval_every: int = 5,
    epochs: int = 10,
    batch_size_set: int = 2,
    crop_size: int = 224,
    lr: float = 2e-4,
    weight_decay: float = 0.05,
):
    assert start_anchor_mode in ("noisy_mix_balanced","noisy_mix","noisy_fixed","clean")
    device = resolve_device(prefer_device)

    if anchor_uniform_scaling_factors is None:
        anchor_uniform_scaling_factors = [round(0.6 + 0.1*i, 1) for i in range(9)]  # 0.6..1.4
    if anchor_side_ks is None:
        anchor_side_ks = [1,3,5,7,9]

    if case_specs is None:
        case_specs = [
            CaseSpec("SINGLE","SINGLE_densenet_absn100_ctx1_dist1_L+E","densenet",100,True,False,False,"L+E",1.0,0.0,0.25)
        ]

    out_root = Path(out_root_base).resolve()
    out_root.mkdir(parents=True, exist_ok=True)

    # pick datasets
    target_infos = [ds for ds in dataset_summaries if ds["dataset"] in set(target_datasets)]
    target_infos = [ds for ds in target_infos if ds.get("train_dir") and ds.get("val_dir") and ds.get("train_labels_dir") and ds.get("val_labels_dir")]

    all_rows = []
    run_idx = 0
    total = len(target_infos) * len(case_specs) * len(seeds)

    print("="*120)
    print("[ORCH] noisy-anchor -> clean-target refinement")
    print(f" - start_anchor_mode: {start_anchor_mode} (writes __case_map.json)")
    print(f" - device: {device}")
    print(f" - subsets: train_absn={train_absn}, val_absn={val_absn}")
    print(f" - total runs: {total}")
    print(f" - retry policy: gain==0(±{ZERO_GAIN_TOL}) -> purge & retry up to {MAX_RETRY_ON_ZERO_GAIN}x")
    print("="*120)

    for ds_info in target_infos:
        for case in case_specs:
            for seed in seeds:
                run_idx += 1
                print(f"\n▶️  [{run_idx}/{total}] {ds_info['dataset']} | {case.case_name} | seed={seed}")

                cfg = build_cfg(
                    out_root=str(out_root),
                    device=str(device),
                    ds_info=ds_info,
                    seed=int(seed),
                    case=case,
                    start_anchor_mode=start_anchor_mode,
                    train_absn=int(train_absn),
                    val_absn=int(val_absn),
                    anchor_uniform_scaling_factors=list(anchor_uniform_scaling_factors),
                    anchor_side_ks=list(anchor_side_ks),
                    num_workers=int(num_workers),
                    prefetch_factor=int(prefetch_factor),
                    persistent_workers=bool(persistent_workers),
                    use_amp=bool(use_amp),
                    pin_memory=bool(pin_memory),
                    eval_every=int(eval_every),
                    epochs=int(epochs),
                    batch_size_set=int(batch_size_set),
                    crop_size=int(crop_size),
                    lr=float(lr),
                    weight_decay=float(weight_decay),
                )

                # -----------------------------
                # ✅ run + retry on gain==0
                # -----------------------------
                retry_cnt = 0
                while True:
                    out = run_train_case(cfg)
                    if not isinstance(out, dict):
                        out = {"status": "unknown"}

                    gain = _to_float(out.get("best_val_gain", None), default=None)

                    if (gain is not None) and (abs(gain) <= ZERO_GAIN_TOL) and (retry_cnt < MAX_RETRY_ON_ZERO_GAIN):
                        retry_cnt += 1
                        print(f"⚠️  best_val_gain≈0 detected → purge & retry fresh training ({retry_cnt}/{MAX_RETRY_ON_ZERO_GAIN})")
                        _purge_run_artifacts(out_root=out_root, out=out)

                        # (Optional) If Cell2 has resume_if_exists option, force fresh
                        if hasattr(cfg, "resume_if_exists"):
                            try:
                                cfg.resume_if_exists = False
                            except Exception:
                                pass

                        continue
                    break

                if retry_cnt > 0:
                    out["status"] = f"{out.get('status','ok')}__retry{retry_cnt}"

                # ✅ CSV row (removed best_ckpt/out_dir)
                row = build_ablation_row(
                    ds_info=ds_info,
                    case=case,
                    cfg=cfg,
                    out=out,
                    start_anchor_mode=start_anchor_mode,
                    train_absn=int(train_absn),
                    val_absn=int(val_absn),
                )
                all_rows.append(row)
                print(f"✅ DONE: status={row.get('status')} pred_iou={row.get('best_val_pred_iou')} gain={row.get('best_val_gain')}")

    df = pd.DataFrame(all_rows)

    # ✅ Fixed column order + auto-create missing columns
    for c in CSV_COLUMNS:
        if c not in df.columns:
            df[c] = ""
    df = df[CSV_COLUMNS]

    save_dir = out_root / "_orchestrator_summary"
    save_dir.mkdir(parents=True, exist_ok=True)
    save_path = save_dir / f"summary__start_{start_anchor_mode}__train{train_absn}__val{val_absn}__ablation.csv"
    df.to_csv(save_path, index=False)
    print(f"\nSaved: {save_path} (rows={len(df)})")
    return df, save_path


# ✅ Example execution
TARGET_DATASETS_DEFAULT = [
    "kitti", "homeobjects-3K", "african-wildlife", "construction-ppe", 
    # "Custom_Blood",
    "brain-tumor", "BCCD", "signature", "medical-pills", "VOC",
]

n_data = 100

# -----------------------------
# CaseSpec (Candidate Ablation — ctx1, dist0, E1 baseline)
# Note:
#  - disable_dist_mlp=True  => dist OFF (dist0)
#  - Baseline: 31 candidates
#  - Exp1,2,3: 15 candidates (half)
# -----------------------------
CASE_SPECS_DEFAULT = [
    # ============================================================
    # Baseline: both, 31 candidates (current setting)
    # ============================================================
    CaseSpec(
        case_group="CAND_ABLATION",
        case_name=f"baseline_both_31_absn{n_data}_ctx1_dist0_E1",
        backbone_key="densenet",
        absn=n_data,
        use_context=True,           # ctx1
        disable_dist_mlp=True,      # dist0
        share_backbone=False,
        loss_name="E",
        w_listmle=0.0,
        w_mono=0.0,
        w_mse=1.0,
        cand_mode="both",
        max_candidates=60,
        num_border_perturb=10,
        include_inverse=True,
    ),
    
    # ============================================================
    # Exp 1: both, 15 candidates (half)
    # ============================================================
    CaseSpec(
        case_group="CAND_ABLATION",
        case_name=f"exp1_both_15_absn{n_data}_ctx1_dist0_E1",
        backbone_key="densenet",
        absn=n_data,
        use_context=True,           # ctx1
        disable_dist_mlp=True,      # dist0
        share_backbone=False,
        loss_name="E",
        w_listmle=0.0,
        w_mono=0.0,
        w_mse=1.0,
        cand_mode="both",
        max_candidates=15,          # 31 -> 15 (half)
        num_border_perturb=10,
        include_inverse=True,
    ),
    
    # ============================================================
    # Exp 2: isotropic_only, 15 candidates
    # anchor(1) + inv(3) + scale(11) = 15
    # ============================================================
    CaseSpec(
        case_group="CAND_ABLATION",
        case_name=f"exp2_isotropic_only_15_absn{n_data}_ctx1_dist0_E1",
        backbone_key="densenet",
        absn=n_data,
        use_context=True,           # ctx1
        disable_dist_mlp=True,      # dist0
        share_backbone=False,
        loss_name="E",
        w_listmle=0.0,
        w_mono=0.0,
        w_mse=1.0,
        cand_mode="isotropic_only",
        max_candidates=15,
        num_border_perturb=0,            # no side
        include_inverse=True,
    ),
    
    # ============================================================
    # Exp 3: borderwise_only, 15 candidates
    # anchor(1) + inv(4) + scale(1) + border_perturb(9) = 15
    # ============================================================
    CaseSpec(
        case_group="CAND_ABLATION",
        case_name=f"exp3_borderwise_only_15_absn{n_data}_ctx1_dist0_E1",
        backbone_key="densenet",
        absn=n_data,
        use_context=True,           # ctx1
        disable_dist_mlp=True,      # dist0
        share_backbone=False,
        loss_name="E",
        w_listmle=0.0,
        w_mono=0.0,
        w_mse=1.0,
        cand_mode="borderwise_only",
        max_candidates=15,
        num_border_perturb=9,            # 15 - 1 - 4 - 1 = 9
        include_inverse=True,
    ),
]

DEVICE = "cuda:1"

df_summary, saved_csv = main(
    target_datasets=TARGET_DATASETS_DEFAULT,
    seeds=[42],
    case_specs=CASE_SPECS_DEFAULT,
    start_anchor_mode="noisy_mix_balanced",
    out_root_base=f"./experiments_ablation(scale_side, start=noise, generation_rule_control)-({n_data})",
    anchor_uniform_scaling_factors=[round(0.6 + 0.1*i, 1) for i in range(9)],  # 0.6..1.4
    anchor_side_ks=[1,3,5,7,9],
    train_absn=n_data,
    val_absn=n_data,
    prefer_device=DEVICE,
    num_workers=8,
    prefetch_factor=2,
    persistent_workers=True,
    use_amp=True,
    pin_memory=True,
    eval_every=5,
    epochs=10,
    batch_size_set=2,
    crop_size=224,
    lr=2e-4,
    weight_decay=0.05,
)


[ORCH] noisy-anchor -> clean-target refinement
 - start_anchor_mode: noisy_mix_balanced (writes __case_map.json)
 - device: cuda:1
 - subsets: train_absn=100, val_absn=100
 - total runs: 36
 - retry policy: gain==0(±1e-12) -> purge & retry up to 1x

▶️  [1/36] kitti | baseline_both_31_absn100_ctx1_dist0_E1 | seed=42
✅ DONE: status=ok pred_iou=0.6700186338169234 gain=0.12158892204364141

▶️  [2/36] kitti | exp1_both_15_absn100_ctx1_dist0_E1 | seed=42
✅ DONE: status=ok pred_iou=0.6326518359070733 gain=0.08422212413379124

▶️  [3/36] kitti | exp2_isotropic_only_15_absn100_ctx1_dist0_E1 | seed=42
✅ DONE: status=ok pred_iou=0.6562314232829072 gain=0.10780171150962511

▶️  [4/36] kitti | exp3_borderwise_only_15_absn100_ctx1_dist0_E1 | seed=42
✅ DONE: status=ok pred_iou=0.6622663375096661 gain=0.11383662573638416

▶️  [5/36] homeobjects-3K | baseline_both_31_absn100_ctx1_dist0_E1 | seed=42
✅ DONE: status=ok pred_iou=0.7237678196016454 gain=0.12259807563410495

▶️  [6/36] homeobjects-3K | exp1

# Load trained model weights & perform refinement

In [14]:
# ==========================================
# Cell 4) Refinement Runner — FINAL (SEEDS-aware, PATCHED, aligned w/ Cell2/3)
#   - Seed list support: refine per-seed ckpts (from Cell3 summary CSV or weights scan)
#   - Prefer orchestrator summary CSV first (Cell3 ablation-friendly CSV)
#   - ONLY refine for selected cases = TARGET_CASE_NAMES OR CASE_SPECS_DEFAULT
#   - Robust OUT_ROOT auto-resolve if weights/ missing
#   - Split selection: startswith(val/valid) => includes VOC val2012
#   - ✅ Save refined labels under OUT_ROOT/refines/seedXXX/<dataset>/<case_id>/<noise>/<split>/...
#   - Candidate generation uses Cell2 build_candidates_for_one_box (includes inverse)
#   - If Cell3 CSV has no best_ckpt: infer best.pt via log_csv -> OUT_ROOT/weights/<rel>/best.pt
# ==========================================

from __future__ import annotations
import os, re, json, gc, csv, inspect, warnings
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any

import numpy as np
import torch
from PIL import Image
from torchvision import transforms

# -------------------------
# USER CONFIG
# -------------------------
DATASETS_ROOT = Path("/home/ISW/project/datasets")

n_data = 100

OUT_ROOT = Path(f"./experiments_ablation(scale_side, start=noise, generation_rule_control)-({n_data})")  # Recommended to use same root as Cell3

TARGET_DATASETS: Optional[List[str]] = [
    "kitti",
    "homeobjects-3K",
    "african-wildlife",
    "construction-ppe",
    # "Custom_Blood",
    "brain-tumor",
    "BCCD",
    "signature",
    "medical-pills",
    "VOC",
]

# ✅ SEEDS (perform refinement per seed here)
SEEDS: List[int] = [42]

UNIFORM_SCALING_FACTORS   = [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.3, 1.4]
JITTER_PATTERNS  = [1, 3, 5, 7, 9]  # side K folders

REFINE_SPLITS = ["train", "val", "valid"]   # val2012/valid2012 includes prefix
MAX_FILES_PER_SPLIT: Optional[int] = n_data  # None=all
OVERWRITE_EXISTING = False
WRITE_METRICS_CSV = True

# ✅ seed-specific folders are separated under refines
REFINES_OUT_ROOT = OUT_ROOT / "refines"

IMG_EXTS = [".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp"]
YOLO_RE = re.compile(r"^\s*(\d+)\s+([\d\.eE+-]+)\s+([\d\.eE+-]+)\s+([\d\.eE+-]+)\s+([\d\.eE+-]+)")

IMAGENET_MEAN = [0.485,0.456,0.406]
IMAGENET_STD  = [0.229,0.224,0.225]

# -------------------------
# CASE SELECTION (core)
# -------------------------
TARGET_CASE_NAMES = [
    # ★ Match with case_name in Cell3 CASE_SPECS_DEFAULT ★
    f"baseline_both_31_absn{n_data}_ctx1_dist0_E1",
    f"exp1_both_15_absn{n_data}_ctx1_dist0_E1",
    f"exp2_isotropic_only_15_absn{n_data}_ctx1_dist0_E1",
    f"exp3_borderwise_only_15_absn{n_data}_ctx1_dist0_E1",
]


USE_CASE_SPECS_DEFAULT_IF_AVAILABLE = True  # ✅ Usually recommended True

# -------------------------
# DEVICE
# -------------------------
DEVICE = str(globals().get("DEVICE", "cuda:0"))

# -------------------------
# (Optional) warning filters
# -------------------------
warnings.filterwarnings("ignore", category=FutureWarning, message=r".*torch\.cuda\.amp\.autocast.*deprecated.*")
warnings.filterwarnings("ignore", category=UserWarning, message=r".*nested tensors is in prototype stage.*")

# -------------------------
# REQUIRE: Cell2 loaded
# -------------------------
assert "ReBox" in globals(), "Cell 2 (ReBox) must be loaded first."
assert "CandidateCfg" in globals() and "build_candidates_for_one_box" in globals(), "Cell2 candidate generation functions are required."

# -------------------------
# Auto-resolve OUT_ROOT if needed
# -------------------------
def resolve_out_root(user_out: Path) -> Path:
    if (user_out / "weights").is_dir():
        return user_out

    candidates = []
    for w in Path(".").glob("**/weights"):
        if not w.is_dir():
            continue
        root = w.parent
        if any(x in str(root).lower() for x in ["site-packages", ".cache", "venv", "conda"]):
            continue
        candidates.append(root)

    if not candidates:
        print(f"⚠️ OUT_ROOT={user_out} has no weights/ and auto-search also failed.")
        return user_out

    def score(r: Path) -> Tuple[int, float]:
        has_summary = int((r / "_orchestrator_summary").is_dir())
        try:
            mtime = (r / "weights").stat().st_mtime
        except Exception:
            mtime = 0.0
        return (has_summary, mtime)

    best = sorted(candidates, key=score, reverse=True)[0]
    print(f"ℹ️ OUT_ROOT auto-corrected: {user_out} -> {best} (weights/ found)")
    return best

OUT_ROOT = resolve_out_root(OUT_ROOT)

# ✅ Create refines root
REFINES_OUT_ROOT = OUT_ROOT / "refines"
REFINES_OUT_ROOT.mkdir(parents=True, exist_ok=True)

# -------------------------
# Target case filter resolve
# -------------------------
def _norm_name(s: str) -> str:
    return str(s).strip().lower()

def resolve_target_cases() -> Tuple[Optional[set], Optional[List[Any]]]:
    """
    return:
      - target_case_name_set: case_name exact match set (preferred)
      - target_case_specs: CaseSpec list (for fallback matching)
    """
    if TARGET_CASE_NAMES:
        names = [str(x) for x in TARGET_CASE_NAMES if str(x).strip()]
        return set(names), None

    if USE_CASE_SPECS_DEFAULT_IF_AVAILABLE and ("CASE_SPECS_DEFAULT" in globals()):
        specs = globals().get("CASE_SPECS_DEFAULT", None)
        if isinstance(specs, list) and specs:
            names = []
            for c in specs:
                if hasattr(c, "case_name"):
                    names.append(str(c.case_name))
            if names:
                return set(names), specs
            return None, specs

    return None, None

TARGET_CASE_NAME_SET, TARGET_CASE_SPECS = resolve_target_cases()

print("\n" + "="*110)
print("[REFINE] Case selection")
if TARGET_CASE_NAME_SET:
    print(f" - using case_name filter: {len(TARGET_CASE_NAME_SET)} cases")
    for x in sorted(list(TARGET_CASE_NAME_SET)):
        print(f"   • {x}")
elif TARGET_CASE_SPECS:
    print(f" - using CASE_SPECS_DEFAULT(spec-based fallback match): {len(TARGET_CASE_SPECS)} cases (no name set)")
else:
    print(" - ⚠️ No case filter set! (will refine ALL best.pt found per dataset)")
print(f"[REFINE] Seeds: {SEEDS}")
print(f"[REFINE] Output root: {REFINES_OUT_ROOT}  (create subfolders per seed)")
print("="*110)

# -------------------------
# Utils: dataset layout (robust)
# -------------------------
def resolve_dataset_root(dataset_name: str) -> Optional[Path]:
    cand = DATASETS_ROOT / dataset_name
    if cand.exists():
        return cand

    lname = dataset_name.lower()
    for d in DATASETS_ROOT.iterdir():
        if d.is_dir() and d.name.lower() == lname:
            return d

    alt = dataset_name.replace("3K", "3k").replace("3k", "3K")
    cand2 = DATASETS_ROOT / alt
    if cand2.exists():
        return cand2
    for d in DATASETS_ROOT.iterdir():
        if d.is_dir() and d.name.lower() == alt.lower():
            return d
    return None

def list_existing_splits(images_root: Path) -> List[str]:
    if not images_root.exists():
        return ["__flat__"]
    subs = sorted([p.name for p in images_root.iterdir() if p.is_dir()])
    return subs if subs else ["__flat__"]

def _split_selected(split: str) -> bool:
    if split == "__flat__":
        return True
    if not REFINE_SPLITS:
        return True
    s = split.lower()
    for t in REFINE_SPLITS:
        tl = str(t).lower()
        if s == tl or s.startswith(tl):
            return True
    if s.startswith("val") and any(str(t).lower()=="val" for t in REFINE_SPLITS):
        return True
    if s.startswith("valid") and any(str(t).lower()=="valid" for t in REFINE_SPLITS):
        return True
    return False

def resolve_split_dir(base: Path, split: str) -> Path:
    return base if split == "__flat__" else (base / split)

def find_image_for_label(images_dir: Path, rel_lbl_path: Path) -> Optional[Path]:
    stem = rel_lbl_path.with_suffix("").as_posix()
    for ext in IMG_EXTS:
        cand = images_dir / f"{stem}{ext}"
        if cand.exists():
            return cand
    return None

def read_yolo_txt(p: Optional[Path]) -> List[Tuple[int,float,float,float,float]]:
    rows=[]
    if p is None or (not p.exists()):
        return rows
    for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines():
        m = YOLO_RE.match(ln)
        if not m:
            continue
        cls = int(float(m.group(1)))
        cx,cy,w,h = map(float, m.groups()[1:])
        rows.append((cls,cx,cy,w,h))
    return rows

def write_yolo_txt(p: Path, rows: List[Tuple[int,float,float,float,float]]):
    p.parent.mkdir(parents=True, exist_ok=True)
    with open(p, "w", encoding="utf-8") as f:
        for cls,cx,cy,w,h in rows:
            cx = min(max(cx, 0.0), 1.0)
            cy = min(max(cy, 0.0), 1.0)
            w  = min(max(w,  0.0), 1.0)
            h  = min(max(h,  0.0), 1.0)
            f.write(f"{cls} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}\n")

# -------------------------
# Metrics: IoU (normalized cxcywh)
# -------------------------
def _to_xyxy(cx, cy, w, h):
    return cx - w/2, cy - h/2, cx + w/2, cy + h/2

def box_iou(a, b, eps=1e-9) -> float:
    ax1, ay1, ax2, ay2 = _to_xyxy(*a)
    bx1, by1, bx2, by2 = _to_xyxy(*b)
    inter_x1 = max(ax1, bx1); inter_y1 = max(ay1, by1)
    inter_x2 = min(ax2, bx2); inter_y2 = min(ay2, by2)
    iw = max(0.0, inter_x2 - inter_x1); ih = max(0.0, inter_y2 - inter_y1)
    inter = iw * ih
    area_a = max(0.0, ax2-ax1) * max(0.0, ay2-ay1)
    area_b = max(0.0, bx2-bx1) * max(0.0, by2-by1)
    union = area_a + area_b - inter
    return float(inter / (union + eps))

# -------------------------
# ckpt parsing + model build
# -------------------------
def _get_state_dict_from_ckpt(ckpt: Dict[str, Any]) -> Dict[str, Any]:
    for k in ["model_state_dict", "model", "state_dict", "net", "weights"]:
        sd = ckpt.get(k, None)
        if isinstance(sd, dict) and sd:
            return sd
    raise KeyError("Cannot find model state dict in ckpt.")

def build_model_from_ckpt(ckpt: Dict[str, Any]) -> Tuple[torch.nn.Module, Dict[str, Any]]:
    exp = ckpt.get("exp", {}) or {}
    cfg = ckpt.get("config", {}) or {}
    meta = {}
    meta.update(cfg)
    meta.update(exp)

    backbone_key = meta.get("backbone", None) or meta.get("backbone_key", None)
    assert backbone_key is not None, "Could not find backbone info in ckpt exp/config."

    use_context = bool(meta.get("use_context", True))
    d_model  = int(meta.get("d_model", 512))
    nhead    = int(meta.get("nhead", 8))
    nlayers  = int(meta.get("nlayers", 2))
    disable_dist_mlp = bool(meta.get("disable_dist_mlp", False))
    share_backbone = bool(meta.get("share_backbone", False))

    cand_kwargs = dict(
        backbone_key=backbone_key,
        d_model=d_model,
        nhead=nhead,
        nlayers=nlayers,
        use_context=use_context,
        disable_dist_mlp=disable_dist_mlp,
        share_backbone=share_backbone,
    )

    sig = inspect.signature(ReBox.__init__)
    allowed = set(sig.parameters.keys()) - {"self"}
    kwargs = {k: v for k, v in cand_kwargs.items() if k in allowed}

    model = ReBox(**kwargs)
    sd = _get_state_dict_from_ckpt(ckpt)
    inc = model.load_state_dict(sd, strict=False)
    mk = len(getattr(inc, "missing_keys", []) or [])
    uk = len(getattr(inc, "unexpected_keys", []) or [])
    if (mk + uk) > 0:
        print(f"     ⚠️ load_state_dict(strict=False): missing={mk}, unexpected={uk}")

    model.eval()
    meta["__init_kwargs_used"] = kwargs
    return model, meta

# -------------------------
# OUT_ROOT summary CSV helpers
# -------------------------
def find_orchestrator_summary_csv(out_root: Path) -> Optional[Path]:
    d = out_root / "_orchestrator_summary"
    if not d.is_dir():
        return None
    cands = sorted(list(d.glob("*.csv")))
    if not cands:
        return None
    cands = sorted(cands, key=lambda p: p.stat().st_mtime, reverse=True)
    for p in cands:
        if "summary__start_" in p.name.lower():
            return p
    return cands[0]

def _best_pt_from_row(out_root: Path, row: Dict[str, str]) -> Optional[Path]:
    bp = row.get("best_ckpt","") or row.get("best_pt","") or row.get("best_model","")
    if bp:
        p = Path(bp)
        if not p.is_absolute():
            p = (out_root / bp).resolve()
        if p.exists():
            return p

    log_csv = row.get("log_csv","")
    if not log_csv:
        return None
    p = Path(log_csv)
    if not p.is_absolute():
        p = (out_root / log_csv).resolve()

    run_dir = p.parent if (not p.exists()) else p.resolve().parent
    try:
        rel = run_dir.relative_to(out_root.resolve())
        best_pt = (out_root / "weights" / rel / "best.pt").resolve()
        if best_pt.exists():
            return best_pt
    except Exception:
        pass
    return None

def _row_matches_target(row: Dict[str, str], target_names: Optional[set], target_specs: Optional[List[Any]]) -> bool:
    if target_names:
        return str(row.get("case_name","")).strip() in target_names

    if target_specs:
        def f(x, d=0.0):
            try: return float(x)
            except: return float(d)
        def b(x):
            s = str(x).strip().lower()
            return (s in ["1","true","yes","y","t"])
        rbk = str(row.get("backbone_key","")).strip()
        rctx = b(row.get("use_context","false"))
        rdist = b(row.get("disable_dist_mlp","false"))
        rshare = b(row.get("share_backbone","false"))
        rwl = f(row.get("w_listmle", 0.0))
        rwm = f(row.get("w_mono", 0.0))
        rwe = f(row.get("w_mse", 0.0))

        for sp in target_specs:
            try:
                if str(getattr(sp,"backbone_key","")).strip() != rbk: continue
                if bool(getattr(sp,"use_context",True)) != rctx: continue
                if bool(getattr(sp,"disable_dist_mlp",False)) != rdist: continue
                if bool(getattr(sp,"share_backbone",False)) != rshare: continue
                if abs(float(getattr(sp,"w_listmle",0.0)) - rwl) > 1e-6: continue
                if abs(float(getattr(sp,"w_mono",0.0)) - rwm) > 1e-6: continue
                if abs(float(getattr(sp,"w_mse",0.0)) - rwe) > 1e-6: continue
                return True
            except Exception:
                continue
        return False

    return True

def _parse_int(x, default=None):
    try:
        return int(float(str(x).strip()))
    except Exception:
        return default

def best_ckpts_from_summary(out_root: Path, dataset_name: str,
                           target_names: Optional[set],
                           target_specs: Optional[List[Any]],
                           seed: Optional[int] = None) -> List[Tuple[str, Path]]:
    summ = find_orchestrator_summary_csv(out_root)
    if summ is None or (not summ.exists()):
        return []

    dsq = _norm_name(dataset_name)
    out: List[Tuple[str, Path]] = []
    try:
        with open(summ, "r", encoding="utf-8", newline="") as f:
            rd = csv.DictReader(f)
            for row in rd:
                if _norm_name(row.get("dataset","")) != dsq:
                    continue

                if seed is not None:
                    rseed = _parse_int(row.get("seed",""), default=None)
                    if rseed is None or int(rseed) != int(seed):
                        continue

                if not _row_matches_target(row, target_names, target_specs):
                    continue

                best_pt = _best_pt_from_row(out_root, row)
                if best_pt is None:
                    continue
                out.append((str(row.get("case_name","")).strip(), best_pt))
    except Exception as e:
        print(f"⚠️ Failed to read summary csv: {summ} ({e})")
        return []

    uniq, seen = [], set()
    for case_name, p in out:
        k = str(p.resolve())
        if k in seen:
            continue
        seen.add(k)
        uniq.append((case_name, p))
    return uniq

def scan_best_ckpts_for_dataset(out_root: Path, dataset_name: str) -> List[Path]:
    wroot = out_root / "weights"
    if not wroot.exists():
        return []
    dsq = _norm_name(dataset_name)
    cand_dirs = []
    for d in wroot.iterdir():
        if d.is_dir() and _norm_name(d.name) == dsq:
            cand_dirs.append(d)
    if not cand_dirs:
        return []
    best_pts = []
    for d in cand_dirs:
        best_pts.extend(list(d.rglob("best.pt")))
    return sorted(best_pts)

def _float_close(a: float, b: float, tol: float = 1e-6) -> bool:
    return abs(float(a) - float(b)) <= tol

def _match_ckpt_meta_to_spec(meta: Dict[str, Any], spec: Any, seed: Optional[int] = None) -> bool:
    if seed is not None:
        try:
            mseed = int(meta.get("seed", meta.get("random_seed", -1)))
            if mseed != int(seed):
                return False
        except Exception:
            return False

    bk = meta.get("backbone_key", None) or meta.get("backbone", None)
    if bk is None:
        return False
    if str(bk).strip() != str(getattr(spec, "backbone_key", "")).strip():
        return False

    def _b(v, default=False):
        if isinstance(v, bool):
            return v
        s = str(v).strip().lower()
        if s in ["1","true","yes","y","t"]: return True
        if s in ["0","false","no","n","f"]: return False
        return bool(default)

    if _b(meta.get("use_context", True)) != bool(getattr(spec, "use_context", True)):
        return False
    if _b(meta.get("disable_dist_mlp", False)) != bool(getattr(spec, "disable_dist_mlp", False)):
        return False
    if _b(meta.get("share_backbone", False)) != bool(getattr(spec, "share_backbone", False)):
        return False

    wl = float(meta.get("w_listmle", meta.get("w_listmle_loss", 0.0)) or 0.0)
    wm = float(meta.get("w_mono", meta.get("w_mono_loss", 0.0)) or 0.0)
    we = float(meta.get("w_mse", meta.get("w_mse_loss", 0.0)) or 0.0)

    if not _float_close(wl, float(getattr(spec, "w_listmle", 0.0)), 1e-6): return False
    if not _float_close(wm, float(getattr(spec, "w_mono", 0.0)), 1e-6): return False
    if not _float_close(we, float(getattr(spec, "w_mse", 0.0)), 1e-6): return False
    return True

def scan_best_ckpts_for_dataset_filtered(out_root: Path, dataset_name: str,
                                        target_specs: Optional[List[Any]],
                                        seed: Optional[int] = None) -> List[Tuple[str, Path]]:
    pts = scan_best_ckpts_for_dataset(out_root, dataset_name)
    if not pts:
        return []

    if not target_specs:
        out = []
        for p in pts:
            try:
                ckpt = torch.load(p, map_location="cpu")
                exp = ckpt.get("exp", {}) or {}
                cfg = ckpt.get("config", {}) or {}
                meta = {}
                meta.update(cfg); meta.update(exp)
                if seed is not None:
                    mseed = int(meta.get("seed", -999999))
                    if mseed != int(seed):
                        continue
                out.append(("", p))
            except Exception:
                continue
        uniq, seen = [], set()
        for cname, p in out:
            k = str(p.resolve())
            if k in seen: continue
            seen.add(k); uniq.append((cname, p))
        return uniq

    matched: List[Tuple[str, Path]] = []
    for p in pts:
        try:
            ckpt = torch.load(p, map_location="cpu")
            exp = ckpt.get("exp", {}) or {}
            cfg = ckpt.get("config", {}) or {}
            meta = {}
            meta.update(cfg); meta.update(exp)

            for sp in target_specs:
                if _match_ckpt_meta_to_spec(meta, sp, seed=seed):
                    cname = str(getattr(sp, "case_name", "")).strip()
                    matched.append((cname, p))
                    break
        except Exception:
            continue

    uniq, seen = [], set()
    for cname, p in matched:
        k = str(p.resolve())
        if k in seen: continue
        seen.add(k); uniq.append((cname, p))
    return uniq

def make_case_id(dataset: str, best_pt: Path) -> str:
    try:
        rel = best_pt.relative_to(OUT_ROOT / "weights" / dataset)
        return rel.parent.as_posix().replace("/", "__")
    except Exception:
        return best_pt.parent.as_posix().replace("/", "__")

# -------------------------
# candidate + crop inference
# -------------------------
def build_tf(crop_size: int):
    return transforms.Compose([
        transforms.Resize((crop_size, crop_size)),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
    ])

def _parse_case_tag_from_noise_dirname(noise_name: str) -> str:
    m = re.match(r"^labels_uniform_scaling_(.+)$", noise_name)
    if m:
        raw = m.group(1).replace("p",".")
        try:
            s = float(raw)
            return f"scale_{s:g}"
        except Exception:
            return "clean"
    m = re.match(r"^labels_boundary_jitter_(\d+)$", noise_name)
    if m:
        return f"side_{int(m.group(1))}"
    return "clean"

def iter_noise_label_dirs(dataset_root: Path) -> List[Tuple[str, Path]]:
    out = []
    for s in UNIFORM_SCALING_FACTORS:
        name = f"labels_uniform_scaling_{s:g}"
        p = dataset_root / name
        if p.exists():
            out.append((name, p))
    for k in JITTER_PATTERNS:
        name = f"labels_boundary_jitter_{int(k)}"
        p = dataset_root / name
        if p.exists():
            out.append((name, p))
    return out

def free_cuda():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

@torch.inference_mode()
def refine_one_label_file(
    model: torch.nn.Module,
    img_path: Path,
    noisy_lbl_path: Path,
    clean_lbl_path: Optional[Path],
    out_lbl_path: Path,
    cand_cfg: Any,                    # CandidateCfg
    jitter_patterns: List[int],
    crop_size: int,
    use_context: bool,
    use_amp: bool,
    device: torch.device,
    case_tag: str,
) -> Dict[str, float]:
    noisy_rows = read_yolo_txt(noisy_lbl_path)
    if len(noisy_rows) == 0:
        if OVERWRITE_EXISTING or (not out_lbl_path.exists()):
            write_yolo_txt(out_lbl_path, [])
        return {"n_boxes": 0.0, "sum_iou_noisy": 0.0, "sum_iou_refined": 0.0}

    clean_rows = read_yolo_txt(clean_lbl_path) if (WRITE_METRICS_CSV and clean_lbl_path is not None and clean_lbl_path.exists()) else None
    tf = build_tf(crop_size)

    with Image.open(img_path).convert("RGB") as img:
        xctx = tf(img).to(device) if use_context else None

        refined_rows = []
        sum_iou_noisy = 0.0
        sum_iou_ref   = 0.0
        n_boxes = 0

        GROUP_BS = 4
        groups_Xc, groups_D, groups_meta = [], [], []

        def flush():
            nonlocal sum_iou_noisy, sum_iou_ref, n_boxes, refined_rows
            if not groups_Xc:
                return
            Bs = len(groups_Xc)
            maxN = max(x.shape[0] for x in groups_Xc)

            Xc = torch.zeros((Bs, maxN, 3, crop_size, crop_size), device=device)
            D  = torch.zeros((Bs, maxN), device=device)
            mask_valid = torch.zeros((Bs, maxN), dtype=torch.bool, device=device)

            if use_context:
                Xctx = torch.stack([xctx for _ in range(Bs)], dim=0)
            else:
                Xctx = torch.zeros((Bs, 3, crop_size, crop_size), device=device)

            for b in range(Bs):
                n = groups_Xc[b].shape[0]
                Xc[b, :n] = groups_Xc[b]
                D[b, :n]  = groups_D[b]
                mask_valid[b, :n] = True

            amp_ok = (use_amp and device.type == "cuda" and torch.cuda.is_available())
            if amp_ok:
                with torch.amp.autocast(device_type="cuda", enabled=True):
                    scores = model(Xc, Xctx, D, mask_valid)  # (B,N)
            else:
                scores = model(Xc, Xctx, D, mask_valid)

            for b in range(Bs):
                valid = mask_valid[b]
                if valid.sum() == 0:
                    cls, noisy_box = groups_meta[b]["cls"], groups_meta[b]["noisy_box"]
                    refined_rows.append((cls, *noisy_box))
                    continue
                sb = scores[b, valid].detach()
                best_idx = int(torch.argmax(sb).item())
                cand_boxes = groups_meta[b]["cand_boxes"]
                refined_box = cand_boxes[best_idx]
                cls = groups_meta[b]["cls"]
                refined_rows.append((cls, *refined_box))

                if groups_meta[b]["clean_box"] is not None:
                    clean_box = groups_meta[b]["clean_box"]
                    noisy_box = groups_meta[b]["noisy_box"]
                    sum_iou_noisy += box_iou(noisy_box, clean_box)
                    sum_iou_ref   += box_iou(refined_box, clean_box)
                    n_boxes += 1

            groups_Xc.clear(); groups_D.clear(); groups_meta.clear()

        for i, (cls, cx, cy, w, h) in enumerate(noisy_rows):
            clean_for_hash = (clean_lbl_path if (clean_lbl_path is not None and clean_lbl_path.exists()) else noisy_lbl_path)

            cand_boxes, cand_dists, _ = build_candidates_for_one_box(
                anchor_box=(cx,cy,w,h),
                case_tag=case_tag,
                clean_lbl_path=Path(clean_for_hash),
                box_index=i,
                seed=int(getattr(cand_cfg, "seed", 42)) if hasattr(cand_cfg, "seed") else 42,
                cand_cfg=cand_cfg,
                jitter_patterns=jitter_patterns,
            )

            if not cand_boxes:
                refined_rows.append((cls, cx,cy,w,h))
                continue

            cand_tensors = []
            for (cx2,cy2,w2,h2) in cand_boxes:
                W,H = img.size
                x1 = int(round((cx2 - w2/2) * W)); x2 = int(round((cx2 + w2/2) * W))
                y1 = int(round((cy2 - h2/2) * H)); y2 = int(round((cy2 + h2/2) * H))
                x1 = max(0, min(W-1, x1)); x2 = max(1, min(W, x2))
                y1 = max(0, min(H-1, y1)); y2 = max(1, min(H, y2))
                crop = img.crop((x1,y1,x2,y2))
                cand_tensors.append(tf(crop).unsqueeze(0))

            Xc_group = torch.cat(cand_tensors, dim=0).to(device)
            D_group  = torch.tensor(cand_dists, dtype=torch.float32, device=device)

            clean_box = None
            if clean_rows is not None and i < len(clean_rows):
                _, ccx, ccy, cw, ch = clean_rows[i]
                clean_box = (ccx, ccy, cw, ch)

            groups_Xc.append(Xc_group)
            groups_D.append(D_group)
            groups_meta.append({
                "cls": int(cls),
                "clean_box": clean_box,
                "noisy_box": (cx,cy,w,h),
                "cand_boxes": cand_boxes,
            })

            if len(groups_Xc) >= GROUP_BS:
                flush()

        flush()

    if OVERWRITE_EXISTING or (not out_lbl_path.exists()):
        write_yolo_txt(out_lbl_path, refined_rows)

    return {"n_boxes": float(n_boxes), "sum_iou_noisy": float(sum_iou_noisy), "sum_iou_refined": float(sum_iou_ref)}

# -------------------------
# MAIN
# -------------------------
device = torch.device(DEVICE if (torch.cuda.is_available() and str(DEVICE).startswith("cuda")) else "cpu")
all_metrics_rows = []

for dataset in (TARGET_DATASETS or []):
    ds_root = resolve_dataset_root(dataset)
    if ds_root is None or (not ds_root.exists()):
        print(f"⚠️ Skip missing dataset root: {DATASETS_ROOT/dataset}")
        continue

    images_root = ds_root / "images"
    labels_root = ds_root / "labels"   # clean labels
    if not images_root.exists() or not labels_root.exists():
        print(f"⚠️ Skip (missing images/labels): {ds_root.name}")
        continue

    splits_all = list_existing_splits(images_root)
    splits = [s for s in splits_all if _split_selected(s)]
    if not splits:
        splits = ["__flat__"]

    noise_dirs = iter_noise_label_dirs(ds_root)
    if not noise_dirs:
        print(f"⚠️ No noise label dirs in {ds_root} (labels_uniform_scaling_*/labels_boundary_jitter_*). Skip.")
        continue

    print("\n" + "="*110)
    print(f"[DATASET] {ds_root.name} | noises={len(noise_dirs)} | splits={splits}")
    print(f"          OUT_ROOT={OUT_ROOT}")
    print(f"          output base => {REFINES_OUT_ROOT} (subfolders per seed)")
    print("="*110)

    # ✅ seed-specific refinement
    for seed_req in SEEDS:
        # ✅ Separate seed folder at top level
        SEED_REFINES_ROOT = REFINES_OUT_ROOT / f"seed{int(seed_req):03d}"
        SEED_REFINES_ROOT.mkdir(parents=True, exist_ok=True)

        # 1) summary csv first: (dataset + seed + selected cases) -> best.pt
        case_ckpts = best_ckpts_from_summary(
            out_root=OUT_ROOT,
            dataset_name=ds_root.name,
            target_names=TARGET_CASE_NAME_SET,
            target_specs=TARGET_CASE_SPECS,
            seed=int(seed_req),
        )

        # 2) fallback: scan weights + (if possible) spec matching + seed filter
        if not case_ckpts:
            case_ckpts = scan_best_ckpts_for_dataset_filtered(
                out_root=OUT_ROOT,
                dataset_name=ds_root.name,
                target_specs=TARGET_CASE_SPECS,
                seed=int(seed_req),
            )

        if not case_ckpts:
            print(f"  ⚠️ [seed={seed_req}] No matched best.pt found. Skip seed.")
            continue

        uniq, seen = [], set()
        for cname, p in case_ckpts:
            k = str(Path(p).resolve())
            if k in seen: continue
            seen.add(k); uniq.append((cname, p))
        case_ckpts = uniq

        print(f"\n  [SEED] {seed_req} | matched_ckpts={len(case_ckpts)}")
        print(f"         seed output => {SEED_REFINES_ROOT / ds_root.name}")

        for case_name_from_list, best_pt in case_ckpts:
            ckpt = torch.load(best_pt, map_location="cpu")
            model, meta = build_model_from_ckpt(ckpt)
            model = model.to(device)

            crop_size = int(meta.get("crop_size", 224))
            use_context = bool(meta.get("use_context", True))
            use_amp = bool(meta.get("use_amp", True))
            seed = int(meta.get("seed", seed_req))

            # Restore cand_cfg (Cell2 exp_meta["cand_cfg"])
            exp = ckpt.get("exp", {}) or {}
            cand_cfg_dict = (exp.get("cand_cfg", None) or meta.get("cand_cfg", None))
            if not isinstance(cand_cfg_dict, dict):
                cand_cfg_dict = {
                    "cand_uniform_scaling_factors": meta.get("cand_uniform_scaling_factors", [0.6,0.7,0.8,0.9,1.0,1.1,1.2,1.3,1.4]),
                    "cand_side_ks": meta.get("cand_side_ks", JITTER_PATTERNS),
                    "num_border_perturb": int(meta.get("num_border_perturb", 10)),
                    "include_anchor": True,
                    "include_inverse": True,
                    "inverse_jitter": float(meta.get("inverse_jitter", 0.03)),
                    "require_mixed_signs": True,
                    "max_candidates_per_group": int(meta.get("max_candidates_per_group", 60)),
                }
            cand_cfg = CandidateCfg(**{k: cand_cfg_dict[k] for k in cand_cfg_dict.keys()
                                      if k in inspect.signature(CandidateCfg).parameters})
            setattr(cand_cfg, "seed", int(seed))  # ✅ seed-specific candidate reproducibility

            jitter_patterns = list(exp.get("jitter_patterns", JITTER_PATTERNS))

            # case_id (no need for seed prefix since separated into seed folder)
            ds_key = ckpt.get("config", {}).get("dataset_name", ds_root.name) if isinstance(ckpt.get("config", {}), dict) else ds_root.name
            case_id = make_case_id(ds_key, best_pt)

            # ✅ Final refine_root: OUT_ROOT/refines/seedXXX/<dataset>/<case_id>
            refine_root = SEED_REFINES_ROOT / ds_root.name / case_id
            refine_root.mkdir(parents=True, exist_ok=True)

            pretty_case_name = case_name_from_list or str(meta.get("case_name", "")) or "(unknown_case)"
            print(f"\n    ▶ Case: {case_id}")
            print(f"       - seed_req={seed_req} | ckpt_seed={seed}")
            print(f"       - selected_case_name={pretty_case_name}")
            print(f"       - backbone={meta.get('backbone') or meta.get('backbone_key')} | crop={crop_size} | ctx={int(use_context)} | amp={int(use_amp)}")
            print(f"       - ckpt={best_pt}")
            print(f"       - out_root={refine_root}")

            for noise_name, noise_dir in noise_dirs:
                case_tag = _parse_case_tag_from_noise_dirname(noise_name)

                for split in splits:
                    img_dir = resolve_split_dir(images_root, split)
                    clean_lbl_dir = resolve_split_dir(labels_root, split)
                    noisy_lbl_dir = resolve_split_dir(noise_dir, split)

                    if not img_dir.exists() or not noisy_lbl_dir.exists() or not clean_lbl_dir.exists():
                        continue

                    out_lbl_dir = refine_root / noise_name / (split if split != "__flat__" else "")
                    done_marker = out_lbl_dir / "_DONE.json"
                    if done_marker.exists() and not OVERWRITE_EXISTING:
                        continue

                    lbl_files = sorted([p for p in noisy_lbl_dir.rglob("*.txt") if p.is_file()])
                    if MAX_FILES_PER_SPLIT is not None:
                        lbl_files = lbl_files[:int(MAX_FILES_PER_SPLIT)]
                    if not lbl_files:
                        continue

                    print(f"      - refine {noise_name} / {split} | files={len(lbl_files)}")
                    print(f"        -> save to: {out_lbl_dir}")

                    sum_iou_noisy = 0.0
                    sum_iou_ref = 0.0
                    n_boxes = 0
                    n_files_done = 0
                    n_files_skipped = 0

                    for lf in lbl_files:
                        rel = lf.relative_to(noisy_lbl_dir)
                        out_lf = out_lbl_dir / rel

                        if out_lf.exists() and (not OVERWRITE_EXISTING):
                            n_files_skipped += 1
                            continue

                        img_path = find_image_for_label(img_dir, rel)
                        if img_path is None:
                            continue

                        clean_lf = clean_lbl_dir / rel
                        m = refine_one_label_file(
                            model=model,
                            img_path=img_path,
                            noisy_lbl_path=lf,
                            clean_lbl_path=clean_lf if clean_lf.exists() else None,
                            out_lbl_path=out_lf,
                            cand_cfg=cand_cfg,
                            jitter_patterns=jitter_patterns,
                            crop_size=crop_size,
                            use_context=use_context,
                            use_amp=use_amp,
                            device=device,
                            case_tag=case_tag,
                        )
                        n_files_done += 1
                        sum_iou_noisy += m["sum_iou_noisy"]
                        sum_iou_ref   += m["sum_iou_refined"]
                        n_boxes       += int(m["n_boxes"])

                    out_lbl_dir.mkdir(parents=True, exist_ok=True)
                    payload = {
                        "dataset": ds_root.name,
                        "seed_req": int(seed_req),
                        "seed_ckpt": int(seed),
                        "seed_dir": f"seed{int(seed_req):03d}",
                        "case_id": case_id,
                        "selected_case_name": pretty_case_name,
                        "noise_name": noise_name,
                        "case_tag": case_tag,
                        "split": split,
                        "files_total": len(lbl_files),
                        "files_done": n_files_done,
                        "files_skipped": n_files_skipped,
                        "n_boxes": n_boxes,
                        "mean_iou_noisy": (sum_iou_noisy / max(n_boxes, 1)) if WRITE_METRICS_CSV else None,
                        "mean_iou_refined": (sum_iou_ref / max(n_boxes, 1)) if WRITE_METRICS_CSV else None,
                        "delta_iou": ((sum_iou_ref - sum_iou_noisy) / max(n_boxes, 1)) if WRITE_METRICS_CSV else None,
                        "best_pt": str(best_pt),
                        "out_dir": str(out_lbl_dir),
                    }
                    done_marker.write_text(json.dumps(payload, indent=2), encoding="utf-8")

                    if WRITE_METRICS_CSV:
                        all_metrics_rows.append(payload)

            del model
            free_cuda()

# =========================
# save metrics (per-seed + overall)
# =========================
if WRITE_METRICS_CSV and all_metrics_rows:
    import pandas as pd

    df_all = pd.DataFrame(all_metrics_rows)

    # (1) overall
    save_all = OUT_ROOT / "_refine_summary_metrics.csv"
    df_all.to_csv(save_all, index=False)
    print(f"\n✅ Saved refine metrics summary (overall): {save_all}")

    # (2) per-seed (based on seed_req)
    if "seed_req" in df_all.columns:
        for seed_val, df_s in df_all.groupby("seed_req"):
            seed_val_int = int(seed_val)
            save_seed = OUT_ROOT / f"_refine_summary_metrics__seed{seed_val_int:03d}.csv"
            df_s.to_csv(save_seed, index=False)
            print(f"✅ Saved refine metrics summary (seed={seed_val_int}): {save_seed}")
    else:
        print("⚠️ 'seed_req' column not found; cannot write per-seed CSVs.")
else:
    print("\n✅ Done (no metrics csv requested or no rows).")



[REFINE] Case selection
 - using case_name filter: 4 cases
   • baseline_both_31_absn100_ctx1_dist0_E1
   • exp1_both_15_absn100_ctx1_dist0_E1
   • exp2_isotropic_only_15_absn100_ctx1_dist0_E1
   • exp3_borderwise_only_15_absn100_ctx1_dist0_E1
[REFINE] Seeds: [42]
[REFINE] Output root: experiments_ablation(scale_side, start=noise, generation_rule_control)-(100)/refines  (create subfolders per seed)

[DATASET] kitti | noises=13 | splits=['train', 'val']
          OUT_ROOT=experiments_ablation(scale_side, start=noise, generation_rule_control)-(100)
          output base => experiments_ablation(scale_side, start=noise, generation_rule_control)-(100)/refines (subfolders per seed)

  [SEED] 42 | matched_ckpts=4
         seed output => experiments_ablation(scale_side, start=noise, generation_rule_control)-(100)/refines/seed042/kitti


  ckpt = torch.load(best_pt, map_location="cpu")



    ▶ Case: __home__ISW__project__experiments_ablation(scale_side, start=noise, generation_rule_control)-(100)__weights__kitti__set_densenet_noisy2clean__absn_0100_ctx1__d512_h8_l2_dist0_share0_wL0_wM0_wE1_rnd10_inv1_jit0p03_cap60
       - seed_req=42 | ckpt_seed=42
       - selected_case_name=baseline_both_31_absn100_ctx1_dist0_E1
       - backbone=densenet | crop=224 | ctx=1 | amp=1
       - ckpt=/home/ISW/project/experiments_ablation(scale_side, start=noise, generation_rule_control)-(100)/weights/kitti/set_densenet_noisy2clean/absn_0100_ctx1/d512_h8_l2_dist0_share0_wL0_wM0_wE1_rnd10_inv1_jit0p03_cap60/best.pt
       - out_root=experiments_ablation(scale_side, start=noise, generation_rule_control)-(100)/refines/seed042/kitti/__home__ISW__project__experiments_ablation(scale_side, start=noise, generation_rule_control)-(100)__weights__kitti__set_densenet_noisy2clean__absn_0100_ctx1__d512_h8_l2_dist0_share0_wL0_wM0_wE1_rnd10_inv1_jit0p03_cap60
      - refine labels_uniform_scaling_0.6 /

      - refine labels_uniform_scaling_0.6 / val | files=100
        -> save to: experiments_ablation(scale_side, start=noise, generation_rule_control)-(100)/refines/seed042/kitti/__home__ISW__project__experiments_ablation(scale_side, start=noise, generation_rule_control)-(100)__weights__kitti__set_densenet_noisy2clean__absn_0100_ctx1__d512_h8_l2_dist0_share0_wL0_wM0_wE1_rnd10_inv1_jit0p03_cap60/labels_uniform_scaling_0.6/val
      - refine labels_uniform_scaling_0.7 / train | files=100
        -> save to: experiments_ablation(scale_side, start=noise, generation_rule_control)-(100)/refines/seed042/kitti/__home__ISW__project__experiments_ablation(scale_side, start=noise, generation_rule_control)-(100)__weights__kitti__set_densenet_noisy2clean__absn_0100_ctx1__d512_h8_l2_dist0_share0_wL0_wM0_wE1_rnd10_inv1_jit0p03_cap60/labels_uniform_scaling_0.7/train
      - refine labels_uniform_scaling_0.7 / val | files=100
        -> save to: experiments_ablation(scale_side, start=noise, generation_r

# Full noise dataset refinement

In [15]:
# # ==========================================
# # Cell 5) Full-dataset Refinement Writer — FINAL (NO seed/case_id in output path)
# #   - Refine ALL files under /datasets/<ds>/<noise_name>/...
# #   - Select ONE ckpt per dataset via TARGET_CASE_NAMES (+ SEEDS-aware)
# #   - ✅ Save refined labels under:
# #       /datasets/<ds>/refine(n_data)/<noise_name>/<split>/.../*.txt
# #   - No seedXXX / no case_id in output path
# #
# # [PATCHED]
# #   ✅ GPU-accelerated crop/resize/normalize via ROIAlign (removes PIL crop loop bottleneck)
# #   ✅ Context tensor also built on GPU (single image -> interpolate)
# #   ✅ Larger GROUP_BS default (tunable) for better GPU utilization
# # ==========================================

# from __future__ import annotations
# import os, re, json, gc, csv, inspect, warnings
# from pathlib import Path
# from typing import Dict, List, Tuple, Optional, Any

# import numpy as np
# import torch
# import torch.nn.functional as F
# from PIL import Image

# from torchvision import transforms  # kept (may be used elsewhere)
# import torchvision.transforms.functional as TVF
# from torchvision.ops import roi_align

# # -------------------------
# # USER CONFIG
# # -------------------------
# DATASETS_ROOT = Path("/home/ISW/project/datasets")

# n_data = 50  # Only used in folder name: refine(10)

# OUT_ROOT = Path(f"./experiments_ablation(scale_side, start=noise)-({n_data})").resolve()

# TARGET_DATASETS: Optional[List[str]] = [
#     "kitti",
#     "homeobjects-3K",
#     "african-wildlife",
#     "construction-ppe",
#     "brain-tumor",
#     "BCCD",
#     "signature",
#     "medical-pills",
#     "VOC",
# ]

# # ⚠️ Since output path doesn't include seed, recommend using only 1 seed per dataset to avoid conflicts
# SEEDS: List[int] = [42]

# UNIFORM_SCALING_FACTORS   = [0.6, 0.8, 1.2, 1.4]
# JITTER_PATTERNS  = [1, 5, 9]

# # If None, perform all splits that exist under images/
# REFINE_SPLITS: Optional[List[str]] = None

# # For FULL execution: recommend None
# MAX_FILES_PER_SPLIT: Optional[int] = None

# OVERWRITE_EXISTING = False
# WRITE_METRICS_CSV = True

# DEVICE = str(globals().get("DEVICE", "cuda:0"))

# IMG_EXTS = [".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp"]
# YOLO_RE = re.compile(r"^\s*(\d+)\s+([\d\.eE+-]+)\s+([\d\.eE+-]+)\s+([\d\.eE+-]+)\s+([\d\.eE+-]+)")

# IMAGENET_MEAN = [0.485,0.456,0.406]
# IMAGENET_STD  = [0.229,0.224,0.225]

# warnings.filterwarnings("ignore", category=FutureWarning, message=r".*torch\.cuda\.amp\.autocast.*deprecated.*")
# warnings.filterwarnings("ignore", category=UserWarning, message=r".*nested tensors is in prototype stage.*")

# # -------------------------
# # CASE SELECTION (core)
# # -------------------------
# TARGET_CASE_NAMES = [
#     f"SINGLE_densenet_absn{n_data}_ctx1_dist0_E1",
# ]
# USE_CASE_SPECS_DEFAULT_IF_AVAILABLE = True  # Recommend True

# # -------------------------
# # REQUIRE: Cell2 loaded
# # -------------------------
# assert "ReBox" in globals(), "Cell 2 (ReBox) must be loaded first."
# assert "CandidateCfg" in globals() and "build_candidates_for_one_box" in globals(), "Cell2 candidate generation functions are required."

# # -------------------------
# # Speed knobs (tunable)
# # -------------------------
# # ROIAlign GPU crop makes forward faster; increasing GROUP_BS helps GPU utilization.
# GROUP_BS_DEFAULT = 16   # was 4
# # If you see VRAM spikes, lower this (e.g., 8). If GPU still underutilized, raise (e.g., 32).


# # -------------------------
# # Auto-resolve OUT_ROOT if needed
# # -------------------------
# def resolve_out_root(user_out: Path) -> Path:
#     if (user_out / "weights").is_dir():
#         return user_out

#     candidates = []
#     for w in Path(".").glob("**/weights"):
#         if not w.is_dir():
#             continue
#         root = w.parent
#         if any(x in str(root).lower() for x in ["site-packages", ".cache", "venv", "conda"]):
#             continue
#         candidates.append(root)

#     if not candidates:
#         print(f"⚠️ OUT_ROOT={user_out} has no weights/ and auto-search also failed.")
#         return user_out

#     def score(r: Path) -> Tuple[int, float]:
#         has_summary = int((r / "_orchestrator_summary").is_dir())
#         try:
#             mtime = (r / "weights").stat().st_mtime
#         except Exception:
#             mtime = 0.0
#         return (has_summary, mtime)

#     best = sorted(candidates, key=score, reverse=True)[0]
#     print(f"ℹ️ OUT_ROOT auto-corrected: {user_out} -> {best} (weights/ found)")
#     return best

# OUT_ROOT = resolve_out_root(OUT_ROOT)

# # -------------------------
# # Target case filter resolve
# # -------------------------
# def _norm_name(s: str) -> str:
#     return str(s).strip().lower()

# def resolve_target_cases() -> Tuple[Optional[set], Optional[List[Any]]]:
#     if TARGET_CASE_NAMES:
#         names = [str(x) for x in TARGET_CASE_NAMES if str(x).strip()]
#         return set(names), None

#     if USE_CASE_SPECS_DEFAULT_IF_AVAILABLE and ("CASE_SPECS_DEFAULT" in globals()):
#         specs = globals().get("CASE_SPECS_DEFAULT", None)
#         if isinstance(specs, list) and specs:
#             names = []
#             for c in specs:
#                 if hasattr(c, "case_name"):
#                     names.append(str(c.case_name))
#             if names:
#                 return set(names), specs
#             return None, specs

#     return None, None

# TARGET_CASE_NAME_SET, TARGET_CASE_SPECS = resolve_target_cases()

# print("\n" + "="*110)
# print("[REFINE-FULL] Output rule (NO seed/case_id):")
# print("✅ /datasets/<ds>/refine(n_data)/<noise_name>/<split>/.../*.txt")
# print(f" - n_data={n_data}")
# print(f" - OUT_ROOT(ckpt)={OUT_ROOT}")
# print(f" - TARGET_CASE_NAMES={list(TARGET_CASE_NAME_SET or [])}")
# print(f" - SEEDS={SEEDS}")
# print("="*110)

# # -------------------------
# # Dataset utils
# # -------------------------
# def resolve_dataset_root(dataset_name: str) -> Optional[Path]:
#     cand = DATASETS_ROOT / dataset_name
#     if cand.exists():
#         return cand
#     lname = dataset_name.lower()
#     for d in DATASETS_ROOT.iterdir():
#         if d.is_dir() and d.name.lower() == lname:
#             return d
#     alt = dataset_name.replace("3K", "3k").replace("3k", "3K")
#     cand2 = DATASETS_ROOT / alt
#     if cand2.exists():
#         return cand2
#     for d in DATASETS_ROOT.iterdir():
#         if d.is_dir() and d.name.lower() == alt.lower():
#             return d
#     return None

# def list_existing_splits(images_root: Path) -> List[str]:
#     if not images_root.exists():
#         return ["__flat__"]
#     subs = sorted([p.name for p in images_root.iterdir() if p.is_dir()])
#     return subs if subs else ["__flat__"]

# def _split_selected(split: str) -> bool:
#     if split == "__flat__":
#         return True
#     if not REFINE_SPLITS:
#         return True
#     s = split.lower()
#     for t in REFINE_SPLITS:
#         tl = str(t).lower()
#         if s == tl or s.startswith(tl):
#             return True
#     if s.startswith("val") and any(str(t).lower()=="val" for t in REFINE_SPLITS):
#         return True
#     if s.startswith("valid") and any(str(t).lower()=="valid" for t in REFINE_SPLITS):
#         return True
#     return False

# def resolve_split_dir(base: Path, split: str) -> Path:
#     return base if split == "__flat__" else (base / split)

# def find_image_for_label(images_dir: Path, rel_lbl_path: Path) -> Optional[Path]:
#     stem = rel_lbl_path.with_suffix("").as_posix()
#     for ext in IMG_EXTS:
#         cand = images_dir / f"{stem}{ext}"
#         if cand.exists():
#             return cand
#     return None

# def read_yolo_txt(p: Optional[Path]) -> List[Tuple[int,float,float,float,float]]:
#     rows=[]
#     if p is None or (not p.exists()):
#         return rows
#     for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines():
#         m = YOLO_RE.match(ln)
#         if not m:
#             continue
#         cls = int(float(m.group(1)))
#         cx,cy,w,h = map(float, m.groups()[1:])
#         rows.append((cls,cx,cy,w,h))
#     return rows

# def write_yolo_txt(p: Path, rows: List[Tuple[int,float,float,float,float]]):
#     p.parent.mkdir(parents=True, exist_ok=True)
#     with open(p, "w", encoding="utf-8") as f:
#         for cls,cx,cy,w,h in rows:
#             cx = min(max(cx, 0.0), 1.0)
#             cy = min(max(cy, 0.0), 1.0)
#             w  = min(max(w,  0.0), 1.0)
#             h  = min(max(h,  0.0), 1.0)
#             f.write(f"{cls} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}\n")

# # -------------------------
# # Metrics: IoU
# # -------------------------
# def _to_xyxy(cx, cy, w, h):
#     return cx - w/2, cy - h/2, cx + w/2, cy + h/2

# def box_iou(a, b, eps=1e-9) -> float:
#     ax1, ay1, ax2, ay2 = _to_xyxy(*a)
#     bx1, by1, bx2, by2 = _to_xyxy(*b)
#     inter_x1 = max(ax1, bx1); inter_y1 = max(ay1, by1)
#     inter_x2 = min(ax2, bx2); inter_y2 = min(ay2, by2)
#     iw = max(0.0, inter_x2 - inter_x1); ih = max(0.0, inter_y2 - inter_y1)
#     inter = iw * ih
#     area_a = max(0.0, ax2-ax1) * max(0.0, ay2-ay1)
#     area_b = max(0.0, bx2-bx1) * max(0.0, by2-by1)
#     union = area_a + area_b - inter
#     return float(inter / (union + eps))

# # -------------------------
# # ckpt parsing + model build
# # -------------------------
# def _get_state_dict_from_ckpt(ckpt: Dict[str, Any]) -> Dict[str, Any]:
#     for k in ["model_state_dict", "model", "state_dict", "net", "weights"]:
#         sd = ckpt.get(k, None)
#         if isinstance(sd, dict) and sd:
#             return sd
#     raise KeyError("Cannot find model state dict in ckpt.")

# def build_model_from_ckpt(ckpt: Dict[str, Any]) -> Tuple[torch.nn.Module, Dict[str, Any]]:
#     exp = ckpt.get("exp", {}) or {}
#     cfg = ckpt.get("config", {}) or {}
#     meta = {}
#     meta.update(cfg)
#     meta.update(exp)

#     backbone_key = meta.get("backbone", None) or meta.get("backbone_key", None)
#     assert backbone_key is not None, "Could not find backbone info in ckpt exp/config."

#     use_context = bool(meta.get("use_context", True))
#     d_model  = int(meta.get("d_model", 512))
#     nhead    = int(meta.get("nhead", 8))
#     nlayers  = int(meta.get("nlayers", 2))
#     disable_dist_mlp = bool(meta.get("disable_dist_mlp", False))
#     share_backbone = bool(meta.get("share_backbone", False))

#     cand_kwargs = dict(
#         backbone_key=backbone_key,
#         d_model=d_model,
#         nhead=nhead,
#         nlayers=nlayers,
#         use_context=use_context,
#         disable_dist_mlp=disable_dist_mlp,
#         share_backbone=share_backbone,
#     )

#     sig = inspect.signature(ReBox.__init__)
#     allowed = set(sig.parameters.keys()) - {"self"}
#     kwargs = {k: v for k, v in cand_kwargs.items() if k in allowed}

#     model = ReBox(**kwargs)
#     sd = _get_state_dict_from_ckpt(ckpt)
#     inc = model.load_state_dict(sd, strict=False)
#     mk = len(getattr(inc, "missing_keys", []) or [])
#     uk = len(getattr(inc, "unexpected_keys", []) or [])
#     if (mk + uk) > 0:
#         print(f"     ⚠️ load_state_dict(strict=False): missing={mk}, unexpected={uk}")

#     model.eval()
#     meta["__init_kwargs_used"] = kwargs
#     return model, meta

# # -------------------------
# # OUT_ROOT summary CSV helpers
# # -------------------------
# def find_orchestrator_summary_csv(out_root: Path) -> Optional[Path]:
#     d = out_root / "_orchestrator_summary"
#     if not d.is_dir():
#         return None
#     cands = sorted(list(d.glob("*.csv")))
#     if not cands:
#         return None
#     cands = sorted(cands, key=lambda p: p.stat().st_mtime, reverse=True)
#     for p in cands:
#         if "summary__start_" in p.name.lower():
#             return p
#     return cands[0]

# def _parse_int(x, default=None):
#     try:
#         return int(float(str(x).strip()))
#     except Exception:
#         return default

# def _best_pt_from_row(out_root: Path, row: Dict[str, str]) -> Optional[Path]:
#     bp = row.get("best_ckpt","") or row.get("best_pt","") or row.get("best_model","")
#     if bp:
#         p = Path(bp)
#         if not p.is_absolute():
#             p = (out_root / bp).resolve()
#         if p.exists():
#             return p

#     log_csv = row.get("log_csv","")
#     if not log_csv:
#         return None
#     p = Path(log_csv)
#     if not p.is_absolute():
#         p = (out_root / log_csv).resolve()

#     run_dir = p.parent if (not p.exists()) else p.resolve().parent
#     try:
#         rel = run_dir.relative_to(out_root.resolve())
#         best_pt = (out_root / "weights" / rel / "best.pt").resolve()
#         if best_pt.exists():
#             return best_pt
#     except Exception:
#         pass
#     return None

# def best_ckpts_from_summary(out_root: Path, dataset_name: str,
#                            target_names: Optional[set],
#                            seed: Optional[int] = None) -> List[Tuple[str, Path]]:
#     summ = find_orchestrator_summary_csv(out_root)
#     if summ is None or (not summ.exists()):
#         return []

#     dsq = _norm_name(dataset_name)
#     out: List[Tuple[str, Path]] = []
#     with open(summ, "r", encoding="utf-8", newline="") as f:
#         rd = csv.DictReader(f)
#         for row in rd:
#             if _norm_name(row.get("dataset","")) != dsq:
#                 continue
#             if seed is not None:
#                 rseed = _parse_int(row.get("seed",""), default=None)
#                 if rseed is None or int(rseed) != int(seed):
#                     continue
#             if target_names:
#                 if str(row.get("case_name","")).strip() not in target_names:
#                     continue
#             best_pt = _best_pt_from_row(out_root, row)
#             if best_pt is None:
#                 continue
#             out.append((str(row.get("case_name","")).strip(), best_pt))

#     uniq, seen = [], set()
#     for cname, p in out:
#         k = str(p.resolve())
#         if k in seen: continue
#         seen.add(k); uniq.append((cname, p))
#     return uniq

# def scan_best_ckpts_for_dataset(out_root: Path, dataset_name: str) -> List[Path]:
#     wroot = out_root / "weights"
#     if not wroot.exists():
#         return []
#     dsq = _norm_name(dataset_name)
#     cand_dirs = []
#     for d in wroot.iterdir():
#         if d.is_dir() and _norm_name(d.name) == dsq:
#             cand_dirs.append(d)
#     if not cand_dirs:
#         return []
#     best_pts = []
#     for d in cand_dirs:
#         best_pts.extend(list(d.rglob("best.pt")))
#     return sorted(best_pts)

# def scan_best_ckpts_for_dataset_filtered(out_root: Path, dataset_name: str,
#                                         target_names: Optional[set],
#                                         seed: Optional[int] = None) -> List[Tuple[str, Path]]:
#     pts = scan_best_ckpts_for_dataset(out_root, dataset_name)
#     if not pts:
#         return []

#     matched: List[Tuple[str, Path]] = []
#     for p in pts:
#         try:
#             ckpt = torch.load(p, map_location="cpu")
#             exp = ckpt.get("exp", {}) or {}
#             cfg = ckpt.get("config", {}) or {}
#             meta = {}
#             meta.update(cfg); meta.update(exp)

#             if seed is not None:
#                 mseed = int(meta.get("seed", -999999))
#                 if mseed != int(seed):
#                     continue

#             case_name = str(meta.get("case_name", "")) if meta.get("case_name", "") is not None else ""
#             if target_names and case_name.strip() not in target_names:
#                 continue

#             matched.append((case_name.strip(), p))
#         except Exception:
#             continue

#     uniq, seen = [], set()
#     for cname, p in matched:
#         k = str(p.resolve())
#         if k in seen: continue
#         seen.add(k); uniq.append((cname, p))
#     return uniq

# # -------------------------
# # Candidate + inference helpers
# # -------------------------
# def _parse_case_tag_from_noise_dirname(noise_name: str) -> str:
#     m = re.match(r"^labels_uniform_scaling_(.+)$", noise_name)
#     if m:
#         raw = m.group(1).replace("p",".")
#         try:
#             s = float(raw)
#             return f"scale_{s:g}"
#         except Exception:
#             return "clean"
#     m = re.match(r"^labels_boundary_jitter_(\d+)$", noise_name)
#     if m:
#         return f"side_{int(m.group(1))}"
#     return "clean"

# def iter_noise_label_dirs(dataset_root: Path) -> List[Tuple[str, Path]]:
#     out = []
#     for s in UNIFORM_SCALING_FACTORS:
#         name = f"labels_uniform_scaling_{s:g}"
#         p = dataset_root / name
#         if p.exists():
#             out.append((name, p))
#     for k in JITTER_PATTERNS:
#         name = f"labels_boundary_jitter_{int(k)}"
#         p = dataset_root / name
#         if p.exists():
#             out.append((name, p))
#     return out

# def free_cuda():
#     if torch.cuda.is_available():
#         torch.cuda.empty_cache()
#     gc.collect()

# @torch.inference_mode()
# def _forward_scores(model: torch.nn.Module, Xc, Xctx, D, mask_valid):
#     try:
#         return model(Xc, Xctx, D, mask_valid)
#     except TypeError:
#         AUX = torch.zeros((Xc.shape[0], Xc.shape[1], 0), device=Xc.device)
#         return model(Xc, Xctx, AUX, D, mask_valid)

# def _rois_from_cand_boxes(
#     cand_boxes: List[Tuple[float,float,float,float]],
#     W: float, H: float,
#     device: torch.device,
# ) -> torch.Tensor:
#     # ROIAlign expects: (batch_idx, x1, y1, x2, y2) in input coordinate space.
#     rois = []
#     for (cx2, cy2, w2, h2) in cand_boxes:
#         x1 = (cx2 - w2/2) * W
#         y1 = (cy2 - h2/2) * H
#         x2 = (cx2 + w2/2) * W
#         y2 = (cy2 + h2/2) * H

#         # clamp
#         x1 = float(max(0.0, min(W - 1.0, x1)))
#         y1 = float(max(0.0, min(H - 1.0, y1)))
#         x2 = float(max(x1 + 1.0, min(W * 1.0, x2)))
#         y2 = float(max(y1 + 1.0, min(H * 1.0, y2)))

#         rois.append([0.0, x1, y1, x2, y2])
#     return torch.tensor(rois, device=device, dtype=torch.float32)

# @torch.inference_mode()
# def refine_one_label_file(
#     model: torch.nn.Module,
#     img_path: Path,
#     noisy_lbl_path: Path,
#     clean_lbl_path: Optional[Path],
#     out_lbl_path: Path,
#     cand_cfg: Any,
#     jitter_patterns: List[int],
#     crop_size: int,
#     use_context: bool,
#     use_amp: bool,
#     device: torch.device,
#     case_tag: str,
#     group_bs: int = GROUP_BS_DEFAULT,
# ) -> Dict[str, float]:
#     noisy_rows = read_yolo_txt(noisy_lbl_path)
#     if len(noisy_rows) == 0:
#         if OVERWRITE_EXISTING or (not out_lbl_path.exists()):
#             write_yolo_txt(out_lbl_path, [])
#         return {"n_boxes": 0.0, "sum_iou_noisy": 0.0, "sum_iou_refined": 0.0}

#     clean_rows = read_yolo_txt(clean_lbl_path) if (WRITE_METRICS_CSV and clean_lbl_path is not None and clean_lbl_path.exists()) else None

#     with Image.open(img_path).convert("RGB") as img:
#         # --- 1) Build GPU image tensor once (normalize on GPU) ---
#         # pil_to_tensor: uint8 (C,H,W) on CPU
#         img_u8 = TVF.pil_to_tensor(img)
#         img_f = img_u8.to(device=device, dtype=torch.float32) / 255.0  # (C,H,W) on GPU

#         mean = torch.tensor(IMAGENET_MEAN, device=device, dtype=torch.float32).view(3,1,1)
#         std  = torch.tensor(IMAGENET_STD,  device=device, dtype=torch.float32).view(3,1,1)
#         img_f = (img_f - mean) / std
#         img_bchw = img_f.unsqueeze(0)  # (1,3,H,W)

#         Ht = float(img_f.shape[-2])
#         Wt = float(img_f.shape[-1])

#         # --- 2) Context on GPU (single interpolate) ---
#         if use_context:
#             xctx = F.interpolate(
#                 img_bchw, size=(crop_size, crop_size),
#                 mode="bilinear", align_corners=False
#             ).squeeze(0)  # (3,crop,crop)
#         else:
#             xctx = None

#         refined_rows = []
#         sum_iou_noisy = 0.0
#         sum_iou_ref   = 0.0
#         n_boxes = 0

#         groups_Xc, groups_D, groups_meta = [], [], []

#         def flush():
#             nonlocal sum_iou_noisy, sum_iou_ref, n_boxes, refined_rows
#             if not groups_Xc:
#                 return
#             Bs = len(groups_Xc)
#             maxN = max(x.shape[0] for x in groups_Xc)

#             Xc = torch.zeros((Bs, maxN, 3, crop_size, crop_size), device=device, dtype=torch.float32)
#             D  = torch.zeros((Bs, maxN), device=device, dtype=torch.float32)
#             mask_valid = torch.zeros((Bs, maxN), dtype=torch.bool, device=device)

#             if use_context:
#                 # avoid repeated stack copies; expand is fine for read-only
#                 Xctx = xctx.unsqueeze(0).expand(Bs, -1, -1, -1).contiguous()
#             else:
#                 Xctx = torch.zeros((Bs, 3, crop_size, crop_size), device=device, dtype=torch.float32)

#             for b in range(Bs):
#                 n = groups_Xc[b].shape[0]
#                 Xc[b, :n] = groups_Xc[b]
#                 D[b, :n]  = groups_D[b]
#                 mask_valid[b, :n] = True

#             amp_ok = (use_amp and device.type == "cuda" and torch.cuda.is_available())
#             if amp_ok:
#                 with torch.amp.autocast(device_type="cuda", enabled=True):
#                     scores = _forward_scores(model, Xc, Xctx, D, mask_valid)
#             else:
#                 scores = _forward_scores(model, Xc, Xctx, D, mask_valid)

#             for b in range(Bs):
#                 valid = mask_valid[b]
#                 if valid.sum() == 0:
#                     cls, noisy_box = groups_meta[b]["cls"], groups_meta[b]["noisy_box"]
#                     refined_rows.append((cls, *noisy_box))
#                     continue
#                 sb = scores[b, valid].detach()
#                 best_idx = int(torch.argmax(sb).item())
#                 cand_boxes = groups_meta[b]["cand_boxes"]
#                 refined_box = cand_boxes[best_idx]
#                 cls = groups_meta[b]["cls"]
#                 refined_rows.append((cls, *refined_box))

#                 if groups_meta[b]["clean_box"] is not None:
#                     clean_box = groups_meta[b]["clean_box"]
#                     noisy_box = groups_meta[b]["noisy_box"]
#                     sum_iou_noisy += box_iou(noisy_box, clean_box)
#                     sum_iou_ref   += box_iou(refined_box, clean_box)
#                     n_boxes += 1

#             groups_Xc.clear(); groups_D.clear(); groups_meta.clear()

#         # --- 3) per box: build candidates (CPU) + ROIAlign crops (GPU) ---
#         for i, (cls, cx, cy, w, h) in enumerate(noisy_rows):
#             clean_for_hash = (clean_lbl_path if (clean_lbl_path is not None and clean_lbl_path.exists()) else noisy_lbl_path)

#             cand_boxes, cand_dists, _ = build_candidates_for_one_box(
#                 anchor_box=(cx,cy,w,h),
#                 case_tag=case_tag,
#                 clean_lbl_path=Path(clean_for_hash),
#                 box_index=i,
#                 seed=int(getattr(cand_cfg, "seed", 42)) if hasattr(cand_cfg, "seed") else 42,
#                 cand_cfg=cand_cfg,
#                 jitter_patterns=jitter_patterns,
#             )

#             if not cand_boxes:
#                 refined_rows.append((int(cls), cx,cy,w,h))
#                 continue

#             # GPU crop all candidate boxes at once
#             rois_t = _rois_from_cand_boxes(cand_boxes, W=Wt, H=Ht, device=device)
#             # (N,3,crop,crop) on GPU, already normalized (input is normalized)
#             Xc_group = roi_align(
#                 img_bchw, rois_t,
#                 output_size=(crop_size, crop_size),
#                 spatial_scale=1.0,
#                 aligned=True
#             )
#             D_group  = torch.tensor(cand_dists, dtype=torch.float32, device=device)

#             clean_box = None
#             if clean_rows is not None and i < len(clean_rows):
#                 _, ccx, ccy, cw, ch = clean_rows[i]
#                 clean_box = (ccx, ccy, cw, ch)

#             groups_Xc.append(Xc_group)
#             groups_D.append(D_group)
#             groups_meta.append({
#                 "cls": int(cls),
#                 "clean_box": clean_box,
#                 "noisy_box": (cx,cy,w,h),
#                 "cand_boxes": cand_boxes,
#             })

#             if len(groups_Xc) >= int(group_bs):
#                 flush()

#         flush()

#     if OVERWRITE_EXISTING or (not out_lbl_path.exists()):
#         write_yolo_txt(out_lbl_path, refined_rows)

#     return {"n_boxes": float(n_boxes), "sum_iou_noisy": float(sum_iou_noisy), "sum_iou_refined": float(sum_iou_ref)}

# # -------------------------
# # MAIN
# # -------------------------
# device = torch.device(DEVICE if (torch.cuda.is_available() and str(DEVICE).startswith("cuda")) else "cpu")

# # Optional: helps speed on fixed-size ROIAlign/conv pipelines
# if device.type == "cuda":
#     torch.backends.cudnn.benchmark = True

# all_metrics_rows = []

# # ⚠️ Since output path has no seed/case distinction, use only 1 seed per dataset
# if len(SEEDS) != 1:
#     print(f"⚠️ SEEDS={SEEDS} (len>1). Output path has NO seed partition -> will use ONLY first seed: {SEEDS[0]}")
# SEED_REQ = int(SEEDS[0]) if SEEDS else 42

# for dataset in (TARGET_DATASETS or []):
#     ds_root = resolve_dataset_root(dataset)
#     if ds_root is None or (not ds_root.exists()):
#         print(f"⚠️ Skip missing dataset root: {DATASETS_ROOT/dataset}")
#         continue

#     images_root = ds_root / "images"
#     labels_root = ds_root / "labels"
#     if not images_root.exists() or not labels_root.exists():
#         print(f"⚠️ Skip (missing images/labels): {ds_root.name}")
#         continue

#     splits_all = list_existing_splits(images_root)
#     splits = [s for s in splits_all if _split_selected(s)]
#     if not splits:
#         splits = ["__flat__"]

#     noise_dirs = iter_noise_label_dirs(ds_root)
#     if not noise_dirs:
#         print(f"⚠️ No noise label dirs in {ds_root} (labels_uniform_scaling_*/labels_boundary_jitter_*). Skip.")
#         continue

#     # dataset-level refine root
#     DATASET_REFINE_ROOT = ds_root / f"refine({n_data})"
#     DATASET_REFINE_ROOT.mkdir(parents=True, exist_ok=True)

#     print("\n" + "="*110)
#     print(f"[DATASET] {ds_root.name} | noises={len(noise_dirs)} | splits={splits}")
#     print("        refine INPUT  => datasets/<ds>/(labels_uniform_scaling_*/labels_boundary_jitter_*)")
#     print(f"        refine OUTPUT => {DATASET_REFINE_ROOT}/<noise_name>/<split>/.../*.txt   (NO seed/case)")
#     print("="*110)

#     # 1) summary csv first: (dataset + seed + selected cases) -> best.pt
#     case_ckpts = best_ckpts_from_summary(
#         out_root=OUT_ROOT,
#         dataset_name=ds_root.name,
#         target_names=TARGET_CASE_NAME_SET,
#         seed=SEED_REQ,
#     )

#     # 2) fallback: weights scan + (case_name match) + seed filter
#     if not case_ckpts:
#         case_ckpts = scan_best_ckpts_for_dataset_filtered(
#             out_root=OUT_ROOT,
#             dataset_name=ds_root.name,
#             target_names=TARGET_CASE_NAME_SET,
#             seed=SEED_REQ,
#         )

#     if not case_ckpts:
#         print(f"  ⚠️ [seed={SEED_REQ}] No matched best.pt found. Skip dataset.")
#         continue

#     # ⚠️ Since output path has no case distinction, use only 1 ckpt
#     if len(case_ckpts) > 1:
#         # Select 1 latest best.pt
#         case_ckpts_sorted = sorted(case_ckpts, key=lambda x: x[1].stat().st_mtime if x[1].exists() else 0.0, reverse=True)
#         print(f"  ⚠️ matched_ckpts={len(case_ckpts)} (output has NO case partition) -> will use ONLY latest ckpt:")
#         for cname, p in case_ckpts_sorted[:3]:
#             print(f"     - {cname} | {p}")
#         case_ckpts = [case_ckpts_sorted[0]]

#     case_name_from_list, best_pt = case_ckpts[0]

#     ckpt = torch.load(best_pt, map_location="cpu")
#     model, meta = build_model_from_ckpt(ckpt)
#     model = model.to(device)

#     crop_size = int(meta.get("crop_size", 224))
#     use_context = bool(meta.get("use_context", True))
#     use_amp = bool(meta.get("use_amp", True))
#     seed_ckpt = int(meta.get("seed", SEED_REQ))

#     exp = ckpt.get("exp", {}) or {}
#     cand_cfg_dict = (exp.get("cand_cfg", None) or meta.get("cand_cfg", None))
#     if not isinstance(cand_cfg_dict, dict):
#         cand_cfg_dict = {
#             "cand_uniform_scaling_factors": meta.get("cand_uniform_scaling_factors", [0.6,0.7,0.8,0.9,1.0,1.1,1.2,1.3,1.4]),
#             "cand_side_ks": meta.get("cand_side_ks", JITTER_PATTERNS),
#             "num_border_perturb": int(meta.get("num_border_perturb", 10)),
#             "include_anchor": True,
#             "include_inverse": True,
#             "inverse_jitter": float(meta.get("inverse_jitter", 0.03)),
#             "require_mixed_signs": True,
#             "max_candidates_per_group": int(meta.get("max_candidates_per_group", 60)),
#         }
#     cand_cfg = CandidateCfg(**{k: cand_cfg_dict[k] for k in cand_cfg_dict.keys()
#                               if k in inspect.signature(CandidateCfg).parameters})
#     setattr(cand_cfg, "seed", int(seed_ckpt))  # candidate reproducibility

#     jitter_patterns = list(exp.get("jitter_patterns", JITTER_PATTERNS))
#     pretty_case_name = case_name_from_list or str(meta.get("case_name", "")) or "(unknown_case)"

#     print(f"\n  ▶ Using ONE ckpt for dataset={ds_root.name}")
#     print(f"     - device={device} | model_param_device={next(model.parameters()).device}")
#     print(f"     - seed_req={SEED_REQ} | ckpt_seed={seed_ckpt}")
#     print(f"     - selected_case_name={pretty_case_name}")
#     print(f"     - ckpt={best_pt}")
#     print(f"     - ROIAlign GPU crops: {'ON' if device.type=='cuda' else 'OFF (CPU device)'}")
#     print(f"     - GROUP_BS={GROUP_BS_DEFAULT}")

#     for noise_name, noise_dir in noise_dirs:
#         case_tag = _parse_case_tag_from_noise_dirname(noise_name)

#         for split in splits:
#             img_dir = resolve_split_dir(images_root, split)
#             clean_lbl_dir = resolve_split_dir(labels_root, split)
#             noisy_lbl_dir = resolve_split_dir(noise_dir, split)

#             if not img_dir.exists() or not noisy_lbl_dir.exists() or not clean_lbl_dir.exists():
#                 continue

#             # ✅ FINAL OUTPUT RULE (NO seed/case):
#             # /datasets/<ds>/refine(n_data)/<noise_name>/<split>/.../*.txt
#             out_lbl_dir = DATASET_REFINE_ROOT / noise_name / (split if split != "__flat__" else "")
#             done_marker = out_lbl_dir / "_DONE.json"
#             if done_marker.exists() and not OVERWRITE_EXISTING:
#                 continue

#             lbl_files = sorted([p for p in noisy_lbl_dir.rglob("*.txt") if p.is_file()])
#             if MAX_FILES_PER_SPLIT is not None:
#                 lbl_files = lbl_files[:int(MAX_FILES_PER_SPLIT)]
#             if not lbl_files:
#                 continue

#             print(f"    - refine {noise_name} / {split} | files={len(lbl_files)}")
#             print(f"      -> save to: {out_lbl_dir}")

#             sum_iou_noisy = 0.0
#             sum_iou_ref = 0.0
#             n_boxes = 0
#             n_files_done = 0
#             n_files_skipped = 0

#             for lf in lbl_files:
#                 rel = lf.relative_to(noisy_lbl_dir)
#                 out_lf = out_lbl_dir / rel

#                 if out_lf.exists() and (not OVERWRITE_EXISTING):
#                     n_files_skipped += 1
#                     continue

#                 img_path = find_image_for_label(img_dir, rel)
#                 if img_path is None:
#                     continue

#                 clean_lf = clean_lbl_dir / rel
#                 m = refine_one_label_file(
#                     model=model,
#                     img_path=img_path,
#                     noisy_lbl_path=lf,
#                     clean_lbl_path=clean_lf if clean_lf.exists() else None,
#                     out_lbl_path=out_lf,
#                     cand_cfg=cand_cfg,
#                     jitter_patterns=jitter_patterns,
#                     crop_size=crop_size,
#                     use_context=use_context,
#                     use_amp=use_amp,
#                     device=device,
#                     case_tag=case_tag,
#                     group_bs=GROUP_BS_DEFAULT,
#                 )
#                 n_files_done += 1
#                 sum_iou_noisy += m["sum_iou_noisy"]
#                 sum_iou_ref   += m["sum_iou_refined"]
#                 n_boxes       += int(m["n_boxes"])

#             out_lbl_dir.mkdir(parents=True, exist_ok=True)
#             payload = {
#                 "dataset": ds_root.name,
#                 "n_data_tag": int(n_data),
#                 "seed_req": int(SEED_REQ),
#                 "seed_ckpt": int(seed_ckpt),
#                 "selected_case_name": pretty_case_name,
#                 "noise_name": noise_name,
#                 "case_tag": case_tag,
#                 "split": split,
#                 "files_total": len(lbl_files),
#                 "files_done": n_files_done,
#                 "files_skipped": n_files_skipped,
#                 "n_boxes": n_boxes,
#                 "mean_iou_noisy": (sum_iou_noisy / max(n_boxes, 1)) if WRITE_METRICS_CSV else None,
#                 "mean_iou_refined": (sum_iou_ref / max(n_boxes, 1)) if WRITE_METRICS_CSV else None,
#                 "delta_iou": ((sum_iou_ref - sum_iou_noisy) / max(n_boxes, 1)) if WRITE_METRICS_CSV else None,
#                 "best_pt": str(best_pt),
#                 "out_dir": str(out_lbl_dir),
#                 "note": "output has NO seed/case_id partition | ROIAlign GPU crops enabled when CUDA",
#             }
#             done_marker.write_text(json.dumps(payload, indent=2), encoding="utf-8")

#             if WRITE_METRICS_CSV:
#                 all_metrics_rows.append(payload)

#     del model
#     free_cuda()

# # -------------------------
# # Save metrics CSV
# # -------------------------
# if WRITE_METRICS_CSV and all_metrics_rows:
#     import pandas as pd
#     df_all = pd.DataFrame(all_metrics_rows)

#     save_all = OUT_ROOT / "_refine_full_summary_metrics__no_seed_case_in_path.csv"
#     df_all.to_csv(save_all, index=False)
#     print(f"\n✅ Saved refine metrics summary: {save_all}")
# else:
#     print("\n✅ Done (no metrics csv requested or no rows).")


## Time Calculation

In [16]:
# # ==========================================
# # Cell X) ReBox-based Refinement Speed Benchmark (NO saving)
# #  - No saving logic
# #  - Random 5 samples per dataset -> speed measurement
# #  - "previously presented SAM Speed Benchmark" aligned with same output format/measurement method
# #    * Img Load(s): Image open/convert + np conversion + YOLO txt parsing
# #    * Refine(s)  : (ReBox) context tensor prep + candidate gen + crop+tf + model forward(score/argmax)
# #    * Includes per-section GPU sync via torch.cuda.synchronize()
# # ==========================================

# import time
# import random
# import statistics
# import torch
# import numpy as np
# import csv, re, gc, inspect
# from pathlib import Path
# from PIL import Image

# # -------------------------
# # [Configuration] (Same naming/format as SAM benchmark)
# # -------------------------
# DATASETS_ROOT = Path("/home/ISW/project/datasets")
# DEVICE = "cuda:1"

# TARGET_DATASETS = [
#     "kitti",
#     "homeobjects-3K",
#     "african-wildlife",
#     "construction-ppe",
#     # "Custom_Blood",
#     "brain-tumor",
#     "BCCD",
#     "signature",
#     "medical-pills",
#     "VOC",
# ]

# # Benchmark settings
# TEST_NUMBER = 100
# IMG_EXTS = [".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp"]

# # -------------------------
# # ReBox ckpt selection (same conditions as Cell5)
# # -------------------------
# n_data = 10
# OUT_ROOT = Path(f"./experiments_ablation(scale_side, start=noise)-({n_data})").resolve()
# SEED_REQ = 43

# TARGET_CASE_NAMES = [
#     f"SINGLE_densenet_absn{n_data}_ctx1_dist0_E1",
# ]
# TARGET_CASE_NAME_SET = set([x for x in TARGET_CASE_NAMES if str(x).strip()])

# # Candidate generation default fallback
# JITTER_PATTERNS = [1, 3, 5, 7, 9]
# UNIFORM_SCALING_FACTORS  = [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.3, 1.4]

# # -------------------------
# # REQUIRE: Cell2 loaded
# # -------------------------
# assert "ReBox" in globals(), "Cell 2 (ReBox) must be loaded first."
# assert "CandidateCfg" in globals() and "build_candidates_for_one_box" in globals(), "Cell2 candidate generation functions are required."

# # -------------------------
# # 1) Helper Functions (maintain SAM benchmark style)
# # -------------------------
# def resolve_dataset_root(dataset_name: str) -> Path:
#     cand = DATASETS_ROOT / dataset_name
#     if cand.exists(): return cand
#     for d in DATASETS_ROOT.iterdir():
#         if d.is_dir() and d.name.lower() == dataset_name.lower():
#             return d
#     return cand  # fallback

# def read_yolo_txt_fast(p: Path):
#     rows = []
#     if not p.exists(): return rows
#     with open(p, "r") as f:
#         for line in f:
#             parts = line.strip().split()
#             if len(parts) >= 5:
#                 # exclude cls, only cx cy w h
#                 rows.append([float(x) for x in parts[1:5]])
#     return rows

# def yolo_to_xyxy_pix(cx, cy, w, h, W, H):
#     x1 = int(round((cx - w/2) * W)); x2 = int(round((cx + w/2) * W))
#     y1 = int(round((cy - h/2) * H)); y2 = int(round((cy + h/2) * H))
#     return max(0, x1), max(0, y1), min(W, x2), min(H, y2)

# # -------------------------
# # 2) OUT_ROOT & ckpt helpers (same logic as Cell5)
# # -------------------------
# def resolve_out_root(user_out: Path) -> Path:
#     if (user_out / "weights").is_dir():
#         return user_out

#     candidates = []
#     for w in Path(".").glob("**/weights"):
#         if not w.is_dir():
#             continue
#         root = w.parent
#         if any(x in str(root).lower() for x in ["site-packages", ".cache", "venv", "conda"]):
#             continue
#         candidates.append(root)

#     if not candidates:
#         print(f"⚠️ OUT_ROOT={user_out} has no weights/ and auto-search also failed.")
#         return user_out

#     def score(r: Path):
#         has_summary = int((r / "_orchestrator_summary").is_dir())
#         try:
#             mtime = (r / "weights").stat().st_mtime
#         except Exception:
#             mtime = 0.0
#         return (has_summary, mtime)

#     best = sorted(candidates, key=score, reverse=True)[0]
#     print(f"ℹ️ OUT_ROOT auto-corrected: {user_out} -> {best} (weights/ found)")
#     return best

# OUT_ROOT = resolve_out_root(OUT_ROOT)

# def _norm_name(s: str) -> str:
#     return str(s).strip().lower()

# def find_orchestrator_summary_csv(out_root: Path) -> Path | None:
#     d = out_root / "_orchestrator_summary"
#     if not d.is_dir():
#         return None
#     cands = sorted(list(d.glob("*.csv")))
#     if not cands:
#         return None
#     cands = sorted(cands, key=lambda p: p.stat().st_mtime, reverse=True)
#     for p in cands:
#         if "summary__start_" in p.name.lower():
#             return p
#     return cands[0]

# def _parse_int(x, default=None):
#     try:
#         return int(float(str(x).strip()))
#     except Exception:
#         return default

# def _best_pt_from_row(out_root: Path, row: dict) -> Path | None:
#     bp = row.get("best_ckpt","") or row.get("best_pt","") or row.get("best_model","")
#     if bp:
#         p = Path(bp)
#         if not p.is_absolute():
#             p = (out_root / bp).resolve()
#         if p.exists():
#             return p

#     log_csv = row.get("log_csv","")
#     if not log_csv:
#         return None
#     p = Path(log_csv)
#     if not p.is_absolute():
#         p = (out_root / log_csv).resolve()

#     run_dir = p.parent if (not p.exists()) else p.resolve().parent
#     try:
#         rel = run_dir.relative_to(out_root.resolve())
#         best_pt = (out_root / "weights" / rel / "best.pt").resolve()
#         if best_pt.exists():
#             return best_pt
#     except Exception:
#         pass
#     return None

# def best_ckpts_from_summary(out_root: Path, dataset_name: str,
#                            target_names: set | None,
#                            seed: int | None = None) -> list[tuple[str, Path]]:
#     summ = find_orchestrator_summary_csv(out_root)
#     if summ is None or (not summ.exists()):
#         return []

#     dsq = _norm_name(dataset_name)
#     out: list[tuple[str, Path]] = []
#     with open(summ, "r", encoding="utf-8", newline="") as f:
#         rd = csv.DictReader(f)
#         for row in rd:
#             if _norm_name(row.get("dataset","")) != dsq:
#                 continue
#             if seed is not None:
#                 rseed = _parse_int(row.get("seed",""), default=None)
#                 if rseed is None or int(rseed) != int(seed):
#                     continue
#             if target_names:
#                 if str(row.get("case_name","")).strip() not in target_names:
#                     continue
#             best_pt = _best_pt_from_row(out_root, row)
#             if best_pt is None:
#                 continue
#             out.append((str(row.get("case_name","")).strip(), best_pt))

#     uniq, seen = [], set()
#     for cname, p in out:
#         k = str(p.resolve())
#         if k in seen: continue
#         seen.add(k); uniq.append((cname, p))
#     return uniq

# def scan_best_ckpts_for_dataset(out_root: Path, dataset_name: str) -> list[Path]:
#     wroot = out_root / "weights"
#     if not wroot.exists():
#         return []
#     dsq = _norm_name(dataset_name)
#     cand_dirs = [d for d in wroot.iterdir() if d.is_dir() and _norm_name(d.name) == dsq]
#     if not cand_dirs:
#         return []
#     best_pts = []
#     for d in cand_dirs:
#         best_pts.extend(list(d.rglob("best.pt")))
#     return sorted(best_pts)

# def scan_best_ckpts_for_dataset_filtered(out_root: Path, dataset_name: str,
#                                         target_names: set | None,
#                                         seed: int | None = None) -> list[tuple[str, Path]]:
#     pts = scan_best_ckpts_for_dataset(out_root, dataset_name)
#     if not pts:
#         return []
#     matched: list[tuple[str, Path]] = []
#     for p in pts:
#         try:
#             ckpt = torch.load(p, map_location="cpu")
#             exp = ckpt.get("exp", {}) or {}
#             cfg = ckpt.get("config", {}) or {}
#             meta = {}
#             meta.update(cfg); meta.update(exp)

#             if seed is not None:
#                 mseed = int(meta.get("seed", -999999))
#                 if mseed != int(seed):
#                     continue

#             case_name = str(meta.get("case_name", "")) if meta.get("case_name", "") is not None else ""
#             if target_names and case_name.strip() not in target_names:
#                 continue

#             matched.append((case_name.strip(), p))
#         except Exception:
#             continue

#     uniq, seen = [], set()
#     for cname, p in matched:
#         k = str(p.resolve())
#         if k in seen: continue
#         seen.add(k); uniq.append((cname, p))
#     return uniq

# # -------------------------
# # 3) ckpt -> model builder
# # -------------------------
# def _get_state_dict_from_ckpt(ckpt: dict) -> dict:
#     for k in ["model_state_dict", "model", "state_dict", "net", "weights"]:
#         sd = ckpt.get(k, None)
#         if isinstance(sd, dict) and sd:
#             return sd
#     raise KeyError("Cannot find model state dict in ckpt.")

# def build_model_from_ckpt(ckpt: dict):
#     exp = ckpt.get("exp", {}) or {}
#     cfg = ckpt.get("config", {}) or {}
#     meta = {}
#     meta.update(cfg)
#     meta.update(exp)

#     backbone_key = meta.get("backbone", None) or meta.get("backbone_key", None)
#     assert backbone_key is not None, "Could not find backbone info in ckpt exp/config."

#     use_context = bool(meta.get("use_context", True))
#     d_model  = int(meta.get("d_model", 512))
#     nhead    = int(meta.get("nhead", 8))
#     nlayers  = int(meta.get("nlayers", 2))
#     disable_dist_mlp = bool(meta.get("disable_dist_mlp", False))
#     share_backbone = bool(meta.get("share_backbone", False))

#     cand_kwargs = dict(
#         backbone_key=backbone_key,
#         d_model=d_model,
#         nhead=nhead,
#         nlayers=nlayers,
#         use_context=use_context,
#         disable_dist_mlp=disable_dist_mlp,
#         share_backbone=share_backbone,
#     )
#     sig = inspect.signature(ReBox.__init__)
#     allowed = set(sig.parameters.keys()) - {"self"}
#     kwargs = {k: v for k, v in cand_kwargs.items() if k in allowed}

#     model = ReBox(**kwargs)
#     sd = _get_state_dict_from_ckpt(ckpt)
#     inc = model.load_state_dict(sd, strict=False)

#     mk = len(getattr(inc, "missing_keys", []) or [])
#     uk = len(getattr(inc, "unexpected_keys", []) or [])
#     if (mk + uk) > 0:
#         print(f"     ⚠️ load_state_dict(strict=False): missing={mk}, unexpected={uk}")

#     model.eval()
#     meta["__init_kwargs_used"] = kwargs
#     return model, meta

# # -------------------------
# # 4) Minimal transforms/candidate helpers
# # -------------------------
# IMAGENET_MEAN = [0.485,0.456,0.406]
# IMAGENET_STD  = [0.229,0.224,0.225]

# def build_tf(crop_size: int):
#     # without importing torchvision again, simply implement PIL->Tensor conversion
#     # (Cell5 used transforms, but here it's lightweight for benchmark)
#     def _to_tensor_norm(pil_img: Image.Image):
#         arr = np.asarray(pil_img, dtype=np.float32) / 255.0  # HWC
#         if arr.ndim == 2:
#             arr = np.stack([arr,arr,arr], axis=-1)
#         arr = arr[..., :3]
#         arr = arr.transpose(2,0,1)  # CHW
#         t = torch.from_numpy(arr)
#         # resize (resize with PIL then reconvert: simple/consistent)
#         pil_rs = pil_img.resize((crop_size, crop_size), resample=Image.BILINEAR)
#         arr_rs = np.asarray(pil_rs, dtype=np.float32) / 255.0
#         arr_rs = arr_rs[..., :3].transpose(2,0,1)
#         t = torch.from_numpy(arr_rs)
#         # normalize
#         mean = torch.tensor(IMAGENET_MEAN).view(3,1,1)
#         std  = torch.tensor(IMAGENET_STD).view(3,1,1)
#         t = (t - mean) / std
#         return t
#     return _to_tensor_norm

# def _parse_case_tag_from_noise_dirname(noise_name: str) -> str:
#     m = re.match(r"^labels_uniform_scaling_(.+)$", noise_name)
#     if m:
#         raw = m.group(1).replace("p",".")
#         try:
#             s = float(raw)
#             return f"scale_{s:g}"
#         except Exception:
#             return "clean"
#     m = re.match(r"^labels_boundary_jitter_(\d+)$", noise_name)
#     if m:
#         return f"side_{int(m.group(1))}"
#     return "clean"

# @torch.inference_mode()
# def _forward_scores(model, Xc, Xctx, D, mask_valid):
#     try:
#         return model(Xc, Xctx, D, mask_valid)
#     except TypeError:
#         AUX = torch.zeros((Xc.shape[0], Xc.shape[1], 0), device=Xc.device)
#         return model(Xc, Xctx, AUX, D, mask_valid)

# @torch.inference_mode()
# def run_rebox_refine_timing(
#     model,
#     img_pil: Image.Image,
#     boxes_yolo: list[list[float]],
#     label_source_name: str,
#     cand_cfg,
#     jitter_patterns: list[int],
#     crop_size: int,
#     use_context: bool,
#     use_amp: bool,
#     device,
# ):
#     """
#     To be included in Refine(s) section:
#       - (optional) context tensor prep (=entire img 1x TF + GPU transfer)
#       - candidate gen + crop+tf for each box
#       - group batch flush -> model forward + argmax
#     """
#     tf = build_tf(crop_size)
#     W, H = img_pil.size

#     # context tensor (same as Cell5 logic "image 1x TF")
#     xctx = None
#     if use_context:
#         xctx = tf(img_pil).to(device)

#     GROUP_BS = 4
#     groups_Xc, groups_D = [], []

#     def flush():
#         if not groups_Xc:
#             return
#         Bs = len(groups_Xc)
#         maxN = max(x.shape[0] for x in groups_Xc)

#         Xc = torch.zeros((Bs, maxN, 3, crop_size, crop_size), device=device)
#         D  = torch.zeros((Bs, maxN), device=device)
#         mask_valid = torch.zeros((Bs, maxN), dtype=torch.bool, device=device)

#         if use_context and (xctx is not None):
#             Xctx = torch.stack([xctx for _ in range(Bs)], dim=0)
#         else:
#             Xctx = torch.zeros((Bs, 3, crop_size, crop_size), device=device)

#         for b in range(Bs):
#             n = groups_Xc[b].shape[0]
#             Xc[b, :n] = groups_Xc[b]
#             D[b, :n]  = groups_D[b]
#             mask_valid[b, :n] = True

#         amp_ok = (use_amp and device.type == "cuda" and torch.cuda.is_available())
#         if amp_ok:
#             with torch.amp.autocast(device_type="cuda", enabled=True):
#                 scores = _forward_scores(model, Xc, Xctx, D, mask_valid)
#         else:
#             scores = _forward_scores(model, Xc, Xctx, D, mask_valid)

#         # perform argmax only (no saving)
#         for b in range(Bs):
#             v = mask_valid[b]
#             if v.any():
#                 _ = int(torch.argmax(scores[b, v]).item())

#         groups_Xc.clear(); groups_D.clear()

#     # set case_tag based on label_source_name
#     case_tag = _parse_case_tag_from_noise_dirname(label_source_name)

#     # box loop
#     for i, (cx, cy, w, h) in enumerate(boxes_yolo):
#         # skip "too small boxes" in refine loop same as SAM benchmark
#         x1, y1, x2, y2 = yolo_to_xyxy_pix(cx, cy, w, h, W, H)
#         if (x2 - x1) < 2 or (y2 - y1) < 2:
#             continue

#         # candidate generation
#         try:
#             cand_boxes, cand_dists, _ = build_candidates_for_one_box(
#                 anchor_box=(cx,cy,w,h),
#                 case_tag=case_tag,
#                 clean_lbl_path=Path("__bench__"),  # for hash (dummy to not depend on file)
#                 box_index=i,
#                 seed=int(getattr(cand_cfg, "seed", SEED_REQ)),
#                 cand_cfg=cand_cfg,
#                 jitter_patterns=jitter_patterns,
#             )
#         except Exception:
#             # minimum fallback
#             try:
#                 cand_boxes, cand_dists, _ = build_candidates_for_one_box(
#                     anchor_box=(cx,cy,w,h),
#                     case_tag=case_tag,
#                     clean_lbl_path=Path("__bench__"),
#                     box_index=i,
#                     seed=int(getattr(cand_cfg, "seed", SEED_REQ)),
#                 )
#             except Exception:
#                 continue

#         if not cand_boxes:
#             continue

#         # crop -> tensor
#         cand_tensors = []
#         for (cx2,cy2,w2,h2) in cand_boxes:
#             xx1 = int(round((cx2 - w2/2) * W)); xx2 = int(round((cx2 + w2/2) * W))
#             yy1 = int(round((cy2 - h2/2) * H)); yy2 = int(round((cy2 + h2/2) * H))
#             xx1 = max(0, min(W-1, xx1)); xx2 = max(1, min(W, xx2))
#             yy1 = max(0, min(H-1, yy1)); yy2 = max(1, min(H, yy2))
#             crop = img_pil.crop((xx1,yy1,xx2,yy2))
#             cand_tensors.append(tf(crop).unsqueeze(0))

#         Xc_group = torch.cat(cand_tensors, dim=0).to(device)
#         D_group  = torch.tensor(cand_dists, dtype=torch.float32, device=device)

#         groups_Xc.append(Xc_group)
#         groups_D.append(D_group)

#         if len(groups_Xc) >= GROUP_BS:
#             flush()

#     flush()

# def free_cuda():
#     if torch.cuda.is_available():
#         torch.cuda.empty_cache()
#         torch.cuda.ipc_collect()
#     gc.collect()

# # -------------------------
# # 5) Main Benchmark Logic (same output as SAM benchmark)
# # -------------------------
# def run_benchmark():
#     print(f"Loading ReBox ckpt per-dataset on {DEVICE}...")
#     device = torch.device(DEVICE if (torch.cuda.is_available() and str(DEVICE).startswith("cuda")) else "cpu")

#     dataset_times = {}

#     print(f"{'Dataset':<20} | {'Img Load(s)':<10} | {'Refine(s)':<12} | {'Objs/Img':<8} | {'Total/Img(s)':<12}")
#     print("-" * 75)

#     for ds_name in TARGET_DATASETS:
#         ds_root = resolve_dataset_root(ds_name)
#         if not ds_root.exists():
#             print(f"{ds_name:<20} | NOT FOUND")
#             continue

#         images_dir = ds_root / "images"
#         if not images_dir.exists():
#             print(f"{ds_name:<20} | NO IMAGES DIR")
#             continue

#         # label source: (same as SAM benchmark) select 1 if labels_uniform_scaling_* exists, otherwise labels
#         label_source = ds_root / "labels"
#         label_source_name = "labels"
#         try:
#             scale_dirs = sorted([d for d in ds_root.iterdir() if d.is_dir() and d.name.startswith("labels_uniform_scaling_")])
#         except Exception:
#             scale_dirs = []
#         if scale_dirs:
#             label_source = scale_dirs[0]
#             label_source_name = label_source.name

#         if not label_source.exists():
#             print(f"{ds_name:<20} | NO LABELS")
#             continue

#         # ---- ckpt select (same as Cell5) ----
#         case_ckpts = best_ckpts_from_summary(
#             out_root=OUT_ROOT,
#             dataset_name=ds_root.name,
#             target_names=TARGET_CASE_NAME_SET,
#             seed=SEED_REQ,
#         )
#         if not case_ckpts:
#             case_ckpts = scan_best_ckpts_for_dataset_filtered(
#                 out_root=OUT_ROOT,
#                 dataset_name=ds_root.name,
#                 target_names=TARGET_CASE_NAME_SET,
#                 seed=SEED_REQ,
#             )
#         if not case_ckpts:
#             print(f"{ds_name:<20} | SKIPPED (No ckpt)")
#             continue

#         case_ckpts_sorted = sorted(
#             case_ckpts,
#             key=lambda x: x[1].stat().st_mtime if x[1].exists() else 0.0,
#             reverse=True
#         )
#         _, best_pt = case_ckpts_sorted[0]

#         # model load
#         ckpt = torch.load(best_pt, map_location="cpu")
#         model, meta = build_model_from_ckpt(ckpt)
#         model.to(device)
#         model.eval()

#         crop_size = int(meta.get("crop_size", 224))
#         use_context = bool(meta.get("use_context", True))
#         use_amp = bool(meta.get("use_amp", True))

#         # CandidateCfg (restored from ckpt)
#         exp = ckpt.get("exp", {}) or {}
#         cand_cfg_dict = (exp.get("cand_cfg", None) or meta.get("cand_cfg", None))
#         if not isinstance(cand_cfg_dict, dict):
#             cand_cfg_dict = {
#                 "cand_uniform_scaling_factors": meta.get("cand_uniform_scaling_factors", UNIFORM_SCALING_FACTORS),
#                 "cand_side_ks": meta.get("cand_side_ks", JITTER_PATTERNS),
#                 "num_border_perturb": int(meta.get("num_border_perturb", 10)),
#                 "include_anchor": True,
#                 "include_inverse": True,
#                 "inverse_jitter": float(meta.get("inverse_jitter", 0.03)),
#                 "require_mixed_signs": True,
#                 "max_candidates_per_group": int(meta.get("max_candidates_per_group", 60)),
#             }
#         cand_cfg = CandidateCfg(**{k: cand_cfg_dict[k] for k in cand_cfg_dict.keys()
#                                   if k in inspect.signature(CandidateCfg).parameters})
#         # prevent AttributeError(include_anchor)
#         if not hasattr(cand_cfg, "include_anchor"):
#             setattr(cand_cfg, "include_anchor", True)
#         setattr(cand_cfg, "seed", int(meta.get("seed", SEED_REQ)))

#         jitter_patterns = list(exp.get("jitter_patterns", JITTER_PATTERNS))

#         # GPU warmup (similar to SAM benchmark)
#         if device.type == "cuda":
#             dummy = torch.zeros((1, 4, 3, crop_size, crop_size), device=device)
#             dctx  = torch.zeros((1, 3, crop_size, crop_size), device=device)
#             D     = torch.zeros((1, 4), device=device)
#             mask  = torch.ones((1, 4), dtype=torch.bool, device=device)
#             with torch.amp.autocast(device_type="cuda", enabled=bool(use_amp)):
#                 _ = _forward_scores(model, dummy, dctx, D, mask)
#             torch.cuda.synchronize()

#         # ---- collect file list (same as SAM benchmark) ----
#         all_imgs = []
#         for ext in IMG_EXTS:
#             all_imgs.extend(list(images_dir.rglob(f"*{ext}")))

#         if not all_imgs:
#             print(f"{ds_name:<20} | NO IMAGES")
#             del model
#             free_cuda()
#             continue

#         if len(all_imgs) > TEST_NUMBER:
#             samples = random.sample(all_imgs, TEST_NUMBER)
#         else:
#             samples = all_imgs

#         time_records = []  # (load_time, refine_time, num_objects)

#         for img_path in samples:
#             rel_path = img_path.relative_to(images_dir)
#             lbl_path = label_source / rel_path.with_suffix(".txt")

#             # --- [Measurement Start] ---
#             # 1) Image Loading & Preprocessing Time (same as SAM benchmark)
#             t_start_load = time.time()
#             try:
#                 im = Image.open(img_path).convert("RGB")
#                 img_np = np.array(im)  # matching same conditions
#                 H, W = img_np.shape[:2]
#                 boxes_yolo = read_yolo_txt_fast(lbl_path)
#             except Exception:
#                 try:
#                     im.close()
#                 except Exception:
#                     pass
#                 continue

#             if torch.cuda.is_available():
#                 torch.cuda.synchronize()
#             t_end_load = time.time()
#             load_dur = t_end_load - t_start_load

#             # 2) ReBox Refinement Time
#             t_start_refine = time.time()

#             # Refine(s) includes context prep + candidate gen + crop/tf + model forward
#             run_rebox_refine_timing(
#                 model=model,
#                 img_pil=im,                    # use PIL as is (prevent additional open)
#                 boxes_yolo=boxes_yolo,
#                 label_source_name=label_source_name,
#                 cand_cfg=cand_cfg,
#                 jitter_patterns=jitter_patterns,
#                 crop_size=crop_size,
#                 use_context=use_context,
#                 use_amp=use_amp,
#                 device=device,
#             )

#             if torch.cuda.is_available():
#                 torch.cuda.synchronize()
#             t_end_refine = time.time()
#             refine_dur = t_end_refine - t_start_refine

#             try:
#                 im.close()
#             except Exception:
#                 pass

#             time_records.append((load_dur, refine_dur, len(boxes_yolo)))
#             # --- [Measurement End] ---

#         if not time_records:
#             print(f"{ds_name:<20} | SKIPPED (No valid labels/images)")
#             del model
#             free_cuda()
#             continue

#         avg_load = statistics.mean([t[0] for t in time_records])
#         avg_ref  = statistics.mean([t[1] for t in time_records])
#         avg_objs = statistics.mean([t[2] for t in time_records])
#         total_time_per_img = avg_load + avg_ref

#         dataset_times[ds_name] = total_time_per_img

#         print(f"{ds_name:<20} | {avg_load:.4f}s      | {avg_ref:.4f}s       | {avg_objs:>8.1f} | {total_time_per_img:.4f}s")

#         del model
#         free_cuda()

#     # Final Summary (same as SAM benchmark)
#     print("-" * 75)
#     if dataset_times:
#         global_avg = statistics.mean(dataset_times.values())
#         print(f"{'OVERALL AVERAGE':<20} | {'-':<10} | {'-':<12} | {'-':<8} | {global_avg:.4f}s")
#         print("=" * 75)
#         print(f"Note: Measured on {TEST_NUMBER} random images per dataset.")
#         print("Time includes: Image Load + (ReBox) candidate gen + crop/tf + model forward(scoring).")
#     else:
#         print("No datasets processed successfully.")

# if __name__ == "__main__":
#     if torch.cuda.is_available():
#         torch.cuda.empty_cache()
#     run_benchmark()


# noise_detection model training

In [17]:
# # ==========================================
# # Cell 1) Detection datasets discovery & inspection (FINAL)
# #   - How many datasets are found
# #   - train/val image count per dataset
# #   - Whether label cases (original/scale/side) exist
# #   - Output estimated class count/names (multiclass-based)
# #
# # [UPDATED]
# #   ✅ Reflect split structure rules per dataset name
# #     * BCCD, brain-tumor, Custom_Blood, homeobjects-3K, kitti, medical-pills, signature
# #         images/labels -> train, val
# #     * coco
# #         images/labels -> train2017, val2017, test2017
# #     * construction-ppe, african-wildlife
# #         images/labels -> train, val, test
# #     * lvis
# #         images/labels -> train2017, val2017, test2017
# #     * SKU-110K
# #         images/labels -> no subfolders
# #         -> seed-based 8:2 virtual split
# #     * VOC
# #         images/labels -> test2007, train2007, train2012, val2007, val2012
# #         -> use train2012 / val2012 only
# # ==========================================

# from __future__ import annotations

# import os, sys, random, shutil
# from pathlib import Path
# from typing import List, Tuple, Optional, Dict

# # -------------------------------------------------------------------------
# # 0) Register PROJECT_MODULE_DIR
# # -------------------------------------------------------------------------
# PROJECT_MODULE_DIR = Path("/home/ISW/project/Project_Module")
# if str(PROJECT_MODULE_DIR) not in sys.path:
#     sys.path.insert(0, str(PROJECT_MODULE_DIR))

# # -------------------------------------------------------------------------
# # 1) ultra_det_loader
# # -------------------------------------------------------------------------
# from ultra_det_loader import discover_det_datasets

# # -------------------------------------------------------------------------
# # 2) noisy_insection (use only scale/boundary jitter case list)
# # -------------------------------------------------------------------------
# try:
#     from noisy_insection import UNIFORM_SCALING_FACTORS, JITTER_PATTERNS
# except Exception:
#     UNIFORM_SCALING_FACTORS = [0.6, 0.8, 1.2, 1.4]
#     JITTER_PATTERNS = [3, 5, 7]

# # -------------------------------------------------------------------------
# # User config
# # -------------------------------------------------------------------------
# LOAD_DIR = "/home/ISW/project/datasets"
# SEED = 42

# # Image extensions
# _IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}

# def set_seed(seed: int = 42):
#     random.seed(seed)

# def list_images(dir_path: Optional[Path]) -> List[Path]:
#     if dir_path is None or not Path(dir_path).exists():
#         return []
#     dir_path = Path(dir_path)
#     imgs = []
#     for p in dir_path.rglob("*"):
#         if p.is_file() and p.suffix.lower() in _IMG_EXTS:
#             imgs.append(p)
#     return sorted(imgs)

# def normalize_name(name: str) -> str:
#     # Simple normalization to absorb case/separator differences
#     n = name.strip().lower()
#     n = n.replace("_", "-")
#     n = n.replace(" ", "-")
#     return n

# # -------------------------------------------------------------------------
# # Legacy heuristic (fallback)
# # -------------------------------------------------------------------------
# def _fallback_train_dir(images_root: Path) -> Path:
#     if (images_root / "train").is_dir():
#         return images_root / "train"
#     return images_root

# def _fallback_val_dir(images_root: Path) -> Optional[Path]:
#     if (images_root / "val").is_dir():
#         return images_root / "val"
#     if (images_root / "valid").is_dir():
#         return images_root / "valid"
#     return None

# # -------------------------------------------------------------------------
# # ✅ Dataset-specific split rules
# # -------------------------------------------------------------------------
# _SIMPLE_TRAIN_VAL = {
#     "bccd",
#     "brain-tumor",
#     "custom-blood",
#     "homeobjects-3k",
#     "kitti",
#     "medical-pills",
#     "signature",
# }

# _TRAIN_TEST_VAL = {
#     "construction-ppe",
#     "african-wildlife",
# }

# def detect_split_dirs(ds_root: Path) -> Dict[str, Optional[Path]]:
#     """
#     Interpret images/labels split structure based on ds_root.
#     Returns:
#         {
#           "train_img_dir": Path|None,
#           "val_img_dir": Path|None,
#           "test_img_dir": Path|None,
#           "split_mode": str,  # "explicit" | "sku_virtual_8_2" | "fallback"
#           "train_tag": str,
#           "val_tag": str,
#         }
#     """
#     ds_name = normalize_name(ds_root.name)
#     images_root = ds_root / "images"

#     # 1) VOC rule: use train2012/val2012 only
#     if ds_name == "voc":
#         return dict(
#             train_img_dir=images_root / "train2012",
#             val_img_dir=images_root / "val2012",
#             test_img_dir=None,
#             split_mode="explicit",
#             train_tag="train2012",
#             val_tag="val2012",
#         )

#     # 2) COCO/LVIS rule
#     #    - Handle exact coco/lvis names or potential inclusion
#     if ds_name == "coco" or "coco" in ds_name:
#         return dict(
#             train_img_dir=images_root / "train2017",
#             val_img_dir=images_root / "val2017",
#             test_img_dir=images_root / "test2017",
#             split_mode="explicit",
#             train_tag="train2017",
#             val_tag="val2017",
#         )

#     if ds_name == "lvis" or "lvis" in ds_name:
#         return dict(
#             train_img_dir=images_root / "train2017",
#             val_img_dir=images_root / "val2017",
#             test_img_dir=images_root / "test2017",
#             split_mode="explicit",
#             train_tag="train2017",
#             val_tag="val2017",
#         )

#     # 3) Explicit train/val structure
#     if ds_name in _SIMPLE_TRAIN_VAL:
#         return dict(
#             train_img_dir=images_root / "train",
#             val_img_dir=images_root / "val",
#             test_img_dir=None,
#             split_mode="explicit",
#             train_tag="train",
#             val_tag="val",
#         )

#     # 4) train/test/val structure (only train/val used for summary)
#     if ds_name in _TRAIN_TEST_VAL:
#         return dict(
#             train_img_dir=images_root / "train",
#             val_img_dir=images_root / "val",
#             test_img_dir=images_root / "test",
#             split_mode="explicit",
#             train_tag="train",
#             val_tag="val",
#         )

#     # 5) SKU-110K: no subfolders -> virtual split
#     #    - Handle name variations
#     if ds_name in {"sku-110k", "sku110k", "sku_110k"} or "sku" in ds_name and "110k" in ds_name:
#         return dict(
#             train_img_dir=images_root,  # Same physical folder since virtual split
#             val_img_dir=images_root,
#             test_img_dir=None,
#             split_mode="sku_virtual_8_2",
#             train_tag="virtual_8_2",
#             val_tag="virtual_8_2",
#         )

#     # 6) fallback
#     tr = _fallback_train_dir(images_root)
#     va = _fallback_val_dir(images_root)
#     return dict(
#         train_img_dir=tr,
#         val_img_dir=va,
#         test_img_dir=None,
#         split_mode="fallback",
#         train_tag=tr.name if tr else "unknown",
#         val_tag=va.name if va else "missing",
#     )

# # -------------------------------------------------------------------------
# # Class name estimation
# # -------------------------------------------------------------------------
# def infer_class_names_from_labels(label_root: Path, max_files: int = 2000) -> List[str]:
#     if label_root is None or not label_root.exists():
#         return ["class_0"]

#     txts = list(label_root.rglob("*.txt"))
#     if not txts:
#         return ["class_0"]

#     txts = txts[:max_files]
#     cls_ids = set()

#     for t in txts:
#         try:
#             with open(t, "r", encoding="utf-8") as f:
#                 for line in f:
#                     parts = line.strip().split()
#                     if len(parts) < 5:
#                         continue
#                     cid = int(float(parts[0]))
#                     cls_ids.add(cid)
#         except Exception:
#             continue

#     if not cls_ids:
#         return ["class_0"]

#     max_id = max(cls_ids)
#     return [f"class_{i}" for i in range(max_id + 1)]

# # -------------------------------------------------------------------------
# # Label case detection
# # -------------------------------------------------------------------------
# def list_label_cases_for_dataset(ds_root: Path) -> List[Tuple[str, str]]:
#     cases: List[Tuple[str, str]] = []

#     if (ds_root / "labels").is_dir():
#         cases.append(("original", "labels"))

#     for s in UNIFORM_SCALING_FACTORS:
#         d = f"labels_uniform_scaling_{s}"
#         if (ds_root / d).is_dir():
#             cases.append((f"scale_{s}", d))

#     for k in JITTER_PATTERNS:
#         d = f"labels_boundary_jitter_{k}"
#         if (ds_root / d).is_dir():
#             cases.append((f"side_{k}", d))

#     return cases

# # -------------------------------------------------------------------------
# # SKU-110K virtual split count
# # -------------------------------------------------------------------------
# def compute_sku_virtual_counts(images_root: Path, seed: int = 42, ratio: float = 0.8) -> Tuple[int, int]:
#     imgs = list_images(images_root)
#     n = len(imgs)
#     if n == 0:
#         return 0, 0
#     rnd = random.Random(seed)
#     idxs = list(range(n))
#     rnd.shuffle(idxs)
#     cut = int(n * ratio)
#     n_train = cut
#     n_val = n - cut
#     return n_train, n_val

# # -------------------------------------------------------------------------
# # Discover dataset roots
# # -------------------------------------------------------------------------
# set_seed(SEED)

# specs = discover_det_datasets(LOAD_DIR)
# roots: List[Path] = []
# for s in specs:
#     r = Path(s.root)
#     if r not in roots:
#         roots.append(r)

# print("=" * 80)
# print(f"[DISCOVERY] Found {len(roots)} unique dataset roots under: {Path(LOAD_DIR).resolve()}")
# print("=" * 80)

# # -------------------------------------------------------------------------
# # Per-dataset summary
# # -------------------------------------------------------------------------
# dataset_summaries: List[Dict] = []

# for ds_root in roots:
#     ds_root = Path(ds_root)
#     images_root = ds_root / "images"
#     labels_root = ds_root / "labels"

#     if not images_root.is_dir() or not labels_root.is_dir():
#         print(f"⏭️  Skip (missing images/labels): {ds_root}")
#         continue

#     split_info = detect_split_dirs(ds_root)
#     train_dir = split_info["train_img_dir"]
#     val_dir   = split_info["val_img_dir"]
#     split_mode = split_info["split_mode"]
#     train_tag  = split_info.get("train_tag", "train")
#     val_tag    = split_info.get("val_tag", "val")

#     # --- Calculate image count ---
#     if split_mode == "sku_virtual_8_2":
#         n_train, n_val = compute_sku_virtual_counts(images_root, seed=SEED, ratio=0.8)
#     else:
#         n_train = len(list_images(train_dir))
#         n_val   = len(list_images(val_dir)) if val_dir else 0

#     cases = list_label_cases_for_dataset(ds_root)
#     class_names = infer_class_names_from_labels(labels_root)
#     nc = len(class_names)

#     info = {
#         "dataset": ds_root.name,
#         "root": str(ds_root),
#         "images_root": str(images_root),
#         "labels_root": str(labels_root),
#         "train_dir": str(train_dir) if train_dir else None,
#         "val_dir": str(val_dir) if val_dir else None,
#         "n_train": n_train,
#         "n_val": n_val,
#         "split_mode": split_mode,
#         "train_tag": train_tag,
#         "val_tag": val_tag,
#         "label_cases": [c[0] for c in cases],
#         "nc_inferred": nc,
#         "class_names_inferred": class_names,
#     }
#     dataset_summaries.append(info)

#     print("\n" + "-" * 80)
#     print(f"[Dataset] {ds_root.name}")
#     print(f" - root        : {ds_root}")
#     print(f" - split_mode  : {split_mode}")
#     print(f" - train_dir   : {train_dir if train_dir else '(missing)'} | tag={train_tag} | n_train={n_train}")
#     print(f" - val_dir     : {val_dir if val_dir else '(missing)'} | tag={val_tag} | n_val={n_val}")

#     # If test structure exists, just report existence (summary counts focus on train/val)
#     test_dir = split_info.get("test_img_dir", None)
#     if test_dir and test_dir.is_dir():
#         n_test = len(list_images(test_dir))
#         print(f" - test_dir    : {test_dir} | n_test={n_test}")

#     print(f" - label_cases : {[c[0] for c in cases] if cases else '(none)'}")
#     print(f" - inferred classes (multiclass-based): nc={nc}, names={class_names}")
#     print("-" * 80)

# print("\n✅ Cell 1 done.")
# print(f"   -> dataset_summaries length = {len(dataset_summaries)}")
# print("   -> roots variable is ready for Cell 2.")


In [18]:
# # ==========================================
# # Cell 2) Train & Validate — FINAL (Train on refines-only splits, Eval vs original labels)
# #   ✅ Requirements reflected:
# #   - Construct train/val images from refines/<dataset>/<case_id>/<noise_name>/<split>/.../*.txt "only those with labels"
# #   - Even if original data (image) exists, ✅do not include if no refined label (no empty label creation ❌)
# #   - Val evaluation always based on original labels (/home/ISW/project/datasets/{ds}/labels/...)
# #     -> Use relpath of label txt in refines to find and match same relpath txt in original labels
# #   - refines noise_name uses only labels_uniform_scaling_*, labels_boundary_jitter_*
# #   - Save precision/recall/F1/mAP50/mAP50-95 to summary
# # ==========================================

# from __future__ import annotations

# import os, json, shutil, random, csv, sys, time, hashlib
# from pathlib import Path
# from typing import List, Dict, Tuple, Optional
# from contextlib import contextmanager

# # ✅ Image integrity check libs
# try:
#     import cv2
#     from PIL import Image
#     from tqdm import tqdm
# except ImportError:
#     print("⚠️ Required libraries are missing. Please install: pip install opencv-python pillow tqdm")
#     sys.exit(1)

# # ✅ OOM fragmentation mitigation
# os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")

# import torch
# from ultralytics import YOLO
# import logging
# logging.getLogger("ultralytics").setLevel(logging.ERROR)

# # -------------------------------------------------------------------------
# # Fixed grids (noise dirs)
# # -------------------------------------------------------------------------
# UNIFORM_SCALING_FACTORS = [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.3, 1.4]
# JITTER_PATTERNS = [1, 3, 5, 7, 9]

# # -------------------------------------------------------------------------
# # ✅ Refined label usage settings (Cell4 OUT_ROOT)
# # -------------------------------------------------------------------------
# REFINE_SRC_OUT_ROOT = Path(f"./experiments_ablation(scale_side, start=noise)-({n_data})")  # <-- modifiable
# REFINES_DIR = REFINE_SRC_OUT_ROOT / "refines"
# REFINE_METRICS_CSV = REFINE_SRC_OUT_ROOT / "_refine_summary_metrics.csv"

# # refined case selection policy:
# # - "best_delta": use only 1 case_id with max mean(delta_iou) per dataset
# # - "all_cases": use all case_ids
# REFINE_CASE_POLICY = "all_cases"   # "best_delta" | "all_cases"
# MAX_REFINED_NOISES_PER_DATASET: Optional[int] = None  # None=all, or e.g. 5

# # -------------------------------------------------------------------------
# # ✅ Train-case control flags
# # -------------------------------------------------------------------------
# TRAIN_USE_ORIGINAL    = True
# TRAIN_USE_UNIFORM_SCALING_NOISE = True
# TRAIN_USE_BOUNDARY_JITTER_NOISE  = True
# TRAIN_USE_REFINED     = True   # ✅ Enable refined training cases

# # -------------------------------------------------------------------------
# # Speed control
# # -------------------------------------------------------------------------
# TRAIN_FRACTION = 1.0
# TRAIN_MIN_IMAGES = 50
# NUM_WORKERS = min(8, os.cpu_count() or 4)

# # -------------------------------------------------------------------------
# # User config
# # -------------------------------------------------------------------------
# DATASETS_ROOT = Path("/home/ISW/project/datasets")

# IMG_SIZE = 640
# EPOCHS = 10
# BATCH = 32
# DEVICE = "0"
# SEED = 42

# # ✅ Detection training result save root (recommended to separate from Cell4 OUT_ROOT)
# OUT_ROOT = Path(f"/home/ISW/project/noise_object_detect-({n_data})")
# OUT_ROOT.mkdir(parents=True, exist_ok=True)

# RUNTIME_VROOT_BASE = Path(f"/home/ISW/project/_runtime_dataset_views-({n_data})")
# RUNTIME_VROOT_BASE.mkdir(parents=True, exist_ok=True)

# CORRUPT_BACKUP_DIR = Path(f"/home/ISW/project/_corrupt_files_backup-({n_data})")
# CORRUPT_BACKUP_DIR.mkdir(parents=True, exist_ok=True)

# CLEANUP_RUNTIME_VROOT = True
# SILENCE_ULTRA_OUTPUT = True

# CLASS_MODES = ["multiclass", "object_only"]

# TARGET_DATASETS: Optional[List[str]] = [
#     "kitti",
#     "homeobjects-3K",
#     "african-wildlife",
#     "construction-ppe",
#     "Custom_Blood",
#     "brain-tumor",
#     "BCCD",
#     "signature",
#     "medical-pills",
#     "VOC",
# ]
# TARGET_DATASETS_LOWER = (
#     {name.strip().lower() for name in TARGET_DATASETS}
#     if TARGET_DATASETS is not None
#     else None
# )

# @contextmanager
# def suppress_output(enabled: bool = True):
#     if not enabled:
#         yield
#         return
#     devnull = open(os.devnull, "w")
#     old_out, old_err = sys.stdout, sys.stderr
#     try:
#         sys.stdout, sys.stderr = devnull, devnull
#         yield
#     finally:
#         sys.stdout, sys.stderr = old_out, old_err
#         devnull.close()

# # -------------------------------------------------------------------------
# # Model specs
# # -------------------------------------------------------------------------
# YOLOV8N_CKPT_CANDIDATES = ["yolov8n.pt"]
# YOLO11N_CKPT_CANDIDATES = ["yolo11n.pt", "yolov11n.pt"]
# DETR_CKPT_CANDIDATES   = ["rtdetr-s.pt", "rtdetr-l.pt"]

# MODEL_SPECS = [
#     ("yolov8n", YOLOV8N_CKPT_CANDIDATES),
#     ("yolo11n", YOLO11N_CKPT_CANDIDATES),
#     ("detr",    DETR_CKPT_CANDIDATES),
# ]

# # -------------------------------------------------------------------------
# # Utils
# # -------------------------------------------------------------------------
# _IMG_EXTS = [".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"]

# def set_seed(seed: int = 42):
#     random.seed(seed)
#     torch.manual_seed(seed)
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed_all(seed)

# def list_images(dir_path: Optional[Path]) -> List[Path]:
#     if dir_path is None or not Path(dir_path).exists():
#         return []
#     dir_path = Path(dir_path)
#     imgs = []
#     for p in dir_path.rglob("*"):
#         if p.is_file() and p.suffix.lower() in set(_IMG_EXTS):
#             imgs.append(p)
#     return sorted(imgs)

# def _safe_symlink(src: Path, dst: Path):
#     dst.parent.mkdir(parents=True, exist_ok=True)
#     if dst.exists() or dst.is_symlink():
#         return
#     os.symlink(str(src), str(dst))

# def _safe_copytree(src: Path, dst: Path):
#     if not src.exists():
#         return
#     dst.parent.mkdir(parents=True, exist_ok=True)
#     shutil.copytree(src, dst, dirs_exist_ok=True)

# def _link_or_copy(src: Path, dst: Path, prefer_symlink: bool = True):
#     if not src.exists():
#         return
#     try:
#         if src.is_dir():
#             _safe_copytree(src, dst)
#             return
#         if prefer_symlink:
#             _safe_symlink(src, dst)
#             return
#         dst.parent.mkdir(parents=True, exist_ok=True)
#         shutil.copy2(src, dst)
#     except Exception:
#         try:
#             dst.parent.mkdir(parents=True, exist_ok=True)
#             shutil.copy2(src, dst)
#         except Exception:
#             pass

# def infer_class_names_from_labels(label_root: Path, max_files: int = 2000) -> List[str]:
#     if label_root is None or not label_root.exists():
#         return ["class_0"]
#     txts = list(label_root.rglob("*.txt"))
#     if not txts:
#         return ["class_0"]

#     txts = txts[:max_files]
#     cls_ids = set()
#     for t in txts:
#         try:
#             with open(t, "r", encoding="utf-8") as f:
#                 for line in f:
#                     parts = line.strip().split()
#                     if len(parts) < 5:
#                         continue
#                     cid = int(float(parts[0]))
#                     cls_ids.add(cid)
#         except Exception:
#             continue

#     if not cls_ids:
#         return ["class_0"]
#     max_id = max(cls_ids)
#     return [f"class_{i}" for i in range(max_id + 1)]

# def choose_model(ckpt_candidates: List[str]) -> YOLO:
#     last_err = None
#     for ckpt in ckpt_candidates:
#         try:
#             return YOLO(ckpt)
#         except Exception as e:
#             last_err = e
#     raise RuntimeError(f"Failed to load model weights: {ckpt_candidates}") from last_err

# def extract_metrics_dict(val_result) -> Dict:
#     if val_result is None:
#         return {}
#     if hasattr(val_result, "results_dict") and isinstance(val_result.results_dict, dict):
#         return dict(val_result.results_dict)
#     try:
#         d = dict(val_result.__dict__)
#         d.pop("plots", None)
#         d.pop("speed", None)
#         return d
#     except Exception:
#         return {}

# def compute_f1_from_metrics(metrics: Dict) -> Optional[float]:
#     p = metrics.get("metrics/precision(B)", None)
#     r = metrics.get("metrics/recall(B)", None)
#     try:
#         if p is None or r is None:
#             return None
#         p = float(p); r = float(r)
#         if (p + r) <= 1e-12:
#             return 0.0
#         return float(2.0 * p * r / (p + r))
#     except Exception:
#         return None

# # -------------------------------------------------------------------------
# # object-only label rewrite
# # -------------------------------------------------------------------------
# def rewrite_label_file_to_object_only(src: Path, dst: Path):
#     dst.parent.mkdir(parents=True, exist_ok=True)
#     if not src.exists():
#         # Requirement: "Don't create empty labels" is to prevent refined omission.
#         # object_only conversion is configured to only be called for "targets that already have labels",
#         # so skip safely here.
#         return

#     try:
#         lines_out = []
#         with open(src, "r", encoding="utf-8") as f:
#             for line in f:
#                 parts = line.strip().split()
#                 if len(parts) < 5:
#                     continue
#                 parts[0] = "0"
#                 lines_out.append(" ".join(parts))
#         with open(dst, "w", encoding="utf-8") as f:
#             for ln in lines_out:
#                 f.write(ln + "\n")
#     except Exception:
#         pass

# # -------------------------------------------------------------------------
# # Corrupt Image Cleaner
# # -------------------------------------------------------------------------
# def scan_and_clean_images(dir_path: Path):
#     if not dir_path.exists():
#         return

#     images = list_images(dir_path)
#     if not images:
#         return

#     print(f"   🔍 Scanning integrity of {len(images)} images in {dir_path.name}...")
#     corrupt_count = 0
#     for img_path in tqdm(images, desc="Checking", leave=False):
#         is_corrupt = False

#         try:
#             with Image.open(img_path) as im:
#                 im.verify()
#         except Exception:
#             is_corrupt = True

#         if not is_corrupt:
#             try:
#                 img = cv2.imread(str(img_path))
#                 if img is None:
#                     is_corrupt = True
#                 else:
#                     _ = img.shape
#             except Exception:
#                 is_corrupt = True

#         if is_corrupt:
#             corrupt_count += 1
#             dest = CORRUPT_BACKUP_DIR / dir_path.name / img_path.name
#             dest.parent.mkdir(parents=True, exist_ok=True)

#             label_path = img_path.parent.parent / "labels" / img_path.parent.name / img_path.with_suffix(".txt").name
#             if not label_path.exists():
#                 label_path = img_path.with_suffix(".txt")

#             try:
#                 shutil.move(str(img_path), str(dest))
#                 if label_path.exists():
#                     shutil.move(str(label_path), str(dest.with_suffix(".txt")))
#             except Exception:
#                 pass

#     if corrupt_count > 0:
#         print(f"   ⚠️  Moved {corrupt_count} corrupt images to {CORRUPT_BACKUP_DIR}")
#     else:
#         print(f"   ✅  No corrupt images found.")

# # -------------------------------------------------------------------------
# # Split/Case helpers
# # -------------------------------------------------------------------------
# def normalize_name(name: str) -> str:
#     return name.strip().lower().replace("_", "-").replace(" ", "-")

# # Use Cell1 dataset_summaries when available
# _ds_map: Dict[str, Dict] = {}
# try:
#     for info in dataset_summaries:
#         key = normalize_name(info.get("dataset", ""))
#         if key:
#             _ds_map[key] = info
# except Exception:
#     _ds_map = {}

# def get_split_info(ds_root: Path) -> Dict[str, Optional[Path]]:
#     key = normalize_name(ds_root.name)
#     if key in _ds_map:
#         info = _ds_map[key]
#         return dict(
#             train_img_dir=Path(info["train_dir"]) if info.get("train_dir") else None,
#             val_img_dir=Path(info["val_dir"]) if info.get("val_dir") else None,
#             split_mode=info.get("split_mode", "fallback"),
#             train_tag=info.get("train_tag", "train"),
#             val_tag=info.get("val_tag", "val"),
#         )

#     images_root = ds_root / "images"
#     tr = images_root / "train" if (images_root / "train").is_dir() else images_root
#     va = images_root / "val" if (images_root / "val").is_dir() else (images_root / "valid" if (images_root / "valid").is_dir() else None)
#     return dict(
#         train_img_dir=tr,
#         val_img_dir=va,
#         split_mode="fallback",
#         train_tag=tr.name if tr else "unknown",
#         val_tag=va.name if va else "missing",
#     )

# # -------------------------------------------------------------------------
# # ✅ refined case discovery
# # -------------------------------------------------------------------------
# def _short_hash(s: str, n: int = 8) -> str:
#     return hashlib.md5(s.encode("utf-8")).hexdigest()[:n]

# def _read_refine_metrics_csv() -> Optional[List[Dict[str, str]]]:
#     if not REFINE_METRICS_CSV.exists():
#         return None
#     rows = []
#     try:
#         with open(REFINE_METRICS_CSV, "r", encoding="utf-8", newline="") as f:
#             rd = csv.DictReader(f)
#             for r in rd:
#                 rows.append(r)
#         return rows
#     except Exception:
#         return None

# def pick_best_case_id_for_dataset(ds_name: str) -> Optional[str]:
#     rows = _read_refine_metrics_csv()
#     if not rows:
#         ds_dir = REFINES_DIR / ds_name
#         if not ds_dir.is_dir():
#             return None
#         case_dirs = [p for p in ds_dir.iterdir() if p.is_dir()]
#         if not case_dirs:
#             return None
#         case_dirs = sorted(case_dirs, key=lambda p: p.stat().st_mtime, reverse=True)
#         return case_dirs[0].name

#     key = normalize_name(ds_name)
#     agg: Dict[str, List[float]] = {}
#     for r in rows:
#         if normalize_name(r.get("dataset", "")) != key:
#             continue
#         case_id = (r.get("case_id", "") or "").strip()
#         if not case_id:
#             continue
#         try:
#             d = float(r.get("delta_iou", "nan"))
#         except Exception:
#             continue
#         if not (d == d):  # NaN
#             continue
#         agg.setdefault(case_id, []).append(d)

#     if not agg:
#         return None

#     best_case = None
#     best_score = -1e18
#     for cid, vals in agg.items():
#         if not vals:
#             continue
#         sc = sum(vals) / max(1, len(vals))
#         if sc > best_score:
#             best_score = sc
#             best_case = cid
#     return best_case

# def list_refined_label_cases_for_dataset(ds_root: Path) -> List[Tuple[str, Path]]:
#     """
#     return: (case_tag, case_label_root_path)
#     Here case_label_root_path is set to refines/.../<noise_name> folder (train*/val* splits exist below).
#     """
#     out: List[Tuple[str, Path]] = []
#     if not TRAIN_USE_REFINED:
#         return out
#     if not REFINES_DIR.is_dir():
#         return out

#     ds_dir = REFINES_DIR / ds_root.name
#     if not ds_dir.is_dir():
#         # case-insensitive fallback
#         for p in REFINES_DIR.iterdir():
#             if p.is_dir() and normalize_name(p.name) == normalize_name(ds_root.name):
#                 ds_dir = p
#                 break
#     if not ds_dir.is_dir():
#         return out

#     if REFINE_CASE_POLICY == "all_cases":
#         case_ids = sorted([p.name for p in ds_dir.iterdir() if p.is_dir()])
#     else:
#         best = pick_best_case_id_for_dataset(ds_root.name)
#         case_ids = [best] if best else []

#     for case_id in case_ids:
#         if not case_id:
#             continue
#         case_dir = ds_dir / case_id
#         if not case_dir.is_dir():
#             continue

#         noise_dirs = [p for p in case_dir.iterdir() if p.is_dir()]
#         noise_dirs = sorted(noise_dirs, key=lambda p: p.name)

#         # ✅ Requirement: use only labels_uniform_scaling_* / labels_boundary_jitter_*
#         noise_dirs = [
#             nd for nd in noise_dirs
#             if nd.name.startswith("labels_uniform_scaling_") or nd.name.startswith("labels_boundary_jitter_")
#         ]

#         if MAX_REFINED_NOISES_PER_DATASET is not None:
#             noise_dirs = noise_dirs[:int(MAX_REFINED_NOISES_PER_DATASET)]

#         case_hash = _short_hash(case_id, 8)
#         for nd in noise_dirs:
#             noise_name = nd.name
#             tag = f"refined__{noise_name}__{case_hash}"
#             out.append((tag, nd))
#     return out

# # -------------------------------------------------------------------------
# # label case listing
# # -------------------------------------------------------------------------
# def list_label_cases_for_dataset(ds_root: Path) -> List[Tuple[str, Path]]:
#     cases: List[Tuple[str, Path]] = []

#     if TRAIN_USE_ORIGINAL:
#         p = ds_root / "labels"
#         if p.is_dir():
#             cases.append(("original", p))

#     if TRAIN_USE_UNIFORM_SCALING_NOISE:
#         for s in UNIFORM_SCALING_FACTORS:
#             dirname = f"labels_uniform_scaling_{s}"
#             p = ds_root / dirname
#             if p.is_dir():
#                 cases.append((f"scale_{s}", p))

#     if TRAIN_USE_BOUNDARY_JITTER_NOISE:
#         for k in JITTER_PATTERNS:
#             dirname = f"labels_boundary_jitter_{k}"
#             p = ds_root / dirname
#             if p.is_dir():
#                 cases.append((f"side_{k}", p))

#     if TRAIN_USE_REFINED:
#         cases.extend(list_refined_label_cases_for_dataset(ds_root))

#     return cases

# # -------------------------------------------------------------------------
# # refined split base resolver (auto-search train*/val*)
# # -------------------------------------------------------------------------
# def _pick_split_dir(label_root: Path, preferred_tag: str, kind: str) -> Optional[Path]:
#     """
#     kind: "train" or "val"
#     Priority:
#       1) label_root/preferred_tag
#       2) label_root/<kind> (exact)
#       3) Select one from subdirectories starting with kind (train*, val*, valid*) in label_root
#       4) None if not found
#     """
#     if preferred_tag:
#         p = label_root / preferred_tag
#         if p.is_dir():
#             return p

#     # common exact names
#     if kind == "train":
#         for cand in ["train", "train2012", "train2017", "training"]:
#             p = label_root / cand
#             if p.is_dir():
#                 return p
#     else:
#         for cand in ["val", "valid", "val2012", "val2017", "validation"]:
#             p = label_root / cand
#             if p.is_dir():
#                 return p

#     # prefix search
#     if label_root.is_dir():
#         subs = [d for d in label_root.iterdir() if d.is_dir()]
#         if kind == "train":
#             cands = [d for d in subs if d.name.lower().startswith("train")]
#         else:
#             cands = [d for d in subs if (d.name.lower().startswith("val") or d.name.lower().startswith("valid"))]

#         if cands:
#             # Prefer names similar to preferred_tag, otherwise first by name sort
#             pref = preferred_tag.lower() if preferred_tag else ""
#             cands = sorted(
#                 cands,
#                 key=lambda d: (0 if pref and pref in d.name.lower() else 1, d.name)
#             )
#             return cands[0]

#     return None

# def _find_image_by_rel(img_base: Path, rel_txt: Path) -> Optional[Path]:
#     """
#     rel_txt: relative path with .txt suffix (e.g., a/b/0001.txt)
#     Find image with img_base/rel_txt.with_suffix(ext)
#     """
#     stem_rel = rel_txt.with_suffix("")  # remove .txt
#     for ext in _IMG_EXTS:
#         p = img_base / (str(stem_rel) + ext)
#         if p.exists():
#             return p
#     return None

# def _collect_pairs_from_refined_label_base(
#     lbl_base: Path,
#     img_base: Path,
# ) -> List[Tuple[Path, Path, Path]]:
#     """
#     return list of (rel_txt, img_path, lbl_path)
#     - rel_txt is relative to lbl_base and endswith .txt
#     - img_path is resolved under img_base using same rel path (suffix replaced)
#     """
#     if lbl_base is None or not lbl_base.is_dir():
#         return []
#     if img_base is None or not img_base.is_dir():
#         return []

#     out = []
#     for lbl_path in lbl_base.rglob("*.txt"):
#         if not lbl_path.is_file():
#             continue
#         rel_txt = lbl_path.relative_to(lbl_base)
#         img_path = _find_image_by_rel(img_base, rel_txt)
#         if img_path is None:
#             continue
#         out.append((rel_txt, img_path, lbl_path))
#     return out

# def sample_train_pairs(pairs: List[Tuple[Path, Path, Path]], fraction: float, seed: int) -> Tuple[List[Tuple[Path, Path, Path]], int]:
#     n_total = len(pairs)
#     if n_total == 0:
#         return [], 0
#     if fraction >= 1.0:
#         return pairs, n_total
#     n_pick = int(n_total * fraction)
#     n_pick = max(TRAIN_MIN_IMAGES, n_pick)
#     n_pick = min(n_pick, n_total)
#     rnd = random.Random(seed)
#     idxs = rnd.sample(range(n_total), k=n_pick)
#     chosen = [pairs[i] for i in idxs]
#     return chosen, n_total

# # -------------------------------------------------------------------------
# # SKU virtual split helper (fallback only)
# # -------------------------------------------------------------------------
# def sku_virtual_split_images(images_root: Path, seed: int = 42, ratio: float = 0.8) -> Tuple[List[Path], List[Path]]:
#     imgs = list_images(images_root)
#     n = len(imgs)
#     if n == 0:
#         return [], []
#     rnd = random.Random(seed)
#     idxs = list(range(n))
#     rnd.shuffle(idxs)
#     cut = int(n * ratio)
#     train_imgs = [imgs[i] for i in idxs[:cut]]
#     val_imgs   = [imgs[i] for i in idxs[cut:]]
#     return train_imgs, val_imgs

# def sample_train_images(base_train_imgs: List[Path], fraction: float, seed: int) -> Tuple[List[Path], int]:
#     n_total = len(base_train_imgs)
#     if n_total == 0:
#         return [], 0
#     if fraction >= 1.0:
#         return base_train_imgs, n_total
#     n_pick = int(n_total * fraction)
#     n_pick = max(TRAIN_MIN_IMAGES, n_pick)
#     n_pick = min(n_pick, n_total)
#     rnd = random.Random(seed)
#     chosen = rnd.sample(base_train_imgs, k=n_pick)
#     return chosen, n_total

# # -------------------------------------------------------------------------
# # Runtime View Builder
# #   - ✅ refined cases: construct train/val with only "images with refined labels" (no empty label creation ❌)
# #   - ✅ val evaluation uses original labels
# # -------------------------------------------------------------------------
# def build_runtime_view_root(
#     ds_root: Path,
#     case_label_root: Path,
#     train_fraction: float,
#     seed: int,
#     case_tag: str,
#     class_mode: str = "multiclass",
# ) -> Tuple[Path, Path, int, int, int, int, str]:
#     """
#     return:
#       vroot, data_yaml,
#       train_used, train_total,
#       val_used, val_total,
#       split_mode
#     """

#     assert class_mode in ("multiclass", "object_only")
#     images_root     = ds_root / "images"
#     orig_label_root = ds_root / "labels"
#     assert case_label_root.is_dir(), f"case_label_root missing: {case_label_root}"

#     split_info = get_split_info(ds_root)
#     train_img_dir = Path(split_info["train_img_dir"]) if split_info["train_img_dir"] else None
#     val_img_dir   = Path(split_info["val_img_dir"]) if split_info["val_img_dir"] else None
#     split_mode    = split_info["split_mode"]
#     train_tag     = split_info.get("train_tag", "train")
#     val_tag       = split_info.get("val_tag", "val")

#     # Include case_tag in vroot key (abbreviate with hash if name is long)
#     safe_case = case_tag.replace("/", "__").replace("\\", "__")
#     if len(safe_case) > 120:
#         safe_case = safe_case[:80] + "__" + _short_hash(safe_case, 8)

#     vroot    = RUNTIME_VROOT_BASE / ds_root.name / f"case__{safe_case}__{class_mode}"
#     v_images = vroot / "images"
#     v_labels = vroot / "labels"

#     if vroot.exists():
#         try:
#             shutil.rmtree(vroot)
#         except Exception:
#             pass

#     (v_images / "train").mkdir(parents=True, exist_ok=True)
#     (v_images / "val").mkdir(parents=True, exist_ok=True)
#     (v_labels / "train").mkdir(parents=True, exist_ok=True)
#     (v_labels / "val").mkdir(parents=True, exist_ok=True)

#     # ---------------------------------------------------------------------
#     # ✅ (1) refined cases: construct train/val from refines "only those with labels"
#     # ---------------------------------------------------------------------
#     is_refined = case_tag.startswith("refined__")

#     if is_refined:
#         # Image base selection: prefer standard split, fallback to images_root if not found
#         if train_img_dir is None or not train_img_dir.is_dir():
#             train_img_dir = images_root
#         if val_img_dir is None or not val_img_dir.is_dir():
#             val_img_dir = images_root

#         # Search refined split label base (auto train*/val*)
#         case_train_lbl_base = _pick_split_dir(case_label_root, train_tag, kind="train")
#         case_val_lbl_base   = _pick_split_dir(case_label_root, val_tag,   kind="val")

#         if case_train_lbl_base is None or case_val_lbl_base is None:
#             raise RuntimeError(
#                 f"Refined split dirs missing: train_base={case_train_lbl_base}, val_base={case_val_lbl_base} "
#                 f"(label_root={case_label_root})"
#             )

#         # Original val label base (for evaluation)
#         orig_val_lbl_base = _pick_split_dir(orig_label_root, val_tag, kind="val")
#         if orig_val_lbl_base is None:
#             # When val_tag is missing or structure differs, additionally search labels/val, labels/valid, etc.
#             orig_val_lbl_base = _pick_split_dir(orig_label_root, "", kind="val")
#         if orig_val_lbl_base is None:
#             raise RuntimeError(f"Original val labels dir not found under {orig_label_root}")

#         # Collect image/label matches based on refined label existence
#         train_pairs_all = _collect_pairs_from_refined_label_base(case_train_lbl_base, train_img_dir)
#         val_pairs_all   = _collect_pairs_from_refined_label_base(case_val_lbl_base,   val_img_dir)

#         # ✅ val must be evaluated with "original labels", so keep only those with original labels
#         val_pairs = []
#         for rel_txt, img_path, refined_lbl in val_pairs_all:
#             orig_lbl = orig_val_lbl_base / rel_txt
#             if orig_lbl.exists():
#                 val_pairs.append((rel_txt, img_path, orig_lbl))
#         val_total = len(val_pairs)

#         # train already has only those with refined labels
#         train_pairs_sampled, train_total = sample_train_pairs(train_pairs_all, train_fraction, seed)
#         train_used = len(train_pairs_sampled)

#         if train_used < max(1, TRAIN_MIN_IMAGES):
#             raise RuntimeError(
#                 f"Refined train coverage too low: train_used={train_used}, train_total={train_total} "
#                 f"(base={case_train_lbl_base})"
#             )
#         if val_total == 0:
#             raise RuntimeError(
#                 f"Refined val coverage is zero after matching original labels (base={case_val_lbl_base}, orig={orig_val_lbl_base})"
#             )

#         # ---- runtime view: train (images + refined labels)
#         for rel_txt, img_path, refined_lbl_path in train_pairs_sampled:
#             # image rel path uses same rel structure; use actual image suffix
#             rel_img = rel_txt.with_suffix(img_path.suffix)
#             dst_img = v_images / "train" / rel_img
#             try:
#                 _safe_symlink(img_path, dst_img)
#             except Exception:
#                 _link_or_copy(img_path, dst_img, prefer_symlink=False)

#             # label
#             dst_lbl = v_labels / "train" / rel_txt
#             if class_mode == "multiclass":
#                 try:
#                     _safe_symlink(refined_lbl_path, dst_lbl)
#                 except Exception:
#                     _link_or_copy(refined_lbl_path, dst_lbl, prefer_symlink=False)
#             else:
#                 rewrite_label_file_to_object_only(refined_lbl_path, dst_lbl)

#         # ---- runtime view: val (images + ORIGINAL labels)
#         val_used = 0
#         for rel_txt, img_path, orig_lbl_path in val_pairs:
#             rel_img = rel_txt.with_suffix(img_path.suffix)
#             dst_img = v_images / "val" / rel_img
#             try:
#                 _safe_symlink(img_path, dst_img)
#             except Exception:
#                 _link_or_copy(img_path, dst_img, prefer_symlink=False)

#             dst_lbl = v_labels / "val" / rel_txt
#             if class_mode == "multiclass":
#                 try:
#                     _safe_symlink(orig_lbl_path, dst_lbl)
#                 except Exception:
#                     _link_or_copy(orig_lbl_path, dst_lbl, prefer_symlink=False)
#             else:
#                 rewrite_label_file_to_object_only(orig_lbl_path, dst_lbl)
#             val_used += 1

#         # classes
#         if class_mode == "multiclass":
#             names = infer_class_names_from_labels(orig_label_root)
#             nc = len(names)
#         else:
#             names = ["object"]
#             nc = 1

#         data_yaml = vroot / "data.yaml"
#         with open(data_yaml, "w", encoding="utf-8") as f:
#             f.write(
#                 f"path: {str(vroot)}\n"
#                 f"train: images/train\n"
#                 f"val: images/val\n"
#                 f"nc: {nc}\n"
#                 f"names: {names}\n"
#             )

#         rt_n_train = len(list_images(v_images / "train"))
#         rt_n_val   = len(list_images(v_images / "val"))
#         if rt_n_train == 0:
#             raise RuntimeError(f"Runtime train images empty (refined): {v_images/'train'}")
#         if rt_n_val == 0:
#             raise RuntimeError(f"Runtime val images empty (refined): {v_images/'val'}")

#         return vroot, data_yaml, train_used, train_total, val_used, val_total, split_mode

#     # ---------------------------------------------------------------------
#     # (2) non-refined legacy logic (use if needed)
#     #     - Keep existing behavior (but requirements are refined-centric, so leave as is)
#     # ---------------------------------------------------------------------

#     # SKU virtual split
#     if split_mode == "sku_virtual_8_2":
#         base_train_imgs, base_val_imgs = sku_virtual_split_images(images_root, seed=seed, ratio=0.8)
#         chosen_train_imgs, n_total_train = sample_train_images(base_train_imgs, train_fraction, seed)

#         for img in chosen_train_imgs:
#             rel = img.relative_to(images_root)
#             dst = v_images / "train" / rel
#             try:
#                 _safe_symlink(img, dst)
#             except Exception:
#                 _link_or_copy(img, dst, prefer_symlink=False)

#         for img in base_val_imgs:
#             rel = img.relative_to(images_root)
#             dst = v_images / "val" / rel
#             try:
#                 _safe_symlink(img, dst)
#             except Exception:
#                 _link_or_copy(img, dst, prefer_symlink=False)

#         # labels (train: case, val: original)
#         case_train_lbl_base = case_label_root / train_tag if (case_label_root / train_tag).is_dir() else case_label_root
#         orig_val_lbl_base   = orig_label_root / val_tag if (orig_label_root / val_tag).is_dir() else orig_label_root

#         # NOTE: Old cases had empty label creation, but differs from refined requirements, so keep existing here.
#         for img in chosen_train_imgs:
#             rel = img.relative_to(images_root)
#             src_lbl = case_train_lbl_base / rel.with_suffix(".txt")
#             if not src_lbl.exists():
#                 continue  # Changed to "skip if not found" for non-refined too (safe)
#             dst_lbl = v_labels / "train" / rel.with_suffix(".txt")
#             if class_mode == "multiclass":
#                 try:
#                     _safe_symlink(src_lbl, dst_lbl)
#                 except Exception:
#                     _link_or_copy(src_lbl, dst_lbl, prefer_symlink=False)
#             else:
#                 rewrite_label_file_to_object_only(src_lbl, dst_lbl)

#         val_used = 0
#         for img in base_val_imgs:
#             rel = img.relative_to(images_root)
#             src_lbl = orig_val_lbl_base / rel.with_suffix(".txt")
#             if not src_lbl.exists():
#                 continue
#             dst_lbl = v_labels / "val" / rel.with_suffix(".txt")
#             if class_mode == "multiclass":
#                 try:
#                     _safe_symlink(src_lbl, dst_lbl)
#                 except Exception:
#                     _link_or_copy(src_lbl, dst_lbl, prefer_symlink=False)
#             else:
#                 rewrite_label_file_to_object_only(src_lbl, dst_lbl)
#             val_used += 1

#         if class_mode == "multiclass":
#             names = infer_class_names_from_labels(orig_label_root)
#             nc = len(names)
#         else:
#             names = ["object"]
#             nc = 1

#         data_yaml = vroot / "data.yaml"
#         with open(data_yaml, "w", encoding="utf-8") as f:
#             f.write(
#                 f"path: {str(vroot)}\n"
#                 f"train: images/train\n"
#                 f"val: images/val\n"
#                 f"nc: {nc}\n"
#                 f"names: {names}\n"
#             )

#         train_used = len(list_images(v_images / "train"))
#         train_total = len(base_train_imgs)
#         val_total = len(base_val_imgs)

#         if train_used == 0:
#             raise RuntimeError(f"Runtime train images empty (SKU): {v_images/'train'}")
#         if val_used == 0:
#             raise RuntimeError(f"Runtime val images empty (SKU): {v_images/'val'}")

#         return vroot, data_yaml, train_used, train_total, val_used, val_total, split_mode

#     # Standard split
#     if train_img_dir is None or not Path(train_img_dir).is_dir():
#         raise RuntimeError(f"No train images dir resolved for {ds_root.name}")
#     if val_img_dir is None or not Path(val_img_dir).is_dir():
#         raise RuntimeError(f"No val images dir resolved for {ds_root.name}")

#     all_train_imgs = list_images(train_img_dir)
#     chosen_train_imgs, n_total_train = sample_train_images(all_train_imgs, train_fraction, seed)
#     if n_total_train == 0 or len(chosen_train_imgs) == 0:
#         raise RuntimeError(f"No train images for {ds_root.name}")

#     for img in chosen_train_imgs:
#         rel = img.relative_to(train_img_dir)
#         dst = v_images / "train" / rel
#         try:
#             _safe_symlink(img, dst)
#         except Exception:
#             _link_or_copy(img, dst, prefer_symlink=False)

#     _safe_copytree(Path(val_img_dir), v_images / "val")

#     # labels: train case, val original
#     case_train_lbl_base = case_label_root / train_tag if (case_label_root / train_tag).is_dir() else case_label_root

#     # train labels: skip if not found (was empty creation before, but skip for safety)
#     for img in chosen_train_imgs:
#         rel = img.relative_to(train_img_dir)
#         src_lbl = case_train_lbl_base / rel.with_suffix(".txt")
#         if not src_lbl.exists():
#             continue
#         dst_lbl = v_labels / "train" / rel.with_suffix(".txt")
#         if class_mode == "multiclass":
#             try:
#                 _safe_symlink(src_lbl, dst_lbl)
#             except Exception:
#                 _link_or_copy(src_lbl, dst_lbl, prefer_symlink=False)
#         else:
#             rewrite_label_file_to_object_only(src_lbl, dst_lbl)

#     # val labels: original
#     orig_val_lbl_base = _pick_split_dir(orig_label_root, val_tag, kind="val") or orig_label_root
#     val_used = 0
#     for src_lbl in orig_val_lbl_base.rglob("*.txt"):
#         rel = src_lbl.relative_to(orig_val_lbl_base)
#         dst_lbl = v_labels / "val" / rel
#         if class_mode == "multiclass":
#             try:
#                 _safe_symlink(src_lbl, dst_lbl)
#             except Exception:
#                 _link_or_copy(src_lbl, dst_lbl, prefer_symlink=False)
#         else:
#             rewrite_label_file_to_object_only(src_lbl, dst_lbl)
#         val_used += 1

#     if class_mode == "multiclass":
#         names = infer_class_names_from_labels(orig_label_root)
#         nc = len(names)
#     else:
#         names = ["object"]
#         nc = 1

#     data_yaml = vroot / "data.yaml"
#     with open(data_yaml, "w", encoding="utf-8") as f:
#         f.write(
#             f"path: {str(vroot)}\n"
#             f"train: images/train\n"
#             f"val: images/val\n"
#             f"nc: {nc}\n"
#             f"names: {names}\n"
#         )

#     rt_n_train = len(list_images(v_images / "train"))
#     rt_n_val   = len(list_images(v_images / "val"))
#     if rt_n_train == 0:
#         raise RuntimeError(f"Runtime train images empty: {v_images/'train'}")
#     if rt_n_val == 0:
#         raise RuntimeError(f"Runtime val images empty: {v_images/'val'}")

#     train_used = rt_n_train
#     train_total = n_total_train
#     val_total = rt_n_val

#     return vroot, data_yaml, train_used, train_total, val_used, val_total, split_mode

# # -------------------------------------------------------------------------
# # OOM-safe train wrapper
# # -------------------------------------------------------------------------
# def train_with_auto_oom(model: YOLO, data_yaml: Path, project_dir: Path, name_dir: str, model_tag: str):
#     if model_tag == "detr":
#         candidates = [(4, 640), (2, 640), (2, 512), (1, 512)]
#     else:
#         candidates = [(BATCH, IMG_SIZE)]

#     last_err = None
#     for b, sz in candidates:
#         try:
#             with suppress_output(SILENCE_ULTRA_OUTPUT):
#                 model.train(
#                     data=str(data_yaml),
#                     epochs=EPOCHS,
#                     imgsz=sz,
#                     batch=b,
#                     device=DEVICE,
#                     project=str(project_dir),
#                     name=name_dir,
#                     exist_ok=True,
#                     verbose=False,
#                     workers=NUM_WORKERS,
#                     amp=True,
#                 )
#             return True, b, sz, None
#         except RuntimeError as e:
#             msg = str(e).lower()
#             last_err = e
#             if "out of memory" in msg or "cuda out of memory" in msg:
#                 if torch.cuda.is_available():
#                     torch.cuda.empty_cache()
#                     torch.cuda.ipc_collect()
#                 continue
#             break
#         except Exception as e:
#             last_err = e
#             break
#     return False, None, None, last_err

# # -------------------------------------------------------------------------
# # roots fallback
# # -------------------------------------------------------------------------
# def discover_roots_from_datasets_root(datasets_root: Path) -> List[Path]:
#     if not datasets_root.is_dir():
#         return []
#     return sorted([p for p in datasets_root.iterdir() if p.is_dir()])

# # -------------------------------------------------------------------------
# # Train & validate loop
# # -------------------------------------------------------------------------
# set_seed(SEED)

# print("=" * 80)
# print("[TRAIN/EVAL] Start (refines-only train/val images + eval vs original labels)")
# print(f" - DATASETS_ROOT         : {DATASETS_ROOT}")
# print(f" - OUT_ROOT (det exp)    : {OUT_ROOT}")
# print(f" - REFINE_SRC_OUT_ROOT   : {REFINE_SRC_OUT_ROOT}")
# print(f" - REFINES_DIR           : {REFINES_DIR}")
# print(f" - REFINE_CASE_POLICY    : {REFINE_CASE_POLICY}")
# print(f" - TRAIN_USE_REFINED     : {TRAIN_USE_REFINED}")
# print(f" - TRAIN_USE_ORIGINAL    : {TRAIN_USE_ORIGINAL}")
# print(f" - TRAIN_USE_UNIFORM_SCALING_NOISE : {TRAIN_USE_UNIFORM_SCALING_NOISE}")
# print(f" - TRAIN_USE_BOUNDARY_JITTER_NOISE  : {TRAIN_USE_BOUNDARY_JITTER_NOISE}")
# print(f" - CLASS_MODES           : {CLASS_MODES}")
# print("=" * 80)

# summary_rows: List[Dict] = []

# # Prepare roots
# try:
#     _ = roots
# except NameError:
#     roots = discover_roots_from_datasets_root(DATASETS_ROOT)

# for ds_root in roots:
#     ds_root = Path(ds_root)

#     if TARGET_DATASETS_LOWER is not None:
#         ds_name_lower = ds_root.name.strip().lower()
#         if ds_name_lower not in TARGET_DATASETS_LOWER:
#             print(f"⏭️  Skip (not in TARGET_DATASETS): {ds_root.name}")
#             continue

#     images_root = ds_root / "images"
#     labels_root = ds_root / "labels"
#     if not images_root.is_dir() or not labels_root.is_dir():
#         print(f"⏭️  Skip (missing images/labels): {ds_root}")
#         continue

#     print(f"\n[Integrity Check] {ds_root.name}")
#     scan_and_clean_images(images_root)

#     cases = list_label_cases_for_dataset(ds_root)
#     if not cases:
#         print(f"⏭️  Skip (no target label cases after flags): {ds_root.name}")
#         continue

#     sp = get_split_info(ds_root)
#     print("\n" + "-" * 80)
#     print(f"[Dataset] {ds_root.name}")
#     print(f" - split_mode : {sp.get('split_mode')}")
#     print(f" - Cases      : {[c[0] for c in cases]}")
#     print(f" - CLASS_MODES: {CLASS_MODES}")
#     print("-" * 80)

#     for case_tag, case_label_root in cases:
#         for class_mode in CLASS_MODES:
#             vroot = None
#             try:
#                 vroot, data_yaml, tr_used, tr_total, va_used, va_total, split_mode = build_runtime_view_root(
#                     ds_root=ds_root,
#                     case_label_root=case_label_root,
#                     train_fraction=TRAIN_FRACTION,
#                     seed=SEED,
#                     case_tag=case_tag,
#                     class_mode=class_mode,
#                 )
#                 tr_pct = (tr_used / max(1, tr_total)) * 100.0
#                 print(
#                     f"  [Subset] case={case_tag} | class_mode={class_mode} | split_mode={split_mode} | "
#                     f"train_used={tr_used}/{tr_total} ({tr_pct:.1f}%) | val_used={va_used}/{va_total}"
#                 )
#             except Exception as e:
#                 print(f"  ⏭️  Skip build failed: case={case_tag} | class_mode={class_mode} | err={e}")
#                 if vroot and CLEANUP_RUNTIME_VROOT:
#                     try:
#                         shutil.rmtree(vroot)
#                     except Exception:
#                         pass
#                 continue

#             for model_tag, ckpt_candidates in MODEL_SPECS:
#                 print(f"\n  [Train] case={case_tag} | class_mode={class_mode} | model={model_tag}")

#                 try:
#                     model = choose_model(ckpt_candidates)
#                 except Exception as e:
#                     print(f"    ❌ Model load failed: {ckpt_candidates} | err={e}")
#                     continue

#                 project_dir = OUT_ROOT / ds_root.name
#                 frac_tag = f"tr{int(TRAIN_FRACTION*100)}"
#                 safe_case = case_tag.replace("/", "__").replace("\\", "__")
#                 if len(safe_case) > 120:
#                     safe_case = safe_case[:80] + "__" + _short_hash(safe_case, 8)
#                 name_dir = f"{model_tag}__{safe_case}__{class_mode}__{frac_tag}"
#                 project_dir.mkdir(parents=True, exist_ok=True)

#                 ok, used_b, used_sz, err = train_with_auto_oom(model, data_yaml, project_dir, name_dir, model_tag)
#                 if not ok:
#                     print(f"    ❌ Train failed: {err}")
#                     try:
#                         del model
#                     except Exception:
#                         pass
#                     if torch.cuda.is_available():
#                         torch.cuda.empty_cache()
#                         torch.cuda.ipc_collect()
#                     continue

#                 # ✅ val is evaluated based on data_yaml's labels/val (=configured with original labels)
#                 try:
#                     with suppress_output(SILENCE_ULTRA_OUTPUT):
#                         val_res = model.val(
#                             data=str(data_yaml),
#                             imgsz=used_sz if used_sz else IMG_SIZE,
#                             device=DEVICE,
#                             split="val",
#                             verbose=False,
#                             workers=NUM_WORKERS,
#                         )
#                 except Exception as e:
#                     print(f"    ❌ Val failed: {e}")
#                     val_res = None

#                 metrics = extract_metrics_dict(val_res)
#                 f1 = compute_f1_from_metrics(metrics)

#                 metrics_out = project_dir / name_dir / "metrics_eval.json"
#                 try:
#                     metrics_out.parent.mkdir(parents=True, exist_ok=True)
#                     with open(metrics_out, "w", encoding="utf-8") as f:
#                         json.dump(
#                             {
#                                 "dataset": ds_root.name,
#                                 "root": str(ds_root),
#                                 "case_tag": case_tag,
#                                 "case_label_root": str(case_label_root),
#                                 "class_mode": class_mode,
#                                 "model_tag": model_tag,
#                                 "ckpt_candidates": ckpt_candidates,
#                                 "train_fraction": TRAIN_FRACTION,
#                                 "train_used": tr_used,
#                                 "train_total": tr_total,
#                                 "val_used": va_used,
#                                 "val_total": va_total,
#                                 "data_yaml": str(data_yaml),
#                                 "runtime_view_root": str(vroot),
#                                 "split_mode": split_mode,
#                                 "effective_batch": used_b,
#                                 "effective_imgsz": used_sz,
#                                 "metrics": metrics,
#                                 "f1_from_pr": f1,
#                             },
#                             f,
#                             ensure_ascii=False,
#                             indent=2,
#                         )
#                     print(f"    ✅ Saved metrics: {metrics_out}")
#                 except Exception:
#                     pass

#                 row = {
#                     "dataset": ds_root.name,
#                     "model": model_tag,
#                     "case": case_tag,
#                     "case_label_root": str(case_label_root),
#                     "class_mode": class_mode,
#                     "split_mode": split_mode,
#                     "train_fraction": TRAIN_FRACTION,
#                     "train_used": tr_used,
#                     "train_total": tr_total,
#                     "val_used": va_used,
#                     "val_total": va_total,
#                     "effective_batch": used_b,
#                     "effective_imgsz": used_sz,
#                     "precision": metrics.get("metrics/precision(B)", None),
#                     "recall": metrics.get("metrics/recall(B)", None),
#                     "f1": f1,
#                     "mAP50": metrics.get("metrics/mAP50(B)", None),
#                     "mAP50-95": metrics.get("metrics/mAP50-95(B)", None),
#                 }
#                 summary_rows.append(row)

#                 try:
#                     del model
#                 except Exception:
#                     pass
#                 if torch.cuda.is_available():
#                     torch.cuda.empty_cache()
#                     torch.cuda.ipc_collect()

#             if CLEANUP_RUNTIME_VROOT:
#                 try:
#                     shutil.rmtree(vroot)
#                 except Exception:
#                     pass

# # -------------------------------------------------------------------------
# # Save summary CSV
# # -------------------------------------------------------------------------
# out_csv = OUT_ROOT / "summary_refines_vs_original.csv"
# try:
#     cols = set()
#     for r in summary_rows:
#         cols.update(r.keys())

#     base_cols = [
#         "dataset", "model", "case", "case_label_root",
#         "class_mode", "split_mode",
#         "train_fraction", "train_used", "train_total",
#         "val_used", "val_total",
#         "effective_batch", "effective_imgsz",
#         "precision", "recall", "f1", "mAP50", "mAP50-95",
#     ]
#     extra_cols = sorted([c for c in cols if c not in set(base_cols)])
#     cols = base_cols + extra_cols

#     with open(out_csv, "w", newline="", encoding="utf-8") as f:
#         w = csv.DictWriter(f, fieldnames=cols)
#         w.writeheader()
#         for r in summary_rows:
#             w.writerow(r)

#     print("\n" + "=" * 80)
#     print(f"✅ Saved summary CSV: {out_csv}")
#     print(f"✅ Total runs: {len(summary_rows)}")
#     print("=" * 80)

# except Exception as e:
#     print(f"⚠️  Summary CSV save failed: {e}")

# print("\n✅ Cell 2 done.")
