In [None]:
from __future__ import annotations

# =========================
# Determinism (set BEFORE importing torch)
# =========================
import os  # [DET]

os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":16:8")  # [DET]

import re
from pathlib import Path
from collections import Counter, defaultdict
from typing import List, Tuple, Dict, Optional

import numpy as np
import random
import hashlib  # [DET]
import warnings

warnings.filterwarnings(
    "ignore", message=".*gpu_hist.*deprecated.*", category=UserWarning
)
warnings.filterwarnings(
    "ignore",
    message='.*Parameters: { "predictor" } are not used.*',
    category=UserWarning,
)

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image

from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier

import timm

# =========================
# Config
# =========================
DATA_ROOT = Path(r"DataSet")
OUTDIR = Path("checkpoints_subject5cv_per_draw_tiles_WINS")

DRAW_TYPES = ["circle", "meander", "spiral"]

EPOCHS_FN = {"circle": 5, "meander": 10, "spiral": 5}
LR_INIT = {"circle": 3.16e-4, "meander": 1e-3, "spiral": 3.16e-4}
BATCH_SIZE = {"circle": 64, "meander": 16, "spiral": 16}
LS_EPS = {"circle": 0.05, "meander": 0.05, "spiral": 0.0}
WEIGHT_DECAY = {"circle": 1e-4, "meander": 0.0, "spiral": 1e-4}
NUM_WORKERS = 4
SEED = 42

PRINT_PATHS = False

# ---- Tiling
TILE_GRID = (2, 2)  # 2×2 tiles per image
TILE_RESIZE_BASE = 448  # resize to 448, then crop 4×(224×224)
FINAL_SIZE = 224

# ---- Feature backbone mode (your best runs used resnet features)
FEAT_MODE = "resnet"

# ---- WINNING CLASSIFIER SPECS (fixed; from your “best” logs)
# Circle (KNN): scaler=StandardScaler, no PCA, n_neighbors=1, p=1, weights=uniform, leaf_size=15
# Meander (DT):  gini, max_depth=None, min_samples_split=2, min_samples_leaf=1, max_features=None, class_weight=None
# Spiral (RF):   n_estimators=100, max_depth=None, min_samples_split=2, min_samples_leaf=1, max_features='sqrt',
#                class_weight=None, n_jobs=-1
WINNERS = {
    "circle": Pipeline(
        [
            ("scaler", StandardScaler(with_mean=True, with_std=True)),
            (
                "clf",
                KNeighborsClassifier(
                    n_neighbors=1,
                    weights="uniform",
                    p=1,
                    metric="minkowski",
                    leaf_size=15,
                ),
            ),
        ]
    ),
    "meander": Pipeline(
        [
            (
                "clf",
                DecisionTreeClassifier(
                    random_state=SEED,
                    criterion="gini",
                    max_depth=None,
                    min_samples_split=2,
                    min_samples_leaf=1,
                    max_features=None,
                    class_weight=None,
                ),
            ),
        ]
    ),
    "spiral": Pipeline(
        [
            (
                "clf",
                KNeighborsClassifier(
                    n_neighbors=1,
                    weights="uniform",
                    p=1,
                    metric="minkowski",
                    leaf_size=15,
                ),
            ),
        ]
    ),
}

# =========================
# Repro / determinism
# =========================
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False  # [DET]
    torch.backends.cudnn.deterministic = True  # [DET]
    torch.backends.cuda.matmul.allow_tf32 = False  # [DET]
    torch.backends.cudnn.allow_tf32 = False  # [DET]
    torch.use_deterministic_algorithms(True)  # [DET]


set_seed(SEED)


# ---------- Deterministic DataLoader helpers ----------
def _seed_from_key(base_seed: int, *parts: str) -> int:  # [DET]
    h = hashlib.sha256()
    h.update(str(base_seed).encode())
    for p in parts:
        h.update(str(p).encode())
    return int.from_bytes(h.digest()[:8], "little")


def _worker_init_fn(worker_id: int):  # [DET]
    ws = torch.initial_seed() % (2**32)
    np.random.seed(ws)
    random.seed(ws)


def _make_loader(
    dataset,
    *,
    shuffle: bool,
    num_workers: int,
    base_seed: int,
    key: str,
    batch_size: int,
):  # [DET]
    g = torch.Generator(device="cpu")
    g.manual_seed(_seed_from_key(base_seed, key))
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,
        worker_init_fn=_worker_init_fn,
        generator=g,
        persistent_workers=False,
        drop_last=False,
    )


# ---------- Fold-level reseed to lock init per draw×fold ----------
def make_fold_seed(*parts: str) -> int:
    h = hashlib.sha256()
    h.update(str(SEED).encode())
    for p in parts:
        h.update(str(p).encode())
    return int.from_bytes(h.digest()[:8], "little")


def reseed_for_fold(draw: str, fold_idx: int):
    s = make_fold_seed("fold", draw, str(fold_idx))
    random.seed(s)
    np.random.seed(s % (2**32))
    torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)


# ------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# =========================
# Data discovery (DETERMINISTIC)
# =========================
_SUBJ_PAT = re.compile(r"(?i)-([HP])0*(\d+)$")
_SUBJ_ANY = re.compile(r"(?i)-([HP])0*(\d+)\b")


def extract_subject_id(fp: str) -> str:
    stem = Path(fp).stem
    m = _SUBJ_PAT.search(stem) or _SUBJ_ANY.search(stem)
    if not m:
        raise ValueError(f"[subject_id] Could not parse subject from: {fp}")
    prefix = m.group(1).upper()
    num = str(int(m.group(2)))
    return f"{prefix}{num}"


_DRAW_MAP = {
    "circle": ["circle", "circ"],
    "meander": ["meander", "meand", "wave", "wav"],
    "spiral": ["spiral", "spir", "sp"],
}


def _norm_draw_folder(name: str):
    s = name.lower()
    for canon, toks in _DRAW_MAP.items():
        if any(tok == s or tok in s for tok in toks):
            return canon
    return None


def discover_items_tree(data_root: Path) -> List[Tuple[str, int, str, str]]:
    exts = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
    items: List[Tuple[str, int, str, str]] = []
    seen: set[str] = set()

    for hlab_name, lab in [("Healthy", 0), ("Patient", 1)]:
        lab_dir = data_root / hlab_name
        if not lab_dir.exists():
            print(f"[warn] Missing folder: {lab_dir}")
            continue
        for sub in sorted(lab_dir.iterdir(), key=lambda p: p.name.lower()):
            if not sub.is_dir():
                continue
            d = _norm_draw_folder(sub.name)
            if d is None:
                continue
            files = sorted(
                (p for p in sub.rglob("*") if p.is_file() and p.suffix.lower() in exts),
                key=lambda p: str(p).lower(),
            )
            for p in files:
                rp = str(p.resolve())
                if rp in seen:
                    continue
                seen.add(rp)
                sid = extract_subject_id(rp)
                items.append((rp, lab, d, sid))

    items.sort(key=lambda t: (t[1], t[2], t[3], t[0]))
    print(f"[info] Found {len(items)} images")
    if items:
        print("  by label    :", Counter([lab for _, lab, _, _ in items]))
        print("  by draw type:", Counter([d for _, _, d, _ in items]))
        print("  subjects    :", len(set([sid for *_, sid in items])))
    return items


# =========================
# Transforms / Tiling (per-image max only; NO mean/std)
# =========================
class PerImageMaxNormalize(object):
    def __call__(self, img: torch.Tensor):
        cmax = torch.amax(img, dim=(1, 2), keepdim=True)
        return img / (cmax + 1e-8)


_RESIZE_448 = transforms.Resize(
    (TILE_RESIZE_BASE, TILE_RESIZE_BASE), interpolation=InterpolationMode.BILINEAR
)
_TO_TENSOR = transforms.ToTensor()
_TO_PIL = transforms.ToPILImage()
_NORM_PERIMG = PerImageMaxNormalize()


def _resize_to_base(pil_img: Image.Image) -> Image.Image:
    return _RESIZE_448(pil_img)


def _crop_tile_from_resized(
    pil_img_448: Image.Image, tile_index: int, rows: int = 2, cols: int = 2
) -> Image.Image:
    W, H = pil_img_448.size
    tile_w, tile_h = W // cols, H // rows
    r = tile_index // cols
    c = tile_index % cols
    left = c * tile_w
    top = r * tile_h
    return pil_img_448.crop((left, top, left + tile_w, top + tile_h))


def _tile_to_tensor(pil_tile: Image.Image) -> torch.Tensor:
    t = _TO_TENSOR(pil_tile)  # [0,1]
    t = _NORM_PERIMG(t)
    return t


# =========================
# Datasets (NO AUGMENTATION)
# =========================
class TrainTilesDataset(Dataset):
    """Each image is resized once and split into 2×2 tiles; no augmentation, no repeats."""
    def __init__(self, items):
        self.records = []  # (fp, lab, draw, sid, tile_index)
        for fp, lab, d, sid in items:
            for tile_idx in range(TILE_GRID[0] * TILE_GRID[1]):
                self.records.append((fp, int(lab), d, sid, tile_idx))

    def __len__(self):
        return len(self.records)

    def __getitem__(self, i):
        fp, lab, d, sid, tile_idx = self.records[i]
        pil_img = Image.open(fp).convert("RGB")
        img448 = _resize_to_base(pil_img)
        pil_tile = _crop_tile_from_resized(img448, tile_idx, *TILE_GRID)
        x = _tile_to_tensor(pil_tile)
        return x, lab


class TilesWithMetaDataset(Dataset):
    def __init__(self, items):
        self.records = []  # (fp, lab, sid, tile_index)
        for fp, lab, _d, sid in items:
            for tile_idx in range(TILE_GRID[0] * TILE_GRID[1]):
                self.records.append((fp, int(lab), sid, tile_idx))

    def __len__(self):
        return len(self.records)

    def __getitem__(self, i):
        fp, lab, sid, tile_idx = self.records[i]
        pil_img = Image.open(fp).convert("RGB")
        img448 = _resize_to_base(pil_img)
        pil_tile = _crop_tile_from_resized(img448, tile_idx, *TILE_GRID)
        x = _tile_to_tensor(pil_tile)
        return x, lab, sid, fp


# =========================
# Backbone model (feature extractor)
# =========================
class ResNetPVTRep(nn.Module):
    def __init__(
        self,
        mode: str = "both",
        *,
        resnet_pretrained: bool = True,
        pvt_pretrained: bool = False,
        hidden1: int = 512,
        hidden2: Optional[int] = 512,
        dropout: float = 0.0,
        freeze_backbones: bool = False,
    ):
        super().__init__()
        assert mode in {"resnet", "pvt", "both"}
        self.mode = mode

        self.resnet = None
        res_dim = 0
        if mode in {"resnet", "both"}:
            self.resnet = resnet18(
                weights=ResNet18_Weights.IMAGENET1K_V1 if resnet_pretrained else None
            )
            res_dim = self.resnet.fc.in_features  # 512
            self.resnet.fc = nn.Identity()

        self.pvt = None
        pvt_dim = 0
        if mode in {"pvt", "both"}:
            self.pvt = timm.create_model(
                "pvt_v2_b1",
                pretrained=pvt_pretrained,
                num_classes=0,
                global_pool="avg",
            )
            pvt_dim = getattr(self.pvt, "num_features", 512)

        fused_dim = res_dim + pvt_dim

        layers = [nn.Linear(fused_dim, hidden1), nn.ReLU(inplace=True)]
        if dropout and dropout > 0:
            layers.append(nn.Dropout(dropout))
        if hidden2 is None:
            layers.append(nn.Linear(hidden1, 2))
        else:
            layers.extend(
                [
                    nn.Linear(hidden1, hidden2),
                    nn.ReLU(inplace=True),
                    nn.Linear(hidden2, 2),
                ]
            )
        self.head = nn.Sequential(*layers)

        if freeze_backbones:
            if self.resnet is not None:
                for p in self.resnet.parameters():
                    p.requires_grad = False
            if self.pvt is not None:
                for p in self.pvt.parameters():
                    p.requires_grad = False

    def forward(self, x):
        feats = []
        if self.resnet is not None:
            feats.append(self.resnet(x))
        if self.pvt is not None:
            feats.append(self.pvt(x))
        z = feats[0] if len(feats) == 1 else torch.cat(feats, dim=1)
        return self.head(z)

    @torch.no_grad()
    def tile_features(self, x):
        feats = []
        if self.resnet is not None:
            feats.append(self.resnet(x))
        if self.pvt is not None:
            feats.append(self.pvt(x))
        return feats[0] if len(feats) == 1 else torch.cat(feats, dim=1)


def make_backbone_model(
    mode: str = "both",
    *,
    hidden1: int = 128,
    hidden2: Optional[int] = None,
    dropout: float = 0.0,
    freeze_backbones: bool = False,
    resnet_pretrained: bool = True,
    pvt_pretrained: bool = True,
) -> nn.Module:
    return ResNetPVTRep(
        mode=mode,
        resnet_pretrained=resnet_pretrained,
        pvt_pretrained=pvt_pretrained,
        hidden1=hidden1,
        hidden2=hidden2,
        dropout=dropout,
        freeze_backbones=freeze_backbones,
    )


# =========================
# Inference helpers (tile → image features)
# =========================
@torch.no_grad()
def collect_image_features(
    model: ResNetPVTRep, items: List[Tuple[str, int, str, str]], device
):
    ds = TilesWithMetaDataset(items)
    dl = _make_loader(
        ds,
        batch_size=32,
        shuffle=False,
        num_workers=NUM_WORKERS,
        base_seed=SEED,
        key="FEATS",
    )
    model.eval()

    feats_by_key: Dict[str, List[np.ndarray]] = defaultdict(list)
    labels_by_key: Dict[str, int] = {}
    sid_by_key: Dict[str, str] = {}

    for xb, yb, sidb, imgkey in dl:
        xb = xb.to(device, non_blocking=True)
        z = model.tile_features(xb).detach().cpu().numpy()
        yb = yb.numpy()
        for zi, lab, sid, k in zip(z, yb, sidb, imgkey):
            feats_by_key[k].append(zi)
            labels_by_key[k] = int(lab)
            sid_by_key[k] = str(sid)

    X_list, y_list, sid_list, key_list = [], [], [], []
    for k in sorted(feats_by_key.keys()):
        arr = np.stack(feats_by_key[k], axis=0)  # [num_tiles, D]
        X_list.append(arr.mean(axis=0))  # mean over tiles → 512-D (or fused)
        y_list.append(labels_by_key[k])
        sid_list.append(sid_by_key[k])
        key_list.append(k)
    X = np.stack(X_list, axis=0) if X_list else np.empty((0, 1))
    y = np.asarray(y_list, dtype=np.int64)
    return X, y, sid_list, key_list


# =========================
# Subject helpers
# =========================
def majority_vote(int_list: List[int]):
    if not int_list:
        return None
    z = int_list.count(0)
    o = int_list.count(1)
    return 1 if o > z else 0


def subject_labels_from_items(items_all):
    sid2lab = {}
    for _fp, lab, _d, sid in items_all:
        if sid in sid2lab:
            assert sid2lab[sid] == lab, f"Subject {sid} has mixed labels!"
        else:
            sid2lab[sid] = lab
    return sid2lab


# =========================
# Train one draw (tiles), then extract features and fit WINNER clf
# =========================
def train_one_draw_and_eval(
    train_items, test_items, draw_name, fold_idx, device, feat_mode=FEAT_MODE
):
    train_sids = {sid for _fp, _lab, _d, sid in train_items}
    test_sids = {sid for _fp, _lab, _d, sid in test_items}
    assert train_sids.isdisjoint(test_sids), "Leak: TRAIN and TEST share subjects!"

    reseed_for_fold(draw_name, fold_idx)

    ep = EPOCHS_FN[draw_name]
    lr = LR_INIT[draw_name]
    bs = BATCH_SIZE[draw_name]
    ls = LS_EPS[draw_name]
    wd = WEIGHT_DECAY[draw_name]
    hp_tag = f"ep{ep}_lr{lr:.4g}_bs{bs}_ls{ls}_wd{wd}"

    dl_tr = _make_loader(
        TrainTilesDataset(train_items),  # <-- no augmentation
        shuffle=True,
        num_workers=NUM_WORKERS,
        base_seed=SEED,
        key=f"TILE-TRAIN|{draw_name}|fold{fold_idx}|{hp_tag}",
        batch_size=BATCH_SIZE[draw_name],
    )

    feat_mode_per_draw = "both" if draw_name == "spiral" else "resnet"

    # Fine-tune tile model (same as original settings)
    model = make_backbone_model(
        mode=feat_mode_per_draw,
        hidden1=128,
        hidden2=None,
        dropout=0.0,
        freeze_backbones=False,
        resnet_pretrained=True,
        pvt_pretrained=(feat_mode_per_draw == "both"),  # enable PVT weights for spiral
    ).to(device)

    opt = torch.optim.AdamW(
        model.parameters(), lr=LR_INIT[draw_name], weight_decay=WEIGHT_DECAY[draw_name]
    )
    sched = torch.optim.lr_scheduler.StepLR(opt, step_size=7, gamma=0.1)
    crit = nn.CrossEntropyLoss(label_smoothing=LS_EPS[draw_name])

    for ep in range(1, EPOCHS_FN[draw_name] + 1):
        model.train()
        corr = tot = 0
        for xb, yb in dl_tr:
            xb, yb = xb.to(device, non_blocking=True), torch.as_tensor(
                yb, device=device
            )
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()
            corr += (logits.argmax(1) == yb).sum().item()
            tot += yb.numel()
        print(
            f"[{draw_name:7s}][Fold {fold_idx:02d}][Ep {ep:02d}] "
            f"train_tile_acc={corr/max(1,tot):.3f} | lr={opt.param_groups[0]['lr']:.6f}"
        )
        sched.step()

    # Save backbone (optional)
    (OUTDIR / draw_name).mkdir(parents=True, exist_ok=True)
    torch.save(
        model.state_dict(), OUTDIR / draw_name / f"WINS_{draw_name}_fold{fold_idx}.pth"
    )

    # ---- Extract per-image features
    Xtr, ytr, _sid_tr, _keys_tr = collect_image_features(model, train_items, device)
    Xte, yte, sid_te, _keys_te = collect_image_features(model, test_items, device)
    print(
        f"[{draw_name:7s}][Fold {fold_idx:02d}] Feature dim: {Xtr.shape[1]}  (Xtr={Xtr.shape}, Xte={Xte.shape})"
    )

    # ---- Fit WINNER classifier (fixed params)
    clf = WINNERS[draw_name]
    clf.fit(Xtr, ytr)
    yhat = clf.predict(Xte)

    # ---- Per-draw image metrics
    tp = int(((yhat == 1) & (yte == 1)).sum())
    fp = int(((yhat == 1) & (yte == 0)).sum())
    tn = int(((yhat == 0) & (yte == 0)).sum())
    fn = int(((yhat == 0) & (yte == 1)).sum())
    acc = (tp + tn) / max(1, (tp + fp + tn + fn))
    P, R, F1 = _prf(tp, fp, fn)
    print(
        f"[{draw_name:7s}][Fold {fold_idx:02d}] WINNER TEST (image): acc={acc:.3f}  prec={P:.3f}  recall={R:.3f}  f1={F1:.3f}  "
        f"[tp={tp} fp={fp} tn={tn} fn={fn}]"
    )

    # package predictions per image with sid for subject voting
    per_image = [
        (sid, int(p), int(t)) for sid, p, t in zip(sid_te, yhat.tolist(), yte.tolist())
    ]
    return per_image, {"tp": tp, "fp": fp, "tn": tn, "fn": fn}


# =========================
# 5-fold driver (subject-wise) with WINNERS
# =========================
def _prf(tp: int, fp: int, fn: int):
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = (2 * prec * rec / (prec + rec)) if (prec + rec) > 0 else 0.0
    return prec, rec, f1


def run_subjectwise_5fold_winners(items_all, device, seed=42, feat_mode=FEAT_MODE):
    sid2lab = subject_labels_from_items(items_all)
    subjects = sorted(sid2lab.keys())
    y_subject = np.array([sid2lab[s] for s in subjects], dtype=int)

    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=SEED % (2**32 - 1))

    # ---------- MICRO accumulators (kept for reference/optional printing) ----------
    per_draw_counts_global = {
        d: {"tp": 0, "fp": 0, "tn": 0, "fn": 0} for d in DRAW_TYPES
    }
    per_draw_tp = {d: 0 for d in DRAW_TYPES}
    per_draw_tot = {d: 0 for d in DRAW_TYPES}

    subj_flat_counts_global = {"tp": 0, "fp": 0, "tn": 0, "fn": 0}
    subj_flat_tp = subj_flat_tot = 0
    subj_group_counts_global = {"tp": 0, "fp": 0, "tn": 0, "fn": 0}
    subj_group_tp = subj_group_tot = 0

    # ---------- MACRO accumulators ----------
    per_draw_fold_accs = {d: [] for d in DRAW_TYPES}
    subj_flat_macro_accs = []
    subj_group_macro_accs = []

    for fidx, (tr_idx, te_idx) in enumerate(skf.split(subjects, y_subject), 1):
        print("\n" + "#" * 80)
        print(
            f"############## SUBJECT-WISE CV — FOLD {fidx:02d} (WINNERS) ##############"
        )
        print("#" * 80 + "\n")

        train_sids = set(subjects[i] for i in tr_idx)
        test_sids = set(subjects[i] for i in te_idx)
        assert train_sids.isdisjoint(test_sids)

        per_draw_items = {}
        for d in DRAW_TYPES:
            tr_items = [it for it in items_all if it[2] == d and it[3] in train_sids]
            te_items = [it for it in items_all if it[2] == d and it[3] in test_sids]
            per_draw_items[d] = (tr_items, te_items)

        fold_preds_by_draw = {d: [] for d in DRAW_TYPES}  # (sid, pred, true)

        # ----- Train/eval each draw with its fixed best classifier
        for d in DRAW_TYPES:
            tr_items, te_items = per_draw_items[d]
            if len({lab for _fp, lab, _dr, _sid in tr_items}) < 2 or len(te_items) == 0:
                print(
                    f"[{d:7s}][Fold {fidx:02d}] Skipped (class collapse or empty test)."
                )
                continue

            preds, counts = train_one_draw_and_eval(
                tr_items, te_items, d, fidx, device, feat_mode=feat_mode
            )
            fold_preds_by_draw[d].extend(preds)

            # ---- MICRO accumulation
            per_draw_counts_global[d]["tp"] += counts["tp"]
            per_draw_counts_global[d]["fp"] += counts["fp"]
            per_draw_counts_global[d]["tn"] += counts["tn"]
            per_draw_counts_global[d]["fn"] += counts["fn"]
            per_draw_tp[d] += counts["tp"] + counts["tn"]
            per_draw_tot[d] += counts["tp"] + counts["tn"] + counts["fp"] + counts["fn"]

            # ---- MACRO: per-fold image accuracy for this draw
            _tot = counts["tp"] + counts["fp"] + counts["tn"] + counts["fn"]
            if _tot > 0:
                per_draw_fold_accs[d].append((counts["tp"] + counts["tn"]) / _tot)

        # ----- Subject-level voting (flat + group)
        preds_by_sid_all = defaultdict(list)
        preds_by_sid_by_draw = defaultdict(lambda: defaultdict(list))
        for d in DRAW_TYPES:
            for sid, pred, true in fold_preds_by_draw[d]:
                preds_by_sid_all[sid].append(pred)
                preds_by_sid_by_draw[sid][d].append(pred)

        flat_tp = flat_fp = flat_tn = flat_fn = 0
        group_tp = group_fp = group_tn = group_fn = 0

        for sid in sorted(test_sids):
            true_lab = sid2lab[sid]

            # flat-majority across all images (all draws pooled)
            if len(preds_by_sid_all[sid]) > 0:
                z = preds_by_sid_all[sid].count(0)
                o = len(preds_by_sid_all[sid]) - z
                flat_pred = 1 if o > z else 0
                subj_flat_tp += int(flat_pred == true_lab)  # MICRO bookkeeping
                subj_flat_tot += 1
                if flat_pred == 1 and true_lab == 1:
                    flat_tp += 1
                elif flat_pred == 1 and true_lab == 0:
                    flat_fp += 1
                elif flat_pred == 0 and true_lab == 0:
                    flat_tn += 1
                else:
                    flat_fn += 1

            # group-majority (majority inside each draw → majority across draws)
            draw_votes = []
            for d in DRAW_TYPES:
                if len(preds_by_sid_by_draw[sid][d]) > 0:
                    lst = preds_by_sid_by_draw[sid][d]
                    zeros = lst.count(0)
                    ones = len(lst) - zeros
                    draw_votes.append(1 if ones > zeros else 0)
            if len(draw_votes) > 0:
                zeros = draw_votes.count(0)
                ones = len(draw_votes) - zeros
                group_pred = 1 if ones > zeros else 0
                subj_group_tp += int(group_pred == true_lab)  # MICRO bookkeeping
                subj_group_tot += 1
                if group_pred == 1 and true_lab == 1:
                    group_tp += 1
                elif group_pred == 1 and true_lab == 0:
                    group_fp += 1
                elif group_pred == 0 and true_lab == 0:
                    group_tn += 1
                else:
                    group_fn += 1

        # ----- Per-fold subject-level report
        P, R, F1 = _prf(flat_tp, flat_fp, flat_fn)
        acc = (flat_tp + flat_tn) / max(1, (flat_tp + flat_fp + flat_tn + flat_fn))
        print(
            "\n--- SUBJECT-LEVEL (Flat) METRICS (Fold {:02d}) — WINNERS ---".format(
                fidx
            )
        )
        print(
            f"acc={acc:.3f}  prec={P:.3f}  recall={R:.3f}  f1={F1:.3f}  "
            f"[tp={flat_tp} fp={flat_fp} tn={flat_tn} fn={flat_fn}]"
        )

        P, R, F1 = _prf(group_tp, group_fp, group_fn)
        acc = (group_tp + group_tn) / max(
            1, (group_tp + group_fp + group_tn + group_fn)
        )
        print(
            "\n--- SUBJECT-LEVEL (Group) METRICS (Fold {:02d}) — WINNERS ---".format(
                fidx
            )
        )
        print(
            f"acc={acc:.3f}  prec={P:.3f}  recall={R:.3f}  f1={F1:.3f}  "
            f"[tp={group_tp} fp={group_fp} tn={group_tn} fn={group_fn}]"
        )

        # ----- MACRO: record this fold's subject accuracies
        flat_den = flat_tp + flat_fp + flat_tn + flat_fn
        group_den = group_tp + group_fp + group_tn + group_fn
        if flat_den > 0:
            subj_flat_macro_accs.append((flat_tp + flat_tn) / flat_den)
        if group_den > 0:
            subj_group_macro_accs.append((group_tp + group_tn) / group_den)

    # ================= FINAL REPORTS =================
    # ---- PER-DRAW IMAGE ACCURACIES — MACRO
    print("\n=== PER-DRAW IMAGE ACCURACIES — WINNERS (MACRO over folds) ===")
    for d in DRAW_TYPES:
        if per_draw_fold_accs[d]:
            macro_acc = float(np.mean(per_draw_fold_accs[d]))
            print(
                f"{d:8s}: macro_img_acc={macro_acc:.4f}  (from {len(per_draw_fold_accs[d])} folds)"
            )
        else:
            print(f"{d:8s}: n/a (no valid folds)")

    # ---- PER-DRAW IMAGE ACCURACIES — MICRO (pooled)
    print("\n=== PER-DRAW IMAGE ACCURACIES — WINNERS (MICRO pooled) ===")
    for d in DRAW_TYPES:
        if per_draw_tot[d] == 0:
            print(f"{d:8s}: n/a (no valid folds)")
        else:
            print(
                f"{d:8s}: micro_img_acc={per_draw_tp[d]/per_draw_tot[d]:.4f}  [{per_draw_tp[d]}/{per_draw_tot[d]}]"
            )

    # ---- PER-DRAW IMAGE PRF — MICRO (pooled)
    print("\n=== PER-DRAW IMAGE PRF — WINNERS (MICRO pooled) ===")
    for d in DRAW_TYPES:
        c = per_draw_counts_global[d]
        if sum(c.values()) == 0:
            print(f"{d:8s}: n/a")
            continue
        P, R, F1 = _prf(c["tp"], c["fp"], c["fn"])
        acc = (c["tp"] + c["tn"]) / max(1, (c["tp"] + c["fp"] + c["tn"] + c["fn"]))
        print(
            f"{d:8s}: acc={acc:.3f}  prec={P:.3f}  recall={R:.3f}  f1={F1:.3f}  "
            f"[tp={c['tp']} fp={c['fp']} tn={c['tn']} fn={c['fn']}]"
        )

    # ---- SUBJECT-LEVEL ACCURACIES — MACRO
    print("\n=== SUBJECT-LEVEL ACCURACIES — WINNERS (MACRO over folds) ===")
    if subj_flat_macro_accs:
        flat_macro = float(np.mean(subj_flat_macro_accs))
        print(f"Flat-majority (macro):   {flat_macro:.4f}")
    else:
        print("Flat-majority (macro):   n/a")
    if subj_group_macro_accs:
        group_macro = float(np.mean(subj_group_macro_accs))
        print(f"Group-majority (macro):  {group_macro:.4f}")
    else:
        print("Group-majority (macro):  n/a")

    # ---- SUBJECT-LEVEL ACCURACIES — MICRO (pooled)
    flat_acc = subj_flat_tp / max(1, subj_flat_tot)
    group_acc = subj_group_tp / max(1, subj_group_tot)
    print("\n=== SUBJECT-LEVEL ACCURACIES — WINNERS (MICRO pooled) ===")
    print(f"Flat-majority (micro):   {flat_acc:.4f}  [{subj_flat_tp}/{subj_flat_tot}]")
    print(
        f"Group-majority (micro):  {group_acc:.4f}  [{subj_group_tp}/{subj_group_tot}]"
    )

    # ---- SUBJECT-LEVEL PRF — MICRO
    print("\n=== SUBJECT-LEVEL PRF — WINNERS (MICRO pooled) ===")
    c = subj_flat_counts_global
    P, R, F1 = _prf(c["tp"], c["fp"], c["fn"])
    acc = (c["tp"] + c["tn"]) / max(1, (c["tp"] + c["fp"] + c["tn"] + c["fn"]))
    print(
        f"Flat-majority:  acc={acc:.3f}  prec={P:.3f}  recall={R:.3f}  f1={F1:.3f}  "
        f"[tp={c['tp']} fp={c['fp']} tn={c['tn']} fn={c['fn']}]"
    )

    c = subj_group_counts_global
    P, R, F1 = _prf(c["tp"], c["fp"], c["fn"])
    acc = (c["tp"] + c["tn"]) / max(1, (c["tp"] + c["fp"] + c["tn"] + c["fn"]))
    print(
        f"Group-majority: acc={acc:.3f}  prec={P:.3f}  recall={R:.3f}  f1={F1:.3f}  "
        f"[tp={c['tp']} fp={c['fp']} tn={c['tn']} fn={c['fn']}]"
    )


# =========================
# Kickoff
# =========================
if __name__ == "__main__":
    OUTDIR.mkdir(parents=True, exist_ok=True)
    items_all = discover_items_tree(DATA_ROOT)
    results = run_subjectwise_5fold_winners(
        items_all, device, seed=SEED, feat_mode=FEAT_MODE
    )