# MIFOCAT K-Fold Cross-Validation Training Notebook
## Cardiac Myocardium Segmentation with TensorFlow/Keras

This notebook orchestrates end-to-end k-fold cross-validation training for cardiac segmentation using the MIFOCAT loss function.

**Key Steps:**
1. Setup and imports
2. Load and split dataset with patient-level stratification
3. Initialize data loaders and trainer
4. Train all k folds with early stopping and checkpointing
5. Aggregate and visualize fold results

In [1]:
import sys
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Add project to path
project_root = Path.cwd()
sys.path.insert(0, str(project_root))

print(f"Project root: {project_root}")
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")

Project root: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation
Python version: 3.11.14 (main, Oct 21 2025, 18:27:30) [Clang 20.1.8 ]
NumPy version: 1.26.4


In [2]:
try:
    import tensorflow as tf
    from tensorflow import keras
    print(f"TensorFlow version: {tf.__version__}")
    
    # Check for Apple Silicon GPU availability
    if tf.config.list_physical_devices('GPU'):
        print(f"✓ GPU detected: {tf.config.list_physical_devices('GPU')}")
    else:
        print(" No GPU detected. Using CPU.")
        print(f"  Available devices: {tf.config.list_physical_devices()}")
    
except ImportError as e:
    print(f"TensorFlow import failed: {e}")
    print("Install with: pip install tensorflow keras")

try:
    from split_data import CardiacDataSplitter
    from custom_datagen import FoldAwareDataLoader
    from train_kfold_wrapper import KFoldTrainer
    print("✓ Custom modules imported successfully")
except ImportError as e:
    print(f"Custom module import failed: {e}")
    print("Ensure split_data.py, custom_datagen.py, and train_kfold_wrapper.py are in the project root.")

TensorFlow version: 2.15.0
✓ GPU detected: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
✓ Custom modules imported successfully


## Configuration
Set up paths, hyperparameters, and k-fold settings.

In [3]:
# ============ CONFIGURATION ============

# Data paths - ACDC2017 Dataset Structure
DATA_ROOT = Path.cwd() / "acdc2017" / "Data 2D" / "ED" / "Data Per Pasien Training 2D"
OUTPUT_DIR = Path.cwd() / "kfold_results"
OUTPUT_DIR.mkdir(exist_ok=True)

# K-Fold parameters
N_SPLITS = 5                    # Number of folds
VAL_RATIO = 0.1               # Validation ratio per fold
RANDOM_SEED = 42

# Training parameters
MODEL_TYPE = "unet"            # Options: "unet", "transunet", "resnet", etc.
EPOCHS_PER_FOLD = 50
BATCH_SIZE = 32
EARLY_STOP_PATIENCE = 10
IMAGE_SUBDIR = "images"        # Subdirectory for images within patient folders
MASK_SUBDIR = "groundtruth"    # Subdirectory for masks within patient folders (ACDC2017 uses 'groundtruth')

# Model checkpoint output
CHECKPOINT_DIR = OUTPUT_DIR / "checkpoints"
CHECKPOINT_DIR.mkdir(exist_ok=True)

# Print configuration
print("=" * 50)
print("K-FOLD TRAINING CONFIGURATION")
print("=" * 50)
print(f"Data root: {DATA_ROOT}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Number of folds: {N_SPLITS}")
print(f"Validation ratio: {VAL_RATIO}")
print(f"Model type: {MODEL_TYPE}")
print(f"Epochs per fold: {EPOCHS_PER_FOLD}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Early stopping patience: {EARLY_STOP_PATIENCE}")
print(f"Mask subdirectory: {MASK_SUBDIR}")
print("=" * 50)

K-FOLD TRAINING CONFIGURATION
Data root: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
Output directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/kfold_results
Number of folds: 5
Validation ratio: 0.1
Model type: unet
Epochs per fold: 50
Batch size: 32
Early stopping patience: 10
Mask subdirectory: groundtruth


## Step 1: Load and Split Dataset
Use `CardiacDataSplitter` to generate patient-level stratified k-fold splits and create metadata.

In [4]:
# Initialize the data splitter
splitter = CardiacDataSplitter(
    input_folder=str(DATA_ROOT),
    output_folder=str(OUTPUT_DIR)
)

# Run k-fold split
# This generates kfold_metadata.json with patient-to-fold assignments
print(f"\nExecuting k-fold split with n_splits={N_SPLITS}, val_ratio={VAL_RATIO}...")
print(f"Data directory: {DATA_ROOT}")
print(f"Data directory exists: {DATA_ROOT.exists()}")

splitter.kfold_split(
    n_splits=N_SPLITS,
    val_ratio=VAL_RATIO
)
print("✓ K-fold split complete")

# Load the metadata to inspect fold structure
metadata_path = OUTPUT_DIR / "kfold_metadata.json"
with open(metadata_path, 'r') as f:
    metadata = json.load(f)

print(f"\nK-Fold Metadata Summary:")
print(f"  Total patients: {metadata['total_patients']}")
print(f"  Total folds: {metadata['n_splits']}")
for fold_data in metadata['folds']:
    fold_id = fold_data.get('fold_id', '?')
    train_patients = len(fold_data.get('train', []))
    val_patients = len(fold_data.get('val', []))
    print(f"  Fold {fold_id}: {train_patients} train, {val_patients} val patients")


Executing k-fold split with n_splits=5, val_ratio=0.1...
Data directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
Data directory exists: True
[K-FOLD] Found 99 patients: ['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '4', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '5', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '6', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '7', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '8', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '9', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99']
[K-FOLD] Creating 5-fold split with val_ratio=0.1
[K-FOLD] Fold 0: train=72, val=20, test=7 patients
[K-FOLD] Fold 1: train=72, val=20, test=7 patients
[K-FOLD] Fold 2: train=72, va

## Step 2: Initialize Data Loader and Trainer
Set up the fold-aware data loader and k-fold trainer.

In [5]:
# Initialize the k-fold trainer
# NOTE: KFoldTrainer.get_model() is a placeholder; you must implement it with your model architecture
trainer = KFoldTrainer(
    base_data_dir=str(DATA_ROOT),
    fold_metadata_path=str(metadata_path),
    output_dir=str(OUTPUT_DIR)
)

print("✓ KFoldTrainer initialized")
print(f"  Metadata: {metadata_path}")
print(f"  Output directory: {OUTPUT_DIR}")
print(f"  Checkpoint directory: {CHECKPOINT_DIR}")

[KFoldTrainer] Loaded metadata for 5-fold CV
[KFoldTrainer] Output directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/kfold_results
[KFoldTrainer] Initializing data loader...
[KFoldTrainer] Base data directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
[KFoldTrainer] Base data directory exists: True
✓ KFoldTrainer initialized
  Metadata: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/kfold_results/kfold_metadata.json
  Output directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/kfold_results
  Checkpoint directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/kfold_results/checkpoints


## Step 3: Train All K-Folds
Run training for each fold with early stopping and checkpointing. This may take a long time.

**Note:** Before running training, ensure you have implemented the `get_model()` method in `train_kfold_wrapper.py` to instantiate and compile your model with the MIFOCAT loss function.

In [6]:
# DIAGNOSTIC: Test data loading before training
import sys
sys.stdout.flush()

print("=" * 70)
print("DIAGNOSTIC: Testing Data Loader")
print("=" * 70)

# Force import reload to get latest debugging code
import importlib
import custom_datagen
importlib.reload(custom_datagen)
from custom_datagen import FoldAwareDataLoader

# Test creating a loader with the same parameters
test_loader = FoldAwareDataLoader(
    str(DATA_ROOT),
    str(metadata_path)
)

print("\n✓ Test loader created successfully")
print(f"Base directory: {test_loader.base_dir}")
print(f"Base directory exists: {test_loader.base_dir.exists()}")

# Try to get generators for fold 0
print("\n" + "-" * 70)
print("Testing fold 0 generator creation...")
print("-" * 70)
sys.stdout.flush()

try:
    test_train_gen, test_val_gen, test_train_steps, test_val_steps = test_loader.get_generators(
        fold_id=0,
        batch_size=BATCH_SIZE
    )
    print(f"\n✓✓✓ SUCCESS! Data loading works!")
    print(f"    Training steps: {test_train_steps}")
    print(f"    Validation steps: {test_val_steps}")
    print(f"\nIf you see this, the data can be loaded. Proceed with training.")
except Exception as e:
    print(f"\n✗✗✗ FAILED! Error: {e}")
    import traceback
    traceback.print_exc()
    print("\nPlease fix the error above before running training.")

sys.stdout.flush()
print("=" * 70)

DIAGNOSTIC: Testing Data Loader

✓ Test loader created successfully
Base directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
Base directory exists: True

----------------------------------------------------------------------
Testing fold 0 generator creation...
----------------------------------------------------------------------

[FoldAwareDataLoader.get_generators] ===== STARTING FOLD 0 =====
[FoldAwareDataLoader.get_generators] Image subdir: images, Mask subdir: groundtruth
[FoldAwareDataLoader.get_generators] Target size: (256, 256)
[FoldAwareDataLoader.get_generators] Base directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
[FoldAwareDataLoader] Fold 0: 72 train patients, 20 val patients
[FoldAwareDataLoader] Train patients: ['1', '10', '11', '12', '13']...
[FoldAwareDataLoader] Val patients: ['16', '17', '2', '27', '29']...
[FoldAwareDataLoader.get_file_lis

In [7]:
import time

# Run all folds
start_time = time.time()

print("\n" + "=" * 70)
print("STARTING K-FOLD CROSS-VALIDATION TRAINING")
print("=" * 70)

try:
    trainer.run_all_folds(
        model_type=MODEL_TYPE,
        epochs=EPOCHS_PER_FOLD,
        batch_size=BATCH_SIZE,
        train_only=False  # Set to False if you want evaluation per fold
    )
    print("\n✓ Training completed successfully!")
except Exception as e:
    print(f"\n✗ Training failed with error: {e}")
    import traceback
    traceback.print_exc()

elapsed_time = time.time() - start_time
hours = int(elapsed_time // 3600)
minutes = int((elapsed_time % 3600) // 60)
seconds = int(elapsed_time % 60)
print(f"\nTotal training time: {hours}h {minutes}m {seconds}s")


STARTING K-FOLD CROSS-VALIDATION TRAINING

######################################################################
# K-FOLD CROSS-VALIDATION: 5 Folds
# Metadata: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/kfold_results/kfold_metadata.json
# Data dir: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
# Output dir: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/kfold_results
# Started: 2026-01-26 15:55:37
######################################################################

[KFoldTrainer] Loading U-Net model with MIFOCAT loss...


2026-01-26 15:55:37.910955: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3 Pro
2026-01-26 15:55:37.910982: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 18.00 GB
2026-01-26 15:55:37.910985: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 6.00 GB
2026-01-26 15:55:37.911012: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2026-01-26 15:55:37.911027: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


[KFoldTrainer] Model compiled with MIFOCAT loss

[FOLD 0] Starting training
[FOLD 0] Loading data generators...

[FoldAwareDataLoader.get_generators] ===== STARTING FOLD 0 =====
[FoldAwareDataLoader.get_generators] Image subdir: images, Mask subdir: groundtruth
[FoldAwareDataLoader.get_generators] Target size: (256, 256)
[FoldAwareDataLoader.get_generators] Base directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
[FoldAwareDataLoader] Fold 0: 72 train patients, 20 val patients
[FoldAwareDataLoader] Train patients: ['1', '10', '11', '12', '13']...
[FoldAwareDataLoader] Val patients: ['16', '17', '2', '27', '29']...
[FoldAwareDataLoader.get_file_list] Starting file list collection
[FoldAwareDataLoader.get_file_list] Base directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
[FoldAwareDataLoader.get_file_list] Base directory exists: True
[FoldAwareDataLoader.get_file_l

2026-01-26 15:55:40.431499: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
2026-01-26 15:55:40.650037: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] model_pruner failed: INVALID_ARGUMENT: Graph does not contain terminal node Adam/AssignAddVariableOp.


Epoch 1: val_loss improved from inf to 6.93290, saving model to /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/kfold_results/fold_0/fold_0_best_model.h5


  saving_api.save_model(


Epoch 2/50
Epoch 2: val_loss did not improve from 6.93290
Epoch 3/50
Epoch 3: val_loss did not improve from 6.93290
Epoch 4/50
Epoch 4: val_loss did not improve from 6.93290
Epoch 5/50
Epoch 5: val_loss did not improve from 6.93290
Epoch 6/50
Epoch 6: val_loss did not improve from 6.93290
Epoch 7/50
Epoch 7: val_loss did not improve from 6.93290
Epoch 8/50
Epoch 8: val_loss improved from 6.93290 to 0.31819, saving model to /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/kfold_results/fold_0/fold_0_best_model.h5
Epoch 9/50
Epoch 9: val_loss did not improve from 0.31819
Epoch 10/50
Epoch 10: val_loss improved from 0.31819 to 0.03404, saving model to /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/kfold_results/fold_0/fold_0_best_model.h5
Epoch 11/50
Epoch 11: val_loss did not improve from 0.03404
Epoch 12/50
Epoch 12: val_loss improved from 0.03404 to 0.02406, saving model to /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/kfold_results/fold_0/fold_0_best_model.h5
Ep



[KFoldTrainer] Model compiled with MIFOCAT loss

[FOLD 1] Starting training
[FOLD 1] Loading data generators...

[FoldAwareDataLoader.get_generators] ===== STARTING FOLD 1 =====
[FoldAwareDataLoader.get_generators] Image subdir: images, Mask subdir: groundtruth
[FoldAwareDataLoader.get_generators] Target size: (256, 256)
[FoldAwareDataLoader.get_generators] Base directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
[FoldAwareDataLoader] Fold 1: 72 train patients, 20 val patients
[FoldAwareDataLoader] Train patients: ['1', '10', '11', '12', '14']...
[FoldAwareDataLoader] Val patients: ['13', '19', '24', '30', '32']...
[FoldAwareDataLoader.get_file_list] Starting file list collection
[FoldAwareDataLoader.get_file_list] Base directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
[FoldAwareDataLoader.get_file_list] Base directory exists: True
[FoldAwareDataLoader.get_file_

2026-01-26 16:40:46.520995: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] model_pruner failed: INVALID_ARGUMENT: Graph does not contain terminal node Adam/AssignAddVariableOp.


 1/21 [>.............................] - ETA: 1:49 - loss: 1.8432 - accuracy: 0.2467 - mean_iou: 0.0165 - dice_score: 0.0324

KeyboardInterrupt: 

## Step 4: Load and Display Aggregated Results
After training completes, load the aggregated metrics across all folds.

In [None]:
# Load aggregated results
fold_results_path = OUTPUT_DIR / "fold_results.json"
aggregated_results_path = OUTPUT_DIR / "aggregated_results.json"

if fold_results_path.exists():
    with open(fold_results_path, 'r') as f:
        fold_results = json.load(f)
    print("Fold-specific results:")
    for fold_id, metrics in fold_results.items():
        print(f"\n  Fold {fold_id}:")
        for metric_name, value in metrics.items():
            if isinstance(value, float):
                print(f"    {metric_name}: {value:.4f}")
            else:
                print(f"    {metric_name}: {value}")
else:
    print(f"Fold results not found at {fold_results_path}")
    fold_results = {}

if aggregated_results_path.exists():
    with open(aggregated_results_path, 'r') as f:
        aggregated_results = json.load(f)
    print("\n" + "=" * 50)
    print("AGGREGATED RESULTS ACROSS ALL FOLDS")
    print("=" * 50)
    for metric_name, stats in aggregated_results.items():
        print(f"\n{metric_name}:")
        for stat_type, value in stats.items():
            if isinstance(value, float):
                print(f"  {stat_type}: {value:.4f}")
            else:
                print(f"  {stat_type}: {value}")
else:
    print(f"Aggregated results not found at {aggregated_results_path}")
    aggregated_results = {}

## Step 5: Visualize Training History
Plot training and validation metrics per fold to assess convergence and overfitting.

In [None]:
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (15, 10)

# Collect all training histories
all_histories = {}
history_dir = OUTPUT_DIR
for fold_id in range(N_SPLITS):
    history_file = history_dir / f"fold_{fold_id}_history.json"
    if history_file.exists():
        with open(history_file, 'r') as f:
            all_histories[fold_id] = json.load(f)

if all_histories:
    # Plot metrics per fold
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    fig.suptitle('Training History Across All Folds', fontsize=16, fontweight='bold')
    
    # Determine available metrics (from first fold)
    first_fold_history = all_histories[0]
    metrics_to_plot = [k for k in first_fold_history.keys() if 'loss' in k or 'metric' in k][:4]
    
    for idx, metric in enumerate(metrics_to_plot):
        ax = axes.flatten()[idx]
        for fold_id, history in all_histories.items():
            if metric in history:
                ax.plot(history[metric], marker='o', label=f'Fold {fold_id}', alpha=0.7)
        ax.set_title(metric, fontweight='bold')
        ax.set_xlabel('Epoch')
        ax.set_ylabel(metric)
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / 'training_history.png', dpi=150, bbox_inches='tight')
    plt.show()
    print(f"✓ Training history plot saved to {OUTPUT_DIR / 'training_history.png'}")
else:
    print("No training history files found. Training may not have completed.")

## Step 6: Summary and Recommendations
Display final summary and provide next steps.

In [None]:
print("\n" + "=" * 70)
print("K-FOLD CROSS-VALIDATION TRAINING SUMMARY")
print("=" * 70)

summary = f"""
✓ Training Configuration:
  - Number of folds: {N_SPLITS}
  - Model type: {MODEL_TYPE}
  - Epochs per fold: {EPOCHS_PER_FOLD}
  - Batch size: {BATCH_SIZE}
  - Early stopping patience: {EARLY_STOP_PATIENCE}

✓ Output Artifacts:
  - Checkpoints: {CHECKPOINT_DIR}
  - Results: {OUTPUT_DIR}
  - Metadata: {OUTPUT_DIR / 'kfold_metadata.json'}
  - Fold results: {OUTPUT_DIR / 'fold_results.json'}
  - Aggregated results: {OUTPUT_DIR / 'aggregated_results.json'}

✓ Next Steps:
  1. Review aggregated_results.json for final performance metrics
  2. Inspect individual fold checkpoints for the best model per fold
  3. Load best checkpoint and run inference on test set (see hitung_evaluasi_metrik.py)
  4. For publication: Include aggregated metrics (mean ± std) in results tables
  5. Optionally: Run ablation studies on MIFOCAT loss components (L_MI, L_FO, L_CAT)

✓ Reproducibility:
  - Random seed: {RANDOM_SEED}
  - K-fold metadata: kfold_metadata.json (patient-to-fold assignments)
  - Training logs: Individual fold_*_history.json files
"""

print(summary)

# Display aggregated metrics summary table if available
if aggregated_results:
    print("\n" + "=" * 70)
    print("AGGREGATED METRICS SUMMARY")
    print("=" * 70)
    df_summary = pd.DataFrame(aggregated_results).T
    print(df_summary.to_string())

## Appendix: Optional - Load Best Model and Run Inference
Uncomment and modify this section to load a trained model and run inference on sample data.

In [None]:
# Example: Load best model from a fold and run inference
# Uncomment to use:

# from keras.models import load_model
# import cv2
# from pathlib import Path

# # Select the best fold (e.g., Fold 0)
# best_fold = 0
# best_model_path = CHECKPOINT_DIR / f"fold_{best_fold}_best_model.h5"

# if best_model_path.exists():
#     # Load model with custom objects if needed
#     custom_objects = {}  # Add custom layers/losses if using TransUNet or custom MIFOCAT loss
#     model = load_model(best_model_path, custom_objects=custom_objects)
#     print(f"✓ Loaded model from {best_model_path}")
    
#     # Example inference on a test image
#     test_image_path = DATA_ROOT / "Pasien_001" / "images" / "image_001.png"
#     if test_image_path.exists():
#         img = cv2.imread(str(test_image_path), cv2.IMREAD_GRAYSCALE)
#         img = cv2.resize(img, (224, 224))  # Adjust to your model input size
#         img_normalized = img.astype('float32') / 255.0
#         img_batch = np.expand_dims(np.expand_dims(img_normalized, axis=0), axis=-1)
        
#         # Run prediction
#         prediction = model.predict(img_batch, verbose=0)
#         print(f"Prediction shape: {prediction.shape}")
#         print(f"Prediction range: [{prediction.min():.4f}, {prediction.max():.4f}]")
        
#         # Visualize
#         fig, axes = plt.subplots(1, 3, figsize=(15, 5))
#         axes[0].imshow(img, cmap='gray')
#         axes[0].set_title('Input Image')
#         axes[0].axis('off')
        
#         axes[1].imshow(prediction[0, :, :, 0], cmap='hot')
#         axes[1].set_title('Model Output')
#         axes[1].axis('off')
        
#         axes[2].imshow(prediction[0, :, :, 0] > 0.5, cmap='gray')
#         axes[2].set_title('Binary Prediction (threshold=0.5)')
#         axes[2].axis('off')
        
#         plt.tight_layout()
#         plt.show()
# else:
#     print(f"Model checkpoint not found at {best_model_path}")