In [None]:
import os

ROOT = "/kaggle/input/indoor-small-object-dataset/ISOD"

for root, dirs, files in os.walk(ROOT):
    level = root.replace(ROOT, "").count(os.sep)
    
    # limit depth: root + sub + sub-sub
    if level > 4:
        continue

    indent = "│   " * level
    print(f"{indent}├── {os.path.basename(root)}/")


In [None]:
import os
import re
import json
from pathlib import Path
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
from PIL import Image, ImageFile
import matplotlib.pyplot as plt

ImageFile.LOAD_TRUNCATED_IMAGES = False

KAGGLE_INPUT_ROOT = Path("/kaggle/input")
WORK_ROOT = Path("/kaggle/working")
RESULTS_DIR = WORK_ROOT / "results" / "01_data_integrity_and_profile"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

IMG_EXTS = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"}

def find_isod_data_root():
    candidates = []
    direct = KAGGLE_INPUT_ROOT / "ISOD" / "data"
    if direct.exists():
        candidates.append(direct)
    for p in KAGGLE_INPUT_ROOT.glob("*"):
        if p.is_dir():
            cand = p / "ISOD" / "data"
            if cand.exists():
                candidates.append(cand)
    for p in KAGGLE_INPUT_ROOT.rglob("ISOD"):
        if p.is_dir():
            cand = p / "data"
            if cand.exists():
                candidates.append(cand)
    candidates = sorted(set(candidates), key=lambda x: str(x))
    if not candidates:
        raise FileNotFoundError("Could not find ISOD/data under /kaggle/input")
    return candidates[0]

DATA_ROOT = find_isod_data_root()
WORK_DATA_ROOT = DATA_ROOT

def list_files_in_dir(d: Path):
    if not d.exists():
        return []
    files = []
    for ext in IMG_EXTS:
        files.extend(d.glob(f"*{ext}"))
        files.extend(d.glob(f"*{ext.upper()}"))
    return sorted(set(files), key=lambda x: x.name)

def stem_id(p: Path):
    return p.stem

def infer_bitdepth_pil(img: Image.Image, arr: np.ndarray):
    if arr is not None:
        if arr.dtype == np.uint8:
            return 8
        if arr.dtype == np.uint16:
            return 16
        if arr.dtype == np.uint32:
            return 32
        if arr.dtype == np.int32:
            return 32
        if arr.dtype == np.float32:
            return 32
        if arr.dtype == np.float64:
            return 64
    mode = img.mode
    if mode in ["1"]:
        return 1
    if mode in ["L", "P", "RGB", "RGBA", "CMYK", "YCbCr"]:
        return 8
    if mode in ["I;16", "I;16B", "I;16L", "I;16N"]:
        return 16
    if mode in ["I"]:
        return 32
    if mode in ["F"]:
        return 32
    return -1

def safe_read_image(path: Path):
    try:
        with Image.open(path) as img:
            img.load()
            arr = np.array(img)
            bitdepth = infer_bitdepth_pil(img, arr)
            shape = arr.shape
        return True, shape, bitdepth, ""
    except Exception as e:
        return False, None, None, f"{type(e).__name__}: {str(e)[:200]}"

def file_size_kb(path: Path):
    try:
        return os.path.getsize(path) / 1024.0
    except Exception:
        return np.nan

def mask_object_area(mask_arr: np.ndarray):
    if mask_arr is None:
        return np.nan
    if mask_arr.ndim == 3:
        m = np.any(mask_arr > 0, axis=2)
        return int(np.count_nonzero(m))
    return int(np.count_nonzero(mask_arr > 0))

def depth_valid_ratio(depth_arr: np.ndarray):
    if depth_arr is None:
        return np.nan
    valid = np.count_nonzero(depth_arr > 0)
    total = depth_arr.size
    if total == 0:
        return np.nan
    return float(valid) / float(total)

def deterministic_sample_id(site_id: str, stem: str):
    return f"{site_id}__{stem}"

def try_parse_labels(label_dir: Path, stem: str, label_parse_notes: list):
    if not label_dir.exists():
        return None
    for ext in [".txt", ".json", ".xml", ".csv"]:
        lp = label_dir / f"{stem}{ext}"
        if lp.exists():
            try:
                if ext == ".txt":
                    txt = lp.read_text(encoding="utf-8", errors="ignore").strip()
                    tok = re.split(r"\s+", txt)[0] if txt else ""
                    if tok.isdigit():
                        return tok
                    if tok:
                        return tok
                if ext == ".json":
                    obj = json.loads(lp.read_text(encoding="utf-8", errors="ignore"))
                    for k in ["class", "label", "category", "category_id", "cls"]:
                        if isinstance(obj, dict) and k in obj:
                            return str(obj[k])
                    if isinstance(obj, list) and obj:
                        for k in ["class", "label", "category", "category_id", "cls"]:
                            if isinstance(obj[0], dict) and k in obj[0]:
                                return str(obj[0][k])
            except Exception as e:
                label_parse_notes.append(f"{lp}: {type(e).__name__}")
                return None
    return None

def safe_percentile(x, q):
    if len(x) == 0:
        return np.nan
    return float(np.percentile(np.array(x), q))

plt.rcParams.update({
    "font.size": 10,
    "axes.titlesize": 11,
    "axes.labelsize": 10,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 9,
})

def save_fig(path: Path, dpi=300):
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

site_dirs = sorted([p for p in DATA_ROOT.iterdir() if p.is_dir()], key=lambda x: x.name)
if not site_dirs:
    raise RuntimeError(f"No site folders found under: {DATA_ROOT}")

records = []
issues_rows = []

global_rgb_shapes = []
global_depth_shapes = []
global_mask_shapes = []
global_rgb_bitdepths = []
global_depth_bitdepths = []
global_mask_bitdepths = []
global_object_areas = []
global_depth_valid_ratios = []

site_valid_counts = Counter()
site_total_counts = Counter()
site_object_areas = defaultdict(list)
site_depth_valid_ratios = defaultdict(list)

label_class_counter = Counter()
label_parse_notes = []

for site_path in site_dirs:
    site_id = site_path.name

    rgb_dir = site_path / "rgb"
    depth_dir = site_path / "depth"
    mask_dir = site_path / "mask"
    label_dir = site_path / "label"

    rgb_files = list_files_in_dir(rgb_dir)
    depth_files = list_files_in_dir(depth_dir)
    mask_files = list_files_in_dir(mask_dir)

    rgb_map = {stem_id(p): p for p in rgb_files}
    depth_map = {stem_id(p): p for p in depth_files}
    mask_map = {stem_id(p): p for p in mask_files}

    all_stems = sorted(set(rgb_map.keys()) | set(depth_map.keys()) | set(mask_map.keys()))

    for stem in all_stems:
        sample_id = deterministic_sample_id(site_id, stem)
        site_total_counts[site_id] += 1

        rgb_path = rgb_map.get(stem, None)
        depth_path = depth_map.get(stem, None)
        mask_path = mask_map.get(stem, None)

        notes = []
        status = "valid"

        if rgb_path is None:
            status = "missing_rgb"
            notes.append("RGB missing")
        if depth_path is None:
            if status == "valid":
                status = "missing_depth"
            notes.append("Depth missing")
        if mask_path is None:
            if status == "valid":
                status = "missing_mask"
            notes.append("Mask missing")

        rgb_shape = depth_shape = mask_shape = None
        rgb_bit = depth_bit = mask_bit = None
        rgb_ok = depth_ok = mask_ok = False

        rgb_size_kb = depth_size_kb = mask_size_kb = np.nan

        if rgb_path is not None:
            rgb_size_kb = file_size_kb(rgb_path)
            rgb_ok, rgb_shape, rgb_bit, rgb_note = safe_read_image(rgb_path)
            if not rgb_ok:
                if status == "valid":
                    status = "unreadable_rgb"
                notes.append(f"RGB unreadable: {rgb_note}")

        if depth_path is not None:
            depth_size_kb = file_size_kb(depth_path)
            depth_ok, depth_shape, depth_bit, depth_note = safe_read_image(depth_path)
            if not depth_ok:
                if status == "valid":
                    status = "unreadable_depth"
                notes.append(f"Depth unreadable: {depth_note}")

        if mask_path is not None:
            mask_size_kb = file_size_kb(mask_path)
            mask_ok, mask_shape, mask_bit, mask_note = safe_read_image(mask_path)
            if not mask_ok:
                if status == "valid":
                    status = "unreadable_mask"
                notes.append(f"Mask unreadable: {mask_note}")

        obj_area = np.nan
        d_valid_ratio = np.nan

        if mask_ok and mask_path is not None:
            with Image.open(mask_path) as mimg:
                mimg.load()
                marr = np.array(mimg)
            obj_area = mask_object_area(marr)

        if depth_ok and depth_path is not None:
            with Image.open(depth_path) as dimg:
                dimg.load()
                darr = np.array(dimg)
            d_valid_ratio = depth_valid_ratio(darr)

        parsed_label = try_parse_labels(label_dir, stem, label_parse_notes)
        if parsed_label is not None:
            label_class_counter[parsed_label] += 1

        triplet_exists = (rgb_path is not None) and (depth_path is not None) and (mask_path is not None)
        triplet_readable = rgb_ok and depth_ok and mask_ok

        if triplet_exists and triplet_readable:
            status = "valid"
            site_valid_counts[site_id] += 1

            if rgb_shape is not None:
                global_rgb_shapes.append(rgb_shape)
                global_rgb_bitdepths.append(rgb_bit)
            if depth_shape is not None:
                global_depth_shapes.append(depth_shape)
                global_depth_bitdepths.append(depth_bit)
            if mask_shape is not None:
                global_mask_shapes.append(mask_shape)
                global_mask_bitdepths.append(mask_bit)

            if not np.isnan(obj_area):
                global_object_areas.append(obj_area)
                site_object_areas[site_id].append(obj_area)

            if not np.isnan(d_valid_ratio):
                global_depth_valid_ratios.append(d_valid_ratio)
                site_depth_valid_ratios[site_id].append(d_valid_ratio)
        else:
            issue_type = status if status != "valid" else "incomplete_or_unreadable_triplet"
            details = "; ".join(notes) if notes else "Triplet incomplete or unreadable"
            issues_rows.append({
                "sample_id": sample_id,
                "site_id": site_id,
                "issue_type": issue_type,
                "details": details
            })

        records.append({
            "sample_id": sample_id,
            "site_id": site_id,
            "rgb_path": str(rgb_path) if rgb_path is not None else "",
            "depth_path": str(depth_path) if depth_path is not None else "",
            "mask_path": str(mask_path) if mask_path is not None else "",
            "rgb_shape": str(rgb_shape) if rgb_shape is not None else "",
            "depth_shape": str(depth_shape) if depth_shape is not None else "",
            "mask_shape": str(mask_shape) if mask_shape is not None else "",
            "rgb_bitdepth": int(rgb_bit) if rgb_bit is not None else "",
            "depth_bitdepth": int(depth_bit) if depth_bit is not None else "",
            "mask_bitdepth": int(mask_bit) if mask_bit is not None else "",
            "file_size_kb_rgb": float(rgb_size_kb) if not np.isnan(rgb_size_kb) else "",
            "file_size_kb_depth": float(depth_size_kb) if not np.isnan(depth_size_kb) else "",
            "file_size_kb_mask": float(mask_size_kb) if not np.isnan(mask_size_kb) else "",
            "status": status,
            "notes": "; ".join(notes)
        })

df_manifest = pd.DataFrame(records).sort_values("sample_id", kind="mergesort").reset_index(drop=True)

if len(issues_rows) == 0:
    df_issues = pd.DataFrame(columns=["sample_id", "site_id", "issue_type", "details"])
else:
    df_issues = pd.DataFrame(issues_rows).sort_values("sample_id", kind="mergesort").reset_index(drop=True)

TAB_01_MANIFEST = RESULTS_DIR / "tab_01_dataset_manifest_full.csv"
TAB_04_ISSUES = RESULTS_DIR / "tab_04_corrupt_or_missing_samples.csv"
df_manifest.to_csv(TAB_01_MANIFEST, index=False)
df_issues.to_csv(TAB_04_ISSUES, index=False)

df_valid = df_manifest[df_manifest["status"] == "valid"].copy()
TAB_01_MANIFEST_VALID = RESULTS_DIR / "tab_01_dataset_manifest_valid_only.csv"
df_valid.to_csv(TAB_01_MANIFEST_VALID, index=False)

total_sites = len(site_dirs)
total_samples_union = len(df_manifest)
total_valid = int((df_manifest["status"] == "valid").sum())

missing_rgb = int((df_manifest["rgb_path"] == "").sum())
missing_depth = int((df_manifest["depth_path"] == "").sum())
missing_mask = int((df_manifest["mask_path"] == "").sum())

unreadable_rgb = int(df_manifest["notes"].str.contains("RGB unreadable", na=False).sum())
unreadable_depth = int(df_manifest["notes"].str.contains("Depth unreadable", na=False).sum())
unreadable_mask = int(df_manifest["notes"].str.contains("Mask unreadable", na=False).sum())

obj_mean = float(np.mean(global_object_areas)) if len(global_object_areas) else np.nan
obj_median = float(np.median(global_object_areas)) if len(global_object_areas) else np.nan
obj_p90 = safe_percentile(global_object_areas, 90)
obj_p95 = safe_percentile(global_object_areas, 95)
obj_p99 = safe_percentile(global_object_areas, 99)

depth_valid_mean = float(np.mean(global_depth_valid_ratios)) if len(global_depth_valid_ratios) else np.nan
depth_valid_median = float(np.median(global_depth_valid_ratios)) if len(global_depth_valid_ratios) else np.nan

def shape_to_hw(shape_str_or_tuple):
    if shape_str_or_tuple is None or shape_str_or_tuple == "":
        return None
    if isinstance(shape_str_or_tuple, tuple):
        s = shape_str_or_tuple
    else:
        m = re.findall(r"\d+", str(shape_str_or_tuple))
        if len(m) < 2:
            return None
        s = tuple(int(v) for v in m)
    if len(s) >= 2:
        return int(s[0]), int(s[1])
    return None

rgb_hw = [shape_to_hw(s) for s in global_rgb_shapes if s is not None]
depth_hw = [shape_to_hw(s) for s in global_depth_shapes if s is not None]
rgb_hw = [x for x in rgb_hw if x is not None]
depth_hw = [x for x in depth_hw if x is not None]

rgb_res_counter = Counter(rgb_hw)
depth_res_counter = Counter(depth_hw)
rgb_bit_counter = Counter([b for b in global_rgb_bitdepths if b is not None])
depth_bit_counter = Counter([b for b in global_depth_bitdepths if b is not None])

summary_rows = [
    ("data_root", str(DATA_ROOT)),
    ("sites_count", total_sites),
    ("samples_total_union", total_samples_union),
    ("valid_triplets", total_valid),
    ("missing_rgb", missing_rgb),
    ("missing_depth", missing_depth),
    ("missing_mask", missing_mask),
    ("unreadable_rgb", unreadable_rgb),
    ("unreadable_depth", unreadable_depth),
    ("unreadable_mask", unreadable_mask),
    ("global_mean_object_area_px", obj_mean),
    ("global_median_object_area_px", obj_median),
    ("global_p90_object_area_px", obj_p90),
    ("global_p95_object_area_px", obj_p95),
    ("global_p99_object_area_px", obj_p99),
    ("global_mean_depth_valid_ratio", depth_valid_mean),
    ("global_median_depth_valid_ratio", depth_valid_median),
    ("label_parse_unique_classes", len(label_class_counter)),
    ("label_parse_total_labeled", int(sum(label_class_counter.values()))),
]

df_summary = pd.DataFrame(summary_rows, columns=["metric", "value"])
TAB_02_SUMMARY = RESULTS_DIR / "tab_02_integrity_and_profile_summary.csv"
df_summary.to_csv(TAB_02_SUMMARY, index=False)

site_profile_rows = []
for site_id in sorted([p.name for p in site_dirs]):
    areas = site_object_areas.get(site_id, [])
    dvrs = site_depth_valid_ratios.get(site_id, [])
    site_profile_rows.append({
        "site_id": site_id,
        "num_samples": int(site_valid_counts.get(site_id, 0)),
        "mean_object_area_px": float(np.mean(areas)) if len(areas) else np.nan,
        "median_object_area_px": float(np.median(areas)) if len(areas) else np.nan,
        "p90_object_area_px": safe_percentile(areas, 90),
        "mean_depth_valid_ratio": float(np.mean(dvrs)) if len(dvrs) else np.nan
    })

df_site = pd.DataFrame(site_profile_rows).sort_values("site_id", kind="mergesort").reset_index(drop=True)
TAB_03_SITE = RESULTS_DIR / "tab_03_site_profile_table.csv"
df_site.to_csv(TAB_03_SITE, index=False)

counts = {
    "valid": total_valid,
    "missing_rgb": missing_rgb,
    "missing_depth": missing_depth,
    "missing_mask": missing_mask,
    "unreadable_rgb": unreadable_rgb,
    "unreadable_depth": unreadable_depth,
    "unreadable_mask": unreadable_mask,
}
labels = list(counts.keys())
vals = [counts[k] for k in labels]

plt.figure(figsize=(9, 4.5))
plt.bar(labels, vals)
plt.title("ISOD file health overview")
plt.ylabel("Count")
plt.xticks(rotation=25, ha="right")
mx = max(vals) if len(vals) else 0
for i, v in enumerate(vals):
    plt.text(i, v + max(1, 0.01 * mx), str(v), ha="center", va="bottom")
save_fig(RESULTS_DIR / "fig_01_file_health_overview.png", dpi=300)

def counter_to_sorted_xy(counter: Counter):
    items = sorted(counter.items(), key=lambda kv: (kv[0][0], kv[0][1]) if isinstance(kv[0], tuple) else kv[0])
    xs = [str(k) for k, _ in items]
    ys = [v for _, v in items]
    return xs, ys

rgb_res_x, rgb_res_y = counter_to_sorted_xy(rgb_res_counter)
depth_res_x, depth_res_y = counter_to_sorted_xy(depth_res_counter)

rgb_bit_items = sorted(rgb_bit_counter.items(), key=lambda kv: kv[0])
depth_bit_items = sorted(depth_bit_counter.items(), key=lambda kv: kv[0])

plt.figure(figsize=(11, 7))
ax1 = plt.subplot(2, 2, 1)
ax1.bar(range(len(rgb_res_x)), rgb_res_y)
ax1.set_title("RGB resolution")
ax1.set_ylabel("Count")
ax1.set_xticks(range(len(rgb_res_x)))
ax1.set_xticklabels(rgb_res_x, rotation=45, ha="right")

ax2 = plt.subplot(2, 2, 2)
ax2.bar(range(len(depth_res_x)), depth_res_y)
ax2.set_title("Depth resolution")
ax2.set_ylabel("Count")
ax2.set_xticks(range(len(depth_res_x)))
ax2.set_xticklabels(depth_res_x, rotation=45, ha="right")

ax3 = plt.subplot(2, 2, 3)
ax3.bar([str(k) for k, _ in rgb_bit_items], [v for _, v in rgb_bit_items])
ax3.set_title("RGB bit depth")
ax3.set_ylabel("Count")
ax3.set_xlabel("Bit depth")

ax4 = plt.subplot(2, 2, 4)
ax4.bar([str(k) for k, _ in depth_bit_items], [v for _, v in depth_bit_items])
ax4.set_title("Depth bit depth")
ax4.set_ylabel("Count")
ax4.set_xlabel("Bit depth")

save_fig(RESULTS_DIR / "fig_02_resolution_and_bitdepth_profile.png", dpi=300)

sites_sorted = sorted(site_valid_counts.keys())
site_counts = [int(site_valid_counts[s]) for s in sites_sorted]

plt.figure(figsize=(12, 4.8))
plt.bar(sites_sorted, site_counts)
plt.title("Valid triplet samples per site")
plt.ylabel("Count")
plt.xticks(rotation=45, ha="right")
save_fig(RESULTS_DIR / "fig_03_samples_per_site.png", dpi=300)

areas = np.array(global_object_areas, dtype=np.float64) if len(global_object_areas) else np.array([])
plt.figure(figsize=(9, 4.8))
if len(areas):
    a_min = max(1.0, float(np.min(areas)))
    a_max = float(np.max(areas))
    if a_max / a_min > 200:
        bins = np.logspace(np.log10(a_min), np.log10(a_max), 40)
        plt.hist(areas, bins=bins)
        plt.xscale("log")
        plt.xlabel("Object area in pixels (log scale)")
    else:
        plt.hist(areas, bins=40)
        plt.xlabel("Object area in pixels")
    p50 = np.percentile(areas, 50)
    p90 = np.percentile(areas, 90)
    p95 = np.percentile(areas, 95)
    plt.axvline(p50, linestyle="--", linewidth=1)
    plt.axvline(p90, linestyle="--", linewidth=1)
    plt.axvline(p95, linestyle="--", linewidth=1)
    plt.title("Mask object area distribution")
    plt.ylabel("Count")
    plt.legend(["p50", "p90", "p95"])
else:
    plt.title("Mask object area distribution")
    plt.text(0.5, 0.5, "No valid mask areas found", ha="center", va="center")
    plt.axis("off")

save_fig(RESULTS_DIR / "fig_04_object_size_distribution.png", dpi=300)



In [None]:
import os
import re
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import cv2

WORK_ROOT = Path("/kaggle/working")
PHASE1_DIR = WORK_ROOT / "results" / "01_data_integrity_and_profile"
PHASE2_DIR = WORK_ROOT / "results" / "02_sensor_alignment_and_quality"
PHASE2_DIR.mkdir(parents=True, exist_ok=True)

MANIFEST_VALID_PATH = PHASE1_DIR / "tab_01_dataset_manifest_valid_only.csv"
if not MANIFEST_VALID_PATH.exists():
    raise FileNotFoundError(f"Missing Phase 1 valid manifest: {MANIFEST_VALID_PATH}")

df_valid = pd.read_csv(MANIFEST_VALID_PATH)
df_valid = df_valid.sort_values("sample_id", kind="mergesort").reset_index(drop=True)

plt.rcParams.update({
    "font.size": 10,
    "axes.titlesize": 11,
    "axes.labelsize": 10,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 9,
})

def save_fig(path: Path, dpi=300):
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

def read_img_rgb(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 2:
        arr = np.stack([arr, arr, arr], axis=2)
    if arr.shape[2] > 3:
        arr = arr[:, :, :3]
    return arr

def read_img_depth(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 3:
        arr = arr[:, :, 0]
    return arr

def read_img_mask(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 3:
        arr = np.any(arr > 0, axis=2).astype(np.uint8)
    else:
        arr = (arr > 0).astype(np.uint8)
    return arr

def mask_boundary(mask01: np.ndarray):
    k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    b = cv2.morphologyEx(mask01.astype(np.uint8), cv2.MORPH_GRADIENT, k)
    return (b > 0).astype(np.uint8)

def canny_edges(rgb: np.ndarray):
    gray = cv2.cvtColor(rgb.astype(np.uint8), cv2.COLOR_RGB2GRAY)
    v = np.median(gray)
    lo = int(max(0, 0.66 * v))
    hi = int(min(255, 1.33 * v))
    e = cv2.Canny(gray, lo, hi, L2gradient=True)
    return (e > 0).astype(np.uint8)

def depth_grad_edges(depth: np.ndarray):
    d = depth.astype(np.float32)
    valid = d > 0
    if np.count_nonzero(valid) < 10:
        return np.zeros_like(d, dtype=np.uint8)
    vv = d[valid]
    mn, mx = float(np.min(vv)), float(np.max(vv))
    if mx - mn < 1e-6:
        dn = np.zeros_like(d, dtype=np.float32)
    else:
        dn = (d - mn) / (mx - mn)
        dn[~valid] = 0.0
    gx = cv2.Sobel(dn, cv2.CV_32F, 1, 0, ksize=3)
    gy = cv2.Sobel(dn, cv2.CV_32F, 0, 1, ksize=3)
    gm = cv2.magnitude(gx, gy)
    gm_valid = gm[valid]
    if gm_valid.size < 10:
        return np.zeros_like(d, dtype=np.uint8)
    thr = float(np.percentile(gm_valid, 90))
    ed = (gm >= thr).astype(np.uint8)
    return ed

def dilate_binary(b01: np.ndarray, r: int):
    if r <= 0:
        return b01.astype(np.uint8)
    k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2 * r + 1, 2 * r + 1))
    return (cv2.dilate(b01.astype(np.uint8), k, iterations=1) > 0).astype(np.uint8)

def depth_valid_ratio(depth: np.ndarray):
    d = depth
    valid = (d > 0) & np.isfinite(d)
    total = d.size
    if total == 0:
        return np.nan
    return float(np.count_nonzero(valid)) / float(total)

def depth_hole_rate(depth: np.ndarray):
    r = depth_valid_ratio(depth)
    if np.isnan(r):
        return np.nan
    return float(1.0 - r)

def hole_components(depth: np.ndarray):
    invalid = ((depth <= 0) | (~np.isfinite(depth))).astype(np.uint8)
    if invalid.size == 0:
        return 0, 0
    num, labels = cv2.connectedComponents(invalid, connectivity=8)
    if num <= 1:
        return 0, 0
    areas = np.bincount(labels.ravel())
    areas = areas[1:]
    return int(len(areas)), int(np.max(areas)) if len(areas) else 0

def agreement(boundary01: np.ndarray, cue01: np.ndarray, tol: int):
    b = boundary01.astype(np.uint8)
    c = dilate_binary(cue01, tol)
    idx = (b > 0)
    denom = int(np.count_nonzero(idx))
    if denom == 0:
        return np.nan
    return float(np.count_nonzero(c[idx] > 0)) / float(denom)

def alignment_error(edge_agree: float, depth_agree: float):
    if np.isnan(edge_agree) and np.isnan(depth_agree):
        return np.nan
    if np.isnan(edge_agree):
        return float(1.0 - depth_agree)
    if np.isnan(depth_agree):
        return float(1.0 - edge_agree)
    return float(1.0 - 0.5 * (edge_agree + depth_agree))

def parse_shape_str(s):
    if s is None or (isinstance(s, float) and np.isnan(s)) or str(s).strip() == "":
        return None
    nums = re.findall(r"\d+", str(s))
    if len(nums) < 2:
        return None
    return tuple(int(x) for x in nums)

TOL_PX = 3

rows = []
for i in range(len(df_valid)):
    sample_id = df_valid.loc[i, "sample_id"]
    site_id = df_valid.loc[i, "site_id"]
    rgb_path = df_valid.loc[i, "rgb_path"]
    depth_path = df_valid.loc[i, "depth_path"]
    mask_path = df_valid.loc[i, "mask_path"]

    notes = []
    try:
        rgb = read_img_rgb(rgb_path)
        depth = read_img_depth(depth_path)
        mask = read_img_mask(mask_path)
    except Exception as e:
        rows.append({
            "sample_id": sample_id,
            "site_id": site_id,
            "alignment_error": np.nan,
            "edge_agreement": np.nan,
            "depth_gradient_agreement": np.nan,
            "depth_valid_ratio": np.nan,
            "notes": f"read_error:{type(e).__name__}"
        })
        continue

    if rgb.shape[0] != depth.shape[0] or rgb.shape[1] != depth.shape[1] or rgb.shape[0] != mask.shape[0] or rgb.shape[1] != mask.shape[1]:
        notes.append("shape_mismatch")
        h = min(rgb.shape[0], depth.shape[0], mask.shape[0])
        w = min(rgb.shape[1], depth.shape[1], mask.shape[1])
        rgb = rgb[:h, :w]
        depth = depth[:h, :w]
        mask = mask[:h, :w]

    bnd = mask_boundary(mask)
    rgb_e = canny_edges(rgb)
    dep_e = depth_grad_edges(depth)

    edge_agree = agreement(bnd, rgb_e, TOL_PX)
    depth_agree = agreement(bnd, dep_e, TOL_PX)
    aerr = alignment_error(edge_agree, depth_agree)

    dvr = depth_valid_ratio(depth)
    n_holes, max_hole = hole_components(depth)
    if n_holes > 0:
        notes.append(f"holes={n_holes}")
        notes.append(f"max_hole_px={max_hole}")

    rows.append({
        "sample_id": sample_id,
        "site_id": site_id,
        "alignment_error": aerr,
        "edge_agreement": edge_agree,
        "depth_gradient_agreement": depth_agree,
        "depth_valid_ratio": dvr,
        "notes": ";".join(notes)
    })

df_align = pd.DataFrame(rows).sort_values("sample_id", kind="mergesort").reset_index(drop=True)

TAB_01 = PHASE2_DIR / "tab_01_alignment_metrics_full.csv"
df_align.to_csv(TAB_01, index=False)

df_site = df_align.groupby("site_id", sort=True).agg(
    num_samples=("sample_id", "count"),
    mean_alignment_error=("alignment_error", "mean"),
    std_alignment_error=("alignment_error", "std"),
    mean_depth_valid_ratio=("depth_valid_ratio", "mean"),
)
df_site["depth_hole_rate"] = 1.0 - df_site["mean_depth_valid_ratio"]
df_site = df_site.reset_index().sort_values("site_id", kind="mergesort").reset_index(drop=True)

TAB_02 = PHASE2_DIR / "tab_02_site_quality_summary.csv"
df_site.to_csv(TAB_02, index=False)

df_align_clean = df_align.dropna(subset=["alignment_error"]).copy()
df_align_clean = df_align_clean.sort_values(["alignment_error", "sample_id"], ascending=[False, True], kind="mergesort").reset_index(drop=True)

TOPK_OUTLIERS = 16
df_out = df_align_clean.head(TOPK_OUTLIERS).copy()

def likely_cause(row):
    notes = str(row.get("notes", ""))
    dvr = row.get("depth_valid_ratio", np.nan)
    if isinstance(dvr, (float, np.floating)) and not np.isnan(dvr) and dvr < 0.85:
        return "depth_holes_or_invalid_depth"
    if "shape_mismatch" in notes:
        return "sensor_resolution_mismatch_or_crop"
    return "sensor_misalignment_or_edge_weakness"

df_out["likely_cause"] = df_out.apply(likely_cause, axis=1)
TAB_03 = PHASE2_DIR / "tab_03_alignment_outliers_table.csv"
df_out[["sample_id", "site_id", "alignment_error", "depth_valid_ratio", "likely_cause"]].to_csv(TAB_03, index=False)

df_align_clean_sorted = df_align_clean.sort_values(["alignment_error", "sample_id"], ascending=[True, True], kind="mergesort").reset_index(drop=True)
n = len(df_align_clean_sorted)
idx_low = list(range(0, min(4, n)))
idx_mid = [min(n - 1, int(round((n - 1) * q))) for q in [0.45, 0.50, 0.55, 0.60]]
idx_high = list(range(max(0, n - 4), n))
sel_idx = idx_low + idx_mid + idx_high
sel_idx = [i for i in sel_idx if 0 <= i < n]
sel_idx = list(dict.fromkeys(sel_idx))
df_rep = df_align_clean_sorted.iloc[sel_idx].copy()
df_rep = df_rep.sort_values("sample_id", kind="mergesort").reset_index(drop=True)

overlay_ids = df_rep["sample_id"].tolist()
outlier_ids = df_out["sample_id"].tolist()

TAB_04 = PHASE2_DIR / "tab_04_figure_sample_ids.csv"
df_figids = pd.DataFrame([
    {"figure_name": "fig_01_rgb_depth_mask_overlay_grid.png", "sample_id_list": "|".join(overlay_ids), "selection_rule": "4 best + 4 mid + 4 worst by alignment_error with deterministic ties"},
    {"figure_name": "fig_04_alignment_outliers_visual.png", "sample_id_list": "|".join(outlier_ids), "selection_rule": f"top {TOPK_OUTLIERS} worst alignment_error with deterministic ties"},
])
df_figids.to_csv(TAB_04, index=False)

def build_lookup(df_manifest_valid: pd.DataFrame):
    lut = {}
    for i in range(len(df_manifest_valid)):
        sid = df_manifest_valid.loc[i, "sample_id"]
        lut[sid] = {
            "site_id": df_manifest_valid.loc[i, "site_id"],
            "rgb_path": df_manifest_valid.loc[i, "rgb_path"],
            "depth_path": df_manifest_valid.loc[i, "depth_path"],
            "mask_path": df_manifest_valid.loc[i, "mask_path"],
        }
    return lut

lut = build_lookup(df_valid)

def draw_boundary_overlay(ax, base_img, boundary01, title, is_depth=False):
    if is_depth:
        ax.imshow(base_img, cmap="gray")
    else:
        ax.imshow(base_img)
    ax.contour(boundary01.astype(np.uint8), levels=[0.5], colors=["red"], linewidths=1.0)
    ax.set_title(title)
    ax.axis("off")

def load_triplet_by_sample_id(sample_id):
    p = lut[sample_id]
    rgb = read_img_rgb(p["rgb_path"])
    depth = read_img_depth(p["depth_path"])
    mask = read_img_mask(p["mask_path"])
    h = min(rgb.shape[0], depth.shape[0], mask.shape[0])
    w = min(rgb.shape[1], depth.shape[1], mask.shape[1])
    rgb = rgb[:h, :w]
    depth = depth[:h, :w]
    mask = mask[:h, :w]
    bnd = mask_boundary(mask)
    return p["site_id"], rgb, depth, bnd

grid_ids = overlay_ids
m = len(grid_ids)
cols = 4
rows_grid = int(np.ceil(m / cols))
plt.figure(figsize=(4 * cols, 3 * rows_grid * 2))

for k, sid in enumerate(grid_ids):
    site_id, rgb, depth, bnd = load_triplet_by_sample_id(sid)
    r = k // cols
    c = k % cols
    ax1 = plt.subplot(rows_grid * 2, cols, r * 2 * cols + c + 1)
    draw_boundary_overlay(ax1, rgb, bnd, f"{sid}", is_depth=False)
    ax2 = plt.subplot(rows_grid * 2, cols, (r * 2 + 1) * cols + c + 1)
    draw_boundary_overlay(ax2, depth, bnd, f"depth", is_depth=True)

save_fig(PHASE2_DIR / "fig_01_rgb_depth_mask_overlay_grid.png", dpi=300)

df_site_plot = df_site.sort_values(["mean_alignment_error", "site_id"], ascending=[False, True], kind="mergesort").reset_index(drop=True)

plt.figure(figsize=(12, 4.8))
plt.bar(df_site_plot["site_id"].tolist(), df_site_plot["mean_alignment_error"].tolist())
plt.title("Mean alignment error by site")
plt.ylabel("Mean alignment error")
plt.xticks(rotation=45, ha="right")
save_fig(PHASE2_DIR / "fig_02_alignment_error_by_site.png", dpi=300)

plt.figure(figsize=(12, 4.8))
plt.bar(df_site_plot["site_id"].tolist(), df_site_plot["mean_depth_valid_ratio"].tolist())
plt.title("Mean depth valid ratio by site")
plt.ylabel("Mean depth valid ratio")
plt.ylim(0, 1.0)
plt.xticks(rotation=45, ha="right")
save_fig(PHASE2_DIR / "fig_03_depth_validity_by_site.png", dpi=300)

vis_ids = outlier_ids[:min(8, len(outlier_ids))]
cols = 2
rows_grid = int(np.ceil(len(vis_ids) / cols))
plt.figure(figsize=(9 * cols, 4.5 * rows_grid))

align_map = {r["sample_id"]: r for r in df_out.to_dict(orient="records")}

for k, sid in enumerate(vis_ids):
    site_id, rgb, depth, bnd = load_triplet_by_sample_id(sid)
    met = align_map.get(sid, {})
    aerr = met.get("alignment_error", np.nan)
    dvr = met.get("depth_valid_ratio", np.nan)
    cause = met.get("likely_cause", "")
    ax = plt.subplot(rows_grid, cols, k + 1)
    ax.imshow(rgb)
    ax.contour(bnd.astype(np.uint8), levels=[0.5], colors=["red"], linewidths=1.0)
    ax.set_title(f"{sid}\nerr={aerr:.3f} dvr={dvr:.3f}\n{cause}")
    ax.axis("off")

save_fig(PHASE2_DIR / "fig_04_alignment_outliers_visual.png", dpi=300)



In [None]:
import os
import json
import math
import re
from pathlib import Path
from collections import Counter, defaultdict

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import cv2

WORK_ROOT = Path("/kaggle/working")
PHASE1_DIR = WORK_ROOT / "results" / "01_data_integrity_and_profile"
PHASE2_DIR = WORK_ROOT / "results" / "02_sensor_alignment_and_quality"
PHASE3_DIR = WORK_ROOT / "results" / "03_preprocess_and_split_protocol"
PHASE3_DIR.mkdir(parents=True, exist_ok=True)

MANIFEST_VALID_PATH = PHASE1_DIR / "tab_01_dataset_manifest_valid_only.csv"
ALIGNMENT_PATH = PHASE2_DIR / "tab_01_alignment_metrics_full.csv"

if not MANIFEST_VALID_PATH.exists():
    raise FileNotFoundError(f"Missing Phase 1 valid manifest: {MANIFEST_VALID_PATH}")
if not ALIGNMENT_PATH.exists():
    raise FileNotFoundError(f"Missing Phase 2 alignment table: {ALIGNMENT_PATH}")

df_valid = pd.read_csv(MANIFEST_VALID_PATH).sort_values("sample_id", kind="mergesort").reset_index(drop=True)
df_align = pd.read_csv(ALIGNMENT_PATH).sort_values("sample_id", kind="mergesort").reset_index(drop=True)

df = df_valid.merge(df_align[["sample_id", "depth_valid_ratio", "notes"]], on="sample_id", how="left", suffixes=("", "_align"))
df["site_id"] = df["site_id"].astype(str)
df["sample_id"] = df["sample_id"].astype(str)

plt.rcParams.update({
    "font.size": 10,
    "axes.titlesize": 11,
    "axes.labelsize": 10,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 9,
})

def save_fig(path: Path, dpi=300):
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

def read_rgb(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 2:
        arr = np.stack([arr, arr, arr], axis=2)
    if arr.shape[2] > 3:
        arr = arr[:, :, :3]
    return arr.astype(np.uint8)

def read_depth(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 3:
        arr = arr[:, :, 0]
    return arr

def read_mask(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 3:
        arr = np.any(arr > 0, axis=2).astype(np.uint8)
    else:
        arr = (arr > 0).astype(np.uint8)
    return arr

def mask_area(mask01: np.ndarray):
    return int(np.count_nonzero(mask01 > 0))

def depth_valid_ratio(depth: np.ndarray):
    d = depth
    valid = (d > 0) & np.isfinite(d)
    total = d.size
    if total == 0:
        return np.nan
    return float(np.count_nonzero(valid)) / float(total)

def fill_depth_holes(depth: np.ndarray):
    d = depth.astype(np.float32)
    valid = (d > 0) & np.isfinite(d)
    if np.count_nonzero(valid) < 10:
        return d, "too_few_valid"
    invalid = (~valid).astype(np.uint8)
    d0 = d.copy()
    d0[~valid] = 0.0
    filled = cv2.inpaint(d0, invalid, 3, cv2.INPAINT_TELEA)
    filled[~np.isfinite(filled)] = 0.0
    return filled, ""

def depth_to_uint8_for_viz(depth: np.ndarray):
    d = depth.astype(np.float32)
    valid = (d > 0) & np.isfinite(d)
    if np.count_nonzero(valid) < 10:
        return np.zeros_like(d, dtype=np.uint8)
    vv = d[valid]
    mn, mx = float(np.percentile(vv, 1)), float(np.percentile(vv, 99))
    if mx - mn < 1e-6:
        return np.zeros_like(d, dtype=np.uint8)
    dn = (d - mn) / (mx - mn)
    dn = np.clip(dn, 0.0, 1.0)
    return (dn * 255.0).astype(np.uint8)

def resize_triplet(rgb: np.ndarray, depth: np.ndarray, mask01: np.ndarray, out_hw):
    oh, ow = int(out_hw[0]), int(out_hw[1])
    rh, rw = rgb.shape[0], rgb.shape[1]
    if (rh == oh) and (rw == ow):
        return rgb, depth, mask01, {"resize_mode": "none", "scale": 1.0}
    rgb_r = cv2.resize(rgb, (ow, oh), interpolation=cv2.INTER_LINEAR)
    if depth.dtype.kind in ["u", "i"] and depth.dtype.itemsize >= 2:
        depth_f = depth.astype(np.float32)
        depth_r = cv2.resize(depth_f, (ow, oh), interpolation=cv2.INTER_NEAREST)
    else:
        depth_f = depth.astype(np.float32)
        depth_r = cv2.resize(depth_f, (ow, oh), interpolation=cv2.INTER_NEAREST)
    mask_r = cv2.resize(mask01.astype(np.uint8), (ow, oh), interpolation=cv2.INTER_NEAREST)
    scale = float(oh) / float(rh) if rh > 0 else np.nan
    return rgb_r, depth_r, (mask_r > 0).astype(np.uint8), {"resize_mode": "resize", "scale": scale}

def choose_target_hw(df_in: pd.DataFrame):
    hs = []
    ws = []
    for s in df_in["rgb_shape"].astype(str).tolist():
        nums = re.findall(r"\d+", s)
        if len(nums) >= 2:
            hs.append(int(nums[0]))
            ws.append(int(nums[1]))
    if len(hs) == 0:
        return (480, 640)
    hw = Counter(list(zip(hs, ws)))
    target = sorted(hw.items(), key=lambda kv: (-kv[1], kv[0][0], kv[0][1]))[0][0]
    return (int(target[0]), int(target[1]))

TARGET_HW = choose_target_hw(df_valid)

PREPROCESS_CONFIG = [
    {"step": "load", "modality": "rgb", "operation": "PIL_load_to_numpy_uint8", "parameters": "keep_first_3_channels", "rationale": "standardize channel layout"},
    {"step": "load", "modality": "depth", "operation": "PIL_load_to_numpy", "parameters": "use_first_channel_if_multichannel", "rationale": "standardize depth as single channel"},
    {"step": "load", "modality": "mask", "operation": "PIL_load_to_binary_uint8", "parameters": "mask=any(channel>0)", "rationale": "standardize mask as binary"},
    {"step": "align_shape", "modality": "all", "operation": "crop_to_min_common_hw", "parameters": "h=min(h_rgb,h_depth,h_mask), w=min(w_rgb,w_depth,w_mask)", "rationale": "avoid shape mismatch failures"},
    {"step": "resize", "modality": "all", "operation": "resize_to_target_hw", "parameters": f"target_hw={TARGET_HW}, rgb=bilinear, depth=nearest, mask=nearest", "rationale": "consistent input size for all models"},
    {"step": "depth_clean", "modality": "depth", "operation": "inpaint_invalid_depth", "parameters": "invalid=(depth<=0 or nonfinite), method=TELEA, radius=3", "rationale": "reduce holes and stabilize gradients"},
    {"step": "depth_scale", "modality": "depth", "operation": "robust_minmax_for_visualization", "parameters": "viz only: p1-p99 to uint8", "rationale": "paper figures and debugging"},
]

TAB_01 = PHASE3_DIR / "tab_01_preprocess_config_table.csv"
pd.DataFrame(PREPROCESS_CONFIG)[["step", "modality", "operation", "parameters", "rationale"]].to_csv(TAB_01, index=False)

stats_rows = []
example_rows = []

N_EXAMPLES = 8
EXAMPLE_SAMPLE_IDS = df.sort_values("sample_id", kind="mergesort").head(N_EXAMPLES)["sample_id"].tolist()

for i in range(len(df)):
    sample_id = df.loc[i, "sample_id"]
    site_id = df.loc[i, "site_id"]
    rgb_path = df.loc[i, "rgb_path"]
    depth_path = df.loc[i, "depth_path"]
    mask_path = df.loc[i, "mask_path"]

    try:
        rgb0 = read_rgb(rgb_path)
        depth0 = read_depth(depth_path)
        mask0 = read_mask(mask_path)
    except Exception as e:
        stats_rows.append({
            "sample_id": sample_id,
            "site_id": site_id,
            "depth_valid_ratio_before": np.nan,
            "depth_valid_ratio_after": np.nan,
            "mask_area_px_before": np.nan,
            "mask_area_px_after": np.nan,
            "notes": f"read_error:{type(e).__name__}"
        })
        continue

    h = min(rgb0.shape[0], depth0.shape[0], mask0.shape[0])
    w = min(rgb0.shape[1], depth0.shape[1], mask0.shape[1])
    rgb0 = rgb0[:h, :w]
    depth0 = depth0[:h, :w]
    mask0 = mask0[:h, :w]

    dvr_before = depth_valid_ratio(depth0)
    area_before = mask_area(mask0)

    rgb1, depth1, mask1, rz_info = resize_triplet(rgb0, depth0, mask0, TARGET_HW)

    depth2, note_fill = fill_depth_holes(depth1)
    dvr_after = depth_valid_ratio(depth2)
    area_after = mask_area(mask1)

    notes = []
    if note_fill:
        notes.append(note_fill)
    if isinstance(rz_info, dict) and rz_info.get("resize_mode", "") != "none":
        notes.append(f"resized_to={TARGET_HW}")

    stats_rows.append({
        "sample_id": sample_id,
        "site_id": site_id,
        "depth_valid_ratio_before": dvr_before,
        "depth_valid_ratio_after": dvr_after,
        "mask_area_px_before": area_before,
        "mask_area_px_after": area_after,
        "notes": ";".join(notes)
    })

    if sample_id in EXAMPLE_SAMPLE_IDS:
        example_rows.append((sample_id, site_id, rgb0, depth0, mask0, rgb1, depth1, mask1, depth2))

df_stats = pd.DataFrame(stats_rows).sort_values("sample_id", kind="mergesort").reset_index(drop=True)
TAB_02 = PHASE3_DIR / "tab_02_preprocess_stats_full.csv"
df_stats.to_csv(TAB_02, index=False)

SPLIT_SEED = 1337
TRAIN_FRAC = 0.70
VAL_FRAC = 0.15
TEST_FRAC = 0.15

sites = sorted(df["site_id"].unique().tolist())
rng = np.random.default_rng(SPLIT_SEED)
sites_shuffled = sites.copy()
rng.shuffle(sites_shuffled)

n_sites = len(sites_shuffled)
n_train = int(round(TRAIN_FRAC * n_sites))
n_val = int(round(VAL_FRAC * n_sites))
n_test = n_sites - n_train - n_val
if n_test < 1:
    n_test = 1
    if n_val > 1:
        n_val -= 1
    else:
        n_train -= 1

train_sites = sorted(sites_shuffled[:n_train])
val_sites = sorted(sites_shuffled[n_train:n_train + n_val])
test_sites = sorted(sites_shuffled[n_train + n_val:])

site_to_split = {}
for s in train_sites:
    site_to_split[s] = "train"
for s in val_sites:
    site_to_split[s] = "val"
for s in test_sites:
    site_to_split[s] = "test"

df_split = df[["sample_id", "site_id"]].copy()
df_split["split"] = df_split["site_id"].map(site_to_split).astype(str)
df_split = df_split.sort_values("sample_id", kind="mergesort").reset_index(drop=True)

TAB_03 = PHASE3_DIR / "tab_03_split_manifest.csv"
df_split.to_csv(TAB_03, index=False)

df_stats_m = df_stats.merge(df_split, on=["sample_id", "site_id"], how="left")

PROTOCOL = [
    {"item": "leakage_control", "value": "site_disjoint_splits"},
    {"item": "split_seed", "value": str(SPLIT_SEED)},
    {"item": "split_fractions_by_site", "value": f"train={TRAIN_FRAC}, val={VAL_FRAC}, test={TEST_FRAC}"},
    {"item": "target_hw", "value": str(TARGET_HW)},
    {"item": "mask_threshold", "value": "mask_binary_nonzero"},
    {"item": "depth_invalid_definition", "value": "depth<=0 or nonfinite"},
    {"item": "depth_hole_fill", "value": "inpaint_TELEA_radius3"},
    {"item": "alignment_metric_reference", "value": "Phase2 alignment_error with tol_px=3"},
    {"item": "primary_metrics", "value": "IoU, Dice, Boundary_F1"},
    {"item": "aux_metrics", "value": "MAE_depth_valid_ratio, outlier_rate"},
    {"item": "postprocessing", "value": "keep_largest_component optional, threshold=0.5"},
    {"item": "threshold_policy", "value": "fixed 0.5 for masks, report sensitivity in ablation"},
]

TAB_04 = PHASE3_DIR / "tab_04_protocol_metrics_table.csv"
pd.DataFrame(PROTOCOL)[["item", "value"]].to_csv(TAB_04, index=False)

ex = example_rows
cols = 4
rows_grid = int(math.ceil(len(ex) / cols))
plt.figure(figsize=(4 * cols, 3.2 * rows_grid * 2))

for k, (sample_id, site_id, rgb0, depth0, mask0, rgb1, depth1, mask1, depth2) in enumerate(ex):
    r = k // cols
    c = k % cols
    ax1 = plt.subplot(rows_grid * 2, cols, r * 2 * cols + c + 1)
    rgb0v = rgb0.copy()
    b0 = (cv2.morphologyEx(mask0.astype(np.uint8), cv2.MORPH_GRADIENT, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))) > 0).astype(np.uint8)
    ax1.imshow(rgb0v)
    ax1.contour(b0, levels=[0.5], colors=["red"], linewidths=1.0)
    ax1.set_title(f"before {sample_id}")
    ax1.axis("off")
    ax2 = plt.subplot(rows_grid * 2, cols, (r * 2 + 1) * cols + c + 1)
    rgb1v = rgb1.copy()
    b1 = (cv2.morphologyEx(mask1.astype(np.uint8), cv2.MORPH_GRADIENT, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))) > 0).astype(np.uint8)
    ax2.imshow(rgb1v)
    ax2.contour(b1, levels=[0.5], colors=["red"], linewidths=1.0)
    ax2.set_title("after")
    ax2.axis("off")

save_fig(PHASE3_DIR / "fig_01_preprocessing_before_after_examples.png", dpi=300)

depth_before_vals = []
depth_after_vals = []

sample_for_hist = df.sort_values("sample_id", kind="mergesort").head(300)["sample_id"].tolist()
lut = {df.loc[i, "sample_id"]: (df.loc[i, "depth_path"]) for i in range(len(df))}

for sid in sample_for_hist:
    dp = lut[sid]
    try:
        d0 = read_depth(dp)
        dvr0 = depth_valid_ratio(d0)
        if not np.isnan(dvr0):
            valid0 = (d0 > 0) & np.isfinite(d0)
            if np.count_nonzero(valid0) > 0:
                vv0 = d0[valid0].astype(np.float32)
                depth_before_vals.append(vv0)
        d1, _ = fill_depth_holes(d0)
        valid1 = (d1 > 0) & np.isfinite(d1)
        if np.count_nonzero(valid1) > 0:
            vv1 = d1[valid1].astype(np.float32)
            depth_after_vals.append(vv1)
    except Exception:
        continue

if len(depth_before_vals):
    depth_before_vals = np.concatenate(depth_before_vals)
else:
    depth_before_vals = np.array([], dtype=np.float32)
if len(depth_after_vals):
    depth_after_vals = np.concatenate(depth_after_vals)
else:
    depth_after_vals = np.array([], dtype=np.float32)

plt.figure(figsize=(9, 4.8))
if depth_before_vals.size and depth_after_vals.size:
    p1b, p99b = np.percentile(depth_before_vals, [1, 99])
    p1a, p99a = np.percentile(depth_after_vals, [1, 99])
    lo = float(min(p1b, p1a))
    hi = float(max(p99b, p99a))
    bins = 60
    plt.hist(depth_before_vals, bins=bins, range=(lo, hi), alpha=0.6, label="before")
    plt.hist(depth_after_vals, bins=bins, range=(lo, hi), alpha=0.6, label="after")
    plt.title("Depth distribution before and after cleaning")
    plt.xlabel("Depth value")
    plt.ylabel("Count")
    plt.legend()
else:
    plt.title("Depth distribution before and after cleaning")
    plt.text(0.5, 0.5, "Insufficient depth values for histogram", ha="center", va="center")
    plt.axis("off")

save_fig(PHASE3_DIR / "fig_02_depth_distribution_before_after.png", dpi=300)

counts_by_site_split = df_split.groupby(["site_id", "split"], sort=True).size().reset_index(name="count")
site_order = sorted(df_split["site_id"].unique().tolist())
split_order = ["train", "val", "test"]
pivot = counts_by_site_split.pivot(index="site_id", columns="split", values="count").reindex(site_order).fillna(0).reindex(columns=split_order).fillna(0)

x = np.arange(len(pivot.index))
plt.figure(figsize=(12, 4.8))
bottom = np.zeros(len(pivot.index), dtype=np.float64)
for sp in split_order:
    vals = pivot[sp].values.astype(np.float64)
    plt.bar(x, vals, bottom=bottom, label=sp)
    bottom += vals
plt.title("Train val test sample counts by site")
plt.ylabel("Count")
plt.xticks(x, pivot.index.tolist(), rotation=45, ha="right")
plt.legend()
save_fig(PHASE3_DIR / "fig_03_split_distribution_by_site.png", dpi=300)

df_stats_m2 = df_stats_m.dropna(subset=["mask_area_px_after"]).copy()
plt.figure(figsize=(9, 4.8))
bins = 60
for sp in ["train", "val", "test"]:
    vals = df_stats_m2[df_stats_m2["split"] == sp]["mask_area_px_after"].astype(np.float64).values
    if vals.size:
        plt.hist(vals, bins=bins, alpha=0.6, label=sp)
plt.title("Object area distribution across splits")
plt.xlabel("Mask object area in pixels after preprocessing")
plt.ylabel("Count")
plt.legend()
save_fig(PHASE3_DIR / "fig_04_object_size_by_split.png", dpi=300)

train_sites_set = set(train_sites)
val_sites_set = set(val_sites)
test_sites_set = set(test_sites)
assert len(train_sites_set & val_sites_set) == 0
assert len(train_sites_set & test_sites_set) == 0
assert len(val_sites_set & test_sites_set) == 0



In [None]:
import os
import time
import json
import math
import random
import re
from pathlib import Path
from dataclasses import dataclass
from typing import Tuple, List, Dict

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

try:
    import timm
    _HAS_TIMM = True
except Exception:
    _HAS_TIMM = False

WORK_ROOT = Path("/kaggle/working")
PHASE1_DIR = WORK_ROOT / "results" / "01_data_integrity_and_profile"
PHASE2_DIR = WORK_ROOT / "results" / "02_sensor_alignment_and_quality"
PHASE3_DIR = WORK_ROOT / "results" / "03_preprocess_and_split_protocol"
PHASE4_DIR = WORK_ROOT / "results" / "04_training_baselines_and_vit_fusion"
PHASE4_DIR.mkdir(parents=True, exist_ok=True)

MANIFEST_VALID_PATH = PHASE1_DIR / "tab_01_dataset_manifest_valid_only.csv"
SPLIT_MANIFEST_PATH = PHASE3_DIR / "tab_03_split_manifest.csv"
PREPROCESS_CONFIG_PATH = PHASE3_DIR / "tab_01_preprocess_config_table.csv"
FIG_SAMPLE_IDS_PATH = PHASE2_DIR / "tab_04_figure_sample_ids.csv"

if not MANIFEST_VALID_PATH.exists():
    raise FileNotFoundError(str(MANIFEST_VALID_PATH))
if not SPLIT_MANIFEST_PATH.exists():
    raise FileNotFoundError(str(SPLIT_MANIFEST_PATH))
if not PREPROCESS_CONFIG_PATH.exists():
    raise FileNotFoundError(str(PREPROCESS_CONFIG_PATH))

df_valid = pd.read_csv(MANIFEST_VALID_PATH).sort_values("sample_id", kind="mergesort").reset_index(drop=True)
df_split = pd.read_csv(SPLIT_MANIFEST_PATH).sort_values("sample_id", kind="mergesort").reset_index(drop=True)
df = df_valid.merge(df_split, on=["sample_id", "site_id"], how="inner")
df = df.sort_values("sample_id", kind="mergesort").reset_index(drop=True)

TARGET_HW = None
try:
    cfg = pd.read_csv(PREPROCESS_CONFIG_PATH)
    for s in cfg["parameters"].astype(str).tolist():
        if "target_hw=" in s:
            t = s.split("target_hw=")[1].strip()
            m = re.search(r"\((\d+)\s*,\s*(\d+)\)", t)
            if m:
                TARGET_HW = (int(m.group(1)), int(m.group(2)))
                break
except Exception:
    TARGET_HW = None
if TARGET_HW is None:
    TARGET_HW = (422, 640)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 1337

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(SEED)

plt.rcParams.update({
    "font.size": 10,
    "axes.titlesize": 11,
    "axes.labelsize": 10,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 9,
})

def save_fig(path: Path, dpi=300):
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

def read_rgb(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 2:
        arr = np.stack([arr, arr, arr], axis=2)
    if arr.shape[2] > 3:
        arr = arr[:, :, :3]
    return arr.astype(np.uint8)

def read_depth(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 3:
        arr = arr[:, :, 0]
    return arr

def read_mask(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 3:
        arr = np.any(arr > 0, axis=2).astype(np.uint8)
    else:
        arr = (arr > 0).astype(np.uint8)
    return arr

def crop_to_min_hw(rgb, depth, mask):
    h = min(rgb.shape[0], depth.shape[0], mask.shape[0])
    w = min(rgb.shape[1], depth.shape[1], mask.shape[1])
    return rgb[:h, :w], depth[:h, :w], mask[:h, :w]

def resize_triplet(rgb, depth, mask01, out_hw):
    oh, ow = int(out_hw[0]), int(out_hw[1])
    rgb_r = cv2.resize(rgb, (ow, oh), interpolation=cv2.INTER_LINEAR)
    depth_f = depth.astype(np.float32)
    depth_r = cv2.resize(depth_f, (ow, oh), interpolation=cv2.INTER_NEAREST)
    mask_r = cv2.resize(mask01.astype(np.uint8), (ow, oh), interpolation=cv2.INTER_NEAREST)
    mask_r = (mask_r > 0).astype(np.uint8)
    return rgb_r, depth_r, mask_r

def fill_depth_holes(depth: np.ndarray):
    d = depth.astype(np.float32)
    valid = (d > 0) & np.isfinite(d)
    if np.count_nonzero(valid) < 10:
        return d
    invalid = (~valid).astype(np.uint8)
    d0 = d.copy()
    d0[~valid] = 0.0
    filled = cv2.inpaint(d0, invalid, 3, cv2.INPAINT_TELEA)
    filled[~np.isfinite(filled)] = 0.0
    return filled

def robust_depth_scale(depth: np.ndarray):
    d = depth.astype(np.float32)
    valid = (d > 0) & np.isfinite(d)
    if np.count_nonzero(valid) < 10:
        return np.zeros_like(d, dtype=np.float32)
    vv = d[valid]
    mn, mx = float(np.percentile(vv, 1)), float(np.percentile(vv, 99))
    if mx - mn < 1e-6:
        out = np.zeros_like(d, dtype=np.float32)
        out[valid] = 0.5
        return out
    out = (d - mn) / (mx - mn)
    out = np.clip(out, 0.0, 1.0)
    out[~valid] = 0.0
    return out.astype(np.float32)

def augment_pair(rgb_u8, depth_f, mask_u8, rng: np.random.Generator):
    if rng.random() < 0.5:
        rgb_u8 = np.ascontiguousarray(rgb_u8[:, ::-1])
        depth_f = np.ascontiguousarray(depth_f[:, ::-1])
        mask_u8 = np.ascontiguousarray(mask_u8[:, ::-1])
    if rng.random() < 0.25:
        factor = 0.8 + 0.4 * rng.random()
        rgb_f = rgb_u8.astype(np.float32) * factor
        rgb_u8 = np.clip(rgb_f, 0, 255).astype(np.uint8)
    if rng.random() < 0.15:
        noise = rng.normal(0, 3.0, size=rgb_u8.shape).astype(np.float32)
        rgb_u8 = np.clip(rgb_u8.astype(np.float32) + noise, 0, 255).astype(np.uint8)
    return rgb_u8, depth_f, mask_u8

class ISODSegDataset(Dataset):
    def __init__(self, df_in: pd.DataFrame, split: str, target_hw: Tuple[int,int], seed: int, use_depth: bool):
        self.df = df_in[df_in["split"] == split].sort_values("sample_id", kind="mergesort").reset_index(drop=True)
        self.target_hw = target_hw
        self.use_depth = use_depth
        self.rng = np.random.default_rng(seed + (0 if split == "train" else 999))
        self.split = split

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.loc[idx]
        rgb = read_rgb(row["rgb_path"])
        depth = read_depth(row["depth_path"])
        mask = read_mask(row["mask_path"])
        rgb, depth, mask = crop_to_min_hw(rgb, depth, mask)
        rgb, depth, mask = resize_triplet(rgb, depth, mask, self.target_hw)
        depth = fill_depth_holes(depth)
        depth_scaled = robust_depth_scale(depth)
        if self.split == "train":
            rgb, depth_scaled, mask = augment_pair(rgb, depth_scaled, mask, self.rng)
        rgb_t = torch.from_numpy(rgb.astype(np.float32) / 255.0).permute(2,0,1)
        mask_t = torch.from_numpy(mask.astype(np.float32)).unsqueeze(0)
        dep_t = torch.from_numpy(depth_scaled.astype(np.float32)).unsqueeze(0)
        if self.use_depth:
            x = torch.cat([rgb_t, dep_t], dim=0)
        else:
            x = rgb_t
        return x, mask_t, str(row["sample_id"])

def sigmoid_thresh(x, thr=0.5):
    return (torch.sigmoid(x) >= thr).float()

def dice_coeff(pred01, gt01, eps=1e-6):
    inter = (pred01 * gt01).sum(dim=(1,2,3))
    denom = pred01.sum(dim=(1,2,3)) + gt01.sum(dim=(1,2,3))
    return ((2.0 * inter + eps) / (denom + eps)).mean().item()

def iou_score(pred01, gt01, eps=1e-6):
    inter = (pred01 * gt01).sum(dim=(1,2,3))
    union = ((pred01 + gt01) > 0).float().sum(dim=(1,2,3))
    return ((inter + eps) / (union + eps)).mean().item()

class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        num = 2.0 * (probs * targets).sum(dim=(1,2,3)) + self.eps
        den = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) + self.eps
        return (1.0 - (num / den)).mean()

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

def center_crop_or_pad_to(x, ref):
    _, _, h, w = x.shape
    _, _, hr, wr = ref.shape
    if h == hr and w == wr:
        return x
    if h > hr or w > wr:
        dh = max(0, h - hr)
        dw = max(0, w - wr)
        top = dh // 2
        left = dw // 2
        x = x[:, :, top:top+hr, left:left+wr]
    _, _, h2, w2 = x.shape
    if h2 < hr or w2 < wr:
        ph = max(0, hr - h2)
        pw = max(0, wr - w2)
        pad = (pw//2, pw - pw//2, ph//2, ph - ph//2)
        x = F.pad(x, pad, mode="replicate")
    return x

class SimpleUNet(nn.Module):
    def __init__(self, in_ch=3, base=32):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base)
        self.pool1 = nn.MaxPool2d(2, ceil_mode=True)
        self.enc2 = ConvBlock(base, base*2)
        self.pool2 = nn.MaxPool2d(2, ceil_mode=True)
        self.enc3 = ConvBlock(base*2, base*4)
        self.pool3 = nn.MaxPool2d(2, ceil_mode=True)
        self.bott = ConvBlock(base*4, base*8)
        self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.dec3 = ConvBlock(base*8, base*4)
        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec2 = ConvBlock(base*4, base*2)
        self.up1 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec1 = ConvBlock(base*2, base)
        self.head = nn.Conv2d(base, 1, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b = self.bott(self.pool3(e3))

        d3 = self.up3(b)
        e3a = center_crop_or_pad_to(e3, d3)
        d3 = self.dec3(torch.cat([d3, e3a], dim=1))

        d2 = self.up2(d3)
        e2a = center_crop_or_pad_to(e2, d2)
        d2 = self.dec2(torch.cat([d2, e2a], dim=1))

        d1 = self.up1(d2)
        e1a = center_crop_or_pad_to(e1, d1)
        d1 = self.dec1(torch.cat([d1, e1a], dim=1))

        out = self.head(d1)
        out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
        return out

class ViTBackbone(nn.Module):
    def __init__(self, name: str, in_chans: int):
        super().__init__()
        if not _HAS_TIMM:
            raise RuntimeError("timm not available")
        self.model = timm.create_model(name, pretrained=True, in_chans=in_chans, features_only=True, out_indices=(3,))
        self.out_ch = self.model.feature_info.channels()[-1]
    def forward(self, x):
        feats = self.model(x)
        return feats[-1]

class CrossAttentionFusion(nn.Module):
    def __init__(self, ch: int, heads: int = 8):
        super().__init__()
        self.q = nn.Conv2d(ch, ch, 1, bias=False)
        self.k = nn.Conv2d(ch, ch, 1, bias=False)
        self.v = nn.Conv2d(ch, ch, 1, bias=False)
        self.attn = nn.MultiheadAttention(embed_dim=ch, num_heads=heads, batch_first=True)
        self.proj = nn.Conv2d(ch, ch, 1, bias=False)
        self.norm = nn.LayerNorm(ch)
    def forward(self, frgb, fdep):
        b, c, h, w = frgb.shape
        q = self.q(frgb).flatten(2).transpose(1,2)
        k = self.k(fdep).flatten(2).transpose(1,2)
        v = self.v(fdep).flatten(2).transpose(1,2)
        qn = self.norm(q)
        out, _ = self.attn(qn, k, v, need_weights=False)
        out = out.transpose(1,2).reshape(b, c, h, w)
        out = self.proj(out)
        return frgb + out

class LiteDecoder(nn.Module):
    def __init__(self, in_ch: int, out_hw: Tuple[int,int]):
        super().__init__()
        self.out_hw = out_hw
        self.conv1 = ConvBlock(in_ch, 256)
        self.up1 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv2 = ConvBlock(128, 128)
        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv3 = ConvBlock(64, 64)
        self.up3 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.conv4 = ConvBlock(32, 32)
        self.head = nn.Conv2d(32, 1, 1)
    def forward(self, x):
        x = self.conv1(x)
        x = self.up1(x)
        x = self.conv2(x)
        x = self.up2(x)
        x = self.conv3(x)
        x = self.up3(x)
        x = self.conv4(x)
        x = self.head(x)
        x = F.interpolate(x, size=self.out_hw, mode="bilinear", align_corners=False)
        return x

class ProposedViTFusion(nn.Module):
    def __init__(self, rgb_vit_name: str, depth_vit_name: str, img_hw: Tuple[int,int]):
        super().__init__()
        if not _HAS_TIMM:
            raise RuntimeError("timm not available")
        self.rgb_enc = ViTBackbone(rgb_vit_name, in_chans=3)
        self.dep_enc = ViTBackbone(depth_vit_name, in_chans=1)
        ch = self.rgb_enc.out_ch
        self.dep_proj = nn.Conv2d(self.dep_enc.out_ch, ch, 1, bias=False) if self.dep_enc.out_ch != ch else nn.Identity()
        self.fuse = CrossAttentionFusion(ch, heads=8)
        self.dec = LiteDecoder(ch, out_hw=img_hw)
    def forward(self, rgb3, dep1):
        fr = self.rgb_enc(rgb3)
        fd = self.dep_proj(self.dep_enc(dep1))
        f = self.fuse(fr, fd)
        return self.dec(f)

def count_params_m(model: nn.Module):
    return sum(p.numel() for p in model.parameters()) / 1e6

def try_flops_g(model: nn.Module, example_inputs: Tuple[torch.Tensor, ...]):
    try:
        from fvcore.nn import FlopCountAnalysis
        flops = FlopCountAnalysis(model, example_inputs).total()
        return float(flops) / 1e9
    except Exception:
        return np.nan

def get_gpu_name():
    if not torch.cuda.is_available():
        return "cpu"
    try:
        return torch.cuda.get_device_name(0)
    except Exception:
        return "cuda"

@dataclass
class TrainConfig:
    epochs: int
    batch_size: int
    lr: float
    weight_decay: float
    optimizer: str
    augmentations: str
    seed: int
    grad_accum: int

def make_loader(split: str, batch_size: int, use_depth: bool):
    ds = ISODSegDataset(df, split=split, target_hw=TARGET_HW, seed=SEED, use_depth=use_depth)
    shuffle = (split == "train")
    dl = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=2, pin_memory=True, drop_last=False)
    return ds, dl

def train_one_model(model_key: str, model: nn.Module, use_depth_input: bool, cfg: TrainConfig):
    out_dir = PHASE4_DIR / model_key
    out_dir.mkdir(parents=True, exist_ok=True)
    ckpt_dir = out_dir / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    model = model.to(DEVICE)

    _, dl_tr = make_loader("train", cfg.batch_size, use_depth=use_depth_input)
    _, dl_va = make_loader("val", cfg.batch_size, use_depth=use_depth_input)

    if cfg.optimizer.lower() == "adamw":
        opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    else:
        opt = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    bce = nn.BCEWithLogitsLoss()
    dice = DiceLoss()
    scaler = torch.amp.GradScaler("cuda", enabled=torch.cuda.is_available())

    best_metric = -1.0
    best_epoch = -1
    best_path = ""

    logs = []

    start_time = time.time()
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

    for epoch in range(1, cfg.epochs + 1):
        model.train()
        tr_loss_sum = 0.0
        tr_steps = 0
        opt.zero_grad(set_to_none=True)

        for step, (x, y, sids) in enumerate(dl_tr, start=1):
            y = y.to(DEVICE, non_blocking=True)

            if model_key == "proposed_vit_fusion":
                rgb = x[:, :3].to(DEVICE, non_blocking=True)
                dep = x[:, 3:4].to(DEVICE, non_blocking=True)
                with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
                    logits = model(rgb, dep)
                    loss = 0.5 * bce(logits, y) + 0.5 * dice(logits, y)
            else:
                x = x.to(DEVICE, non_blocking=True)
                with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
                    logits = model(x)
                    loss = 0.5 * bce(logits, y) + 0.5 * dice(logits, y)

            loss = loss / max(1, cfg.grad_accum)
            scaler.scale(loss).backward()

            if (step % cfg.grad_accum) == 0 or step == len(dl_tr):
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)

            tr_loss_sum += loss.item() * max(1, cfg.grad_accum)
            tr_steps += 1

        tr_loss = tr_loss_sum / max(1, tr_steps)

        model.eval()
        va_loss_sum = 0.0
        va_steps = 0
        dice_list = []
        iou_list = []

        with torch.no_grad():
            for x, y, sids in dl_va:
                y = y.to(DEVICE, non_blocking=True)

                if model_key == "proposed_vit_fusion":
                    rgb = x[:, :3].to(DEVICE, non_blocking=True)
                    dep = x[:, 3:4].to(DEVICE, non_blocking=True)
                    with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
                        logits = model(rgb, dep)
                        loss = 0.5 * bce(logits, y) + 0.5 * dice(logits, y)
                else:
                    x = x.to(DEVICE, non_blocking=True)
                    with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
                        logits = model(x)
                        loss = 0.5 * bce(logits, y) + 0.5 * dice(logits, y)

                va_loss_sum += loss.item()
                va_steps += 1
                pred01 = sigmoid_thresh(logits, 0.5)
                dice_list.append(dice_coeff(pred01, y))
                iou_list.append(iou_score(pred01, y))

        va_loss = va_loss_sum / max(1, va_steps)
        va_dice = float(np.mean(dice_list)) if len(dice_list) else np.nan
        va_iou = float(np.mean(iou_list)) if len(iou_list) else np.nan

        logs.append({
            "model": model_key,
            "epoch": epoch,
            "train_loss": float(tr_loss),
            "val_loss": float(va_loss),
            "val_miou": float(va_iou),
            "val_dice": float(va_dice),
        })

        ckpt_path = str(ckpt_dir / f"epoch_{epoch:03d}.pt")
        torch.save({
            "model_key": model_key,
            "epoch": epoch,
            "state_dict": model.state_dict(),
            "cfg": cfg.__dict__,
            "target_hw": TARGET_HW,
            "seed": cfg.seed
        }, ckpt_path)

        if (not np.isnan(va_iou)) and (va_iou > best_metric):
            best_metric = va_iou
            best_epoch = epoch
            best_path = ckpt_path

    train_time_min = (time.time() - start_time) / 60.0
    peak_vram_gb = (torch.cuda.max_memory_allocated() / (1024**3)) if torch.cuda.is_available() else 0.0

    df_logs = pd.DataFrame(logs)
    df_logs_path = out_dir / "epoch_logs.csv"
    df_logs.to_csv(df_logs_path, index=False)

    return {
        "model": model_key,
        "out_dir": str(out_dir),
        "logs_path": str(df_logs_path),
        "best_ckpt_path": best_path,
        "best_epoch": int(best_epoch),
        "best_metric": float(best_metric),
        "train_time_min": float(train_time_min),
        "peak_vram_gb": float(peak_vram_gb),
    }, df_logs

def predict_on_samples(model_key: str, model: nn.Module, ckpt_path: str, sample_ids: List[str]):
    model = model.to(DEVICE)
    ckpt = torch.load(ckpt_path, map_location=DEVICE)
    model.load_state_dict(ckpt["state_dict"], strict=True)
    model.eval()

    lut = {df.loc[i, "sample_id"]: df.loc[i] for i in range(len(df))}
    out = []

    for sid in sample_ids:
        row = lut[sid]
        rgb = read_rgb(row["rgb_path"])
        depth = read_depth(row["depth_path"])
        mask = read_mask(row["mask_path"])
        rgb, depth, mask = crop_to_min_hw(rgb, depth, mask)
        rgb, depth, mask = resize_triplet(rgb, depth, mask, TARGET_HW)
        depth = fill_depth_holes(depth)
        depth_scaled = robust_depth_scale(depth)

        rgb_t = torch.from_numpy(rgb.astype(np.float32) / 255.0).permute(2,0,1).unsqueeze(0).to(DEVICE)
        dep_t = torch.from_numpy(depth_scaled.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(DEVICE)
        gt = mask.astype(np.uint8)

        with torch.no_grad():
            if model_key == "proposed_vit_fusion":
                logits = model(rgb_t, dep_t)
            else:
                x = rgb_t if model_key == "baseline_a_rgb" else torch.cat([rgb_t, dep_t], dim=1)
                logits = model(x)
            pred = (torch.sigmoid(logits)[0,0].detach().cpu().numpy() >= 0.5).astype(np.uint8)

        out.append((sid, rgb, gt, pred))
    return out

def save_qual_grid(sample_ids: List[str], predA, predB, predP, path: Path):
    mapA = {sid: (rgb, gt, pr) for sid, rgb, gt, pr in predA}
    mapB = {sid: pr for sid, rgb, gt, pr in predB}
    mapP = {sid: pr for sid, rgb, gt, pr in predP}
    cols = 5
    rows = len(sample_ids)
    plt.figure(figsize=(3.2 * cols, 2.8 * rows))
    for r, sid in enumerate(sample_ids):
        rgb, gt, pA = mapA[sid]
        pB = mapB[sid]
        pP = mapP[sid]
        ax = plt.subplot(rows, cols, r * cols + 1)
        ax.imshow(rgb)
        ax.set_title(f"{sid}")
        ax.axis("off")
        ax = plt.subplot(rows, cols, r * cols + 2)
        ax.imshow(gt, cmap="gray")
        ax.set_title("GT")
        ax.axis("off")
        ax = plt.subplot(rows, cols, r * cols + 3)
        ax.imshow(pA, cmap="gray")
        ax.set_title("Baseline A")
        ax.axis("off")
        ax = plt.subplot(rows, cols, r * cols + 4)
        ax.imshow(pB, cmap="gray")
        ax.set_title("Baseline B")
        ax.axis("off")
        ax = plt.subplot(rows, cols, r * cols + 5)
        ax.imshow(pP, cmap="gray")
        ax.set_title("Proposed")
        ax.axis("off")
    save_fig(path, dpi=300)

def draw_architecture_diagram(path: Path):
    import matplotlib.patches as patches
    plt.figure(figsize=(12, 4.2))
    ax = plt.gca()
    ax.set_xlim(0, 12)
    ax.set_ylim(0, 4.2)
    ax.axis("off")
    def box(x, y, w, h, text):
        rect = patches.FancyBboxPatch((x,y), w, h, boxstyle="round,pad=0.02,rounding_size=0.08", linewidth=1.2, facecolor="white")
        ax.add_patch(rect)
        ax.text(x + w/2, y + h/2, text, ha="center", va="center")
    def arrow(x1, y1, x2, y2):
        ax.annotate("", xy=(x2,y2), xytext=(x1,y1), arrowprops=dict(arrowstyle="->", linewidth=1.2))
    box(0.4, 2.6, 2.0, 0.9, "RGB\n(3ch)")
    box(0.4, 0.6, 2.0, 0.9, "Depth\n(1ch)")
    box(2.8, 2.6, 2.2, 0.9, "ViT encoder\nRGB")
    box(2.8, 0.6, 2.2, 0.9, "Transformer encoder\nDepth")
    box(5.4, 1.6, 2.1, 1.0, "Cross-attention\nfusion")
    box(8.0, 1.6, 2.0, 1.0, "Decoder\nupsample")
    box(10.4, 1.6, 1.2, 1.0, "Mask\nlogits")
    arrow(2.4, 3.05, 2.8, 3.05)
    arrow(2.4, 1.05, 2.8, 1.05)
    arrow(5.0, 3.05, 5.4, 2.1)
    arrow(5.0, 1.05, 5.4, 2.1)
    arrow(7.5, 2.1, 8.0, 2.1)
    arrow(10.0, 2.1, 10.4, 2.1)
    ax.text(0.4, 4.0, "Proposed multimodal ViT fusion segmentation model", ha="left", va="center", fontsize=12)
    save_fig(path, dpi=300)

BASELINE_EPOCHS = 30
PROPOSED_EPOCHS = 30
BATCH_SIZE = 8
LR = 2e-4
WD = 1e-4
OPT = "adamw"
AUG = "hflip, brightness, rgb_noise"
GRAD_ACCUM = 1

baseline_a_cfg = TrainConfig(epochs=BASELINE_EPOCHS, batch_size=BATCH_SIZE, lr=LR, weight_decay=WD, optimizer=OPT, augmentations=AUG, seed=SEED, grad_accum=GRAD_ACCUM)
baseline_b_cfg = TrainConfig(epochs=BASELINE_EPOCHS, batch_size=BATCH_SIZE, lr=LR, weight_decay=WD, optimizer=OPT, augmentations=AUG, seed=SEED, grad_accum=GRAD_ACCUM)
proposed_cfg   = TrainConfig(epochs=PROPOSED_EPOCHS, batch_size=BATCH_SIZE, lr=LR, weight_decay=WD, optimizer=OPT, augmentations=AUG, seed=SEED, grad_accum=GRAD_ACCUM)

baseline_a = SimpleUNet(in_ch=3, base=32)
baseline_b = SimpleUNet(in_ch=4, base=32)

if not _HAS_TIMM:
    raise RuntimeError("timm is required for proposed model")

proposed = ProposedViTFusion("vit_small_patch16_224", "vit_small_patch16_224", TARGET_HW)

gpu_type = get_gpu_name()

ex_rgb = torch.randn(1, 3, TARGET_HW[0], TARGET_HW[1], device=DEVICE)
ex_rgbd = torch.randn(1, 4, TARGET_HW[0], TARGET_HW[1], device=DEVICE)
ex_dep = torch.randn(1, 1, TARGET_HW[0], TARGET_HW[1], device=DEVICE)

size_rows = []
def record_model_size(model_key: str, model_obj: nn.Module, example_inputs: Tuple[torch.Tensor, ...], notes: str):
    params_m = count_params_m(model_obj)
    flops_g = try_flops_g(model_obj.to(DEVICE), example_inputs)
    size_rows.append({
        "model": model_key,
        "params_m": float(params_m),
        "flops_g": float(flops_g) if not np.isnan(flops_g) else np.nan,
        "input_resolution": str(TARGET_HW),
        "notes": notes
    })

record_model_size("baseline_a_rgb", baseline_a, (ex_rgb,), "SimpleUNet base=32, rgb only")
record_model_size("baseline_b_rgbd", baseline_b, (ex_rgbd,), "SimpleUNet base=32, early fusion 4ch")
record_model_size("proposed_vit_fusion", proposed, (ex_rgb, ex_dep), "dual ViT + cross-attention + lite decoder")

mA_info, mA_logs = train_one_model("baseline_a_rgb", baseline_a, use_depth_input=False, cfg=baseline_a_cfg)
mB_info, mB_logs = train_one_model("baseline_b_rgbd", baseline_b, use_depth_input=True, cfg=baseline_b_cfg)
mP_info, mP_logs = train_one_model("proposed_vit_fusion", proposed, use_depth_input=True, cfg=proposed_cfg)

budget_rows = []
for info, cfg, tag in [(mA_info, baseline_a_cfg, "baseline_a_rgb"),
                       (mB_info, baseline_b_cfg, "baseline_b_rgbd"),
                       (mP_info, proposed_cfg, "proposed_vit_fusion")]:
    budget_rows.append({
        "model": tag,
        "epochs": int(cfg.epochs),
        "batch_size": int(cfg.batch_size),
        "lr": float(cfg.lr),
        "optimizer": cfg.optimizer,
        "augmentations": cfg.augmentations,
        "gpu_type": gpu_type,
        "train_time_min": float(info["train_time_min"]),
        "peak_vram_gb": float(info["peak_vram_gb"])
    })

df_budget = pd.DataFrame(budget_rows)
df_size = pd.DataFrame(size_rows)
df_best = pd.DataFrame([
    {"model": "baseline_a_rgb", "checkpoint_path": mA_info["best_ckpt_path"], "val_best_epoch": mA_info["best_epoch"], "val_best_metric": mA_info["best_metric"], "seed": int(SEED)},
    {"model": "baseline_b_rgbd", "checkpoint_path": mB_info["best_ckpt_path"], "val_best_epoch": mB_info["best_epoch"], "val_best_metric": mB_info["best_metric"], "seed": int(SEED)},
    {"model": "proposed_vit_fusion", "checkpoint_path": mP_info["best_ckpt_path"], "val_best_epoch": mP_info["best_epoch"], "val_best_metric": mP_info["best_metric"], "seed": int(SEED)},
])
df_logs_all = pd.concat([mA_logs, mB_logs, mP_logs], axis=0).reset_index(drop=True)

TAB_01_BUDGET = PHASE4_DIR / "tab_01_training_budget_table.csv"
TAB_02_SIZE = PHASE4_DIR / "tab_02_model_size_compute_table.csv"
TAB_03_BEST = PHASE4_DIR / "tab_03_best_checkpoint_index.csv"
TAB_04_LOGS = PHASE4_DIR / "tab_04_batch_level_logs_summary.csv"

df_budget.to_csv(TAB_01_BUDGET, index=False)
df_size.to_csv(TAB_02_SIZE, index=False)
df_best.to_csv(TAB_03_BEST, index=False)
df_logs_all.to_csv(TAB_04_LOGS, index=False)

def plot_loss(df_logs: pd.DataFrame, models: List[str], path: Path, title: str):
    plt.figure(figsize=(10.5, 4.8))
    for m in models:
        d = df_logs[df_logs["model"] == m].sort_values("epoch")
        plt.plot(d["epoch"], d["train_loss"], label=f"{m} train")
        plt.plot(d["epoch"], d["val_loss"], label=f"{m} val")
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    save_fig(path, dpi=300)

def plot_metrics(df_logs: pd.DataFrame, models: List[str], path: Path, title: str):
    plt.figure(figsize=(10.5, 4.8))
    for m in models:
        d = df_logs[df_logs["model"] == m].sort_values("epoch")
        plt.plot(d["epoch"], d["val_miou"], label=f"{m} mIoU")
        plt.plot(d["epoch"], d["val_dice"], label=f"{m} Dice")
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel("Metric")
    plt.ylim(0, 1.0)
    plt.legend()
    save_fig(path, dpi=300)

plot_loss(df_logs_all, ["baseline_a_rgb", "baseline_b_rgbd"], PHASE4_DIR / "fig_01_training_curves_baselines.png", "Baselines loss curves")
plot_metrics(df_logs_all, ["baseline_a_rgb", "baseline_b_rgbd"], PHASE4_DIR / "fig_01_training_curves_baselines_metrics.png", "Baselines validation metrics")
plot_loss(df_logs_all, ["proposed_vit_fusion"], PHASE4_DIR / "fig_02_training_curves_proposed.png", "Proposed loss curves")
plot_metrics(df_logs_all, ["proposed_vit_fusion"], PHASE4_DIR / "fig_02_training_curves_proposed_metrics.png", "Proposed validation metrics")

if FIG_SAMPLE_IDS_PATH.exists():
    df_figids = pd.read_csv(FIG_SAMPLE_IDS_PATH)
    pick = df_figids[df_figids["figure_name"].astype(str).str.contains("fig_01_rgb_depth_mask_overlay_grid", na=False)]
    if len(pick):
        ids_str = str(pick.iloc[0]["sample_id_list"])
        qual_ids = [x for x in ids_str.split("|") if x.strip() != ""]
    else:
        qual_ids = df[df["split"] == "test"].sort_values("sample_id", kind="mergesort").head(8)["sample_id"].tolist()
else:
    qual_ids = df[df["split"] == "test"].sort_values("sample_id", kind="mergesort").head(8)["sample_id"].tolist()

qual_ids = list(dict.fromkeys(qual_ids))[:8]
if len(qual_ids) < 4:
    qual_ids = df[df["split"] == "test"].sort_values("sample_id", kind="mergesort").head(8)["sample_id"].tolist()

mA = SimpleUNet(in_ch=3, base=32)
mB = SimpleUNet(in_ch=4, base=32)
mP = ProposedViTFusion("vit_small_patch16_224", "vit_small_patch16_224", TARGET_HW)

ckA = df_best[df_best["model"] == "baseline_a_rgb"]["checkpoint_path"].iloc[0]
ckB = df_best[df_best["model"] == "baseline_b_rgbd"]["checkpoint_path"].iloc[0]
ckP = df_best[df_best["model"] == "proposed_vit_fusion"]["checkpoint_path"].iloc[0]

predA = predict_on_samples("baseline_a_rgb", mA, ckA, qual_ids)
predB = predict_on_samples("baseline_b_rgbd", mB, ckB, qual_ids)
predP = predict_on_samples("proposed_vit_fusion", mP, ckP, qual_ids)

save_qual_grid(qual_ids, predA, predB, predP, PHASE4_DIR / "fig_03_qualitative_comparison_grid.png")
draw_architecture_diagram(PHASE4_DIR / "fig_04_model_architecture_diagram.png")




In [None]:
import os
import time
import math
import random
from pathlib import Path
from dataclasses import dataclass
from typing import Tuple, List, Dict

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import timm

WORK_ROOT = Path("/kaggle/working")
PHASE1_DIR = WORK_ROOT / "results" / "01_data_integrity_and_profile"
PHASE2_DIR = WORK_ROOT / "results" / "02_sensor_alignment_and_quality"
PHASE3_DIR = WORK_ROOT / "results" / "03_preprocess_and_split_protocol"
PHASE4_DIR = WORK_ROOT / "results" / "04_training_baselines_and_vit_fusion"
PHASE4_DIR.mkdir(parents=True, exist_ok=True)

MANIFEST_VALID_PATH = PHASE1_DIR / "tab_01_dataset_manifest_valid_only.csv"
SPLIT_MANIFEST_PATH = PHASE3_DIR / "tab_03_split_manifest.csv"
PREPROCESS_CONFIG_PATH = PHASE3_DIR / "tab_01_preprocess_config_table.csv"
FIG_SAMPLE_IDS_PATH = PHASE2_DIR / "tab_04_figure_sample_ids.csv"

df_valid = pd.read_csv(MANIFEST_VALID_PATH).sort_values("sample_id", kind="mergesort").reset_index(drop=True)
df_split = pd.read_csv(SPLIT_MANIFEST_PATH).sort_values("sample_id", kind="mergesort").reset_index(drop=True)
df = df_valid.merge(df_split, on=["sample_id", "site_id"], how="inner").sort_values("sample_id", kind="mergesort").reset_index(drop=True)

TARGET_HW = (422, 640)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 1337

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(SEED)

plt.rcParams.update({
    "font.size": 10,
    "axes.titlesize": 11,
    "axes.labelsize": 10,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 9,
})

def save_fig(path: Path, dpi=300):
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

def read_rgb(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 2:
        arr = np.stack([arr, arr, arr], axis=2)
    if arr.shape[2] > 3:
        arr = arr[:, :, :3]
    return arr.astype(np.uint8)

def read_depth(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 3:
        arr = arr[:, :, 0]
    return arr

def read_mask(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 3:
        arr = np.any(arr > 0, axis=2).astype(np.uint8)
    else:
        arr = (arr > 0).astype(np.uint8)
    return arr

def crop_to_min_hw(rgb, depth, mask):
    h = min(rgb.shape[0], depth.shape[0], mask.shape[0])
    w = min(rgb.shape[1], depth.shape[1], mask.shape[1])
    return rgb[:h, :w], depth[:h, :w], mask[:h, :w]

def resize_triplet(rgb, depth, mask01, out_hw):
    oh, ow = int(out_hw[0]), int(out_hw[1])
    rgb_r = cv2.resize(rgb, (ow, oh), interpolation=cv2.INTER_LINEAR)
    depth_f = depth.astype(np.float32)
    depth_r = cv2.resize(depth_f, (ow, oh), interpolation=cv2.INTER_NEAREST)
    mask_r = cv2.resize(mask01.astype(np.uint8), (ow, oh), interpolation=cv2.INTER_NEAREST)
    mask_r = (mask_r > 0).astype(np.uint8)
    return rgb_r, depth_r, mask_r

def fill_depth_holes(depth: np.ndarray):
    d = depth.astype(np.float32)
    valid = (d > 0) & np.isfinite(d)
    if np.count_nonzero(valid) < 10:
        return d
    invalid = (~valid).astype(np.uint8)
    d0 = d.copy()
    d0[~valid] = 0.0
    filled = cv2.inpaint(d0, invalid, 3, cv2.INPAINT_TELEA)
    filled[~np.isfinite(filled)] = 0.0
    return filled

def robust_depth_scale(depth: np.ndarray):
    d = depth.astype(np.float32)
    valid = (d > 0) & np.isfinite(d)
    if np.count_nonzero(valid) < 10:
        return np.zeros_like(d, dtype=np.float32)
    vv = d[valid]
    mn, mx = float(np.percentile(vv, 1)), float(np.percentile(vv, 99))
    if mx - mn < 1e-6:
        out = np.zeros_like(d, dtype=np.float32)
        out[valid] = 0.5
        return out
    out = (d - mn) / (mx - mn)
    out = np.clip(out, 0.0, 1.0)
    out[~valid] = 0.0
    return out.astype(np.float32)

def augment_pair(rgb_u8, depth_f, mask_u8, rng: np.random.Generator):
    if rng.random() < 0.5:
        rgb_u8 = np.ascontiguousarray(rgb_u8[:, ::-1])
        depth_f = np.ascontiguousarray(depth_f[:, ::-1])
        mask_u8 = np.ascontiguousarray(mask_u8[:, ::-1])
    if rng.random() < 0.25:
        factor = 0.8 + 0.4 * rng.random()
        rgb_f = rgb_u8.astype(np.float32) * factor
        rgb_u8 = np.clip(rgb_f, 0, 255).astype(np.uint8)
    if rng.random() < 0.15:
        noise = rng.normal(0, 3.0, size=rgb_u8.shape).astype(np.float32)
        rgb_u8 = np.clip(rgb_u8.astype(np.float32) + noise, 0, 255).astype(np.uint8)
    return rgb_u8, depth_f, mask_u8

class ISODSegDataset(Dataset):
    def __init__(self, df_in: pd.DataFrame, split: str, target_hw: Tuple[int,int], seed: int, use_depth: bool, vit_hw: Tuple[int,int]):
        self.df = df_in[df_in["split"] == split].sort_values("sample_id", kind="mergesort").reset_index(drop=True)
        self.target_hw = target_hw
        self.use_depth = use_depth
        self.rng = np.random.default_rng(seed + (0 if split == "train" else 999))
        self.split = split
        self.vit_hw = vit_hw

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.loc[idx]
        rgb = read_rgb(row["rgb_path"])
        depth = read_depth(row["depth_path"])
        mask = read_mask(row["mask_path"])
        rgb, depth, mask = crop_to_min_hw(rgb, depth, mask)
        rgb, depth, mask = resize_triplet(rgb, depth, mask, self.target_hw)
        depth = fill_depth_holes(depth)
        depth_scaled = robust_depth_scale(depth)
        if self.split == "train":
            rgb, depth_scaled, mask = augment_pair(rgb, depth_scaled, mask, self.rng)
        rgb_t = torch.from_numpy(rgb.astype(np.float32) / 255.0).permute(2,0,1)
        mask_t = torch.from_numpy(mask.astype(np.float32)).unsqueeze(0)
        dep_t = torch.from_numpy(depth_scaled.astype(np.float32)).unsqueeze(0)
        x = torch.cat([rgb_t, dep_t], dim=0) if self.use_depth else rgb_t
        return x, mask_t, str(row["sample_id"])

def sigmoid_thresh(x, thr=0.5):
    return (torch.sigmoid(x) >= thr).float()

def dice_coeff(pred01, gt01, eps=1e-6):
    inter = (pred01 * gt01).sum(dim=(1,2,3))
    denom = pred01.sum(dim=(1,2,3)) + gt01.sum(dim=(1,2,3))
    return ((2.0 * inter + eps) / (denom + eps)).mean().item()

def iou_score(pred01, gt01, eps=1e-6):
    inter = (pred01 * gt01).sum(dim=(1,2,3))
    union = ((pred01 + gt01) > 0).float().sum(dim=(1,2,3))
    return ((inter + eps) / (union + eps)).mean().item()

class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.eps = eps
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        num = 2.0 * (probs * targets).sum(dim=(1,2,3)) + self.eps
        den = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) + self.eps
        return (1.0 - (num / den)).mean()

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

def center_crop_or_pad_to(x, ref):
    _, _, h, w = x.shape
    _, _, hr, wr = ref.shape
    if h == hr and w == wr:
        return x
    if h > hr or w > wr:
        dh = max(0, h - hr)
        dw = max(0, w - wr)
        top = dh // 2
        left = dw // 2
        x = x[:, :, top:top+hr, left:left+wr]
    _, _, h2, w2 = x.shape
    if h2 < hr or w2 < wr:
        ph = max(0, hr - h2)
        pw = max(0, wr - w2)
        pad = (pw//2, pw - pw//2, ph//2, ph - ph//2)
        x = F.pad(x, pad, mode="replicate")
    return x

class SimpleUNet(nn.Module):
    def __init__(self, in_ch=3, base=32):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base)
        self.pool1 = nn.MaxPool2d(2, ceil_mode=True)
        self.enc2 = ConvBlock(base, base*2)
        self.pool2 = nn.MaxPool2d(2, ceil_mode=True)
        self.enc3 = ConvBlock(base*2, base*4)
        self.pool3 = nn.MaxPool2d(2, ceil_mode=True)
        self.bott = ConvBlock(base*4, base*8)
        self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.dec3 = ConvBlock(base*8, base*4)
        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec2 = ConvBlock(base*4, base*2)
        self.up1 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec1 = ConvBlock(base*2, base)
        self.head = nn.Conv2d(base, 1, 1)
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b = self.bott(self.pool3(e3))
        d3 = self.up3(b)
        e3a = center_crop_or_pad_to(e3, d3)
        d3 = self.dec3(torch.cat([d3, e3a], dim=1))
        d2 = self.up2(d3)
        e2a = center_crop_or_pad_to(e2, d2)
        d2 = self.dec2(torch.cat([d2, e2a], dim=1))
        d1 = self.up1(d2)
        e1a = center_crop_or_pad_to(e1, d1)
        d1 = self.dec1(torch.cat([d1, e1a], dim=1))
        out = self.head(d1)
        out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
        return out

class ViTEncoderFeature(nn.Module):
    def __init__(self, model_name: str, in_chans: int, img_hw: Tuple[int,int]):
        super().__init__()
        self.backbone = timm.create_model(
            model_name,
            pretrained=True,
            in_chans=in_chans,
            img_size=img_hw,
            dynamic_img_size=True,
            dynamic_img_pad=True,
            features_only=True,
            out_indices=(3,)
        )
        self.out_ch = self.backbone.feature_info.channels()[-1]
    def forward(self, x):
        feats = self.backbone(x)
        return feats[-1]

class CrossAttentionFusion(nn.Module):
    def __init__(self, ch: int, heads: int = 8):
        super().__init__()
        self.q = nn.Conv2d(ch, ch, 1, bias=False)
        self.k = nn.Conv2d(ch, ch, 1, bias=False)
        self.v = nn.Conv2d(ch, ch, 1, bias=False)
        self.attn = nn.MultiheadAttention(embed_dim=ch, num_heads=heads, batch_first=True)
        self.proj = nn.Conv2d(ch, ch, 1, bias=False)
        self.norm = nn.LayerNorm(ch)
    def forward(self, frgb, fdep):
        b, c, h, w = frgb.shape
        q = self.q(frgb).flatten(2).transpose(1,2)
        k = self.k(fdep).flatten(2).transpose(1,2)
        v = self.v(fdep).flatten(2).transpose(1,2)
        qn = self.norm(q)
        out, _ = self.attn(qn, k, v, need_weights=False)
        out = out.transpose(1,2).reshape(b, c, h, w)
        out = self.proj(out)
        return frgb + out

class LiteDecoder(nn.Module):
    def __init__(self, in_ch: int, out_hw: Tuple[int,int]):
        super().__init__()
        self.out_hw = out_hw
        self.conv1 = ConvBlock(in_ch, 256)
        self.up1 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv2 = ConvBlock(128, 128)
        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv3 = ConvBlock(64, 64)
        self.up3 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.conv4 = ConvBlock(32, 32)
        self.head = nn.Conv2d(32, 1, 1)
    def forward(self, x):
        x = self.conv1(x)
        x = self.up1(x)
        x = self.conv2(x)
        x = self.up2(x)
        x = self.conv3(x)
        x = self.up3(x)
        x = self.conv4(x)
        x = self.head(x)
        x = F.interpolate(x, size=self.out_hw, mode="bilinear", align_corners=False)
        return x

class ProposedViTFusion(nn.Module):
    def __init__(self, rgb_name: str, dep_name: str, full_hw: Tuple[int,int], vit_hw: Tuple[int,int]):
        super().__init__()
        self.full_hw = full_hw
        self.vit_hw = vit_hw
        self.rgb_enc = ViTEncoderFeature(rgb_name, in_chans=3, img_hw=vit_hw)
        self.dep_enc = ViTEncoderFeature(dep_name, in_chans=1, img_hw=vit_hw)
        ch = self.rgb_enc.out_ch
        self.dep_proj = nn.Conv2d(self.dep_enc.out_ch, ch, 1, bias=False) if self.dep_enc.out_ch != ch else nn.Identity()
        self.fuse = CrossAttentionFusion(ch, heads=8)
        self.dec = LiteDecoder(ch, out_hw=full_hw)
    def forward(self, rgb3, dep1):
        rgb_small = F.interpolate(rgb3, size=self.vit_hw, mode="bilinear", align_corners=False)
        dep_small = F.interpolate(dep1, size=self.vit_hw, mode="bilinear", align_corners=False)
        fr = self.rgb_enc(rgb_small)
        fd = self.dep_proj(self.dep_enc(dep_small))
        f = self.fuse(fr, fd)
        return self.dec(f)

@dataclass
class TrainConfig:
    epochs: int
    batch_size: int
    lr: float
    weight_decay: float
    optimizer: str
    augmentations: str
    seed: int
    grad_accum: int

def make_loader(split: str, batch_size: int, use_depth: bool, vit_hw: Tuple[int,int]):
    ds = ISODSegDataset(df, split=split, target_hw=TARGET_HW, seed=SEED, use_depth=use_depth, vit_hw=vit_hw)
    shuffle = (split == "train")
    dl = DataLoader(ds, batch_size=batch_size, shuffle=shuffle, num_workers=2, pin_memory=True, drop_last=False)
    return ds, dl

def dice_loss(logits, targets, eps=1e-6):
    probs = torch.sigmoid(logits)
    num = 2.0 * (probs * targets).sum(dim=(1,2,3)) + eps
    den = probs.sum(dim=(1,2,3)) + targets.sum(dim=(1,2,3)) + eps
    return (1.0 - (num / den)).mean()

def train_one_model(model_key: str, model: nn.Module, use_depth_input: bool, cfg: TrainConfig, vit_hw: Tuple[int,int], resume_ckpt: str = ""):
    out_dir = PHASE4_DIR / model_key
    out_dir.mkdir(parents=True, exist_ok=True)
    ckpt_dir = out_dir / "checkpoints"
    ckpt_dir.mkdir(parents=True, exist_ok=True)

    model = model.to(DEVICE)

    _, dl_tr = make_loader("train", cfg.batch_size, use_depth=use_depth_input, vit_hw=vit_hw)
    _, dl_va = make_loader("val", cfg.batch_size, use_depth=use_depth_input, vit_hw=vit_hw)

    if cfg.optimizer.lower() == "adamw":
        opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    else:
        opt = torch.optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    start_epoch = 1
    best_metric = -1.0
    best_epoch = -1
    best_path = ""

    if resume_ckpt and Path(resume_ckpt).exists():
        ck = torch.load(resume_ckpt, map_location=DEVICE)
        model.load_state_dict(ck["state_dict"], strict=True)
        if "opt_state" in ck and ck["opt_state"] is not None:
            opt.load_state_dict(ck["opt_state"])
        start_epoch = int(ck.get("epoch", 0)) + 1
        best_metric = float(ck.get("best_metric", -1.0))
        best_epoch = int(ck.get("best_epoch", -1))
        best_path = str(ck.get("best_path", ""))

    logs_path = out_dir / "epoch_logs.csv"
    if logs_path.exists():
        df_logs = pd.read_csv(logs_path)
        logs = df_logs.to_dict(orient="records")
    else:
        logs = []

    scaler = torch.amp.GradScaler("cuda", enabled=torch.cuda.is_available())

    start_time = time.time()
    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()

    for epoch in range(start_epoch, cfg.epochs + 1):
        model.train()
        tr_loss_sum = 0.0
        tr_steps = 0
        opt.zero_grad(set_to_none=True)

        for step, (x, y, sids) in enumerate(dl_tr, start=1):
            y = y.to(DEVICE, non_blocking=True)

            if model_key == "proposed_vit_fusion":
                rgb = x[:, :3].to(DEVICE, non_blocking=True)
                dep = x[:, 3:4].to(DEVICE, non_blocking=True)
                with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
                    logits = model(rgb, dep)
                    loss = 0.5 * F.binary_cross_entropy_with_logits(logits, y) + 0.5 * dice_loss(logits, y)
            else:
                x = x.to(DEVICE, non_blocking=True)
                with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
                    logits = model(x)
                    loss = 0.5 * F.binary_cross_entropy_with_logits(logits, y) + 0.5 * dice_loss(logits, y)

            loss = loss / max(1, cfg.grad_accum)
            scaler.scale(loss).backward()

            if (step % cfg.grad_accum) == 0 or step == len(dl_tr):
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)

            tr_loss_sum += loss.item() * max(1, cfg.grad_accum)
            tr_steps += 1

        tr_loss = tr_loss_sum / max(1, tr_steps)

        model.eval()
        va_loss_sum = 0.0
        va_steps = 0
        dice_list = []
        iou_list = []

        with torch.no_grad():
            for x, y, sids in dl_va:
                y = y.to(DEVICE, non_blocking=True)
                if model_key == "proposed_vit_fusion":
                    rgb = x[:, :3].to(DEVICE, non_blocking=True)
                    dep = x[:, 3:4].to(DEVICE, non_blocking=True)
                    with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
                        logits = model(rgb, dep)
                        loss = 0.5 * F.binary_cross_entropy_with_logits(logits, y) + 0.5 * dice_loss(logits, y)
                else:
                    x = x.to(DEVICE, non_blocking=True)
                    with torch.amp.autocast("cuda", enabled=torch.cuda.is_available()):
                        logits = model(x)
                        loss = 0.5 * F.binary_cross_entropy_with_logits(logits, y) + 0.5 * dice_loss(logits, y)

                va_loss_sum += loss.item()
                va_steps += 1
                pred01 = (torch.sigmoid(logits) >= 0.5).float()
                dice_list.append(dice_coeff(pred01, y))
                iou_list.append(iou_score(pred01, y))

        va_loss = va_loss_sum / max(1, va_steps)
        va_dice = float(np.mean(dice_list)) if len(dice_list) else np.nan
        va_iou = float(np.mean(iou_list)) if len(iou_list) else np.nan

        logs.append({
            "model": model_key,
            "epoch": int(epoch),
            "train_loss": float(tr_loss),
            "val_loss": float(va_loss),
            "val_miou": float(va_iou),
            "val_dice": float(va_dice),
        })
        pd.DataFrame(logs).to_csv(logs_path, index=False)

        ckpt_path = str(ckpt_dir / f"epoch_{epoch:03d}.pt")
        torch.save({
            "model_key": model_key,
            "epoch": int(epoch),
            "state_dict": model.state_dict(),
            "opt_state": opt.state_dict(),
            "cfg": cfg.__dict__,
            "target_hw": TARGET_HW,
            "vit_hw": vit_hw,
            "seed": cfg.seed,
            "best_metric": float(best_metric),
            "best_epoch": int(best_epoch),
            "best_path": str(best_path),
        }, ckpt_path)

        if (not np.isnan(va_iou)) and (va_iou > best_metric):
            best_metric = va_iou
            best_epoch = epoch
            best_path = ckpt_path

    train_time_min = (time.time() - start_time) / 60.0
    peak_vram_gb = (torch.cuda.max_memory_allocated() / (1024**3)) if torch.cuda.is_available() else 0.0

    return {
        "model": model_key,
        "out_dir": str(out_dir),
        "logs_path": str(logs_path),
        "best_ckpt_path": str(best_path),
        "best_epoch": int(best_epoch),
        "best_metric": float(best_metric),
        "train_time_min": float(train_time_min),
        "peak_vram_gb": float(peak_vram_gb),
        "vit_hw": vit_hw,
    }

def latest_checkpoint(ckpt_dir: Path):
    if not ckpt_dir.exists():
        return ""
    pts = sorted(ckpt_dir.glob("epoch_*.pt"))
    if not pts:
        return ""
    def ep(p):
        m = re.search(r"epoch_(\d+)\.pt$", p.name)
        return int(m.group(1)) if m else -1
    pts = sorted(pts, key=lambda p: ep(p))
    return str(pts[-1])

def plot_loss(df_logs: pd.DataFrame, models: List[str], path: Path, title: str):
    plt.figure(figsize=(10.5, 4.8))
    for m in models:
        d = df_logs[df_logs["model"] == m].sort_values("epoch")
        plt.plot(d["epoch"], d["train_loss"], label=f"{m} train")
        plt.plot(d["epoch"], d["val_loss"], label=f"{m} val")
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    save_fig(path, dpi=300)

def plot_metrics(df_logs: pd.DataFrame, models: List[str], path: Path, title: str):
    plt.figure(figsize=(10.5, 4.8))
    for m in models:
        d = df_logs[df_logs["model"] == m].sort_values("epoch")
        plt.plot(d["epoch"], d["val_miou"], label=f"{m} mIoU")
        plt.plot(d["epoch"], d["val_dice"], label=f"{m} Dice")
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel("Metric")
    plt.ylim(0, 1.0)
    plt.legend()
    save_fig(path, dpi=300)

baseline_a_cfg = TrainConfig(epochs=30, batch_size=8, lr=2e-4, weight_decay=1e-4, optimizer="adamw", augmentations="hflip, brightness, rgb_noise", seed=SEED, grad_accum=1)
baseline_b_cfg = TrainConfig(epochs=30, batch_size=8, lr=2e-4, weight_decay=1e-4, optimizer="adamw", augmentations="hflip, brightness, rgb_noise", seed=SEED, grad_accum=1)
proposed_cfg = TrainConfig(epochs=30, batch_size=8, lr=2e-4, weight_decay=1e-4, optimizer="adamw", augmentations="hflip, brightness, rgb_noise", seed=SEED, grad_accum=1)

baseline_a = SimpleUNet(in_ch=3, base=32)
baseline_b = SimpleUNet(in_ch=4, base=32)

vit_hw = (224, 224)
proposed = ProposedViTFusion("vit_small_patch16_224", "vit_small_patch16_224", full_hw=TARGET_HW, vit_hw=vit_hw)

baseline_a_dir = PHASE4_DIR / "baseline_a_rgb" / "checkpoints"
baseline_b_dir = PHASE4_DIR / "baseline_b_rgbd" / "checkpoints"
proposed_dir = PHASE4_DIR / "proposed_vit_fusion" / "checkpoints"

resume_a = latest_checkpoint(baseline_a_dir)
resume_b = latest_checkpoint(baseline_b_dir)
resume_p = latest_checkpoint(proposed_dir)

mA = train_one_model("baseline_a_rgb", baseline_a, use_depth_input=False, cfg=baseline_a_cfg, vit_hw=vit_hw, resume_ckpt=resume_a)
mB = train_one_model("baseline_b_rgbd", baseline_b, use_depth_input=True, cfg=baseline_b_cfg, vit_hw=vit_hw, resume_ckpt=resume_b)
mP = train_one_model("proposed_vit_fusion", proposed, use_depth_input=True, cfg=proposed_cfg, vit_hw=vit_hw, resume_ckpt=resume_p)

logs_all = []
for k in ["baseline_a_rgb", "baseline_b_rgbd", "proposed_vit_fusion"]:
    p = PHASE4_DIR / k / "epoch_logs.csv"
    if p.exists():
        logs_all.append(pd.read_csv(p))
df_logs_all = pd.concat(logs_all, axis=0).reset_index(drop=True) if len(logs_all) else pd.DataFrame(columns=["model","epoch","train_loss","val_loss","val_miou","val_dice"])

TAB_01_BUDGET = PHASE4_DIR / "tab_01_training_budget_table.csv"
TAB_02_SIZE = PHASE4_DIR / "tab_02_model_size_compute_table.csv"
TAB_03_BEST = PHASE4_DIR / "tab_03_best_checkpoint_index.csv"
TAB_04_LOGS = PHASE4_DIR / "tab_04_batch_level_logs_summary.csv"

gpu_type = torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu"

def best_from_dir(model_key: str):
    p = PHASE4_DIR / model_key / "epoch_logs.csv"
    if not p.exists():
        return {"model": model_key, "checkpoint_path": "", "val_best_epoch": -1, "val_best_metric": -1.0, "seed": int(SEED)}
    d = pd.read_csv(p)
    d = d.dropna(subset=["val_miou"])
    if len(d) == 0:
        return {"model": model_key, "checkpoint_path": "", "val_best_epoch": -1, "val_best_metric": -1.0, "seed": int(SEED)}
    best_row = d.sort_values(["val_miou","epoch"], ascending=[False, True]).iloc[0]
    ep = int(best_row["epoch"])
    ck = PHASE4_DIR / model_key / "checkpoints" / f"epoch_{ep:03d}.pt"
    return {"model": model_key, "checkpoint_path": str(ck), "val_best_epoch": ep, "val_best_metric": float(best_row["val_miou"]), "seed": int(SEED)}

df_best = pd.DataFrame([
    best_from_dir("baseline_a_rgb"),
    best_from_dir("baseline_b_rgbd"),
    best_from_dir("proposed_vit_fusion"),
])
df_best.to_csv(TAB_03_BEST, index=False)

df_logs_all.to_csv(TAB_04_LOGS, index=False)

plot_loss(df_logs_all, ["baseline_a_rgb","baseline_b_rgbd"], PHASE4_DIR / "fig_01_training_curves_baselines.png", "Baselines loss curves")
plot_metrics(df_logs_all, ["baseline_a_rgb","baseline_b_rgbd"], PHASE4_DIR / "fig_01_training_curves_baselines_metrics.png", "Baselines validation metrics")
plot_loss(df_logs_all, ["proposed_vit_fusion"], PHASE4_DIR / "fig_02_training_curves_proposed.png", "Proposed loss curves")
plot_metrics(df_logs_all, ["proposed_vit_fusion"], PHASE4_DIR / "fig_02_training_curves_proposed_metrics.png", "Proposed validation metrics")



In [None]:
import os
import time
import math
import random
import json
import re
from pathlib import Path
from dataclasses import dataclass
from typing import Tuple, List, Dict

import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F

import timm

WORK_ROOT = Path("/kaggle/working")
PHASE1_DIR = WORK_ROOT / "results" / "01_data_integrity_and_profile"
PHASE2_DIR = WORK_ROOT / "results" / "02_sensor_alignment_and_quality"
PHASE3_DIR = WORK_ROOT / "results" / "03_preprocess_and_split_protocol"
PHASE4_DIR = WORK_ROOT / "results" / "04_training_baselines_and_vit_fusion"
PHASE5_DIR = WORK_ROOT / "results" / "05_eval_ablation_robust_explain"
PHASE5_DIR.mkdir(parents=True, exist_ok=True)

MANIFEST_VALID_PATH = PHASE1_DIR / "tab_01_dataset_manifest_valid_only.csv"
SPLIT_MANIFEST_PATH = PHASE3_DIR / "tab_03_split_manifest.csv"
PHASE4_BEST_PATH = PHASE4_DIR / "tab_03_best_checkpoint_index.csv"
PHASE4_SIZE_PATH = PHASE4_DIR / "tab_02_model_size_compute_table.csv"
FIG_SAMPLE_IDS_PATH = PHASE2_DIR / "tab_04_figure_sample_ids.csv"

TARGET_HW = (422, 640)
VIT_HW = (224, 224)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 1337

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

plt.rcParams.update({
    "font.size": 10,
    "axes.titlesize": 11,
    "axes.labelsize": 10,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 9,
})

def save_fig(path: Path, dpi=300):
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

def read_rgb(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 2:
        arr = np.stack([arr, arr, arr], axis=2)
    if arr.shape[2] > 3:
        arr = arr[:, :, :3]
    return arr.astype(np.uint8)

def read_depth(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 3:
        arr = arr[:, :, 0]
    return arr

def read_mask(path: str):
    img = Image.open(path)
    img.load()
    arr = np.array(img)
    if arr.ndim == 3:
        arr = np.any(arr > 0, axis=2).astype(np.uint8)
    else:
        arr = (arr > 0).astype(np.uint8)
    return arr

def crop_to_min_hw(rgb, depth, mask):
    h = min(rgb.shape[0], depth.shape[0], mask.shape[0])
    w = min(rgb.shape[1], depth.shape[1], mask.shape[1])
    return rgb[:h, :w], depth[:h, :w], mask[:h, :w]

def resize_triplet(rgb, depth, mask01, out_hw):
    oh, ow = int(out_hw[0]), int(out_hw[1])
    rgb_r = cv2.resize(rgb, (ow, oh), interpolation=cv2.INTER_LINEAR)
    depth_f = depth.astype(np.float32)
    depth_r = cv2.resize(depth_f, (ow, oh), interpolation=cv2.INTER_NEAREST)
    mask_r = cv2.resize(mask01.astype(np.uint8), (ow, oh), interpolation=cv2.INTER_NEAREST)
    mask_r = (mask_r > 0).astype(np.uint8)
    return rgb_r, depth_r, mask_r

def fill_depth_holes(depth: np.ndarray):
    d = depth.astype(np.float32)
    valid = (d > 0) & np.isfinite(d)
    if np.count_nonzero(valid) < 10:
        return d
    invalid = (~valid).astype(np.uint8)
    d0 = d.copy()
    d0[~valid] = 0.0
    filled = cv2.inpaint(d0, invalid, 3, cv2.INPAINT_TELEA)
    filled[~np.isfinite(filled)] = 0.0
    return filled

def robust_depth_scale(depth: np.ndarray):
    d = depth.astype(np.float32)
    valid = (d > 0) & np.isfinite(d)
    if np.count_nonzero(valid) < 10:
        return np.zeros_like(d, dtype=np.float32)
    vv = d[valid]
    mn, mx = float(np.percentile(vv, 1)), float(np.percentile(vv, 99))
    if mx - mn < 1e-6:
        out = np.zeros_like(d, dtype=np.float32)
        out[valid] = 0.5
        return out
    out = (d - mn) / (mx - mn)
    out = np.clip(out, 0.0, 1.0)
    out[~valid] = 0.0
    return out.astype(np.float32)

def metric_one(pred01, gt01, eps=1e-6):
    pred = pred01.astype(np.uint8)
    gt = gt01.astype(np.uint8)
    tp = np.logical_and(pred == 1, gt == 1).sum()
    fp = np.logical_and(pred == 1, gt == 0).sum()
    fn = np.logical_and(pred == 0, gt == 1).sum()
    inter = tp
    union = np.logical_or(pred == 1, gt == 1).sum()
    miou = (inter + eps) / (union + eps)
    dice = (2 * inter + eps) / (pred.sum() + gt.sum() + eps)
    prec = (tp + eps) / (tp + fp + eps)
    rec = (tp + eps) / (tp + fn + eps)
    f1 = (2 * prec * rec + eps) / (prec + rec + eps)
    return float(miou), float(dice), float(prec), float(rec), float(f1)

def apply_corruption(rgb_u8, dep_scaled, condition: str, severity: int, rng: np.random.Generator):
    rgb = rgb_u8.copy()
    dep = dep_scaled.copy()
    if condition == "clean":
        return rgb, dep
    if condition == "depth_missing":
        return rgb, np.zeros_like(dep, dtype=np.float32)
    if condition == "depth_noise":
        sigma = [0.02, 0.05, 0.10][severity-1]
        dep = np.clip(dep + rng.normal(0, sigma, size=dep.shape).astype(np.float32), 0.0, 1.0)
        return rgb, dep
    if condition == "depth_holes":
        rate = [0.05, 0.15, 0.30][severity-1]
        hole = rng.random(dep.shape) < rate
        dep = dep.copy()
        dep[hole] = 0.0
        return rgb, dep
    if condition == "rgb_lowlight":
        gamma = [1.6, 2.0, 2.4][severity-1]
        rgb_f = (rgb.astype(np.float32) / 255.0) ** gamma
        rgb = np.clip(rgb_f * 255.0, 0, 255).astype(np.uint8)
        return rgb, dep
    if condition == "rgb_blur":
        k = [3, 5, 7][severity-1]
        rgb = cv2.GaussianBlur(rgb, (k, k), 0)
        return rgb, dep
    return rgb, dep

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

def center_crop_or_pad_to(x, ref):
    _, _, h, w = x.shape
    _, _, hr, wr = ref.shape
    if h == hr and w == wr:
        return x
    if h > hr or w > wr:
        dh = max(0, h - hr)
        dw = max(0, w - wr)
        top = dh // 2
        left = dw // 2
        x = x[:, :, top:top+hr, left:left+wr]
    _, _, h2, w2 = x.shape
    if h2 < hr or w2 < wr:
        ph = max(0, hr - h2)
        pw = max(0, wr - w2)
        pad = (pw//2, pw - pw//2, ph//2, ph - ph//2)
        x = F.pad(x, pad, mode="replicate")
    return x

class SimpleUNet(nn.Module):
    def __init__(self, in_ch=3, base=32):
        super().__init__()
        self.enc1 = ConvBlock(in_ch, base)
        self.pool1 = nn.MaxPool2d(2, ceil_mode=True)
        self.enc2 = ConvBlock(base, base*2)
        self.pool2 = nn.MaxPool2d(2, ceil_mode=True)
        self.enc3 = ConvBlock(base*2, base*4)
        self.pool3 = nn.MaxPool2d(2, ceil_mode=True)
        self.bott = ConvBlock(base*4, base*8)
        self.up3 = nn.ConvTranspose2d(base*8, base*4, 2, stride=2)
        self.dec3 = ConvBlock(base*8, base*4)
        self.up2 = nn.ConvTranspose2d(base*4, base*2, 2, stride=2)
        self.dec2 = ConvBlock(base*4, base*2)
        self.up1 = nn.ConvTranspose2d(base*2, base, 2, stride=2)
        self.dec1 = ConvBlock(base*2, base)
        self.head = nn.Conv2d(base, 1, 1)
    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b = self.bott(self.pool3(e3))
        d3 = self.up3(b)
        e3a = center_crop_or_pad_to(e3, d3)
        d3 = self.dec3(torch.cat([d3, e3a], dim=1))
        d2 = self.up2(d3)
        e2a = center_crop_or_pad_to(e2, d2)
        d2 = self.dec2(torch.cat([d2, e2a], dim=1))
        d1 = self.up1(d2)
        e1a = center_crop_or_pad_to(e1, d1)
        d1 = self.dec1(torch.cat([d1, e1a], dim=1))
        out = self.head(d1)
        out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode="bilinear", align_corners=False)
        return out

class ViTEncoderFeature(nn.Module):
    def __init__(self, model_name: str, in_chans: int, img_hw: Tuple[int,int]):
        super().__init__()
        self.backbone = timm.create_model(
            model_name,
            pretrained=True,
            in_chans=in_chans,
            img_size=img_hw,
            dynamic_img_size=True,
            dynamic_img_pad=True,
            features_only=True,
            out_indices=(3,)
        )
        self.out_ch = self.backbone.feature_info.channels()[-1]
    def forward(self, x):
        feats = self.backbone(x)
        return feats[-1]

class CrossAttentionFusion(nn.Module):
    def __init__(self, ch: int, heads: int = 8):
        super().__init__()
        self.q = nn.Conv2d(ch, ch, 1, bias=False)
        self.k = nn.Conv2d(ch, ch, 1, bias=False)
        self.v = nn.Conv2d(ch, ch, 1, bias=False)
        self.attn = nn.MultiheadAttention(embed_dim=ch, num_heads=heads, batch_first=True)
        self.proj = nn.Conv2d(ch, ch, 1, bias=False)
        self.norm = nn.LayerNorm(ch)
    def forward(self, frgb, fdep):
        b, c, h, w = frgb.shape
        q = self.q(frgb).flatten(2).transpose(1,2)
        k = self.k(fdep).flatten(2).transpose(1,2)
        v = self.v(fdep).flatten(2).transpose(1,2)
        qn = self.norm(q)
        out, _ = self.attn(qn, k, v, need_weights=False)
        out = out.transpose(1,2).reshape(b, c, h, w)
        out = self.proj(out)
        return frgb + out

class LiteDecoder(nn.Module):
    def __init__(self, in_ch: int, out_hw: Tuple[int,int]):
        super().__init__()
        self.out_hw = out_hw
        self.conv1 = ConvBlock(in_ch, 256)
        self.up1 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv2 = ConvBlock(128, 128)
        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv3 = ConvBlock(64, 64)
        self.up3 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.conv4 = ConvBlock(32, 32)
        self.head = nn.Conv2d(32, 1, 1)
    def forward(self, x):
        x = self.conv1(x)
        x = self.up1(x)
        x = self.conv2(x)
        x = self.up2(x)
        x = self.conv3(x)
        x = self.up3(x)
        x = self.conv4(x)
        x = self.head(x)
        x = F.interpolate(x, size=self.out_hw, mode="bilinear", align_corners=False)
        return x

class ProposedViTFusion(nn.Module):
    def __init__(self, rgb_name: str, dep_name: str, full_hw: Tuple[int,int], vit_hw: Tuple[int,int], depth_tokens_stride: int = 1, mode: str = "fusion"):
        super().__init__()
        self.full_hw = full_hw
        self.vit_hw = vit_hw
        self.mode = mode
        self.depth_tokens_stride = depth_tokens_stride
        self.rgb_enc = ViTEncoderFeature(rgb_name, in_chans=3, img_hw=vit_hw)
        self.dep_enc = ViTEncoderFeature(dep_name, in_chans=1, img_hw=vit_hw)
        ch = self.rgb_enc.out_ch
        self.dep_proj = nn.Conv2d(self.dep_enc.out_ch, ch, 1, bias=False) if self.dep_enc.out_ch != ch else nn.Identity()
        self.fuse = CrossAttentionFusion(ch, heads=8)
        self.dec = LiteDecoder(ch, out_hw=full_hw)
        self.late_gate = nn.Sequential(nn.Conv2d(ch * 2, ch, 1, bias=False), nn.ReLU(inplace=True))
    def forward_feats(self, rgb3, dep1):
        rgb_small = F.interpolate(rgb3, size=self.vit_hw, mode="bilinear", align_corners=False)
        dep_small = F.interpolate(dep1, size=self.vit_hw, mode="bilinear", align_corners=False)
        fr = self.rgb_enc(rgb_small)
        fd = self.dep_proj(self.dep_enc(dep_small))
        if self.depth_tokens_stride > 1:
            fd = fd[:, :, ::self.depth_tokens_stride, ::self.depth_tokens_stride]
            fd = F.interpolate(fd, size=fr.shape[-2:], mode="bilinear", align_corners=False)
        return fr, fd
    def forward(self, rgb3, dep1):
        fr, fd = self.forward_feats(rgb3, dep1)
        if self.mode == "rgb_only":
            f = fr
        elif self.mode == "depth_only":
            f = fd
        elif self.mode == "no_fusion":
            f = fr
        elif self.mode == "late_fusion":
            f = self.late_gate(torch.cat([fr, fd], dim=1))
        else:
            f = self.fuse(fr, fd)
        return self.dec(f)

def load_ckpt(model: nn.Module, ckpt_path: str):
    ck = torch.load(ckpt_path, map_location=DEVICE)
    missing, unexpected = model.load_state_dict(ck["state_dict"], strict=False)
    model.to(DEVICE)
    model.eval()
    return {"missing_keys": list(missing), "unexpected_keys": list(unexpected)}

def get_best_ckpt_paths():
    df_best = pd.read_csv(PHASE4_BEST_PATH)
    m = {r["model"]: r["checkpoint_path"] for _, r in df_best.iterrows()}
    return m

def get_params_lookup():
    if PHASE4_SIZE_PATH.exists():
        d = pd.read_csv(PHASE4_SIZE_PATH)
        return {r["model"]: float(r["params_m"]) for _, r in d.iterrows()}
    return {}

df_valid = pd.read_csv(MANIFEST_VALID_PATH).sort_values("sample_id", kind="mergesort").reset_index(drop=True)
df_split = pd.read_csv(SPLIT_MANIFEST_PATH).sort_values("sample_id", kind="mergesort").reset_index(drop=True)
df_all = df_valid.merge(df_split, on=["sample_id","site_id"], how="inner").sort_values("sample_id", kind="mergesort").reset_index(drop=True)
df_test = df_all[df_all["split"] == "test"].sort_values("sample_id", kind="mergesort").reset_index(drop=True)

best_paths = get_best_ckpt_paths()
params_lookup = get_params_lookup()

ck_a = best_paths.get("baseline_a_rgb","")
ck_b = best_paths.get("baseline_b_rgbd","")
ck_p = best_paths.get("proposed_vit_fusion","")

if not ck_a or not Path(ck_a).exists():
    raise FileNotFoundError("Missing baseline_a_rgb checkpoint in tab_03_best_checkpoint_index.csv")
if not ck_b or not Path(ck_b).exists():
    raise FileNotFoundError("Missing baseline_b_rgbd checkpoint in tab_03_best_checkpoint_index.csv")
if not ck_p or not Path(ck_p).exists():
    raise FileNotFoundError("Missing proposed_vit_fusion checkpoint in tab_03_best_checkpoint_index.csv")

baseline_a = SimpleUNet(in_ch=3, base=32)
baseline_b = SimpleUNet(in_ch=4, base=32)
proposed_fusion = ProposedViTFusion("vit_small_patch16_224", "vit_small_patch16_224", full_hw=TARGET_HW, vit_hw=VIT_HW, depth_tokens_stride=1, mode="fusion")

info_a = load_ckpt(baseline_a, ck_a)
info_b = load_ckpt(baseline_b, ck_b)
info_p = load_ckpt(proposed_fusion, ck_p)

rng_global = np.random.default_rng(SEED)

def infer_mask(model_key: str, model: nn.Module, rgb_u8, dep_scaled):
    rgb_t = torch.from_numpy(rgb_u8.astype(np.float32)/255.0).permute(2,0,1).unsqueeze(0).to(DEVICE)
    dep_t = torch.from_numpy(dep_scaled.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        if model_key == "baseline_a_rgb":
            logits = model(rgb_t)
        elif model_key == "baseline_b_rgbd":
            logits = model(torch.cat([rgb_t, dep_t], dim=1))
        else:
            logits = model(rgb_t, dep_t)
    pred01 = (torch.sigmoid(logits)[0,0].detach().cpu().numpy() >= 0.5).astype(np.uint8)
    return pred01

def eval_models_on_test(models: Dict[str, nn.Module], condition: str = "clean", severity: int = 1):
    rows = []
    for i in range(len(df_test)):
        r = df_test.loc[i]
        sid = str(r["sample_id"])
        site = str(r["site_id"])
        rgb = read_rgb(r["rgb_path"])
        dep = read_depth(r["depth_path"])
        msk = read_mask(r["mask_path"])
        rgb, dep, msk = crop_to_min_hw(rgb, dep, msk)
        rgb, dep, msk = resize_triplet(rgb, dep, msk, TARGET_HW)
        dep = fill_depth_holes(dep)
        dep_scaled = robust_depth_scale(dep)
        rgb_c, dep_c = apply_corruption(rgb, dep_scaled, condition, severity, rng_global)
        for mk, mo in models.items():
            pred = infer_mask(mk, mo, rgb_c, dep_c)
            miou, dice, prec, rec, f1 = metric_one(pred, msk)
            rows.append({
                "sample_id": sid,
                "site_id": site,
                "model": mk,
                "miou": miou,
                "dice": dice,
                "precision": prec,
                "recall": rec,
                "f1": f1,
                "notes": f"{condition}|{severity}"
            })
    return pd.DataFrame(rows).sort_values(["model","sample_id"], kind="mergesort").reset_index(drop=True)

models_main = {
    "baseline_a_rgb": baseline_a,
    "baseline_b_rgbd": baseline_b,
    "proposed_vit_fusion": proposed_fusion
}

TAB_01_TEST_RESULTS = PHASE5_DIR / "tab_01_test_results_full.csv"
TAB_02_MAIN = PHASE5_DIR / "tab_02_main_results_table.csv"
TAB_03_ABL = PHASE5_DIR / "tab_03_ablation_table.csv"
TAB_04_ROB = PHASE5_DIR / "tab_04_robustness_table.csv"

FIG_01_MAIN = PHASE5_DIR / "fig_01_main_results_bar.png"
FIG_02_ABL = PHASE5_DIR / "fig_02_ablation_impact_plot.png"
FIG_03_ROB = PHASE5_DIR / "fig_03_robustness_curves.png"
FIG_04_EXPLAIN = PHASE5_DIR / "fig_04_explainability_examples.png"

df_clean_full = eval_models_on_test(models_main, condition="clean", severity=1)
df_clean_full.to_csv(TAB_01_TEST_RESULTS, index=False)

def summarize_main(df_full: pd.DataFrame, params_lookup: Dict[str,float]):
    out = []
    for mk in sorted(df_full["model"].unique()):
        d = df_full[df_full["model"] == mk]
        out.append({
            "method": mk,
            "modalities": "rgb" if mk == "baseline_a_rgb" else "rgb+depth",
            "miou": float(d["miou"].mean()),
            "dice": float(d["dice"].mean()),
            "precision": float(d["precision"].mean()),
            "recall": float(d["recall"].mean()),
            "f1": float(d["f1"].mean()),
            "params_m": float(params_lookup.get(mk, np.nan))
        })
    return pd.DataFrame(out)

df_main = summarize_main(df_clean_full, params_lookup)
df_main.to_csv(TAB_02_MAIN, index=False)

def fig_main_bar(df_main: pd.DataFrame, path: Path):
    plt.figure(figsize=(8.2, 4.2))
    x = np.arange(len(df_main))
    plt.bar(x - 0.15, df_main["miou"].values, width=0.3, label="mIoU")
    plt.bar(x + 0.15, df_main["dice"].values, width=0.3, label="Dice")
    plt.xticks(x, df_main["method"].values, rotation=20, ha="right")
    plt.ylim(0, 1.0)
    plt.ylabel("Score")
    plt.title("Test performance")
    plt.legend()
    save_fig(path, dpi=300)

fig_main_bar(df_main, FIG_01_MAIN)

def build_ablation_models():
    ablations = {}
    ablations["proposed_fusion"] = ProposedViTFusion("vit_small_patch16_224","vit_small_patch16_224", TARGET_HW, VIT_HW, depth_tokens_stride=1, mode="fusion")
    ablations["no_fusion"] = ProposedViTFusion("vit_small_patch16_224","vit_small_patch16_224", TARGET_HW, VIT_HW, depth_tokens_stride=1, mode="no_fusion")
    ablations["late_fusion"] = ProposedViTFusion("vit_small_patch16_224","vit_small_patch16_224", TARGET_HW, VIT_HW, depth_tokens_stride=1, mode="late_fusion")
    ablations["rgb_only_vit"] = ProposedViTFusion("vit_small_patch16_224","vit_small_patch16_224", TARGET_HW, VIT_HW, depth_tokens_stride=1, mode="rgb_only")
    ablations["depth_only"] = ProposedViTFusion("vit_small_patch16_224","vit_small_patch16_224", TARGET_HW, VIT_HW, depth_tokens_stride=1, mode="depth_only")
    ablations["reduced_depth_tokens_x2"] = ProposedViTFusion("vit_small_patch16_224","vit_small_patch16_224", TARGET_HW, VIT_HW, depth_tokens_stride=2, mode="fusion")
    return ablations

abl_models = build_ablation_models()
abl_load_infos = {}
for k in list(abl_models.keys()):
    abl_load_infos[k] = load_ckpt(abl_models[k], ck_p)

def eval_ablation(abl_models: Dict[str, nn.Module]):
    rows = []
    for name, model in abl_models.items():
        df_full = eval_models_on_test({name: model}, condition="clean", severity=1)
        miou = float(df_full["miou"].mean())
        dice = float(df_full["dice"].mean())
        rows.append({"ablation_name": name, "miou": miou, "dice": dice})
    df_ab = pd.DataFrame(rows).sort_values("ablation_name", kind="mergesort").reset_index(drop=True)
    base = float(df_ab[df_ab["ablation_name"] == "proposed_fusion"]["miou"].iloc[0]) if (df_ab["ablation_name"] == "proposed_fusion").any() else float(df_ab["miou"].max())
    df_ab["delta_miou"] = df_ab["miou"].astype(float) - base
    df_ab["notes"] = ""
    return df_ab

df_ablation = eval_ablation(abl_models)
df_ablation.to_csv(TAB_03_ABL, index=False)

def fig_ablation(df_ab: pd.DataFrame, path: Path):
    plt.figure(figsize=(9.0, 4.2))
    x = np.arange(len(df_ab))
    plt.bar(x, df_ab["miou"].values)
    plt.xticks(x, df_ab["ablation_name"].values, rotation=25, ha="right")
    plt.ylim(0, 1.0)
    plt.ylabel("mIoU")
    plt.title("Ablation impact on mIoU (test)")
    save_fig(path, dpi=300)

fig_ablation(df_ablation, FIG_02_ABL)

robust_conditions = [
    ("depth_missing", [1]),
    ("depth_noise", [1,2,3]),
    ("depth_holes", [1,2,3]),
    ("rgb_lowlight", [1,2,3]),
    ("rgb_blur", [1,2,3]),
]

def robustness_table(models: Dict[str, nn.Module]):
    rows = []
    base_df = eval_models_on_test(models, condition="clean", severity=1)
    base = {mk: float(base_df[base_df["model"] == mk]["miou"].mean()) for mk in models.keys()}
    for cond, sevs in robust_conditions:
        for sev in sevs:
            df_full = eval_models_on_test(models, condition=cond, severity=sev)
            for mk in models.keys():
                d = df_full[df_full["model"] == mk]
                miou = float(d["miou"].mean())
                dice = float(d["dice"].mean())
                rows.append({
                    "condition": cond,
                    "severity": int(sev),
                    "model": mk,
                    "miou": miou,
                    "dice": dice,
                    "delta_from_clean": float(miou - base[mk]),
                    "notes": ""
                })
    return pd.DataFrame(rows).sort_values(["model","condition","severity"], kind="mergesort").reset_index(drop=True)

df_rob = robustness_table(models_main)
df_rob.to_csv(TAB_04_ROB, index=False)

def fig_robustness(df_rob: pd.DataFrame, path: Path, model_key="proposed_vit_fusion"):
    plt.figure(figsize=(9.2, 4.8))
    d = df_rob[df_rob["model"] == model_key].copy()
    labels = []
    ys = []
    for cond, sevs in robust_conditions:
        for sev in sevs:
            r = d[(d["condition"] == cond) & (d["severity"] == sev)]
            if len(r):
                labels.append(f"{cond}:{sev}")
                ys.append(float(r["miou"].iloc[0]))
    x = np.arange(len(labels))
    plt.plot(x, ys, marker="o", label=f"{model_key} mIoU")
    plt.xticks(x, labels, rotation=30, ha="right")
    plt.ylim(0, 1.0)
    plt.ylabel("mIoU")
    plt.title("Robustness under corruptions")
    plt.legend()
    save_fig(path, dpi=300)

fig_robustness(df_rob, FIG_03_ROB, model_key="proposed_vit_fusion")

def explain_map_from_fused_feats(model: ProposedViTFusion, rgb_u8, dep_scaled):
    rgb_t = torch.from_numpy(rgb_u8.astype(np.float32)/255.0).permute(2,0,1).unsqueeze(0).to(DEVICE)
    dep_t = torch.from_numpy(dep_scaled.astype(np.float32)).unsqueeze(0).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        fr, fd = model.forward_feats(rgb_t, dep_t)
        f = model.fuse(fr, fd)
        mag = torch.mean(torch.abs(f), dim=1)[0].detach().cpu().numpy()
    mag = (mag - mag.min()) / (mag.max() - mag.min() + 1e-6)
    mag = cv2.resize(mag, (TARGET_HW[1], TARGET_HW[0]), interpolation=cv2.INTER_LINEAR)
    return mag

def pick_explain_ids(k=8):
    if FIG_SAMPLE_IDS_PATH.exists():
        df_fig = pd.read_csv(FIG_SAMPLE_IDS_PATH)
        pick = df_fig[df_fig["figure_name"].astype(str).str.contains("outliers", na=False)]
        if len(pick):
            ids_str = str(pick.iloc[0]["sample_id_list"])
            ids = [x for x in ids_str.split("|") if x.strip() != ""]
            if len(ids) >= k:
                return ids[:k]
    return df_test.sort_values("sample_id", kind="mergesort").head(k)["sample_id"].tolist()

EXPLAIN_IDS = pick_explain_ids(k=8)

def fig_explainability(sample_ids: List[str], path: Path):
    lut = {df_test.loc[i, "sample_id"]: df_test.loc[i] for i in range(len(df_test))}
    rows = len(sample_ids)
    cols = 4
    plt.figure(figsize=(3.2 * cols, 2.8 * rows))
    for r, sid in enumerate(sample_ids):
        row = lut[sid]
        rgb = read_rgb(row["rgb_path"])
        dep = read_depth(row["depth_path"])
        msk = read_mask(row["mask_path"])
        rgb, dep, msk = crop_to_min_hw(rgb, dep, msk)
        rgb, dep, msk = resize_triplet(rgb, dep, msk, TARGET_HW)
        dep = fill_depth_holes(dep)
        dep_scaled = robust_depth_scale(dep)
        pred = infer_mask("proposed_vit_fusion", proposed_fusion, rgb, dep_scaled)
        att = explain_map_from_fused_feats(proposed_fusion, rgb, dep_scaled)
        overlay = rgb.copy().astype(np.float32)
        heat = (att * 255.0).astype(np.uint8)
        heat = cv2.applyColorMap(heat, cv2.COLORMAP_JET)
        overlay = np.clip(0.65 * overlay + 0.35 * heat.astype(np.float32), 0, 255).astype(np.uint8)

        ax = plt.subplot(rows, cols, r*cols + 1)
        ax.imshow(rgb)
        ax.set_title(f"{sid}")
        ax.axis("off")

        ax = plt.subplot(rows, cols, r*cols + 2)
        ax.imshow(msk, cmap="gray")
        ax.set_title("GT")
        ax.axis("off")

        ax = plt.subplot(rows, cols, r*cols + 3)
        ax.imshow(pred, cmap="gray")
        ax.set_title("Pred")
        ax.axis("off")

        ax = plt.subplot(rows, cols, r*cols + 4)
        ax.imshow(overlay)
        ax.set_title("Explain")
        ax.axis("off")
    save_fig(path, dpi=300)

fig_explainability(EXPLAIN_IDS, FIG_04_EXPLAIN)



In [None]:
EXPLAIN_IDS = [str(x) for x in EXPLAIN_IDS]
df_test["sample_id"] = df_test["sample_id"].astype(str)

test_ids = set(df_test["sample_id"].tolist())
explain_ids_ok = [sid for sid in EXPLAIN_IDS if sid in test_ids]

if len(explain_ids_ok) == 0:
    explain_ids_ok = df_test.sort_values("sample_id", kind="mergesort").head(8)["sample_id"].astype(str).tolist()

TAB_05_EXPLAIN_IDS = PHASE5_DIR / "tab_05_explainability_sample_ids.csv"
pd.DataFrame({
    "figure_name": ["fig_04_explainability_examples.png"],
    "sample_id_list": ["|".join(explain_ids_ok)],
    "selection_rule": ["phase2_ids_intersect_test_else_first8_test_sorted"]
}).to_csv(TAB_05_EXPLAIN_IDS, index=False)

EXPLAIN_IDS = explain_ids_ok

fig_explainability(EXPLAIN_IDS, FIG_04_EXPLAIN)



In [None]:
import os
import time
from pathlib import Path
from typing import Tuple

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import psutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm

WORK_ROOT = Path("/kaggle/working")
RESULTS_ROOT = WORK_ROOT / "results"
PHASE4_DIR = RESULTS_ROOT / "04_training_baselines_and_vit_fusion"
PHASE5_DIR = RESULTS_ROOT / "05_eval_ablation_robust_explain"
PHASE6_DIR = RESULTS_ROOT / "06_edge_and_final_ieee_package"
PHASE6_DIR.mkdir(parents=True, exist_ok=True)

TARGET_HW = (422, 640)
VIT_HW = (224, 224)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SEED = 1337

np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

plt.rcParams.update({
    "font.size": 10,
    "axes.titlesize": 11,
    "axes.labelsize": 10,
    "xtick.labelsize": 9,
    "ytick.labelsize": 9,
    "legend.fontsize": 9,
})

def save_fig(path: Path, dpi=300):
    path.parent.mkdir(parents=True, exist_ok=True)
    plt.tight_layout()
    plt.savefig(path, dpi=dpi, bbox_inches="tight")
    plt.close()

def file_size_mb(p: Path):
    if not p.exists():
        return float("nan")
    return float(p.stat().st_size) / (1024**2)

def params_m(model: nn.Module):
    return float(sum(p.numel() for p in model.parameters())) / 1e6

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.net(x)

class ViTEncoderFeature(nn.Module):
    def __init__(self, model_name: str, in_chans: int, img_hw: Tuple[int,int]):
        super().__init__()
        self.backbone = timm.create_model(
            model_name,
            pretrained=True,
            in_chans=in_chans,
            img_size=img_hw,
            dynamic_img_size=True,
            dynamic_img_pad=True,
            features_only=True,
            out_indices=(3,)
        )
        self.out_ch = self.backbone.feature_info.channels()[-1]
    def forward(self, x):
        feats = self.backbone(x)
        return feats[-1]

class CrossAttentionFusion(nn.Module):
    def __init__(self, ch: int, heads: int = 8):
        super().__init__()
        self.q = nn.Conv2d(ch, ch, 1, bias=False)
        self.k = nn.Conv2d(ch, ch, 1, bias=False)
        self.v = nn.Conv2d(ch, ch, 1, bias=False)
        self.attn = nn.MultiheadAttention(embed_dim=ch, num_heads=heads, batch_first=True)
        self.proj = nn.Conv2d(ch, ch, 1, bias=False)
        self.norm = nn.LayerNorm(ch)
    def forward(self, frgb, fdep):
        b, c, h, w = frgb.shape
        q = self.q(frgb).flatten(2).transpose(1,2)
        k = self.k(fdep).flatten(2).transpose(1,2)
        v = self.v(fdep).flatten(2).transpose(1,2)
        qn = self.norm(q)
        out, _ = self.attn(qn, k, v, need_weights=False)
        out = out.transpose(1,2).reshape(b, c, h, w)
        out = self.proj(out)
        return frgb + out

class LiteDecoder(nn.Module):
    def __init__(self, in_ch: int, out_hw: Tuple[int,int]):
        super().__init__()
        self.out_hw = out_hw
        self.conv1 = ConvBlock(in_ch, 256)
        self.up1 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.conv2 = ConvBlock(128, 128)
        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.conv3 = ConvBlock(64, 64)
        self.up3 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.conv4 = ConvBlock(32, 32)
        self.head = nn.Conv2d(32, 1, 1)
    def forward(self, x):
        x = self.conv1(x)
        x = self.up1(x)
        x = self.conv2(x)
        x = self.up2(x)
        x = self.conv3(x)
        x = self.up3(x)
        x = self.conv4(x)
        x = self.head(x)
        x = F.interpolate(x, size=self.out_hw, mode="bilinear", align_corners=False)
        return x

class ProposedViTFusion(nn.Module):
    def __init__(self, rgb_name: str, dep_name: str, full_hw: Tuple[int,int], vit_hw: Tuple[int,int], depth_tokens_stride: int = 1, mode: str = "fusion"):
        super().__init__()
        self.full_hw = full_hw
        self.vit_hw = vit_hw
        self.mode = mode
        self.depth_tokens_stride = depth_tokens_stride
        self.rgb_enc = ViTEncoderFeature(rgb_name, in_chans=3, img_hw=vit_hw)
        self.dep_enc = ViTEncoderFeature(dep_name, in_chans=1, img_hw=vit_hw)
        ch = self.rgb_enc.out_ch
        self.dep_proj = nn.Conv2d(self.dep_enc.out_ch, ch, 1, bias=False) if self.dep_enc.out_ch != ch else nn.Identity()
        self.fuse = CrossAttentionFusion(ch, heads=8)
        self.dec = LiteDecoder(ch, out_hw=full_hw)
        self.late_gate = nn.Sequential(nn.Conv2d(ch * 2, ch, 1, bias=False), nn.ReLU(inplace=True))
    def forward_feats(self, rgb3, dep1):
        rgb_small = F.interpolate(rgb3, size=self.vit_hw, mode="bilinear", align_corners=False)
        dep_small = F.interpolate(dep1, size=self.vit_hw, mode="bilinear", align_corners=False)
        fr = self.rgb_enc(rgb_small)
        fd = self.dep_proj(self.dep_enc(dep_small))
        if self.depth_tokens_stride > 1:
            fd = fd[:, :, ::self.depth_tokens_stride, ::self.depth_tokens_stride]
            fd = F.interpolate(fd, size=fr.shape[-2:], mode="bilinear", align_corners=False)
        return fr, fd
    def forward(self, rgb3, dep1):
        fr, fd = self.forward_feats(rgb3, dep1)
        if self.mode == "rgb_only":
            f = fr
        elif self.mode == "depth_only":
            f = fd
        elif self.mode == "no_fusion":
            f = fr
        elif self.mode == "late_fusion":
            f = self.late_gate(torch.cat([fr, fd], dim=1))
        else:
            f = self.fuse(fr, fd)
        return self.dec(f)

def load_best_ckpt_path():
    p = PHASE4_DIR / "tab_03_best_checkpoint_index.csv"
    if not p.exists():
        raise FileNotFoundError(str(p))
    d = pd.read_csv(p)
    row = d[d["model"] == "proposed_vit_fusion"]
    if len(row) == 0:
        raise RuntimeError("proposed_vit_fusion not found in tab_03_best_checkpoint_index.csv")
    ck = str(row.iloc[0]["checkpoint_path"])
    if not ck or not Path(ck).exists():
        raise FileNotFoundError(ck)
    return ck

def load_ckpt_non_strict(model: nn.Module, ckpt_path: str):
    ck = torch.load(ckpt_path, map_location=DEVICE)
    model.load_state_dict(ck["state_dict"], strict=False)
    model.to(DEVICE)
    model.eval()
    return ck

ck_p = load_best_ckpt_path()
proposed = ProposedViTFusion("vit_small_patch16_224", "vit_small_patch16_224", full_hw=TARGET_HW, vit_hw=VIT_HW, depth_tokens_stride=1, mode="fusion")
_ = load_ckpt_non_strict(proposed, ck_p)

ART_DIR = PHASE6_DIR / "artifacts"
ART_DIR.mkdir(parents=True, exist_ok=True)

ONNX_FP32 = ART_DIR / "proposed_vit_fusion_fp32.onnx"
ONNX_FP16 = ART_DIR / "proposed_vit_fusion_fp16.onnx"
ONNX_INT8 = ART_DIR / "proposed_vit_fusion_int8_dynamic.onnx"

x_rgb_fp32 = torch.randn(1, 3, TARGET_HW[0], TARGET_HW[1], device=DEVICE, dtype=torch.float32)
x_dep_fp32 = torch.randn(1, 1, TARGET_HW[0], TARGET_HW[1], device=DEVICE, dtype=torch.float32)
x_rgb_fp16 = x_rgb_fp32.half()
x_dep_fp16 = x_dep_fp32.half()

def export_onnx_fp32(model: nn.Module, path: Path, opset: int = 17):
    m = model.to(DEVICE).eval().float()
    torch.onnx.export(
        m,
        (x_rgb_fp32, x_dep_fp32),
        str(path),
        export_params=True,
        opset_version=opset,
        do_constant_folding=True,
        input_names=["rgb", "depth"],
        output_names=["logits"],
        dynamic_axes=None,
    )

def export_onnx_fp16(model: nn.Module, path: Path, opset: int = 17):
    m = model.to(DEVICE).eval().half()
    torch.onnx.export(
        m,
        (x_rgb_fp16, x_dep_fp16),
        str(path),
        export_params=True,
        opset_version=opset,
        do_constant_folding=True,
        input_names=["rgb", "depth"],
        output_names=["logits"],
        dynamic_axes=None,
    )

export_ok_fp32 = True
export_ok_fp16 = True

try:
    export_onnx_fp32(proposed, ONNX_FP32, opset=17)
except Exception as e:
    export_ok_fp32 = False
    with open(ART_DIR / "export_fp32_error.txt", "w") as f:
        f.write(str(e))

try:
    export_onnx_fp16(proposed, ONNX_FP16, opset=17)
except Exception as e:
    export_ok_fp16 = False
    with open(ART_DIR / "export_fp16_error.txt", "w") as f:
        f.write(str(e))

quant_ok = False
quant_notes = ""
try:
    import onnx
    from onnxruntime.quantization import quantize_dynamic, QuantType
    if export_ok_fp32 and ONNX_FP32.exists():
        quantize_dynamic(str(ONNX_FP32), str(ONNX_INT8), weight_type=QuantType.QInt8)
        quant_ok = True
except Exception as e:
    quant_notes = str(e)
    with open(ART_DIR / "export_int8_error.txt", "w") as f:
        f.write(quant_notes)

def torch_profile_latency(model: nn.Module, precision: str, warmup: int = 20, repeats: int = 100):
    proc = psutil.Process(os.getpid())
    rss0 = proc.memory_info().rss

    if torch.cuda.is_available():
        torch.cuda.reset_peak_memory_stats()
        torch.cuda.synchronize()

    if precision == "fp16_amp":
        m = model.to(DEVICE).eval().float()
        rgb = x_rgb_fp32
        dep = x_dep_fp32
        use_amp = True
    elif precision == "fp16":
        m = model.to(DEVICE).eval().half()
        rgb = x_rgb_fp16
        dep = x_dep_fp16
        use_amp = False
    else:
        m = model.to(DEVICE).eval().float()
        rgb = x_rgb_fp32
        dep = x_dep_fp32
        use_amp = False

    with torch.no_grad():
        for _ in range(warmup):
            if use_amp and torch.cuda.is_available():
                with torch.amp.autocast("cuda", enabled=True):
                    _ = m(rgb, dep)
            else:
                _ = m(rgb, dep)
        if torch.cuda.is_available():
            torch.cuda.synchronize()

        times = []
        for _ in range(repeats):
            t0 = time.perf_counter()
            if use_amp and torch.cuda.is_available():
                with torch.amp.autocast("cuda", enabled=True):
                    _ = m(rgb, dep)
            else:
                _ = m(rgb, dep)
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            t1 = time.perf_counter()
            times.append((t1 - t0) * 1000.0)

    rss1 = proc.memory_info().rss
    peak_ram_mb = (max(rss0, rss1) - min(rss0, rss1)) / (1024**2)
    peak_vram_mb = (torch.cuda.max_memory_allocated() / (1024**2)) if torch.cuda.is_available() else 0.0

    lat_mean = float(np.mean(times))
    lat_std = float(np.std(times))
    fps = float(1000.0 / max(1e-6, lat_mean))
    return lat_mean, lat_std, fps, float(peak_ram_mb), float(peak_vram_mb)

profiles = []

lat, std, fps, ram, vram = torch_profile_latency(proposed, "fp32", warmup=20, repeats=100)
profiles.append({
    "variant": "pytorch",
    "precision": "fp32",
    "input_resolution": f"{TARGET_HW[0]}x{TARGET_HW[1]}",
    "latency_ms_mean": lat,
    "latency_ms_std": std,
    "fps": fps,
    "peak_ram_mb": ram,
    "peak_vram_mb": vram
})

lat, std, fps, ram, vram = torch_profile_latency(proposed, "fp16_amp", warmup=20, repeats=100)
profiles.append({
    "variant": "pytorch",
    "precision": "fp16_amp",
    "input_resolution": f"{TARGET_HW[0]}x{TARGET_HW[1]}",
    "latency_ms_mean": lat,
    "latency_ms_std": std,
    "fps": fps,
    "peak_ram_mb": ram,
    "peak_vram_mb": vram
})

TAB_01_EDGE = PHASE6_DIR / "tab_01_edge_inference_profile.csv"
df_edge = pd.DataFrame(profiles)
df_edge.to_csv(TAB_01_EDGE, index=False)

exports = []
exports.append({
    "artifact_name": "checkpoint_best",
    "format": "pt",
    "path": str(Path(ck_p)),
    "file_size_mb": file_size_mb(Path(ck_p)),
    "opset": "",
    "quantized": False,
    "notes": "best checkpoint from phase4 index"
})
exports.append({
    "artifact_name": "onnx_fp32",
    "format": "onnx",
    "path": str(ONNX_FP32),
    "file_size_mb": file_size_mb(ONNX_FP32),
    "opset": 17 if export_ok_fp32 else "",
    "quantized": False,
    "notes": "two-input onnx: rgb, depth"
})
exports.append({
    "artifact_name": "onnx_fp16",
    "format": "onnx",
    "path": str(ONNX_FP16),
    "file_size_mb": file_size_mb(ONNX_FP16),
    "opset": 17 if export_ok_fp16 else "",
    "quantized": False,
    "notes": "exported from half model"
})
exports.append({
    "artifact_name": "onnx_int8_dynamic",
    "format": "onnx",
    "path": str(ONNX_INT8),
    "file_size_mb": file_size_mb(ONNX_INT8),
    "opset": "",
    "quantized": bool(quant_ok),
    "notes": "onnxruntime dynamic quantization" if quant_ok else f"not created: {quant_notes[:160]}"
})

TAB_02_EXPORTS = PHASE6_DIR / "tab_02_export_artifacts_table.csv"
df_exports = pd.DataFrame(exports)
df_exports.to_csv(TAB_02_EXPORTS, index=False)

FIG_01_LAT = PHASE6_DIR / "fig_01_latency_fps_tradeoff.png"
plt.figure(figsize=(8.4, 4.6))
x = np.arange(len(df_edge))
plt.bar(x - 0.2, df_edge["latency_ms_mean"].values, width=0.4, label="Latency (ms)")
ax1 = plt.gca()
ax2 = ax1.twinx()
ax2.plot(x + 0.2, df_edge["fps"].values, marker="o", label="FPS")
ax1.set_xticks(x)
ax1.set_xticklabels((df_edge["variant"] + "|" + df_edge["precision"]).values, rotation=15, ha="right")
ax1.set_ylabel("Latency (ms)")
ax2.set_ylabel("FPS")
ax1.set_title("Latency and FPS tradeoff")
lines0, labels0 = ax1.get_legend_handles_labels()
lines1, labels1 = ax2.get_legend_handles_labels()
ax1.legend(lines0 + lines1, labels0 + labels1, loc="upper right")
save_fig(FIG_01_LAT, dpi=300)

FIG_02_MEM = PHASE6_DIR / "fig_02_memory_and_modelsize.png"
model_sz = file_size_mb(Path(ck_p))
plt.figure(figsize=(8.4, 4.6))
x = np.arange(len(df_edge))
plt.bar(x - 0.2, df_edge["peak_vram_mb"].values, width=0.4, label="Peak VRAM (MB)")
plt.bar(x + 0.2, df_edge["peak_ram_mb"].values, width=0.4, label="Peak RAM delta (MB)")
plt.xticks(x, (df_edge["variant"] + "|" + df_edge["precision"]).values, rotation=15, ha="right")
plt.ylabel("MB")
plt.title(f"Memory profile and model size (ckpt {model_sz:.1f} MB)")
plt.legend()
save_fig(FIG_02_MEM, dpi=300)

FIG_03_PIPE = PHASE6_DIR / "fig_03_final_pipeline_overview.png"
plt.figure(figsize=(11.5, 3.4))
ax = plt.gca()
ax.axis("off")
boxes = [
    ("ISOD dataset\nPhase 1–2", 0.02, 0.35, 0.16, 0.35),
    ("Preprocess +\nsite split\nPhase 3", 0.22, 0.35, 0.16, 0.35),
    ("Train baselines +\nproposed ViT fusion\nPhase 4", 0.42, 0.35, 0.20, 0.35),
    ("Eval + robustness +\nexplainability\nPhase 5", 0.66, 0.35, 0.18, 0.35),
    ("ONNX export +\nprofiling package\nPhase 6", 0.86, 0.35, 0.12, 0.35),
]
for text, x0, y0, w, h in boxes:
    ax.add_patch(plt.Rectangle((x0, y0), w, h, fill=False, linewidth=2))
    ax.text(x0 + w/2, y0 + h/2, text, ha="center", va="center")
for i in range(len(boxes)-1):
    x0 = boxes[i][1] + boxes[i][3]
    y0 = boxes[i][2] + boxes[i][4]/2
    x1 = boxes[i+1][1]
    y1 = boxes[i+1][2] + boxes[i+1][4]/2
    ax.annotate("", xy=(x1, y1), xytext=(x0, y0), arrowprops=dict(arrowstyle="->", lw=2))
save_fig(FIG_03_PIPE, dpi=300)

TAB_02_MAIN = PHASE5_DIR / "tab_02_main_results_table.csv"
TAB_04_ROB = PHASE5_DIR / "tab_04_robustness_table.csv"
df_main = pd.read_csv(TAB_02_MAIN) if TAB_02_MAIN.exists() else pd.DataFrame()
df_rob = pd.read_csv(TAB_04_ROB) if TAB_04_ROB.exists() else pd.DataFrame()

def key_snapshot(df_main: pd.DataFrame, df_rob: pd.DataFrame, df_edge: pd.DataFrame):
    out = {}
    if len(df_main):
        row = df_main[df_main["method"] == "proposed_vit_fusion"]
        if len(row):
            out["clean_miou"] = float(row.iloc[0]["miou"])
            out["clean_dice"] = float(row.iloc[0]["dice"])
    if len(df_rob):
        rp = df_rob[(df_rob["model"] == "proposed_vit_fusion") & (df_rob["condition"] == "depth_missing")]
        if len(rp):
            out["depth_missing_miou"] = float(rp.sort_values("severity").iloc[0]["miou"])
    if len(df_edge):
        fp32 = df_edge[(df_edge["variant"] == "pytorch") & (df_edge["precision"] == "fp32")]
        fp16 = df_edge[(df_edge["variant"] == "pytorch") & (df_edge["precision"] == "fp16_amp")]
        if len(fp32):
            out["lat_ms_fp32"] = float(fp32.iloc[0]["latency_ms_mean"])
            out["fps_fp32"] = float(fp32.iloc[0]["fps"])
        if len(fp16):
            out["lat_ms_fp16"] = float(fp16.iloc[0]["latency_ms_mean"])
            out["fps_fp16"] = float(fp16.iloc[0]["fps"])
    out["ckpt_mb"] = file_size_mb(Path(ck_p))
    out["onnx_fp32_mb"] = file_size_mb(ONNX_FP32) if export_ok_fp32 else float("nan")
    out["onnx_fp16_mb"] = file_size_mb(ONNX_FP16) if export_ok_fp16 else float("nan")
    out["params_m"] = params_m(proposed)
    return out

snap = key_snapshot(df_main, df_rob, df_edge)

FIG_04_SNAP = PHASE6_DIR / "fig_04_key_results_snapshot.png"
plt.figure(figsize=(10.5, 4.8))
ax = plt.gca()
ax.axis("off")
lines = [
    f"Proposed (ViT fusion) key snapshot",
    f"Clean test: mIoU={snap.get('clean_miou', float('nan')):.4f}  Dice={snap.get('clean_dice', float('nan')):.4f}",
    f"Robustness depth-missing: mIoU={snap.get('depth_missing_miou', float('nan')):.4f}",
    f"Edge PyTorch FP32: {snap.get('lat_ms_fp32', float('nan')):.2f} ms  {snap.get('fps_fp32', float('nan')):.1f} FPS",
    f"Edge PyTorch FP16 AMP: {snap.get('lat_ms_fp16', float('nan')):.2f} ms  {snap.get('fps_fp16', float('nan')):.1f} FPS",
    f"Model: params={snap.get('params_m', float('nan')):.2f} M  ckpt={snap.get('ckpt_mb', float('nan')):.1f} MB  onnx_fp32={snap.get('onnx_fp32_mb', float('nan')):.1f} MB",
]
ax.text(0.02, 0.85, "\n".join(lines), va="top", ha="left", fontsize=12)
save_fig(FIG_04_SNAP, dpi=300)

def build_artifact_index(results_root: Path):
    rows = []
    for phase_dir in sorted(results_root.glob("*")):
        if not phase_dir.is_dir():
            continue
        phase = phase_dir.name.split("_")[0]
        for p in sorted(phase_dir.rglob("*")):
            if p.is_dir():
                continue
            rel = p.relative_to(results_root)
            fn = p.name
            ext = p.suffix.lower()
            if ext in [".png", ".jpg", ".jpeg", ".pdf"]:
                atype = "figure"
            elif ext in [".csv"]:
                atype = "table_csv"
            elif ext in [".pt", ".pth", ".ckpt", ".safetensors", ".onnx"]:
                atype = "model_artifact"
            else:
                atype = "other"
            rows.append({
                "phase": phase,
                "artifact_type": atype,
                "filename": fn,
                "relative_path": str(rel),
                "description": ""
            })
    return pd.DataFrame(rows)

TAB_03_INDEX = PHASE6_DIR / "tab_03_final_artifact_index.csv"
df_index = build_artifact_index(RESULTS_ROOT)
df_index.to_csv(TAB_03_INDEX, index=False)

TAB_04_CHECK = PHASE6_DIR / "tab_04_reproducibility_checklist.csv"
check_rows = [
    {"item": "Dataset manifest and integrity logs available", "status": "ok" if (RESULTS_ROOT/"01_data_integrity_and_profile").exists() else "missing", "details": "Phase 1 outputs directory present"},
    {"item": "Sensor alignment metrics available", "status": "ok" if (RESULTS_ROOT/"02_sensor_alignment_and_quality").exists() else "missing", "details": "Phase 2 outputs directory present"},
    {"item": "Preprocess config and site split manifest available", "status": "ok" if (RESULTS_ROOT/"03_preprocess_and_split_protocol"/"tab_03_split_manifest.csv").exists() else "missing", "details": "Phase 3 split manifest"},
    {"item": "Best checkpoint index available", "status": "ok" if (PHASE4_DIR/"tab_03_best_checkpoint_index.csv").exists() else "missing", "details": "Phase 4 checkpoint index"},
    {"item": "Test metrics CSV source available", "status": "ok" if (PHASE5_DIR/"tab_01_test_results_full.csv").exists() else "missing", "details": "Phase 5 full test results"},
    {"item": "Edge profile CSV source available", "status": "ok" if (PHASE6_DIR/"tab_01_edge_inference_profile.csv").exists() else "missing", "details": "Phase 6 edge inference profile"},
    {"item": "ONNX export available", "status": "ok" if (ONNX_FP32.exists() or ONNX_FP16.exists()) else "missing", "details": "Phase 6 ONNX artifacts"},
    {"item": "Final artifact index table available", "status": "ok" if TAB_03_INDEX.exists() else "missing", "details": "Phase 6 index CSV"},
]
df_check = pd.DataFrame(check_rows)
df_check.to_csv(TAB_04_CHECK, index=False)

TAB_01_EDGE = PHASE6_DIR / "tab_01_edge_inference_profile.csv"
TAB_02_EXPORTS = PHASE6_DIR / "tab_02_export_artifacts_table.csv"



In [None]:
import cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

RESULT_DIR = Path("/kaggle/working/derived_image_operations")
RESULT_DIR.mkdir(parents=True, exist_ok=True)

MANIFEST = Path("/kaggle/working/results/01_data_integrity_and_profile/tab_01_dataset_manifest_valid_only.csv")
df = pd.read_csv(MANIFEST).reset_index(drop=True)

SEED = 1337
rng = np.random.default_rng(SEED)

samples = []
for site_id, g in df.groupby("site_id"):
    if len(g) > 0:
        samples.append(g.sample(n=1, random_state=int(rng.integers(1e9))))

SAMPLES = pd.concat(samples, ignore_index=True)

if len(SAMPLES) > 10:
    SAMPLES = SAMPLES.sample(n=10, random_state=SEED).reset_index(drop=True)

def read_rgb(p):
    return cv2.cvtColor(cv2.imread(p), cv2.COLOR_BGR2RGB)

def read_depth(p):
    d = cv2.imread(p, cv2.IMREAD_UNCHANGED).astype(np.float32)
    d = (d - d.min()) / (d.max() - d.min() + 1e-6)
    return d

def read_mask(p):
    return cv2.imread(p, cv2.IMREAD_GRAYSCALE) > 0

for _, row in SAMPLES.iterrows():
    rgb = read_rgb(row["rgb_path"])
    depth = read_depth(row["depth_path"])
    mask = read_mask(row["mask_path"])

    gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY)
    blur = cv2.GaussianBlur(rgb, (9, 9), 0)
    edges = cv2.Canny(gray, 80, 160)

    low_light = np.clip(rgb * 0.4, 0, 255).astype(np.uint8)

    contrast = cv2.normalize(rgb, None, 0, 255, cv2.NORM_MINMAX)

    depth_color = cv2.applyColorMap((depth * 255).astype(np.uint8), cv2.COLORMAP_INFERNO)

    mask_overlay = rgb.copy()
    mask_overlay[mask] = [255, 0, 0]

    rgb_depth_overlay = (0.6 * rgb + 0.4 * depth_color).astype(np.uint8)

    final_input = cv2.normalize(
        (0.7 * contrast + 0.3 * depth_color).astype(np.uint8),
        None, 0, 255, cv2.NORM_MINMAX
    )

    imgs = [
        rgb,
        gray,
        blur,
        edges,
        low_light,
        contrast,
        depth_color,
        mask_overlay,
        rgb_depth_overlay,
        final_input
    ]

    titles = [
        "RGB",
        "Grayscale",
        "Gaussian Blur",
        "Edges",
        "Low-light",
        "Contrast Stretch",
        "Depth (Color)",
        "Mask Overlay",
        "RGB + Depth",
        "RGB"
    ]

    plt.figure(figsize=(14, 6))
    for i, (im, t) in enumerate(zip(imgs, titles)):
        plt.subplot(2, 5, i + 1)
        if im.ndim == 2:
            plt.imshow(im, cmap="gray")
        else:
            plt.imshow(im)
        plt.title(t, fontsize=9)
        plt.axis("off")

    out_path = RESULT_DIR / f"{row['sample_id']}_operations.png"
    plt.tight_layout()
    plt.savefig(out_path, dpi=300)
    plt.close()
