<a href="https://colab.research.google.com/github/719csy/AI-for-skin-cancer/blob/main/ISIC_Task1%2B3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from google.colab import drive
import os

# Unmount the drive if it's currently mounted
drive.flush_and_unmount()

# Ensure the mountpoint is clean by removing and recreating the directory
if os.path.exists('/content/drive'):
    # os.rmdir('/content/drive')  # This only works for empty directories
    # More robust way to remove a directory and its contents
    !rm -rf /content/drive/*
    !rmdir /content/drive
os.makedirs('/content/drive', exist_ok=True)

# Now, mount the drive with force_remount=True
drive.mount('/content/drive', force_remount=True)

Drive not mounted, so nothing to flush and unmount.
Mounted at /content/drive


In [None]:
import gc, torch
gc.collect()
torch.cuda.empty_cache()


In [None]:
import pandas as pd
df = pd.read_csv(CLS_LABEL_CSV)
print(df.columns.tolist()[:30])
print(df.head(2))

['image', 'MEL', 'NV', 'BCC', 'AKIEC', 'BKL', 'DF', 'VASC']
          image  MEL   NV  BCC  AKIEC  BKL   DF  VASC
0  ISIC_0024306  0.0  1.0  0.0    0.0  0.0  0.0   0.0
1  ISIC_0024307  0.0  1.0  0.0    0.0  0.0  0.0   0.0


In [None]:
import gc, torch
gc.collect()
torch.cuda.empty_cache()

In [None]:
# -*- coding: utf-8 -*-
# =========================================================
# Hybrid CNN–ViT–U-Net (Seg + Cls) with NEW AMP API + resume
# ISIC 2018 Task 1 (Seg) + Task 3 (Cls)
# Physical size: 15mm × 15mm → 256 × 256 px
#
# Speed controls:
# - steps_per_epoch (cap per epoch steps)
# - MC Dropout only every mc_every epochs (otherwise fast MC)
# - validation caps (val_max_batches_seg/cls)
# - workers=2
#
# Resume:
# - load best.pt if exists else latest epoch_XXX.pt
# - continue from ckpt_epoch + 1
# =========================================================

from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import time, math, random, gc, re
import numpy as np
import pandas as pd
import cv2
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

import matplotlib.pyplot as plt
from scipy import ndimage as ndi

# =========================================================
# Performance settings
# =========================================================
cv2.setNumThreads(0)
torch.backends.cudnn.benchmark = True

# Optional: keep current TF32 style (ok). If you want to silence new TF32 warnings,
# you can REMOVE old flags. We'll only set matmul precision (recommended).
if torch.cuda.is_available():
    torch.set_float32_matmul_precision("high")

# =========================================================
# Paths
# =========================================================
SEG_IMAGE_DIR = Path("/content/drive/My Drive/isic2018/images")
SEG_MASK_DIR  = Path("/content/drive/My Drive/isic2018/masks")

CLS_IMAGE_DIR = Path("/content/drive/My Drive/ISIC_2018_T3/ISIC2018_Task3_Training_Input".strip())
CLS_LABEL_CSV = Path("/content/drive/My Drive/ISIC_2018_T3/ISIC2018_Task3_Training_GroundTruth/ISIC2018_Task3_Training_GroundTruth.csv".strip())

CHECKPOINT_DIR = Path("/content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net".strip())
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
EXPORT_DIR = CHECKPOINT_DIR.parent / "exports"
EXPORT_DIR.mkdir(parents=True, exist_ok=True)

# =========================================================
# Physical size definition
# =========================================================
FIELD_SIZE_MM = 15.0
IMAGE_SIZE = 256
MM_PER_PIXEL = FIELD_SIZE_MM / IMAGE_SIZE

# =========================================================
# Utils
# =========================================================
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def save_map_with_mm_axis(path_png: Path, arr: np.ndarray, title: str):
    arr = np.asarray(arr)
    # robust shape handling (avoid imshow crash)
    if arr.ndim == 1:
        if arr.size == IMAGE_SIZE * IMAGE_SIZE:
            arr = arr.reshape(IMAGE_SIZE, IMAGE_SIZE)
        else:
            x_mm = np.linspace(0, FIELD_SIZE_MM, arr.size)
            plt.figure(figsize=(5, 3), dpi=150)
            plt.plot(x_mm, arr)
            plt.xlabel("x (mm)")
            plt.ylabel("value")
            plt.title(title + f" (1D, shape={arr.shape})")
            plt.tight_layout()
            plt.savefig(path_png, bbox_inches="tight")
            plt.close()
            return
    if arr.ndim != 2:
        print(f"[WARN] Skip saving {title}: invalid shape={arr.shape}")
        return

    plt.figure(figsize=(5, 4), dpi=150)
    plt.imshow(arr, extent=(0, FIELD_SIZE_MM, FIELD_SIZE_MM, 0))
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.xlabel("x (mm)")
    plt.ylabel("y (mm)")
    plt.title(title + f" (shape={arr.shape})")
    plt.tight_layout()
    plt.savefig(path_png, bbox_inches="tight")
    plt.close()

def latest_epoch_ckpt(ckpt_dir: Path) -> Path | None:
    # match epoch_XXX.pt
    pts = sorted(ckpt_dir.glob("epoch_*.pt"))
    if not pts:
        return None
    def key(p: Path):
        m = re.search(r"epoch_(\d+)\.pt$", p.name)
        return int(m.group(1)) if m else -1
    pts = sorted(pts, key=key)
    return pts[-1] if key(pts[-1]) >= 0 else None

# =========================================================
# Dataset — Segmentation (ISIC 2018 Task 1)
# =========================================================
class ISICSegDataset(Dataset):
    def __init__(self, image_dir: Path, mask_dir: Path, image_size: int = 256):
        self.image_paths = sorted(image_dir.glob("*.jpg"))
        self.mask_paths = [mask_dir / f"{p.stem}_segmentation.png" for p in self.image_paths]
        self.image_size = image_size

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        img = cv2.imread(str(img_path))
        if img is None:
            raise FileNotFoundError(f"Cannot read image: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(f"Cannot read mask: {mask_path}")

        img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
        mask = cv2.resize(mask, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)

        img = img.astype(np.float32) / 255.0
        mask = (mask > 0).astype(np.float32)

        img = torch.from_numpy(img).permute(2, 0, 1)     # [3,H,W]
        mask = torch.from_numpy(mask).unsqueeze(0)       # [1,H,W]
        return {"image": img, "mask": mask}

# =========================================================
# Dataset — Classification (ISIC 2018 Task 3) supports one-hot CSV
# =========================================================
CLS_MAP_DX = {
    "MEL": 2, "BCC": 2,
    "AKIEC": 1,
    "NV": 0, "BKL": 0, "DF": 0, "VASC": 0
}

class ISICClsDataset(Dataset):
    """
    A) one-hot CSV:
       columns: image, MEL, NV, BCC, AKIEC, BKL, DF, VASC
    B) dx CSV:
       columns: image_id, dx
    """
    def __init__(self, image_dir: Path, label_csv: Path, image_size: int = 256):
        self.image_dir = image_dir
        self.df = pd.read_csv(label_csv)
        self.image_size = image_size

        self.df.columns = self.df.columns.str.strip()
        cols = set(self.df.columns)

        onehot_classes = ["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"]
        has_onehot = ("image" in cols) and all(c in cols for c in onehot_classes)
        has_dx = ("image_id" in cols and "dx" in cols)

        if not (has_onehot or has_dx):
            raise ValueError(
                "Unsupported CSV columns. Expect either:\n"
                "A) one-hot: image + MEL,NV,BCC,AKIEC,BKL,DF,VASC\n"
                "or B) dx: image_id + dx\n"
                f"Got columns: {list(self.df.columns)}"
            )

        self.records = []
        if has_onehot:
            for _, row in self.df.iterrows():
                img_id = str(row["image"]).strip()
                malignant = (row["MEL"] == 1) or (row["BCC"] == 1)
                intermediate = (row["AKIEC"] == 1)
                benign = (row["NV"] == 1) or (row["BKL"] == 1) or (row["DF"] == 1) or (row["VASC"] == 1)

                if malignant:
                    label = 2
                elif intermediate:
                    label = 1
                elif benign:
                    label = 0
                else:
                    continue
                self.records.append((img_id, label))
        else:
            for _, row in self.df.iterrows():
                dx = str(row["dx"]).strip()
                if dx not in CLS_MAP_DX:
                    continue
                img_id = str(row["image_id"]).strip()
                self.records.append((img_id, CLS_MAP_DX[dx]))

        if len(self.records) == 0:
            raise ValueError("No valid records parsed from CSV. Check labels/image ids.")

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        img_id, label = self.records[idx]
        img_path = self.image_dir / f"{img_id}.jpg"
        img = cv2.imread(str(img_path))
        if img is None:
            img_path2 = self.image_dir / f"{img_id}.png"
            img = cv2.imread(str(img_path2))
            if img is None:
                raise FileNotFoundError(f"Cannot find image for id={img_id}: {img_path} or {img_path2}")

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
        img = img.astype(np.float32) / 255.0
        img = torch.from_numpy(img).permute(2, 0, 1)
        return {"image": img, "label": torch.tensor(label, dtype=torch.long)}

# =========================================================
# Model blocks
# =========================================================
class ConvBNAct(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.GELU()
        )
    def forward(self, x): return self.net(x)

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(ConvBNAct(in_ch, out_ch), ConvBNAct(out_ch, out_ch))
    def forward(self, x): return self.net(x)

class ConvEncoder(nn.Module):
    def __init__(self, in_ch=3, base_ch=32):
        super().__init__()
        self.e1 = ConvBlock(in_ch, base_ch)            # H
        self.e2 = ConvBlock(base_ch, base_ch * 2)      # H/2
        self.e3 = ConvBlock(base_ch * 2, base_ch * 4)  # H/4
        self.e4 = ConvBlock(base_ch * 4, base_ch * 8)  # H/8
        self.pool = nn.MaxPool2d(2)
    def forward(self, x):
        f1 = self.e1(x)
        f2 = self.e2(self.pool(f1))
        f3 = self.e3(self.pool(f2))
        f4 = self.e4(self.pool(f3))
        return [f1, f2, f3, f4]

class ViTBottleneck(nn.Module):
    def __init__(self, dim, depth=2, heads=4, drop=0.1):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=dim, nhead=heads, dim_feedforward=dim*4,
            dropout=drop, activation="gelu", batch_first=True, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, depth)
    def forward(self, x):
        b, c, h, w = x.shape
        tokens = x.flatten(2).transpose(1, 2)  # [B,N,C]
        tokens = self.encoder(tokens)
        return tokens.transpose(1, 2).reshape(b, c, h, w)

class UNetDecoder(nn.Module):
    def __init__(self, base_ch=32, out_ch=1, dropout_p=0.10):
        super().__init__()
        self.up3 = nn.ConvTranspose2d(base_ch*8, base_ch*4, 2, 2)
        self.d3  = ConvBlock(base_ch*8, base_ch*4)
        self.up2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, 2, 2)
        self.d2  = ConvBlock(base_ch*4, base_ch*2)
        self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch, 2, 2)
        self.d1  = ConvBlock(base_ch*2, base_ch)
        self.drop = nn.Dropout2d(dropout_p) if dropout_p > 0 else nn.Identity()
        self.out  = nn.Conv2d(base_ch, out_ch, 1)
    def forward(self, feats):
        f1, f2, f3, f4 = feats
        x = self.up3(f4)
        x = self.d3(torch.cat([x, f3], dim=1))
        x = self.up2(x)
        x = self.d2(torch.cat([x, f2], dim=1))
        x = self.up1(x)
        x = self.d1(torch.cat([x, f1], dim=1))
        x = self.drop(x)
        return self.out(x)

class HybridMTLModel(nn.Module):
    def __init__(self, base_ch=32, num_classes=3, seg_dropout=0.10):
        super().__init__()
        self.encoder = ConvEncoder(3, base_ch)
        self.bottleneck = ViTBottleneck(base_ch*8, depth=2, heads=4, drop=0.1)
        self.decoder = UNetDecoder(base_ch, out_ch=1, dropout_p=seg_dropout)
        self.cls_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(base_ch*8, base_ch*8),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(base_ch*8, num_classes),
        )
    def forward(self, x):
        feats = self.encoder(x)
        feats[-1] = self.bottleneck(feats[-1])
        seg_logit = self.decoder(feats)
        cls_logit = self.cls_head(feats[-1])
        return seg_logit, cls_logit

# =========================================================
# Metrics (seg)
# =========================================================
def mask_to_boundary(mask: np.ndarray) -> np.ndarray:
    mask_u8 = (mask > 0).astype(np.uint8)
    kernel = np.ones((3, 3), np.uint8)
    er = cv2.erode(mask_u8, kernel, iterations=1)
    bd = (mask_u8 - er) > 0
    return bd.astype(np.uint8)

def bf_score(pred_mask: np.ndarray, gt_mask: np.ndarray, tol_px: int = 2) -> float:
    pb = mask_to_boundary(pred_mask)
    gb = mask_to_boundary(gt_mask)
    if pb.sum() == 0 and gb.sum() == 0:
        return 1.0
    if pb.sum() == 0 or gb.sum() == 0:
        return 0.0
    dt_g = ndi.distance_transform_edt(1 - gb)
    dt_p = ndi.distance_transform_edt(1 - pb)
    prec = (dt_g[pb.astype(bool)] <= tol_px).mean() if pb.sum() > 0 else 0.0
    rec  = (dt_p[gb.astype(bool)] <= tol_px).mean() if gb.sum() > 0 else 0.0
    return float(2 * prec * rec / (prec + rec)) if (prec + rec) > 0 else 0.0

def hd95(pred_mask: np.ndarray, gt_mask: np.ndarray) -> float:
    pb = mask_to_boundary(pred_mask).astype(bool)
    gb = mask_to_boundary(gt_mask).astype(bool)
    if pb.sum() == 0 and gb.sum() == 0:
        return 0.0
    if pb.sum() == 0 or gb.sum() == 0:
        h, w = pred_mask.shape
        return float(math.sqrt(h*h + w*w))
    dt_g = ndi.distance_transform_edt(~gb)
    dt_p = ndi.distance_transform_edt(~pb)
    d1 = dt_g[pb]
    d2 = dt_p[gb]
    all_d = np.concatenate([d1, d2], axis=0)
    return float(np.percentile(all_d, 95))

# =========================================================
# Metrics (cls)
# =========================================================
def softmax_np(logits: np.ndarray) -> np.ndarray:
    x = logits - logits.max(axis=1, keepdims=True)
    e = np.exp(x)
    return e / (e.sum(axis=1, keepdims=True) + 1e-12)

def ece_score(probs: np.ndarray, labels: np.ndarray, n_bins: int = 15) -> float:
    conf = probs.max(axis=1)
    pred = probs.argmax(axis=1)
    acc = (pred == labels).astype(np.float32)
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    for i in range(n_bins):
        m = (conf > bins[i]) & (conf <= bins[i+1])
        if m.any():
            ece += abs(acc[m].mean() - conf[m].mean()) * m.mean()
    return float(ece)

def nll_score(probs: np.ndarray, labels: np.ndarray) -> float:
    p = probs[np.arange(len(labels)), labels]
    return float((-np.log(p + 1e-12)).mean())

def brier_score(probs: np.ndarray, labels: np.ndarray) -> float:
    n, k = probs.shape
    y = np.zeros((n, k), dtype=np.float32)
    y[np.arange(n), labels] = 1.0
    return float(((probs - y) ** 2).sum(axis=1).mean())

def predictive_entropy(p: np.ndarray, eps=1e-12) -> np.ndarray:
    if p.ndim >= 2 and p.shape[-1] > 1:
        return -(p * np.log(p + eps)).sum(axis=-1)
    return -(p*np.log(p+eps) + (1-p)*np.log(1-p+eps))

def risk_coverage_curve(probs: np.ndarray, labels: np.ndarray, uncertainty: np.ndarray):
    n = len(labels)
    order = np.argsort(uncertainty)
    probs_s = probs[order]
    labels_s = labels[order]
    pred = probs_s.argmax(axis=1)
    err = (pred != labels_s).astype(np.float32)
    coverages = np.linspace(1/n, 1.0, n)
    risks = np.cumsum(err) / (np.arange(n) + 1)
    return coverages, risks

def aurc(coverages: np.ndarray, risks: np.ndarray) -> float:
    # numpy 2.0 prefers trapezoid
    return float(np.trapezoid(risks, coverages))

# =========================================================
# MC Dropout helpers (seg)
# =========================================================
def enable_dropout_only(model: nn.Module):
    model.eval()
    for m in model.modules():
        if isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
            m.train()

def mutual_information(mc_probs: np.ndarray, eps=1e-12) -> np.ndarray:
    mean_p = mc_probs.mean(axis=0)
    H_mean = predictive_entropy(mean_p, eps=eps)
    H_each = predictive_entropy(mc_probs, eps=eps)
    return H_mean - H_each.mean(axis=0)

def compute_boundary_prob_from_plesion(p_lesion: np.ndarray) -> np.ndarray:
    gy, gx = np.gradient(p_lesion)
    g = np.sqrt(gx*gx + gy*gy)
    g = g / (g.max() + 1e-12)
    return g.astype(np.float32)

def transition_width_map_mm(p_lesion: np.ndarray) -> np.ndarray:
    gy, gx = np.gradient(p_lesion)
    g = np.sqrt(gx*gx + gy*gy) + 1e-12
    width_px = 0.6 / g
    width_mm = width_px * MM_PER_PIXEL
    return np.clip(width_mm, 0.0, 10.0).astype(np.float32)

# =========================================================
# Losses
# =========================================================
bce = nn.BCEWithLogitsLoss()
ce  = nn.CrossEntropyLoss()

def dice_loss_with_logits(logits, targets, eps=1e-6):
    probs = torch.sigmoid(logits)
    num = 2 * (probs * targets).sum(dim=(2, 3))
    den = (probs + targets).sum(dim=(2, 3)) + eps
    return (1 - (num / den)).mean()

# =========================================================
# Config (workers=2 + speed controls + resume)
# =========================================================
@dataclass
class TrainConfig:
    image_size: int = 256
    batch_size: int = 8
    epochs: int = 60

    lr: float = 3e-4
    weight_decay: float = 1e-2

    num_workers: int = 2   # ✅ requested

    seg_w: float = 1.0
    cls_w: float = 0.5
    grad_clip: float = 1.0
    val_ratio: float = 0.15
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    # speed controls
    steps_per_epoch: int = 350
    mc_samples: int = 8
    mc_samples_fast: int = 2
    mc_every: int = 5
    val_max_batches_seg: int = 120
    val_max_batches_cls: int = 120

    patience: int = 10
    min_delta: float = 1e-4

    amp: bool = True
    resume: bool = True
    resume_prefer_best: bool = True

cfg = TrainConfig()
seed_everything(cfg.seed)

def make_split(dataset: Dataset, val_ratio: float, seed: int):
    n = len(dataset)
    n_val = int(round(n * val_ratio))
    n_tr = n - n_val
    g = torch.Generator().manual_seed(seed)
    return random_split(dataset, [n_tr, n_val], generator=g)

# =========================================================
# Build datasets/loaders
# =========================================================
seg_all = ISICSegDataset(SEG_IMAGE_DIR, SEG_MASK_DIR, image_size=cfg.image_size)
cls_all = ISICClsDataset(CLS_IMAGE_DIR, CLS_LABEL_CSV, image_size=cfg.image_size)

seg_tr, seg_va = make_split(seg_all, cfg.val_ratio, cfg.seed)
cls_tr, cls_va = make_split(cls_all, cfg.val_ratio, cfg.seed)

common_loader_kwargs = dict(
    num_workers=cfg.num_workers,
    pin_memory=True,
    persistent_workers=True if cfg.num_workers > 0 else False,
    prefetch_factor=2 if cfg.num_workers > 0 else None
)

seg_train_loader = DataLoader(seg_tr, batch_size=cfg.batch_size, shuffle=True, **common_loader_kwargs)
seg_val_loader   = DataLoader(seg_va, batch_size=1, shuffle=False, **common_loader_kwargs)

cls_train_loader = DataLoader(cls_tr, batch_size=cfg.batch_size, shuffle=True, **common_loader_kwargs)
cls_val_loader   = DataLoader(cls_va, batch_size=cfg.batch_size, shuffle=False, **common_loader_kwargs)

# =========================================================
# Model / Optim / AMP scaler (NEW API)
# =========================================================
model = HybridMTLModel(base_ch=32, num_classes=3, seg_dropout=0.10).to(cfg.device)
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.epochs)

# ✅ NEW AMP API (no torch.cuda.amp)
scaler = torch.amp.GradScaler("cuda", enabled=(cfg.amp and cfg.device.startswith("cuda")))

# =========================================================
# Checkpoint save/load
# =========================================================
def save_checkpoint(path: Path, model: nn.Module, optimizer, epoch: int, metrics: dict):
    torch.save({
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict() if optimizer is not None else None,
        "metrics": metrics,
        "time": time.time(),
        "mm_per_pixel": MM_PER_PIXEL,
        "field_size_mm": FIELD_SIZE_MM,
        "image_size": IMAGE_SIZE,
    }, str(path))

def try_resume():
    if not cfg.resume:
        return 1, -1e18, -1

    best_path = CHECKPOINT_DIR / "best.pt"
    last_path = latest_epoch_ckpt(CHECKPOINT_DIR)
    pick = None

    if cfg.resume_prefer_best and best_path.exists():
        pick = best_path
    elif last_path is not None:
        pick = last_path
    elif best_path.exists():
        pick = best_path

    if pick is None:
        print("[Resume] No checkpoint found. Start from epoch 1.")
        return 1, -1e18, -1

    ckpt = torch.load(pick, map_location=cfg.device)
    model.load_state_dict(ckpt["model"], strict=True)
    if ckpt.get("optimizer") is not None:
        try:
            optimizer.load_state_dict(ckpt["optimizer"])
        except Exception as e:
            print("[Resume] Optimizer state not loaded (ok). Reason:", str(e))

    start_epoch = int(ckpt.get("epoch", 0)) + 1
    best_score = float(ckpt.get("metrics", {}).get("val_score", -1e18))
    best_epoch = int(ckpt.get("epoch", -1))

    print(f"[Resume] Loaded: {pick.name} | start_epoch={start_epoch} | best_score={best_score:.4f} (ep{best_epoch})")
    return start_epoch, best_score, best_epoch

# =========================================================
# Train / Val
# =========================================================
def train_one_epoch(epoch: int):
    model.train()
    seg_iter = iter(seg_train_loader)
    cls_iter = iter(cls_train_loader)

    steps = cfg.steps_per_epoch
    running = {"loss": 0.0, "loss_seg": 0.0, "loss_cls": 0.0}

    pbar = tqdm(range(steps), desc=f"[Train] Epoch {epoch}", leave=False)
    for i in pbar:
        optimizer.zero_grad(set_to_none=True)

        # recycle iterators
        try:
            bseg = next(seg_iter)
        except StopIteration:
            seg_iter = iter(seg_train_loader)
            bseg = next(seg_iter)

        try:
            bcls = next(cls_iter)
        except StopIteration:
            cls_iter = iter(cls_train_loader)
            bcls = next(cls_iter)

        # ✅ NEW autocast API
        with torch.amp.autocast(device_type="cuda", enabled=(cfg.amp and cfg.device.startswith("cuda"))):
            # seg forward
            x = bseg["image"].to(cfg.device, non_blocking=True)
            y = bseg["mask"].to(cfg.device, non_blocking=True)
            seg_logit, _ = model(x)
            loss_seg = bce(seg_logit, y) + dice_loss_with_logits(seg_logit, y)

            # cls forward
            x2 = bcls["image"].to(cfg.device, non_blocking=True)
            y2 = bcls["label"].to(cfg.device, non_blocking=True)
            _, cls_logit = model(x2)
            loss_cls = ce(cls_logit, y2)

            loss = cfg.seg_w * loss_seg + cfg.cls_w * loss_cls

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        scaler.step(optimizer)
        scaler.update()

        running["loss"] += float(loss.item())
        running["loss_seg"] += float(loss_seg.item())
        running["loss_cls"] += float(loss_cls.item())

        pbar.set_postfix({
            "loss": running["loss"]/(i+1),
            "seg": running["loss_seg"]/(i+1),
            "cls": running["loss_cls"]/(i+1),
            "lr": scheduler.get_last_lr()[0]
        })

    scheduler.step()
    return {k: v/steps for k, v in running.items()}

@torch.no_grad()
def validate_segmentation_mc(epoch: int, export_n: int = 2):
    do_full = (epoch % cfg.mc_every == 0)
    mc_T = cfg.mc_samples if do_full else cfg.mc_samples_fast

    enable_dropout_only(model)

    dices, bfs, hd95s_px = [], [], []
    entropies, mis = [], []

    export_dir = EXPORT_DIR / f"epoch_{epoch:03d}"
    export_dir.mkdir(parents=True, exist_ok=True)

    for idx, b in enumerate(tqdm(seg_val_loader, desc=f"[Val-Seg MC(T={mc_T})] Epoch {epoch}", leave=False)):
        if idx >= cfg.val_max_batches_seg:
            break

        x = b["image"].to(cfg.device)
        y = b["mask"].to(cfg.device)

        mc_probs = []
        for _ in range(mc_T):
            with torch.amp.autocast(device_type="cuda", enabled=(cfg.amp and cfg.device.startswith("cuda"))):
                seg_logit, _ = model(x)
                prob = torch.sigmoid(seg_logit).squeeze(0).squeeze(0)  # expected [H,W]
            if prob.ndim != 2:
                print("[WARN] prob shape unexpected:", tuple(prob.shape), "seg_logit:", tuple(seg_logit.shape))
                continue
            mc_probs.append(prob.detach().float().cpu().numpy())

        if len(mc_probs) == 0:
            continue

        mc_probs = np.stack(mc_probs, axis=0)  # [T,H,W]
        p_lesion = mc_probs.mean(axis=0).astype(np.float32)
        if p_lesion.ndim != 2:
            print("[WARN] p_lesion shape unexpected:", p_lesion.shape)
            continue

        pred_mask = (p_lesion > 0.5).astype(np.uint8)
        gt_mask = (y.squeeze().detach().cpu().numpy() > 0.5).astype(np.uint8)

        inter = (pred_mask & gt_mask).sum()
        union = pred_mask.sum() + gt_mask.sum()
        dice = (2 * inter / (union + 1e-6)) if union > 0 else 1.0
        dices.append(float(dice))

        bfs.append(bf_score(pred_mask, gt_mask, tol_px=2))
        hd95s_px.append(hd95(pred_mask, gt_mask))

        ent_map = predictive_entropy(p_lesion)
        mi_map = mutual_information(mc_probs) if mc_T > 1 else np.zeros_like(ent_map)

        entropies.append(float(ent_map.mean()))
        mis.append(float(mi_map.mean()))

        if do_full and idx < export_n:
            p_boundary = compute_boundary_prob_from_plesion(p_lesion)
            tw_mm = transition_width_map_mm(p_lesion)

            np.save(export_dir / f"val{idx:03d}_p_lesion.npy", p_lesion)
            np.save(export_dir / f"val{idx:03d}_p_boundary.npy", p_boundary)
            np.save(export_dir / f"val{idx:03d}_entropy.npy", ent_map.astype(np.float32))
            np.save(export_dir / f"val{idx:03d}_mi.npy", mi_map.astype(np.float32))
            np.save(export_dir / f"val{idx:03d}_transition_width_mm.npy", tw_mm)

            save_map_with_mm_axis(export_dir / f"val{idx:03d}_p_lesion.png", p_lesion, "p_lesion (mean prob)")
            save_map_with_mm_axis(export_dir / f"val{idx:03d}_p_boundary.png", p_boundary, "p_boundary (|∇p| normalized)")
            save_map_with_mm_axis(export_dir / f"val{idx:03d}_entropy.png", ent_map, "Predictive entropy")
            save_map_with_mm_axis(export_dir / f"val{idx:03d}_mi.png", mi_map, "Mutual information")
            save_map_with_mm_axis(export_dir / f"val{idx:03d}_transition_width_mm.png", tw_mm, "Transition width (mm)")

    val_dice = float(np.mean(dices)) if dices else 0.0
    val_bf   = float(np.mean(bfs)) if bfs else 0.0
    val_hd95_mm = float(np.mean(hd95s_px) * MM_PER_PIXEL) if hd95s_px else 0.0
    mean_ent = float(np.mean(entropies)) if entropies else 0.0
    mean_mi  = float(np.mean(mis)) if mis else 0.0

    return {
        "val_dice": val_dice,
        "val_bf": val_bf,
        "val_hd95_mm": val_hd95_mm,
        "val_pred_entropy_mean": mean_ent,
        "val_mi_mean": mean_mi,
        "mc_T": mc_T,
        "mc_full": bool(do_full),
    }

@torch.no_grad()
def validate_classification(epoch: int):
    model.eval()
    all_logits, all_labels = [], []

    for bi, b in enumerate(tqdm(cls_val_loader, desc=f"[Val-Cls] Epoch {epoch}", leave=False)):
        if bi >= cfg.val_max_batches_cls:
            break
        x = b["image"].to(cfg.device)
        y = b["label"].to(cfg.device)
        with torch.amp.autocast(device_type="cuda", enabled=(cfg.amp and cfg.device.startswith("cuda"))):
            _, logits = model(x)
        all_logits.append(logits.detach().float().cpu().numpy())
        all_labels.append(y.detach().cpu().numpy())

    if not all_logits:
        return {"val_acc": 0.0, "val_ece": 0.0, "val_aurc": 0.0, "val_nll": 0.0, "val_brier": 0.0}

    logits = np.concatenate(all_logits, axis=0)
    labels = np.concatenate(all_labels, axis=0)

    probs = softmax_np(logits)
    pred = probs.argmax(axis=1)
    acc = float((pred == labels).mean())

    ece = ece_score(probs, labels, n_bins=15)
    nll = nll_score(probs, labels)
    br  = brier_score(probs, labels)

    unc = predictive_entropy(probs)
    cov, risk = risk_coverage_curve(probs, labels, unc)
    A = aurc(cov, risk)

    export_dir = EXPORT_DIR / f"epoch_{epoch:03d}"
    export_dir.mkdir(parents=True, exist_ok=True)

    np.save(export_dir / "risk_coverage_coverage.npy", cov.astype(np.float32))
    np.save(export_dir / "risk_coverage_risk.npy", risk.astype(np.float32))

    plt.figure(figsize=(4, 3), dpi=150)
    plt.plot(cov, risk)
    plt.xlabel("Coverage")
    plt.ylabel("Risk (error rate)")
    plt.title(f"Risk–Coverage (AURC={A:.4f})")
    plt.tight_layout()
    plt.savefig(export_dir / "risk_coverage_curve.png", bbox_inches="tight")
    plt.close()

    return {"val_acc": acc, "val_ece": float(ece), "val_aurc": float(A), "val_nll": float(nll), "val_brier": float(br)}

# =========================================================
# Early stopping
# =========================================================
class EarlyStopper:
    def __init__(self, patience=10, min_delta=1e-4):
        self.patience = patience
        self.min_delta = min_delta
        self.best = -1e18
        self.bad = 0
    def step(self, score: float) -> bool:
        if score > self.best + self.min_delta:
            self.best = score
            self.bad = 0
            return False
        self.bad += 1
        return self.bad >= self.patience

def composite_val_score(seg_m: dict, cls_m: dict) -> float:
    dice = seg_m["val_dice"]
    bf = seg_m["val_bf"]
    hd = seg_m["val_hd95_mm"]
    acc = cls_m["val_acc"]
    ece = cls_m["val_ece"]
    aurc_v = cls_m["val_aurc"]
    nll = cls_m["val_nll"]
    brier = cls_m["val_brier"]

    score = (
        2.0 * dice +
        1.0 * bf +
        1.0 * acc -
        0.25 * math.log(1.0 + hd) -
        0.5 * ece -
        0.5 * aurc_v -
        0.1 * nll -
        0.1 * brier
    )
    return float(score)

stopper = EarlyStopper(patience=cfg.patience, min_delta=cfg.min_delta)

# =========================================================
# Clear cache (safe if rerunning cells)
# =========================================================
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# =========================================================
# Resume
# =========================================================
start_epoch, best_score, best_epoch = try_resume()

# If resumed from a later epoch, keep scheduler aligned (optional)
# For cosine schedule, you can step scheduler to start_epoch-1:
for _ in range(max(0, start_epoch - 1)):
    scheduler.step()

# =========================================================
# Train loop
# =========================================================
for epoch in range(start_epoch, cfg.epochs + 1):
    train_m = train_one_epoch(epoch)
    seg_m = validate_segmentation_mc(epoch, export_n=2)
    cls_m = validate_classification(epoch)

    val_score = composite_val_score(seg_m, cls_m)

    metrics = {**train_m, **seg_m, **cls_m,
               "val_score": val_score,
               "epoch": epoch,
               "mm_per_pixel": MM_PER_PIXEL,
               "field_size_mm": FIELD_SIZE_MM,
               "image_size": IMAGE_SIZE,
               "amp": cfg.amp,
               "steps_per_epoch": cfg.steps_per_epoch}

    # log
    with open(CHECKPOINT_DIR.parent / "train_log.jsonl", "a", encoding="utf-8") as f:
        f.write(pd.Series(metrics).to_json() + "\n")

    save_checkpoint(CHECKPOINT_DIR / f"epoch_{epoch:03d}.pt", model, optimizer, epoch, metrics)

    if val_score > best_score:
        best_score = val_score
        best_epoch = epoch
        save_checkpoint(CHECKPOINT_DIR / "best.pt", model, optimizer, epoch, metrics)

    print(
        f"[Epoch {epoch:03d}] "
        f"loss={train_m['loss']:.4f} | "
        f"Dice={seg_m['val_dice']:.4f} BF={seg_m['val_bf']:.4f} HD95(mm)={seg_m['val_hd95_mm']:.3f} "
        f"(MC T={seg_m['mc_T']}, full={seg_m['mc_full']}) | "
        f"Acc={cls_m['val_acc']:.4f} ECE={cls_m['val_ece']:.4f} AURC={cls_m['val_aurc']:.4f} "
        f"NLL={cls_m['val_nll']:.4f} Brier={cls_m['val_brier']:.4f} | "
        f"val_score={val_score:.4f} best={best_score:.4f} (ep{best_epoch})"
    )

    if stopper.step(val_score):
        print(f"Early stopping at epoch {epoch}. Best epoch={best_epoch}, best_score={best_score:.4f}")
        break

print(f"Training finished. Best epoch: {best_epoch}, best score: {best_score:.4f}")
print(f"Exports saved to: {EXPORT_DIR}")




[Resume] Loaded: best.pt | start_epoch=4 | best_score=2.2798 (ep3)




[Epoch 004] loss=0.6947 | Dice=0.8014 BF=0.3535 HD95(mm)=1.878 (MC T=2, full=False) | Acc=0.8187 ECE=0.0932 AURC=0.0585 NLL=0.5091 Brier=0.2846 | val_score=2.3556 best=2.3556 (ep4)




[Epoch 005] loss=0.6638 | Dice=0.8107 BF=0.3649 HD95(mm)=1.662 (MC T=8, full=True) | Acc=0.8208 ECE=0.0419 AURC=0.0514 NLL=0.4161 Brier=0.2490 | val_score=2.4491 best=2.4491 (ep5)




[Epoch 006] loss=0.6048 | Dice=0.8230 BF=0.3462 HD95(mm)=1.770 (MC T=2, full=False) | Acc=0.8156 ECE=0.0467 AURC=0.0510 NLL=0.4322 Brier=0.2571 | val_score=2.4354 best=2.4491 (ep5)




[Epoch 007] loss=0.5900 | Dice=0.7898 BF=0.3128 HD95(mm)=1.559 (MC T=2, full=False) | Acc=0.8271 ECE=0.0352 AURC=0.0524 NLL=0.4170 Brier=0.2456 | val_score=2.3744 best=2.4491 (ep5)




[Epoch 008] loss=0.5892 | Dice=0.7905 BF=0.2804 HD95(mm)=2.027 (MC T=2, full=False) | Acc=0.8271 ECE=0.0306 AURC=0.0489 NLL=0.4070 Brier=0.2434 | val_score=2.3067 best=2.4491 (ep5)




[Epoch 009] loss=0.5489 | Dice=0.8174 BF=0.3611 HD95(mm)=1.819 (MC T=2, full=False) | Acc=0.8146 ECE=0.0346 AURC=0.0494 NLL=0.4115 Brier=0.2470 | val_score=2.4435 best=2.4491 (ep5)




[Epoch 010] loss=0.5544 | Dice=0.8288 BF=0.3770 HD95(mm)=1.621 (MC T=8, full=True) | Acc=0.8219 ECE=0.0312 AURC=0.0461 NLL=0.3936 Brier=0.2396 | val_score=2.5137 best=2.5137 (ep10)




[Epoch 011] loss=0.5369 | Dice=0.8406 BF=0.4154 HD95(mm)=1.541 (MC T=2, full=False) | Acc=0.8240 ECE=0.0360 AURC=0.0450 NLL=0.4048 Brier=0.2356 | val_score=2.5828 best=2.5828 (ep11)




[Epoch 012] loss=0.5320 | Dice=0.8384 BF=0.4126 HD95(mm)=1.528 (MC T=2, full=False) | Acc=0.8333 ECE=0.0253 AURC=0.0458 NLL=0.3937 Brier=0.2381 | val_score=2.5921 best=2.5921 (ep12)




[Epoch 013] loss=0.5109 | Dice=0.8190 BF=0.4067 HD95(mm)=1.576 (MC T=2, full=False) | Acc=0.8302 ECE=0.0613 AURC=0.0490 NLL=0.4551 Brier=0.2532 | val_score=2.5124 best=2.5921 (ep12)




[Epoch 014] loss=0.5068 | Dice=0.8442 BF=0.3740 HD95(mm)=1.508 (MC T=2, full=False) | Acc=0.8313 ECE=0.0362 AURC=0.0478 NLL=0.4089 Brier=0.2403 | val_score=2.5569 best=2.5921 (ep12)




[Epoch 015] loss=0.5010 | Dice=0.8448 BF=0.3805 HD95(mm)=1.428 (MC T=8, full=True) | Acc=0.8510 ECE=0.0527 AURC=0.0404 NLL=0.3886 Brier=0.2230 | val_score=2.5917 best=2.5921 (ep12)




[Epoch 016] loss=0.4871 | Dice=0.8305 BF=0.3768 HD95(mm)=1.540 (MC T=2, full=False) | Acc=0.8458 ECE=0.0412 AURC=0.0426 NLL=0.3885 Brier=0.2293 | val_score=2.5470 best=2.5921 (ep12)




[Epoch 017] loss=0.4731 | Dice=0.8444 BF=0.4054 HD95(mm)=1.467 (MC T=2, full=False) | Acc=0.7979 ECE=0.0393 AURC=0.0527 NLL=0.4263 Brier=0.2493 | val_score=2.5528 best=2.5921 (ep12)




[Epoch 018] loss=0.4696 | Dice=0.8257 BF=0.4114 HD95(mm)=1.639 (MC T=2, full=False) | Acc=0.8417 ECE=0.0274 AURC=0.0437 NLL=0.3860 Brier=0.2272 | val_score=2.5651 best=2.5921 (ep12)




[Epoch 019] loss=0.4697 | Dice=0.8415 BF=0.4132 HD95(mm)=1.545 (MC T=2, full=False) | Acc=0.8417 ECE=0.0398 AURC=0.0437 NLL=0.3892 Brier=0.2249 | val_score=2.6012 best=2.6012 (ep19)




[Epoch 020] loss=0.4648 | Dice=0.8497 BF=0.4196 HD95(mm)=1.354 (MC T=8, full=True) | Acc=0.8344 ECE=0.0217 AURC=0.0426 NLL=0.3751 Brier=0.2227 | val_score=2.6474 best=2.6474 (ep20)




[Epoch 021] loss=0.4512 | Dice=0.8582 BF=0.4124 HD95(mm)=1.337 (MC T=2, full=False) | Acc=0.8385 ECE=0.0432 AURC=0.0460 NLL=0.4023 Brier=0.2309 | val_score=2.6472 best=2.6474 (ep20)




[Epoch 022] loss=0.4438 | Dice=0.8377 BF=0.3584 HD95(mm)=1.545 (MC T=2, full=False) | Acc=0.8406 ECE=0.0368 AURC=0.0436 NLL=0.3850 Brier=0.2280 | val_score=2.5395 best=2.6474 (ep20)




[Epoch 023] loss=0.4421 | Dice=0.8381 BF=0.3712 HD95(mm)=1.537 (MC T=2, full=False) | Acc=0.8542 ECE=0.0344 AURC=0.0382 NLL=0.3672 Brier=0.2079 | val_score=2.5751 best=2.6474 (ep20)




[Epoch 024] loss=0.4291 | Dice=0.8398 BF=0.3562 HD95(mm)=1.436 (MC T=2, full=False) | Acc=0.8583 ECE=0.0431 AURC=0.0356 NLL=0.3622 Brier=0.2029 | val_score=2.5755 best=2.6474 (ep20)




[Epoch 025] loss=0.4234 | Dice=0.8428 BF=0.4052 HD95(mm)=1.369 (MC T=8, full=True) | Acc=0.8292 ECE=0.0294 AURC=0.0415 NLL=0.3832 Brier=0.2278 | val_score=2.6079 best=2.6474 (ep20)




[Epoch 026] loss=0.4238 | Dice=0.8448 BF=0.4137 HD95(mm)=1.364 (MC T=2, full=False) | Acc=0.8333 ECE=0.0245 AURC=0.0430 NLL=0.3817 Brier=0.2274 | val_score=2.6268 best=2.6474 (ep20)




[Epoch 027] loss=0.4131 | Dice=0.8477 BF=0.4045 HD95(mm)=1.487 (MC T=2, full=False) | Acc=0.8479 ECE=0.0375 AURC=0.0395 NLL=0.3799 Brier=0.2155 | val_score=2.6220 best=2.6474 (ep20)




[Epoch 028] loss=0.3991 | Dice=0.8338 BF=0.4297 HD95(mm)=1.449 (MC T=2, full=False) | Acc=0.8281 ECE=0.0354 AURC=0.0432 NLL=0.3885 Brier=0.2258 | val_score=2.6007 best=2.6474 (ep20)




[Epoch 029] loss=0.3922 | Dice=0.8323 BF=0.4045 HD95(mm)=1.373 (MC T=2, full=False) | Acc=0.8531 ECE=0.0386 AURC=0.0365 NLL=0.3485 Brier=0.2001 | val_score=2.6137 best=2.6474 (ep20)




[Epoch 030] loss=0.3872 | Dice=0.8589 BF=0.4208 HD95(mm)=1.318 (MC T=8, full=True) | Acc=0.8490 ECE=0.0302 AURC=0.0378 NLL=0.3710 Brier=0.2075 | val_score=2.6856 best=2.6856 (ep30)




[Epoch 031] loss=0.3688 | Dice=0.8474 BF=0.3973 HD95(mm)=1.442 (MC T=2, full=False) | Acc=0.8490 ECE=0.0601 AURC=0.0407 NLL=0.4014 Brier=0.2224 | val_score=2.6051 best=2.6856 (ep30)




[Epoch 032] loss=0.3849 | Dice=0.8617 BF=0.3920 HD95(mm)=1.317 (MC T=2, full=False) | Acc=0.8344 ECE=0.0429 AURC=0.0410 NLL=0.3785 Brier=0.2224 | val_score=2.6376 best=2.6856 (ep30)




[Epoch 033] loss=0.3709 | Dice=0.8356 BF=0.4013 HD95(mm)=1.543 (MC T=2, full=False) | Acc=0.8469 ECE=0.0544 AURC=0.0392 NLL=0.3992 Brier=0.2174 | val_score=2.5775 best=2.6856 (ep30)




KeyboardInterrupt: 

In [None]:
from pathlib import Path

CLS_LABEL_CSV = Path("/content/drive/My Drive/ISIC_2018_T3/ISIC2018_Task3_Training_GroundTruth/ISIC2018_Task3_Training_GroundTruth.csv".strip())

if CLS_LABEL_CSV.exists():
    print(f"The file {CLS_LABEL_CSV.name} exists at: {CLS_LABEL_CSV}")
else:
    print(f"Error: The file {CLS_LABEL_CSV.name} does NOT exist at: {CLS_LABEL_CSV}")
    print("Please check the path and filename.")

# The rest of the original code should follow here
# ... (original content of the cell should be appended here)

The file ISIC2018_Task3_Training_GroundTruth.csv exists at: /content/drive/My Drive/ISIC_2018_T3/ISIC2018_Task3_Training_GroundTruth/ISIC2018_Task3_Training_GroundTruth.csv


In [None]:


from pathlib import Path

CLS_LABEL_CSV = Path("/content/drive/My Drive/ISIC_2018_T3/ISIC2018_Task3_Training_Input".strip())

if CLS_LABEL_CSV.exists():
    print(f"The file {CLS_LABEL_CSV.name} exists at: {CLS_LABEL_CSV}")
else:
    print(f"Error: The file {CLS_LABEL_CSV.name} does NOT exist at: {CLS_LABEL_CSV}")
    print("Please check the path and filename.")

# The rest of the original code should follow here
# ... (original content of the cell should be appended here)

The file ISIC2018_Task3_Training_Input exists at: /content/drive/My Drive/ISIC_2018_T3/ISIC2018_Task3_Training_Input


In [None]:
from pathlib import Path

CLS_LABEL_CSV = Path("/content/drive/My Drive/isic2018/images".strip())

if CLS_LABEL_CSV.exists():
    print(f"The file {CLS_LABEL_CSV.name} exists at: {CLS_LABEL_CSV}")
else:
    print(f"Error: The file {CLS_LABEL_CSV.name} does NOT exist at: {CLS_LABEL_CSV}")
    print("Please check the path and filename.")

# The rest of the original code should follow here
# ... (original content of the cell should be appended here)

The file images exists at: /content/drive/My Drive/isic2018/images


In [None]:
!ls "/content/drive/My Drive/"

In [None]:
# -*- coding: utf-8 -*-
# =========================================================
# Hybrid CNN–ViT–U-Net (Seg + Cls) with NEW AMP API + resume + finetune
# ISIC 2018 Task 1 (Seg) + Task 3 (Cls)
# Physical size: 15mm × 15mm → 256 × 256 px
#
# Fine-tune:
# - LR ↓ (e.g., 3e-4 -> 1e-4 or 5e-5)
# - cls_w ↓ (e.g., 0.5 -> 0.3)
# - classification loss: weighted CE OR focal
#
# Fix scheduler warning:
# - save scheduler into checkpoint
# - load scheduler when resume
# - NO manual scheduler.step() loop on resume
# =========================================================

from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import time, math, random, gc, re
import numpy as np
import pandas as pd
import cv2
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

import matplotlib.pyplot as plt
from scipy import ndimage as ndi

# =========================================================
# Performance settings
# =========================================================
cv2.setNumThreads(0)
torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
    torch.set_float32_matmul_precision("high")

# =========================================================
# Paths
# =========================================================
SEG_IMAGE_DIR = Path("/content/drive/My Drive/isic2018/images")
SEG_MASK_DIR  = Path("/content/drive/My Drive/isic2018/masks")

CLS_IMAGE_DIR = Path("/content/drive/My Drive/ISIC_2018_T3/ISIC2018_Task3_Training_Input".strip())
CLS_LABEL_CSV = Path("/content/drive/My Drive/ISIC_2018_T3/ISIC2018_Task3_Training_GroundTruth/ISIC2018_Task3_Training_GroundTruth.csv".strip())

CHECKPOINT_DIR = Path("/content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/checkpoints".strip())
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
EXPORT_DIR = CHECKPOINT_DIR.parent / "exports"
EXPORT_DIR.mkdir(parents=True, exist_ok=True)

# =========================================================
# Physical size definition
# =========================================================
FIELD_SIZE_MM = 15.0
IMAGE_SIZE = 256
MM_PER_PIXEL = FIELD_SIZE_MM / IMAGE_SIZE

# =========================================================
# Utils
# =========================================================
def seed_everything(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def save_map_with_mm_axis(path_png: Path, arr: np.ndarray, title: str):
    arr = np.asarray(arr)
    if arr.ndim == 1:
        if arr.size == IMAGE_SIZE * IMAGE_SIZE:
            arr = arr.reshape(IMAGE_SIZE, IMAGE_SIZE)
        else:
            x_mm = np.linspace(0, FIELD_SIZE_MM, arr.size)
            plt.figure(figsize=(5, 3), dpi=150)
            plt.plot(x_mm, arr)
            plt.xlabel("x (mm)")
            plt.ylabel("value")
            plt.title(title + f" (1D, shape={arr.shape})")
            plt.tight_layout()
            plt.savefig(path_png, bbox_inches="tight")
            plt.close()
            return
    if arr.ndim != 2:
        print(f"[WARN] Skip saving {title}: invalid shape={arr.shape}")
        return

    plt.figure(figsize=(5, 4), dpi=150)
    plt.imshow(arr, extent=(0, FIELD_SIZE_MM, FIELD_SIZE_MM, 0))
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.xlabel("x (mm)")
    plt.ylabel("y (mm)")
    plt.title(title + f" (shape={arr.shape})")
    plt.tight_layout()
    plt.savefig(path_png, bbox_inches="tight")
    plt.close()

def latest_epoch_ckpt(ckpt_dir: Path) -> Path | None:
    pts = sorted(ckpt_dir.glob("epoch_*.pt"))
    if not pts:
        return None
    def key(p: Path):
        m = re.search(r"epoch_(\d+)\.pt$", p.name)
        return int(m.group(1)) if m else -1
    pts = sorted(pts, key=key)
    return pts[-1] if key(pts[-1]) >= 0 else None

# =========================================================
# Dataset — Segmentation (ISIC 2018 Task 1)
# =========================================================
class ISICSegDataset(Dataset):
    def __init__(self, image_dir: Path, mask_dir: Path, image_size: int = 256):
        self.image_paths = sorted(image_dir.glob("*.jpg"))
        self.mask_paths = [mask_dir / f"{p.stem}_segmentation.png" for p in self.image_paths]
        self.image_size = image_size

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        mask_path = self.mask_paths[idx]

        img = cv2.imread(str(img_path))
        if img is None:
            raise FileNotFoundError(f"Cannot read image: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(f"Cannot read mask: {mask_path}")

        img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
        mask = cv2.resize(mask, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)

        img = img.astype(np.float32) / 255.0
        mask = (mask > 0).astype(np.float32)

        img = torch.from_numpy(img).permute(2, 0, 1)     # [3,H,W]
        mask = torch.from_numpy(mask).unsqueeze(0)       # [1,H,W]
        return {"image": img, "mask": mask}

# =========================================================
# Dataset — Classification (ISIC 2018 Task 3) supports one-hot CSV
# =========================================================
CLS_MAP_DX = {
    "MEL": 2, "BCC": 2,
    "AKIEC": 1,
    "NV": 0, "BKL": 0, "DF": 0, "VASC": 0
}

class ISICClsDataset(Dataset):
    """
    A) one-hot CSV:
       columns: image, MEL, NV, BCC, AKIEC, BKL, DF, VASC
    B) dx CSV:
       columns: image_id, dx
    """
    def __init__(self, image_dir: Path, label_csv: Path, image_size: int = 256):
        self.image_dir = image_dir
        self.df = pd.read_csv(label_csv)
        self.image_size = image_size

        self.df.columns = self.df.columns.str.strip()
        cols = set(self.df.columns)

        onehot_classes = ["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"]
        has_onehot = ("image" in cols) and all(c in cols for c in onehot_classes)
        has_dx = ("image_id" in cols and "dx" in cols)

        if not (has_onehot or has_dx):
            raise ValueError(
                "Unsupported CSV columns. Expect either:\n"
                "A) one-hot: image + MEL,NV,BCC,AKIEC,BKL,DF,VASC\n"
                "or B) dx: image_id + dx\n"
                f"Got columns: {list(self.df.columns)}"
            )

        self.records = []
        if has_onehot:
            for _, row in self.df.iterrows():
                img_id = str(row["image"]).strip()
                malignant = (row["MEL"] == 1) or (row["BCC"] == 1)
                intermediate = (row["AKIEC"] == 1)
                benign = (row["NV"] == 1) or (row["BKL"] == 1) or (row["DF"] == 1) or (row["VASC"] == 1)

                if malignant:
                    label = 2
                elif intermediate:
                    label = 1
                elif benign:
                    label = 0
                else:
                    continue
                self.records.append((img_id, label))
        else:
            for _, row in self.df.iterrows():
                dx = str(row["dx"]).strip()
                if dx not in CLS_MAP_DX:
                    continue
                img_id = str(row["image_id"]).strip()
                self.records.append((img_id, CLS_MAP_DX[dx]))

        if len(self.records) == 0:
            raise ValueError("No valid records parsed from CSV. Check labels/image ids.")

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        img_id, label = self.records[idx]
        img_path = self.image_dir / f"{img_id}.jpg"
        img = cv2.imread(str(img_path))
        if img is None:
            img_path2 = self.image_dir / f"{img_id}.png"
            img = cv2.imread(str(img_path2))
            if img is None:
                raise FileNotFoundError(f"Cannot find image for id={img_id}: {img_path} or {img_path2}")

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
        img = img.astype(np.float32) / 255.0
        img = torch.from_numpy(img).permute(2, 0, 1)
        return {"image": img, "label": torch.tensor(label, dtype=torch.long)}

# =========================================================
# Model blocks
# =========================================================
class ConvBNAct(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.GELU()
        )
    def forward(self, x): return self.net(x)

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(ConvBNAct(in_ch, out_ch), ConvBNAct(out_ch, out_ch))
    def forward(self, x): return self.net(x)

class ConvEncoder(nn.Module):
    def __init__(self, in_ch=3, base_ch=32):
        super().__init__()
        self.e1 = ConvBlock(in_ch, base_ch)            # H
        self.e2 = ConvBlock(base_ch, base_ch * 2)      # H/2
        self.e3 = ConvBlock(base_ch * 2, base_ch * 4)  # H/4
        self.e4 = ConvBlock(base_ch * 4, base_ch * 8)  # H/8
        self.pool = nn.MaxPool2d(2)
    def forward(self, x):
        f1 = self.e1(x)
        f2 = self.e2(self.pool(f1))
        f3 = self.e3(self.pool(f2))
        f4 = self.e4(self.pool(f3))
        return [f1, f2, f3, f4]

class ViTBottleneck(nn.Module):
    def __init__(self, dim, depth=2, heads=4, drop=0.1):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=dim, nhead=heads, dim_feedforward=dim*4,
            dropout=drop, activation="gelu", batch_first=True, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, depth)
    def forward(self, x):
        b, c, h, w = x.shape
        tokens = x.flatten(2).transpose(1, 2)  # [B,N,C]
        tokens = self.encoder(tokens)
        return tokens.transpose(1, 2).reshape(b, c, h, w)

class UNetDecoder(nn.Module):
    def __init__(self, base_ch=32, out_ch=1, dropout_p=0.10):
        super().__init__()
        self.up3 = nn.ConvTranspose2d(base_ch*8, base_ch*4, 2, 2)
        self.d3  = ConvBlock(base_ch*8, base_ch*4)
        self.up2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, 2, 2)
        self.d2  = ConvBlock(base_ch*4, base_ch*2)
        self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch, 2, 2)
        self.d1  = ConvBlock(base_ch*2, base_ch)
        self.drop = nn.Dropout2d(dropout_p) if dropout_p > 0 else nn.Identity()
        self.out  = nn.Conv2d(base_ch, out_ch, 1)
    def forward(self, feats):
        f1, f2, f3, f4 = feats
        x = self.up3(f4)
        x = self.d3(torch.cat([x, f3], dim=1))
        x = self.up2(x)
        x = self.d2(torch.cat([x, f2], dim=1))
        x = self.up1(x)
        x = self.d1(torch.cat([x, f1], dim=1))
        x = self.drop(x)
        return self.out(x)

class HybridMTLModel(nn.Module):
    def __init__(self, base_ch=32, num_classes=3, seg_dropout=0.10):
        super().__init__()
        self.encoder = ConvEncoder(3, base_ch)
        self.bottleneck = ViTBottleneck(base_ch*8, depth=2, heads=4, drop=0.1)
        self.decoder = UNetDecoder(base_ch, out_ch=1, dropout_p=seg_dropout)
        self.cls_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(base_ch*8, base_ch*8),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(base_ch*8, num_classes),
        )
    def forward(self, x):
        feats = self.encoder(x)
        feats[-1] = self.bottleneck(feats[-1])
        seg_logit = self.decoder(feats)
        cls_logit = self.cls_head(feats[-1])
        return seg_logit, cls_logit

# =========================================================
# Metrics (seg)
# =========================================================
def mask_to_boundary(mask: np.ndarray) -> np.ndarray:
    mask_u8 = (mask > 0).astype(np.uint8)
    kernel = np.ones((3, 3), np.uint8)
    er = cv2.erode(mask_u8, kernel, iterations=1)
    bd = (mask_u8 - er) > 0
    return bd.astype(np.uint8)

def bf_score(pred_mask: np.ndarray, gt_mask: np.ndarray, tol_px: int = 2) -> float:
    pb = mask_to_boundary(pred_mask)
    gb = mask_to_boundary(gt_mask)
    if pb.sum() == 0 and gb.sum() == 0:
        return 1.0
    if pb.sum() == 0 or gb.sum() == 0:
        return 0.0
    dt_g = ndi.distance_transform_edt(1 - gb)
    dt_p = ndi.distance_transform_edt(1 - pb)
    prec = (dt_g[pb.astype(bool)] <= tol_px).mean() if pb.sum() > 0 else 0.0
    rec  = (dt_p[gb.astype(bool)] <= tol_px).mean() if gb.sum() > 0 else 0.0
    return float(2 * prec * rec / (prec + rec)) if (prec + rec) > 0 else 0.0

def hd95(pred_mask: np.ndarray, gt_mask: np.ndarray) -> float:
    pb = mask_to_boundary(pred_mask).astype(bool)
    gb = mask_to_boundary(gt_mask).astype(bool)
    if pb.sum() == 0 and gb.sum() == 0:
        return 0.0
    if pb.sum() == 0 or gb.sum() == 0:
        h, w = pred_mask.shape
        return float(math.sqrt(h*h + w*w))
    dt_g = ndi.distance_transform_edt(~gb)
    dt_p = ndi.distance_transform_edt(~pb)
    d1 = dt_g[pb]
    d2 = dt_p[gb]
    all_d = np.concatenate([d1, d2], axis=0)
    return float(np.percentile(all_d, 95))

# =========================================================
# Metrics (cls)
# =========================================================
def softmax_np(logits: np.ndarray) -> np.ndarray:
    x = logits - logits.max(axis=1, keepdims=True)
    e = np.exp(x)
    return e / (e.sum(axis=1, keepdims=True) + 1e-12)

def ece_score(probs: np.ndarray, labels: np.ndarray, n_bins: int = 15) -> float:
    conf = probs.max(axis=1)
    pred = probs.argmax(axis=1)
    acc = (pred == labels).astype(np.float32)
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    for i in range(n_bins):
        m = (conf > bins[i]) & (conf <= bins[i+1])
        if m.any():
            ece += abs(acc[m].mean() - conf[m].mean()) * m.mean()
    return float(ece)

def nll_score(probs: np.ndarray, labels: np.ndarray) -> float:
    p = probs[np.arange(len(labels)), labels]
    return float((-np.log(p + 1e-12)).mean())

def brier_score(probs: np.ndarray, labels: np.ndarray) -> float:
    n, k = probs.shape
    y = np.zeros((n, k), dtype=np.float32)
    y[np.arange(n), labels] = 1.0
    return float(((probs - y) ** 2).sum(axis=1).mean())

def predictive_entropy(p: np.ndarray, eps=1e-12) -> np.ndarray:
    if p.ndim >= 2 and p.shape[-1] > 1:
        return -(p * np.log(p + eps)).sum(axis=-1)
    return -(p*np.log(p+eps) + (1-p)*np.log(1-p+eps))

def risk_coverage_curve(probs: np.ndarray, labels: np.ndarray, uncertainty: np.ndarray):
    n = len(labels)
    order = np.argsort(uncertainty)
    probs_s = probs[order]
    labels_s = labels[order]
    pred = probs_s.argmax(axis=1)
    err = (pred != labels_s).astype(np.float32)
    coverages = np.linspace(1/n, 1.0, n)
    risks = np.cumsum(err) / (np.arange(n) + 1)
    return coverages, risks

def aurc(coverages: np.ndarray, risks: np.ndarray) -> float:
    return float(np.trapezoid(risks, coverages))

# =========================================================
# MC Dropout helpers (seg)
# =========================================================
def enable_dropout_only(model: nn.Module):
    model.eval()
    for m in model.modules():
        if isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
            m.train()

def mutual_information(mc_probs: np.ndarray, eps=1e-12) -> np.ndarray:
    mean_p = mc_probs.mean(axis=0)
    H_mean = predictive_entropy(mean_p, eps=eps)
    H_each = predictive_entropy(mc_probs, eps=eps)
    return H_mean - H_each.mean(axis=0)

def compute_boundary_prob_from_plesion(p_lesion: np.ndarray) -> np.ndarray:
    gy, gx = np.gradient(p_lesion)
    g = np.sqrt(gx*gx + gy*gy)
    g = g / (g.max() + 1e-12)
    return g.astype(np.float32)

def transition_width_map_mm(p_lesion: np.ndarray) -> np.ndarray:
    gy, gx = np.gradient(p_lesion)
    g = np.sqrt(gx*gx + gy*gy) + 1e-12
    width_px = 0.6 / g
    width_mm = width_px * MM_PER_PIXEL
    return np.clip(width_mm, 0.0, 10.0).astype(np.float32)

# =========================================================
# Losses (Seg + Cls)
# =========================================================
bce = nn.BCEWithLogitsLoss()

def dice_loss_with_logits(logits, targets, eps=1e-6):
    probs = torch.sigmoid(logits)
    num = 2 * (probs * targets).sum(dim=(2, 3))
    den = (probs + targets).sum(dim=(2, 3)) + eps
    return (1 - (num / den)).mean()

def compute_class_weights_from_loader(loader: DataLoader, device: str, num_classes: int = 3):
    counts = np.zeros((num_classes,), dtype=np.int64)
    for b in loader:
        y = b["label"].numpy()
        for c in range(num_classes):
            counts[c] += (y == c).sum()
    counts = np.maximum(counts, 1)
    inv = 1.0 / counts.astype(np.float32)
    w = inv / inv.sum() * num_classes
    return torch.tensor(w, dtype=torch.float32, device=device)

class FocalLoss(nn.Module):
    """Multi-class focal loss."""
    def __init__(self, gamma: float = 2.0, alpha: torch.Tensor | None = None, reduction: str = "mean"):
        super().__init__()
        self.gamma = gamma
        self.alpha = alpha  # tensor shape [C] or None
        self.reduction = reduction

    def forward(self, logits: torch.Tensor, targets: torch.Tensor):
        logp = torch.log_softmax(logits, dim=1)                  # [B,C]
        p = torch.exp(logp)                                      # [B,C]
        pt = p.gather(1, targets.view(-1, 1)).squeeze(1)         # [B]
        logpt = logp.gather(1, targets.view(-1, 1)).squeeze(1)   # [B]
        loss = -((1 - pt) ** self.gamma) * logpt

        if self.alpha is not None:
            at = self.alpha.gather(0, targets)
            loss = loss * at

        if self.reduction == "mean":
            return loss.mean()
        if self.reduction == "sum":
            return loss.sum()
        return loss

# =========================================================
# Config
# =========================================================
@dataclass
class TrainConfig:
    image_size: int = 256
    batch_size: int = 8
    epochs: int = 60

    # base lr (pretrain stage)
    lr: float = 3e-4
    weight_decay: float = 1e-2

    num_workers: int = 2
    grad_clip: float = 1.0
    val_ratio: float = 0.15
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    amp: bool = True

    # task weights
    seg_w: float = 1.0
    cls_w: float = 0.5

    # speed controls
    steps_per_epoch: int = 350
    mc_samples: int = 8
    mc_samples_fast: int = 2
    mc_every: int = 5
    val_max_batches_seg: int = 120
    val_max_batches_cls: int = 120

    # early stopping
    patience: int = 10
    min_delta: float = 1e-4

    # resume
    resume: bool = True
    resume_prefer_best: bool = True

    # ---- fine-tune switch ----
    finetune: bool = True
    finetune_lr: float = 1e-4       # ✅ LR 降一档（也可 5e-5）
    finetune_cls_w: float = 0.30    # ✅ cls_w 降一点

    # ---- classification loss choice ----
    # "weighted_ce" or "focal"
    cls_loss: str = "focal"
    focal_gamma: float = 2.0

cfg = TrainConfig()
seed_everything(cfg.seed)

# =========================================================
# Split helper
# =========================================================
def make_split(dataset: Dataset, val_ratio: float, seed: int):
    n = len(dataset)
    n_val = int(round(n * val_ratio))
    n_tr = n - n_val
    g = torch.Generator().manual_seed(seed)
    return random_split(dataset, [n_tr, n_val], generator=g)

# =========================================================
# Build datasets/loaders
# =========================================================
seg_all = ISICSegDataset(SEG_IMAGE_DIR, SEG_MASK_DIR, image_size=cfg.image_size)
cls_all = ISICClsDataset(CLS_IMAGE_DIR, CLS_LABEL_CSV, image_size=cfg.image_size)

seg_tr, seg_va = make_split(seg_all, cfg.val_ratio, cfg.seed)
cls_tr, cls_va = make_split(cls_all, cfg.val_ratio, cfg.seed)

common_loader_kwargs = dict(
    num_workers=cfg.num_workers,
    pin_memory=True,
    persistent_workers=True if cfg.num_workers > 0 else False,
    prefetch_factor=2 if cfg.num_workers > 0 else None
)

seg_train_loader = DataLoader(seg_tr, batch_size=cfg.batch_size, shuffle=True, **common_loader_kwargs)
seg_val_loader   = DataLoader(seg_va, batch_size=1, shuffle=False, **common_loader_kwargs)

cls_train_loader = DataLoader(cls_tr, batch_size=cfg.batch_size, shuffle=True, **common_loader_kwargs)
cls_val_loader   = DataLoader(cls_va, batch_size=cfg.batch_size, shuffle=False, **common_loader_kwargs)

# =========================================================
# Model / Optim / Scheduler / AMP scaler
# =========================================================
model = HybridMTLModel(base_ch=32, num_classes=3, seg_dropout=0.10).to(cfg.device)

# --- classification weights from training set ---
cls_weights = compute_class_weights_from_loader(cls_train_loader, cfg.device, num_classes=3)

# --- choose classification loss ---
if cfg.cls_loss.lower() == "weighted_ce":
    cls_criterion = nn.CrossEntropyLoss(weight=cls_weights)
elif cfg.cls_loss.lower() == "focal":
    cls_criterion = FocalLoss(gamma=cfg.focal_gamma, alpha=cls_weights, reduction="mean")
else:
    raise ValueError("cfg.cls_loss must be 'weighted_ce' or 'focal'")

# --- optimizer / scheduler ---
optimizer = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.epochs)

# --- NEW AMP API ---
scaler = torch.amp.GradScaler("cuda", enabled=(cfg.amp and cfg.device.startswith("cuda")))

# =========================================================
# Checkpoint save/load (SAVE scheduler + LOAD scheduler)
# =========================================================
def save_checkpoint(path: Path, model: nn.Module, optimizer, scheduler, epoch: int, metrics: dict):
    torch.save({
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict() if optimizer is not None else None,
        "scheduler": scheduler.state_dict() if scheduler is not None else None,
        "metrics": metrics,
        "time": time.time(),
        "mm_per_pixel": MM_PER_PIXEL,
        "field_size_mm": FIELD_SIZE_MM,
        "image_size": IMAGE_SIZE,
        "cfg": vars(cfg),
    }, str(path))

def try_resume():
    if not cfg.resume:
        return 1, -1e18, -1

    best_path = CHECKPOINT_DIR / "best.pt"
    last_path = latest_epoch_ckpt(CHECKPOINT_DIR)
    pick = None

    if cfg.resume_prefer_best and best_path.exists():
        pick = best_path
    elif last_path is not None:
        pick = last_path
    elif best_path.exists():
        pick = best_path

    if pick is None:
        print("[Resume] No checkpoint found. Start from epoch 1.")
        return 1, -1e18, -1

    ckpt = torch.load(pick, map_location=cfg.device)
    model.load_state_dict(ckpt["model"], strict=True)

    if ckpt.get("optimizer") is not None:
        try:
            optimizer.load_state_dict(ckpt["optimizer"])
        except Exception as e:
            print("[Resume] Optimizer state not loaded (ok). Reason:", str(e))

    # ✅ load scheduler state (fix warning)
    if ckpt.get("scheduler") is not None:
        try:
            scheduler.load_state_dict(ckpt["scheduler"])
        except Exception as e:
            print("[Resume] Scheduler state not loaded (ok). Reason:", str(e))

    start_epoch = int(ckpt.get("epoch", 0)) + 1
    best_score = float(ckpt.get("metrics", {}).get("val_score", -1e18))
    best_epoch = int(ckpt.get("epoch", -1))

    print(f"[Resume] Loaded: {pick.name} | start_epoch={start_epoch} | best_score={best_score:.4f} (ep{best_epoch})")
    return start_epoch, best_score, best_epoch

# =========================================================
# Train / Val
# =========================================================
def train_one_epoch(epoch: int):
    model.train()
    seg_iter = iter(seg_train_loader)
    cls_iter = iter(cls_train_loader)

    steps = cfg.steps_per_epoch
    running = {"loss": 0.0, "loss_seg": 0.0, "loss_cls": 0.0}

    pbar = tqdm(range(steps), desc=f"[Train] Epoch {epoch}", leave=False)
    for i in pbar:
        optimizer.zero_grad(set_to_none=True)

        try:
            bseg = next(seg_iter)
        except StopIteration:
            seg_iter = iter(seg_train_loader)
            bseg = next(seg_iter)

        try:
            bcls = next(cls_iter)
        except StopIteration:
            cls_iter = iter(cls_train_loader)
            bcls = next(cls_iter)

        with torch.amp.autocast(device_type="cuda", enabled=(cfg.amp and cfg.device.startswith("cuda"))):
            # seg
            x = bseg["image"].to(cfg.device, non_blocking=True)
            y = bseg["mask"].to(cfg.device, non_blocking=True)
            seg_logit, _ = model(x)
            loss_seg = bce(seg_logit, y) + dice_loss_with_logits(seg_logit, y)

            # cls
            x2 = bcls["image"].to(cfg.device, non_blocking=True)
            y2 = bcls["label"].to(cfg.device, non_blocking=True)
            _, cls_logit = model(x2)
            loss_cls = cls_criterion(cls_logit, y2)

            loss = cfg.seg_w * loss_seg + cfg.cls_w * loss_cls

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        scaler.step(optimizer)
        scaler.update()

        running["loss"] += float(loss.item())
        running["loss_seg"] += float(loss_seg.item())
        running["loss_cls"] += float(loss_cls.item())

        pbar.set_postfix({
            "loss": running["loss"]/(i+1),
            "seg": running["loss_seg"]/(i+1),
            "cls": running["loss_cls"]/(i+1),
            "lr": scheduler.get_last_lr()[0]
        })

    # ✅ safe scheduler order: optimizer steps already happened inside loop
    scheduler.step()
    return {k: v/steps for k, v in running.items()}

@torch.no_grad()
def validate_segmentation_mc(epoch: int, export_n: int = 2):
    do_full = (epoch % cfg.mc_every == 0)
    mc_T = cfg.mc_samples if do_full else cfg.mc_samples_fast

    enable_dropout_only(model)

    dices, bfs, hd95s_px = [], [], []
    entropies, mis = [], []

    export_dir = EXPORT_DIR / f"epoch_{epoch:03d}"
    export_dir.mkdir(parents=True, exist_ok=True)

    for idx, b in enumerate(tqdm(seg_val_loader, desc=f"[Val-Seg MC(T={mc_T})] Epoch {epoch}", leave=False)):
        if idx >= cfg.val_max_batches_seg:
            break

        x = b["image"].to(cfg.device)
        y = b["mask"].to(cfg.device)

        mc_probs = []
        for _ in range(mc_T):
            with torch.amp.autocast(device_type="cuda", enabled=(cfg.amp and cfg.device.startswith("cuda"))):
                seg_logit, _ = model(x)
                prob = torch.sigmoid(seg_logit).squeeze(0).squeeze(0)
            if prob.ndim != 2:
                continue
            mc_probs.append(prob.detach().float().cpu().numpy())

        if len(mc_probs) == 0:
            continue

        mc_probs = np.stack(mc_probs, axis=0)  # [T,H,W]
        p_lesion = mc_probs.mean(axis=0).astype(np.float32)

        pred_mask = (p_lesion > 0.5).astype(np.uint8)
        gt_mask = (y.squeeze().detach().cpu().numpy() > 0.5).astype(np.uint8)

        inter = (pred_mask & gt_mask).sum()
        union = pred_mask.sum() + gt_mask.sum()
        dice = (2 * inter / (union + 1e-6)) if union > 0 else 1.0
        dices.append(float(dice))

        bfs.append(bf_score(pred_mask, gt_mask, tol_px=2))
        hd95s_px.append(hd95(pred_mask, gt_mask))

        ent_map = predictive_entropy(p_lesion)
        mi_map = mutual_information(mc_probs) if mc_T > 1 else np.zeros_like(ent_map)

        entropies.append(float(ent_map.mean()))
        mis.append(float(mi_map.mean()))

        if do_full and idx < export_n:
            p_boundary = compute_boundary_prob_from_plesion(p_lesion)
            tw_mm = transition_width_map_mm(p_lesion)

            np.save(export_dir / f"val{idx:03d}_p_lesion.npy", p_lesion)
            np.save(export_dir / f"val{idx:03d}_p_boundary.npy", p_boundary)
            np.save(export_dir / f"val{idx:03d}_entropy.npy", ent_map.astype(np.float32))
            np.save(export_dir / f"val{idx:03d}_mi.npy", mi_map.astype(np.float32))
            np.save(export_dir / f"val{idx:03d}_transition_width_mm.npy", tw_mm)

            save_map_with_mm_axis(export_dir / f"val{idx:03d}_p_lesion.png", p_lesion, "p_lesion (mean prob)")
            save_map_with_mm_axis(export_dir / f"val{idx:03d}_p_boundary.png", p_boundary, "p_boundary (|∇p| normalized)")
            save_map_with_mm_axis(export_dir / f"val{idx:03d}_entropy.png", ent_map, "Predictive entropy")
            save_map_with_mm_axis(export_dir / f"val{idx:03d}_mi.png", mi_map, "Mutual information")
            save_map_with_mm_axis(export_dir / f"val{idx:03d}_transition_width_mm.png", tw_mm, "Transition width (mm)")

    val_dice = float(np.mean(dices)) if dices else 0.0
    val_bf   = float(np.mean(bfs)) if bfs else 0.0
    val_hd95_mm = float(np.mean(hd95s_px) * MM_PER_PIXEL) if hd95s_px else 0.0
    mean_ent = float(np.mean(entropies)) if entropies else 0.0
    mean_mi  = float(np.mean(mis)) if mis else 0.0

    return {
        "val_dice": val_dice,
        "val_bf": val_bf,
        "val_hd95_mm": val_hd95_mm,
        "val_pred_entropy_mean": mean_ent,
        "val_mi_mean": mean_mi,
        "mc_T": mc_T,
        "mc_full": bool(do_full),
    }

@torch.no_grad()
def validate_classification(epoch: int):
    model.eval()
    all_logits, all_labels = [], []

    for bi, b in enumerate(tqdm(cls_val_loader, desc=f"[Val-Cls] Epoch {epoch}", leave=False)):
        if bi >= cfg.val_max_batches_cls:
            break
        x = b["image"].to(cfg.device)
        y = b["label"].to(cfg.device)

        with torch.amp.autocast(device_type="cuda", enabled=(cfg.amp and cfg.device.startswith("cuda"))):
            _, logits = model(x)

        all_logits.append(logits.detach().float().cpu().numpy())
        all_labels.append(y.detach().cpu().numpy())

    if not all_logits:
        return {"val_acc": 0.0, "val_ece": 0.0, "val_aurc": 0.0, "val_nll": 0.0, "val_brier": 0.0}

    logits = np.concatenate(all_logits, axis=0)
    labels = np.concatenate(all_labels, axis=0)

    probs = softmax_np(logits)
    pred = probs.argmax(axis=1)
    acc = float((pred == labels).mean())

    ece = ece_score(probs, labels, n_bins=15)
    nll = nll_score(probs, labels)
    br  = brier_score(probs, labels)

    unc = predictive_entropy(probs)
    cov, risk = risk_coverage_curve(probs, labels, unc)
    A = aurc(cov, risk)

    export_dir = EXPORT_DIR / f"epoch_{epoch:03d}"
    export_dir.mkdir(parents=True, exist_ok=True)

    np.save(export_dir / "risk_coverage_coverage.npy", cov.astype(np.float32))
    np.save(export_dir / "risk_coverage_risk.npy", risk.astype(np.float32))

    plt.figure(figsize=(4, 3), dpi=150)
    plt.plot(cov, risk)
    plt.xlabel("Coverage")
    plt.ylabel("Risk (error rate)")
    plt.title(f"Risk–Coverage (AURC={A:.4f})")
    plt.tight_layout()
    plt.savefig(export_dir / "risk_coverage_curve.png", bbox_inches="tight")
    plt.close()

    return {"val_acc": acc, "val_ece": float(ece), "val_aurc": float(A), "val_nll": float(nll), "val_brier": float(br)}

# =========================================================
# Early stopping
# =========================================================
class EarlyStopper:
    def __init__(self, patience=10, min_delta=1e-4):
        self.patience = patience
        self.min_delta = min_delta
        self.best = -1e18
        self.bad = 0
    def step(self, score: float) -> bool:
        if score > self.best + self.min_delta:
            self.best = score
            self.bad = 0
            return False
        self.bad += 1
        return self.bad >= self.patience

def composite_val_score(seg_m: dict, cls_m: dict) -> float:
    dice = seg_m["val_dice"]
    bf = seg_m["val_bf"]
    hd = seg_m["val_hd95_mm"]
    acc = cls_m["val_acc"]
    ece = cls_m["val_ece"]
    aurc_v = cls_m["val_aurc"]
    nll = cls_m["val_nll"]
    brier = cls_m["val_brier"]

    score = (
        2.0 * dice +
        1.0 * bf +
        1.0 * acc -
        0.25 * math.log(1.0 + hd) -
        0.5 * ece -
        0.5 * aurc_v -
        0.1 * nll -
        0.1 * brier
    )
    return float(score)

stopper = EarlyStopper(patience=cfg.patience, min_delta=cfg.min_delta)

# =========================================================
# Clear cache
# =========================================================
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

# =========================================================
# Resume
# =========================================================
start_epoch, best_score, best_epoch = try_resume()

# =========================================================
# Fine-tune switch:
# - When resuming (usually plateau), you can enable finetune config
# =========================================================
if cfg.finetune:
    # ✅ LR down
    for pg in optimizer.param_groups:
        pg["lr"] = cfg.finetune_lr
    # ✅ cls_w down
    cfg.cls_w = cfg.finetune_cls_w
    print(f"[Finetune] Enabled: lr={cfg.finetune_lr} cls_w={cfg.cls_w} cls_loss={cfg.cls_loss}")

# =========================================================
# Train loop
# =========================================================
for epoch in range(start_epoch, cfg.epochs + 1):
    train_m = train_one_epoch(epoch)
    seg_m = validate_segmentation_mc(epoch, export_n=2)
    cls_m = validate_classification(epoch)

    val_score = composite_val_score(seg_m, cls_m)

    metrics = {**train_m, **seg_m, **cls_m,
               "val_score": val_score,
               "epoch": epoch,
               "mm_per_pixel": MM_PER_PIXEL,
               "field_size_mm": FIELD_SIZE_MM,
               "image_size": IMAGE_SIZE,
               "amp": cfg.amp,
               "steps_per_epoch": cfg.steps_per_epoch,
               "cls_w": cfg.cls_w,
               "lr": optimizer.param_groups[0]["lr"],
               "cls_loss": cfg.cls_loss}

    # log
    with open(CHECKPOINT_DIR.parent / "train_log.jsonl", "a", encoding="utf-8") as f:
        f.write(pd.Series(metrics).to_json() + "\n")

    save_checkpoint(CHECKPOINT_DIR / f"epoch_{epoch:03d}.pt", model, optimizer, scheduler, epoch, metrics)

    if val_score > best_score:
        best_score = val_score
        best_epoch = epoch
        save_checkpoint(CHECKPOINT_DIR / "best.pt", model, optimizer, scheduler, epoch, metrics)

    print(
        f"[Epoch {epoch:03d}] "
        f"loss={train_m['loss']:.4f} | "
        f"Dice={seg_m['val_dice']:.4f} BF={seg_m['val_bf']:.4f} HD95(mm)={seg_m['val_hd95_mm']:.3f} "
        f"(MC T={seg_m['mc_T']}, full={seg_m['mc_full']}) | "
        f"Acc={cls_m['val_acc']:.4f} ECE={cls_m['val_ece']:.4f} AURC={cls_m['val_aurc']:.4f} "
        f"NLL={cls_m['val_nll']:.4f} Brier={cls_m['val_brier']:.4f} | "
        f"val_score={val_score:.4f} best={best_score:.4f} (ep{best_epoch}) | "
        f"lr={optimizer.param_groups[0]['lr']:.1e} cls_w={cfg.cls_w:.2f}"
    )

    if stopper.step(val_score):
        print(f"Early stopping at epoch {epoch}. Best epoch={best_epoch}, best_score={best_score:.4f}")
        break

print(f"Training finished. Best epoch: {best_epoch}, best score: {best_score:.4f}")
print(f"Exports saved to: {EXPORT_DIR}")




[Resume] Loaded: best.pt | start_epoch=31 | best_score=2.6856 (ep30)
[Finetune] Enabled: lr=0.0001 cls_w=0.3 cls_loss=focal




[Epoch 031] loss=0.2164 | Dice=0.8469 BF=0.4001 HD95(mm)=1.411 (MC T=2, full=False) | Acc=0.7865 ECE=0.0320 AURC=0.0575 NLL=0.4625 Brier=0.2820 | val_score=2.5412 best=2.6856 (ep30) | lr=1.0e-04 cls_w=0.30




[Epoch 032] loss=0.2102 | Dice=0.8541 BF=0.4268 HD95(mm)=1.382 (MC T=2, full=False) | Acc=0.7719 ECE=0.0416 AURC=0.0634 NLL=0.4975 Brier=0.3037 | val_score=2.5572 best=2.6856 (ep30) | lr=1.0e-04 cls_w=0.30




[Epoch 033] loss=0.2039 | Dice=0.8514 BF=0.4127 HD95(mm)=1.324 (MC T=2, full=False) | Acc=0.7635 ECE=0.0482 AURC=0.0681 NLL=0.5441 Brier=0.3256 | val_score=2.5230 best=2.6856 (ep30) | lr=9.9e-05 cls_w=0.30




[Epoch 034] loss=0.2009 | Dice=0.8457 BF=0.4180 HD95(mm)=1.379 (MC T=2, full=False) | Acc=0.7260 ECE=0.0490 AURC=0.0847 NLL=0.5690 Brier=0.3439 | val_score=2.4606 best=2.6856 (ep30) | lr=9.9e-05 cls_w=0.30




[Epoch 035] loss=0.1957 | Dice=0.8485 BF=0.3853 HD95(mm)=1.389 (MC T=8, full=True) | Acc=0.7427 ECE=0.0595 AURC=0.0774 NLL=0.5853 Brier=0.3436 | val_score=2.4459 best=2.6856 (ep30) | lr=9.8e-05 cls_w=0.30




[Epoch 036] loss=0.1928 | Dice=0.8534 BF=0.4336 HD95(mm)=1.328 (MC T=2, full=False) | Acc=0.7490 ECE=0.0659 AURC=0.0828 NLL=0.5297 Brier=0.3232 | val_score=2.5184 best=2.6856 (ep30) | lr=9.8e-05 cls_w=0.30




[Epoch 037] loss=0.1899 | Dice=0.8535 BF=0.4269 HD95(mm)=1.334 (MC T=2, full=False) | Acc=0.7500 ECE=0.0426 AURC=0.0801 NLL=0.5273 Brier=0.3214 | val_score=2.5258 best=2.6856 (ep30) | lr=9.7e-05 cls_w=0.30




[Epoch 038] loss=0.1866 | Dice=0.8595 BF=0.4113 HD95(mm)=1.261 (MC T=2, full=False) | Acc=0.7198 ECE=0.0501 AURC=0.0922 NLL=0.6242 Brier=0.3673 | val_score=2.4759 best=2.6856 (ep30) | lr=9.6e-05 cls_w=0.30




[Epoch 039] loss=0.1777 | Dice=0.8526 BF=0.4222 HD95(mm)=1.390 (MC T=2, full=False) | Acc=0.7448 ECE=0.0537 AURC=0.0777 NLL=0.5420 Brier=0.3291 | val_score=2.5015 best=2.6856 (ep30) | lr=9.5e-05 cls_w=0.30




[Epoch 040] loss=0.1811 | Dice=0.8584 BF=0.4269 HD95(mm)=1.260 (MC T=8, full=True) | Acc=0.7229 ECE=0.0445 AURC=0.0874 NLL=0.5670 Brier=0.3456 | val_score=2.5056 best=2.6856 (ep30) | lr=9.3e-05 cls_w=0.30




[Epoch 041] loss=0.1778 | Dice=0.8513 BF=0.3975 HD95(mm)=1.301 (MC T=2, full=False) | Acc=0.7500 ECE=0.0347 AURC=0.0748 NLL=0.5853 Brier=0.3442 | val_score=2.4942 best=2.6856 (ep30) | lr=9.2e-05 cls_w=0.30




KeyboardInterrupt: 

In [None]:
#递归遍历整个目录（含子文件夹），把所有 .npy 转成同名 .csv，并放回原目录，适用于 Google Colab + Drive

import os
import numpy as np
import pandas as pd

ROOT_DIR = "/content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net"

def npy_to_csv(npy_path):
    data = np.load(npy_path)

    # 处理不同维度
    if data.ndim == 1:
        df = pd.DataFrame(data, columns=["value"])
    elif data.ndim == 2:
        df = pd.DataFrame(data)
    else:
        # 高维数据：flatten 成 (N, C)
        data_flat = data.reshape(data.shape[0], -1)
        df = pd.DataFrame(data_flat)

    csv_path = npy_path.replace(".npy", ".csv")
    df.to_csv(csv_path, index=False)
    print(f"✔ Converted: {npy_path} → {csv_path}")

# 递归遍历
for root, _, files in os.walk(ROOT_DIR):
    for file in files:
        if file.endswith(".npy"):
            npy_to_csv(os.path.join(root, file))

print("✅ All .npy files have been converted to .csv")


✔ Converted: /content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/exports/epoch_001/risk_coverage_coverage.npy → /content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/exports/epoch_001/risk_coverage_coverage.csv
✔ Converted: /content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/exports/epoch_001/risk_coverage_risk.npy → /content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/exports/epoch_001/risk_coverage_risk.csv
✔ Converted: /content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/exports/epoch_002/risk_coverage_coverage.npy → /content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/exports/epoch_002/risk_coverage_coverage.csv
✔ Converted: /content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/exports/epoch_002/risk_coverage_risk.npy → /content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/exports/epoch_002/risk_coverage_risk.csv
✔ Converted: /content/drive/My Drive/seg_cla

In [None]:
import os
import pandas as pd

ROOT_DIR = "/content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net"

dfs = []

for root, _, files in os.walk(ROOT_DIR):
    for file in files:
        if file.endswith(".jsonl"):
            jsonl_path = os.path.join(root, file)
            try:
                df = pd.read_json(jsonl_path, lines=True)
                df["source_file"] = jsonl_path  # 标记来源
                dfs.append(df)
                print(f"✔ Loaded: {jsonl_path}")
            except Exception as e:
                print(f"✘ Failed to load {jsonl_path}: {e}")

# 合并所有日志
if len(dfs) > 0:
    all_logs_df = pd.concat(dfs, ignore_index=True)
    print("✅ DataFrame created successfully")
    display(all_logs_df.head())
else:
    print("⚠️ No .jsonl files found")


✔ Loaded: /content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/train_log.jsonl
✅ DataFrame created successfully


Unnamed: 0,loss,loss_seg,loss_cls,val_dice,val_bf,val_hd95_mm,val_pred_entropy_mean,val_mi_mean,mc_T,mc_full,...,epoch,mm_per_pixel,field_size_mm,image_size,amp,steps_per_epoch,cls_w,lr,cls_loss,source_file
0,1.1919,0.90773,0.568339,0.755712,0.258117,2.076796,53.147146,0.174274,2,False,...,1,0.058594,15,256,True,350,,,,/content/drive/My Drive/seg_class_skin_cancer/...
1,0.925326,0.653624,0.543404,0.770972,0.270507,1.933957,35.667731,0.198982,2,False,...,2,0.058594,15,256,True,350,,,,/content/drive/My Drive/seg_class_skin_cancer/...
2,0.780355,0.526947,0.506816,0.784604,0.309778,2.037588,28.413414,0.132802,2,False,...,3,0.058594,15,256,True,350,,,,/content/drive/My Drive/seg_class_skin_cancer/...
3,0.725394,0.476267,0.498256,0.756719,0.256083,2.018875,16.868947,0.100238,2,False,...,4,0.058594,15,256,True,350,,,,/content/drive/My Drive/seg_class_skin_cancer/...
4,0.694706,0.446966,0.49548,0.80143,0.353455,1.877534,15.144863,0.11554,2,False,...,4,0.058594,15,256,True,350,,,,/content/drive/My Drive/seg_class_skin_cancer/...




In [None]:
# -*- coding: utf-8 -*-
# =========================================================
# VALIDATION-ONLY Inference + Export (NO TRAINING)
# Hybrid CNN–ViT–U-Net (Seg + Cls)
# Uses OFFICIAL ISIC 2018 VALIDATION folders (Task1 + Task3)
# Physical size: 15mm × 15mm → 256 × 256 px
#
# Exports:
#  Seg: p_lesion, p_boundary, entropy, MI, transition_width_mm (npy + png with mm axis)
#  Cls: logits/probs/labels/uncertainty(entropy) + risk-coverage curve (AURC)
#
# Speed:
#  - workers=2
#  - AMP autocast
# =========================================================

from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import math, random, re
import numpy as np
import pandas as pd
import cv2
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
from scipy import ndimage as ndi


# -----------------------------
# Performance settings
# -----------------------------
cv2.setNumThreads(0)
torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
    torch.set_float32_matmul_precision("high")


# -----------------------------
# Your VALIDATION paths (✅ use these)
# -----------------------------
SEG_IMAGE_DIR = Path("/content/drive/My Drive/isic2018/Validation")
SEG_MASK_DIR  = Path("/content/drive/My Drive/isic2018/ValidationTruth")

CLS_IMAGE_DIR = Path("/content/drive/My Drive/ISIC_2018_T3/ISIC2018_Task3_Validation_Input".strip())
CLS_LABEL_CSV = Path("/content/drive/My Drive/ISIC_2018_T3/ISIC2018_Task3_Validation_GroundTruth/ISIC2018_Task3_Validation_GroundTruth.csv".strip())

CHECKPOINT_DIR = Path("/content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/checkpoints".strip())
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

EXPORT_DIR = CHECKPOINT_DIR.parent / "exports_validation"
EXPORT_DIR.mkdir(parents=True, exist_ok=True)


# -----------------------------
# Physical size
# -----------------------------
FIELD_SIZE_MM = 15.0
IMAGE_SIZE = 256
MM_PER_PIXEL = FIELD_SIZE_MM / IMAGE_SIZE


# -----------------------------
# Utils
# -----------------------------
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)

def pick_checkpoint(ckpt_dir: Path) -> Path:
    best_path = ckpt_dir / "best.pt"
    if best_path.exists():
        return best_path

    pts = sorted(ckpt_dir.glob("epoch_*.pt"))
    if not pts:
        raise FileNotFoundError(f"No checkpoint found in: {ckpt_dir}")

    def key(p: Path):
        m = re.search(r"epoch_(\d+)\.pt$", p.name)
        return int(m.group(1)) if m else -1

    pts = sorted(pts, key=key)
    return pts[-1]

def save_map_with_mm_axis(path_png: Path, arr: np.ndarray, title: str):
    arr = np.asarray(arr)
    if arr.ndim != 2:
        print(f"[WARN] Skip saving {title}: invalid shape={arr.shape}")
        return

    plt.figure(figsize=(5, 4), dpi=150)
    plt.imshow(arr, extent=(0, FIELD_SIZE_MM, FIELD_SIZE_MM, 0))
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.xlabel("x (mm)")
    plt.ylabel("y (mm)")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(path_png, bbox_inches="tight")
    plt.close()


# -----------------------------
# Dataset helpers
# -----------------------------
def list_images_any_ext(folder: Path):
    exts = ["*.jpg", "*.jpeg", "*.png", "*.JPG", "*.JPEG", "*.PNG"]
    paths = []
    for e in exts:
        paths.extend(folder.glob(e))
    return sorted(paths)


# =========================================================
# Dataset — Segmentation (Task 1 VALIDATION)
# =========================================================
class ISICSegValDataset(Dataset):
    """
    Uses official validation folders.
    Mask naming varies by pack, so we implement:
      - exact stem match: <stem>_segmentation.png (common in training)
      - or <stem>.png in ValidationTruth (common in official val)
    """
    def __init__(self, image_dir: Path, mask_dir: Path, image_size: int = 256):
        self.image_paths = list_images_any_ext(image_dir)
        self.mask_dir = mask_dir
        self.image_size = image_size

        self.pairs = []
        for p in self.image_paths:
            stem = p.stem

            cand1 = mask_dir / f"{stem}_segmentation.png"
            cand2 = mask_dir / f"{stem}.png"
            cand3 = mask_dir / f"{stem}_segmentation.PNG"
            cand4 = mask_dir / f"{stem}.PNG"

            if cand1.exists():
                mp = cand1
            elif cand2.exists():
                mp = cand2
            elif cand3.exists():
                mp = cand3
            elif cand4.exists():
                mp = cand4
            else:
                # skip if no mask
                continue
            self.pairs.append((p, mp))

        if len(self.pairs) == 0:
            raise ValueError(f"No (image,mask) pairs found. Check mask naming in: {mask_dir}")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_path, mask_path = self.pairs[idx]

        img = cv2.imread(str(img_path))
        if img is None:
            raise FileNotFoundError(f"Cannot read image: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(f"Cannot read mask: {mask_path}")

        img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
        mask = cv2.resize(mask, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)

        img = img.astype(np.float32) / 255.0
        mask = (mask > 0).astype(np.float32)

        img_t = torch.from_numpy(img).permute(2, 0, 1)     # [3,H,W]
        mask_t = torch.from_numpy(mask).unsqueeze(0)       # [1,H,W]
        return {"image": img_t, "mask": mask_t, "id": img_path.stem}


# =========================================================
# Dataset — Classification (Task 3 VALIDATION one-hot CSV)
# =========================================================
class ISICClsValDataset(Dataset):
    """
    Expects one-hot CSV:
      columns: image, MEL, NV, BCC, AKIEC, BKL, DF, VASC
    Converts to 3-class:
      malignant: MEL/BCC -> 2
      intermediate: AKIEC -> 1
      benign: NV/BKL/DF/VASC -> 0
    """
    def __init__(self, image_dir: Path, label_csv: Path, image_size: int = 256):
        self.image_dir = image_dir
        self.df = pd.read_csv(label_csv)
        self.df.columns = self.df.columns.str.strip()
        self.image_size = image_size

        onehot_classes = ["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"]
        if ("image" not in self.df.columns) or (not all(c in self.df.columns for c in onehot_classes)):
            raise ValueError(f"CSV missing expected columns. Got: {list(self.df.columns)}")

        self.records = []
        for _, row in self.df.iterrows():
            img_id = str(row["image"]).strip()

            malignant = (row["MEL"] == 1) or (row["BCC"] == 1)
            intermediate = (row["AKIEC"] == 1)
            benign = (row["NV"] == 1) or (row["BKL"] == 1) or (row["DF"] == 1) or (row["VASC"] == 1)

            if malignant:
                label = 2
            elif intermediate:
                label = 1
            elif benign:
                label = 0
            else:
                continue

            self.records.append((img_id, label))

        if len(self.records) == 0:
            raise ValueError("No valid records in validation CSV.")

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        img_id, label = self.records[idx]

        # support jpg/png
        jpg = self.image_dir / f"{img_id}.jpg"
        png = self.image_dir / f"{img_id}.png"
        jpeg = self.image_dir / f"{img_id}.jpeg"

        img_path = jpg if jpg.exists() else (png if png.exists() else (jpeg if jpeg.exists() else None))
        if img_path is None:
            raise FileNotFoundError(f"Cannot find image for id={img_id} in {self.image_dir}")

        img = cv2.imread(str(img_path))
        if img is None:
            raise FileNotFoundError(f"Cannot read image: {img_path}")

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
        img = img.astype(np.float32) / 255.0
        img_t = torch.from_numpy(img).permute(2, 0, 1)

        return {"image": img_t, "label": torch.tensor(label, dtype=torch.long), "id": img_id}


# =========================================================
# Model blocks (same as training)
# =========================================================
class ConvBNAct(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.GELU()
        )
    def forward(self, x): return self.net(x)

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(ConvBNAct(in_ch, out_ch), ConvBNAct(out_ch, out_ch))
    def forward(self, x): return self.net(x)

class ConvEncoder(nn.Module):
    def __init__(self, in_ch=3, base_ch=32):
        super().__init__()
        self.e1 = ConvBlock(in_ch, base_ch)
        self.e2 = ConvBlock(base_ch, base_ch * 2)
        self.e3 = ConvBlock(base_ch * 2, base_ch * 4)
        self.e4 = ConvBlock(base_ch * 4, base_ch * 8)
        self.pool = nn.MaxPool2d(2)
    def forward(self, x):
        f1 = self.e1(x)
        f2 = self.e2(self.pool(f1))
        f3 = self.e3(self.pool(f2))
        f4 = self.e4(self.pool(f3))
        return [f1, f2, f3, f4]

class ViTBottleneck(nn.Module):
    def __init__(self, dim, depth=2, heads=4, drop=0.1):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=dim, nhead=heads, dim_feedforward=dim*4,
            dropout=drop, activation="gelu", batch_first=True, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, depth)
    def forward(self, x):
        b, c, h, w = x.shape
        tokens = x.flatten(2).transpose(1, 2)
        tokens = self.encoder(tokens)
        return tokens.transpose(1, 2).reshape(b, c, h, w)

class UNetDecoder(nn.Module):
    def __init__(self, base_ch=32, out_ch=1, dropout_p=0.10):
        super().__init__()
        self.up3 = nn.ConvTranspose2d(base_ch*8, base_ch*4, 2, 2)
        self.d3  = ConvBlock(base_ch*8, base_ch*4)
        self.up2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, 2, 2)
        self.d2  = ConvBlock(base_ch*4, base_ch*2)
        self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch, 2, 2)
        self.d1  = ConvBlock(base_ch*2, base_ch)
        self.drop = nn.Dropout2d(dropout_p) if dropout_p > 0 else nn.Identity()
        self.out  = nn.Conv2d(base_ch, out_ch, 1)
    def forward(self, feats):
        f1, f2, f3, f4 = feats
        x = self.up3(f4)
        x = self.d3(torch.cat([x, f3], dim=1))
        x = self.up2(x)
        x = self.d2(torch.cat([x, f2], dim=1))
        x = self.up1(x)
        x = self.d1(torch.cat([x, f1], dim=1))
        x = self.drop(x)
        return self.out(x)

class HybridMTLModel(nn.Module):
    def __init__(self, base_ch=32, num_classes=3, seg_dropout=0.10):
        super().__init__()
        self.encoder = ConvEncoder(3, base_ch)
        self.bottleneck = ViTBottleneck(base_ch*8, depth=2, heads=4, drop=0.1)
        self.decoder = UNetDecoder(base_ch, out_ch=1, dropout_p=seg_dropout)
        self.cls_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(base_ch*8, base_ch*8),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(base_ch*8, num_classes),
        )
    def forward(self, x):
        feats = self.encoder(x)
        feats[-1] = self.bottleneck(feats[-1])
        seg_logit = self.decoder(feats)
        cls_logit = self.cls_head(feats[-1])
        return seg_logit, cls_logit


# =========================================================
# Export helpers
# =========================================================
def predictive_entropy(p: np.ndarray, eps=1e-12) -> np.ndarray:
    if p.ndim >= 2 and p.shape[-1] > 1:
        return -(p * np.log(p + eps)).sum(axis=-1)
    return -(p*np.log(p+eps) + (1-p)*np.log(1-p+eps))

def mutual_information(mc_probs: np.ndarray, eps=1e-12) -> np.ndarray:
    mean_p = mc_probs.mean(axis=0)
    H_mean = predictive_entropy(mean_p, eps=eps)
    H_each = predictive_entropy(mc_probs, eps=eps)
    return H_mean - H_each.mean(axis=0)

def compute_boundary_prob_from_plesion(p_lesion: np.ndarray) -> np.ndarray:
    gy, gx = np.gradient(p_lesion)
    g = np.sqrt(gx*gx + gy*gy)
    g = g / (g.max() + 1e-12)
    return g.astype(np.float32)

def transition_width_map_mm(p_lesion: np.ndarray) -> np.ndarray:
    gy, gx = np.gradient(p_lesion)
    g = np.sqrt(gx*gx + gy*gy) + 1e-12
    width_px = 0.6 / g
    width_mm = width_px * MM_PER_PIXEL
    return np.clip(width_mm, 0.0, 10.0).astype(np.float32)

def softmax_np(logits: np.ndarray) -> np.ndarray:
    x = logits - logits.max(axis=1, keepdims=True)
    e = np.exp(x)
    return e / (e.sum(axis=1, keepdims=True) + 1e-12)

def risk_coverage_curve(probs: np.ndarray, labels: np.ndarray, uncertainty: np.ndarray):
    n = len(labels)
    order = np.argsort(uncertainty)
    probs_s = probs[order]
    labels_s = labels[order]
    pred = probs_s.argmax(axis=1)
    err = (pred != labels_s).astype(np.float32)
    coverages = np.linspace(1/n, 1.0, n)
    risks = np.cumsum(err) / (np.arange(n) + 1)
    return coverages, risks

def aurc(coverages: np.ndarray, risks: np.ndarray) -> float:
    return float(np.trapezoid(risks, coverages))

def enable_dropout_only(model: nn.Module):
    model.eval()
    for m in model.modules():
        if isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
            m.train()


# =========================================================
# Config
# =========================================================
@dataclass
class InferCfg:
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    amp: bool = True
    workers: int = 2

    seg_mc_T: int = 8
    seg_batch: int = 1
    cls_batch: int = 8

cfg = InferCfg()


# =========================================================
# Build loaders (NO split; use full validation folders)
# =========================================================
seg_val_ds = ISICSegValDataset(SEG_IMAGE_DIR, SEG_MASK_DIR, image_size=IMAGE_SIZE)
cls_val_ds = ISICClsValDataset(CLS_IMAGE_DIR, CLS_LABEL_CSV, image_size=IMAGE_SIZE)

print(f"[VAL-SEG] pairs found = {len(seg_val_ds)}")
print(f"[VAL-CLS] samples found = {len(cls_val_ds)}")

common_loader_kwargs = dict(
    num_workers=cfg.workers,
    pin_memory=True,
    persistent_workers=True if cfg.workers > 0 else False,
    prefetch_factor=2 if cfg.workers > 0 else None
)

seg_val_loader = DataLoader(seg_val_ds, batch_size=cfg.seg_batch, shuffle=False, **common_loader_kwargs)
cls_val_loader = DataLoader(cls_val_ds, batch_size=cfg.cls_batch, shuffle=False, **common_loader_kwargs)


# =========================================================
# Load model
# =========================================================
model = HybridMTLModel(base_ch=32, num_classes=3, seg_dropout=0.10).to(cfg.device)
ckpt_path = pick_checkpoint(CHECKPOINT_DIR)
ckpt = torch.load(ckpt_path, map_location=cfg.device)
model.load_state_dict(ckpt["model"], strict=True)
ep = int(ckpt.get("epoch", 0))
print(f"[Load] {ckpt_path.name} | epoch={ep}")

OUT_DIR = EXPORT_DIR / f"epoch_{ep:03d}"
OUT_DIR.mkdir(parents=True, exist_ok=True)


# =========================================================
# Export SEG (full validation)
# =========================================================
@torch.no_grad()
def export_seg_full_validation(mc_T: int):
    enable_dropout_only(model)
    out_dir = OUT_DIR / f"seg_val_MC{mc_T}"
    out_dir.mkdir(parents=True, exist_ok=True)

    for i, b in enumerate(tqdm(seg_val_loader, desc=f"[Export-Seg VAL FULL MC(T={mc_T})]", leave=True)):
        x = b["image"].to(cfg.device)
        sid = b["id"][0] if isinstance(b["id"], (list, tuple)) else str(b["id"])

        mc_probs = []
        for _ in range(mc_T):
            with torch.amp.autocast(device_type="cuda", enabled=(cfg.amp and cfg.device.startswith("cuda"))):
                seg_logit, _ = model(x)
                prob = torch.sigmoid(seg_logit).squeeze(0).squeeze(0)  # [H,W]
            if prob.ndim == 2:
                mc_probs.append(prob.detach().float().cpu().numpy())

        if len(mc_probs) == 0:
            continue

        mc_probs = np.stack(mc_probs, axis=0)
        p_lesion = mc_probs.mean(axis=0).astype(np.float32)
        p_boundary = compute_boundary_prob_from_plesion(p_lesion)
        ent_map = predictive_entropy(p_lesion).astype(np.float32)
        mi_map  = mutual_information(mc_probs).astype(np.float32) if mc_T > 1 else np.zeros_like(ent_map)
        tw_mm   = transition_width_map_mm(p_lesion)

        np.save(out_dir / f"{sid}_p_lesion.npy", p_lesion)
        np.save(out_dir / f"{sid}_p_boundary.npy", p_boundary)
        np.save(out_dir / f"{sid}_entropy.npy", ent_map)
        np.save(out_dir / f"{sid}_mi.npy", mi_map)
        np.save(out_dir / f"{sid}_transition_width_mm.npy", tw_mm)

        save_map_with_mm_axis(out_dir / f"{sid}_p_lesion.png", p_lesion, "p_lesion")
        save_map_with_mm_axis(out_dir / f"{sid}_p_boundary.png", p_boundary, "p_boundary")
        save_map_with_mm_axis(out_dir / f"{sid}_entropy.png", ent_map, "Predictive entropy")
        save_map_with_mm_axis(out_dir / f"{sid}_mi.png", mi_map, "Mutual information")
        save_map_with_mm_axis(out_dir / f"{sid}_transition_width_mm.png", tw_mm, "Transition width (mm)")

    print(f"[Seg Export] done -> {out_dir}")


# =========================================================
# Export CLS (full validation)
# =========================================================
@torch.no_grad()
def export_cls_full_validation():
    model.eval()
    out_dir = OUT_DIR / "cls_val"
    out_dir.mkdir(parents=True, exist_ok=True)

    all_logits, all_labels = [], []
    for b in tqdm(cls_val_loader, desc="[Export-Cls VAL FULL]", leave=True):
        x = b["image"].to(cfg.device)
        y = b["label"].to(cfg.device)
        with torch.amp.autocast(device_type="cuda", enabled=(cfg.amp and cfg.device.startswith("cuda"))):
            _, logits = model(x)
        all_logits.append(logits.detach().float().cpu().numpy())
        all_labels.append(y.detach().cpu().numpy())

    logits = np.concatenate(all_logits, axis=0) if all_logits else np.zeros((0,3), np.float32)
    labels = np.concatenate(all_labels, axis=0) if all_labels else np.zeros((0,), np.int64)

    probs = softmax_np(logits) if len(labels) else np.zeros_like(logits)
    unc  = predictive_entropy(probs).astype(np.float32) if len(labels) else np.zeros((0,), np.float32)

    np.save(out_dir / "logits.npy", logits.astype(np.float32))
    np.save(out_dir / "probs.npy", probs.astype(np.float32))
    np.save(out_dir / "labels.npy", labels.astype(np.int64))
    np.save(out_dir / "uncertainty_entropy.npy", unc.astype(np.float32))

    if len(labels) > 0:
        cov, risk = risk_coverage_curve(probs, labels, unc)
        A = aurc(cov, risk)
        np.save(out_dir / "risk_coverage_coverage.npy", cov.astype(np.float32))
        np.save(out_dir / "risk_coverage_risk.npy", risk.astype(np.float32))

        plt.figure(figsize=(4, 3), dpi=150)
        plt.plot(cov, risk)
        plt.xlabel("Coverage")
        plt.ylabel("Risk (error rate)")
        plt.title(f"VAL Risk–Coverage (AURC={A:.4f})")
        plt.tight_layout()
        plt.savefig(out_dir / "risk_coverage_curve.png", bbox_inches="tight")
        plt.close()

        pred = probs.argmax(axis=1)
        acc = float((pred == labels).mean())
        print(f"[VAL-Cls] n={len(labels)} acc={acc:.4f} AURC={A:.4f}")

    print(f"[Cls Export] done -> {out_dir}")


# =========================================================
# Run full validation exports
# =========================================================
export_seg_full_validation(mc_T=cfg.seg_mc_T)
export_cls_full_validation()

print(f"✅ Exports saved to: {OUT_DIR}")

[VAL-SEG] pairs found = 100
[VAL-CLS] samples found = 193




[Load] best.pt | epoch=30


[Export-Seg VAL FULL MC(T=8)]:   0%|          | 0/100 [00:00<?, ?it/s]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:   1%|          | 1/100 [00:01<02:20,  1.42s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:   2%|▏         | 2/100 [00:02<02:03,  1.26s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:   3%|▎         | 3/100 [00:03<01:53,  1.17s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:   4%|▍         | 4/100 [00:04<01:45,  1.10s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:   5%|▌         | 5/100 [00:06<02:06,  1.33s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:   6%|▌         | 6/100 [00:09<02:48,  1.79s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:   7%|▋         | 7/100 [00:10<02:31,  1.63s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:   8%|▊         | 8/100 [00:11<02:27,  1.60s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:   9%|▉         | 9/100 [00:13<02:15,  1.49s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  10%|█         | 10/100 [00:14<02:07,  1.42s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  11%|█         | 11/100 [00:15<02:10,  1.47s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  12%|█▏        | 12/100 [00:17<02:03,  1.41s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  13%|█▎        | 13/100 [00:18<01:59,  1.37s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  14%|█▍        | 14/100 [00:20<02:01,  1.41s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  15%|█▌        | 15/100 [00:21<02:11,  1.54s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  16%|█▌        | 16/100 [00:23<02:06,  1.51s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  17%|█▋        | 17/100 [00:24<01:51,  1.34s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  18%|█▊        | 18/100 [00:25<01:39,  1.22s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  19%|█▉        | 19/100 [00:26<01:39,  1.23s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  20%|██        | 20/100 [00:27<01:32,  1.15s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  21%|██        | 21/100 [00:28<01:28,  1.12s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  22%|██▏       | 22/100 [00:29<01:23,  1.07s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  23%|██▎       | 23/100 [00:30<01:21,  1.06s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  24%|██▍       | 24/100 [00:31<01:18,  1.03s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  25%|██▌       | 25/100 [00:32<01:16,  1.02s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  26%|██▌       | 26/100 [00:33<01:18,  1.05s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  27%|██▋       | 27/100 [00:35<01:27,  1.20s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  28%|██▊       | 28/100 [00:36<01:36,  1.35s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  29%|██▉       | 29/100 [00:37<01:27,  1.23s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  30%|███       | 30/100 [00:38<01:21,  1.16s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  31%|███       | 31/100 [00:39<01:17,  1.12s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  32%|███▏      | 32/100 [00:40<01:15,  1.12s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  33%|███▎      | 33/100 [00:41<01:11,  1.07s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  34%|███▍      | 34/100 [00:42<01:11,  1.08s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  35%|███▌      | 35/100 [00:43<01:08,  1.05s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  36%|███▌      | 36/100 [00:44<01:05,  1.03s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  37%|███▋      | 37/100 [00:45<01:04,  1.02s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  38%|███▊      | 38/100 [00:47<01:05,  1.06s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  39%|███▉      | 39/100 [00:48<01:11,  1.17s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  40%|████      | 40/100 [00:50<01:21,  1.35s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  41%|████      | 41/100 [00:51<01:14,  1.26s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  42%|████▏     | 42/100 [00:52<01:07,  1.17s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  43%|████▎     | 43/100 [00:53<01:05,  1.15s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  44%|████▍     | 44/100 [00:54<01:01,  1.10s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  45%|████▌     | 45/100 [00:55<00:59,  1.08s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  46%|████▌     | 46/100 [00:56<00:57,  1.06s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  47%|████▋     | 47/100 [00:57<00:57,  1.08s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  48%|████▊     | 48/100 [00:58<00:55,  1.07s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  49%|████▉     | 49/100 [00:59<00:53,  1.05s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  50%|█████     | 50/100 [01:01<00:59,  1.19s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  51%|█████     | 51/100 [01:02<01:03,  1.29s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  52%|█████▏    | 52/100 [01:03<01:01,  1.28s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  53%|█████▎    | 53/100 [01:04<00:57,  1.22s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  54%|█████▍    | 54/100 [01:06<00:59,  1.29s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  55%|█████▌    | 55/100 [01:07<00:53,  1.20s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  56%|█████▌    | 56/100 [01:08<00:51,  1.18s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  57%|█████▋    | 57/100 [01:09<00:49,  1.14s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  58%|█████▊    | 58/100 [01:10<00:48,  1.15s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  59%|█████▉    | 59/100 [01:11<00:47,  1.16s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  60%|██████    | 60/100 [01:12<00:44,  1.11s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  61%|██████    | 61/100 [01:14<00:47,  1.21s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  62%|██████▏   | 62/100 [01:15<00:49,  1.29s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  63%|██████▎   | 63/100 [01:17<00:49,  1.34s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  64%|██████▍   | 64/100 [01:18<00:44,  1.24s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  65%|██████▌   | 65/100 [01:19<00:41,  1.18s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  66%|██████▌   | 66/100 [01:20<00:38,  1.13s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  67%|██████▋   | 67/100 [01:21<00:35,  1.09s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  68%|██████▊   | 68/100 [01:22<00:35,  1.12s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  69%|██████▉   | 69/100 [01:23<00:36,  1.18s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  70%|███████   | 70/100 [01:24<00:34,  1.14s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  71%|███████   | 71/100 [01:25<00:32,  1.11s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  72%|███████▏  | 72/100 [01:27<00:38,  1.38s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  73%|███████▎  | 73/100 [01:29<00:39,  1.47s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  74%|███████▍  | 74/100 [01:30<00:35,  1.37s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  75%|███████▌  | 75/100 [01:31<00:31,  1.26s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  76%|███████▌  | 76/100 [01:32<00:29,  1.21s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  77%|███████▋  | 77/100 [01:33<00:26,  1.16s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  78%|███████▊  | 78/100 [01:34<00:24,  1.11s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  79%|███████▉  | 79/100 [01:35<00:23,  1.10s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  80%|████████  | 80/100 [01:37<00:21,  1.08s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  81%|████████  | 81/100 [01:38<00:20,  1.06s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  82%|████████▏ | 82/100 [01:39<00:18,  1.05s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  83%|████████▎ | 83/100 [01:40<00:18,  1.11s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  84%|████████▍ | 84/100 [01:41<00:20,  1.27s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  85%|████████▌ | 85/100 [01:43<00:20,  1.36s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  86%|████████▌ | 86/100 [01:44<00:18,  1.30s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  87%|████████▋ | 87/100 [01:45<00:15,  1.22s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  88%|████████▊ | 88/100 [01:46<00:14,  1.17s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  89%|████████▉ | 89/100 [01:47<00:12,  1.11s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  90%|█████████ | 90/100 [01:48<00:10,  1.07s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  91%|█████████ | 91/100 [01:49<00:09,  1.07s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  92%|█████████▏| 92/100 [01:50<00:08,  1.05s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  93%|█████████▎| 93/100 [01:51<00:07,  1.04s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  94%|█████████▍| 94/100 [01:53<00:07,  1.32s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  95%|█████████▌| 95/100 [01:55<00:06,  1.36s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  96%|█████████▌| 96/100 [01:56<00:05,  1.37s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  97%|█████████▋| 97/100 [01:57<00:03,  1.28s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  98%|█████████▊| 98/100 [01:58<00:02,  1.20s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]:  99%|█████████▉| 99/100 [01:59<00:01,  1.14s/it]

[WARN] Skip saving Predictive entropy: invalid shape=(256,)
[WARN] Skip saving Mutual information: invalid shape=(256,)


[Export-Seg VAL FULL MC(T=8)]: 100%|██████████| 100/100 [02:00<00:00,  1.21s/it]


[Seg Export] done -> /content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/exports_validation/epoch_030/seg_val_MC8


[Export-Cls VAL FULL]: 100%|██████████| 25/25 [00:58<00:00,  2.35s/it]


[VAL-Cls] n=193 acc=0.8394 AURC=0.0434
[Cls Export] done -> /content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/exports_validation/epoch_030/cls_val
✅ Exports saved to: /content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/exports_validation/epoch_030


In [None]:
# training set

# -*- coding: utf-8 -*-
# =========================================================
# TEST-SET Evaluation + Export (NO TRAINING)
# Hybrid CNN–ViT–U-Net (Seg + Cls)
# ISIC 2018 Task 1 (Seg Test) + Task 3 (Cls Test)
# Physical size: 15mm × 15mm → 256 × 256 px
#
# Metrics on TEST:
#  Seg: Dice, BF-score, HD95(mm), mean entropy/MI (optional)
#  Cls: Acc, ECE, Risk–Coverage, AURC, NLL, Brier
#
# Exports:
#  Seg: p_lesion, p_boundary, entropy, MI, transition_width_mm (npy + png with mm axis)
#  Cls: logits/probs/labels/uncertainty(entropy) + risk-coverage curve
#
# Settings:
#  - workers=2
#  - NEW AMP API
#  - load best.pt else latest epoch_XXX.pt
# =========================================================

from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
import math, random, re
import numpy as np
import pandas as pd
import cv2
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
from scipy import ndimage as ndi


# -----------------------------
# Performance settings
# -----------------------------
cv2.setNumThreads(0)
torch.backends.cudnn.benchmark = True
if torch.cuda.is_available():
    torch.set_float32_matmul_precision("high")


# =========================================================
# TEST paths (✅ your setting)
# =========================================================
SEG_IMAGE_DIR = Path("/content/drive/My Drive/isic2018/Test")
SEG_MASK_DIR  = Path("/content/drive/My Drive/isic2018/TestTruth")

CLS_IMAGE_DIR = Path("/content/drive/My Drive/ISIC_2018_T3/ISIC2018_Task3_Test_Input".strip())
CLS_LABEL_CSV = Path("/content/drive/My Drive/ISIC_2018_T3/ISIC2018_Task3_Test_GroundTruth/ISIC2018_Task3_Test_GroundTruth.csv".strip())

CHECKPOINT_DIR = Path("/content/drive/My Drive/seg_class_skin_cancer/Conv_encoder_ViT_U-Net/checkpoints".strip())
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

EXPORT_DIR = CHECKPOINT_DIR.parent / "exports_test"
EXPORT_DIR.mkdir(parents=True, exist_ok=True)


# -----------------------------
# Physical size
# -----------------------------
FIELD_SIZE_MM = 15.0
IMAGE_SIZE = 256
MM_PER_PIXEL = FIELD_SIZE_MM / IMAGE_SIZE


# -----------------------------
# Utils
# -----------------------------
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything(42)

def list_images_any_ext(folder: Path):
    exts = ["*.jpg", "*.jpeg", "*.png", "*.JPG", "*.JPEG", "*.PNG"]
    paths = []
    for e in exts:
        paths.extend(folder.glob(e))
    return sorted(paths)

def pick_checkpoint(ckpt_dir: Path) -> Path:
    best_path = ckpt_dir / "best.pt"
    if best_path.exists():
        return best_path
    pts = sorted(ckpt_dir.glob("epoch_*.pt"))
    if not pts:
        raise FileNotFoundError(f"No checkpoint found in: {ckpt_dir}")
    def key(p: Path):
        m = re.search(r"epoch_(\d+)\.pt$", p.name)
        return int(m.group(1)) if m else -1
    pts = sorted(pts, key=key)
    return pts[-1]

def save_map_with_mm_axis(path_png: Path, arr: np.ndarray, title: str):
    arr = np.asarray(arr)
    if arr.ndim != 2:
        print(f"[WARN] Skip saving {title}: invalid shape={arr.shape}")
        return
    plt.figure(figsize=(5, 4), dpi=150)
    plt.imshow(arr, extent=(0, FIELD_SIZE_MM, FIELD_SIZE_MM, 0))
    plt.colorbar(fraction=0.046, pad=0.04)
    plt.xlabel("x (mm)")
    plt.ylabel("y (mm)")
    plt.title(title)
    plt.tight_layout()
    plt.savefig(path_png, bbox_inches="tight")
    plt.close()


# =========================================================
# Dataset — Segmentation (Task 1 TEST)
# =========================================================
class ISICSegTestDataset(Dataset):
    """
    Robust mask matching:
      - <stem>_segmentation.png
      - <stem>.png
      - also supports .PNG
    """
    def __init__(self, image_dir: Path, mask_dir: Path, image_size: int = 256):
        self.image_size = image_size
        image_paths = list_images_any_ext(image_dir)

        pairs = []
        for p in image_paths:
            stem = p.stem
            cand = [
                mask_dir / f"{stem}_segmentation.png",
                mask_dir / f"{stem}.png",
                mask_dir / f"{stem}_segmentation.PNG",
                mask_dir / f"{stem}.PNG",
            ]
            mp = None
            for c in cand:
                if c.exists():
                    mp = c
                    break
            if mp is None:
                continue
            pairs.append((p, mp))
        if len(pairs) == 0:
            raise ValueError(f"No (image,mask) pairs found. Check masks in: {mask_dir}")

        self.pairs = pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_path, mask_path = self.pairs[idx]

        img = cv2.imread(str(img_path))
        if img is None:
            raise FileNotFoundError(f"Cannot read image: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(str(mask_path), cv2.IMREAD_GRAYSCALE)
        if mask is None:
            raise FileNotFoundError(f"Cannot read mask: {mask_path}")

        img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
        mask = cv2.resize(mask, (self.image_size, self.image_size), interpolation=cv2.INTER_NEAREST)

        img = img.astype(np.float32) / 255.0
        mask = (mask > 0).astype(np.float32)

        img_t = torch.from_numpy(img).permute(2, 0, 1)
        mask_t = torch.from_numpy(mask).unsqueeze(0)
        return {"image": img_t, "mask": mask_t, "id": img_path.stem}


# =========================================================
# Dataset — Classification (Task 3 TEST one-hot CSV)
# =========================================================
class ISICClsTestDataset(Dataset):
    """
    Expects one-hot CSV:
      image, MEL, NV, BCC, AKIEC, BKL, DF, VASC
    Converts to 3-class: benign=0, intermediate=1, malignant=2
    """
    def __init__(self, image_dir: Path, label_csv: Path, image_size: int = 256):
        self.image_dir = image_dir
        self.df = pd.read_csv(label_csv)
        self.df.columns = self.df.columns.str.strip()
        self.image_size = image_size

        onehot = ["MEL", "NV", "BCC", "AKIEC", "BKL", "DF", "VASC"]
        if ("image" not in self.df.columns) or (not all(c in self.df.columns for c in onehot)):
            raise ValueError(f"CSV missing expected columns. Got: {list(self.df.columns)}")

        self.records = []
        for _, row in self.df.iterrows():
            img_id = str(row["image"]).strip()
            malignant = (row["MEL"] == 1) or (row["BCC"] == 1)
            intermediate = (row["AKIEC"] == 1)
            benign = (row["NV"] == 1) or (row["BKL"] == 1) or (row["DF"] == 1) or (row["VASC"] == 1)
            if malignant:
                label = 2
            elif intermediate:
                label = 1
            elif benign:
                label = 0
            else:
                continue
            self.records.append((img_id, label))

        if len(self.records) == 0:
            raise ValueError("No valid records in test CSV.")

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        img_id, label = self.records[idx]
        # support jpg/png/jpeg
        cand = [
            self.image_dir / f"{img_id}.jpg",
            self.image_dir / f"{img_id}.png",
            self.image_dir / f"{img_id}.jpeg",
            self.image_dir / f"{img_id}.JPG",
            self.image_dir / f"{img_id}.PNG",
            self.image_dir / f"{img_id}.JPEG",
        ]
        img_path = None
        for c in cand:
            if c.exists():
                img_path = c
                break
        if img_path is None:
            raise FileNotFoundError(f"Cannot find image for id={img_id} in {self.image_dir}")

        img = cv2.imread(str(img_path))
        if img is None:
            raise FileNotFoundError(f"Cannot read image: {img_path}")
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        img = cv2.resize(img, (self.image_size, self.image_size), interpolation=cv2.INTER_LINEAR)
        img = img.astype(np.float32) / 255.0
        img_t = torch.from_numpy(img).permute(2, 0, 1)
        return {"image": img_t, "label": torch.tensor(label, dtype=torch.long), "id": img_id}


# =========================================================
# Model (same as training)
# =========================================================
class ConvBNAct(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.GELU()
        )
    def forward(self, x): return self.net(x)

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(ConvBNAct(in_ch, out_ch), ConvBNAct(out_ch, out_ch))
    def forward(self, x): return self.net(x)

class ConvEncoder(nn.Module):
    def __init__(self, in_ch=3, base_ch=32):
        super().__init__()
        self.e1 = ConvBlock(in_ch, base_ch)
        self.e2 = ConvBlock(base_ch, base_ch * 2)
        self.e3 = ConvBlock(base_ch * 2, base_ch * 4)
        self.e4 = ConvBlock(base_ch * 4, base_ch * 8)
        self.pool = nn.MaxPool2d(2)
    def forward(self, x):
        f1 = self.e1(x)
        f2 = self.e2(self.pool(f1))
        f3 = self.e3(self.pool(f2))
        f4 = self.e4(self.pool(f3))
        return [f1, f2, f3, f4]

class ViTBottleneck(nn.Module):
    def __init__(self, dim, depth=2, heads=4, drop=0.1):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=dim, nhead=heads, dim_feedforward=dim*4,
            dropout=drop, activation="gelu", batch_first=True, norm_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, depth)
    def forward(self, x):
        b, c, h, w = x.shape
        tokens = x.flatten(2).transpose(1, 2)
        tokens = self.encoder(tokens)
        return tokens.transpose(1, 2).reshape(b, c, h, w)

class UNetDecoder(nn.Module):
    def __init__(self, base_ch=32, out_ch=1, dropout_p=0.10):
        super().__init__()
        self.up3 = nn.ConvTranspose2d(base_ch*8, base_ch*4, 2, 2)
        self.d3  = ConvBlock(base_ch*8, base_ch*4)
        self.up2 = nn.ConvTranspose2d(base_ch*4, base_ch*2, 2, 2)
        self.d2  = ConvBlock(base_ch*4, base_ch*2)
        self.up1 = nn.ConvTranspose2d(base_ch*2, base_ch, 2, 2)
        self.d1  = ConvBlock(base_ch*2, base_ch)
        self.drop = nn.Dropout2d(dropout_p) if dropout_p > 0 else nn.Identity()
        self.out  = nn.Conv2d(base_ch, out_ch, 1)
    def forward(self, feats):
        f1, f2, f3, f4 = feats
        x = self.up3(f4)
        x = self.d3(torch.cat([x, f3], dim=1))
        x = self.up2(x)
        x = self.d2(torch.cat([x, f2], dim=1))
        x = self.up1(x)
        x = self.d1(torch.cat([x, f1], dim=1))
        x = self.drop(x)
        return self.out(x)

class HybridMTLModel(nn.Module):
    def __init__(self, base_ch=32, num_classes=3, seg_dropout=0.10):
        super().__init__()
        self.encoder = ConvEncoder(3, base_ch)
        self.bottleneck = ViTBottleneck(base_ch*8, depth=2, heads=4, drop=0.1)
        self.decoder = UNetDecoder(base_ch, out_ch=1, dropout_p=seg_dropout)
        self.cls_head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(base_ch*8, base_ch*8),
            nn.GELU(),
            nn.Dropout(0.2),
            nn.Linear(base_ch*8, num_classes),
        )
    def forward(self, x):
        feats = self.encoder(x)
        feats[-1] = self.bottleneck(feats[-1])
        seg_logit = self.decoder(feats)
        cls_logit = self.cls_head(feats[-1])
        return seg_logit, cls_logit


# =========================================================
# Metrics (Seg)
# =========================================================
def mask_to_boundary(mask: np.ndarray) -> np.ndarray:
    mask_u8 = (mask > 0).astype(np.uint8)
    kernel = np.ones((3, 3), np.uint8)
    er = cv2.erode(mask_u8, kernel, iterations=1)
    bd = (mask_u8 - er) > 0
    return bd.astype(np.uint8)

def bf_score(pred_mask: np.ndarray, gt_mask: np.ndarray, tol_px: int = 2) -> float:
    pb = mask_to_boundary(pred_mask)
    gb = mask_to_boundary(gt_mask)
    if pb.sum() == 0 and gb.sum() == 0:
        return 1.0
    if pb.sum() == 0 or gb.sum() == 0:
        return 0.0
    dt_g = ndi.distance_transform_edt(1 - gb)
    dt_p = ndi.distance_transform_edt(1 - pb)
    prec = (dt_g[pb.astype(bool)] <= tol_px).mean() if pb.sum() > 0 else 0.0
    rec  = (dt_p[gb.astype(bool)] <= tol_px).mean() if gb.sum() > 0 else 0.0
    return float(2 * prec * rec / (prec + rec)) if (prec + rec) > 0 else 0.0

def hd95(pred_mask: np.ndarray, gt_mask: np.ndarray) -> float:
    pb = mask_to_boundary(pred_mask).astype(bool)
    gb = mask_to_boundary(gt_mask).astype(bool)
    if pb.sum() == 0 and gb.sum() == 0:
        return 0.0
    if pb.sum() == 0 or gb.sum() == 0:
        h, w = pred_mask.shape
        return float(math.sqrt(h*h + w*w))
    dt_g = ndi.distance_transform_edt(~gb)
    dt_p = ndi.distance_transform_edt(~pb)
    d1 = dt_g[pb]
    d2 = dt_p[gb]
    all_d = np.concatenate([d1, d2], axis=0)
    return float(np.percentile(all_d, 95))

def dice_score(pred_mask: np.ndarray, gt_mask: np.ndarray) -> float:
    inter = (pred_mask & gt_mask).sum()
    union = pred_mask.sum() + gt_mask.sum()
    return float((2 * inter / (union + 1e-6)) if union > 0 else 1.0)


# =========================================================
# Metrics (Cls)
# =========================================================
def softmax_np(logits: np.ndarray) -> np.ndarray:
    x = logits - logits.max(axis=1, keepdims=True)
    e = np.exp(x)
    return e / (e.sum(axis=1, keepdims=True) + 1e-12)

def ece_score(probs: np.ndarray, labels: np.ndarray, n_bins: int = 15) -> float:
    conf = probs.max(axis=1)
    pred = probs.argmax(axis=1)
    acc = (pred == labels).astype(np.float32)
    bins = np.linspace(0.0, 1.0, n_bins + 1)
    ece = 0.0
    for i in range(n_bins):
        m = (conf > bins[i]) & (conf <= bins[i+1])
        if m.any():
            ece += abs(acc[m].mean() - conf[m].mean()) * m.mean()
    return float(ece)

def nll_score(probs: np.ndarray, labels: np.ndarray) -> float:
    p = probs[np.arange(len(labels)), labels]
    return float((-np.log(p + 1e-12)).mean())

def brier_score(probs: np.ndarray, labels: np.ndarray) -> float:
    n, k = probs.shape
    y = np.zeros((n, k), dtype=np.float32)
    y[np.arange(n), labels] = 1.0
    return float(((probs - y) ** 2).sum(axis=1).mean())

def predictive_entropy(p: np.ndarray, eps=1e-12) -> np.ndarray:
    # for multi-class probs: [N,K]
    return -(p * np.log(p + eps)).sum(axis=-1)

def risk_coverage_curve(probs: np.ndarray, labels: np.ndarray, uncertainty: np.ndarray):
    n = len(labels)
    order = np.argsort(uncertainty)
    probs_s = probs[order]
    labels_s = labels[order]
    pred = probs_s.argmax(axis=1)
    err = (pred != labels_s).astype(np.float32)
    coverages = np.linspace(1/n, 1.0, n)
    risks = np.cumsum(err) / (np.arange(n) + 1)
    return coverages, risks

def aurc(coverages: np.ndarray, risks: np.ndarray) -> float:
    return float(np.trapezoid(risks, coverages))


# =========================================================
# MC Dropout (Seg)
# =========================================================
def enable_dropout_only(model: nn.Module):
    model.eval()
    for m in model.modules():
        if isinstance(m, (nn.Dropout, nn.Dropout2d, nn.Dropout3d)):
            m.train()

def mutual_information(mc_probs: np.ndarray, eps=1e-12) -> np.ndarray:
    mean_p = mc_probs.mean(axis=0)
    H_mean = -(mean_p*np.log(mean_p+eps) + (1-mean_p)*np.log(1-mean_p+eps))
    H_each = -(mc_probs*np.log(mc_probs+eps) + (1-mc_probs)*np.log(1-mc_probs+eps))
    return H_mean - H_each.mean(axis=0)

def compute_boundary_prob_from_plesion(p_lesion: np.ndarray) -> np.ndarray:
    gy, gx = np.gradient(p_lesion)
    g = np.sqrt(gx*gx + gy*gy)
    g = g / (g.max() + 1e-12)
    return g.astype(np.float32)

def transition_width_map_mm(p_lesion: np.ndarray) -> np.ndarray:
    gy, gx = np.gradient(p_lesion)
    g = np.sqrt(gx*gx + gy*gy) + 1e-12
    width_px = 0.6 / g
    width_mm = width_px * MM_PER_PIXEL
    return np.clip(width_mm, 0.0, 10.0).astype(np.float32)


# =========================================================
# Config
# =========================================================
@dataclass
class TestCfg:
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    amp: bool = True
    workers: int = 2

    seg_mc_T: int = 8
    seg_export_png: bool = True
    seg_export_npy: bool = True

    cls_export_curve: bool = True

cfg = TestCfg()


# =========================================================
# Build loaders
# =========================================================
seg_ds = ISICSegTestDataset(SEG_IMAGE_DIR, SEG_MASK_DIR, image_size=IMAGE_SIZE)
cls_ds = ISICClsTestDataset(CLS_IMAGE_DIR, CLS_LABEL_CSV, image_size=IMAGE_SIZE)

print(f"[TEST-SEG] pairs found = {len(seg_ds)}")
print(f"[TEST-CLS] samples found = {len(cls_ds)}")

common_loader_kwargs = dict(
    num_workers=cfg.workers,
    pin_memory=True,
    persistent_workers=True if cfg.workers > 0 else False,
    prefetch_factor=2 if cfg.workers > 0 else None
)

seg_loader = DataLoader(seg_ds, batch_size=1, shuffle=False, **common_loader_kwargs)
cls_loader = DataLoader(cls_ds, batch_size=8, shuffle=False, **common_loader_kwargs)


# =========================================================
# Load model
# =========================================================
model = HybridMTLModel(base_ch=32, num_classes=3, seg_dropout=0.10).to(cfg.device)
ckpt_path = pick_checkpoint(CHECKPOINT_DIR)
ckpt = torch.load(ckpt_path, map_location=cfg.device)
model.load_state_dict(ckpt["model"], strict=True)
epoch_loaded = int(ckpt.get("epoch", 0))
print(f"[Load] {ckpt_path.name} | epoch={epoch_loaded}")

OUT_DIR = EXPORT_DIR / f"epoch_{epoch_loaded:03d}"
OUT_DIR.mkdir(parents=True, exist_ok=True)


# =========================================================
# TEST evaluation + export
# =========================================================
@torch.no_grad()
def eval_export_seg_test(mc_T: int):
    enable_dropout_only(model)

    out_dir = OUT_DIR / f"seg_test_MC{mc_T}"
    out_dir.mkdir(parents=True, exist_ok=True)

    dices, bfs, hd95s_px = [], [], []
    ent_means, mi_means = [], []

    for b in tqdm(seg_loader, desc=f"[TEST Seg MC(T={mc_T})]", leave=True):
        x = b["image"].to(cfg.device)
        y = b["mask"].to(cfg.device)
        sid = b["id"][0] if isinstance(b["id"], (list, tuple)) else str(b["id"])

        mc_probs = []
        for _ in range(mc_T):
            with torch.amp.autocast(device_type="cuda", enabled=(cfg.amp and cfg.device.startswith("cuda"))):
                seg_logit, _ = model(x)
                prob = torch.sigmoid(seg_logit).squeeze(0).squeeze(0)  # [H,W]
            if prob.ndim == 2:
                mc_probs.append(prob.detach().float().cpu().numpy())

        if len(mc_probs) == 0:
            continue

        mc_probs = np.stack(mc_probs, axis=0)
        p_lesion = mc_probs.mean(axis=0).astype(np.float32)

        pred_mask = (p_lesion > 0.5).astype(np.uint8)
        gt_mask = (y.squeeze().detach().cpu().numpy() > 0.5).astype(np.uint8)

        dices.append(dice_score(pred_mask, gt_mask))
        bfs.append(bf_score(pred_mask, gt_mask, tol_px=2))
        hd95s_px.append(hd95(pred_mask, gt_mask))

        ent_map = -(p_lesion*np.log(p_lesion+1e-12) + (1-p_lesion)*np.log(1-p_lesion+1e-12))
        mi_map = mutual_information(mc_probs) if mc_T > 1 else np.zeros_like(ent_map)

        ent_means.append(float(ent_map.mean()))
        mi_means.append(float(mi_map.mean()))

        # exports
        p_boundary = compute_boundary_prob_from_plesion(p_lesion)
        tw_mm = transition_width_map_mm(p_lesion)

        if cfg.seg_export_npy:
            np.save(out_dir / f"{sid}_p_lesion.npy", p_lesion)
            np.save(out_dir / f"{sid}_p_boundary.npy", p_boundary)
            np.save(out_dir / f"{sid}_entropy.npy", ent_map.astype(np.float32))
            np.save(out_dir / f"{sid}_mi.npy", mi_map.astype(np.float32))
            np.save(out_dir / f"{sid}_transition_width_mm.npy", tw_mm.astype(np.float32))

        if cfg.seg_export_png:
            save_map_with_mm_axis(out_dir / f"{sid}_p_lesion.png", p_lesion, "p_lesion")
            save_map_with_mm_axis(out_dir / f"{sid}_p_boundary.png", p_boundary, "p_boundary")
            save_map_with_mm_axis(out_dir / f"{sid}_entropy.png", ent_map, "Predictive entropy")
            save_map_with_mm_axis(out_dir / f"{sid}_mi.png", mi_map, "Mutual information")
            save_map_with_mm_axis(out_dir / f"{sid}_transition_width_mm.png", tw_mm, "Transition width (mm)")

    metrics = {
        "dice_mean": float(np.mean(dices)) if dices else 0.0,
        "bf_mean": float(np.mean(bfs)) if bfs else 0.0,
        "hd95_mm_mean": float(np.mean(hd95s_px) * MM_PER_PIXEL) if hd95s_px else 0.0,
        "entropy_mean": float(np.mean(ent_means)) if ent_means else 0.0,
        "mi_mean": float(np.mean(mi_means)) if mi_means else 0.0,
        "n": int(len(dices)),
        "mc_T": int(mc_T),
    }
    with open(out_dir / "seg_test_metrics.json", "w", encoding="utf-8") as f:
        f.write(pd.Series(metrics).to_json(indent=2))
    print("[TEST-SEG]", metrics)
    return metrics

@torch.no_grad()
def eval_export_cls_test():
    model.eval()
    out_dir = OUT_DIR / "cls_test"
    out_dir.mkdir(parents=True, exist_ok=True)

    all_logits, all_labels = [], []
    for b in tqdm(cls_loader, desc="[TEST Cls]", leave=True):
        x = b["image"].to(cfg.device)
        y = b["label"].to(cfg.device)
        with torch.amp.autocast(device_type="cuda", enabled=(cfg.amp and cfg.device.startswith("cuda"))):
            _, logits = model(x)
        all_logits.append(logits.detach().float().cpu().numpy())
        all_labels.append(y.detach().cpu().numpy())

    logits = np.concatenate(all_logits, axis=0) if all_logits else np.zeros((0, 3), np.float32)
    labels = np.concatenate(all_labels, axis=0) if all_labels else np.zeros((0,), np.int64)

    probs = softmax_np(logits) if len(labels) else np.zeros_like(logits)
    pred = probs.argmax(axis=1) if len(labels) else np.zeros((0,), np.int64)

    acc = float((pred == labels).mean()) if len(labels) else 0.0
    ece = ece_score(probs, labels, n_bins=15) if len(labels) else 0.0
    nll = nll_score(probs, labels) if len(labels) else 0.0
    br  = brier_score(probs, labels) if len(labels) else 0.0

    unc = predictive_entropy(probs).astype(np.float32) if len(labels) else np.zeros((0,), np.float32)
    cov, risk = risk_coverage_curve(probs, labels, unc) if len(labels) else (np.zeros((0,),), np.zeros((0,),))
    A = aurc(cov, risk) if len(labels) else 0.0

    # save arrays
    np.save(out_dir / "logits.npy", logits.astype(np.float32))
    np.save(out_dir / "probs.npy", probs.astype(np.float32))
    np.save(out_dir / "labels.npy", labels.astype(np.int64))
    np.save(out_dir / "uncertainty_entropy.npy", unc.astype(np.float32))
    if len(labels):
        np.save(out_dir / "risk_coverage_coverage.npy", cov.astype(np.float32))
        np.save(out_dir / "risk_coverage_risk.npy", risk.astype(np.float32))

    if cfg.cls_export_curve and len(labels):
        plt.figure(figsize=(4, 3), dpi=150)
        plt.plot(cov, risk)
        plt.xlabel("Coverage")
        plt.ylabel("Risk (error rate)")
        plt.title(f"TEST Risk–Coverage (AURC={A:.4f})")
        plt.tight_layout()
        plt.savefig(out_dir / "risk_coverage_curve.png", bbox_inches="tight")
        plt.close()

    metrics = {
        "acc": float(acc),
        "ece": float(ece),
        "aurc": float(A),
        "nll": float(nll),
        "brier": float(br),
        "n": int(len(labels)),
    }
    with open(out_dir / "cls_test_metrics.json", "w", encoding="utf-8") as f:
        f.write(pd.Series(metrics).to_json(indent=2))
    print("[TEST-CLS]", metrics)
    return metrics


# =========================================================
# Run TEST
# =========================================================
seg_metrics = eval_export_seg_test(mc_T=cfg.seg_mc_T)
cls_metrics = eval_export_cls_test()

summary = {
    "checkpoint": ckpt_path.name,
    "epoch": epoch_loaded,
    "seg": seg_metrics,
    "cls": cls_metrics,
    "image_size": IMAGE_SIZE,
    "field_size_mm": FIELD_SIZE_MM,
    "mm_per_pixel": MM_PER_PIXEL,
}
with open(OUT_DIR / "test_summary.json", "w", encoding="utf-8") as f:
    f.write(pd.Series(summary).to_json(indent=2))

print("✅ TEST evaluation complete.")
print("📁 Saved to:", OUT_DIR)

  _C._set_float32_matmul_precision(precision)


[TEST-SEG] pairs found = 1000
[TEST-CLS] samples found = 1512




[Load] best.pt | epoch=30


[TEST Seg MC(T=8)]:   1%|          | 6/1000 [00:14<34:02,  2.06s/it]