In [None]:
from google.colab import drive

drive.mount('/content/drive')


In [None]:
#!git clone https://github.com/Hikarukurosawa123/TUPIL_Kidney.git
#!git pull origin main


In [None]:
import os
import math
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import KBinsDiscretizer, MinMaxScaler
from sklearn.model_selection import KFold, StratifiedKFold
from tensorflow.keras import layers, models
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score, accuracy_score, confusion_matrix, classification_report, roc_curve, auc, roc_auc_score
from collections import defaultdict
from scipy.io import loadmat
import h5py
import re
from sklearn.utils import class_weight

# Input configuration - specify which inputs to use
# QUS Options: 'ESD', 'EAC', 'SI', 'SS', 'MBF'
# Image Options: 'B_mode' (greyscale B-mode images)
# Can use any combination
INPUT_TYPES = ['ESD', 'EAC', 'SI', 'SS', 'MBF', 'B_mode']  # All QUS parameters + B-mode
# Examples:
# INPUT_TYPES = ['ESD']  # Only ESD
# INPUT_TYPES = ['ESD', 'EAC', 'B_mode']  # ESD + EAC + B-mode
# INPUT_TYPES = ['ESD', 'EAC', 'SI', 'SS', 'MBF']  # All QUS only
# INPUT_TYPES = ['B_mode']  # Only B-mode

# Environment detection
def is_google_colab():
    """Check if running in Google Colab environment"""
    try:
        import google.colab
        return True
    except ImportError:
        return False

def is_google_drive_mounted():
    """Check if Google Drive is mounted in Colab"""
    return os.path.exists('/content/drive/MyDrive')

# Set paths based on environment
if is_google_colab() and is_google_drive_mounted():
    # Google Colab with Google Drive mounted
    QUS_DATA_DIR = '/content/drive/MyDrive/Hikaru_Colab_Workspace/TUPIL_Kidney/data/QUS_resized'
    SAMPLE_ID_FILE = '/content/drive/MyDrive/Hikaru_Colab_Workspace/TUPIL_Kidney/data/QUS_combined/sample_id_combined.mat'
    CSV_FILE = '/content/drive/MyDrive/Hikaru_Colab_Workspace/TUPIL_Kidney/csv/patient_eGFR_at_pocus_2025_Jul_polynomial_estimation.csv'
    MODEL_WEIGHTS_PATH = '/content/drive/MyDrive/Hikaru_Colab_Workspace/TUPIL_Kidney/data/model_weights/RadImageNet-ResNet50_notop.h5'
    B_MODE_IMAGE_FOLDER = '/content/drive/MyDrive/Hikaru_Colab_Workspace/TUPIL_Kidney/data/Bmode_resize'
    print("Running on Google Colab with Google Drive mounted")
elif is_google_colab():
    # Google Colab without Google Drive mounted
    QUS_DATA_DIR = '/content/QUS_resized'
    SAMPLE_ID_FILE = '/content/QUS_combined/sample_id_combined.mat'
    CSV_FILE = '/content/patient_eGFR_at_pocus_2025_Jul_polynomial_estimation.csv'
    MODEL_WEIGHTS_PATH = '/content/model_weights/RadImageNet-ResNet50_notop.h5'
    B_MODE_IMAGE_FOLDER = '/content/lanczos_shape_corrected_only_nc_resized_images'
    print("Running on Google Colab without Google Drive mounted")
else:
    # Local environment
    QUS_DATA_DIR = 'data/QUS_resized'
    SAMPLE_ID_FILE = 'data/QUS_combined/sample_id_combined.mat'
    CSV_FILE = 'csv/patient_eGFR_at_pocus_2025_Jul_polynomial_estimation.csv'
    MODEL_WEIGHTS_PATH = 'data/model_weights/RadImageNet-ResNet50_notop.h5'
    B_MODE_IMAGE_FOLDER = 'data/lanczos_shape_corrected_only_nc_resized_images'
    print("Running locally")

print(f"Selected QUS input types: {INPUT_TYPES}")
print(f"QUS_DATA_DIR: {QUS_DATA_DIR}")
print(f"CSV_FILE: {CSV_FILE}")
print(f"MODEL_WEIGHTS_PATH: {MODEL_WEIGHTS_PATH}")

BATCH_SIZE = 16
EPOCHS = 160
SEED = 42


In [None]:
import os
import re
import numpy as np
import pandas as pd
from scipy.io import loadmat
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import tensorflow as tf

# =============================
# Separate QUS types and image types
# =============================
QUS_TYPES = [t for t in INPUT_TYPES if t in ['ESD', 'EAC', 'SI', 'SS', 'MBF']]
IMAGE_TYPES = [t for t in INPUT_TYPES if t == 'B_mode']

print(f"INPUT_TYPES: {INPUT_TYPES}")
print(f"QUS_TYPES: {QUS_TYPES}")
print(f"IMAGE_TYPES: {IMAGE_TYPES}")

# =============================
# Helper Functions
# =============================

def extract_patient_id(sample_id):
    """Extract patient ID (integer) from sample_id string"""
    match = re.search(r'P(\d+)', str(sample_id))
    if match:
        return int(match.group(1))
    return None

def extract_matlab_string(cell_item):
    """Extract string from MATLAB cell array element"""
    if isinstance(cell_item, np.ndarray):
        if cell_item.size == 0:
            return ""
        if cell_item.dtype.kind in ['U', 'S']:
            return str(cell_item.flat[0])
        elif cell_item.dtype == object:
            return extract_matlab_string(cell_item.flat[0])
        else:
            return str(cell_item.flat[0])
    else:
        return str(cell_item)

# =============================
# Load QUS matrices
# =============================

print("Loading QUS matrices...")
qus_matrices = {}
for qus_name in INPUT_TYPES:
    npy_file = os.path.join(QUS_DATA_DIR, f'{qus_name}.npy')
    if os.path.exists(npy_file):
        qus_matrices[qus_name] = np.load(npy_file)
        print(f"  Loaded {qus_name}: shape {qus_matrices[qus_name].shape}")
    else:
        raise FileNotFoundError(f"QUS file not found: {npy_file}")

# =============================
# Check for NaN/Inf in raw data
# =============================
print("\n=== Checking QUS matrices for NaN/Inf/zero values ===")
for qus_name, qus_array in qus_matrices.items():
    total_elements = np.prod(qus_array.shape)
    n_nan = np.isnan(qus_array).sum()
    n_inf = np.isinf(qus_array).sum()
    n_zero = np.sum(qus_array == 0)
    print(f"{qus_name}: shape={qus_array.shape}, NaN={n_nan}, Inf={n_inf}, zeros={n_zero}/{total_elements}")

# =============================
# Plot one example case
# =============================
def plot_qus_case(qus_array, case_idx=0, qus_name="QUS"):
    """Plot a single case with NaN and zero masks"""
    img = qus_array[:, :, case_idx]

    # Create masks
    nan_mask = np.isnan(img)
    zero_mask = img == 0

    plt.figure(figsize=(12, 4))

    # Original image
    plt.subplot(1, 3, 1)
    plt.imshow(img, cmap='viridis')
    plt.title(f"{qus_name} - Case {case_idx} Original")
    plt.colorbar()

    # NaN mask
    plt.subplot(1, 3, 2)
    plt.imshow(nan_mask, cmap='Reds')
    plt.title(f"{qus_name} - Case {case_idx} NaNs")

    # Zero mask
    plt.subplot(1, 3, 3)
    plt.imshow(zero_mask, cmap='Blues')
    plt.title(f"{qus_name} - Case {case_idx} Zeros")

    plt.tight_layout()
    plt.show()

# Example: plot first case of the first QUS type (only if QUS types exist)
if len(QUS_TYPES) > 0:
    first_qus_name = QUS_TYPES[0]
    plot_qus_case(qus_matrices[first_qus_name], case_idx=0, qus_name=first_qus_name)

# =============================
# Load sample IDs
# =============================
print("\nLoading sample IDs...")
sample_id_data = loadmat(SAMPLE_ID_FILE, struct_as_record=False, squeeze_me=True)
sample_id_keys = [k for k in sample_id_data.keys() if not k.startswith('__')]
if len(sample_id_keys) == 0:
    raise ValueError("No data found in sample_id_combined.mat")

sample_ids_var = sample_id_data[sample_id_keys[0]]

# Extract sample IDs
if isinstance(sample_ids_var, np.ndarray) and sample_ids_var.dtype == object:
    sample_ids = []
    for i in range(sample_ids_var.shape[0] if sample_ids_var.ndim > 0 else 1):
        item = sample_ids_var[i] if sample_ids_var.ndim == 1 else sample_ids_var[i, 0]
        sample_ids.append(extract_matlab_string(item))
elif isinstance(sample_ids_var, np.ndarray) and sample_ids_var.dtype.kind in ['U', 'S']:
    sample_ids = [str(x) for x in sample_ids_var]
elif isinstance(sample_ids_var, (list, tuple)):
    sample_ids = [str(s) for s in sample_ids_var]
else:
    sample_ids = [str(sample_ids_var)]

print(f"Loaded {len(sample_ids)} sample IDs")
print(f"Sample IDs (first 5): {sample_ids[:5]}")

# Verify all QUS matrices have same number of cases (if QUS types exist)
if len(QUS_TYPES) > 0:
    n_cases_list = [qus_matrices[qus_name].shape[2] for qus_name in QUS_TYPES]
    if len(set(n_cases_list)) > 1:
        raise ValueError(f"QUS matrices have different number of cases: {n_cases_list}")
    n_cases = n_cases_list[0]
    print(f"\nAll QUS matrices have {n_cases} cases")

# =============================
# Load eGFR data
# =============================
print("\nLoading eGFR data...")
egfr_df = pd.read_csv(CSV_FILE)
egfr_dict = {}
for _, row in egfr_df.iterrows():
    patient_id = int(row['Patient ID'])
    egfr_value = row['eGFR (abs/closest)']
    if not pd.isna(egfr_value):
        egfr_dict[patient_id] = float(egfr_value)
print(f"Loaded eGFR for {len(egfr_dict)} patients")


In [None]:
# Flexible Patient class for QUS data
class Patient:
    def __init__(self, patient_id, egfr, egfr_val, case_indices_dict):
        self.patient_id = patient_id
        self.egfr = egfr  # binary label (0 or 1)
        self.egfr_val = egfr_val  # actual eGFR value
        self.case_indices_dict = case_indices_dict  # Dictionary with qus_type -> list of case indices

    def get_case_indices(self, qus_type):
        """Get case indices for a specific QUS type"""
        return self.case_indices_dict.get(qus_type, [])

    def has_all_inputs(self, required_qus_types):
        """Check if patient has all required QUS types"""
        return all(len(self.get_case_indices(qus_type)) > 0 for qus_type in required_qus_types)


def load_b_mode_images(b_mode_folder):
    """Load B-mode image paths and organize by patient ID"""
    if not os.path.exists(b_mode_folder):
        print(f"Warning: B-mode folder not found: {b_mode_folder}")
        return {}
    
    patient_image_map = defaultdict(list)
    files = sorted(os.listdir(b_mode_folder))
    
    for filename in files:
        if not filename.lower().endswith(('.png', '.jpg', '.jpeg')):
            continue
        # Format: Patient_100_Resized_Image_1.png
        try:
            parts = filename.split('_')
            if len(parts) >= 2 and parts[0] == 'Patient':
                patient_id = int(parts[1])
                full_path = os.path.join(b_mode_folder, filename)
                patient_image_map[patient_id].append(full_path)
        except (ValueError, IndexError):
            continue
    
    print(f"Loaded B-mode images for {len(patient_image_map)} patients")
    return patient_image_map

def load_patients_from_qus(qus_types, qus_matrices, sample_ids, egfr_dict, b_mode_image_map=None):
    """
    Load patients from QUS matrices and/or B-mode images and match with eGFR data.
    
    Args:
        qus_types: List of QUS types to use (e.g., ['ESD', 'EAC', 'SI', 'SS', 'MBF'])
        qus_matrices: Dictionary of QUS matrices (224, 224, n_cases)
        sample_ids: List of sample IDs corresponding to cases
        egfr_dict: Dictionary mapping patient_id -> eGFR value
        b_mode_image_map: Dictionary mapping patient_id -> list of B-mode image paths
    
    Returns:
        List of Patient objects
    """
    # Map patient ID to case indices for each QUS type
    patient_case_map = defaultdict(lambda: {qus_type: [] for qus_type in qus_types})
    
    # For each case, extract patient ID and assign to patient (if QUS types exist)
    if len(qus_types) > 0 and len(sample_ids) > 0:
        for case_idx in range(len(sample_ids)):
            sample_id = sample_ids[case_idx]
            patient_id = extract_patient_id(sample_id)
            
            if patient_id is None:
                continue
            
            if patient_id not in egfr_dict:
                continue
            
            # Add case index to all QUS types (same case index for all QUS parameters)
            for qus_type in qus_types:
                patient_case_map[patient_id][qus_type].append(case_idx)
    
    # Add B-mode images to patient_case_map
    if b_mode_image_map is not None:
        for patient_id, image_paths in b_mode_image_map.items():
            if patient_id not in egfr_dict:
                continue
            patient_case_map[patient_id]['B_mode'] = image_paths
    
    # Build Patient objects - only include patients with all required input types
    patient_objects = []
    all_required_types = qus_types + (['B_mode'] if 'B_mode' in INPUT_TYPES else [])
    
    for patient_id, case_indices_dict in patient_case_map.items():
        # Check if patient has all required input types
        has_all_types = True
        for req_type in all_required_types:
            if req_type in qus_types:
                if len(case_indices_dict.get(req_type, [])) == 0:
                    has_all_types = False
                    break
            elif req_type == 'B_mode':
                if len(case_indices_dict.get('B_mode', [])) == 0:
                    has_all_types = False
                    break
        
        if has_all_types:
            egfr = egfr_dict[patient_id]
            egfrLabel = 1 if egfr >= 60 else 0
            patient_objects.append(Patient(
                patient_id,
                egfrLabel,
                egfr,
                case_indices_dict
            ))

    return patient_objects

# Load B-mode images if needed
b_mode_image_map = None
if 'B_mode' in INPUT_TYPES:
    print("\nLoading B-mode images...")
    b_mode_image_map = load_b_mode_images(B_MODE_IMAGE_FOLDER)

# Load patients
print("\nLoading patients from QUS data and/or B-mode images...")
all_patients = load_patients_from_qus(QUS_TYPES, qus_matrices, sample_ids, egfr_dict, b_mode_image_map)
print(f"Total patients with all required input types: {len(all_patients)}")

def summarize_patients_qus(patients, input_types, qus_types):
    """Summarize the patient data"""
    num_patients = len(patients)
    
    print(f"Number of patients: {num_patients}")
    print(f"Input types: {input_types}")
    
    # Count cases for each QUS type
    for qus_type in qus_types:
        total_cases = sum(len(patient.get_case_indices(qus_type)) for patient in patients)
        print(f"Total {qus_type} cases: {total_cases}")
        
        # Distribution of cases per patient
        counts = [len(patient.get_case_indices(qus_type)) for patient in patients]
        print(f"{qus_type} - Min: {min(counts)}, Max: {max(counts)}, Avg: {np.mean(counts):.2f}")
    
    # Count B-mode images
    if 'B_mode' in input_types:
        total_images = sum(len(patient.get_case_indices('B_mode')) for patient in patients)
        print(f"Total B_mode images: {total_images}")
        counts = [len(patient.get_case_indices('B_mode')) for patient in patients]
        print(f"B_mode - Min: {min(counts)}, Max: {max(counts)}, Avg: {np.mean(counts):.2f}")
    
    return num_patients

summarize_patients_qus(all_patients, INPUT_TYPES, QUS_TYPES)


In [None]:
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import MinMaxScaler
import cv2

def load_b_mode_image_greyscale(image_path):
    """Load B-mode image as greyscale (single channel)
    
    Note: Images are normalized to [0, 1] range. No MinMaxScaler is applied.
    """
    import cv2
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError(f"Could not load image: {image_path}")
    # Resize to 224x224 if needed
    if img.shape != (224, 224):
        img = cv2.resize(img, (224, 224), interpolation=cv2.INTER_LINEAR)
    # Normalize to [0, 1] - this is standard image normalization, NOT MinMaxScaler
    img = img.astype(np.float32) / 255.0
    return img[..., np.newaxis]  # (224, 224, 1)

def create_dataset_from_patients_qus(patients, input_types, qus_types, qus_matrices, scalers=None, augment=False, batch_size=4):
    """
    Creates a tf.data.Dataset for QUS data and/or B-mode images with safe min-max scaling.
    NaN regions are replaced with zero.
    
    Args:
        patients: list of Patient objects
        input_types: list of all input types (QUS + B_mode)
        qus_types: list of QUS types only
        qus_matrices: dict of QUS matrices (H, W, n_cases)
        scalers: optional dict of fitted MinMaxScalers (for QUS types only)
        augment: whether to apply augmentation
        batch_size: dataset batch size
    
    Returns:
        dataset, scalers
    """
    qus_data_lists = {qus_type: [] for qus_type in qus_types}
    b_mode_data_list = []
    labels = []

    # Collect data per patient
    for patient in patients:
        # Get minimum number of cases/images across all input types
        all_counts = []
        for input_type in input_types:
            if input_type in qus_types:
                all_counts.append(len(patient.get_case_indices(input_type)))
            elif input_type == 'B_mode':
                all_counts.append(len(patient.get_case_indices('B_mode')))
        
        min_cases = min(all_counts) if all_counts else 0
        
        for i in range(min_cases):
            # Collect QUS data (if QUS types exist)
            if len(qus_types) > 0:
                for qus_type in qus_types:
                    case_idx = patient.get_case_indices(qus_type)[i]
                    qus_map = qus_matrices[qus_type][:, :, case_idx]
                    # Replace NaNs with zero
                    qus_map = np.nan_to_num(qus_map, nan=0.0)
                    qus_data_lists[qus_type].append(qus_map)
            
            # Collect B-mode data
            if 'B_mode' in input_types:
                b_mode_path = patient.get_case_indices('B_mode')[i]
                b_mode_img = load_b_mode_image_greyscale(b_mode_path)
                b_mode_data_list.append(b_mode_img)
            
            labels.append(patient.egfr)

    # Convert to numpy arrays and add channel dimension for QUS
    for qus_type in qus_types:
        if len(qus_data_lists[qus_type]) > 0:
            qus_data_lists[qus_type] = np.array(qus_data_lists[qus_type])[..., np.newaxis]  # (n_samples, H, W, 1)

    # Convert B-mode to numpy array (already has channel dimension)
    # Note: B-mode images are normalized to [0, 1] when loaded - no MinMaxScaler needed
    if 'B_mode' in input_types:
        b_mode_data = np.array(b_mode_data_list)  # (n_samples, H, W, 1)
        print(f"B_mode images: shape={b_mode_data.shape}, range=[{np.min(b_mode_data):.4f}, {np.max(b_mode_data):.4f}] (no scaling applied)")

    # Fit scalers if not provided (only for QUS types)
    if scalers is None:
        scalers = {}
        for qus_type in qus_types:
            if len(qus_data_lists[qus_type]) > 0:
                scaler = MinMaxScaler()
                data_flat = qus_data_lists[qus_type].reshape(-1, 1)
                scaler.fit(data_flat)
                scalers[qus_type] = scaler
                print(f"Fitted scaler for {qus_type}: min={scaler.data_min_[0]:.4f}, max={scaler.data_max_[0]:.4f}")

    # Apply scaling safely (only for QUS types)
    scaled_qus_data = []
    for qus_type in qus_types:
        if len(qus_data_lists[qus_type]) > 0:
            data = qus_data_lists[qus_type]
            scaler = scalers[qus_type]
            data_flat = data.reshape(-1, 1)
            scaled_flat = scaler.transform(data_flat)
            scaled = scaled_flat.reshape(data.shape)
            scaled_qus_data.append(scaled)
            print(f"Scaled {qus_type}: shape={scaled.shape}, range=[{np.min(scaled):.4f}, {np.max(scaled):.4f}]")

    # Concatenate all inputs along channels
    # QUS data is scaled, B-mode is already normalized (no scaling needed)
    all_inputs = scaled_qus_data.copy()
    if 'B_mode' in input_types:
        all_inputs.append(b_mode_data)  # B-mode already in [0, 1] range, no scaling
    
    # Handle case where there might be only one input type
    if len(all_inputs) == 1:
        combined_input = all_inputs[0]
    else:
        combined_input = np.concatenate(all_inputs, axis=-1)  # (n_samples, H, W, num_channels)
    labels = np.array(labels)
    print(f"Combined input shape: {combined_input.shape}")

    # Create tf.data.Dataset
    dataset = tf.data.Dataset.from_tensor_slices((combined_input, labels)).batch(batch_size)

    # Data augmentation
    if augment:
        data_augmentation = tf.keras.Sequential([
            tf.keras.layers.RandomFlip("horizontal"),
            tf.keras.layers.RandomRotation(0.25),
            tf.keras.layers.RandomZoom(0.1),
        ])
        def augment_fn(inputs, label):
            augmented = data_augmentation(inputs, training=True)
            return augmented, label
        dataset = dataset.map(augment_fn, num_parallel_calls=tf.data.AUTOTUNE)

    return dataset.prefetch(tf.data.AUTOTUNE), scalers


In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models

def build_resnet_model_flexible_qus(input_types, with_transfer_learning=True):
    """
    Build ResNet model for flexible QUS input types.

        Args:
        input_types: List of input types (e.g., ['ESD', 'EAC', 'SI', 'SS', 'MBF', 'B_mode'])
        with_transfer_learning: Whether to use transfer learning
    
    Returns:
        Compiled Keras model
    """
    
    if len(input_types) == 1:
                # Single input model
        input_type = input_types[0]
        input_layer = layers.Input(shape=(224, 224, 1), name=f'{input_type}_input')
        inputs = input_layer
        
        # Convert single channel to 3 channels for ResNet (repeat channel)
        x = layers.Concatenate()([input_layer, input_layer, input_layer])

        # Load base model
        if with_transfer_learning:
            base_model = models.load_model(MODEL_WEIGHTS_PATH, compile=False)
        else:
            base_model = tf.keras.applications.ResNet50(
                weights=None,
                include_top=False,
                input_shape=(224, 224, 3)
            )

        base_model.trainable = True
        x = base_model(x)

    else:
        # Multiple input model
        input_layers = []
        for input_type in input_types:
            input_layer = layers.Input(shape=(224, 224, 1), name=f'{input_type}_input')
            input_layers.append(input_layer)

        # Convert each single channel to 3 channels for ResNet
        expanded_inputs = []
        for input_layer in input_layers:
            expanded = layers.Concatenate()([input_layer, input_layer, input_layer])
            expanded_inputs.append(expanded)

        # Load base model
        if with_transfer_learning:
            base_model = models.load_model(MODEL_WEIGHTS_PATH, compile=False)
        else:
            base_model = tf.keras.applications.ResNet50(
                weights=None,
                include_top=False,
                input_shape=(224, 224, 3)
            )

        base_model.trainable = True

        # Process each input through the base model
        feature_maps = []
        for expanded_input in expanded_inputs:
            features = base_model(expanded_input)
            pooled = layers.GlobalAveragePooling2D()(features)
            feature_maps.append(pooled)

        # Concatenate features from all inputs
        x = layers.Concatenate()(feature_maps)

    # Global average pooling for single input
    if len(input_types) == 1:
        x = layers.GlobalAveragePooling2D()(x)

    # Dense layers for classification
    x = layers.Dense(4096, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(2048, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(1024, activation='relu')(x)
    x = layers.Dropout(0.2)(x)

    # Final binary classification
    output = layers.Dense(1, activation='sigmoid')(x)

    # Create model
    if len(input_types) == 1:
        model = models.Model(inputs=inputs, outputs=output)
    else:
        model = models.Model(inputs=input_layers, outputs=output)

    return model


def build_resnet_model_flexible_concatenated_qus(input_types, qus_types, with_transfer_learning=False):
    """
    Alternative architecture: concatenate QUS inputs first, then process through single model.

    Args:
        qus_types: List of QUS types (e.g., ['ESD', 'EAC', 'SI', 'SS', 'MBF'])
        with_transfer_learning: Whether to use transfer learning

    Returns:
        Compiled Keras model
    """

    if len(input_types) == 1:
        # Single input - same as regular model
        return build_resnet_model_flexible_qus(input_types, with_transfer_learning)
    
    # Multiple inputs - concatenate along channel axis
    # Note: We already concatenate in the dataset creation, so we have a single input
    num_channels = len(qus_types) + (1 if 'B_mode' in input_types else 0)
    input_layer = layers.Input(shape=(224, 224, num_channels), name='combined_input')
    concatenated_input = input_layer

    # # Convert to 3 channels for ResNet (repeat channels if needed)
    # num_channels = len(qus_types)
    # if num_channels == 1:
    #     x = layers.Concatenate()([concatenated_input, concatenated_input, concatenated_input])
    # elif num_channels == 2:
    #     x = layers.Concatenate()([concatenated_input, concatenated_input[..., 0:1]])
    # elif num_channels >= 3:
    #     x = concatenated_input[..., :3]  # Take first 3 channels

    # Load base model
    if with_transfer_learning:
        base_model = models.load_model(MODEL_WEIGHTS_PATH, compile=False)
        # Need to adapt the input shape if it doesn't match
        if num_channels != 3:
            # Create a wrapper to handle different channel counts
            # For now, we'll use a custom ResNet50-like architecture
            base_model = None
    
    if base_model is None or num_channels != 3:
        # Use ResNet50 without transfer learning for non-3-channel inputs
        base_model = tf.keras.applications.ResNet50(
            weights=None,
            include_top=False,
            input_shape=(224, 224, num_channels)
        )

    base_model.trainable = True

    # Process concatenated input
    x = base_model(concatenated_input)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(4096, activation='relu')(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(2048, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    x = layers.Dense(1024, activation='relu')(x)
    x = layers.Dropout(0.2)(x)

    # Final binary classification
    output = layers.Dense(1, activation='sigmoid')(x)

    model = models.Model(inputs=input_layer, outputs=output)
    return model


In [None]:
def plotAndReturnValidationTesting_flexible_qus(val_patients, model, input_types, qus_types, qus_matrices, scalers):
    """Updated validation testing function for flexible QUS input models"""
    patient_to_true_labels = {}
    patient_to_predicted_probs_list = {}

    all_true_labels_individual = []
    all_predicted_probs_individual = []

    # ================= Collect predictions =================
    for patient in val_patients:
        patient_dataset, _ = create_dataset_from_patients_qus([patient], INPUT_TYPES, QUS_TYPES, qus_matrices, scalers=scalers, augment=False, batch_size=BATCH_SIZE)
        for inputs, labels in patient_dataset:
            true_label = labels[0].numpy()  # Single label per patient
            predictions = model.predict(inputs, verbose=0).flatten()

            patient_to_true_labels[patient.patient_id] = true_label
            patient_to_predicted_probs_list[patient.patient_id] = predictions.tolist()

            # Collect image-level predictions for metrics
            all_true_labels_individual.extend(labels.numpy().flatten())
            all_predicted_probs_individual.extend(predictions)

    # ================= Image-level ROC =================
    all_predicted_labels_individual = [1 if prob >= 0.5 else 0 for prob in all_predicted_probs_individual]
    fpr_img, tpr_img, _ = roc_curve(all_true_labels_individual, all_predicted_probs_individual)
    auc_img = roc_auc_score(all_true_labels_individual, all_predicted_probs_individual)

    plt.figure(figsize=(8, 6))
    plt.plot(fpr_img, tpr_img, label=f"Image-level ROC (AUC = {auc_img:.4f})", color='blue')
    plt.plot([0, 1], [0, 1], 'r--', label="Random Guess")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"Image-level ROC Curve - {', '.join(input_types)}")
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.show()

    # ================= Patient-level ROC =================
    patient_true = []
    patient_probs = []

    for pid in patient_to_predicted_probs_list:
        patient_true.append(patient_to_true_labels[pid])
        patient_probs.append(np.mean(patient_to_predicted_probs_list[pid]))  # average across cases

    fpr_pat, tpr_pat, _ = roc_curve(patient_true, patient_probs)
    auc_pat = roc_auc_score(patient_true, patient_probs)

    plt.figure(figsize=(8, 6))
    plt.plot(fpr_pat, tpr_pat, label=f"Patient-level ROC (AUC = {auc_pat:.4f})", color='green')
    plt.plot([0, 1], [0, 1], 'r--', label="Random Guess")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title(f"Patient-level ROC Curve - {', '.join(input_types)}")
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.show()

    return patient_to_true_labels, patient_to_predicted_probs_list


# Updated hyperparameter tuning configurations for flexible QUS input
def resnet_flexible_branched_qus(input_types):
    """Flexible QUS input with separate branches"""
    model = build_resnet_model_flexible_qus(input_types, with_transfer_learning=True)
    model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy', tf.keras.metrics.AUC(name='auc')])
    return model

def resnet_flexible_concatenated_qus(input_types, qus_types):
    """Flexible QUS input with concatenated channels"""
    model = build_resnet_model_flexible_concatenated_qus(input_types, qus_types, with_transfer_learning=False)
    model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy', tf.keras.metrics.AUC(name='auc')])
    return model

def resnet_flexible_no_transfer_qus(input_types):
    """Flexible QUS input without transfer learning"""
    model = build_resnet_model_flexible_qus(input_types, with_transfer_learning=False)
    model.compile(optimizer=tf.optimizers.Adam(learning_rate=0.0001), loss='binary_crossentropy', metrics=['accuracy', tf.keras.metrics.AUC(name='auc')])
    return model

# Create hyperparameter configurations based on selected input types
def create_hyperparameter_configs_qus(input_types, qus_types):
    """Create hyperparameter configurations for the selected input types"""
    return {
        f'resnet_flexible_branched_{'_'.join(input_types)}': lambda: resnet_flexible_branched_qus(input_types),
        f'resnet_flexible_concatenated_{'_'.join(input_types)}': lambda: resnet_flexible_concatenated_qus(input_types, qus_types),
        f'resnet_flexible_no_transfer_{'_'.join(input_types)}': lambda: resnet_flexible_no_transfer_qus(input_types)
    }


In [None]:
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight
from sklearn.metrics import roc_auc_score, classification_report
import numpy as np
import tensorflow as tf

N_RUNS = 5
val_aucs, test_aucs = [], []

for run_idx in range(N_RUNS):
    print(f"\n{'='*40}")
    print(f"▶️ HOLD-OUT RUN {run_idx+1}/{N_RUNS}")
    print(f"{'='*40}\n")

    # Split into train/val/test with new random state each run
    SEED_RUN = SEED + run_idx  # different seed each run
    train_and_val_patients, test_patients = train_test_split(
        all_patients, test_size=0.1, random_state=SEED_RUN
    )
    train_patients, val_patients = train_test_split(
        train_and_val_patients, test_size=0.2, random_state=SEED_RUN
    )

            print(f"Training patients: {len(train_patients)}")
        print(f"Validation patients: {len(val_patients)}")
        print(f"Test patients: {len(test_patients)}")

        # Create datasets with min-max scaling
        # Fit scalers on training data only
        print("\nCreating training dataset and fitting scalers...")
        train_dataset, train_scalers = create_dataset_from_patients_qus(
            train_patients, INPUT_TYPES, QUS_TYPES, qus_matrices, scalers=None, augment=True, batch_size=BATCH_SIZE
        )

        # Use the same scalers for validation and test
        print("\nCreating validation dataset (using training scalers)...")
        val_dataset, _ = create_dataset_from_patients_qus(
            val_patients, INPUT_TYPES, QUS_TYPES, qus_matrices, scalers=train_scalers, augment=False, batch_size=BATCH_SIZE
        )

        print("\nCreating test dataset (using training scalers)...")
        test_dataset, _ = create_dataset_from_patients_qus(
            test_patients, INPUT_TYPES, QUS_TYPES, qus_matrices, scalers=train_scalers, augment=False, batch_size=BATCH_SIZE
        )

        # Compute class weights
        weights = class_weight.compute_class_weight(
            class_weight='balanced',
            classes=np.unique(np.array([p.egfr for p in train_patients]).astype(int)),
            y=np.array([p.egfr for p in train_patients]).astype(int)
        )
        class_weights_dict = dict(enumerate(weights))
        print(f"Class weights: {class_weights_dict}")

        # Get model configuration
        #hyperparameter_configs = create_hyperparameter_configs_qus(INPUT_TYPES, QUS_TYPES)
        #MODEL_CONFIG_NAME = f'resnet_flexible_branched_{'_'.join(INPUT_TYPES)}'
        #print(f"Using model configuration: {MODEL_CONFIG_NAME}")

        # Build and compile model
        model = resnet_flexible_concatenated_qus(INPUT_TYPES, QUS_TYPES)

        # Early stopping and LR scheduler
        early_stopping = tf.keras.callbacks.EarlyStopping(
            monitor='val_auc', mode='max', patience=20, restore_best_weights=True
        )
        # def step_decay(epoch, lr):
        #     drop_rate = 0.5
        #     drop_every = 15
        #     if epoch > 0 and epoch % drop_every == 0:
        #         return lr * drop_rate
        #     return lr
        # lr_scheduler = tf.keras.callbacks.LearningRateScheduler(step_decay, verbose=0)

        print("Starting training...")
        history = model.fit(
            train_dataset,
            validation_data=val_dataset,
            epochs=EPOCHS,
            callbacks=[early_stopping],
            class_weight=class_weights_dict,
            verbose=1
        )

        # === Evaluate Validation ===
        print("\nEvaluating on validation set...")
        val_true, val_pred = [], []
        patient_to_true_labels, patient_to_predicted_probs = plotAndReturnValidationTesting_flexible_qus(val_patients, model, INPUT_TYPES, QUS_TYPES, qus_matrices, train_scalers)
        for patient_id in patient_to_predicted_probs:
            for prob in patient_to_predicted_probs[patient_id]:
                val_true.append(patient_to_true_labels[patient_id])
                val_pred.append(prob)
        val_auc = roc_auc_score(val_true, val_pred)
        val_aucs.append(val_auc)
        print(f"Validation AUC (Run {run_idx+1}): {val_auc:.4f}")

        # === Evaluate Test ===
        print("\nEvaluating on test set...")
        test_true, test_pred = [], []
        patient_to_true_labels, patient_to_predicted_probs = plotAndReturnValidationTesting_flexible_qus(test_patients, model, INPUT_TYPES, QUS_TYPES, qus_matrices, train_scalers)
        for patient_id in patient_to_predicted_probs:
            for prob in patient_to_predicted_probs[patient_id]:
                test_true.append(patient_to_true_labels[patient_id])
                test_pred.append(prob)
        test_auc = roc_auc_score(test_true, test_pred)
        test_aucs.append(test_auc)
        print(f"Test AUC (Run {run_idx+1}): {test_auc:.4f}")

        print(classification_report(test_true, [1 if p >= 0.5 else 0 for p in test_pred], digits=4))


# ===========================
# ✅ Final Summary of 5 Runs
# ===========================
print("\n" + "="*60)
print("FINAL SUMMARY OVER MULTIPLE HOLD-OUT RUNS")
print("="*60)
print(f"Validation AUCs: {['%.4f' % a for a in val_aucs]}")
print(f"Test AUCs:       {['%.4f' % a for a in test_aucs]}")
print(f"\nAverage Validation AUC: {np.mean(val_aucs):.4f} ± {np.std(val_aucs):.4f}")
print(f"Average Test AUC:       {np.mean(test_aucs):.4f} ± {np.std(test_aucs):.4f}")
