In [None]:
from google.colab import drive

drive.mount('/content/drive')

In [None]:
import os
import random
import re
from dataclasses import dataclass
from typing import Dict, List, Tuple
from collections import defaultdict

import numpy as np
import pandas as pd
from PIL import Image

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationMode
import matplotlib.pyplot as plt

from scipy.io import loadmat

from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import roc_auc_score, roc_curve, classification_report

plt.style.use("seaborn-v0_8")
print(f"PyTorch version: {torch.__version__}")
print(f"Torchvision version: {models.__version__ if hasattr(models, '__version__') else 'N/A'}")


In [None]:
# Environment detection and path configuration

def is_google_colab() -> bool:
    try:
        import google.colab  # type: ignore  # noqa: F401

        return True
    except ImportError:
        return False


def is_google_drive_mounted() -> bool:
    return os.path.exists("/content/drive/MyDrive")


def configure_paths() -> Tuple[str, str, str, str, str]:
    if is_google_colab() and is_google_drive_mounted():
        base = "/content/drive/MyDrive/Hikaru_Colab_Workspace/TUPIL_Kidney"
        print("Running on Google Colab with Google Drive mounted")
    elif is_google_colab():
        base = "/content"
        print("Running on Google Colab without Google Drive mounted")
    else:
        base = "/Users/hikaru/Desktop/TUPIL/Code/TUPIL_Kidney"
        print("Running locally")

    csv_file = os.path.join(base, "csv", "patient_eGFR_at_pocus_2025_Jul_polynomial_estimation.csv")
    qus_dir = os.path.join(base, "data", "QUS_resized")
    sample_id_file = os.path.join(base, "data", "QUS_combined", "sample_id_combined.mat")
    b_mode_folder = os.path.join(base, "data", "Bmode_resized")
    model_weights_path = os.path.join(base, "data", "model_weights", "RadImageNet-ResNet50_notop.h5")

    return base, csv_file, qus_dir, sample_id_file, b_mode_folder, model_weights_path


BASE_DIR, CSV_FILE, QUS_DATA_DIR, SAMPLE_ID_FILE, B_MODE_IMAGE_FOLDER, MODEL_WEIGHTS_PATH = configure_paths()
IMAGE_FOLDERS = {"B_mode": B_MODE_IMAGE_FOLDER}

print(f"BASE_DIR: {BASE_DIR}")
print(f"CSV file: {CSV_FILE}")
print(f"QUS data directory: {QUS_DATA_DIR}")
print(f"Sample ID file: {SAMPLE_ID_FILE}")
print(f"B-mode folder: {B_MODE_IMAGE_FOLDER}")
print(f"Model weights path (if needed): {MODEL_WEIGHTS_PATH}")


In [None]:
ALL_QUS_TYPES = {"ESD", "EAC", "SI", "SS", "MBF"}


@dataclass
class Patient:
    patient_id: int
    egfr_label: int
    egfr_value: float
    b_mode_paths: List[str]
    qus_case_indices: Dict[str, List[int]]

    def count(self, input_type: str) -> int:
        if input_type == "B_mode":
            return len(self.b_mode_paths)
        return len(self.qus_case_indices.get(input_type, []))

    def min_samples(self, input_types: List[str]) -> int:
        counts = [self.count(input_type) for input_type in input_types]
        return min(counts) if counts else 0


def extract_patient_id_from_sample(sample_id: str) -> int:
    match = re.search(r"P(\d+)", sample_id)
    if match:
        return int(match.group(1))
    raise ValueError(f"Unable to parse patient ID from sample ID: {sample_id}")


def extract_matlab_string(cell_item) -> str:
    if isinstance(cell_item, np.ndarray):
        if cell_item.size == 0:
            return ""
        if cell_item.dtype.kind in {"U", "S"}:
            return str(cell_item.flat[0])
        if cell_item.dtype == object:
            return extract_matlab_string(cell_item.flat[0])
        return str(cell_item.flat[0])
    return str(cell_item)


def load_sample_ids(sample_id_file: str) -> List[str]:
    if not os.path.exists(sample_id_file):
        raise FileNotFoundError(f"Sample ID file not found: {sample_id_file}")

    mat_data = loadmat(sample_id_file, struct_as_record=False, squeeze_me=True)
    keys = [key for key in mat_data.keys() if not key.startswith("__")]
    if not keys:
        raise ValueError("No data found in sample ID file")

    sample_var = mat_data[keys[0]]
    sample_ids: List[str] = []

    if isinstance(sample_var, np.ndarray):
        for item in np.ravel(sample_var):
            sample_ids.append(extract_matlab_string(item))
    elif isinstance(sample_var, (list, tuple)):
        sample_ids = [extract_matlab_string(x) for x in sample_var]
    else:
        sample_ids = [extract_matlab_string(sample_var)]

    return [sid.strip() for sid in sample_ids if sid]


def load_qus_arrays(qus_dir: str, qus_types: List[str]) -> Dict[str, np.ndarray]:
    qus_arrays: Dict[str, np.ndarray] = {}
    for qus_name in qus_types:
        npy_path = os.path.join(qus_dir, f"{qus_name}.npy")
        if not os.path.exists(npy_path):
            raise FileNotFoundError(f"QUS file not found: {npy_path}")
        qus_arrays[qus_name] = np.load(npy_path)
        print(f"Loaded {qus_name} with shape {qus_arrays[qus_name].shape}")
    return qus_arrays


def load_b_mode_image_map(image_folder: str) -> Dict[int, List[str]]:
    image_map: Dict[int, List[str]] = defaultdict(list)
    if not os.path.exists(image_folder):
        print(f"Warning: B-mode folder not found: {image_folder}")
        return image_map

    for filename in sorted(os.listdir(image_folder)):
        if not filename.lower().endswith((".png", ".jpg", ".jpeg")):
            continue
        parts = filename.split("_")
        if len(parts) < 2 or parts[0] != "Patient":
            continue
        try:
            patient_id = int(parts[1])
        except ValueError:
            continue
        image_map[patient_id].append(os.path.join(image_folder, filename))
    return image_map


def build_patients(
    input_types: List[str],
    image_folders: Dict[str, str],
    csv_file: str,
    qus_arrays: Dict[str, np.ndarray],
    sample_ids: List[str],
) -> List[Patient]:
    egfr_df = pd.read_csv(csv_file)
    egfr_df.rename(columns={"Patient ID": "patient_id", "eGFR (abs/closest)": "eGFR"}, inplace=True)
    egfr_df["patient_id"] = egfr_df["patient_id"].astype(int)
    egfr_df.set_index("patient_id", inplace=True)

    selected_qus_types = [input_type for input_type in input_types if input_type in ALL_QUS_TYPES]
    use_b_mode = "B_mode" in input_types

    # Map patient -> qus case indices
    qus_case_map: Dict[int, Dict[str, List[int]]] = defaultdict(lambda: {qus_type: [] for qus_type in selected_qus_types})
    if selected_qus_types:
        if not sample_ids:
            raise ValueError("Sample IDs are required when QUS inputs are selected")
        num_cases_expected = None
        for qus_type in selected_qus_types:
            num_cases = qus_arrays[qus_type].shape[2]
            if num_cases_expected is None:
                num_cases_expected = num_cases
            elif num_cases_expected != num_cases:
                raise ValueError("All QUS arrays must have the same number of cases")
        if num_cases_expected and len(sample_ids) != num_cases_expected:
            raise ValueError("Number of sample IDs does not match QUS cases")

        for case_idx, sample_id in enumerate(sample_ids):
            try:
                patient_id = extract_patient_id_from_sample(sample_id)
            except ValueError:
                continue
            if patient_id not in egfr_df.index:
                continue
            for qus_type in selected_qus_types:
                qus_case_map[patient_id][qus_type].append(case_idx)

    # Map patient -> B-mode image paths
    b_mode_map: Dict[int, List[str]] = defaultdict(list)
    if use_b_mode:
        b_mode_map = load_b_mode_image_map(image_folders.get("B_mode", ""))

    patients: List[Patient] = []
    for patient_id in egfr_df.index:
        egfr_value = float(egfr_df.loc[patient_id, "eGFR"])
        egfr_label = 1 if egfr_value >= 60 else 0

        patient_qus_indices = qus_case_map.get(patient_id, {qus_type: [] for qus_type in selected_qus_types})
        patient_b_mode_paths = b_mode_map.get(patient_id, [])

        has_all_inputs = True
        for input_type in input_types:
            if input_type == "B_mode":
                if not patient_b_mode_paths:
                    has_all_inputs = False
                    break
            else:
                if not patient_qus_indices.get(input_type):
                    has_all_inputs = False
                    break
        if not has_all_inputs:
            continue

        patients.append(
            Patient(
                patient_id=patient_id,
                egfr_label=egfr_label,
                egfr_value=egfr_value,
                b_mode_paths=patient_b_mode_paths,
                qus_case_indices=patient_qus_indices,
            )
        )

    return patients


def summarize_patients(patients: List[Patient], input_types: List[str]) -> None:
    print(f"Number of patients: {len(patients)}")
    for input_type in input_types:
        counts = [patient.count(input_type) for patient in patients]
        if not counts:
            continue
        print(
            f"{input_type}: total={sum(counts)}, min={min(counts)}, max={max(counts)}, avg={np.mean(counts):.2f}"
        )


GLOBAL_QUS_ARRAYS: Dict[str, np.ndarray] = {}
GLOBAL_SAMPLE_IDS: List[str] = []
GLOBAL_TRAIN_SCALERS: Dict[str, Tuple[float, float]] = {}


def load_patients(
    input_types: List[str],
    image_folders: Dict[str, str],
    csv_file: str,
    qus_arrays: Dict[str, np.ndarray] | None = None,
    sample_ids: List[str] | None = None,
) -> List[Patient]:
    resolved_qus_arrays = qus_arrays if qus_arrays is not None else GLOBAL_QUS_ARRAYS
    resolved_sample_ids = sample_ids if sample_ids is not None else GLOBAL_SAMPLE_IDS
    return build_patients(input_types, image_folders, csv_file, resolved_qus_arrays, resolved_sample_ids)



In [None]:
# Legacy TensorFlow-style patient helpers (not used).

In [None]:
# Input configuration and experiment hyperparameters

# QUS Options: 'ESD', 'EAC', 'SI', 'SS', 'MBF'
# Image Option: 'B_mode'
# Specify any combination you would like to experiment with
INPUT_TYPES = ['ESD', 'EAC', 'SI', 'SS', 'MBF', 'B_mode']
# Examples:
# INPUT_TYPES = ['ESD']
# INPUT_TYPES = ['ESD', 'EAC', 'B_mode']
# INPUT_TYPES = ['B_mode']

IMAGE_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 100
EARLY_STOPPING_PATIENCE = 20
LEARNING_RATE = 1e-4
STEP_LR_EVERY = 15
STEP_LR_GAMMA = 0.5
N_RUNS = 5
BASE_SEED = 42

QUS_TYPES = [t for t in INPUT_TYPES if t in ALL_QUS_TYPES]
IMAGE_TYPES = [t for t in INPUT_TYPES if t == 'B_mode']

print(f"Selected input types: {INPUT_TYPES}")
print(f"QUS types: {QUS_TYPES}")
print(f"Image types: {IMAGE_TYPES}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")



In [None]:
# (Deprecated intermediate definition of FlexiblePatientDataset removed in favor of the final PyTorch implementation below.)


In [None]:
# Prepare modality resources for downstream loading

SELECTED_QUS_TYPES = [input_type for input_type in INPUT_TYPES if input_type in ALL_QUS_TYPES]
USE_B_MODE = "B_mode" in INPUT_TYPES
GLOBAL_TRAIN_SCALERS = {}

if SELECTED_QUS_TYPES:
    GLOBAL_SAMPLE_IDS = load_sample_ids(SAMPLE_ID_FILE)
    print(f"Loaded {len(GLOBAL_SAMPLE_IDS)} sample IDs")
    GLOBAL_QUS_ARRAYS = load_qus_arrays(QUS_DATA_DIR, SELECTED_QUS_TYPES)
else:
    GLOBAL_SAMPLE_IDS = []
    GLOBAL_QUS_ARRAYS = {}
    print("No QUS inputs selected; proceeding with image-only data.")



In [None]:
# Legacy TensorFlow-style dataset placeholder (not used).

In [None]:
# PyTorch dataset for flexible multimodal inputs (overrides earlier definition)

class FlexiblePatientDataset(Dataset):
    def __init__(
        self,
        patients: List[Patient],
        input_types: List[str],
        qus_arrays: Dict[str, np.ndarray] | None = None,
        scalers: Dict[str, Tuple[float, float]] | None = None,
        image_size: int = 224,
        augment: bool = False,
    ) -> None:
        resolved_qus_arrays = qus_arrays if qus_arrays is not None else GLOBAL_QUS_ARRAYS
        self.patients = patients
        self.input_types = input_types
        self.image_size = image_size
        self.augment = augment
        self.qus_arrays = resolved_qus_arrays
        self.selected_qus_types = [input_type for input_type in input_types if input_type in ALL_QUS_TYPES]
        self.use_b_mode = "B_mode" in input_types

        self._records: List[Tuple[Patient, int]] = []
        for patient in patients:
            available = patient.min_samples(input_types)
            for sample_idx in range(available):
                self._records.append((patient, sample_idx))

        self.samples = [(patient, patient.egfr_label, patient.patient_id) for patient, _ in self._records]
        self.labels = [patient.egfr_label for patient, _ in self._records]
        self.base_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        self.base_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

        global GLOBAL_TRAIN_SCALERS
        if scalers is not None:
            self.scalers = scalers
        else:
            if GLOBAL_TRAIN_SCALERS and not augment:
                self.scalers = GLOBAL_TRAIN_SCALERS
            else:
                if augment:
                    GLOBAL_TRAIN_SCALERS = {}
                self.scalers = self._compute_qus_scalers()
                if self.selected_qus_types and augment:
                    GLOBAL_TRAIN_SCALERS = self.scalers

    def _compute_qus_scalers(self) -> Dict[str, Tuple[float, float]]:
        scalers: Dict[str, Tuple[float, float]] = {}
        for qus_type in self.selected_qus_types:
            min_val = float("inf")
            max_val = float("-inf")
            for patient, sample_idx in self._records:
                case_idx = patient.qus_case_indices[qus_type][sample_idx]
                qus_map = np.nan_to_num(self.qus_arrays[qus_type][:, :, case_idx], nan=0.0)
                local_min = float(np.min(qus_map))
                local_max = float(np.max(qus_map))
                if local_min < min_val:
                    min_val = local_min
                if local_max > max_val:
                    max_val = local_max
            if min_val == float("inf") or max_val == float("-inf"):
                min_val, max_val = 0.0, 1.0
            if max_val - min_val < 1e-6:
                max_val = min_val + 1e-6
            scalers[qus_type] = (min_val, max_val)
        return scalers

    def __len__(self) -> int:
        return len(self._records)

    def _load_b_mode_tensor(self, path: str) -> torch.Tensor:
        image = Image.open(path).convert("RGB")
        if image.size != (self.image_size, self.image_size):
            image = image.resize((self.image_size, self.image_size))
        tensor = F.to_tensor(image)
        return tensor

    def _load_qus_tensor(self, qus_type: str, patient: Patient, sample_idx: int) -> torch.Tensor:
        case_idx = patient.qus_case_indices[qus_type][sample_idx]
        qus_map = np.nan_to_num(self.qus_arrays[qus_type][:, :, case_idx], nan=0.0)
        tensor = torch.from_numpy(qus_map).float().unsqueeze(0)
        if tensor.shape[1] != self.image_size or tensor.shape[2] != self.image_size:
            tensor = torch.nn.functional.interpolate(
                tensor.unsqueeze(0),
                size=(self.image_size, self.image_size),
                mode="bilinear",
                align_corners=False,
            ).squeeze(0)
        min_val, max_val = self.scalers.get(qus_type, (0.0, 1.0))
        if max_val > min_val:
            tensor = (tensor - min_val) / (max_val - min_val)
        tensor = tensor.clamp(0.0, 1.0)
        return tensor

    def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, int]:
        patient, sample_idx = self._records[index]

        modality_tensors: Dict[str, torch.Tensor] = {}

        if self.use_b_mode:
            b_mode_path = patient.b_mode_paths[sample_idx]
            modality_tensors["B_mode"] = self._load_b_mode_tensor(b_mode_path)

        for qus_type in self.selected_qus_types:
            modality_tensors[qus_type] = self._load_qus_tensor(qus_type, patient, sample_idx)

        if self.augment:
            do_flip = random.random() < 0.5
            angle_deg = random.uniform(-14.0, 14.0)
            zoom = random.uniform(0.9, 1.1)
            for key, tensor in modality_tensors.items():
                if do_flip:
                    tensor = torch.flip(tensor, dims=[2])
                tensor = F.affine(
                    tensor,
                    angle=angle_deg,
                    translate=(0, 0),
                    scale=zoom,
                    shear=0.0,
                    interpolation=InterpolationMode.BILINEAR,
                )
                modality_tensors[key] = tensor

        processed_channels: List[torch.Tensor] = []
        for input_type in self.input_types:
            tensor = modality_tensors[input_type]
            if input_type == "B_mode":
                tensor = (tensor - self.base_mean) / self.base_std
            else:
                tensor = tensor.clamp(0.0, 1.0).repeat(3, 1, 1)
            processed_channels.append(tensor)

        concatenated = torch.cat(processed_channels, dim=0)
        label = torch.tensor(patient.egfr_label, dtype=torch.float32)
        return concatenated, label, patient.patient_id



In [None]:
# Model builder: ResNet18 with flexible input channels

def build_resnet18(num_input_channels: int, device: torch.device) -> nn.Module:
    model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
    if num_input_channels != 3:
        old_conv = model.conv1
        model.conv1 = nn.Conv2d(
            num_input_channels,
            old_conv.out_channels,
            kernel_size=old_conv.kernel_size,
            stride=old_conv.stride,
            padding=old_conv.padding,
            bias=old_conv.bias is not None,
        )
        nn.init.kaiming_normal_(model.conv1.weight, mode="fan_out", nonlinearity="relu")
        if old_conv.bias is not None:
            nn.init.zeros_(model.conv1.bias)
    model.fc = nn.Linear(model.fc.in_features, 1)
    return model.to(device)


def count_trainable_parameters(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


In [None]:
# Training utilities

def set_seeds(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def compute_pos_weight(labels: List[int]) -> float:
    positives = sum(labels)
    negatives = len(labels) - positives
    if positives == 0:
        return 1.0
    return max(1.0, negatives / max(positives, 1))


def train_one_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: torch.device,
) -> Tuple[float, float]:
    model.train()
    total_loss = 0.0
    all_labels: List[float] = []
    all_probs: List[float] = []

    for images, labels, _ in dataloader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(images).squeeze(1)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()

        probs = torch.sigmoid(logits).detach().cpu().numpy()
        all_probs.extend(probs.tolist())
        all_labels.extend(labels.detach().cpu().numpy().tolist())
        total_loss += loss.item() * images.size(0)

    epoch_loss = total_loss / len(dataloader.dataset)
    epoch_auc = roc_auc_score(all_labels, all_probs) if len(set(all_labels)) > 1 else float("nan")
    return epoch_loss, epoch_auc


def evaluate(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    device: torch.device,
) -> Tuple[float, float, Dict[int, List[float]], Dict[int, int]]:
    model.eval()
    total_loss = 0.0
    all_labels: List[float] = []
    all_probs: List[float] = []
    patient_probs: Dict[int, List[float]] = defaultdict(list)
    patient_labels: Dict[int, int] = {}

    with torch.no_grad():
        for images, labels, patient_ids in dataloader:
            images = images.to(device)
            labels = labels.to(device)

            logits = model(images).squeeze(1)
            loss = criterion(logits, labels)

            probs = torch.sigmoid(logits).cpu().numpy()
            total_loss += loss.item() * images.size(0)

            labels_np = labels.cpu().numpy()
            patient_ids_np = patient_ids.numpy()

            all_probs.extend(probs.tolist())
            all_labels.extend(labels_np.tolist())

            for pid, prob, lab in zip(patient_ids_np, probs, labels_np):
                patient_probs[int(pid)].append(float(prob))
                patient_labels[int(pid)] = int(lab)

    epoch_loss = total_loss / len(dataloader.dataset)
    epoch_auc = roc_auc_score(all_labels, all_probs) if len(set(all_labels)) > 1 else float("nan")
    return epoch_loss, epoch_auc, patient_probs, patient_labels


def plot_roc_curves(
    probs: List[float],
    labels: List[int],
    patient_probs: Dict[int, List[float]],
    patient_labels: Dict[int, int],
    title_suffix: str,
) -> Tuple[float, float]:
    if len(set(labels)) > 1:
        fpr_img, tpr_img, _ = roc_curve(labels, probs)
        auc_img = roc_auc_score(labels, probs)
        plt.figure(figsize=(8, 6))
        plt.plot(fpr_img, tpr_img, label=f"Image ROC (AUC={auc_img:.4f})", color="tab:blue")
        plt.plot([0, 1], [0, 1], "k--", label="Random")
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title(f"Image-level ROC {title_suffix}")
        plt.legend()
        plt.grid(True)
        plt.show()
    else:
        auc_img = float("nan")

    patient_mean_probs = [np.mean(patient_probs[pid]) for pid in patient_probs]
    patient_true = [patient_labels[pid] for pid in patient_probs]

    if len(set(patient_true)) > 1:
        fpr_pat, tpr_pat, _ = roc_curve(patient_true, patient_mean_probs)
        auc_pat = roc_auc_score(patient_true, patient_mean_probs)
        plt.figure(figsize=(8, 6))
        plt.plot(fpr_pat, tpr_pat, label=f"Patient ROC (AUC={auc_pat:.4f})", color="tab:green")
        plt.plot([0, 1], [0, 1], "k--", label="Random")
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title(f"Patient-level ROC {title_suffix}")
        plt.legend()
        plt.grid(True)
        plt.show()
    else:
        auc_pat = float("nan")

    return auc_img, auc_pat


In [None]:
# Load and summarize patient data

all_patients = load_patients(
    INPUT_TYPES,
    IMAGE_FOLDERS,
    CSV_FILE,
    qus_arrays=GLOBAL_QUS_ARRAYS,
    sample_ids=GLOBAL_SAMPLE_IDS,
)
if not all_patients:
    raise RuntimeError("No patients found with the specified input types.")

summarize_patients(all_patients, INPUT_TYPES)


In [None]:
# Hold-out runs with training, validation, and testing

results = []

for run_idx in range(N_RUNS):
    global GLOBAL_TRAIN_SCALERS
    GLOBAL_TRAIN_SCALERS = {}

    seed = BASE_SEED + run_idx
    set_seeds(seed)

    print("=" * 60)
    print(f"Run {run_idx + 1}/{N_RUNS} (seed={seed})")
    print("=" * 60)

    train_val_patients, test_patients = train_test_split(all_patients, test_size=0.1, random_state=seed)
    train_patients, val_patients = train_test_split(train_val_patients, test_size=0.2, random_state=seed)

    print(f"Train patients: {len(train_patients)}")
    print(f"Validation patients: {len(val_patients)}")
    print(f"Test patients: {len(test_patients)}")

    train_dataset = FlexiblePatientDataset(train_patients, INPUT_TYPES, image_size=IMAGE_SIZE, augment=True)
    train_scalers = train_dataset.scalers
    GLOBAL_TRAIN_SCALERS = train_scalers
    val_dataset = FlexiblePatientDataset(val_patients, INPUT_TYPES, scalers=train_scalers, image_size=IMAGE_SIZE, augment=False)
    test_dataset = FlexiblePatientDataset(test_patients, INPUT_TYPES, scalers=train_scalers, image_size=IMAGE_SIZE, augment=False)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    model = build_resnet18(num_input_channels=3 * len(INPUT_TYPES), device=device)
    print(f"Trainable parameters: {count_trainable_parameters(model):,}")

    train_labels = train_dataset.labels
    pos_weight_value = compute_pos_weight(train_labels)
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight_value, device=device))
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=STEP_LR_EVERY, gamma=STEP_LR_GAMMA)

    best_val_auc = -1.0
    best_state = None
    epochs_without_improvement = 0

    for epoch in range(1, EPOCHS + 1):
        train_loss, train_auc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_auc, val_patient_probs, val_patient_labels = evaluate(model, val_loader, criterion, device)
        scheduler.step()

        print(
            f"Epoch {epoch:3d} | Train Loss: {train_loss:.4f} AUC: {train_auc:.4f} | "
            f"Val Loss: {val_loss:.4f} AUC: {val_auc:.4f}"
        )

        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_state = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= EARLY_STOPPING_PATIENCE:
            print("Early stopping triggered.")
            break

    if best_state is None:
        raise RuntimeError("Training did not yield a valid model state.")

    model.load_state_dict(best_state["model"])

    val_loss, val_auc, val_patient_probs, val_patient_labels = evaluate(model, val_loader, criterion, device)
    val_probs_flat = [prob for probs in val_patient_probs.values() for prob in probs]
    val_labels_flat = [val_patient_labels[pid] for pid, probs in val_patient_probs.items() for _ in probs]
    val_img_auc, val_patient_auc = plot_roc_curves(
        val_probs_flat,
        val_labels_flat,
        val_patient_probs,
        val_patient_labels,
        title_suffix=f"(Validation Run {run_idx + 1})",
    )

    test_loss, test_auc, test_patient_probs, test_patient_labels = evaluate(model, test_loader, criterion, device)
    test_probs_flat = [prob for probs in test_patient_probs.values() for prob in probs]
    test_labels_flat = [test_patient_labels[pid] for pid, probs in test_patient_probs.items() for _ in probs]
    test_img_auc, test_patient_auc = plot_roc_curves(
        test_probs_flat,
        test_labels_flat,
        test_patient_probs,
        test_patient_labels,
        title_suffix=f"(Test Run {run_idx + 1})",
    )

    preds_binary = [1 if prob >= 0.5 else 0 for prob in test_probs_flat]
    print("Test classification report (image-level):")
    print(classification_report(test_labels_flat, preds_binary, digits=4))

    results.append(
        {
            "val_auc": val_auc,
            "test_auc": test_auc,
            "val_img_auc": val_img_auc,
            "val_patient_auc": val_patient_auc,
            "test_img_auc": test_img_auc,
            "test_patient_auc": test_patient_auc,
        }
    )


In [None]:
# Aggregate results across runs

if results:
    print("\n" + "=" * 60)
    print("FINAL SUMMARY OVER MULTIPLE HOLD-OUT RUNS")
    print("=" * 60)
    val_aucs = [entry["val_auc"] for entry in results]
    test_aucs = [entry["test_auc"] for entry in results]
    print(f"Validation AUCs: {[f'{auc:.4f}' for auc in val_aucs]}")
    print(f"Test AUCs:       {[f'{auc:.4f}' for auc in test_aucs]}")
    print(f"\nAverage Validation AUC: {np.mean(val_aucs):.4f} ± {np.std(val_aucs):.4f}")
    print(f"Average Test AUC:       {np.mean(test_aucs):.4f} ± {np.std(test_aucs):.4f}")
else:
    print("No runs executed yet.")


In [None]:
# Cross-validation testing (patient-level K-Fold)

N_FOLDS = 5
cv_results = []

kf = KFold(n_splits=N_FOLDS, shuffle=True, random_state=BASE_SEED)

for fold_idx, (train_val_idx, test_idx) in enumerate(kf.split(all_patients)):
    global GLOBAL_TRAIN_SCALERS
    GLOBAL_TRAIN_SCALERS = {}

    print("=" * 60)
    print(f"Cross-validation fold {fold_idx + 1}/{N_FOLDS}")
    print("=" * 60)

    test_patients = [all_patients[i] for i in test_idx]
    train_val_patients = [all_patients[i] for i in train_val_idx]

    rng = np.random.default_rng(BASE_SEED + fold_idx)
    rng.shuffle(train_val_patients)
    val_split = max(1, int(len(train_val_patients) * 0.2))
    val_patients = train_val_patients[:val_split]
    train_patients = train_val_patients[val_split:]

    print(f"Training patients: {len(train_patients)}")
    print(f"Validation patients: {len(val_patients)}")
    print(f"Test patients: {len(test_patients)}")

    train_dataset = FlexiblePatientDataset(train_patients, INPUT_TYPES, image_size=IMAGE_SIZE, augment=True)
    train_scalers = train_dataset.scalers
    GLOBAL_TRAIN_SCALERS = train_scalers
    val_dataset = FlexiblePatientDataset(val_patients, INPUT_TYPES, scalers=train_scalers, image_size=IMAGE_SIZE, augment=False)
    test_dataset = FlexiblePatientDataset(test_patients, INPUT_TYPES, scalers=train_scalers, image_size=IMAGE_SIZE, augment=False)

    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

    model = build_resnet18(num_input_channels=3 * len(INPUT_TYPES), device=device)
    print(f"Trainable parameters: {count_trainable_parameters(model):,}")

    train_labels = train_dataset.labels
    pos_weight_value = compute_pos_weight(train_labels)
    criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor(pos_weight_value, device=device))
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=STEP_LR_EVERY, gamma=STEP_LR_GAMMA)

    best_val_auc = -1.0
    best_state = None
    epochs_without_improvement = 0

    for epoch in range(1, EPOCHS + 1):
        train_loss, train_auc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_auc, val_patient_probs, val_patient_labels = evaluate(model, val_loader, criterion, device)
        scheduler.step()

        print(
            f"Epoch {epoch:3d} | Train Loss: {train_loss:.4f} AUC: {train_auc:.4f} | "
            f"Val Loss: {val_loss:.4f} AUC: {val_auc:.4f}"
        )

        if val_auc > best_val_auc:
            best_val_auc = val_auc
            best_state = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= EARLY_STOPPING_PATIENCE:
            print("Early stopping triggered.")
            break

    if best_state is None:
        raise RuntimeError("Training did not yield a valid model state.")

    model.load_state_dict(best_state["model"])

    val_loss, val_auc, val_patient_probs, val_patient_labels = evaluate(model, val_loader, criterion, device)
    val_probs_flat = [prob for probs in val_patient_probs.values() for prob in probs]
    val_labels_flat = [val_patient_labels[pid] for pid, probs in val_patient_probs.items() for _ in probs]
    val_img_auc, val_patient_auc = plot_roc_curves(
        val_probs_flat,
        val_labels_flat,
        val_patient_probs,
        val_patient_labels,
        title_suffix=f"(Validation Fold {fold_idx + 1})",
    )

    test_loss, test_auc, test_patient_probs, test_patient_labels = evaluate(model, test_loader, criterion, device)
    test_probs_flat = [prob for probs in test_patient_probs.values() for prob in probs]
    test_labels_flat = [test_patient_labels[pid] for pid, probs in test_patient_probs.items() for _ in probs]
    test_img_auc, test_patient_auc = plot_roc_curves(
        test_probs_flat,
        test_labels_flat,
        test_patient_probs,
        test_patient_labels,
        title_suffix=f"(Test Fold {fold_idx + 1})",
    )

    preds_binary = [1 if prob >= 0.5 else 0 for prob in test_probs_flat]
    print("Test classification report (image-level):")
    print(classification_report(test_labels_flat, preds_binary, digits=4))

    cv_results.append(
        {
            "val_auc": val_auc,
            "test_auc": test_auc,
            "val_img_auc": val_img_auc,
            "val_patient_auc": val_patient_auc,
            "test_img_auc": test_img_auc,
            "test_patient_auc": test_patient_auc,
        }
    )


In [None]:
# Cross-validation summary

if cv_results:
    print("\n" + "=" * 60)
    print("FINAL CROSS-VALIDATION SUMMARY")
    print("=" * 60)
    cv_val_aucs = [entry["val_auc"] for entry in cv_results]
    cv_test_aucs = [entry["test_auc"] for entry in cv_results]
    print(f"Validation AUCs: {[f'{auc:.4f}' for auc in cv_val_aucs]}")
    print(f"Test AUCs:       {[f'{auc:.4f}' for auc in cv_test_aucs]}")
    print(f"\nAverage Validation AUC: {np.mean(cv_val_aucs):.4f} ± {np.std(cv_val_aucs):.4f}")
    print(f"Average Test AUC:       {np.mean(cv_test_aucs):.4f} ± {np.std(cv_test_aucs):.4f}")
else:
    print("Cross-validation has not been run yet.")
