In [1]:
# Import libraries
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import tensorflow as tf
import os
import sys
import cv2
import re
import logging
from collections import defaultdict
from mpl_toolkits.axes_grid1 import make_axes_locatable
import scipy
import skimage
import tqdm
import os
import numpy as np
import matplotlib.pyplot as plt
import sys
import tensorflow as tf
import json
from datetime import datetime


# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Configure Neurite backend for VoxelMorph
os.environ['NEURITE_BACKEND'] = 'tensorflow'  # Must be set BEFORE importing neurite/voxelmorph
import neurite
from data.voxelmorph import voxelmorph as vxm



# Verify installed packages
print("\n--- Package Versions ---")
print(f"- Python: {sys.version.split()[0]}")
print(f"- TensorFlow: {tf.__version__}")
print(f"- VoxelMorph: {vxm.__version__ if hasattr(vxm, '__version__') else 'custom'}")
print(f"- Neurite: {neurite.__version__}")
print(f"- OpenCV: {cv2.__version__}")
print(f"- scikit-image: {skimage.__version__}")
print(f"- Matplotlib: {matplotlib.__version__}")
print(f"- tqdm: {tqdm.__version__}")
print(f"- SciPy: {scipy.__version__}")


# Set up environment paths
print("\n--- Environment Configuration ---")
print(f"Neurite backend: {neurite.backend}")
print(f"VoxelMorph path: {os.path.dirname(vxm.__file__)}")
print(f"Python path: {sys.path[:3]}...")



# Define data directories based on environment
LOCAL_DATA_DIR = './data'  # Default data directory for local execution
print(f"Using data directory: {LOCAL_DATA_DIR}")



--- Package Versions ---
- Python: 3.10.9
- TensorFlow: 2.19.0
- VoxelMorph: 0.2
- Neurite: 0.2
- OpenCV: 4.11.0
- scikit-image: 0.25.2
- Matplotlib: 3.10.1
- tqdm: 4.67.1
- SciPy: 1.15.2

--- Environment Configuration ---
Neurite backend: tensorflow
VoxelMorph path: d:\study\graduation_project\VXM\data\voxelmorph\voxelmorph
Python path: ['c:\\Users\\AliBadran\\AppData\\Local\\Programs\\Python\\Python310\\python310.zip', 'c:\\Users\\AliBadran\\AppData\\Local\\Programs\\Python\\Python310\\DLLs', 'c:\\Users\\AliBadran\\AppData\\Local\\Programs\\Python\\Python310\\lib']...
Using data directory: ./data


In [2]:
def create_path_structure(base_path):
    """Create and return a dictionary with all necessary data paths"""
    paths = {
        # Base paths
        'BASE_DATA_PATH': base_path,
        'MODELS_BASE_PATH': os.path.join(base_path, 'Models'),
        
        # ACDC and Sunnybrook paths
        'ACDC_BASE': '',
        'SUNNYBROOK_BASE': '',
        
        # Regular data paths
        'train_data': os.path.join(base_path, 'train'),
        'val_data': os.path.join(base_path, 'val'),
        'test_data': os.path.join(base_path, 'test'),
        'mask_data': os.path.join(base_path, 'ACDC-Masks-1'),
        'MODEL_TESTING_PATH': os.path.join(base_path, 'model_testing'),
        
        # Simulated data paths
        'train_simulated_data': os.path.join(base_path, 'Simulated_train'),
        'val_simulated_data': os.path.join(base_path, 'Simulated_val'),
        'test_simulated_data': os.path.join(base_path, 'Simulated_test'),
        'mask_simulated_data': os.path.join(base_path, 'Simulated_masks'),
        'displacement_simulated_data': os.path.join(base_path, 'Simulated_displacements'),
    }
    
    # Convenience aliases
    paths['SIMULATED_DATA_PATH'] = paths['test_simulated_data']
    paths['SIMULATED_MASK_PATH'] = paths['mask_simulated_data']
    paths['SIMULATED_DISP_PATH'] = paths['displacement_simulated_data']
    
    return paths

def check_paths(paths):
    """Verify existence of required paths with enhanced feedback"""
    missing_paths = []
    existing_paths = []

    print("\nChecking data paths:")
    for name, path in paths.items():
        if not isinstance(path, str) or not path:  # Skip empty paths or non-string values
            continue
        exists = os.path.exists(path)
        status = "✓" if exists else "✗"
        print(f"  {status} {name}: {path}")

        if exists:
            existing_paths.append(path)
        else:
            missing_paths.append(path)

    return existing_paths, missing_paths

def validate_environment(paths):
    """Validate paths and provide appropriate error messages"""
    paths_to_check = {
        'Simulated Training': paths['train_simulated_data'],
        'Simulated Validation': paths['val_simulated_data'],
        'Simulated Testing': paths['test_simulated_data'],
        'Simulated Masks': paths['mask_simulated_data'],
        'Simulated Displacements': paths['displacement_simulated_data'],
        'Train Data': paths['train_data'],
        'Validation Data': paths['val_data'],
        'Test Data': paths['test_data'],
        'Mask Data': paths['mask_data'],
    }

    existing, missing = check_paths(paths_to_check)

    if missing:
        print("\n⚠️ Missing paths detected!")
        base_dir = paths['BASE_DATA_PATH']
        print(f"Please ensure your local data directory ({base_dir}) contains:")
        print("- Simulated_train/Simulated_val/Simulated_test folders")
        print("- Simulated_masks folder")
        print("- Simulated_displacements folder")
        print("- ACDC-Masks-1 folder")
        print("- model_testing")
        print("- train/val/test folders")
        # Uncomment to enforce strict checking
            # raise FileNotFoundError("Missing required data paths")
    
    return len(missing) == 0

def create_model_config():
    """Create model configuration dictionary"""
    return {
        # 1. No Mask (Baseline)
        # 'no_mask': {
        #     'name': 'voxelmorph_no_mask',
        #     'use_mask': False,
        #     'use_mse_mask': False,
        #     'use_smoothness_mask': False
        # },
        # Commented configurations
        # 'mse_mask': {
        #     'name': 'voxelmorph_mse_mask',
        #     'use_mask': True,
        #     'use_mse_mask': True,
        #     'use_smoothness_mask': False
        # },
        # 'smoothness_mask': {
        #     'name': 'voxelmorph_smoothness_mask',
        #     'use_mask': True,
        #     'use_mse_mask': False,
        #     'use_smoothness_mask': True
        # },
        'both_masks': {
            'name': 'voxelmorph_both_masks',
            'use_mask': True,
            'use_mse_mask': True,
            'use_smoothness_mask': True
        }
    }

def create_kernel_configs():
    """Create kernel configuration dictionary"""
    return {
        'default': {
            'encoder': [[3], [3], [3], [3]],
            'decoder': [[3], [3], [3], [3]],
            'final' : [3, 3, 3]
        },
        'first5': {
            'encoder': [[5], [3], [3], [3]], # first layer 5 rest is 3
            'decoder': [[3], [3], [3], [3]],
            'final' : [3, 3, 3]
        },
        'first7_second5': {
            'encoder': [[7], [5], [3], [3]],
            'decoder': [[3], [3], [3], [3]],
            'final' : [3, 3, 3]
        }
    }

def setup_model_directories(model_config, kernel_configs, lambdas, models_base_path):
    """Generate model variable mappings and create necessary directories"""
    model_var_map = {}
    
    # Generate mappings
    for model_key in model_config:
        for kernel_key in kernel_configs:
            for lambda_val in lambdas:
                # Create a clean variable name
                var_name = f"vm_model_{model_key}_kernel_{kernel_key}_lambda_{lambda_val:.3f}".replace('.', '_')
                config_key = f"{model_key}_kernel_{kernel_key}_lambda_{lambda_val:.3f}"
                model_var_map[config_key] = var_name
                
                # Add path to model config
                folder_path = os.path.join(models_base_path, 
                                          f"{model_config[model_key]['name']}_kernel_{kernel_key}_lambda_{lambda_val:.3f}")
                model_config[model_key][f'kernel_{kernel_key}_lambda_{lambda_val:.3f}'] = {
                    'folder': folder_path
                }
                
                # Create directories
                os.makedirs(os.path.join(folder_path, 'weights'), exist_ok=True)
                os.makedirs(os.path.join(folder_path, 'results'), exist_ok=True)
                os.makedirs(os.path.join(folder_path, 'logs'), exist_ok=True)
    
    return model_var_map

# Initialize everything
# Define lambda values
# LAMBDAS = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
LAMBDAS = [0.016, 0.033, 0.066, 0.1]
# Alternative lambdas: [0.016, 0.033, 0.066, 0.1, 0.3, 0.5]

# Setup paths
DATA_PATHS = create_path_structure(LOCAL_DATA_DIR)
# Expose all paths as individual variables for backward compatibility
for key, value in DATA_PATHS.items():
    globals()[key] = value

# Create and validate model configurations
MODEL_CONFIG = create_model_config()
KERNEL_CONFIGS = create_kernel_configs()
KERNEL_KEYS = list(KERNEL_CONFIGS.keys())

# Validate environment
is_valid = validate_environment(DATA_PATHS)

# Setup model directories and create variable mappings
MODEL_VAR_MAP = setup_model_directories(MODEL_CONFIG, KERNEL_KEYS, LAMBDAS, DATA_PATHS['MODELS_BASE_PATH'])

USE_CUSTOM_VXM = True  # Set to True if using custom VoxelMorph implementation


Checking data paths:
  ✓ Simulated Training: ./data\Simulated_train
  ✓ Simulated Validation: ./data\Simulated_val
  ✓ Simulated Testing: ./data\Simulated_test
  ✓ Simulated Masks: ./data\Simulated_masks
  ✓ Simulated Displacements: ./data\Simulated_displacements
  ✓ Train Data: ./data\train
  ✓ Validation Data: ./data\val
  ✓ Test Data: ./data\test
  ✓ Mask Data: ./data\ACDC-Masks-1


In [3]:
#### Model Creation
##### MSE Loss
class MSE:
    """
    Sigma-weighted mean squared error for image reconstruction.
    """

    def __init__(self, image_sigma=1.0):
        self.image_sigma = image_sigma

    def mse(self, y_true, y_pred):
        return K.square(y_true - y_pred)

    def loss(self, y_true, y_pred, reduce='mean'):
        # compute mse
        mse = self.mse(y_true, y_pred)

        mask = y_true[..., 1]  # Second channel for fixed mask
        # apply mask
        mse = mse * tf.expand_dims(mask, axis=-1)

        # reduce
        if reduce == 'mean':
            mse = K.mean(mse)
        elif reduce == 'max':
            mse = K.max(mse)
        elif reduce is not None:
            raise ValueError(f'Unknown MSE reduction type: {reduce}')
        # loss
        return 1.0 / (self.image_sigma ** 2) * mse

##### Smootheness Loss
class Grad:
    """
    N-D gradient loss.
    loss_mult can be used to scale the loss value - this is recommended if
    the gradient is computed on a downsampled vector field (where loss_mult
    is equal to the downsample factor).
    """

    def __init__(self, penalty='l1', loss_mult=None, vox_weight=None):
        self.penalty = penalty
        self.loss_mult = loss_mult
        self.vox_weight = vox_weight

    def _diffs(self, y):
        vol_shape = y.get_shape().as_list()[1:-1]
        ndims = len(vol_shape)

        df = [None] * ndims
        for i in range(ndims):
            d = i + 1
            # permute dimensions to put the ith dimension first
            r = [d, *range(d), *range(d + 1, ndims + 2)]
            yp = K.permute_dimensions(y, r)
            dfi = yp[1:, ...] - yp[:-1, ...]

            if self.vox_weight is not None:
                w = K.permute_dimensions(self.vox_weight, r)
                # TODO: Need to add square root, since for non-0/1 weights this is bad.
                dfi = w[1:, ...] * dfi

            # permute back
            # note: this might not be necessary for this loss specifically,
            # since the results are just summed over anyway.
            r = [*range(1, d + 1), 0, *range(d + 1, ndims + 2)]
            df[i] = K.permute_dimensions(dfi, r)

        return df

    def loss(self, y_true, y_pred):
        """
        returns Tensor of size [bs]
        """
        mask = y_true[..., 1]  # [batch, H, W]
        mask = tf.expand_dims(mask, -1)  # [batch, H, W, 1]

        # Resize the mask to match the spatial dimensions of y_pred
        target_size = tf.shape(y_pred)[1:3]  # assuming y_pred shape: [batch, new_H, new_W, channels]
        # Use bilinear interpolation for continuous values
        mask = tf.image.resize(mask, size=target_size, method="bilinear")

        self.vox_weight = mask

        # Reset y_true[..., 1] to zero to restore it as zero_phi
        x_channel = y_true[..., 0:1]  # [batch, H, W, 1]
        zero_channel = tf.zeros_like(x_channel)  # [batch, H, W, 1]
        y_true = tf.concat([x_channel, zero_channel], axis=-1)  # [batch, H, W, 2]

        if self.penalty == 'l1':
            dif = [tf.abs(f) for f in self._diffs(y_pred)]
        else:
            assert self.penalty == 'l2', 'penalty can only be l1 or l2. Got: %s' % self.penalty
            dif = [f * f for f in self._diffs(y_pred)]

        df = [tf.reduce_mean(K.batch_flatten(f), axis=-1) for f in dif]
        grad = tf.add_n(df) / len(df)

        if self.loss_mult is not None:
            grad *= self.loss_mult

        return grad

    def mean_loss(self, y_true, y_pred):
        """
        returns Tensor of size ()
        """

        return K.mean(self.loss(y_true, y_pred))
##### Model params
def create_voxelmorph_model(use_mse_mask=False, use_smoothness_mask=False, kernel_config='default', lambda_val=0.1):
    input_shape = (128, 128)
    src_feats = 1  # Moving image has 1 channel
    trg_feats = 3 if (use_mse_mask or use_smoothness_mask) else 1  # Fixed image + mask channels

    # Input layers
    source_input = tf.keras.Input(shape=(*input_shape, src_feats), name='source_input')
    target_input = tf.keras.Input(shape=(*input_shape, trg_feats), name='target_input')

    # Build VxmDense model
    nb_features = [
        [16, 32, 32, 32],  # encoder
        [32, 32, 32, 32, 32, 16, 16]  # decoder
    ]

    # Get kernel configuration
    kernels = KERNEL_CONFIGS[kernel_config] if USE_CUSTOM_VXM else None

    # Create base VxmDense model
    vm_model = vxm.networks.VxmDense(
        inshape=input_shape,
        nb_unet_features=nb_features,
        unet_kernel_sizes=kernels,
        src_feats=src_feats,
        trg_feats=trg_feats,
        input_model=tf.keras.Model(inputs=[source_input, target_input], outputs=[source_input, target_input]),
        int_steps=5,
        reg_field = 'warp'
    )

    # Configure losses
    losses = []
    loss_weights = []

    # Loss functions
    losses = []
    loss_weights = []

    # 1. MSE Loss (with optional mask)
    if use_mse_mask:
        # Custom MSE loss with BG-to-myocardium ratio mask
        losses.append(MSE().loss)
    else:
        losses.append(vxm.losses.MSE().loss)

    loss_weights.append(1)  # Weight for similarity loss

    # 2. Smoothness Loss (with optional mask)
    if use_smoothness_mask:
        # Custom smoothness loss
        losses.append(Grad('l2').loss)
    else:
        losses.append(vxm.losses.Grad('l2').loss)

    loss_weights.append(lambda_val)  # Weight for smoothness loss

    # Compile model
    vm_model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
        loss=losses,
        loss_weights=loss_weights
    )
    return vm_model



In [4]:
def load_model_for_eval(config, kernel_key, lambda_val, load_best=True, epoch=None):
    """
    Robust model loading with architecture verification to load either the best model
    based on loss or a specific epoch's weights.

    Parameters:
    - config: Model configuration dictionary.
    - kernel_key: Kernel configuration key.
    - lambda_val: Lambda value for smoothness loss.
    - load_best: Boolean, if True loads the best model based on loss (considering all model files), if False loads the specified epoch (default=True).
    - epoch: Integer, the epoch number to load (required if load_best=False, ignored otherwise).

    Returns:
    - Loaded model or None if loading fails.
    """
    # Get model directory
    model_dir = os.path.join(
        config[f'kernel_{kernel_key}_lambda_{lambda_val:.3f}']['folder'],
        'weights'
    )

    # Check directory exists
    if not os.path.exists(model_dir):
        print(f"⚠️ Directory not found: {model_dir}")
        return None

    # Define custom objects
    custom_objects = {
        'Grad': Grad,
        'MSE': MSE,
        'Adam': tf.keras.optimizers.Adam,
        'vxm': vxm.losses  # If using original voxelmorph losses
    }

    if load_best:
        # Find all relevant model files:
        # 1. best_model_val_loss_* files (.weights.h5 or .keras)
        # 2. epoch*_loss*.weights.h5 files
        # 3. Other .keras files (e.g., final_model.keras)
        model_files = [f for f in os.listdir(model_dir) if (f.startswith('best_model_val_loss_') and (f.endswith('.weights.h5') or f.endswith('.keras'))) or
                       (f.startswith('epoch') and f.endswith('.weights.h5') and re.match(r'epoch\d+_loss\d+\.\d+\.weights\.h5', f)) or
                       (f.endswith('.keras') and not f.startswith('best_model_val_loss_'))]

        if not model_files:
            print(f"⛔ No model files found in {model_dir}")
            return None

        # Parse loss from filenames and find the best model
        best_model = None
        lowest_loss = float('inf')
        file_extension = None
        loss_type = None  # To track the type of loss (val_loss, train_loss, or unknown)

        for model_file in model_files:
            loss = float('inf')  # Default loss for files without a loss value in the name
            loss_type_candidate = 'unknown'

            if model_file.startswith('best_model_val_loss_'):
                # Extract validation loss from best_model_val_loss_* files
                match = re.search(r'best_model_val_loss_(\d+\.\d+)', model_file)
                if match:
                    loss = float(match.group(1))
                    loss_type_candidate = 'val_loss'
            elif model_file.startswith('epoch'):
                # Extract training loss from epoch*_loss*.weights.h5 files
                match = re.search(r'epoch\d+_loss(\d+\.\d+)\.weights\.h5', model_file)
                if match:
                    loss = float(match.group(1))
                    loss_type_candidate = 'train_loss'
            else:
                # For other .keras files (e.g., final_model.keras), we can't determine loss from the filename
                # Assign a high loss to deprioritize unless it's the only option
                loss_type_candidate = 'unknown'
                print(f"⚠️ No loss value found in filename {model_file}. Deprioritizing this file.")

            if loss < lowest_loss:
                lowest_loss = loss
                best_model = model_file
                loss_type = loss_type_candidate
                # Determine file extension based on the selected file
                file_extension = '.keras' if model_file.endswith('.keras') else '.weights.h5'

        if best_model is None:
            print(f"⛔ Could not determine best model in {model_dir}")
            return None

        model_path = os.path.join(model_dir, best_model)
        if loss_type == 'unknown':
            print(f"Loading best model: {best_model} (no loss value available in filename)")
        else:
            print(f"Loading best model: {best_model} with {loss_type}={lowest_loss}")

    else:
        # Load specific epoch weights (format: epoch{epoch:02d}_loss{loss:.5f}.weights.h5)
        if epoch is None:
            print("⛔ Epoch number must be specified when load_best=False")
            return None

        # Look for files matching the specified epoch
        epoch_pattern = f'epoch{epoch:02d}_loss[0-9]+\.[0-9]+\.weights\.h5'
        model_files = [f for f in os.listdir(model_dir) if re.match(epoch_pattern, f)]
        if not model_files:
            print(f"⛔ No weight files found for epoch {epoch} in {model_dir}")
            return None

        # There should be only one file matching the epoch
        if len(model_files) > 1:
            print(f"⚠️ Multiple files found for epoch {epoch}: {model_files}. Using the first one.")

        best_model = model_files[0]
        model_path = os.path.join(model_dir, best_model)
        file_extension = '.weights.h5'
        print(f"Loading epoch-specific model: {best_model}")

    try:
        # Recreate model architecture first
        model = create_voxelmorph_model(
            use_mse_mask=config['use_mse_mask'],
            use_smoothness_mask=config['use_smoothness_mask'],
            kernel_config=kernel_key,
            lambda_val=lambda_val
        )

        # Load weights into architecture
        model.load_weights(model_path)
        print(f"✅ Successfully loaded {best_model}")
        return model

    except Exception as e:
        print(f"❌ Loading weights failed: {str(e)}")
        print("Trying fallback load method...")
        try:
            # Fallback only makes sense for .keras files (full model)
            if file_extension == '.keras':
                return tf.keras.models.load_model(
                    model_path,
                    custom_objects=custom_objects
                )
            else:
                raise
        except Exception as e2:
            print(f"⛔ Critical load failure: {str(e2)}")
            return None
        
import os
import gc
import logging
import time
import numpy as np
import tensorflow as tf
from docx import Document
from docx.shared import Inches
from docx.oxml.shared import qn
from tqdm import tqdm
from contextlib import contextmanager

@contextmanager
def memory_cleanup():
    """Context manager for proper memory cleanup"""
    try:
        yield
    finally:
        # Clear Keras session
        tf.keras.backend.clear_session()
        # Force garbage collection
        gc.collect()
        # Clear GPU memory if available
        if tf.config.list_physical_devices('GPU'):
            try:
                # Reset memory growth or clear memory
                for gpu in tf.config.experimental.list_physical_devices('GPU'):
                    tf.config.experimental.set_memory_growth(gpu, True)
            except Exception as e:
                logging.warning(f"GPU memory cleanup warning: {e}")

def load_model_for_eval_robust(config, kernel_key, lambda_val, load_best=True, epoch=None, max_retries=3):
    """
    Enhanced model loading with better error handling and validation.
    """
    model_dir = os.path.join(
        config[f'kernel_{kernel_key}_lambda_{lambda_val:.3f}']['folder'],
        'weights'
    )

    if not os.path.exists(model_dir):
        logging.error(f"Model directory not found: {model_dir}")
        return None

    # Define custom objects once
    custom_objects = {
        'Grad': Grad,
        'MSE': MSE,
        'Adam': tf.keras.optimizers.Adam,
    }

    for attempt in range(max_retries):
        try:
            if load_best:
                model_path, file_extension, loss_info = _find_best_model(model_dir)
                if not model_path:
                    logging.error(f"No valid model files found in {model_dir}")
                    return None
                logging.info(f"Loading best model: {os.path.basename(model_path)} {loss_info}")
            else:
                model_path, file_extension = _find_epoch_model(model_dir, epoch)
                if not model_path:
                    logging.error(f"No model file found for epoch {epoch} in {model_dir}")
                    return None
                logging.info(f"Loading epoch model: {os.path.basename(model_path)}")

            # Create model architecture
            model = create_voxelmorph_model(
                use_mse_mask=config['use_mse_mask'],
                use_smoothness_mask=config['use_smoothness_mask'],
                kernel_config=kernel_key,
                lambda_val=lambda_val
            )

            # Load weights
            if file_extension == '.weights.h5':
                model.load_weights(model_path)
            else:  # .keras file
                loaded_model = tf.keras.models.load_model(model_path, custom_objects=custom_objects)
                return loaded_model

            logging.info(f"✅ Successfully loaded model (attempt {attempt + 1})")
            return model

        except Exception as e:
            logging.warning(f"Loading attempt {attempt + 1} failed: {str(e)}")
            if attempt == max_retries - 1:
                logging.error(f"All {max_retries} loading attempts failed")
                return None
            time.sleep(1)  # Brief pause before retry

    return None

def _find_best_model(model_dir):
    """Find the best model file based on validation loss."""
    import re
    
    model_files = [f for f in os.listdir(model_dir) 
                   if f.endswith(('.weights.h5', '.keras'))]
    
    if not model_files:
        return None, None, None

    best_model = None
    lowest_loss = float('inf')
    file_extension = None
    loss_info = ""

    for model_file in model_files:
        loss = float('inf')
        loss_type = 'unknown'

        if model_file.startswith('best_model_val_loss_'):
            match = re.search(r'best_model_val_loss_(\d+\.\d+)', model_file)
            if match:
                loss = float(match.group(1))
                loss_type = 'val_loss'
        elif model_file.startswith('epoch'):
            match = re.search(r'epoch\d+_loss(\d+\.\d+)\.weights\.h5', model_file)
            if match:
                loss = float(match.group(1))
                loss_type = 'train_loss'

        if loss < lowest_loss:
            lowest_loss = loss
            best_model = model_file
            file_extension = '.keras' if model_file.endswith('.keras') else '.weights.h5'
            loss_info = f"({loss_type}={lowest_loss:.6f})" if loss_type != 'unknown' else "(no loss info)"

    if best_model:
        return os.path.join(model_dir, best_model), file_extension, loss_info
    return None, None, None

def _find_epoch_model(model_dir, epoch):
    """Find model file for specific epoch."""
    import re
    
    epoch_pattern = f'epoch{epoch:02d}_loss[0-9]+\.[0-9]+\.weights\.h5'
    model_files = [f for f in os.listdir(model_dir) if re.match(epoch_pattern, f)]
    
    if model_files:
        return os.path.join(model_dir, model_files[0]), '.weights.h5'
    return None, None


In [5]:
def enforce_full_principal_strain_order(Ep1All, Ep2All, Ep3All=None):
    """
    Ensure Ep1All >= Ep2All >= Ep3All at every pixel location.
    Sorts the principal strains per point.

    Args:
        Ep1All (np.ndarray): First principal strain field (shape: H, W).
        Ep2All (np.ndarray): Second principal strain field (shape: H, W).
        Ep3All (np.ndarray, optional): Third principal strain field (shape: H, W).

    Returns:
        Ep1_sorted (np.ndarray): Largest principal strain (shape: H, W).
        Ep2_sorted (np.ndarray): Middle principal strain (shape: H, W).
        Ep3_sorted (np.ndarray): Smallest principal strain (shape: H, W).
    """
    # Validate input shapes
    if Ep1All.shape != Ep2All.shape or Ep1All.ndim != 2:
        raise ValueError(f"Invalid input shapes: Ep1All={Ep1All.shape}, Ep2All={Ep2All.shape}, expected 2D arrays")

    if Ep3All is None:
        # Assume Ep3All is zero for 2D case
        Ep3All = np.zeros_like(Ep1All)
    elif Ep3All.shape != Ep1All.shape:
        raise ValueError(f"Invalid Ep3All shape: {Ep3All.shape}, expected {Ep1All.shape}")

    # Stack principal strains along a new axis
    strain_stack = np.stack([Ep1All, Ep2All, Ep3All], axis=0)  # Shape: (3, H, W)

    # Sort along the strain axis (axis=0) in descending order
    strain_sorted = np.sort(strain_stack, axis=0)[::-1, ...]  # Shape: (3, H, W)

    Ep1_sorted = strain_sorted[0]  # Largest
    Ep2_sorted = strain_sorted[1]  # Middle
    Ep3_sorted = strain_sorted[2]  # Smallest

    return Ep1_sorted, Ep2_sorted, Ep3_sorted

def limit_strain_range(FrameDisplX, FrameDisplY, deltaX=1, deltaY=1):
    """
    Compute principal strains (Ep1, Ep2) and incompressibility strain (Ep3)
    from 2D displacement fields.

    Args:
        FrameDisplX (np.ndarray): X displacement field (shape: H, W).
        FrameDisplY (np.ndarray): Y displacement field (shape: H, W).
        deltaX (float): Pixel spacing in the X direction (mm).
        deltaY (float): Pixel spacing in the Y direction (mm).

    Returns:
        dx (None): Placeholder for compatibility.
        dy (None): Placeholder for compatibility.
        initial_strain_tensor (dict): Initial strain tensor (E1, E2, E3).
        final_strain_tensor (dict): Final strain tensor (E1, E2, E3).
        max_initial_strain (float): Maximum initial strain.
        max_strain (float): Maximum final strain.
        min_initial_strain (float): Minimum initial strain.
        min_strain (float): Minimum final strain.
    """
    # Validate input shapes
    if FrameDisplX.shape != FrameDisplY.shape or FrameDisplX.ndim != 2:
        raise ValueError(f"Invalid displacement shapes: FrameDisplX={FrameDisplX.shape}, FrameDisplY={FrameDisplY.shape}, expected 2D arrays")

    # Compute spatial gradients
    UXx, UXy = np.gradient(FrameDisplX, deltaX, deltaY, axis=(0, 1))
    UYx, UYy = np.gradient(FrameDisplY, deltaX, deltaY, axis=(0, 1))

    # Validate gradient shapes
    if UXx.shape != FrameDisplX.shape:
        raise ValueError(f"Gradient shape mismatch: UXx={UXx.shape}, expected {FrameDisplX.shape}")

    # Compute Eulerian strain tensor components
    ExxAll = (2 * UXx - (UXx**2 + UYx**2)) / 2
    ExyAll = (UXy + UYx - (UXx * UXy + UYx * UYy)) / 2
    EyyAll = (2 * UYy - (UXy**2 + UYy**2)) / 2

    # Compute principal strains
    Ep1All = (ExxAll + EyyAll) / 2 + np.sqrt(((ExxAll - EyyAll) / 2)**2 + ExyAll**2)
    Ep2All = (ExxAll + EyyAll) / 2 - np.sqrt(((ExxAll - EyyAll) / 2)**2 + ExyAll**2)

    # Enforce principal strain order
    Ep1All, Ep2All, Ep3All = enforce_full_principal_strain_order(Ep1All, Ep2All)

    # Compute incompressibility strain
    Ep3All = 1 / ((1 + np.maximum(Ep1All, Ep2All)) * (1 + np.minimum(Ep1All, Ep2All))) - 1

    # Create strain tensors
    initial_strain_tensor = {'E1': Ep1All, 'E2': Ep2All, 'E3': Ep3All}
    final_strain_tensor = {'E1': Ep1All, 'E2': Ep2All, 'E3': Ep3All}

    # Compute min/max strains
    max_initial_strain = np.max(Ep1All)
    max_strain = np.max(Ep1All)
    min_initial_strain = np.min(Ep2All)
    min_strain = np.min(Ep2All)

    return None, None, initial_strain_tensor, final_strain_tensor, max_initial_strain, max_strain, min_initial_strain, min_strain

In [6]:
def extract_patient_frame_data(simulated_data_path, simulated_mask_path, simulated_displacement_path, 
                              patient_number, frame_number, slice_number):
    """
    Extract all available frame data for a specific patient, frame (t), and slice (z) combination.
    
    Args:
        simulated_data_path (str): Path to simulated data directory
        simulated_mask_path (str): Path to simulated mask directory  
        simulated_displacement_path (str): Path to simulated displacement directory
        patient_number (int): Patient number (e.g., 1 for patient001)
        frame_number (int): Frame number (t value)
        slice_number (int): Slice number (z value)
        
    Returns:
        dict: Dictionary where keys are the # numbers and values contain:
            {
                'first_frame': numpy array of the base frame (ending with _1),
                'target_frame': numpy array of the # frame,
                'first_mask': numpy array of the base frame mask,
                'target_mask': numpy array of the # frame mask,
                'displacement_x': numpy array of x displacement,
                'displacement_y': numpy array of y displacement,
                'frame_number': the # number for reference
            }
    """
    
    # Format patient, frame, and slice strings
    patient_id = f"patient{patient_number:03d}"
    t_str = f"{frame_number:02d}"
    z_str = f"{slice_number:02d}"
    
    # Patient folder path
    patient_folder = os.path.join(simulated_data_path, patient_id)
    mask_folder = os.path.join(simulated_mask_path, patient_id)
    disp_folder = os.path.join(simulated_displacement_path, patient_id)
    
    # Validate directories exist
    if not os.path.exists(patient_folder):
        raise ValueError(f"Patient folder does not exist: {patient_folder}")
    if not os.path.exists(mask_folder):
        raise ValueError(f"Mask folder does not exist: {mask_folder}")
    if not os.path.exists(disp_folder):
        raise ValueError(f"Displacement folder does not exist: {disp_folder}")
    
    # Get base name (remove _z suffix if present)
    base_name = patient_id.split('_z')[0] if '_z' in patient_id else patient_id
    
    # Find all files matching the pattern for this t and z
    files = os.listdir(patient_folder)
    matching_files = []
    
    for fname in files:
        if fname.endswith('.npy') and f"_t{t_str}_z{z_str}#" in fname:
            matching_files.append(fname)
    
    if not matching_files:
        logging.warning(f"No files found for patient {patient_id}, t={frame_number}, z={slice_number}")
        return {}
    
    # Dictionary to store results
    result_data = {}
    
    # Group files by frame number
    frame_groups = defaultdict(list)
    for fname in matching_files:
        try:
            # Extract frame number from filename
            frame_part = fname.split('#')[1].split('.')[0]
            if frame_part.endswith('_1'):
                frame_num = int(frame_part[:-2])  # Remove '_1' suffix
                frame_groups[frame_num].append(('base', fname))
            else:
                frame_num = int(frame_part)
                frame_groups[frame_num].append(('target', fname))
        except (ValueError, IndexError) as e:
            logging.warning(f"Error parsing filename {fname}: {str(e)}")
            continue
    
    # Process each frame group
    for frame_num, files in frame_groups.items():
        try:
            base_file = None
            target_file = None
            
            # Find base and target files for this frame number
            for file_type, fname in files:
                if file_type == 'base':
                    base_file = fname
                elif file_type == 'target':
                    target_file = fname
            
            # Skip if we don't have both base and target files
            if not base_file or not target_file:
                logging.warning(f"Missing base or target file for frame {frame_num}")
                continue
            
            # Load base frame data
            base_frame_path = os.path.join(patient_folder, base_file)
            base_frame_data = np.load(base_frame_path).astype(np.float32)
            
            # Load target frame data
            target_frame_path = os.path.join(patient_folder, target_file)
            target_frame_data = np.load(target_frame_path).astype(np.float32)
            
            # Load base mask
            base_mask_path = os.path.join(mask_folder, base_file)
            if not os.path.exists(base_mask_path):
                logging.warning(f"Base mask file not found: {base_mask_path}")
                continue
            base_mask_data = np.load(base_mask_path).astype(np.float32)
            
            # Load target mask
            target_mask_path = os.path.join(mask_folder, target_file)
            if not os.path.exists(target_mask_path):
                logging.warning(f"Target mask file not found: {target_mask_path}")
                continue
            target_mask_data = np.load(target_mask_path).astype(np.float32)
            
            # Load displacement files
            disp_x_file = f"{base_name}_t{t_str}_z{z_str}#{frame_num}_x.npy"
            disp_y_file = f"{base_name}_t{t_str}_z{z_str}#{frame_num}_y.npy"
            
            disp_x_path = os.path.join(disp_folder, disp_x_file)
            disp_y_path = os.path.join(disp_folder, disp_y_file)
            
            if not os.path.exists(disp_x_path) or not os.path.exists(disp_y_path):
                logging.warning(f"Displacement files not found: {disp_x_path} or {disp_y_path}")
                continue
                
            disp_x_data = np.load(disp_x_path).astype(np.float32)
            disp_y_data = np.load(disp_y_path).astype(np.float32)
            
            # Store in result dictionary
            result_data[frame_num] = {
                'first_frame': base_frame_data.copy(),
                'target_frame': target_frame_data.copy(),
                'first_mask': base_mask_data.copy(),
                'target_mask': target_mask_data.copy(),
                'displacement_x': disp_x_data.copy(),
                'displacement_y': disp_y_data.copy(),
                'frame_number': frame_num
            }
            
        except (ValueError, IndexError) as e:
            logging.warning(f"Error processing frame {frame_num}: {str(e)}")
            continue
    
    logging.info(f"Extracted data for {len(result_data)} frame pairs for patient {patient_id}, t={frame_number}, z={slice_number}")
    # Sort the result_data by frame_number
    result_data = dict(sorted(result_data.items(), key=lambda item: item[0]))
    
    return result_data

In [7]:
# Example usage:
# Extract data for patient 1, frame 5, slice 2
data = extract_patient_frame_data(
    simulated_data_path= test_simulated_data,
    simulated_mask_path= mask_simulated_data, 
    simulated_displacement_path= displacement_simulated_data,
    patient_number=74,
    frame_number=1,
    slice_number=2
)

# Access data for frame #10
if 10 in data:
    first_frame = data[10]['first_frame']      # Base frame (_1)
    target_frame = data[10]['target_frame']     # Frame #10
    first_mask = data[10]['first_mask']         # Base frame mask
    target_mask = data[10]['target_mask']       # Frame #10 mask
    disp_x = data[10]['displacement_x']         # X displacement
    disp_y = data[10]['displacement_y']         # Y displacement

2025-06-19 01:58:13,366 - INFO - Extracted data for 24 frame pairs for patient patient074, t=1, z=2


In [8]:
def add_colorbar(fig, ax, im, label):
    """Adds a colorbar to the given axes."""
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.05)
    cbar = fig.colorbar(im, cax=cax)
    cbar.ax.set_ylabel(label, rotation=270, labelpad=15)
    
def save_images_for_skip(moving, fixed, warped, disp, save_dir, patient_id, frame_number, target_disp):
    """Save visualizations for a single frame number with improved error handling."""
    try:
        # Print debug information
        logging.info(f"Processing patient {patient_id}, frame {frame_number}")

        # Validate input shapes
        expected_shapes = {
            'moving': (None, None, None, 1),
            'fixed': (None, None, None, 1),
            'warped': (None, None, None, 1),
            'disp': (None, None, None, 2),
            'target_disp': (None, None, None, 2)
        }
        for name, arr in [('moving', moving), ('fixed', fixed), ('warped', warped), ('disp', disp), ('target_disp', target_disp)]:
            if not isinstance(arr, np.ndarray):
                raise ValueError(f"{name} is not a numpy array: {type(arr)}")
            if len(arr.shape) != 4 or arr.shape[0] == 0 or any(dim == 0 for dim in arr.shape[1:3]):
                raise ValueError(f"Invalid {name} shape: {arr.shape}, expected 4D with non-zero dimensions")

        # Handle case where inputs might be lists
        if isinstance(disp, list):
            disp = disp[0]
        if isinstance(target_disp, list):
            target_disp = target_disp[0]

        batch_size = moving.shape[0]

        # Ensure fixed has shape (batch, H, W, 1)
        if fixed.shape[-1] != 1:
            fixed = fixed[..., :1]

        for i in range(batch_size):
            # Extract images with proper dimension handling
            moving_img = moving[i, ..., 0].squeeze()
            warped_img = warped[i, ..., 0].squeeze()
            fixed_img = fixed[i, ..., 0].squeeze()

            # Extract displacement fields
            if len(disp.shape) == 4 and disp.shape[-1] >= 2:
                frame_displ_x = disp[i, ..., 0]
                frame_displ_y = disp[i, ..., 1]
            else:
                raise ValueError(f"Unexpected displacement shape: {disp.shape}, expected (batch, H, W, >=2)")

            if len(target_disp.shape) == 4 and target_disp.shape[-1] >= 2:
                target_displ_x = target_disp[i, ..., 0]
                target_displ_y = target_disp[i, ..., 1]
            else:
                raise ValueError(f"Unexpected target displacement shape: {target_disp.shape}, expected (batch, H, W, >=2)")

            # # Log shapes for debugging
            # logging.info(f"Moving image shape: {moving_img.shape}")
            # logging.info(f"Warped image shape: {warped_img.shape}")
            # logging.info(f"Fixed image shape: {fixed_img.shape}")
            # logging.info(f"Frame displacement X shape: {frame_displ_x.shape}")
            # logging.info(f"Frame displacement Y shape: {frame_displ_y.shape}")
            # logging.info(f"Target displacement X shape: {target_displ_x.shape}")
            # logging.info(f"Target displacement Y shape: {target_displ_y.shape}")

            # Calculate strain from predicted and target displacements
            try:
                # Predicted strains
                dx, dy, initial_strain_tensor, final_strain_tensor, max_initial_strain, max_strain, min_initial_strain, min_strain = limit_strain_range(
                    frame_displ_x, frame_displ_y, 1)
                pred_e1, pred_e2, pred_e3 = final_strain_tensor['E1'], final_strain_tensor['E2'], final_strain_tensor['E3']
                final_strain_pred = {'E1': pred_e1, 'E2': pred_e2, 'E3': pred_e3}

                # Target strains
                dx_t, dy_t, initial_strain_tensor_t, final_strain_tensor_t, max_initial_strain_t, max_strain_t, min_initial_strain_t, min_strain_t = limit_strain_range(
                    target_displ_x, target_displ_y, 1)
                target_e1, target_e2, target_e3 = final_strain_tensor_t['E1'], final_strain_tensor_t['E2'], final_strain_tensor_t['E3']
                final_strain_target = {'E1': target_e1, 'E2': target_e2, 'E3': target_e3}
            except Exception as strain_error:
                logging.error(f"Error calculating strain for frame {frame_number}: {str(strain_error)}")
                raise  # Re-raise to skip the frame instead of using zero tensors

            # Compute the differences between target and predicted strains
            e1_diff = final_strain_target['E1'] - final_strain_pred['E1']
            e2_diff = final_strain_target['E2'] - final_strain_pred['E2']

            # Create a figure with 3 rows and 5 columns
            fig, axes = plt.subplots(3, 5, figsize=(40, 21), constrained_layout=True)
            fig.suptitle(f"{patient_id} - Frame {frame_number} Analysis", fontsize=34, y=1.02)

            # --- First Row: Core Images ---
            images = [moving_img, fixed_img, warped_img]
            titles = ["Moving Image", "Fixed Image", "Warped Image"]
            for j, (img, title) in enumerate(zip(images, titles)):
                axes[0, j].imshow(img, cmap='gray')
                axes[0, j].set_title(title, fontsize=28)
                axes[0, j].axis('off')

            # Create RGB images
            warped_norm = (warped_img - warped_img.min()) / (np.ptp(warped_img) + 1e-8)
            fixed_norm = (fixed_img - fixed_img.min()) / (np.ptp(fixed_img) + 1e-8)
            moving_norm = (moving_img - moving_img.min()) / (np.ptp(moving_img) + 1e-8)

            rgb_wrpd_fxd = np.stack([warped_norm, fixed_norm, fixed_norm], axis=-1)
            axes[0, 3].imshow(rgb_wrpd_fxd)
            axes[0, 3].set_title("Warped (Red) over Fixed (RGB)", fontsize=28)
            axes[0, 3].axis('off')

            rgb_mvg_fxd = np.stack([moving_norm, fixed_norm, fixed_norm], axis=-1)
            axes[0, 4].imshow(rgb_mvg_fxd)
            axes[0, 4].set_title("Moving (Red) over Fixed (RGB)", fontsize=28)
            axes[0, 4].axis('off')

            # --- Second Row: Strain Analysis (Heatmaps) ---
            strain_min = min(np.min(final_strain_pred['E1']), np.min(final_strain_pred['E2']),
                             np.min(final_strain_target['E1']), np.min(final_strain_target['E2']))
            strain_max = max(np.max(final_strain_pred['E1']), np.max(final_strain_pred['E2']),
                             np.max(final_strain_target['E1']), np.max(final_strain_target['E2']))
            abs_max = max(abs(strain_min), abs(strain_max))
            vmin, vmax = -0.5, 0.5  # Symmetric colormap

            strain_images_pred = [final_strain_pred['E1'], final_strain_pred['E2']]
            strain_titles_pred = ["Final E1 Strain", "Final E2 Strain"]
            for j, (strain_img, title) in enumerate(zip(strain_images_pred, strain_titles_pred)):
                im = axes[2, j].imshow(strain_img, cmap='jet', vmin=vmin, vmax=vmax)
                axes[2, j].set_title(title, fontsize=28)
                axes[2, j].axis('off')
                add_colorbar(fig, axes[2, j], im, label="Strain (unitless)" if j == 1 else "")

            # Target E1 strain overlay
            axes[2, 2].imshow(fixed_img, cmap='gray', alpha=0.95)
            im_target_e1 = axes[2, 2].imshow(final_strain_target['E1'], cmap='jet', alpha=0.5, vmin=vmin, vmax=vmax)
            axes[2, 2].set_title("Target E1 Strain Overlay", fontsize=28)
            axes[2, 2].axis('off')
            add_colorbar(fig, axes[2, 2], im_target_e1, label="Strain (unitless)")

            # Difference heatmaps for E1 and E2
            diff_min = min(np.min(e1_diff), np.min(e2_diff))
            diff_max = max(np.max(e1_diff), np.max(e2_diff))
            diff_abs_max = max(abs(diff_min), abs(diff_max))
            diff_vmin, diff_vmax = 0, diff_abs_max

            abs_e1_diff = np.abs(e1_diff)
            abs_e2_diff = np.abs(e2_diff)

            im_e1_diff = axes[2, 3].imshow(abs_e1_diff, cmap='hot', vmin=diff_vmin, vmax=diff_vmax)
            axes[2, 3].set_title("Target E1 - Pred E1 Difference", fontsize=28)
            axes[2, 3].axis('off')
            add_colorbar(fig, axes[2, 3], im_e1_diff, label="Strain Difference")

            im_e2_diff = axes[2, 4].imshow(abs_e2_diff, cmap='hot', vmin=diff_vmin, vmax=diff_vmax)
            axes[2, 4].set_title("Target E2 - Pred E2 Difference", fontsize=28)
            axes[2, 4].axis('off')
            add_colorbar(fig, axes[2, 4], im_e2_diff, label="Strain Difference")

            # --- Third Row: Strain Overlays on Fixed Image ---
            overlay_titles = ["E1 Strain Overlay", "E2 Strain Overlay"]
            for j, (strain_img, title) in enumerate(zip(strain_images_pred, overlay_titles)):
                axes[1, j].imshow(fixed_img, cmap='gray', alpha=0.95)
                im_overlay = axes[1, j].imshow(strain_img, cmap='jet', alpha=0.5, vmin=vmin, vmax=vmax)
                axes[1, j].set_title(title, fontsize=28)
                axes[1, j].axis('off')
                add_colorbar(fig, axes[1, j], im_overlay, label="Strain (unitless)" if j == 1 else "")

            # Target E2 strain overlay
            axes[1, 2].imshow(fixed_img, cmap='gray', alpha=0.95)
            im_target_e2 = axes[1, 2].imshow(final_strain_target['E2'], cmap='jet', alpha=0.5, vmin=vmin, vmax=vmax)
            axes[1, 2].set_title("Target E2 Strain Overlay", fontsize=28)
            axes[1, 2].axis('off')
            add_colorbar(fig, axes[1, 2], im_target_e2, label="Strain (unitless)")

            # Error maps
            vmin_error = 0
            vmax_error = max(np.max(np.abs(fixed_norm - warped_norm)), np.max(np.abs(fixed_norm - moving_norm)))

            error_map = np.abs(fixed_norm - warped_norm)
            im = axes[1, 3].imshow(error_map, cmap='hot', vmin=vmin_error, vmax=vmax_error)
            axes[1, 3].set_title("F-W Local Registration Error Heatmap", fontsize=28)
            axes[1, 3].axis('off')
            add_colorbar(fig, axes[1, 3], im, label="")

            error_map = np.abs(fixed_norm - moving_norm)
            im = axes[1, 4].imshow(error_map, cmap='hot', vmin=vmin_error, vmax=vmax_error)
            axes[1, 4].set_title("F-M Local Registration Error Heatmap", fontsize=28)
            axes[1, 4].axis('off')
            add_colorbar(fig, axes[1, 4], im, label="Absolute Intensity Difference")

            # Save figure
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, f"frame_{frame_number}.png")
            plt.savefig(save_path, bbox_inches='tight', dpi=100)  # Reduced DPI for faster saving
            plt.close(fig)

            logging.info(f"Saved visualization: {save_path}")

    except Exception as e:
        logging.error(f"Error in save_images_for_skip for frame {frame_number}: {str(e)}")
        if 'fig' in locals():
            plt.close(fig)
        raise

In [9]:
def visualize_simulated_model(model, test_data, output_dir, use_mask=False):
    """
    Enhanced model evaluation to predict and visualize strain maps for test data.
    
    Args:
        model: Trained model for displacement prediction
        test_data (dict): Dictionary with:
        output_dir (str): Directory to save strain maps and metrics
        use_mask (bool): Whether to use masks in model input
    
    Returns:
        dict: Metrics including processed samples, skipped cases, and strain errors
    """
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Process each frame pair in the extracted data
    for frame_num, data in test_data.items():
        try:
            # Prepare batch data with correct dimensions
            moving = data['first_frame'][None, ..., None]  # Shape: (1, H, W, 1)
            target_frame = data['target_frame'][None, ..., None]  # Shape: (1, H, W, 1)
            first_mask = data['first_mask'][None, ..., None]  # Shape: (1, H, W, 1)
            target_mask = data['target_mask'][None, ..., None]  # Shape: (1, H, W, 1)
            
            # Target displacement - stack x and y components
            target_disp = np.stack([data['displacement_x'], data['displacement_y']], axis=-1)
            target_disp = target_disp[None, ...]  # Shape: (1, H, W, 2)
            
            # Validate batch shapes
            if not validate_batch_shapes(moving, target_frame, first_mask, target_mask, target_disp, use_mask):
                logging.warning(f"Invalid shapes for frame_num {frame_num}")
                continue
            
            # Prepare model input based on use_mask flag
            if use_mask:
                # Fixed input includes the target frame + both masks
                fixed_input = np.concatenate([target_frame, target_mask, first_mask], axis=-1)  # Shape: (1, H, W, 3)
            else:
                # Fixed input is just the target frame
                fixed_input = target_frame  # Shape: (1, H, W, 1)
            
            # Model prediction
            try:
                predictions = model.predict([moving, fixed_input], verbose=0, batch_size=1)
                
                # Handle different model output formats
                warped, pred_disp = predictions
                
                
                # pred_disp should have shape: (1, H, W, 2)
                if pred_disp.shape[-1] != 2:
                    logging.warning(f"Unexpected prediction shape: {pred_disp.shape}")
                    continue
                logging.info(f"Prediction successful for frame_num {frame_num}")
                
            except Exception as pred_error:
                logging.warning(f"Prediction failed for frame_num {frame_num}: {pred_error}")
                continue
            
            # Visualize strain maps
            try:
                # Save visualizations
                save_images_for_skip(
                    moving=moving,
                    fixed=fixed_input[..., :1] if fixed_input.shape[-1] > 1 else fixed_input,  # Extract only image channel
                    warped=warped,
                    disp=pred_disp,
                    save_dir=output_dir,
                    patient_id=74,
                    frame_number=frame_num,
                    target_disp=target_disp
                )
            except Exception as vis_error:
                logging.warning(f"Visualization failed for frame_num {frame_num}: {vis_error}")
                continue


        except Exception as e:
            logging.warning(f"Error processing frame {frame_num}: {str(e)}")
            continue
    
    return


def validate_batch_shapes(moving, fixed, moving_mask, fixed_mask, target_disp, use_mask):
    """Validate batch tensor shapes."""
    try:
        # Get spatial dimensions from the first tensor
        expected_spatial = moving.shape[1:3]
        
        # Check moving frame
        if moving.shape[1:3] != expected_spatial or moving.shape[-1] != 1:
            logging.warning(f"Invalid moving shape: {moving.shape}, expected: (..., {expected_spatial}, 1)")
            return False
            
        # Check fixed frame
        if fixed.shape[1:3] != expected_spatial or fixed.shape[-1] != 1:
            logging.warning(f"Invalid fixed shape: {fixed.shape}, expected: (..., {expected_spatial}, 1)")
            return False
            
        # Check moving mask
        if moving_mask.shape[1:3] != expected_spatial or moving_mask.shape[-1] != 1:
            logging.warning(f"Invalid moving mask shape: {moving_mask.shape}, expected: (..., {expected_spatial}, 1)")
            return False
            
        # Check fixed mask
        if fixed_mask.shape[1:3] != expected_spatial or fixed_mask.shape[-1] != 1:
            logging.warning(f"Invalid fixed mask shape: {fixed_mask.shape}, expected: (..., {expected_spatial}, 1)")
            return False
            
        # Check target displacement
        if target_disp.shape[1:3] != expected_spatial or target_disp.shape[-1] != 2:
            logging.warning(f"Invalid displacement shape: {target_disp.shape}, expected: (..., {expected_spatial}, 2)")
            return False
            
        return True
        
    except Exception as e:
        logging.warning(f"Shape validation error: {e}")
        return False


In [10]:
def visualize_simulated_test(models_config, lambdas, kernel_keys, 
                                     patient_data):
    """
    Improved visualization pipeline with better resource management and error handling.
    """
    print("\nStarting Simulated Test Visualization")
    print("=" * 50)
  
    # Track visualization statistics
    vis_stats = {
        'total_models': 0,
        'successful_visualizations': 0,
        'failed_visualizations': 0,
        'start_time': time.time()
    }
    
    # Iterate through all model configurations
    for model_key in models_config:
        config = models_config[model_key]
        
        for kernel_key in kernel_keys:
            for lambda_val in lambdas:
                vis_stats['total_models'] += 1
                
                print(f"\n{'='*60}")
                print(f"Evaluating {config['name']} (kernel={kernel_key}, λ={lambda_val})")
                print(f"Progress: {vis_stats['successful_visualizations'] + vis_stats['failed_visualizations']}/{vis_stats['total_models']}")

                model_name = f"{config['name']}_kernel_{kernel_key}_lambda_{lambda_val:.3f}"
                
                with memory_cleanup():
                    try:
                        # Load model
                        model = load_model_for_eval_robust(config, kernel_key, lambda_val)
                        if not model:
                            logging.error(f"Failed to load model for {config['name']} (kernel={kernel_key}, λ={lambda_val})")
                            vis_stats['failed_visualizations'] += 1
                            continue
                        
                        save_dir = os.path.join(
                            config[f'kernel_{kernel_key}_lambda_{lambda_val:.3f}']['folder'], 'results', 'Visualizing_Simulated'
                        )

                        # Run visualization
                        visualizing_model = visualize_simulated_model(
                            model=model,
                            test_data=patient_data,
                            output_dir=save_dir,
                            use_mask=config.get('use_mask', False)
                        )

                        # Log summary
                        logging.info(f"✅ visualization completed successfully")
                        
                        vis_stats['successful_visualizations'] += 1
                        
                    except Exception as e:
                        logging.error(f"❌ visualization failed: {str(e)}")
                        vis_stats['failed_visualizations'] += 1
                        continue
    
    # Final summary
    elapsed_time = time.time() - vis_stats['start_time']
    print(f"\n{'='*60}")
    print("visualization SUMMARY")
    print(f"{'='*60}")
    print(f"Total models: {vis_stats['total_models']}")
    print(f"Successful: {vis_stats['successful_visualizations']}")
    print(f"Failed: {vis_stats['failed_visualizations']}")
    print(f"Success rate: {vis_stats['successful_visualizations']/vis_stats['total_models']*100:.1f}%")
    print(f"Total time: {elapsed_time/60:.1f} minutes")
    print(f"Average time per model: {elapsed_time/vis_stats['total_models']:.1f} seconds")


In [11]:
visualize_simulated_test(models_config=MODEL_CONFIG,
                        lambdas=LAMBDAS,
                        kernel_keys=KERNEL_KEYS,
                        patient_data=data
                        )

2025-06-19 01:58:13,506 - INFO - Loading best model: best_model_val_loss_0.895370.weights.h5 (val_loss=0.895370)



Starting Simulated Test Visualization

Evaluating voxelmorph_both_masks (kernel=default, λ=0.016)
Progress: 0/1




  saveable.load_own_variables(weights_store.get(inner_path))
2025-06-19 01:58:16,769 - INFO - ✅ Successfully loaded model (attempt 1)
2025-06-19 01:58:18,150 - INFO - Prediction successful for frame_num 0
2025-06-19 01:58:18,151 - INFO - Processing patient 74, frame 0
2025-06-19 01:58:21,020 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_default_lambda_0.016\results\Visualizing_Simulated\frame_0.png
2025-06-19 01:58:21,116 - INFO - Prediction successful for frame_num 2
2025-06-19 01:58:21,117 - INFO - Processing patient 74, frame 2
2025-06-19 01:58:23,807 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_default_lambda_0.016\results\Visualizing_Simulated\frame_2.png
2025-06-19 01:58:23,892 - INFO - Prediction successful for frame_num 3
2025-06-19 01:58:23,893 - INFO - Processing patient 74, frame 3
2025-06-19 01:58:26,774 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_default_lambda_0.016\results\Visualizing_Si


Evaluating voxelmorph_both_masks (kernel=default, λ=0.033)
Progress: 1/2


2025-06-19 01:59:28,996 - INFO - ✅ Successfully loaded model (attempt 1)
2025-06-19 01:59:30,291 - INFO - Prediction successful for frame_num 0
2025-06-19 01:59:30,291 - INFO - Processing patient 74, frame 0
2025-06-19 01:59:32,979 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_default_lambda_0.033\results\Visualizing_Simulated\frame_0.png
2025-06-19 01:59:33,063 - INFO - Prediction successful for frame_num 2
2025-06-19 01:59:33,064 - INFO - Processing patient 74, frame 2
2025-06-19 01:59:35,804 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_default_lambda_0.033\results\Visualizing_Simulated\frame_2.png
2025-06-19 01:59:35,892 - INFO - Prediction successful for frame_num 3
2025-06-19 01:59:35,892 - INFO - Processing patient 74, frame 3
2025-06-19 01:59:38,671 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_default_lambda_0.033\results\Visualizing_Simulated\frame_3.png
2025-06-19 01:59:38,755 - INFO - Predictio


Evaluating voxelmorph_both_masks (kernel=default, λ=0.066)
Progress: 2/3


2025-06-19 02:00:42,768 - INFO - ✅ Successfully loaded model (attempt 1)
2025-06-19 02:00:44,051 - INFO - Prediction successful for frame_num 0
2025-06-19 02:00:44,052 - INFO - Processing patient 74, frame 0
2025-06-19 02:00:46,823 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_default_lambda_0.066\results\Visualizing_Simulated\frame_0.png
2025-06-19 02:00:46,909 - INFO - Prediction successful for frame_num 2
2025-06-19 02:00:46,910 - INFO - Processing patient 74, frame 2
2025-06-19 02:00:49,653 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_default_lambda_0.066\results\Visualizing_Simulated\frame_2.png
2025-06-19 02:00:49,740 - INFO - Prediction successful for frame_num 3
2025-06-19 02:00:49,741 - INFO - Processing patient 74, frame 3
2025-06-19 02:00:52,560 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_default_lambda_0.066\results\Visualizing_Simulated\frame_3.png
2025-06-19 02:00:52,648 - INFO - Predictio


Evaluating voxelmorph_both_masks (kernel=default, λ=0.1)
Progress: 3/4


2025-06-19 02:01:56,686 - INFO - ✅ Successfully loaded model (attempt 1)
2025-06-19 02:01:57,974 - INFO - Prediction successful for frame_num 0
2025-06-19 02:01:57,974 - INFO - Processing patient 74, frame 0
2025-06-19 02:02:00,850 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_default_lambda_0.100\results\Visualizing_Simulated\frame_0.png
2025-06-19 02:02:00,941 - INFO - Prediction successful for frame_num 2
2025-06-19 02:02:00,941 - INFO - Processing patient 74, frame 2
2025-06-19 02:02:03,924 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_default_lambda_0.100\results\Visualizing_Simulated\frame_2.png
2025-06-19 02:02:04,010 - INFO - Prediction successful for frame_num 3
2025-06-19 02:02:04,011 - INFO - Processing patient 74, frame 3
2025-06-19 02:02:06,841 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_default_lambda_0.100\results\Visualizing_Simulated\frame_3.png
2025-06-19 02:02:06,926 - INFO - Predictio


Evaluating voxelmorph_both_masks (kernel=first5, λ=0.016)
Progress: 4/5


2025-06-19 02:03:12,637 - INFO - Loading best model: best_model_val_loss_0.881871.weights.h5 (val_loss=0.881871)
2025-06-19 02:03:13,217 - INFO - ✅ Successfully loaded model (attempt 1)
2025-06-19 02:03:14,561 - INFO - Prediction successful for frame_num 0
2025-06-19 02:03:14,561 - INFO - Processing patient 74, frame 0
2025-06-19 02:03:17,283 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first5_lambda_0.016\results\Visualizing_Simulated\frame_0.png
2025-06-19 02:03:17,370 - INFO - Prediction successful for frame_num 2
2025-06-19 02:03:17,371 - INFO - Processing patient 74, frame 2
2025-06-19 02:03:20,168 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first5_lambda_0.016\results\Visualizing_Simulated\frame_2.png
2025-06-19 02:03:20,254 - INFO - Prediction successful for frame_num 3
2025-06-19 02:03:20,254 - INFO - Processing patient 74, frame 3
2025-06-19 02:03:23,026 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_k


Evaluating voxelmorph_both_masks (kernel=first5, λ=0.033)
Progress: 5/6


2025-06-19 02:04:28,357 - INFO - ✅ Successfully loaded model (attempt 1)
2025-06-19 02:04:29,686 - INFO - Prediction successful for frame_num 0
2025-06-19 02:04:29,687 - INFO - Processing patient 74, frame 0
2025-06-19 02:04:32,451 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first5_lambda_0.033\results\Visualizing_Simulated\frame_0.png
2025-06-19 02:04:32,535 - INFO - Prediction successful for frame_num 2
2025-06-19 02:04:32,536 - INFO - Processing patient 74, frame 2
2025-06-19 02:04:35,321 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first5_lambda_0.033\results\Visualizing_Simulated\frame_2.png
2025-06-19 02:04:35,416 - INFO - Prediction successful for frame_num 3
2025-06-19 02:04:35,418 - INFO - Processing patient 74, frame 3
2025-06-19 02:04:38,327 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first5_lambda_0.033\results\Visualizing_Simulated\frame_3.png
2025-06-19 02:04:38,421 - INFO - Prediction s


Evaluating voxelmorph_both_masks (kernel=first5, λ=0.066)
Progress: 6/7


2025-06-19 02:05:44,093 - INFO - ✅ Successfully loaded model (attempt 1)
2025-06-19 02:05:45,432 - INFO - Prediction successful for frame_num 0
2025-06-19 02:05:45,433 - INFO - Processing patient 74, frame 0
2025-06-19 02:05:48,199 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first5_lambda_0.066\results\Visualizing_Simulated\frame_0.png
2025-06-19 02:05:48,287 - INFO - Prediction successful for frame_num 2
2025-06-19 02:05:48,288 - INFO - Processing patient 74, frame 2
2025-06-19 02:05:51,084 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first5_lambda_0.066\results\Visualizing_Simulated\frame_2.png
2025-06-19 02:05:51,172 - INFO - Prediction successful for frame_num 3
2025-06-19 02:05:51,173 - INFO - Processing patient 74, frame 3
2025-06-19 02:05:54,078 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first5_lambda_0.066\results\Visualizing_Simulated\frame_3.png
2025-06-19 02:05:54,165 - INFO - Prediction s


Evaluating voxelmorph_both_masks (kernel=first5, λ=0.1)
Progress: 7/8


2025-06-19 02:07:00,707 - INFO - ✅ Successfully loaded model (attempt 1)
2025-06-19 02:07:02,262 - INFO - Prediction successful for frame_num 0
2025-06-19 02:07:02,263 - INFO - Processing patient 74, frame 0
2025-06-19 02:07:05,105 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first5_lambda_0.100\results\Visualizing_Simulated\frame_0.png
2025-06-19 02:07:05,191 - INFO - Prediction successful for frame_num 2
2025-06-19 02:07:05,192 - INFO - Processing patient 74, frame 2
2025-06-19 02:07:08,023 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first5_lambda_0.100\results\Visualizing_Simulated\frame_2.png
2025-06-19 02:07:08,108 - INFO - Prediction successful for frame_num 3
2025-06-19 02:07:08,109 - INFO - Processing patient 74, frame 3
2025-06-19 02:07:10,933 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first5_lambda_0.100\results\Visualizing_Simulated\frame_3.png
2025-06-19 02:07:11,020 - INFO - Prediction s


Evaluating voxelmorph_both_masks (kernel=first7_second5, λ=0.016)
Progress: 8/9


2025-06-19 02:08:49,033 - INFO - ✅ Successfully loaded model (attempt 1)
2025-06-19 02:08:50,402 - INFO - Prediction successful for frame_num 0
2025-06-19 02:08:50,403 - INFO - Processing patient 74, frame 0
2025-06-19 02:09:10,180 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first7_second5_lambda_0.016\results\Visualizing_Simulated\frame_0.png
2025-06-19 02:09:10,361 - INFO - Prediction successful for frame_num 2
2025-06-19 02:09:10,362 - INFO - Processing patient 74, frame 2
2025-06-19 02:09:29,673 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first7_second5_lambda_0.016\results\Visualizing_Simulated\frame_2.png
2025-06-19 02:09:29,823 - INFO - Prediction successful for frame_num 3
2025-06-19 02:09:29,824 - INFO - Processing patient 74, frame 3
2025-06-19 02:09:46,005 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first7_second5_lambda_0.016\results\Visualizing_Simulated\frame_3.png
2025-06-19 02:09:46,1


Evaluating voxelmorph_both_masks (kernel=first7_second5, λ=0.033)
Progress: 9/10


2025-06-19 02:13:59,437 - INFO - ✅ Successfully loaded model (attempt 1)
2025-06-19 02:14:00,987 - INFO - Prediction successful for frame_num 0
2025-06-19 02:14:00,988 - INFO - Processing patient 74, frame 0
2025-06-19 02:14:04,374 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first7_second5_lambda_0.033\results\Visualizing_Simulated\frame_0.png
2025-06-19 02:14:04,471 - INFO - Prediction successful for frame_num 2
2025-06-19 02:14:04,471 - INFO - Processing patient 74, frame 2
2025-06-19 02:14:07,332 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first7_second5_lambda_0.033\results\Visualizing_Simulated\frame_2.png
2025-06-19 02:14:07,428 - INFO - Prediction successful for frame_num 3
2025-06-19 02:14:07,429 - INFO - Processing patient 74, frame 3
2025-06-19 02:14:10,284 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first7_second5_lambda_0.033\results\Visualizing_Simulated\frame_3.png
2025-06-19 02:14:10,3


Evaluating voxelmorph_both_masks (kernel=first7_second5, λ=0.066)
Progress: 10/11


2025-06-19 02:15:20,883 - INFO - ✅ Successfully loaded model (attempt 1)
2025-06-19 02:15:22,231 - INFO - Prediction successful for frame_num 0
2025-06-19 02:15:22,232 - INFO - Processing patient 74, frame 0
2025-06-19 02:15:25,111 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first7_second5_lambda_0.066\results\Visualizing_Simulated\frame_0.png
2025-06-19 02:15:25,209 - INFO - Prediction successful for frame_num 2
2025-06-19 02:15:25,210 - INFO - Processing patient 74, frame 2
2025-06-19 02:15:28,089 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first7_second5_lambda_0.066\results\Visualizing_Simulated\frame_2.png
2025-06-19 02:15:28,191 - INFO - Prediction successful for frame_num 3
2025-06-19 02:15:28,192 - INFO - Processing patient 74, frame 3
2025-06-19 02:15:31,062 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first7_second5_lambda_0.066\results\Visualizing_Simulated\frame_3.png
2025-06-19 02:15:31,1


Evaluating voxelmorph_both_masks (kernel=first7_second5, λ=0.1)
Progress: 11/12


2025-06-19 02:16:42,257 - INFO - ✅ Successfully loaded model (attempt 1)
2025-06-19 02:16:43,652 - INFO - Prediction successful for frame_num 0
2025-06-19 02:16:43,652 - INFO - Processing patient 74, frame 0
2025-06-19 02:16:46,444 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first7_second5_lambda_0.100\results\Visualizing_Simulated\frame_0.png
2025-06-19 02:16:46,537 - INFO - Prediction successful for frame_num 2
2025-06-19 02:16:46,538 - INFO - Processing patient 74, frame 2
2025-06-19 02:16:49,978 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first7_second5_lambda_0.100\results\Visualizing_Simulated\frame_2.png
2025-06-19 02:16:50,075 - INFO - Prediction successful for frame_num 3
2025-06-19 02:16:50,075 - INFO - Processing patient 74, frame 3
2025-06-19 02:16:52,985 - INFO - Saved visualization: ./data\Models\voxelmorph_both_masks_kernel_first7_second5_lambda_0.100\results\Visualizing_Simulated\frame_3.png
2025-06-19 02:16:53,0


visualization SUMMARY
Total models: 12
Successful: 12
Failed: 0
Success rate: 100.0%
Total time: 19.9 minutes
Average time per model: 99.3 seconds
