# Imports

In [1]:

import os
import sys
import time
import numpy as np
import torch

from utils_display import \
    numpy_overview, torch_overview, numpy2torch, torch2numpy, \
    load_nifti_to_array, process_dvf_greedy_oasis, reorient_OASIS_to_RAS, \
    mk_grid_img_3d, JDet, \
    SpatialTransformer, \
    get_slice_obj, \
    cmap_seg_oasis, norm_seg_oasis, \
    plot_registration_results

# Use CPU
device = torch.device('cpu')
# Use GPU (for warping moving image and moving seg)
# os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# device = torch.device('cuda')

img_size = (160, 224, 192)
spatial_trans = SpatialTransformer(img_size).to(device)

jdet_calculator = JDet()


# Inputs and Outputs

In [2]:
### Outputs
dir_results = './results'
os.makedirs(dir_results, exist_ok=True)


### load image data
dir_data = './data/OASIS'
fixed_id  = '0004'
moving_id = '0236'
fname_disp = f"disp_{fixed_id}_{moving_id}.nii.gz"
img_fixed = reorient_OASIS_to_RAS(load_nifti_to_array(os.path.join(dir_data, 'img0004.nii.gz')))
seg_fixed = reorient_OASIS_to_RAS(load_nifti_to_array(os.path.join(dir_data, 'seg0004.nii.gz')))
img_moving = reorient_OASIS_to_RAS(load_nifti_to_array(os.path.join(dir_data, 'img0236.nii.gz')))
seg_moving = reorient_OASIS_to_RAS(load_nifti_to_array(os.path.join(dir_data, 'seg0236.nii.gz')))

img_fixed = np.ascontiguousarray(img_fixed)
seg_fixed = np.ascontiguousarray(seg_fixed)
img_moving = np.ascontiguousarray(img_moving)
seg_moving = np.ascontiguousarray(seg_moving)

# DEBUG
numpy_overview(img_fixed, 'img_fixed')
numpy_overview(seg_fixed, 'seg_fixed')
numpy_overview(img_moving, 'img_moving')
numpy_overview(seg_moving, 'seg_moving')

### convert to torch tensor
img_fixed_tensor = numpy2torch(img_fixed).to(device)
seg_fixed_tensor = numpy2torch(seg_fixed).to(device)
img_moving_tensor = numpy2torch(img_moving).to(device)
seg_moving_tensor = numpy2torch(seg_moving).to(device)

# DEBUG
torch_overview(img_fixed_tensor, 'img_fixed_tensor')
torch_overview(seg_fixed_tensor, 'seg_fixed_tensor')
torch_overview(img_moving_tensor, 'img_moving_tensor')
torch_overview(seg_moving_tensor, 'seg_moving_tensor')


### paths for displacement fields
# disp_0004_0236_bs0_greedy.nii.gz  
# disp_0004_0236_bs1_vxm.nii.gz
# disp_0004_0236_bs2_tm.nii.gz
# disp_0004_0236_bs3_vfa.nii.gz
# disp_0004_0236_bs4_sitreg.nii.gz
# disp_0004_0236_fedda-1mf.nii.gz
# disp_0004_0236_fedda-2mfc.nii.gz
# disp_0004_0236_fedda-3c.nii.gz
# disp_0004_0236_feddb-1mf.nii.gz
# disp_0004_0236_feddb-3c.nii.gz
# disp_0004_0236_feddc.nii.gz

list_method_titles = [
    'Greedy',
    'VoxelMorph',
    'TransMorph',
    'VFA',
    'SITReg',
    '(a)DP-Conv-1MF',
    '(a)DP-Conv-2MFC',
    '(a)DP-Conv-3C',
    '(b)DP-ConvIC-1MF',
    '(b)DP-ConvIC-3C',
    '(c)DP-VFA',
]

list_path_disp = [
    './data/disp/disp_0004_0236_bs0_greedy.nii.gz',
    './data/disp/disp_0004_0236_bs1_vxm.nii.gz',
    './data/disp/disp_0004_0236_bs2_tm.nii.gz',
    './data/disp/disp_0004_0236_bs3_vfa.nii.gz',
    './data/disp/disp_0004_0236_bs4_sitreg.nii.gz',
    './data/disp/disp_0004_0236_fedda-1mf.nii.gz',
    './data/disp/disp_0004_0236_fedda-2mfc.nii.gz',
    './data/disp/disp_0004_0236_fedda-3c.nii.gz',
    './data/disp/disp_0004_0236_feddb-1mf.nii.gz',
    './data/disp/disp_0004_0236_feddb-3c.nii.gz',
    './data/disp/disp_0004_0236_feddc.nii.gz',
]

list_all_titles = [
    'Fixed',
    'Moving',
] + list_method_titles

num_methods = len(list_method_titles)

print(f"Number of comparing methods: {num_methods}")

img_fixed: float64, 3D, shape=(160, 224, 192), min=0.0, max=0.8627451062202454
seg_fixed: float64, 3D, shape=(160, 224, 192), min=0.0, max=35.0
img_moving: float64, 3D, shape=(160, 224, 192), min=0.0, max=0.8549019694328308
seg_moving: float64, 3D, shape=(160, 224, 192), min=0.0, max=35.0
img_fixed_tensor: torch.float32, 5D, size=torch.Size([1, 1, 160, 224, 192]), device=cpu, min=0.0, max=0.8627451062202454
seg_fixed_tensor: torch.float32, 5D, size=torch.Size([1, 1, 160, 224, 192]), device=cpu, min=0.0, max=35.0
img_moving_tensor: torch.float32, 5D, size=torch.Size([1, 1, 160, 224, 192]), device=cpu, min=0.0, max=0.8549019694328308
seg_moving_tensor: torch.float32, 5D, size=torch.Size([1, 1, 160, 224, 192]), device=cpu, min=0.0, max=35.0
Number of comparing methods: 11


# Display
### There are still some (direction, i.e. flip /slice number) mismatches between my visualization and ITK-SNAP, need to fix later
1. Coronal view seem to be 1 slice off
2. Sagittal view the slice order is flipped

In [3]:
# Define a local function to handle slice processing
def process_slice(slice_data):
    """
    Process a slice by applying transpose and flip operations.
    This makes sure the images are correctly oriented.
    
    Args:
        slice_data: numpy array representing a 2D slice
        
    Returns:
        numpy array: processed slice with transpose and flip applied
    """
    return np.fliplr(slice_data.transpose())


In [4]:

# Define the configurations for each axis
axis_configs = [
    {'axis': 0, 'slc_idx_func': lambda shape: shape[0]//2 - 10,  'description': 'sag'}, # offset to avoid the gap in central brain
    {'axis': 1, 'slc_idx_func': lambda shape: shape[1]//2,       'description': 'cor'},
    {'axis': 2, 'slc_idx_func': lambda shape: shape[2]//2,       'description': 'axi'},
]

for config in axis_configs:
    
    current_axis = config['axis']
    slc_idx_func = config['slc_idx_func']
    slc_idx = slc_idx_func(img_size)
    slc_description = config['description']

    suptitle_text = f"Fixed: {fixed_id}, Moving: {moving_id}, Axis: {current_axis}, Slice: {slc_idx+1}/{img_size[current_axis]}"
    fname = f"f{fixed_id}_m{moving_id}_axis{current_axis}_slc{str(slc_idx+1).zfill(3)}of{img_size[current_axis]}.png"
    output_filename = os.path.join(dir_results, fname)
    
    print(f"Plotting {slc_description} view (axis: {current_axis}), slice index: {slc_idx+1}/{img_size[current_axis]}, 3D shape {img_size}, output to {output_filename}")
    
    # --- Size ---
    subplot_width = 3
    fontscaler = subplot_width * 0.5

    # --- Dynamic subplot_aspect_ratio ---
    # Note the reversed order here as slices are transposed when plotting
    if current_axis == 0:
        subplot_aspect_ratio = img_size[2] / img_size[1]
    elif current_axis == 1:
        subplot_aspect_ratio = img_size[2] / img_size[0]
    elif current_axis == 2:
        subplot_aspect_ratio = img_size[1] / img_size[0]
    else:
        subplot_aspect_ratio = 1.0 # Default

    # --- Initialize data lists ---
    list_data_images = []
    list_data_labels = []
    # list_data_labels_mip = []
    list_data_diff_images = []
    list_data_disp_grid_bw = []
    list_data_disp_jdet = []

    # --- Process Fixed and Moving Images/Segmentations ---
    # Fixed
    np_fixed = torch2numpy(img_fixed_tensor)
    slice_obj = get_slice_obj(np_fixed, current_axis, slc_idx) # get slice_obj for dynamic slicing
    list_data_images.append(process_slice(np_fixed[slice_obj]))
    np_fixed_seg = torch2numpy(seg_fixed_tensor)
    list_data_labels.append(process_slice(np_fixed_seg[slice_obj]))
    # list_data_labels_mip.append(process_slice(np.max(np_fixed_seg, axis=current_axis)))

    # Moving
    np_moving = torch2numpy(img_moving_tensor)
    list_data_images.append(process_slice(np_moving[slice_obj]))
    np_diff = np_moving - np_fixed
    list_data_diff_images.append(process_slice(np_diff[slice_obj]))
    np_moving_seg = torch2numpy(seg_moving_tensor)
    list_data_labels.append(process_slice(np_moving_seg[slice_obj]))
    # list_data_labels_mip.append(process_slice(np.max(np_moving_seg, axis=current_axis)))

    # --- Process Deformations ---
    for path_disp in list_path_disp:

        # Add placeholders for missing disp file (to keep the plot structure consistent)
        if not os.path.exists(path_disp):
            print(f"Warning: Displacement field not found at {path_disp}. Skipping.")
            dummy_slice_shape = list_data_images[0].shape
            list_data_images.append(np.zeros(dummy_slice_shape))
            list_data_labels.append(np.zeros(dummy_slice_shape))
            # list_data_labels_mip.append(np.zeros(dummy_slice_shape))
            list_data_diff_images.append(np.zeros(dummy_slice_shape))
            list_data_disp_grid_bw.append(np.zeros(dummy_slice_shape))
            list_data_disp_jdet.append(np.zeros(dummy_slice_shape))
            continue

        if 'greedy' in path_disp: # Assuming 'greedy' is a unique identifier in the path
            disp_np = load_nifti_to_array(path_disp)
            disp_np = process_dvf_greedy_oasis(disp_np, channel_last=False) # special orientation handling for greedy
            disp = numpy2torch(disp_np, device=device)
        else:
            disp_np = load_nifti_to_array(path_disp)
            disp = numpy2torch(disp_np, device=device)

        deformed = spatial_trans(img_moving_tensor, disp)
        
        # Deformed image
        np_deformed = torch2numpy(deformed)
        # slice_obj = get_slice_obj(np_deformed, current_axis, slc_idx)
        list_data_images.append(process_slice(np_deformed[slice_obj]))

        # Difference image
        np_diff = np_deformed - np_fixed
        list_data_diff_images.append(process_slice(np_diff[slice_obj]))

        # Deformed seg
        deformed_seg = spatial_trans(seg_moving_tensor, disp, mode='nearest')
        np_deformed_seg = torch2numpy(deformed_seg)
        list_data_labels.append(process_slice(np_deformed_seg[slice_obj]))
        # list_data_labels_mip.append(process_slice(np.max(np_deformed_seg, axis=current_axis)))

        # Deformation grid
        grid_step = 8
        grid_size = tuple(disp.shape[-3:])
        grid_axis = mk_grid_img_3d(grid_step=grid_step, grid_size=grid_size, axis_to_slice=current_axis)

        deformed_grid_axis = spatial_trans(grid_axis.to(device), disp) # Ensure grid is on correct device
        np_deformed_grid_axis = torch2numpy(deformed_grid_axis)
        
        # slice_obj_grid = get_slice_obj(np_deformed_grid_axis, current_axis, slc_idx)
        list_data_disp_grid_bw.append(process_slice(np_deformed_grid_axis[slice_obj]))

        # Deformation Jdet
        # jdet_calculator expects (Batch, Channels, D, H, W) e.g. (1, 3, H, W, D)
        disp_for_jdet = disp.cpu().numpy() # No squeeze, keep batch for jdet_calculator
        try:
            jacdet, non_pos_jacdet, stdjacdet = jdet_calculator(disp_for_jdet) # jacdet is, H, W, D)
            jacdet_padded = np.pad(jacdet, pad_width=((2, 2), (2, 2), (2, 2)), mode='constant', constant_values=1)
            # slice_obj_jdet = get_slice_obj(jacdet_padded, current_axis, slc_idx)
            list_data_disp_jdet.append(process_slice(jacdet_padded[slice_obj]))
        except Exception as e:
            # Append a placeholder if Jdet calculation fails
            print(f"Error in Jdet calculation for {path_disp}: {e}")
            list_data_disp_jdet.append(np.zeros_like(list_data_disp_grid_bw[-1]))


    ### Data and display settings for each row
    data_rows = [
        {
            'category': 'deformed',
            'content_description': 'Deformed Images',
            'row_title': 'Deformed',
            'data': list_data_images,
            'disp_settings': [{'plot_type': 'imshow', 'imshow_params': {'origin': 'lower', 'cmap': 'gray', 'vmin': 0, 'vmax': 0.75}}] * len(list_data_images),
        },
        {
            'category': 'diff',
            'content_description': 'Difference Maps (Deformed - Fixed)',
            'row_title': 'Difference',
            'data': list_data_diff_images,
            'disp_settings': [{'plot_type': 'imshow', 'imshow_params': {'origin': 'lower', 'cmap': 'coolwarm', 'vmin': -0.1, 'vmax': 0.1}}] * len(list_data_diff_images),
            'shared_row_colorbar': {'show': True, 'label': 'Difference'},
            'shared_row_colorbar': {
               'show': True, 
               'label': 'Difference',
               'tick_values_normalized': [-0.10, -0.05, 0, 0.05, 0.10],
               'tick_labels_actual': ['-0.10', '-0.05', '0', '0.05', '0.10'],
           },
        },
        {
            'category': 'deformed',
            'content_description': 'Deformed Labels Slice',
            'row_title': 'Deformed \nLabels (Slice)',
            'data': list_data_labels,
            'disp_settings': [{'plot_type': 'imshow', 'imshow_params': {'origin': 'lower', 'cmap': cmap_seg_oasis, 'norm': norm_seg_oasis}}] * len(list_data_labels),
        },
        {
            'category': 'deformation',
            'content_description': 'Deformation grid backward',
            'row_title': 'Deformation \nGrid (backward)',
            'data': list_data_disp_grid_bw,
            'disp_settings': [{'plot_type': 'imshow', 'imshow_params': {'cmap': 'gray', 'vmin':0, 'vmax':1}}] * len(list_data_disp_grid_bw),
        },
        {
            'category': 'deformation',
            'content_description': 'Deformation Jdet',
            'row_title': 'Deformation \nJdet',
            'data': list_data_disp_jdet,
            'disp_settings': [{
                'plot_type': 'imshow',
                'imshow_params': {'origin': 'lower', 'cmap': 'bwr', 'vmin':-1, 'vmax':3},
                'contour_overlay': {
                    'show': True, 'level': 0.0, 'color': 'black',
                    'linewidth': 0.7, 'linestyle': '-',
                }
            }] * len(list_data_disp_jdet),
            # 'shared_row_colorbar': {'show': True, 'label': 'Jacobian\ndeterminant'}
            'shared_row_colorbar': {
               'show': True, 
               'label': 'Jacobian\ndeterminant',
               'tick_values_normalized': [-1, 0, 1, 2, 3],
               'tick_labels_actual': ['-1', '0', '1', '2', '3'],
           },
        },
    ]

    plot_registration_results(
        col_title_fontsize=16 * fontscaler,
        row_title_fontsize=16 * fontscaler,
        tick_label_fontsize=14 * fontscaler,
        suptitle_fontsize=20 * fontscaler, 
        data_rows=data_rows,
        num_methods=num_methods,
        suptitle_text=suptitle_text,
        col_titles=list_all_titles,
        subplot_width=subplot_width,
        subplot_aspect_ratio=subplot_aspect_ratio,
        dpi=300,
        wspace_factor=0.1,
        hspace_factor=0.05,
        top_margin_factor=0.90,
        right_margin_factor=0.90,
        output_filename=output_filename,
        show_figure=False
    )


Plotting sag view (axis: 0), slice index: 71/160, 3D shape (160, 224, 192), output to ./results/f0004_m0236_axis0_slc071of160.png
Figure saved to ./results/f0004_m0236_axis0_slc071of160.png
Plotting cor view (axis: 1), slice index: 113/224, 3D shape (160, 224, 192), output to ./results/f0004_m0236_axis1_slc113of224.png
Figure saved to ./results/f0004_m0236_axis1_slc113of224.png
Plotting axi view (axis: 2), slice index: 97/192, 3D shape (160, 224, 192), output to ./results/f0004_m0236_axis2_slc097of192.png
Figure saved to ./results/f0004_m0236_axis2_slc097of192.png
