# TransUNet K-Fold Cross-Validation Training

This notebook implements k-fold cross-validation training and metrics measurement for the **TransUNet model** with **MIFOCAT loss**.

## Overview

- **Model**: TransUNet (Transformer-based U-Net)
- **Loss**: MIFOCAT (MSE + Focal + Categorical Cross-Entropy)
- **Dataset**: ACDC 2017 Cardiac MRI
- **Validation**: 5-fold cross-validation

## Workflow

1. Setup and imports
2. Configure parameters
3. Data splitting (if needed)
4. K-fold training
5. Results aggregation
6. Visualization

## 1. Setup and Imports

In [1]:
import sys
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Add project to path
project_root = Path.cwd()
sys.path.insert(0, str(project_root))

print(f"Project root: {project_root}")
print(f"Python version: {sys.version}")
print(f"NumPy version: {np.__version__}")

Project root: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation
Python version: 3.11.14 (main, Oct 21 2025, 18:27:30) [Clang 20.1.8 ]
NumPy version: 1.26.4


In [2]:
try:
    import tensorflow as tf
    from tensorflow import keras
    print(f"TensorFlow version: {tf.__version__}")
    
    # Check for GPU availability
    if tf.config.list_physical_devices('GPU'):
        print(f"âœ“ GPU detected: {tf.config.list_physical_devices('GPU')}")
    else:
        print("âš  No GPU detected. Using CPU.")
        print(f"  Available devices: {tf.config.list_physical_devices()}")
    
except ImportError as e:
    print(f"TensorFlow import failed: {e}")
    print("Install with: pip install tensorflow keras")

try:
    from split_data import CardiacDataSplitter
    from custom_datagen import FoldAwareDataLoader
    from train_kfold_wrapper import KFoldTrainer
    from transunet_model import build_transunet_mifocat, get_custom_objects
    print("âœ“ Custom modules imported successfully")
except ImportError as e:
    print(f"Custom module import failed: {e}")
    print("Ensure all required modules are in the project root.")

TensorFlow version: 2.15.0
âœ“ GPU detected: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
âœ“ Custom modules imported successfully


## 2. Configuration

In [3]:
# ============ CONFIGURATION ============

# Data paths - ACDC2017 Dataset Structure
DATA_ROOT = Path.cwd() / "acdc2017" / "Data 2D" / "ED" / "Data Per Pasien Training 2D"
OUTPUT_DIR = Path.cwd() / "transunet_kfold_results"
OUTPUT_DIR.mkdir(exist_ok=True)

# K-Fold parameters
N_SPLITS = 5                    # Number of folds
VAL_RATIO = 0.1                 # Validation ratio per fold
RANDOM_SEED = 42

# Training parameters
MODEL_TYPE = "transunet"        # TransUNet model
EPOCHS_PER_FOLD = 50
BATCH_SIZE = 32
EARLY_STOP_PATIENCE = 10
IMAGE_SUBDIR = "images"         # Subdirectory for images within patient folders
MASK_SUBDIR = "groundtruth"     # Subdirectory for masks within patient folders (ACDC2017 uses 'groundtruth')

# Model checkpoint output
CHECKPOINT_DIR = OUTPUT_DIR / "checkpoints"
CHECKPOINT_DIR.mkdir(exist_ok=True)

# Print configuration
print("=" * 60)
print("TRANSUNET K-FOLD TRAINING CONFIGURATION")
print("=" * 60)
print(f"Data root: {DATA_ROOT}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Number of folds: {N_SPLITS}")
print(f"Validation ratio: {VAL_RATIO}")
print(f"Model type: {MODEL_TYPE}")
print(f"Epochs per fold: {EPOCHS_PER_FOLD}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Early stopping patience: {EARLY_STOP_PATIENCE}")
print(f"Mask subdirectory: {MASK_SUBDIR}")
print("=" * 60)

TRANSUNET K-FOLD TRAINING CONFIGURATION
Data root: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
Output directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/transunet_kfold_results
Number of folds: 5
Validation ratio: 0.1
Model type: transunet
Epochs per fold: 50
Batch size: 32
Early stopping patience: 10
Mask subdirectory: groundtruth


## 3. Data Splitting (Optional)

If you haven't generated the k-fold metadata yet, run this cell. Otherwise, skip to the next section.

In [4]:
# ## Uncomment to generate k-fold split metadata
# ## This only needs to be run once

# splitter = CardiacDataSplitter(
#     input_folder=str(DATA_ROOT),
#     output_folder=str(OUTPUT_DIR)
# )

# print(f"\nExecuting k-fold split with n_splits={N_SPLITS}, val_ratio={VAL_RATIO}...")
# print(f"Data directory: {DATA_ROOT}")
# print(f"Data directory exists: {DATA_ROOT.exists()}")

# splitter.kfold_split(
#     n_splits=N_SPLITS,
#     val_ratio=VAL_RATIO
# )
# print("âœ“ K-fold split complete")

## 4. Initialize K-Fold Trainer

In [5]:
# Path to k-fold metadata (generated by split_data.py or above cell)
FOLD_METADATA_PATH = OUTPUT_DIR / "kfold_metadata.json"

if not FOLD_METADATA_PATH.exists():
    print(f"âœ— ERROR: Fold metadata not found: {FOLD_METADATA_PATH}")
    print("Run the data splitting step first or provide existing metadata.")
else:
    print(f"âœ“ Fold metadata found: {FOLD_METADATA_PATH}")
    
    # Initialize trainer
    trainer = KFoldTrainer(
        fold_metadata_path=str(FOLD_METADATA_PATH),
        base_data_dir=str(DATA_ROOT),
        output_dir=str(OUTPUT_DIR),
        seed=RANDOM_SEED
    )
    
    print("âœ“ KFoldTrainer initialized successfully")

âœ“ Fold metadata found: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/transunet_kfold_results/kfold_metadata.json
[KFoldTrainer] Loaded metadata for 5-fold CV
[KFoldTrainer] Output directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/transunet_kfold_results
[KFoldTrainer] Initializing data loader...
[KFoldTrainer] Base data directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
[KFoldTrainer] Base data directory exists: True
âœ“ KFoldTrainer initialized successfully


## 5. Run K-Fold Training

This will train the TransUNet model on all folds and save checkpoints + metrics.

In [6]:
# Run k-fold cross-validation
print("\nðŸš€ Starting k-fold cross-validation training for TransUNet...\n")

results = trainer.run_all_folds(
    model_type=MODEL_TYPE,
    epochs=EPOCHS_PER_FOLD,
    batch_size=BATCH_SIZE,
    train_only=False,  # Set to True to skip test evaluation
    start_fold=0       # Set to resume from a specific fold
)

print("\nâœ“ K-fold cross-validation completed!")


ðŸš€ Starting k-fold cross-validation training for TransUNet...


######################################################################
# K-FOLD CROSS-VALIDATION: 5 Folds
# Metadata: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/transunet_kfold_results/kfold_metadata.json
# Data dir: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
# Output dir: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/transunet_kfold_results
# Model type: transunet
# Starting from fold: 0
# Started: 2026-01-28 17:00:16
######################################################################

[KFoldTrainer] Loading TransUNet model with MIFOCAT loss...


2026-01-28 17:00:16.693497: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M3 Pro
2026-01-28 17:00:16.693526: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 18.00 GB
2026-01-28 17:00:16.693529: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 6.00 GB
2026-01-28 17:00:16.693826: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:306] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2026-01-28 17:00:16.694115: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:272] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


[KFoldTrainer] TransUNet model compiled with MIFOCAT loss

[FOLD 0] Starting training
[FOLD 0] Loading data generators...

[FoldAwareDataLoader.get_generators] ===== STARTING FOLD 0 =====
[FoldAwareDataLoader.get_generators] Image subdir: images, Mask subdir: groundtruth
[FoldAwareDataLoader.get_generators] Target size: (256, 256)
[FoldAwareDataLoader.get_generators] Base directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
[FoldAwareDataLoader] Fold 0: 72 train patients, 20 val patients
[FoldAwareDataLoader] Train patients: ['1', '10', '11', '12', '13']...
[FoldAwareDataLoader] Val patients: ['16', '17', '2', '27', '29']...
[FoldAwareDataLoader.get_file_list] Starting file list collection
[FoldAwareDataLoader.get_file_list] Base directory: /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/acdc2017/Data 2D/ED/Data Per Pasien Training 2D
[FoldAwareDataLoader.get_file_list] Base directory exists: True
[FoldAwareDataLoader.

2026-01-28 17:00:22.222627: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.
2026-01-28 17:00:23.066538: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] model_pruner failed: INVALID_ARGUMENT: Graph does not contain terminal node Adam/AssignAddVariableOp.


Epoch 1: val_loss improved from inf to 15.42576, saving model to /Users/iganarendra/Downloads/Code-Cardiac-Segmentation/transunet_kfold_results/fold_0/fold_0_best_model.h5


  saving_api.save_model(


Epoch 2/50

KeyboardInterrupt: 

## 6. Results Analysis

In [None]:
# Load and display aggregated results
aggregated_results_path = OUTPUT_DIR / "aggregated_results.json"
fold_results_path = OUTPUT_DIR / "fold_results.json"

if aggregated_results_path.exists():
    with open(aggregated_results_path, 'r') as f:
        aggregated = json.load(f)
    
    print("\n" + "="*60)
    print("AGGREGATED RESULTS (ACROSS ALL FOLDS)")
    print("="*60)
    for key, value in aggregated.items():
        if isinstance(value, float):
            print(f"{key:.<40} {value:.6f}")
        else:
            print(f"{key:.<40} {value}")
    print("="*60)

if fold_results_path.exists():
    with open(fold_results_path, 'r') as f:
        fold_results = json.load(f)
    
    # Create DataFrame for easier visualization
    df = pd.DataFrame(fold_results)
    print("\n" + "="*60)
    print("PER-FOLD RESULTS")
    print("="*60)
    print(df.to_string(index=False))
    print("="*60)

## 7. Visualization

In [None]:
# Visualize validation loss across folds
if fold_results_path.exists():
    with open(fold_results_path, 'r') as f:
        fold_results = json.load(f)
    
    val_losses = [r.get('final_val_loss') for r in fold_results if r.get('final_val_loss')]
    fold_ids = [r.get('fold_id') for r in fold_results if r.get('final_val_loss')]
    
    if val_losses:
        plt.figure(figsize=(10, 6))
        plt.bar(fold_ids, val_losses, color='steelblue', alpha=0.7)
        plt.axhline(y=np.mean(val_losses), color='red', linestyle='--', label=f'Mean: {np.mean(val_losses):.4f}')
        plt.xlabel('Fold ID', fontsize=12)
        plt.ylabel('Validation Loss', fontsize=12)
        plt.title('TransUNet Validation Loss Across Folds', fontsize=14, fontweight='bold')
        plt.legend()
        plt.grid(axis='y', alpha=0.3)
        plt.tight_layout()
        plt.savefig(OUTPUT_DIR / 'transunet_validation_loss.png', dpi=300)
        plt.show()
        
        print(f"\nâœ“ Visualization saved to: {OUTPUT_DIR / 'transunet_validation_loss.png'}")

## 8. Load and Test a Specific Fold Model

In [None]:
# Example: Load best model from fold 0
fold_id = 0
model_path = OUTPUT_DIR / f"fold_{fold_id}" / f"fold_{fold_id}_best_model.h5"

if model_path.exists():
    print(f"Loading model from: {model_path}")
    
    # Load with custom objects
    custom_objs = get_custom_objects()
    model = keras.models.load_model(str(model_path), custom_objects=custom_objs, compile=False)
    
    print("âœ“ Model loaded successfully")
    print(f"\nModel summary:")
    model.summary()
else:
    print(f"âœ— Model not found: {model_path}")

## Summary

This notebook implements k-fold cross-validation training for TransUNet with MIFOCAT loss, mirroring the workflow from `train_kfold_notebook.ipynb` but specifically for the TransUNet architecture.

**Key features:**
- TransUNet model with transformer layers
- MIFOCAT unified loss function
- 5-fold cross-validation
- Per-fold checkpointing
- Aggregated metrics across folds
- Results visualization

**Next steps:**
1. Run detailed evaluation metrics (Dice, IoU, Hausdorff, MCC) on test sets
2. Compare TransUNet results with U-Net baseline
3. Generate prediction visualizations
4. Statistical significance testing