In [1]:
# ==========================================
# Cell 1) Detection datasets discovery & inspection (FINAL++)
#   - How many datasets are found
#   - train/val image count per dataset
#   - Whether label cases (original/scale/side) exist
#   - Output estimated class count/names (multiclass-based)
#
# [UPDATED++]
#   ✅ Reflect split structure rules per dataset name
#   ✅ For Cell 2 acceleration
#       - Store confirmed train/val labels dir
#       - Calculate train/val label file count, box (line) count
#       - Store n_train_groups_est (=box count)
# ========================================== 

from __future__ import annotations

import os, sys, random
from pathlib import Path
from typing import List, Tuple, Optional, Dict

# -------------------------------------------------------------------------
# 0) Register PROJECT_MODULE_DIR
# -------------------------------------------------------------------------
PROJECT_MODULE_DIR = Path("/home/ISW/project/Project_Module")
if str(PROJECT_MODULE_DIR) not in sys.path:
    sys.path.insert(0, str(PROJECT_MODULE_DIR))

# -------------------------------------------------------------------------
# 1) ultra_det_loader
# -------------------------------------------------------------------------
from ultra_det_loader import discover_det_datasets

# -------------------------------------------------------------------------
# 2) noisy_insection (use only scale/boundary jitter case list)
# -------------------------------------------------------------------------
try:
    from noisy_insection import UNIFORM_SCALING_FACTORS, JITTER_PATTERNS
except Exception:
    UNIFORM_SCALING_FACTORS = [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.3, 1.4]
    JITTER_PATTERNS = [1, 3, 5, 7, 9]

# -------------------------------------------------------------------------
# User config
# -------------------------------------------------------------------------
LOAD_DIR = "/home/ISW/project/datasets"
SEED = 42

# Image extensions
_IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}

def set_seed(seed: int = 42):
    random.seed(seed)

def list_images(dir_path: Optional[Path]) -> List[Path]:
    if dir_path is None or not Path(dir_path).exists():
        return []
    dir_path = Path(dir_path)
    imgs = []
    for p in dir_path.rglob("*"):
        if p.is_file() and p.suffix.lower() in _IMG_EXTS:
            imgs.append(p)
    return sorted(imgs)

def normalize_name(name: str) -> str:
    n = name.strip().lower()
    n = n.replace("_", "-")
    n = n.replace(" ", "-")
    return n

# -------------------------------------------------------------------------
# Legacy heuristic (fallback)
# -------------------------------------------------------------------------
def _fallback_train_dir(images_root: Path) -> Path:
    if (images_root / "train").is_dir():
        return images_root / "train"
    return images_root

def _fallback_val_dir(images_root: Path) -> Optional[Path]:
    if (images_root / "val").is_dir():
        return images_root / "val"
    if (images_root / "valid").is_dir():
        return images_root / "valid"
    return None

# -------------------------------------------------------------------------
# ✅ Dataset-specific split rules
# -------------------------------------------------------------------------
_SIMPLE_TRAIN_VAL = {
    "bccd",
    "brain-tumor",
    "custom-blood",
    "homeobjects-3k",
    "kitti",
    "medical-pills",
    "signature",
}

_TRAIN_TEST_VAL = {
    "construction-ppe",
    "african-wildlife",
}

def detect_split_dirs(ds_root: Path) -> Dict[str, Optional[Path]]:
    """
    Interpret images/labels split structure based on ds_root.
    Returns:
        {
          "train_img_dir": Path|None,
          "val_img_dir": Path|None,
          "test_img_dir": Path|None,
          "split_mode": str,  # "explicit" | "sku_virtual_8_2" | "fallback"
          "train_tag": str,
          "val_tag": str,
        }
    """
    ds_name = normalize_name(ds_root.name)
    images_root = ds_root / "images"

    # 1) VOC rule: use train2012/val2012 only
    if ds_name == "voc":
        return dict(
            train_img_dir=images_root / "train2012",
            val_img_dir=images_root / "val2012",
            test_img_dir=None,
            split_mode="explicit",
            train_tag="train2012",
            val_tag="val2012",
        )

    # 2) COCO/LVIS rule
    if ds_name == "coco" or "coco" in ds_name:
        return dict(
            train_img_dir=images_root / "train2017",
            val_img_dir=images_root / "val2017",
            test_img_dir=images_root / "test2017",
            split_mode="explicit",
            train_tag="train2017",
            val_tag="val2017",
        )

    if ds_name == "lvis" or "lvis" in ds_name:
        return dict(
            train_img_dir=images_root / "train2017",
            val_img_dir=images_root / "val2017",
            test_img_dir=images_root / "test2017",
            split_mode="explicit",
            train_tag="train2017",
            val_tag="val2017",
        )

    # 3) Explicit train/val structure
    if ds_name in _SIMPLE_TRAIN_VAL:
        return dict(
            train_img_dir=images_root / "train",
            val_img_dir=images_root / "val",
            test_img_dir=None,
            split_mode="explicit",
            train_tag="train",
            val_tag="val",
        )

    # 4) train/test/val structure
    if ds_name in _TRAIN_TEST_VAL:
        return dict(
            train_img_dir=images_root / "train",
            val_img_dir=images_root / "val",
            test_img_dir=images_root / "test",
            split_mode="explicit",
            train_tag="train",
            val_tag="val",
        )

    # 5) SKU-110K: no subfolders -> virtual split
    if ds_name in {"sku-110k", "sku110k", "sku_110k"} or ("sku" in ds_name and "110k" in ds_name):
        return dict(
            train_img_dir=images_root,
            val_img_dir=images_root,
            test_img_dir=None,
            split_mode="sku_virtual_8_2",
            train_tag="virtual_8_2",
            val_tag="virtual_8_2",
        )

    # 6) fallback
    tr = _fallback_train_dir(images_root)
    va = _fallback_val_dir(images_root)
    return dict(
        train_img_dir=tr,
        val_img_dir=va,
        test_img_dir=None,
        split_mode="fallback",
        train_tag=tr.name if tr else "unknown",
        val_tag=va.name if va else "missing",
    )

# -------------------------------------------------------------------------
# Class name estimation
# -------------------------------------------------------------------------
def infer_class_names_from_labels(label_root: Path, max_files: int = 2000) -> List[str]:
    if label_root is None or not label_root.exists():
        return ["class_0"]

    txts = list(label_root.rglob("*.txt"))
    if not txts:
        return ["class_0"]

    txts = txts[:max_files]
    cls_ids = set()

    for t in txts:
        try:
            with open(t, "r", encoding="utf-8") as f:
                for line in f:
                    parts = line.strip().split()
                    if len(parts) < 1:
                        continue
                    cid = int(float(parts[0]))
                    cls_ids.add(cid)
        except Exception:
            continue

    if not cls_ids:
        return ["class_0"]

    max_id = max(cls_ids)
    return [f"class_{i}" for i in range(max_id + 1)]

# -------------------------------------------------------------------------
# Label case detection
# -------------------------------------------------------------------------
def list_label_cases_for_dataset(ds_root: Path) -> List[Tuple[str, str]]:
    cases: List[Tuple[str, str]] = []

    if (ds_root / "labels").is_dir():
        cases.append(("original", "labels"))

    for s in UNIFORM_SCALING_FACTORS:
        d = f"labels_uniform_scaling_{s}"
        if (ds_root / d).is_dir():
            cases.append((f"scale_{s}", d))

    for k in JITTER_PATTERNS:
        d = f"labels_boundary_jitter_{k}"
        if (ds_root / d).is_dir():
            cases.append((f"side_{k}", d))

    return cases

# -------------------------------------------------------------------------
# ✅ Determine train/val labels dir (used directly by Cell 2)
# -------------------------------------------------------------------------
def resolve_split_label_dirs(ds_root: Path, train_tag: str, val_tag: str) -> Tuple[Path, Path]:
    labels_root = ds_root / "labels"

    cand_train = labels_root / train_tag
    cand_val   = labels_root / val_tag

    train_labels_dir = cand_train if cand_train.is_dir() else labels_root
    val_labels_dir   = cand_val   if cand_val.is_dir()   else labels_root

    return train_labels_dir, val_labels_dir

# -------------------------------------------------------------------------
# ✅ Label statistics (quick group count estimation)
# -------------------------------------------------------------------------
def count_label_files_and_boxes(label_dir: Optional[Path]) -> Tuple[int, int]:
    if label_dir is None or not Path(label_dir).exists():
        return 0, 0
    label_dir = Path(label_dir)
    txts = sorted(label_dir.rglob("*.txt"))
    n_files = len(txts)
    n_boxes = 0
    for t in txts:
        try:
            with open(t, "r", encoding="utf-8") as f:
                for line in f:
                    if line.strip():
                        n_boxes += 1
        except Exception:
            continue
    return n_files, n_boxes

# -------------------------------------------------------------------------
# SKU-110K virtual split count
# -------------------------------------------------------------------------
def compute_sku_virtual_counts(images_root: Path, seed: int = 42, ratio: float = 0.8) -> Tuple[int, int]:
    imgs = list_images(images_root)
    n = len(imgs)
    if n == 0:
        return 0, 0
    rnd = random.Random(seed)
    idxs = list(range(n))
    rnd.shuffle(idxs)
    cut = int(n * ratio)
    n_train = cut
    n_val = n - cut
    return n_train, n_val

# -------------------------------------------------------------------------
# Discover dataset roots
# -------------------------------------------------------------------------
set_seed(SEED)

specs = discover_det_datasets(LOAD_DIR)
roots: List[Path] = []
for s in specs:
    r = Path(s.root)
    if r not in roots:
        roots.append(r)

print("=" * 80)
print(f"[DISCOVERY] Found {len(roots)} unique dataset roots under: {Path(LOAD_DIR).resolve()}")
print("=" * 80)

# -------------------------------------------------------------------------
# Per-dataset summary
# -------------------------------------------------------------------------
dataset_summaries: List[Dict] = []

for ds_root in roots:
    ds_root = Path(ds_root)
    images_root = ds_root / "images"
    labels_root = ds_root / "labels"

    if not images_root.is_dir() or not labels_root.is_dir():
        print(f"⏭️  Skip (missing images/labels): {ds_root}")
        continue

    split_info = detect_split_dirs(ds_root)
    train_dir = split_info["train_img_dir"]
    val_dir   = split_info["val_img_dir"]
    split_mode = split_info["split_mode"]
    train_tag  = split_info.get("train_tag", "train")
    val_tag    = split_info.get("val_tag", "val")

    # --- Calculate image count ---
    if split_mode == "sku_virtual_8_2":
        n_train, n_val = compute_sku_virtual_counts(images_root, seed=SEED, ratio=0.8)
    else:
        n_train = len(list_images(train_dir))
        n_val   = len(list_images(val_dir)) if val_dir else 0

    # --- Confirm labels split dir ---
    train_labels_dir, val_labels_dir = resolve_split_label_dirs(ds_root, train_tag, val_tag)

    # --- Label statistics ---
    n_train_label_files, n_train_boxes = count_label_files_and_boxes(train_labels_dir)
    n_val_label_files,   n_val_boxes   = count_label_files_and_boxes(val_labels_dir)

    # group count estimate (for practical absn conversion)
    n_train_groups_est = n_train_boxes

    cases = list_label_cases_for_dataset(ds_root)
    class_names = infer_class_names_from_labels(labels_root)
    nc = len(class_names)

    info = {
        "dataset": ds_root.name,
        "root": str(ds_root),
        "images_root": str(images_root),
        "labels_root": str(labels_root),

        "train_dir": str(train_dir) if train_dir else None,
        "val_dir": str(val_dir) if val_dir else None,

        # ✅ Used directly by Cell 2
        "train_labels_dir": str(train_labels_dir) if train_labels_dir else None,
        "val_labels_dir": str(val_labels_dir) if val_labels_dir else None,

        "n_train": n_train,
        "n_val": n_val,

        # ✅ Label statistics
        "n_train_label_files": n_train_label_files,
        "n_val_label_files": n_val_label_files,
        "n_train_boxes": n_train_boxes,
        "n_val_boxes": n_val_boxes,
        "n_train_groups_est": n_train_groups_est,

        "split_mode": split_mode,
        "train_tag": train_tag,
        "val_tag": val_tag,
        "label_cases": [c[0] for c in cases],
        "nc_inferred": nc,
        "class_names_inferred": class_names,
    }
    dataset_summaries.append(info)

    print("\n" + "-" * 80)
    print(f"[Dataset] {ds_root.name}")
    print(f" - root        : {ds_root}")
    print(f" - split_mode  : {split_mode}")
    print(f" - train_dir   : {train_dir if train_dir else '(missing)'} | tag={train_tag} | n_train={n_train}")
    print(f" - val_dir     : {val_dir if val_dir else '(missing)'} | tag={val_tag} | n_val={n_val}")

    test_dir = split_info.get("test_img_dir", None)
    if test_dir and test_dir.is_dir():
        n_test = len(list_images(test_dir))
        print(f" - test_dir    : {test_dir} | n_test={n_test}")

    print(f" - train_labels_dir : {train_labels_dir}")
    print(f" - val_labels_dir   : {val_labels_dir}")
    print(f" - train label files/boxes/groups_est : {n_train_label_files} / {n_train_boxes} / {n_train_groups_est}")
    print(f" - val   label files/boxes           : {n_val_label_files} / {n_val_boxes}")

    print(f" - label_cases : {[c[0] for c in cases] if cases else '(none)'}")
    print(f" - inferred classes (multiclass-based): nc={nc}, names={class_names}")
    print("-" * 80)

print("\n✅ Cell 1 done.")
print(f"   -> dataset_summaries length = {len(dataset_summaries)}")
print("   -> roots variable is ready for Cell 2.")


[DISCOVERY] Found 13 unique dataset roots under: /home/ISW/project/datasets

--------------------------------------------------------------------------------
[Dataset] SKU-110K
 - root        : /home/ISW/project/datasets/SKU-110K
 - split_mode  : sku_virtual_8_2
 - train_dir   : /home/ISW/project/datasets/SKU-110K/images | tag=virtual_8_2 | n_train=9394
 - val_dir     : /home/ISW/project/datasets/SKU-110K/images | tag=virtual_8_2 | n_val=2349
 - train_labels_dir : /home/ISW/project/datasets/SKU-110K/labels
 - val_labels_dir   : /home/ISW/project/datasets/SKU-110K/labels
 - train label files/boxes/groups_est : 11743 / 1730996 / 1730996
 - val   label files/boxes           : 11743 / 1730996
 - label_cases : ['original', 'scale_0.6', 'scale_0.7', 'scale_0.8', 'scale_0.9', 'scale_1.1', 'scale_1.2', 'scale_1.3', 'scale_1.4', 'side_1', 'side_3', 'side_5', 'side_7', 'side_9']
 - inferred classes (multiclass-based): nc=1, names=['class_0']
------------------------------------------------------

# Load SAM model weights & perform refinement

In [7]:
###### Not modifying all data with SAM / Sampling to verify refinement works for each noise ######

# ==========================================
# Cell 2) SAM-based Refinement Runner — FINAL (RESUME + PROGRESS + INCREMENTAL METRICS)
#  - Refine YOLO boxes using SAM segmentation (box prompt)
#  - ONLY covers label_cases: labels_uniform_scaling_* + labels_boundary_jitter_*  (original labels/ is OFF by default)
#  - Save refined labels UNDER OUT_ROOT/refines/<dataset>/<case_id>/<noise_name>/<split>/.../*.txt
#  - Evaluate IoU vs ORIGINAL clean labels (/datasets/<ds>/labels/...) for files/boxes that exist
#  - Do NOT create empty labels for missing files (skip if missing)
#
# [NEW: RESUME SAFE]
#  - If out_lbl_dir has _DONE.json and it's complete => skip that noise/split
#  - If interrupted mid-way (no _DONE.json) => resume by skipping existing out txt files
#  - Write _PROGRESS.json every N files to track where it stopped
#  - Metrics CSV is appended per split (survives crashes), with duplicate-key guard
# ==========================================

from __future__ import annotations
import os, re, json, gc, hashlib, sys, zipfile, shutil, time, csv
from pathlib import Path
from typing import List, Tuple, Optional, Dict, Any

import numpy as np
import torch
from PIL import Image

# -------------------------
# USER CONFIG
# -------------------------
DATASETS_ROOT = Path("/home/ISW/project/datasets")

OUT_ROOT = Path("./SAM_refine_final")  # Desired experiment root (refines/ save location)
OUT_ROOT.mkdir(parents=True, exist_ok=True)

DEVICE = "cuda:1"

TARGET_DATASETS: Optional[List[str]] = [
    "kitti",
    "homeobjects-3K",
    "african-wildlife",
    "construction-ppe",
    # "Custom_Blood",
    "brain-tumor",
    "BCCD",
    "signature",
    "medical-pills",
    "VOC",
]

# Target noise label folders for refinement
UNIFORM_SCALING_FACTORS   = [0.6, 0.7, 0.8, 0.9, 1.1, 1.2, 1.3, 1.4]
JITTER_PATTERNS  = [1, 3, 5, 7, 9]

# Set True to refine original labels/ as well
INCLUDE_ORIGINAL_LABELS = False

# split selection (includes val2012/valid2012)
REFINE_SPLITS = ["train", "val", "valid"]   # includes train/val/valid prefix
MAX_FILES_PER_SPLIT: Optional[int] = None   # None=all (recommend limiting for large datasets)
OVERWRITE_EXISTING = False
WRITE_METRICS_CSV = True

# -------------------------
# RESUME / PROGRESS (NEW)
# -------------------------
RESUME_MODE = True                 # If True, skip files when out_lbl_path exists (resume execution)
VERIFY_EXISTING_OUT_LABELS = False # If True, regenerate even if out_lbl_path exists but YOLO parsing fails/corrupt
PROGRESS_EVERY = 50                # Update _PROGRESS.json every N files processed
WRITE_METRICS_INCREMENTAL = True   # Append to metrics CSV per split completion (records remain even on mid-termination)
METRICS_CSV_PATH = OUT_ROOT / "_refine_summary_metrics.csv"
METRICS_KEY_COLS = ("dataset", "case_id", "noise_name", "split")  # Key for preventing duplicates

# SAM config
SAM_MODEL_TYPE = "vit_h"  # "vit_h" | "vit_l" | "vit_b"
SAM_CKPT_PATH  = "/home/ISW/project/checkpoints/sam_vit_h_4b8939.pth"  # <-- Modify to your own path

# refinement knobs
BOX_EXPAND_RATIO = 0.0     # Slightly expand SAM box prompt (include boundaries)
MIN_BOX_WH_PIX   = 2       # Skip/fallback too small boxes
MASK_TO_BOX_PAD  = 0       # Padding pixels for mask->bbox conversion (recommend 0)
USE_MULTIMASK    = True    # SAM multimask_output=True
PICK_BY          = "score" # "score" | "area" (multimask selection criterion)

REFINES_OUT_ROOT = OUT_ROOT / "refines"
REFINES_OUT_ROOT.mkdir(parents=True, exist_ok=True)

IMG_EXTS = [".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp"]
YOLO_RE = re.compile(r"^\s*(\d+)\s+([\d\.eE+-]+)\s+([\d\.eE+-]+)\s+([\d\.eE+-]+)\s+([\d\.eE+-]+)")

# -------------------------
# EASY SAM LIB LOADER (NO GIT REQUIRED)
# -------------------------
SAM_DEPS_DIR = OUT_ROOT / "_deps"
SAM_DEPS_DIR.mkdir(parents=True, exist_ok=True)

LOCAL_SAM_ZIP_CANDIDATES = [
    Path("segment-anything-main.zip"),
    Path("segment-anything.zip"),
]

LOCAL_SAM_DIR_CANDIDATES = [
    Path("segment-anything-main"),
    Path("segment-anything"),
    SAM_DEPS_DIR / "segment-anything-main",
]

AUTO_DOWNLOAD_SAM_ZIP = True
SAM_GITHUB_ZIP_URL = "https://github.com/facebookresearch/segment-anything/archive/refs/heads/main.zip"

def _add_syspath(p: Path):
    p = p.resolve()
    if str(p) not in sys.path:
        sys.path.insert(0, str(p))

def _extract_zip(zip_path: Path, out_dir: Path) -> Path:
    out_dir.mkdir(parents=True, exist_ok=True)
    with zipfile.ZipFile(zip_path, "r") as zf:
        zf.extractall(out_dir)

    cand = []
    for p in out_dir.rglob("segment_anything"):
        if p.is_dir():
            cand.append(p.parent)
    if not cand:
        raise RuntimeError(f"Zip extracted but cannot find 'segment_anything/' inside: {zip_path}")

    repo_root = sorted(cand, key=lambda x: len(str(x)))[0]
    return repo_root

def ensure_segment_anything_importable() -> None:
    try:
        import segment_anything  # noqa
        return
    except Exception:
        pass

    for d in LOCAL_SAM_DIR_CANDIDATES:
        if d.is_dir() and (d / "segment_anything").is_dir():
            _add_syspath(d)
            try:
                import segment_anything  # noqa
                print(f"[SAM] using local repo dir: {d}")
                return
            except Exception:
                continue

    for z in LOCAL_SAM_ZIP_CANDIDATES:
        if z.is_file():
            try:
                repo_root = _extract_zip(z, SAM_DEPS_DIR)
                _add_syspath(repo_root)
                import segment_anything  # noqa
                print(f"[SAM] extracted from local zip: {z} -> {repo_root}")
                return
            except Exception as e:
                raise RuntimeError(f"[SAM] found local zip but failed to use it: {z}\nReason: {e}")

    if AUTO_DOWNLOAD_SAM_ZIP:
        try:
            import urllib.request
            zip_path = SAM_DEPS_DIR / "segment-anything-main.zip"
            if not zip_path.exists():
                print(f"[SAM] downloading zip: {SAM_GITHUB_ZIP_URL}")
                urllib.request.urlretrieve(SAM_GITHUB_ZIP_URL, str(zip_path))
                print(f"[SAM] downloaded: {zip_path}")

            repo_root = _extract_zip(zip_path, SAM_DEPS_DIR)
            _add_syspath(repo_root)
            import segment_anything  # noqa
            print(f"[SAM] extracted from downloaded zip -> {repo_root}")
            return
        except Exception as e:
            raise RuntimeError(
                "[SAM] segment_anything import failed + auto download/install also failed.\n"
                "Alternatives:\n"
                "  (A) Download segment-anything-main.zip directly in internet-enabled environment and place in current folder\n"
                "  (B) Extract zip and place segment-anything-main/ folder in current folder\n"
                "  (C) Or request git installation from administrator\n"
                f"Cause: {e}"
            )

    raise RuntimeError(
        "[SAM] segment_anything import failed.\n"
        "Solution:\n"
        "  - Place segment-anything-main/ folder (containing segment_anything/) in current directory, or\n"
        "  - Place segment-anything-main.zip in current directory and run again.\n"
        "  - Or set AUTO_DOWNLOAD_SAM_ZIP=True."
    )

ensure_segment_anything_importable()
from segment_anything import sam_model_registry, SamPredictor  # noqa: E402

# -------------------------
# Utils: dataset layout (robust)
# -------------------------
def resolve_dataset_root(dataset_name: str) -> Optional[Path]:
    cand = DATASETS_ROOT / dataset_name
    if cand.exists():
        return cand
    lname = dataset_name.lower()
    for d in DATASETS_ROOT.iterdir():
        if d.is_dir() and d.name.lower() == lname:
            return d
    alt = dataset_name.replace("3K", "3k").replace("3k", "3K")
    cand2 = DATASETS_ROOT / alt
    if cand2.exists():
        return cand2
    for d in DATASETS_ROOT.iterdir():
        if d.is_dir() and d.name.lower() == alt.lower():
            return d
    return None

def list_existing_splits(images_root: Path) -> List[str]:
    if not images_root.exists():
        return ["__flat__"]
    subs = sorted([p.name for p in images_root.iterdir() if p.is_dir()])
    return subs if subs else ["__flat__"]

def _split_selected(split: str) -> bool:
    if split == "__flat__":
        return True
    if not REFINE_SPLITS:
        return True
    s = split.lower()
    for t in REFINE_SPLITS:
        tl = str(t).lower()
        if s == tl or s.startswith(tl):
            return True
    if s.startswith("val") and any(str(t).lower()=="val" for t in REFINE_SPLITS):
        return True
    if s.startswith("valid") and any(str(t).lower()=="valid" for t in REFINE_SPLITS):
        return True
    return False

def resolve_split_dir(base: Path, split: str) -> Path:
    return base if split == "__flat__" else (base / split)

def find_image_for_label(images_dir: Path, rel_lbl_path: Path) -> Optional[Path]:
    stem = rel_lbl_path.with_suffix("").as_posix()
    for ext in IMG_EXTS:
        cand = images_dir / f"{stem}{ext}"
        if cand.exists():
            return cand
    return None

def read_yolo_txt(p: Optional[Path]) -> List[Tuple[int,float,float,float,float]]:
    rows=[]
    if p is None or (not p.exists()):
        return rows
    for ln in p.read_text(encoding="utf-8", errors="ignore").splitlines():
        m = YOLO_RE.match(ln)
        if not m:
            continue
        cls = int(float(m.group(1)))
        cx,cy,w,h = map(float, m.groups()[1:])
        rows.append((cls,cx,cy,w,h))
    return rows

def write_yolo_txt(p: Path, rows: List[Tuple[int,float,float,float,float]]):
    p.parent.mkdir(parents=True, exist_ok=True)
    with open(p, "w", encoding="utf-8") as f:
        for cls,cx,cy,w,h in rows:
            cx = min(max(cx, 0.0), 1.0)
            cy = min(max(cy, 0.0), 1.0)
            w  = min(max(w,  0.0), 1.0)
            h  = min(max(h,  0.0), 1.0)
            f.write(f"{cls} {cx:.6f} {cy:.6f} {w:.6f} {h:.6f}\n")

def is_valid_yolo_file(p: Path) -> bool:
    try:
        if not p.exists():
            return False
        txt = p.read_text(encoding="utf-8", errors="ignore").strip()
        if txt == "":
            return True  # empty allowed
        for ln in txt.splitlines():
            ln = ln.strip()
            if ln == "":
                continue
            if YOLO_RE.match(ln) is None:
                return False
        return True
    except Exception:
        return False

# -------------------------
# Metrics: IoU (normalized cxcywh)
# -------------------------
def _to_xyxy(cx, cy, w, h):
    return cx - w/2, cy - h/2, cx + w/2, cy + h/2

def box_iou(a, b, eps=1e-9) -> float:
    ax1, ay1, ax2, ay2 = _to_xyxy(*a)
    bx1, by1, bx2, by2 = _to_xyxy(*b)
    inter_x1 = max(ax1, bx1); inter_y1 = max(ay1, by1)
    inter_x2 = min(ax2, bx2); inter_y2 = min(ay2, by2)
    iw = max(0.0, inter_x2 - inter_x1); ih = max(0.0, inter_y2 - inter_y1)
    inter = iw * ih
    area_a = max(0.0, ax2-ax1) * max(0.0, ay2-ay1)
    area_b = max(0.0, bx2-bx1) * max(0.0, by2-by1)
    union = area_a + area_b - inter
    return float(inter / (union + eps))

# -------------------------
# label cases: labels_uniform_scaling_* + labels_boundary_jitter_* (and optionally labels/)
# -------------------------
def iter_label_case_dirs(dataset_root: Path) -> List[Tuple[str, Path]]:
    out: List[Tuple[str, Path]] = []

    if INCLUDE_ORIGINAL_LABELS and (dataset_root / "labels").is_dir():
        out.append(("labels", dataset_root / "labels"))

    for s in UNIFORM_SCALING_FACTORS:
        name = f"labels_uniform_scaling_{s:g}"
        p = dataset_root / name
        if p.is_dir():
            out.append((name, p))

    for k in JITTER_PATTERNS:
        name = f"labels_boundary_jitter_{int(k)}"
        p = dataset_root / name
        if p.is_dir():
            out.append((name, p))

    return out

# -------------------------
# case_id for SAM run (no best.pt)
# -------------------------
def make_sam_case_id() -> str:
    cfg = {
        "sam_type": SAM_MODEL_TYPE,
        "box_expand": BOX_EXPAND_RATIO,
        "multimask": bool(USE_MULTIMASK),
        "pick_by": PICK_BY,
        "mask_to_box_pad": MASK_TO_BOX_PAD,
        "min_box_wh_pix": MIN_BOX_WH_PIX,
    }
    h = hashlib.md5(json.dumps(cfg, sort_keys=True).encode("utf-8")).hexdigest()[:10]
    return f"SAM_{SAM_MODEL_TYPE}__bexp{BOX_EXPAND_RATIO:g}__mm{int(USE_MULTIMASK)}__pick{PICK_BY}__{h}"

# -------------------------
# SAM helper: yolo->pixel, expand, mask->box
# -------------------------
def yolo_to_xyxy_pix(cx, cy, w, h, W, H) -> Tuple[int,int,int,int]:
    x1 = int(round((cx - w/2) * W))
    x2 = int(round((cx + w/2) * W))
    y1 = int(round((cy - h/2) * H))
    y2 = int(round((cy + h/2) * H))
    x1 = max(0, min(W-1, x1))
    y1 = max(0, min(H-1, y1))
    x2 = max(1, min(W, x2))
    y2 = max(1, min(H, y2))
    return x1,y1,x2,y2

def expand_xyxy(x1,y1,x2,y2,W,H, ratio: float) -> Tuple[int,int,int,int]:
    if ratio <= 0:
        return x1,y1,x2,y2
    bw = x2 - x1
    bh = y2 - y1
    ex = int(round(bw * ratio))
    ey = int(round(bh * ratio))
    nx1 = max(0, x1 - ex)
    ny1 = max(0, y1 - ey)
    nx2 = min(W, x2 + ex)
    ny2 = min(H, y2 + ey)
    nx2 = max(nx2, nx1+1)
    ny2 = max(ny2, ny1+1)
    return nx1,ny1,nx2,ny2

def xyxy_pix_to_yolo(x1,y1,x2,y2,W,H) -> Tuple[float,float,float,float]:
    x1 = max(0, min(W-1, x1))
    y1 = max(0, min(H-1, y1))
    x2 = max(1, min(W, x2))
    y2 = max(1, min(H, y2))
    cx = ((x1 + x2) / 2.0) / float(W)
    cy = ((y1 + y2) / 2.0) / float(H)
    w  = (x2 - x1) / float(W)
    h  = (y2 - y1) / float(H)
    return float(cx), float(cy), float(w), float(h)

def mask_to_xyxy(mask: np.ndarray, pad: int = 0) -> Optional[Tuple[int,int,int,int]]:
    ys, xs = np.where(mask.astype(bool))
    if xs.size == 0 or ys.size == 0:
        return None
    x1 = int(xs.min()) - pad
    x2 = int(xs.max()) + 1 + pad
    y1 = int(ys.min()) - pad
    y2 = int(ys.max()) + 1 + pad
    return x1,y1,x2,y2

# -------------------------
# Resume helpers (NEW)
# -------------------------
def _safe_read_json(p: Path) -> Optional[Dict[str, Any]]:
    try:
        if p.exists():
            return json.loads(p.read_text(encoding="utf-8"))
    except Exception:
        return None
    return None

def _done_marker_seems_complete(done_payload: Optional[Dict[str, Any]]) -> bool:
    if not done_payload:
        return False
    ft = int(done_payload.get("files_total", 0) or 0)
    fd = int(done_payload.get("files_done", 0) or 0)
    fs = int(done_payload.get("files_skipped", 0) or 0)
    # At minimum, if total exists and done+skipped >= total, consider complete
    return (ft > 0) and ((fd + fs) >= ft)

def load_existing_metric_keys(csv_path: Path) -> set:
    keys = set()
    if not csv_path.exists():
        return keys
    try:
        import pandas as pd
        dfk = pd.read_csv(csv_path, usecols=list(METRICS_KEY_COLS))
        for r in dfk.itertuples(index=False):
            keys.add(tuple(r))
    except Exception:
        # If CSV is corrupt: give up key check for append safety (duplicates possible) -> execution still works
        return set()
    return keys

def append_metrics_row_csv(csv_path: Path, row: Dict[str, Any]):
    csv_path.parent.mkdir(parents=True, exist_ok=True)
    exists = csv_path.exists()
    fieldnames = list(row.keys())
    with open(csv_path, "a", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        if not exists:
            w.writeheader()
        w.writerow(row)

# -------------------------
# SAM build
# -------------------------
def build_sam_predictor(device: torch.device) -> SamPredictor:
    ckpt = Path(SAM_CKPT_PATH)
    if not ckpt.exists():
        raise FileNotFoundError(f"SAM checkpoint not found: {ckpt}")

    sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=str(ckpt))
    sam.to(device=device)
    sam.eval()
    predictor = SamPredictor(sam)
    return predictor

def free_cuda():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    gc.collect()

# -------------------------
# Core: refine one file using SAM
# -------------------------
@torch.inference_mode()
def refine_one_label_file_sam(
    predictor: SamPredictor,
    img_path: Path,
    noisy_lbl_path: Path,
    clean_lbl_path: Optional[Path],
    out_lbl_path: Path,
) -> Dict[str, float]:
    noisy_rows = read_yolo_txt(noisy_lbl_path)
    if len(noisy_rows) == 0:
        # noisy label file exists but empty => saving refined empty is OK
        if OVERWRITE_EXISTING or (not out_lbl_path.exists()):
            write_yolo_txt(out_lbl_path, [])
        return {"n_boxes_eval": 0.0, "sum_iou_noisy": 0.0, "sum_iou_refined": 0.0}

    clean_rows = read_yolo_txt(clean_lbl_path) if (WRITE_METRICS_CSV and clean_lbl_path is not None and clean_lbl_path.exists()) else None

    with Image.open(img_path).convert("RGB") as im:
        img_np = np.array(im)
        H, W = img_np.shape[0], img_np.shape[1]

    predictor.set_image(img_np)

    refined_rows: List[Tuple[int,float,float,float,float]] = []
    sum_iou_noisy = 0.0
    sum_iou_ref   = 0.0
    n_boxes_eval = 0

    for i, (cls, cx, cy, w, h) in enumerate(noisy_rows):
        x1,y1,x2,y2 = yolo_to_xyxy_pix(cx,cy,w,h,W,H)
        bw, bh = (x2-x1), (y2-y1)

        if bw < MIN_BOX_WH_PIX or bh < MIN_BOX_WH_PIX:
            refined_rows.append((cls, cx,cy,w,h))
            if clean_rows is not None and i < len(clean_rows):
                _, ccx, ccy, cw, ch = clean_rows[i]
                clean_box = (ccx, ccy, cw, ch)
                noisy_box = (cx, cy, w, h)
                sum_iou_noisy += box_iou(noisy_box, clean_box)
                sum_iou_ref   += box_iou(noisy_box, clean_box)
                n_boxes_eval += 1
            continue

        x1e,y1e,x2e,y2e = expand_xyxy(x1,y1,x2,y2,W,H, BOX_EXPAND_RATIO)
        box = np.array([x1e, y1e, x2e, y2e], dtype=np.float32)

        masks, scores, _ = predictor.predict(
            box=box,
            multimask_output=bool(USE_MULTIMASK),
        )

        if masks is None or len(masks) == 0:
            refined_rows.append((cls, cx,cy,w,h))
            if clean_rows is not None and i < len(clean_rows):
                _, ccx, ccy, cw, ch = clean_rows[i]
                clean_box = (ccx, ccy, cw, ch)
                noisy_box = (cx, cy, w, h)
                sum_iou_noisy += box_iou(noisy_box, clean_box)
                sum_iou_ref   += box_iou(noisy_box, clean_box)
                n_boxes_eval += 1
            continue

        if PICK_BY == "area":
            areas = [float(m.sum()) for m in masks]
            j = int(np.argmax(areas))
        else:
            j = int(np.argmax(scores)) if scores is not None else 0

        best_mask = masks[j]
        xyxy = mask_to_xyxy(best_mask, pad=int(MASK_TO_BOX_PAD))

        if xyxy is None:
            refined_rows.append((cls, cx,cy,w,h))
            if clean_rows is not None and i < len(clean_rows):
                _, ccx, ccy, cw, ch = clean_rows[i]
                clean_box = (ccx, ccy, cw, ch)
                noisy_box = (cx, cy, w, h)
                sum_iou_noisy += box_iou(noisy_box, clean_box)
                sum_iou_ref   += box_iou(noisy_box, clean_box)
                n_boxes_eval += 1
            continue

        mx1,my1,mx2,my2 = xyxy
        mx1 = max(0, min(W-1, mx1))
        my1 = max(0, min(H-1, my1))
        mx2 = max(1, min(W, mx2))
        my2 = max(1, min(H, my2))

        if (mx2-mx1) < MIN_BOX_WH_PIX or (my2-my1) < MIN_BOX_WH_PIX:
            refined_rows.append((cls, cx,cy,w,h))
            if clean_rows is not None and i < len(clean_rows):
                _, ccx, ccy, cw, ch = clean_rows[i]
                clean_box = (ccx, ccy, cw, ch)
                noisy_box = (cx, cy, w, h)
                sum_iou_noisy += box_iou(noisy_box, clean_box)
                sum_iou_ref   += box_iou(noisy_box, clean_box)
                n_boxes_eval += 1
            continue

        rcx, rcy, rw, rh = xyxy_pix_to_yolo(mx1,my1,mx2,my2,W,H)
        refined_rows.append((cls, rcx, rcy, rw, rh))

        if clean_rows is not None and i < len(clean_rows):
            _, ccx, ccy, cw, ch = clean_rows[i]
            clean_box = (ccx, ccy, cw, ch)
            noisy_box = (cx, cy, w, h)
            refined_box = (rcx, rcy, rw, rh)
            sum_iou_noisy += box_iou(noisy_box, clean_box)
            sum_iou_ref   += box_iou(refined_box, clean_box)
            n_boxes_eval += 1

    if OVERWRITE_EXISTING or (not out_lbl_path.exists()):
        write_yolo_txt(out_lbl_path, refined_rows)

    return {
        "n_boxes_eval": float(n_boxes_eval),
        "sum_iou_noisy": float(sum_iou_noisy),
        "sum_iou_refined": float(sum_iou_ref),
    }

# -------------------------
# MAIN
# -------------------------
device = torch.device(DEVICE if torch.cuda.is_available() else "cpu")
predictor = build_sam_predictor(device=device)

case_id = make_sam_case_id()
all_metrics_rows: List[Dict[str, Any]] = []

# (NEW) Load existing metrics keys (prevent duplicate append)
existing_metric_keys = load_existing_metric_keys(METRICS_CSV_PATH) if (WRITE_METRICS_CSV and WRITE_METRICS_INCREMENTAL) else set()

for dataset in (TARGET_DATASETS or []):
    ds_root = resolve_dataset_root(dataset)
    if ds_root is None or (not ds_root.exists()):
        print(f"⚠️ Skip missing dataset root: {DATASETS_ROOT/dataset}")
        continue

    images_root = ds_root / "images"
    labels_root = ds_root / "labels"  # clean labels (metric reference)
    if not images_root.exists() or not labels_root.exists():
        print(f"⚠️ Skip (missing images/labels): {ds_root.name}")
        continue

    splits_all = list_existing_splits(images_root)
    splits = [s for s in splits_all if _split_selected(s)]
    if not splits:
        splits = ["__flat__"]

    label_case_dirs = iter_label_case_dirs(ds_root)
    if not label_case_dirs:
        print(f"⚠️ No noise label dirs in {ds_root} (labels_uniform_scaling_* / labels_boundary_jitter_*). Skip.")
        continue

    refine_root = REFINES_OUT_ROOT / ds_root.name / case_id
    refine_root.mkdir(parents=True, exist_ok=True)

    print("\n" + "="*110)
    print(f"[DATASET] {ds_root.name} | case_id={case_id} | noise_cases={len(label_case_dirs)} | splits={splits}")
    print(f"          output => {refine_root}")
    print("="*110)

    for noise_name, noise_dir in label_case_dirs:
        for split in splits:
            img_dir = resolve_split_dir(images_root, split)
            clean_lbl_dir = resolve_split_dir(labels_root, split)
            noisy_lbl_dir = resolve_split_dir(noise_dir, split)

            # Requirement: if no data, skip instead of creating empty label
            if (not img_dir.exists()) or (not noisy_lbl_dir.exists()) or (not clean_lbl_dir.exists()):
                continue

            out_lbl_dir = refine_root / noise_name / (split if split != "__flat__" else "")
            done_marker = out_lbl_dir / "_DONE.json"
            progress_marker = out_lbl_dir / "_PROGRESS.json"

            # (A) Skip if already complete (but supplement with _DONE.json if metrics missing)
            if done_marker.exists() and (not OVERWRITE_EXISTING):
                dm = _safe_read_json(done_marker)
                if _done_marker_seems_complete(dm):
                    if WRITE_METRICS_CSV and WRITE_METRICS_INCREMENTAL and dm is not None:
                        key = tuple(dm.get(k) for k in METRICS_KEY_COLS)
                        if (None not in key) and (key not in existing_metric_keys):
                            # Supplement n_boxes compatible column
                            if "n_boxes" not in dm and "n_boxes_eval" in dm:
                                dm["n_boxes"] = dm["n_boxes_eval"]
                            append_metrics_row_csv(METRICS_CSV_PATH, dm)
                            existing_metric_keys.add(key)
                    continue
                else:
                    print(f"  ⚠️ found {done_marker} but seems incomplete -> resume {noise_name}/{split}")

            lbl_files = sorted([p for p in noisy_lbl_dir.rglob("*.txt") if p.is_file()])
            if MAX_FILES_PER_SPLIT is not None:
                lbl_files = lbl_files[:int(MAX_FILES_PER_SPLIT)]
            if not lbl_files:
                continue

            out_lbl_dir.mkdir(parents=True, exist_ok=True)

            print(f"  - refine {noise_name} / {split} | files={len(lbl_files)}")
            print(f"    -> save to: {out_lbl_dir}")

            sum_iou_noisy = 0.0
            sum_iou_ref = 0.0
            n_boxes_eval = 0
            n_files_done = 0
            n_files_skipped = 0
            n_files_noimg = 0
            n_files_noclean = 0

            # (B) Resume core: skip file if out_lf already exists
            for idx, lf in enumerate(lbl_files, start=1):
                rel = lf.relative_to(noisy_lbl_dir)
                out_lf = out_lbl_dir / rel

                should_skip = False
                if RESUME_MODE and out_lf.exists() and (not OVERWRITE_EXISTING):
                    if (not VERIFY_EXISTING_OUT_LABELS) or is_valid_yolo_file(out_lf):
                        should_skip = True

                if should_skip:
                    n_files_skipped += 1
                else:
                    img_path = find_image_for_label(img_dir, rel)
                    if img_path is None:
                        n_files_noimg += 1
                        continue

                    clean_lf = clean_lbl_dir / rel
                    clean_arg = clean_lf if clean_lf.exists() else None
                    if clean_arg is None:
                        n_files_noclean += 1

                    m = refine_one_label_file_sam(
                        predictor=predictor,
                        img_path=img_path,
                        noisy_lbl_path=lf,
                        clean_lbl_path=clean_arg,
                        out_lbl_path=out_lf,
                    )
                    n_files_done += 1
                    sum_iou_noisy += m["sum_iou_noisy"]
                    sum_iou_ref   += m["sum_iou_refined"]
                    n_boxes_eval  += int(m["n_boxes_eval"])

                # (C) Save progress state (for mid-termination)
                if (PROGRESS_EVERY is not None) and (PROGRESS_EVERY > 0) and (idx % int(PROGRESS_EVERY) == 0):
                    prog = {
                        "dataset": ds_root.name,
                        "case_id": case_id,
                        "noise_name": noise_name,
                        "split": split,
                        "files_total": len(lbl_files),
                        "files_done": n_files_done,
                        "files_skipped": n_files_skipped,
                        "files_noimg": n_files_noimg,
                        "files_noclean": n_files_noclean,
                        "n_boxes_eval": int(n_boxes_eval),
                        "last_rel": rel.as_posix(),
                        "updated_at": time.strftime("%Y-%m-%d %H:%M:%S"),
                        "out_dir": str(out_lbl_dir),
                    }
                    progress_marker.write_text(json.dumps(prog, indent=2), encoding="utf-8")

            mean_iou_noisy = None
            mean_iou_refined = None
            delta_iou = None
            if WRITE_METRICS_CSV and n_boxes_eval > 0:
                mean_iou_noisy = (sum_iou_noisy / n_boxes_eval)
                mean_iou_refined = (sum_iou_ref / n_boxes_eval)
                delta_iou = ((sum_iou_ref - sum_iou_noisy) / n_boxes_eval)

            payload = {
                "dataset": ds_root.name,
                "case_id": case_id,
                "noise_name": noise_name,
                "split": split,
                "files_total": len(lbl_files),
                "files_done": n_files_done,
                "files_skipped": n_files_skipped,
                "files_noimg": n_files_noimg,
                "files_noclean": n_files_noclean,
                # Also record n_boxes column for compatibility
                "n_boxes": int(n_boxes_eval),
                "n_boxes_eval": int(n_boxes_eval),
                "mean_iou_noisy": mean_iou_noisy,
                "mean_iou_refined": mean_iou_refined,
                "delta_iou": delta_iou,
                "sam_model_type": SAM_MODEL_TYPE,
                "sam_ckpt": str(SAM_CKPT_PATH),
                "box_expand_ratio": float(BOX_EXPAND_RATIO),
                "multimask": bool(USE_MULTIMASK),
                "pick_by": str(PICK_BY),
                "mask_to_box_pad": int(MASK_TO_BOX_PAD),
                "min_box_wh_pix": int(MIN_BOX_WH_PIX),
                "out_dir": str(out_lbl_dir),
            }

            # (D) Metrics append (for mid-termination)
            if WRITE_METRICS_CSV:
                if WRITE_METRICS_INCREMENTAL:
                    key = tuple(payload.get(k) for k in METRICS_KEY_COLS)
                    if (None not in key) and (key not in existing_metric_keys):
                        append_metrics_row_csv(METRICS_CSV_PATH, payload)
                        existing_metric_keys.add(key)
                else:
                    all_metrics_rows.append(payload)

            # (E) Record completion marker (if exists, skip entire noise/split on next run)
            done_marker.write_text(json.dumps(payload, indent=2), encoding="utf-8")

            # Progress marker is deleted on completion (optional)
            try:
                if progress_marker.exists():
                    progress_marker.unlink()
            except Exception:
                pass

free_cuda()

# save global metrics (non-incremental only)
if WRITE_METRICS_CSV:
    if WRITE_METRICS_INCREMENTAL:
        print(f"\n✅ Metrics are appended incrementally: {METRICS_CSV_PATH}")
    else:
        if all_metrics_rows:
            import pandas as pd
            df = pd.DataFrame(all_metrics_rows)
            save_path = OUT_ROOT / "_refine_summary_metrics.csv"
            save_path.parent.mkdir(parents=True, exist_ok=True)
            df.to_csv(save_path, index=False)
            print(f"\n✅ Saved refine metrics summary: {save_path}")
        else:
            print("\n✅ Done (no metrics rows).")
else:
    print("\n✅ Done (no metrics csv requested).")



[DATASET] kitti | case_id=SAM_vit_h__bexp0__mm1__pickscore__a195a0a8e8 | noise_cases=13 | splits=['train', 'val']
          output => SAM_refine_final/refines/kitti/SAM_vit_h__bexp0__mm1__pickscore__a195a0a8e8
  - refine labels_uniform_scaling_0.6 / train | files=5985
    -> save to: SAM_refine_final/refines/kitti/SAM_vit_h__bexp0__mm1__pickscore__a195a0a8e8/labels_uniform_scaling_0.6/train
  - refine labels_uniform_scaling_0.6 / val | files=1496
    -> save to: SAM_refine_final/refines/kitti/SAM_vit_h__bexp0__mm1__pickscore__a195a0a8e8/labels_uniform_scaling_0.6/val
  - refine labels_uniform_scaling_0.7 / train | files=5985
    -> save to: SAM_refine_final/refines/kitti/SAM_vit_h__bexp0__mm1__pickscore__a195a0a8e8/labels_uniform_scaling_0.7/train
  - refine labels_uniform_scaling_0.7 / val | files=1496
    -> save to: SAM_refine_final/refines/kitti/SAM_vit_h__bexp0__mm1__pickscore__a195a0a8e8/labels_uniform_scaling_0.7/val
  - refine labels_uniform_scaling_0.8 / train | files=5985
  

## Time Calculation

In [None]:
# ==========================================
# Cell 3) SAM-based Refinement Speed Benchmark
#  - No saving logic
#  - Random 5 samples per dataset -> speed measurement
#  - Includes Image Encoding time + Box Prompting time
# ==========================================

import time
import random
import statistics
import torch
import numpy as np
from pathlib import Path
from PIL import Image
from tqdm import tqdm

# [Configuration] Keep same as previous cell or modify if needed
DATASETS_ROOT = Path("/home/ISW/project/datasets")
DEVICE = "cuda:1"
SAM_MODEL_TYPE = "vit_h"
SAM_CKPT_PATH  = "/home/ISW/project/checkpoints/sam_vit_h_4b8939.pth"

TARGET_DATASETS = [
    "kitti",
    "homeobjects-3K",
    "african-wildlife",
    "construction-ppe",
    # "Custom_Blood",
    "brain-tumor",
    "BCCD",
    "signature",
    "medical-pills",
    "VOC",
]

# Benchmark settings
TEST_NUMBER = 50          # Number of images to test per dataset
BOX_EXPAND_RATIO = 0.0   # Setting for benchmark
USE_MULTIMASK = True
IMG_EXTS = [".jpg",".jpeg",".png",".bmp",".tif",".tiff",".webp"]

# -------------------------------------------------
# 1. Helper Functions (Reused for consistency)
# -------------------------------------------------
def resolve_dataset_root(dataset_name: str) -> Path:
    cand = DATASETS_ROOT / dataset_name
    if cand.exists(): return cand
    # Try simple name matching
    for d in DATASETS_ROOT.iterdir():
        if d.is_dir() and d.name.lower() == dataset_name.lower():
            return d
    return cand # fallback

def read_yolo_txt_fast(p: Path):
    rows = []
    if not p.exists(): return rows
    with open(p, "r") as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) >= 5:
                # cls, cx, cy, w, h
                rows.append([float(x) for x in parts[1:5]]) # exclude class, coordinates only
    return rows

def yolo_to_xyxy_pix(cx, cy, w, h, W, H):
    x1 = int(round((cx - w/2) * W)); x2 = int(round((cx + w/2) * W))
    y1 = int(round((cy - h/2) * H)); y2 = int(round((cy + h/2) * H))
    return max(0, x1), max(0, y1), min(W, x2), min(H, y2)

def expand_xyxy(x1, y1, x2, y2, W, H, ratio):
    if ratio <= 0: return x1, y1, x2, y2
    bw, bh = x2-x1, y2-y1
    ex, ey = int(bw*ratio), int(bh*ratio)
    return max(0, x1-ex), max(0, y1-ey), min(W, x2+ex), min(H, y2+ey)

# SAM Loader (can skip if already loaded in Cell 2, but check for safety)
try:
    from segment_anything import sam_model_registry, SamPredictor
except ImportError:
    # Add path if needed in case Cell 2 loading logic was not executed
    import sys
    sys.path.append(str(Path("./SAM_refine_final/_deps/segment-anything-main").resolve()))
    from segment_anything import sam_model_registry, SamPredictor

# -------------------------------------------------
# 2. Main Benchmark Logic
# -------------------------------------------------
def run_benchmark():
    # 1) Model Load
    print(f"Loading SAM ({SAM_MODEL_TYPE}) on {DEVICE}...")
    sam = sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CKPT_PATH)
    sam.to(device=DEVICE)
    sam.eval()
    predictor = SamPredictor(sam)
    
    # Warmup
    print("Warming up GPU...")
    dummy_img = np.zeros((512, 512, 3), dtype=np.uint8)
    predictor.set_image(dummy_img)
    predictor.predict(box=np.array([10,10,100,100]))
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    print("Warmup done.\n")

    dataset_times = {}
    
    print(f"{'Dataset':<20} | {'Img Load(s)':<10} | {'SAM Refine(s)':<12} | {'Objs/Img':<8} | {'Total/Img(s)':<12}")
    print("-" * 75)

    for ds_name in TARGET_DATASETS:
        ds_root = resolve_dataset_root(ds_name)
        if not ds_root.exists():
            print(f"{ds_name:<20} | NOT FOUND")
            continue

        # Find images and labels (prioritize Noise folder, use labels if not found)
        images_dir = ds_root / "images"
        
        # Find label source: use noise folder like labels_uniform_scaling_0.6 if exists (simulate real scenario)
        # Use clean labels if not found
        label_source = ds_root / "labels"
        for d in ds_root.iterdir():
            if d.is_dir() and d.name.startswith("labels_uniform_scaling_"):
                label_source = d
                break
        
        # Collect file list
        # Find all images recursively in images folder
        all_imgs = []
        for ext in IMG_EXTS:
            all_imgs.extend(list(images_dir.rglob(f"*{ext}")))
            
        if not all_imgs:
            print(f"{ds_name:<20} | NO IMAGES")
            continue

        # Random sampling
        if len(all_imgs) > TEST_NUMBER:
            samples = random.sample(all_imgs, TEST_NUMBER)
        else:
            samples = all_imgs

        time_records = [] # (load_time, infer_time, num_objects)

        for img_path in samples:
            # Find matching label
            rel_path = img_path.relative_to(images_dir)
            lbl_path = label_source / rel_path.with_suffix(".txt")
            
            # --- [Measurement Start] ---
            # 1. Image Loading & Preprocessing Time
            t_start_load = time.time()
            try:
                with Image.open(img_path).convert("RGB") as im:
                    img_np = np.array(im)
                    H, W = img_np.shape[:2]
                
                # Read label (YOLO parsing)
                boxes_yolo = read_yolo_txt_fast(lbl_path)
            except Exception:
                continue # Skip on file read failure
            
            if torch.cuda.is_available(): torch.cuda.synchronize()
            t_end_load = time.time()
            load_dur = t_end_load - t_start_load

            # 2. SAM Inference Time (Set Image + Predict Loop)
            #    * Note: SAM's set_image(encoding) time is quite long.
            t_start_infer = time.time()
            
            predictor.set_image(img_np) # Run Encoder (Heavy)
            
            # Inference for all boxes
            # (Batch inference is possible, but following existing code's Loop approach)
            if boxes_yolo:
                for (cx, cy, w, h) in boxes_yolo:
                    x1, y1, x2, y2 = yolo_to_xyxy_pix(cx, cy, w, h, W, H)
                    
                    # Box Valid Check
                    if (x2-x1) < 2 or (y2-y1) < 2: continue

                    ex1, ey1, ex2, ey2 = expand_xyxy(x1, y1, x2, y2, W, H, BOX_EXPAND_RATIO)
                    input_box = np.array([ex1, ey1, ex2, ey2])

                    masks, _, _ = predictor.predict(
                        box=input_box,
                        multimask_output=USE_MULTIMASK
                    )
                    # (Post-processing time is minimal, so included)
            
            if torch.cuda.is_available(): torch.cuda.synchronize()
            t_end_infer = time.time()
            infer_dur = t_end_infer - t_start_infer
            
            time_records.append((load_dur, infer_dur, len(boxes_yolo)))
            # --- [Measurement End] ---

        if not time_records:
            print(f"{ds_name:<20} | SKIPPED (No valid labels)")
            continue

        avg_load = statistics.mean([t[0] for t in time_records])
        avg_infer = statistics.mean([t[1] for t in time_records])
        avg_objs = statistics.mean([t[2] for t in time_records])
        total_time_per_img = avg_load + avg_infer

        dataset_times[ds_name] = total_time_per_img

        print(f"{ds_name:<20} | {avg_load:.4f}s      | {avg_infer:.4f}s       | {avg_objs:>8.1f} | {total_time_per_img:.4f}s")

    # Final Summary
    print("-" * 75)
    if dataset_times:
        global_avg = statistics.mean(dataset_times.values())
        print(f"{'OVERALL AVERAGE':<20} | {'-':<10} | {'-':<12} | {'-':<8} | {global_avg:.4f}s")
        print("=" * 75)
        print(f"Note: Measured on {TEST_NUMBER} random images per dataset.")
        print("Time includes: Image Load + SAM Encoder(set_image) + Mask Decoding(predict loop).")
    else:
        print("No datasets processed successfully.")

if __name__ == "__main__":
    # Run after memory cleanup
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    run_benchmark()

Loading SAM (vit_h) on cuda:1...
Warming up GPU...
Warmup done.

Dataset              | Img Load(s) | SAM Refine(s) | Objs/Img | Total/Img(s)
---------------------------------------------------------------------------
kitti                | 0.0020s      | 0.5671s       |      4.4 | 0.5691s
homeobjects-3K       | 0.0093s      | 0.5969s       |      9.6 | 0.6062s
african-wildlife     | 0.0020s      | 0.5634s       |      1.2 | 0.5653s
construction-ppe     | 0.0032s      | 0.5776s       |      6.8 | 0.5808s
brain-tumor          | 0.0007s      | 0.5621s       |      1.0 | 0.5628s
BCCD                 | 0.0013s      | 0.5947s       |     12.6 | 0.5960s
signature            | 0.0076s      | 0.5699s       |      1.0 | 0.5775s
medical-pills        | 0.0075s      | 0.6218s       |     16.8 | 0.6294s
VOC                  | 0.0019s      | 0.5693s       |      1.8 | 0.5712s
---------------------------------------------------------------------------
OVERALL AVERAGE      | -          | -            