In [None]:
from __future__ import annotations


import os  # [DET]
os.environ.setdefault("CUBLAS_WORKSPACE_CONFIG", ":16:8")  # [DET]

import re
from pathlib import Path
from collections import Counter, defaultdict
from typing import List, Tuple, Dict, Optional

import numpy as np
import random
import hashlib  # [DET]
import warnings
warnings.filterwarnings("ignore", message=".*gpu_hist.*deprecated.*", category=UserWarning)
warnings.filterwarnings("ignore", message='.*Parameters: { "predictor" } are not used.*', category=UserWarning)

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torchvision.models import resnet18, ResNet18_Weights
from PIL import Image

from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier

import timm

# =========================
# Config
# =========================
DATA_ROOT = Path(r"DATA/PD_Detection_NewHandPD_Dataset")
OUTDIR    = Path("checkpoints_img5cv_per_draw_tiles_WINS")

DRAW_TYPES = ["circle", "meander", "spiral"]

# Keep circle/spiral as they were; update only meander:
EPOCHS_FN    = {"circle": 5,       "meander": 8,        "spiral": 3}
LR_INIT      = {"circle": 3.16e-4, "meander": 3.16e-4,  "spiral": 1e-4}
BATCH_SIZE   = {"circle": 64,      "meander": 16,       "spiral": 8}
LS_EPS       = {"circle": 0.05,    "meander": 0.0,      "spiral": 0.05}
WEIGHT_DECAY = {"circle": 1e-4,    "meander": 0.0,      "spiral": 0.0}


NUM_WORKERS  = 4
SEED         = 42

PRINT_PATHS = False

# ---- Tiling
TILE_GRID        = (2, 2)  # 2×2 tiles per image (4 tiles)
TILE_RESIZE_BASE = 448     # resize to 448, then crop 4×(224×224)
FINAL_SIZE       = 224


# ---- Backbone mode
FEAT_MODE = "resnet"  # "resnet" | "pvt" | "both"

# ---- Deterministic full-image augmentation repeats
REPEATS_BY_DRAW = {
    "circle":  4,  # fixed rotations over 360°
    "meander": 2,  # deterministic Gaussian noise draws
    "spiral":  2,  # deterministic Gaussian noise draws
}

# ---- Winning classical heads (fixed)
WINNERS = {
    "circle": Pipeline([
        ("scaler", StandardScaler(with_mean=True, with_std=True)),
        ("clf", KNeighborsClassifier(n_neighbors=1, weights="uniform", p=1, metric="minkowski", leaf_size=15)),
    ]),
    "meander": Pipeline([
        ("clf", DecisionTreeClassifier(
            random_state=SEED, criterion="gini", max_depth=None,
            min_samples_split=2, min_samples_leaf=1, max_features=None, class_weight=None
        )),
    ]),
    "spiral": Pipeline(
        [
         (
                "clf",
                KNeighborsClassifier(
                    n_neighbors=1,
                    weights="uniform",
                    p=1,
                    metric="minkowski",
                    leaf_size=15,
                ),
        ),
    ]),
}

# =========================
# Repro / determinism
# =========================
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False          # [DET]
    torch.backends.cudnn.deterministic = True       # [DET]
    torch.backends.cuda.matmul.allow_tf32 = False   # [DET]
    torch.backends.cudnn.allow_tf32 = False         # [DET]
    torch.use_deterministic_algorithms(True)        # [DET]
set_seed(SEED)

def _seed_from_key(base_seed: int, *parts: str) -> int:  # [DET]
    h = hashlib.sha256()
    h.update(str(base_seed).encode())
    for p in parts:
        h.update(str(p).encode())
    return int.from_bytes(h.digest()[:8], "little")

def _worker_init_fn(worker_id: int):  # [DET]
    ws = torch.initial_seed() % (2**32)
    np.random.seed(ws)
    random.seed(ws)

def _make_loader(dataset, *, shuffle: bool, num_workers: int, base_seed: int, key: str, batch_size: int):  # [DET]
    g = torch.Generator(device="cpu")
    g.manual_seed(_seed_from_key(base_seed, key))
    return DataLoader(
        dataset,
        batch_size=batch_size, shuffle=shuffle, num_workers=num_workers,
        pin_memory=True, worker_init_fn=_worker_init_fn, generator=g,
        persistent_workers=False, drop_last=False,
    )

def make_fold_seed(*parts: str) -> int:
    h = hashlib.sha256()
    h.update(str(SEED).encode())
    for p in parts:
        h.update(str(p).encode())
    return int.from_bytes(h.digest()[:8], "little")

def reseed_for_fold(draw: str, fold_idx: int):
    s = make_fold_seed("fold", draw, str(fold_idx))
    random.seed(s)
    np.random.seed(s % (2**32))
    torch.manual_seed(s)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(s)

# ------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# =========================
# Data discovery (DETERMINISTIC)
# =========================
_SUBJ_PAT = re.compile(r"(?i)-([HP])0*(\d+)$")
_SUBJ_ANY = re.compile(r"(?i)-([HP])0*(\d+)\b")

def extract_subject_id(fp: str) -> str:
    stem = Path(fp).stem
    m = _SUBJ_PAT.search(stem) or _SUBJ_ANY.search(stem)
    if not m:
        # For image-wise evaluation, subject is irrelevant; but keep parser strict
        raise ValueError(f"[subject_id] Could not parse subject from: {fp}")
    prefix = m.group(1).upper()
    num = str(int(m.group(2)))
    return f"{prefix}{num}"

_DRAW_MAP = {
    "circle":  ["circle", "circ"],
    "meander": ["meander", "meand", "wave", "wav"],
    "spiral":  ["spiral", "spir", "sp"],
}

def _norm_draw_folder(name: str):
    s = name.lower()
    for canon, toks in _DRAW_MAP.items():
        if any(tok == s or tok in s for tok in toks):
            return canon
    return None

def discover_items_tree(data_root: Path) -> List[Tuple[str, int, str, str]]:
    exts = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff"}
    items: List[Tuple[str, int, str, str]] = []
    seen: set[str] = set()

    for hlab_name, lab in [("Healthy", 0), ("Patient", 1)]:
        lab_dir = data_root / hlab_name
        if not lab_dir.exists():
            print(f"[warn] Missing folder: {lab_dir}")
            continue
        for sub in sorted(lab_dir.iterdir(), key=lambda p: p.name.lower()):
            if not sub.is_dir():
                continue
            d = _norm_draw_folder(sub.name)
            if d is None:
                continue
            files = sorted(
                (p for p in sub.rglob("*") if p.is_file() and p.suffix.lower() in exts),
                key=lambda p: str(p).lower(),
            )
            for p in files:
                rp = str(p.resolve())
                if rp in seen:
                    continue
                seen.add(rp)
                sid = extract_subject_id(rp)
                items.append((rp, lab, d, sid))

    items.sort(key=lambda t: (t[1], t[2], t[3], t[0]))
    print(f"[info] Found {len(items)} images")
    if items:
        print("  by label    :", Counter([lab for _, lab, _, _ in items]))
        print("  by draw type:", Counter([d for _, _, d, _ in items]))
        print("  subjects    :", len(set([sid for *_, sid in items])))
    return items

# =========================
# Transforms / Tiling (per-image max only; NO mean/std)
# =========================
class PerImageMaxNormalize(object):
    def __call__(self, img: torch.Tensor):
        cmax = torch.amax(img, dim=(1, 2), keepdim=True)
        return img / (cmax + 1e-8)

_RESIZE_448 = transforms.Resize((TILE_RESIZE_BASE, TILE_RESIZE_BASE), interpolation=InterpolationMode.BILINEAR)
_TO_TENSOR  = transforms.ToTensor()
_TO_PIL     = transforms.ToPILImage()
_NORM_PERIMG= PerImageMaxNormalize()

def _resize_to_base(pil_img: Image.Image) -> Image.Image:
    return _RESIZE_448(pil_img)

def _crop_tile_from_resized(pil_img_448: Image.Image, tile_index: int, rows: int = 2, cols: int = 2) -> Image.Image:
    W, H = pil_img_448.size
    tile_w, tile_h = W // cols, H // rows
    r = tile_index // cols
    c = tile_index % cols
    left = c * tile_w
    top  = r * tile_h
    return pil_img_448.crop((left, top, left + tile_w, top + tile_h))

def _tile_to_tensor(pil_tile: Image.Image) -> torch.Tensor:
    t = _TO_TENSOR(pil_tile)     # [0,1]
    t = _NORM_PERIMG(t)
    return t

# =========================
# Deterministic full-image augmentation (applied BEFORE tiling)
# =========================
def _hash_int(s: str) -> int:
    return int.from_bytes(hashlib.sha256(s.encode()).digest()[:8], "little")

def _angle_for_repeat(total_repeats: int, repeat_idx: int, min_angle: float = 0.0, max_angle: float = 360.0) -> float:
    if total_repeats <= 1:
        return 0.0
    step = (max_angle - min_angle) / total_repeats
    return min_angle + step * (repeat_idx % total_repeats)

def augment_full_image_deterministic(pil_img: Image.Image, draw: str, fp: str, repeat_idx: int) -> Image.Image:
    img = _resize_to_base(pil_img)
    if draw == "circle":
        angle = _angle_for_repeat(REPEATS_BY_DRAW.get("circle", 1), repeat_idx)
        img = transforms.functional.rotate(
            img, angle, interpolation=InterpolationMode.BILINEAR, expand=False, fill=255
        )
        return img
    repeats = REPEATS_BY_DRAW.get(draw, 1)
    if repeats > 1:
        key = f"{SEED}|noise|{draw}|{fp}|rep{repeat_idx}"
        gen = torch.Generator(device="cpu")
        gen.manual_seed(_hash_int(key))
        x = _TO_TENSOR(img)
        std = 0.003
        noise = torch.randn(x.size(), generator=gen) * std
        x = (x + noise).clamp(0.0, 1.0)
        img = _TO_PIL(x)
    return img

# =========================
# Datasets
# =========================
def _prf(tp: int, fp: int, fn: int):
    prec = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    rec  = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1   = (2*prec*rec/(prec+rec)) if (prec+rec) > 0 else 0.0
    return prec, rec, f1

class TrainTilesAugFirstDataset(Dataset):
    def __init__(self, items):
        self.records = []  # (fp, lab, draw, sid, repeat_idx, tile_index)
        for fp, lab, d, sid in items:
            repeats = max(1, int(REPEATS_BY_DRAW.get(d, 1)))
            for rep in range(repeats):
                for tile_idx in range(TILE_GRID[0] * TILE_GRID[1]):
                    self.records.append((fp, int(lab), d, sid, rep, tile_idx))

    def __len__(self):
        return len(self.records)

    def __getitem__(self, i):
        fp, lab, d, sid, rep, tile_idx = self.records[i]
        pil_img = Image.open(fp).convert("RGB")
        aug_img = augment_full_image_deterministic(pil_img, d, fp, rep)
        pil_tile = _crop_tile_from_resized(aug_img, tile_idx, *TILE_GRID)
        x = _tile_to_tensor(pil_tile)
        return x, lab

class TilesWithMetaDataset(Dataset):
    def __init__(self, items):
        self.records = []  # (fp, lab, sid, tile_index)
        for fp, lab, _d, sid in items:
            for tile_idx in range(TILE_GRID[0] * TILE_GRID[1]):
                self.records.append((fp, int(lab), sid, tile_idx))

    def __len__(self):
        return len(self.records)

    def __getitem__(self, i):
        fp, lab, sid, tile_idx = self.records[i]
        pil_img = Image.open(fp).convert("RGB")
        img448 = _resize_to_base(pil_img)
        pil_tile = _crop_tile_from_resized(img448, tile_idx, *TILE_GRID)
        x = _tile_to_tensor(pil_tile)
        return x, lab, sid, fp

# =========================
# Backbone model (feature extractor)
# =========================
class ResNetPVTRep(nn.Module):
    def __init__(
        self,
        mode: str = "both",
        *,
        resnet_pretrained: bool = True,
        pvt_pretrained: bool = False,
        hidden1: int = 512,
        hidden2: Optional[int] = 512,
        dropout: float = 0.0,
        freeze_backbones: bool = False,
    ):
        super().__init__()
        assert mode in {"resnet", "pvt", "both"}
        self.mode = mode

        self.resnet = None
        res_dim = 0
        if mode in {"resnet", "both"}:
            self.resnet = resnet18(
                weights=ResNet18_Weights.IMAGENET1K_V1 if resnet_pretrained else None
            )
            res_dim = self.resnet.fc.in_features  # 512
            self.resnet.fc = nn.Identity()

        self.pvt = None
        pvt_dim = 0
        if mode in {"pvt", "both"}:
            self.pvt = timm.create_model(
                "pvt_v2_b1", pretrained=pvt_pretrained, num_classes=0, global_pool="avg",
            )
            pvt_dim = getattr(self.pvt, "num_features", 512)

        fused_dim = res_dim + pvt_dim
        layers = [nn.Linear(fused_dim, hidden1), nn.ReLU(inplace=True)]
        if dropout and dropout > 0:
            layers.append(nn.Dropout(dropout))
        if hidden2 is None:
            layers.append(nn.Linear(hidden1, 2))
        else:
            layers.extend([nn.Linear(hidden1, hidden2), nn.ReLU(inplace=True), nn.Linear(hidden2, 2)])
        self.head = nn.Sequential(*layers)

        if freeze_backbones:
            if self.resnet is not None:
                for p in self.resnet.parameters(): p.requires_grad = False
            if self.pvt is not None:
                for p in self.pvt.parameters(): p.requires_grad = False

    def forward(self, x):
        feats = []
        if self.resnet is not None:
            feats.append(self.resnet(x))
        if self.pvt is not None:
            feats.append(self.pvt(x))
        z = feats[0] if len(feats) == 1 else torch.cat(feats, dim=1)
        return self.head(z)

    @torch.no_grad()
    def tile_features(self, x):
        feats = []
        if self.resnet is not None:
            feats.append(self.resnet(x))
        if self.pvt is not None:
            feats.append(self.pvt(x))
        return feats[0] if len(feats) == 1 else torch.cat(feats, dim=1)

def make_backbone_model(
    mode: str = "both",
    *,
    hidden1: int = 128,
    hidden2: Optional[int] = None,
    dropout: float = 0.0,
    freeze_backbones: bool = False,
    resnet_pretrained: bool = True,
    pvt_pretrained: bool = True,
) -> nn.Module:
    return ResNetPVTRep(
        mode=mode,
        resnet_pretrained=resnet_pretrained,
        pvt_pretrained=pvt_pretrained,
        hidden1=hidden1,
        hidden2=hidden2,
        dropout=dropout,
        freeze_backbones=freeze_backbones,
    )

# =========================
# Image-level feature extraction (mean over tiles)
# =========================
@torch.no_grad()
def collect_image_features(model: ResNetPVTRep, items: List[Tuple[str,int,str,str]], device):
    ds = TilesWithMetaDataset(items)
    dl = _make_loader(ds, batch_size=32, shuffle=False, num_workers=NUM_WORKERS, base_seed=SEED, key="FEATS",)
    model.eval()
    feats_by_key: Dict[str, List[np.ndarray]] = defaultdict(list)
    labels_by_key: Dict[str, int] = {}
    for xb, yb, sidb, imgkey in dl:
        xb = xb.to(device, non_blocking=True)
        z = model.tile_features(xb).detach().cpu().numpy()
        yb = yb.numpy()
        for zi, lab, k in zip(z, yb, imgkey):
            feats_by_key[k].append(zi)
            labels_by_key[k] = int(lab)
    X_list, y_list, key_list = [], [], []
    for k in sorted(feats_by_key.keys()):
        arr = np.stack(feats_by_key[k], axis=0)  # [num_tiles, D]
        X_list.append(arr.mean(axis=0))          # mean over tiles → D
        y_list.append(labels_by_key[k])
        key_list.append(k)
    X = np.stack(X_list, axis=0) if X_list else np.empty((0, 1))
    y = np.asarray(y_list, dtype=np.int64)
    return X, y, key_list

# =========================
# One-draw train & eval (image-wise)
# =========================
def train_one_draw_and_eval_imagewise(
    train_items: List[Tuple[str,int,str,str]],
    test_items:  List[Tuple[str,int,str,str]],
    draw_name: str, fold_idx: int, device, feat_mode=FEAT_MODE
):
    reseed_for_fold(draw_name, fold_idx)

    ep = EPOCHS_FN[draw_name]; lr = LR_INIT[draw_name]; bs = BATCH_SIZE[draw_name]
    ls = LS_EPS[draw_name];    wd = WEIGHT_DECAY[draw_name]
    hp_tag = f"ep{ep}_lr{lr:.4g}_bs{bs}_ls{ls}_wd{wd}"

    dl_tr = _make_loader(
        TrainTilesAugFirstDataset(train_items),
        shuffle=True, num_workers=NUM_WORKERS, base_seed=SEED,
        key=f"TILE-TRAIN|{draw_name}|fold{fold_idx}|{hp_tag}", batch_size=bs
    )

    # Backbone
    feat_mode_per_draw = "both" if draw_name == "spiral" else "resnet"

    # Fine-tune tile model (same as original settings)
    model = make_backbone_model(
        mode=feat_mode_per_draw,
        hidden1=128,
        hidden2=None,
        dropout=0.0,
        freeze_backbones=False,
        resnet_pretrained=True,
        pvt_pretrained=(feat_mode_per_draw == "both"),  # enable PVT weights for spiral
    ).to(device)

    opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    sched = torch.optim.lr_scheduler.StepLR(opt, step_size=7, gamma=0.1)
    crit = nn.CrossEntropyLoss(label_smoothing=ls)

    for ep_i in range(1, ep + 1):
        model.train()
        corr = tot = 0
        for xb, yb in dl_tr:
            xb, yb = xb.to(device, non_blocking=True), torch.as_tensor(yb, device=device)
            opt.zero_grad(set_to_none=True)
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward(); opt.step()
            corr += (logits.argmax(1) == yb).sum().item()
            tot  += yb.numel()
        print(f"[{draw_name:7s}][Fold {fold_idx:02d}][Ep {ep_i:02d}] train_tile_acc={corr/max(1,tot):.3f} | lr={opt.param_groups[0]['lr']:.6f}")
        sched.step()

    # Save backbone
    (OUTDIR / draw_name).mkdir(parents=True, exist_ok=True)
    torch.save(model.state_dict(), OUTDIR / draw_name / f"WINSIMG_{draw_name}_fold{fold_idx}.pth")

    # Image features
    Xtr, ytr, _keys_tr = collect_image_features(model, train_items, device)
    Xte, yte,  _keys_te = collect_image_features(model, test_items,  device)
    print(f"[{draw_name:7s}][Fold {fold_idx:02d}] Feature dim: {Xtr.shape[1]}  (Xtr={Xtr.shape}, Xte={Xte.shape})")

    # Winner head
    clf = WINNERS[draw_name]
    clf.fit(Xtr, ytr)
    yhat = clf.predict(Xte)

    # Image-level metrics
    tp = int(((yhat == 1) & (yte == 1)).sum())
    fp = int(((yhat == 1) & (yte == 0)).sum())
    tn = int(((yhat == 0) & (yte == 0)).sum())
    fn = int(((yhat == 0) & (yte == 1)).sum())
    acc = (tp + tn) / max(1, (tp + fp + tn + fn))
    P  = tp / max(1, (tp + fp))
    R  = tp / max(1, (tp + fn))
    F1 = (2*P*R) / max(1e-12, (P+R))

    print(f"[{draw_name:7s}][Fold {fold_idx:02d}] IMAGE TEST: acc={acc:.3f}  prec={P:.3f}  recall={R:.3f}  f1={F1:.3f}  [tp={tp} fp={fp} tn={tn} fn={fn}]")
    return {"tp": tp, "fp": fp, "tn": tn, "fn": fn, "acc": acc}

# =========================
# 5-fold driver — IMAGE-WISE (leaky-by-subject on purpose)
# =========================
def run_imagewise_5fold_winners(items_all, device, seed=SEED, feat_mode=FEAT_MODE):
    print("\n=== IMAGE-WISE 5-FOLD CV (SUBJECTS IGNORED) ===")

    # Per-draw accumulators
    per_draw_fold_accs = {d: [] for d in DRAW_TYPES}    # macro (mean across folds)
    per_draw_counts    = {d: {"tp":0,"fp":0,"tn":0,"fn":0} for d in DRAW_TYPES}  # micro pooled

    for d in DRAW_TYPES:
        # Collect all items of this draw (image-wise split)
        draw_items = [it for it in items_all if it[2] == d]
        if not draw_items:
            print(f"[{d}] No items, skipping.")
            continue

        # Build image-wise labels (stratify by label, NOT by subject)
        y_img = np.array([lab for (_fp, lab, _dr, _sid) in draw_items], dtype=int)

        skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=seed % (2**32 - 1))

        print("\n" + "#"*80)
        print(f"############## IMAGE-WISE CV — DRAW {d.upper()} ##############")
        print("#"*80 + "\n")

        for fidx, (tr_idx, te_idx) in enumerate(skf.split(np.arange(len(draw_items)), y_img), 1):
            tr_items = [draw_items[i] for i in tr_idx]
            te_items = [draw_items[i] for i in te_idx]

            # Ensure both classes exist in train fold
            if len({lab for _fp, lab, _dr, _sid in tr_items}) < 2:
                print(f"[{d:7s}][Fold {fidx:02d}] Skipped (class collapse).")
                continue

            counts = train_one_draw_and_eval_imagewise(tr_items, te_items, d, fidx, device, feat_mode)
            per_draw_fold_accs[d].append(counts["acc"])
            for k in ("tp","fp","tn","fn"):
                per_draw_counts[d][k] += counts[k]

        # Per-draw summaries
        if per_draw_fold_accs[d]:
            macro_acc = float(np.mean(per_draw_fold_accs[d]))
            print(f"\n=== {d.upper()} — IMAGE ACCURACY (MACRO over folds) ===")
            print(f"{d:8s}: macro_img_acc={macro_acc:.4f}  (from {len(per_draw_fold_accs[d])} folds)")
        else:
            print(f"\n=== {d.upper()} — IMAGE ACCURACY (MACRO) ===")
            print(f"{d:8s}: n/a (no valid folds)")

        c = per_draw_counts[d]
        tot = c["tp"] + c["fp"] + c["tn"] + c["fn"]
        if tot > 0:
            P,R = (c["tp"]/max(1,(c["tp"]+c["fp"]))), (c["tp"]/max(1,(c["tp"]+c["fn"])))
            F1  = (2*P*R)/max(1e-12,(P+R))
            acc = (c["tp"] + c["tn"]) / tot
            print(f"\n=== {d.upper()} — IMAGE PRF (MICRO pooled) ===")
            print(f"{d:8s}: acc={acc:.3f}  prec={P:.3f}  recall={R:.3f}  f1={F1:.3f}  "
                  f"[tp={c['tp']} fp={c['fp']} tn={c['tn']} fn={c['fn']}]")
        else:
            print(f"\n=== {d.upper()} — IMAGE PRF (MICRO) ===")
            print(f"{d:8s}: n/a")

    return per_draw_fold_accs, per_draw_counts

# =========================
# Kickoff
# =========================
if __name__ == "__main__":
    OUTDIR.mkdir(parents=True, exist_ok=True)
    items_all = discover_items_tree(DATA_ROOT)
    _macro, _micro = run_imagewise_5fold_winners(items_all, device, seed=SEED, feat_mode=FEAT_MODE)

### _