In [2]:
"""
UNIVERSAL OCT DATASET CLEANER (SHA-256 ONLY)

What it does (NO pHash / visual dedupe):
1) If dataset has NO train/val/test split and has a CSV with patient_id (e.g., NEH/OCTDL):
   - Build manifest from CSV (or scan if CSV not provided)
   - SHA-256 exact dedupe (drop label-conflicts)
   - Patient-wise stratified split 70/15/15 (no leakage) if patient IDs exist
   - Materialize OUT/train|val|test/<label>/
2) If dataset ALREADY has train/val/test split (e.g., C8):
   - Scan folders root/{train,val,test}/{label}/...
   - SHA-256 exact dedupe with split priority (KEEP: test > val > train)
   - Materialize OUT with same split structure (does NOT reshuffle)
3) If dataset has NO CSV:
   - If already split: scan folders
   - If not split: scan ROOT/<label>/... and do IMAGE-wise split (warns: no patient leakage protection)

Extra:
- Prints detailed dataset stats BEFORE and AFTER:
  * total images
  * images per split
  * images per class
  * images per split x class
  * #unique patients (if available)
  * #images removed as duplicates, #removed as label-conflicts, #bad/unreadable
- Writes reports to: OUT_DIR/reports/
"""

import os
import sys
import shutil
import hashlib
import random
from pathlib import Path
from collections import Counter
from typing import Optional, List, Tuple, Dict

# -----------------------------
# Auto-install deps (optional)
# -----------------------------
def _ensure_packages():
    import importlib, subprocess
    pkgs = [
        ("pandas", "pandas"),
        ("numpy", "numpy"),
        ("tqdm", "tqdm"),
        ("sklearn", "scikit-learn"),
    ]
    missing = []
    for mod, pip_name in pkgs:
        try:
            importlib.import_module(mod)
        except Exception:
            missing.append(pip_name)
    if missing:
        print("[INFO] Installing missing packages:", missing)
        subprocess.check_call([sys.executable, "-m", "pip", "install", *missing])

_ensure_packages()

import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split

# ============================================================
# CONFIG (EDIT THESE)
# ============================================================

# Dataset root (input)
ROOT_DIR = Path(r"D:\AIUB\DSP\Code\Datasets\C8\RetinalOCT_Dataset")  # change per dataset

# Optional CSV (set to None if dataset has no CSV)
CSV_PATH = None
# Examples:
# CSV_PATH = Path(r"D:\AIUB\DSP\Code\Datasets\OCTDL\OCTDL_labels.csv")
# CSV_PATH = Path(r"D:\AIUB\DSP\Code\Datasets\NEH_UT_2021RetinalOCTDataset\V2\data_information.csv")

# Output root (will be created)
OUT_DIR = Path(r"D:\AIUB\DSP\Code\Datasets\C8\RetinalOCT_Dataset_CLEAN_SHAONLY")

# Mode: "auto" (recommended), "preserve_splits", "build_splits"
MODE = "auto"

# Split ratios (only used when building splits)
RANDOM_SEED = 42
SPLIT_RATIOS = {"train": 0.70, "val": 0.15, "test": 0.15}

# Split folder names we recognize
SPLITS = {"train", "val", "valid", "validation", "test"}
SPLIT_CANON = {"valid": "val", "validation": "val"}

# If duplicates span splits, keep higher-priority split:
# test > val > train
SPLIT_PRIORITY = {"test": 0, "val": 1, "train": 2, "unknown": 3}

# File extensions
VALID_EXT = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}

# Output method
COPY_MODE = "hardlink"  # "copy" | "hardlink" | "symlink"

# If you have patient IDs and want extra conservative exact dedupe:
# True => if same sha appears across different patients (same label), drop all those copies
DROP_CROSS_PATIENT_EXACT_DUPES = False

# CSV profile: "auto", "NEH", "OCTDL", "GENERIC"
CSV_PROFILE = "auto"

# NEH-specific (only used when CSV_PROFILE resolves to NEH)
# FILTER_MODE: "all" or "worstcase" (keep only rows where Class == Label)
FILTER_MODE = "all"

# Ignore these directories during scans
IGNORE_DIRS = {"reports", "dup_reports", ".git", "__pycache__", ".ipynb_checkpoints"}


# ============================================================
# Helpers
# ============================================================

def safe_mkdir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def write_csv_always(df: Optional[pd.DataFrame], path: Path, columns: List[str]):
    safe_mkdir(path.parent)
    if df is None or df.empty:
        pd.DataFrame(columns=columns).to_csv(path, index=False)
        return
    out = df.copy()
    for c in columns:
        if c not in out.columns:
            out[c] = ""
    out[columns].to_csv(path, index=False)

def normalize_label(x: str) -> str:
    x = str(x).strip().upper()
    if x == "HEALTHY":
        return "NORMAL"
    return x

def normalize_split(s: str) -> str:
    s = (s or "").lower().strip()
    return SPLIT_CANON.get(s, s)

def sha256_file(path: Path, chunk_size: int = 1024 * 1024) -> str:
    h = hashlib.sha256()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(chunk_size), b""):
            h.update(chunk)
    return h.hexdigest()

def stable_mode(items: List[str]) -> str:
    c = Counter(items)
    max_ct = max(c.values())
    winners = sorted([k for k, v in c.items() if v == max_ct])
    return winners[0]

def detect_has_splits(root: Path) -> bool:
    if not root.exists():
        return False
    kids = {p.name.lower() for p in root.iterdir() if p.is_dir()}
    return any(s in kids for s in {"train", "val", "test", "valid", "validation"})

def infer_split_label_from_path(dataset_root: Path, img_path: Path) -> Tuple[str, str, Path]:
    rel = img_path.relative_to(dataset_root)
    parts = rel.parts

    split = "unknown"
    split_idx = None
    for i, part in enumerate(parts):
        pl = part.lower()
        if pl in SPLITS:
            split = normalize_split(pl)
            split_idx = i
            break

    if split_idx is not None and split_idx + 1 < len(parts):
        label = normalize_label(parts[split_idx + 1])
        return split, label, rel

    label = normalize_label(parts[0]) if len(parts) >= 1 else "UNKNOWN"
    return "unknown", label, rel

def iter_images_scan(root: Path):
    for p in root.rglob("*"):
        if not p.is_file():
            continue
        if p.suffix.lower() not in VALID_EXT:
            continue
        if any(part in IGNORE_DIRS for part in p.parts):
            continue
        yield p

def canonical_sort(df_group: pd.DataFrame, has_splits: bool) -> pd.DataFrame:
    g = df_group.copy()
    if has_splits:
        g["prio"] = g["split"].map(lambda s: SPLIT_PRIORITY.get(str(s), 3))
        return g.sort_values(["prio", "rel_path"], ascending=True)
    return g.sort_values(["path"], ascending=True)

def place_file(src: Path, dst: Path):
    safe_mkdir(dst.parent)
    if dst.exists():
        return
    if COPY_MODE == "copy":
        shutil.copy2(str(src), str(dst))
    elif COPY_MODE == "hardlink":
        os.link(str(src), str(dst))
    elif COPY_MODE == "symlink":
        os.symlink(str(src), str(dst))
    else:
        raise ValueError(f"Unknown COPY_MODE: {COPY_MODE}")

def detect_extensions_from_disk(source_root: Path, class_names: List[str]) -> List[str]:
    exts = set()
    for c in class_names:
        d = source_root / c
        if not d.exists():
            continue
        for p in d.iterdir():
            if p.is_file():
                exts.add(p.suffix.lower())
    if not exts:
        exts = set(VALID_EXT)
    return sorted(exts)

def resolve_octdl_style_path(source_root: Path, disease: str, file_name: str, exts: List[str]) -> Optional[Path]:
    disease = normalize_label(disease)
    cls_dir = source_root / disease
    if not cls_dir.exists():
        return None

    fn = str(file_name).strip()
    p = Path(fn)

    if p.suffix:
        cand = cls_dir / fn
        if cand.exists():
            return cand
        matches = list(cls_dir.glob(p.name))
        return matches[0] if len(matches) == 1 else None

    for ext in exts:
        cand = cls_dir / f"{fn}{ext}"
        if cand.exists():
            return cand

    matches = list(cls_dir.glob(f"{fn}.*"))
    return matches[0] if len(matches) == 1 else None

def guess_csv_profile(df: pd.DataFrame) -> str:
    cols = set(df.columns)
    if {"Patient ID", "Class", "Label", "Directory"}.issubset(cols):
        return "NEH"
    if {"file_name", "disease", "patient_id"}.issubset(set(map(str.lower, df.columns))):
        return "OCTDL"
    return "GENERIC"


# ============================================================
# Stats helpers (BEFORE/AFTER reporting)
# ============================================================

def summarize_manifest(m: pd.DataFrame, title: str):
    print("\n" + "="*90)
    print(f"[STATS] {title}")
    print("="*90)

    total = len(m)
    print(f"Total images: {total:,}")

    if "split" in m.columns:
        split_counts = m["split"].fillna("unknown").astype(str).value_counts()
        print("\nImages per split:")
        print(split_counts.to_string())

    if "label" in m.columns:
        label_counts = m["label"].fillna("UNKNOWN").astype(str).value_counts()
        print("\nImages per class:")
        print(label_counts.to_string())

    if "split" in m.columns and "label" in m.columns:
        pivot = (
            m.assign(split=m["split"].fillna("unknown").astype(str),
                     label=m["label"].fillna("UNKNOWN").astype(str))
             .groupby(["split","label"]).size()
             .unstack(fill_value=0)
        )
        print("\nImages per split x class:")
        print(pivot)

    if "patient_uid" in m.columns:
        pu = m["patient_uid"].astype(str).str.strip()
        if pu.ne("").any():
            print(f"\nUnique patients (patient_uid): {pu[pu.ne('')].nunique():,}")
        else:
            print("\nUnique patients (patient_uid): N/A (not provided)")

def write_stats_csvs(m: pd.DataFrame, out_dir: Path, prefix: str):
    safe_mkdir(out_dir)
    m2 = m.copy()
    m2["split"] = m2.get("split", "unknown")
    m2["label"] = m2.get("label", "UNKNOWN")

    split_counts = m2["split"].astype(str).value_counts().reset_index()
    split_counts.columns = ["split", "count"]
    split_counts.to_csv(out_dir / f"{prefix}_counts_by_split.csv", index=False)

    label_counts = m2["label"].astype(str).value_counts().reset_index()
    label_counts.columns = ["label", "count"]
    label_counts.to_csv(out_dir / f"{prefix}_counts_by_label.csv", index=False)

    pivot = (
        m2.groupby(["split","label"]).size().reset_index(name="count")
    )
    pivot.to_csv(out_dir / f"{prefix}_counts_by_split_label.csv", index=False)


# ============================================================
# Manifest builders
# ============================================================

def build_manifest_from_csv(root_dir: Path, csv_path: Path, has_splits: bool) -> Tuple[pd.DataFrame, pd.DataFrame]:
    df = pd.read_csv(csv_path)
    profile = CSV_PROFILE
    if profile == "auto":
        profile = guess_csv_profile(df)

    rows, bad = [], []

    if profile == "NEH":
        required = {"Patient ID", "Class", "Label", "Directory"}
        missing = required - set(df.columns)
        if missing:
            raise RuntimeError(f"NEH CSV missing columns: {missing}")

        df = df.copy()
        df["Label"] = df["Label"].apply(normalize_label)
        df["Class"] = df["Class"].apply(normalize_label)

        if FILTER_MODE.lower() == "worstcase":
            df = df[df["Class"] == df["Label"]].copy()

        for _, r in df.iterrows():
            rel = str(r["Directory"]).lstrip("/\\")
            img_path = root_dir / rel
            if img_path.suffix.lower() not in VALID_EXT:
                continue
            if not img_path.exists():
                bad.append({"row": rel, "error": "missing_on_disk"})
                continue

            patient_id = str(r["Patient ID"]).strip()
            patient_class = str(r["Class"]).strip()
            patient_uid = f"{patient_class}_{patient_id}"

            split = "unknown"
            if has_splits:
                split, _, _rel2 = infer_split_label_from_path(root_dir, img_path)
                rel_path = str(_rel2)
            else:
                rel_path = str(img_path.relative_to(root_dir))

            rows.append({
                "path": str(img_path),
                "rel_path": rel_path,
                "label": str(r["Label"]),
                "split": split,
                "patient_id": patient_id,
                "patient_uid": patient_uid,
                "patient_class": patient_class,
                "source": "csv(NEH)",
            })

    elif profile == "OCTDL":
        cols_lower = {c.lower(): c for c in df.columns}
        need = {"file_name", "disease", "patient_id"}
        if not need.issubset(set(cols_lower.keys())):
            raise RuntimeError(f"OCTDL CSV missing columns. Need {need}, got {set(cols_lower.keys())}")

        fn_col = cols_lower["file_name"]
        dis_col = cols_lower["disease"]
        pid_col = cols_lower["patient_id"]

        df = df.copy()
        df[dis_col] = df[dis_col].apply(normalize_label)
        df[fn_col] = df[fn_col].astype(str).str.strip()
        df[pid_col] = df[pid_col].astype(str).str.strip()

        class_names = sorted(df[dis_col].unique().tolist())
        exts = detect_extensions_from_disk(root_dir, class_names)

        for r in df.itertuples(index=False):
            disease = normalize_label(getattr(r, dis_col))
            file_name = str(getattr(r, fn_col))
            patient_id = str(getattr(r, pid_col))

            img_path = resolve_octdl_style_path(root_dir, disease, file_name, exts)
            if img_path is None or not img_path.exists():
                bad.append({"row": f"{disease}/{file_name}", "error": "missing_or_unresolved"})
                continue
            if img_path.suffix.lower() not in VALID_EXT:
                continue

            rows.append({
                "path": str(img_path),
                "rel_path": str(img_path.relative_to(root_dir)),
                "label": disease,
                "split": "unknown",
                "patient_id": patient_id,
                "patient_uid": patient_id,
                "patient_class": "",
                "source": "csv(OCTDL)",
            })

    else:
        cols = list(df.columns)
        cols_lower = {c.lower(): c for c in cols}

        path_col = None
        for cand in ["path", "file_path", "filepath", "image_path", "directory", "file", "filename", "file_name"]:
            if cand in cols_lower:
                path_col = cols_lower[cand]
                break

        label_col = None
        for cand in ["label", "class", "disease", "diagnosis", "category"]:
            if cand in cols_lower:
                label_col = cols_lower[cand]
                break

        patient_col = None
        for cand in ["patient_id", "patient id", "subject_id", "subject id", "case_id", "case id"]:
            if cand in cols_lower:
                patient_col = cols_lower[cand]
                break

        if path_col is None or label_col is None:
            raise RuntimeError(f"GENERIC CSV needs path+label columns. Got columns: {cols}")

        for _, r in df.iterrows():
            raw_path = str(r[path_col]).strip()
            label = normalize_label(str(r[label_col]))

            img_path = Path(raw_path)
            if not img_path.is_absolute():
                img_path = root_dir / raw_path

            if img_path.suffix.lower() not in VALID_EXT:
                continue
            if not img_path.exists():
                bad.append({"row": raw_path, "error": "missing_on_disk"})
                continue

            pid = str(r[patient_col]).strip() if patient_col else ""
            patient_uid = pid if pid else ""

            split = "unknown"
            if has_splits:
                split, _, rel = infer_split_label_from_path(root_dir, img_path)
                rel_path = str(rel)
            else:
                rel_path = str(img_path.relative_to(root_dir))

            rows.append({
                "path": str(img_path),
                "rel_path": rel_path,
                "label": label,
                "split": split,
                "patient_id": pid,
                "patient_uid": patient_uid,
                "patient_class": "",
                "source": "csv(GENERIC)",
            })

    manifest = pd.DataFrame(rows).drop_duplicates(subset=["path"]).reset_index(drop=True)
    bad_df = pd.DataFrame(bad)
    return manifest, bad_df

def build_manifest_by_scanning(root_dir: Path, has_splits: bool) -> pd.DataFrame:
    rows = []
    for p in iter_images_scan(root_dir):
        split, label, rel = infer_split_label_from_path(root_dir, p)

        patient_id = ""
        patient_uid = ""
        patient_class = ""

        if not has_splits:
            rel_parts = rel.parts
            if len(rel_parts) >= 2:
                patient_id = str(rel_parts[1])
                patient_uid = patient_id
                patient_class = label

        rows.append({
            "path": str(p),
            "rel_path": str(rel),
            "label": label,
            "split": split if has_splits else "unknown",
            "patient_id": patient_id,
            "patient_uid": patient_uid,
            "patient_class": patient_class,
            "source": "scan",
        })
    return pd.DataFrame(rows).drop_duplicates(subset=["path"]).reset_index(drop=True)


# ============================================================
# Dedupe (SHA-256 ONLY)
# ============================================================

def dedupe_sha256(manifest: pd.DataFrame, has_splits: bool, reports_dir: Path) -> Tuple[pd.DataFrame, Dict[str,int]]:
    print("[INFO] Computing SHA-256 (exact duplicates only)...")
    sha_list = []
    bad = []

    for p in tqdm(manifest["path"].tolist(), desc="sha256", unit="img"):
        try:
            sha_list.append(sha256_file(Path(p)))
        except Exception as e:
            sha_list.append(None)
            bad.append({"path": p, "stage": "sha256", "error": str(e)})

    m = manifest.copy()
    m["sha256"] = sha_list
    m = m.dropna(subset=["sha256"]).reset_index(drop=True)

    write_csv_always(pd.DataFrame(bad), reports_dir / "bad_files_hashing.csv",
                     ["path", "stage", "error"])

    conflict_rows = []
    removed_rows = []
    keep_rows = []

    dup_groups = 0
    removed_dupes = 0
    removed_conflicts = 0

    for sha, g in m.groupby("sha256", sort=False):
        if len(g) == 1:
            keep_rows.append(g.iloc[0].to_dict())
            continue

        dup_groups += 1
        labels = set(g["label"].tolist())

        # If same bytes map to multiple labels => drop all (conservative)
        if len(labels) > 1:
            removed_conflicts += len(g)
            for _, rr in g.iterrows():
                conflict_rows.append({**rr.to_dict(), "reason": "sha256_label_conflict_drop_all"})
            continue

        patients = set([x for x in g["patient_uid"].tolist() if str(x).strip() != ""])
        if DROP_CROSS_PATIENT_EXACT_DUPES and len(patients) > 1:
            removed_conflicts += len(g)
            for _, rr in g.iterrows():
                removed_rows.append({
                    "sha256": sha,
                    "canonical_path": "",
                    "removed_path": rr["path"],
                    "canonical_split": "",
                    "removed_split": rr["split"],
                    "label": rr["label"],
                    "reason": "sha256_cross_patient_drop_all",
                })
            continue

        g_sorted = canonical_sort(g, has_splits)
        canon = g_sorted.iloc[0].to_dict()
        keep_rows.append(canon)

        for i in range(1, len(g_sorted)):
            rr = g_sorted.iloc[i].to_dict()
            removed_dupes += 1
            removed_rows.append({
                "sha256": sha,
                "canonical_path": canon["path"],
                "removed_path": rr["path"],
                "canonical_split": canon.get("split", "unknown"),
                "removed_split": rr.get("split", "unknown"),
                "label": canon["label"],
                "reason": "sha256_exact_duplicate_keep_one",
            })

    kept = pd.DataFrame(keep_rows).reset_index(drop=True)

    write_csv_always(pd.DataFrame(conflict_rows), reports_dir / "removed_conflicts_sha256.csv",
                     ["path","rel_path","split","label","patient_uid","sha256","reason","source"])
    write_csv_always(pd.DataFrame(removed_rows), reports_dir / "removed_duplicates_sha256.csv",
                     ["sha256","canonical_path","removed_path","canonical_split","removed_split","label","reason"])
    write_csv_always(kept, reports_dir / "manifest_after_sha256.csv",
                     ["path","rel_path","split","label","patient_uid","patient_id","patient_class","sha256","source"])

    stats = {
        "bad_hash_reads": len(bad),
        "duplicate_groups_sha256": dup_groups,
        "removed_duplicates_sha256": removed_dupes,
        "removed_conflicts_sha256": removed_conflicts,
        "kept_after_sha256": len(kept),
    }

    print(f"[INFO] After SHA-256: kept {len(kept):,} images")
    print("[INFO] SHA-256 dedupe stats:", stats)
    return kept, stats


# ============================================================
# Splitting + Materialization
# ============================================================

def assign_splits(m: pd.DataFrame, has_splits: bool, reports_dir: Path) -> pd.DataFrame:
    if has_splits:
        out = m.copy()
        out["split"] = out["split"].fillna("unknown").map(lambda s: normalize_split(str(s)))
        return out

    out = m.copy()
    have_patient = out["patient_uid"].astype(str).str.strip().ne("").any()

    if have_patient:
        patient_primary = (
            out.groupby("patient_uid")["label"]
               .apply(lambda s: stable_mode(s.tolist()))
               .to_dict()
        )
        out["patient_class"] = out["patient_uid"].map(patient_primary)

        patient_df = pd.DataFrame({
            "patient_uid": list(patient_primary.keys()),
            "patient_class": list(patient_primary.values()),
        })

        patients = patient_df["patient_uid"].tolist()
        strat = patient_df["patient_class"].tolist()

        counts = pd.Series(strat).value_counts()
        can_stratify = (len(counts) >= 2) and (counts.min() >= 2)
        if not can_stratify:
            print("[WARN] Not enough patients per class for stratify; falling back to non-stratified patient split.")
            strat = None

        train_pat, temp_pat = train_test_split(
            patients,
            test_size=(1.0 - SPLIT_RATIOS["train"]),
            random_state=RANDOM_SEED,
            stratify=strat
        )

        val_ratio = SPLIT_RATIOS["val"] / (SPLIT_RATIOS["val"] + SPLIT_RATIOS["test"])
        if strat is not None:
            strat_map = dict(zip(patients, patient_df["patient_class"].tolist()))
            temp_strat = [strat_map[p] for p in temp_pat]
        else:
            temp_strat = None

        val_pat, test_pat = train_test_split(
            temp_pat,
            test_size=(1.0 - val_ratio),
            random_state=RANDOM_SEED,
            stratify=temp_strat
        )

        split_map = {pid: "train" for pid in train_pat}
        split_map.update({pid: "val" for pid in val_pat})
        split_map.update({pid: "test" for pid in test_pat})

        out["split"] = out["patient_uid"].map(split_map)

        assert out["split"].isna().sum() == 0, "Some rows did not get a split assignment."
        assert out.groupby("patient_uid")["split"].nunique().max() == 1, "Patient leakage detected."

        pat_summary = (
            out[["patient_uid", "split", "patient_class"]]
            .drop_duplicates()
            .groupby(["split", "patient_class"])
            .size()
            .unstack(fill_value=0)
        )
        pat_summary.to_csv(reports_dir / "split_summary_patients.csv")
        print("[INFO] Patient-wise split done (no leakage).")
        return out

    print("[WARN] No patient IDs available; doing IMAGE-wise stratified split (leakage possible).")
    labels = out["label"].tolist()
    idxs = list(range(len(out)))

    counts = pd.Series(labels).value_counts()
    can_stratify = (len(counts) >= 2) and (counts.min() >= 2)
    strat = labels if can_stratify else None

    train_idx, temp_idx = train_test_split(
        idxs,
        test_size=(1.0 - SPLIT_RATIOS["train"]),
        random_state=RANDOM_SEED,
        stratify=strat
    )

    val_ratio = SPLIT_RATIOS["val"] / (SPLIT_RATIOS["val"] + SPLIT_RATIOS["test"])

    if strat is not None:
        temp_labels = [labels[i] for i in temp_idx]
    else:
        temp_labels = None

    val_idx, test_idx = train_test_split(
        temp_idx,
        test_size=(1.0 - val_ratio),
        random_state=RANDOM_SEED,
        stratify=temp_labels
    )

    split = [""] * len(out)
    for i in train_idx: split[i] = "train"
    for i in val_idx:   split[i] = "val"
    for i in test_idx:  split[i] = "test"
    out["split"] = split
    return out

def materialize(m: pd.DataFrame, out_dir: Path, reports_dir: Path):
    print("[INFO] Writing output dataset...")
    safe_mkdir(out_dir)
    safe_mkdir(reports_dir)

    splits = ["train", "val", "test"]
    labels = sorted(set(m["label"].tolist()))
    for sp in splits:
        for lab in labels:
            safe_mkdir(out_dir / sp / lab)

    out_paths = []
    bad_writes = []

    for _, r in tqdm(m.iterrows(), total=len(m), desc="writing", unit="img"):
        src = Path(r["path"])
        split = str(r["split"])
        label = str(r["label"])

        sha8 = str(r.get("sha256", ""))[:8] if pd.notna(r.get("sha256", "")) else "nosha"
        puid = str(r.get("patient_uid", "")).strip()
        puid_safe = puid.replace("/", "_").replace("\\", "_") if puid else ""

        if puid_safe:
            dst_name = f"{puid_safe}__{sha8}__{src.name}"
        else:
            dst_name = f"{sha8}__{src.name}"

        dst = out_dir / split / label / dst_name

        try:
            place_file(src, dst)
            out_paths.append(str(dst))
        except Exception as e:
            bad_writes.append({"src": str(src), "dst": str(dst), "error": str(e)})

    m2 = m.copy().reset_index(drop=True)
    m2["new_path"] = out_paths[:len(m2)]

    write_csv_always(pd.DataFrame(bad_writes), reports_dir / "bad_writes.csv",
                     ["src", "dst", "error"])

    m2.to_csv(out_dir / "final_manifest.csv", index=False)

    img_summary = m2.groupby(["split", "label"]).size().unstack(fill_value=0)
    img_summary.to_csv(out_dir / "split_summary_images.csv")

    print("\n[INFO] Images per split/label:\n", img_summary)
    print("\n[DONE] Output root:", out_dir)
    print("[DONE] Reports:", reports_dir)


# ============================================================
# Main
# ============================================================

def main():
    assert abs(sum(SPLIT_RATIOS.values()) - 1.0) < 1e-9, "Split ratios must sum to 1.0"
    random.seed(RANDOM_SEED)
    np.random.seed(RANDOM_SEED)

    safe_mkdir(OUT_DIR)
    reports_dir = OUT_DIR / "reports"
    safe_mkdir(reports_dir)

    has_splits = detect_has_splits(ROOT_DIR)

    if MODE == "preserve_splits":
        has_splits = True
    elif MODE == "build_splits":
        has_splits = False
    elif MODE != "auto":
        raise ValueError("MODE must be one of: auto, preserve_splits, build_splits")

    print("[INFO] ROOT_DIR:", ROOT_DIR)
    print("[INFO] CSV_PATH:", CSV_PATH)
    print("[INFO] MODE:", MODE, "| detected_has_splits:", has_splits)

    # 1) Build manifest
    bad_csv = pd.DataFrame()
    if CSV_PATH is not None and Path(CSV_PATH).exists():
        manifest, bad_csv = build_manifest_from_csv(ROOT_DIR, Path(CSV_PATH), has_splits)
    else:
        manifest = build_manifest_by_scanning(ROOT_DIR, has_splits)

    if manifest.empty:
        raise RuntimeError("No images found/resolved. Check ROOT_DIR/CSV_PATH and folder structure.")

    write_csv_always(manifest, reports_dir / "manifest_raw.csv",
                     ["path","rel_path","split","label","patient_uid","patient_id","patient_class","source"])
    if not bad_csv.empty:
        write_csv_always(bad_csv, reports_dir / "csv_missing_or_unresolved.csv", list(bad_csv.columns))
        print(f"[WARN] CSV unresolved rows: {len(bad_csv)} (see reports)")

    # ---- BEFORE stats ----
    summarize_manifest(manifest, "BEFORE DEDUPLICATION (raw manifest)")
    write_stats_csvs(manifest, reports_dir, "before")

    # 2) SHA-256 exact dedupe ONLY
    m, sha_stats = dedupe_sha256(manifest, has_splits=has_splits, reports_dir=reports_dir)

    # 3) Assign / preserve splits
    m = assign_splits(m, has_splits=has_splits, reports_dir=reports_dir)

    # Drop unknown split rows (optional but keeps train/val/test clean)
    unknown = m["split"].astype(str).str.lower().eq("unknown")
    if unknown.any():
        dropped = int(unknown.sum())
        print(f"[WARN] Dropping {dropped} rows with split='unknown' (path not under train/val/test).")
        write_csv_always(m[unknown], reports_dir / "dropped_unknown_split.csv",
                         ["path","rel_path","split","label","sha256","source"])
        m = m[~unknown].copy().reset_index(drop=True)

    # ---- AFTER stats (post-dedupe + split assignment) ----
    summarize_manifest(m, "AFTER SHA-256 DEDUPLICATION (and split assignment)")
    write_stats_csvs(m, reports_dir, "after")

    # Summary line items
    print("\n" + "="*90)
    print("[SUMMARY] KEY NUMBERS")
    print("="*90)
    print(f"Before: {len(manifest):,} images")
    print(f"After : {len(m):,} images")
    print(f"Removed exact duplicates (kept one per SHA group): {sha_stats['removed_duplicates_sha256']:,}")
    print(f"Removed label-conflict groups (dropped all in group): {sha_stats['removed_conflicts_sha256']:,}")
    print(f"Bad/unreadable during hashing: {sha_stats['bad_hash_reads']:,}")
    print(f"Duplicate groups found (SHA): {sha_stats['duplicate_groups_sha256']:,}")
    print("="*90 + "\n")

    # 4) Materialize output dataset
    materialize(m, OUT_DIR, reports_dir)

if __name__ == "__main__":
    main()


[INFO] ROOT_DIR: D:\AIUB\DSP\Code\Datasets\C8\RetinalOCT_Dataset
[INFO] CSV_PATH: None
[INFO] MODE: auto | detected_has_splits: True

[STATS] BEFORE DEDUPLICATION (raw manifest)
Total images: 24,000

Images per split:
split
train    18400
test      2800
val       2800

Images per class:
label
AMD       3000
CNV       3000
CSR       3000
DME       3000
DR        3000
DRUSEN    3000
MH        3000
NORMAL    3000

Images per split x class:
label   AMD   CNV   CSR   DME    DR  DRUSEN    MH  NORMAL
split                                                    
test    350   350   350   350   350     350   350     350
train  2300  2300  2300  2300  2300    2300  2300    2300
val     350   350   350   350   350     350   350     350

Unique patients (patient_uid): N/A (not provided)
[INFO] Computing SHA-256 (exact duplicates only)...


sha256: 100%|██████████████████████████████████████████████████████████████████| 24000/24000 [01:49<00:00, 219.84img/s]


[INFO] After SHA-256: kept 23,850 images
[INFO] SHA-256 dedupe stats: {'bad_hash_reads': 0, 'duplicate_groups_sha256': 144, 'removed_duplicates_sha256': 138, 'removed_conflicts_sha256': 12, 'kept_after_sha256': 23850}

[STATS] AFTER SHA-256 DEDUPLICATION (and split assignment)
Total images: 23,850

Images per split:
split
train    18260
test      2797
val       2793

Images per class:
label
AMD       3000
CSR       3000
MH        3000
DR        3000
NORMAL    2996
DME       2970
CNV       2954
DRUSEN    2930

Images per split x class:
label   AMD   CNV   CSR   DME    DR  DRUSEN    MH  NORMAL
split                                                    
test    350   350   350   349   350     349   350     349
train  2300  2257  2300  2272  2300    2233  2300    2298
val     350   347   350   349   350     348   350     349

Unique patients (patient_uid): N/A (not provided)

[SUMMARY] KEY NUMBERS
Before: 24,000 images
After : 23,850 images
Removed exact duplicates (kept one per SHA group): 

writing: 100%|████████████████████████████████████████████████████████████████| 23850/23850 [00:10<00:00, 2384.96img/s]



[INFO] Images per split/label:
 label   AMD   CNV   CSR   DME    DR  DRUSEN    MH  NORMAL
split                                                    
test    350   350   350   349   350     349   350     349
train  2300  2257  2300  2272  2300    2233  2300    2298
val     350   347   350   349   350     348   350     349

[DONE] Output root: D:\AIUB\DSP\Code\Datasets\C8\RetinalOCT_Dataset_CLEAN_SHAONLY
[DONE] Reports: D:\AIUB\DSP\Code\Datasets\C8\RetinalOCT_Dataset_CLEAN_SHAONLY\reports
