## Setup and Imports

In [None]:
import os
from pathlib import Path

import itk
import numpy as np
import pyvista as pv
from itk import TubeTK as ttk

# Import from PhysioMotion4D package
from physiomotion4d import (
    ContourTools,
    HeartModelToPatientWorkflow,
    SegmentChestTotalSegmentator,
)


## Define File Paths

In [None]:
# Patient CT image (defines coordinate frame)
patient_data_dir = Path.cwd().parent / '..' / 'data' / 'Slicer-Heart-CT'
patient_ct_path = patient_data_dir / 'slice_007.mha'
patient_ct_heart_mask_path = patient_data_dir / 'slice_007_heart_mask3.nii.gz'

# Atlas template model (moving)
atlas_data_dir = Path.cwd().parent / '..' / 'data' / 'KCL-Heart-Model'
atlas_vtu_path = atlas_data_dir / 'average_mesh.vtk'

pca_data_dir = Path.cwd().parent / '..' / 'data' / 'KCL-Heart-Model' / 'pca'
pca_json_path = pca_data_dir / 'pca.json'
pca_group_key = 'All'
pca_n_modes = 10

# Output directory
output_dir = Path.cwd() / 'results'

os.makedirs(output_dir, exist_ok=True)

In [None]:
fixed_image = itk.imread(str(patient_ct_path))
itk.imwrite(fixed_image, str(output_dir / 'patient_image.mha'), compression=True)

In [None]:
if False:
    segmentator = SegmentChestTotalSegmentator()
    segmentator.contrast_threshold = 500
    fixed_data = segmentator.segment(fixed_image, contrast_enhanced_study=False)
    labelmap_image = fixed_data["labelmap"]
    lung_mask_image = fixed_data["lung"]
    heart_mask_image = fixed_data["heart"]
    major_vessels_mask_image = fixed_data["major_vessels"]
    bone_mask_image = fixed_data["bone"]
    soft_tissue_mask_image = fixed_data["soft_tissue"]
    other_mask_image = fixed_data["other"]
    contrast_mask_image = fixed_data["contrast"]


    itk.imwrite(labelmap_image, str(output_dir / 'fixed_labelmap.mha'), compression=True)

    heart_arr = itk.GetArrayFromImage(heart_mask_image)
    #contrast_arr = itk.GetArrayFromImage(contrast_mask_image)
    mask_arr = (heart_arr > 0).astype(np.uint8) #((heart_arr + contrast_arr) > 0).astype(np.uint8)
    fixed_mask_image = itk.GetImageFromArray(mask_arr)
    fixed_mask_image.CopyInformation(fixed_image)

    itk.imwrite(fixed_mask_image, str(output_dir / 'fixed_mask_draft.mha'), compression=True)

    # hand edit fixed_mask to make slice_007_heart_mask.nii.gz that is saved in patient_data_dir
else:
    fixed_mask_image = itk.imread(str(patient_ct_heart_mask_path))

In [None]:
flip0 = np.array(fixed_mask_image.GetDirection())[0,0] < 0
flip1 = np.array(fixed_mask_image.GetDirection())[1,1] < 0
flip2 = np.array(fixed_mask_image.GetDirection())[2,2] < 0
if flip0 or flip1 or flip2:
    print("Flipping fixed image...")
    print(flip0, flip1, flip2)
    flip_filter = itk.FlipImageFilter.New(Input=fixed_image)
    flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])
    flip_filter.SetFlipAboutOrigin(True)
    flip_filter.Update()
    fixed_image = flip_filter.GetOutput()
    id_mat = itk.Matrix[itk.D, 3, 3]()
    id_mat.SetIdentity()
    fixed_image.SetDirection(id_mat)
    itk.imwrite(fixed_image, str(output_dir / 'fixed_image.mha'), compression=True)
    print("Flipping fixed mask image...")
    flip_filter = itk.FlipImageFilter.New(Input=fixed_mask_image)
    flip_filter.SetFlipAxes([int(flip0), int(flip1), int(flip2)])
    flip_filter.SetFlipAboutOrigin(True)
    flip_filter.Update()
    fixed_mask_image = flip_filter.GetOutput()
    fixed_mask_image.SetDirection(id_mat)
    itk.imwrite(fixed_mask_image, str(output_dir / 'fixed_mask.mha'), compression=True)


In [None]:

fixed_mesh = ContourTools().extract_contours(fixed_mask_image)
fixed_mesh.save(str(output_dir / 'fixed_mesh.vtp'))
fixed_mesh = pv.read(str(output_dir / 'fixed_mesh.vtp'))

moving_original_mesh = pv.read(str(atlas_vtu_path))
moving_mesh = moving_original_mesh.extract_surface()
moving_mesh.save(str(output_dir / 'moving_mesh.vtp'))
moving_mesh = pv.read(str(output_dir / 'moving_mesh.vtp'))

In [None]:
registrar = HeartModelToPatientWorkflow(
    moving_mesh=moving_mesh,
    fixed_image=fixed_image,
    fixed_meshes=[fixed_mesh],
)

registrar.set_masks(
    moving_mask_image=None,
    fixed_mask_image=fixed_mask_image,
)

registrar.set_mask_dilation_mm(5)
registrar.set_roi_dilation_mm(25)

registrar.set_pca_data_from_slicersalt(
    json_filename=pca_json_path,
    group_key=pca_group_key,
    n_modes=pca_n_modes,
)

fixed_image = registrar.fixed_image
itk.imwrite(fixed_image, str(output_dir / 'fixed_image.mha'), compression=True)

In [None]:
# Rough alignment using ICP
reg_results = registrar.register_mesh_to_mesh_icp()
icp_phi_FM = reg_results['phi_FM']
icp_phi_MF = reg_results['phi_MF']
moving_icp_mesh = reg_results['moving_mesh']
moving_icp_mask_image = reg_results['moving_mask_image']
moving_icp_mask_roi_image = reg_results['moving_mask_roi_image']

fixed_roi_image = registrar.fixed_mask_roi_image
moving_mask_image = registrar.moving_mask_image
moving_roi_image = registrar.moving_mask_roi_image

# Save masks for inspection
itk.imwrite(moving_mask_image, str(output_dir / 'moving_mask.mha'), compression=True)
itk.imwrite(moving_roi_image, str(output_dir / 'moving_roi.mha'), compression=True)
itk.imwrite(fixed_roi_image, str(output_dir / 'fixed_roi.mha'), compression=True)

print("New center =", moving_icp_mesh.center)
print(" Rough alignment using ICP completed.")
itk.imwrite(moving_icp_mask_image, str(output_dir / "moving_icp_mask.nii.gz"), compression=True)
itk.imwrite(moving_icp_mask_roi_image, str(output_dir / "moving_icp_mask_roi.nii.gz"), compression=True)
moving_icp_mesh.save(str(output_dir / "moving_icp_mesh.vtp"))

In [None]:
#import logging
#from physiomotion4d import PhysioMotion4DBase
#
#PhysioMotion4DBase.set_log_level(logging.DEBUG)

reg_results = registrar.register_mesh_to_mesh_pca()
pca_rigid_transform = reg_results['pre_phi_FM']
pca_coefficients = reg_results['pca_coefficients_FM']
moving_pca_mesh = reg_results['moving_mesh']
moving_pca_mask_image = reg_results["moving_mask_image"]
moving_pca_mask_roi_image = reg_results["moving_mask_roi_image"]

itk.imwrite(moving_pca_mask_image, str(output_dir / "moving_pca_mask.nii.gz"), compression=True)
itk.imwrite(moving_pca_mask_roi_image, str(output_dir / "moving_pca_mask_roi.nii.gz"), compression=True)
moving_pca_mesh.save(str(output_dir / "moving_pca_mesh.vtp"))


## Mask Alignment

In [None]:
# Perform deformable registration
print("Starting deformable mask-to-mask registration...")

reg_results = registrar.register_mask_to_mask()
m2m_phi_FM = reg_results['phi_FM']
m2m_phi_MF = reg_results['phi_MF']
moving_m2m_mesh = reg_results['moving_mesh']
moving_m2m_mask_image = reg_results['moving_mask_image']
moving_m2m_mask_roi_image = reg_results['moving_mask_roi_image']

print("Registration complete!")

# Save registration results to output folder
itk.transformwrite([m2m_phi_FM], str(output_dir / "m2m_phi_FM.hdf"), compression=True)
itk.transformwrite([m2m_phi_MF], str(output_dir / "m2m_phi_MF.hdf"), compression=True)
itk.imwrite(moving_m2m_mask_image, str(output_dir / "moving_m2m_mask.nii.gz"), compression=True)
itk.imwrite(moving_m2m_mask_roi_image, str(output_dir / "moving_m2m_mask_roi.nii.gz"), compression=True)
moving_m2m_mesh.save(str(output_dir / "moving_m2m_mesh.vtp"))

In [None]:
# Perform Icon-based deformable registration
# This is the most computationally intensive step (requires GPU)
print("Starting deformable registration...")
print("This may take several minutes depending on GPU availability.")

reg_results = registrar.register_mask_to_image()
m2i_phi_FM = reg_results['phi_FM']
m2i_phi_MF = reg_results['phi_MF']
moving_m2i_mesh = reg_results['moving_mesh']
moving_m2i_mask_image = reg_results['moving_mask_image']
moving_m2i_mask_roi_image = reg_results['moving_mask_roi_image']

print("\nRegistration complete!")

# Save registration results to output folder
itk.transformwrite([m2i_phi_FM], str(output_dir / "m2i_phi_FM.hdf"), compression=True)
itk.transformwrite([m2i_phi_MF], str(output_dir / "m2i_phi_MF.hdf"), compression=True)
itk.imwrite(moving_m2i_mask_image, str(output_dir / "moving_m2i_mask.nii.gz"), compression=True)
itk.imwrite(moving_m2i_mask_roi_image, str(output_dir / "moving_m2i_mask_roi.nii.gz"), compression=True)
moving_m2i_mesh.save(str(output_dir / "moving_m2i_mesh.vtp"))

In [None]:
moving_registered_surface_mesh = registrar.moving_mesh.copy(deep=True)
new_points = moving_registered_surface_mesh.points
for i in range(new_points.shape[0]):
    p = itk.Point[itk.D, 3]()
    new_p = itk.Point[itk.D, 3]()
    p[0], p[1], p[2] = float(new_points[i, 0]), float(new_points[i, 1]), float(new_points[i, 2])
    tmp_p = registrar.icp_phi_FM.TransformPoint(p)
    new_p[0], new_p[1], new_p[2] = tmp_p[0], tmp_p[1], tmp_p[2]
    tmp_p = registrar.registrar_pca.transform_point(new_p)
    new_p[0], new_p[1], new_p[2] = tmp_p[0], tmp_p[1], tmp_p[2]
    tmp_p = registrar.m2m_phi_FM.TransformPoint(new_p)
    new_p[0], new_p[1], new_p[2] = tmp_p[0], tmp_p[1], tmp_p[2]
    tmp_p = registrar.m2i_phi_FM.TransformPoint(new_p)
    new_p[0], new_p[1], new_p[2] = tmp_p[0], tmp_p[1], tmp_p[2]
    new_points[i, 0], new_points[i, 1], new_points[i, 2] = new_p[0], new_p[1], new_p[2]

moving_registered_surface_mesh.points = new_points
moving_registered_surface_mesh.save(str(output_dir / "moving_registered_surface_mesh.vtp"))

In [None]:
new_original_mesh = registrar.apply_transforms_to_original_mesh(include_m2i=True)
new_original_mesh.save(str(output_dir / "moving_registered_original_mesh.vtu"))

## Visualize Final Results

In [None]:
# Load meshes from registrar member variables
moving_mesh = registrar.moving_original_mesh
aligned_mesh = registrar.moving_icp_mesh
registered_mesh = registrar.moving_m2i_mesh
fixed_mesh = registrar.fixed_mesh

# Create side-by-side comparison
plotter = pv.Plotter(shape=(1, 2))

# After rough alignment
plotter.subplot(0, 0)
plotter.add_mesh(fixed_mesh, color='red', opacity=1.0, label='Patient')
plotter.add_mesh(aligned_mesh, color='green', opacity=0.6, label='After ICP')
plotter.add_title('Rough Alignment')

# After deformable registration
plotter.subplot(0, 1)
plotter.add_mesh(fixed_mesh, color='red', opacity=0.6, label='Patient')
plotter.add_mesh(registered_mesh, color='blue', opacity=0.6, label='Registered')
plotter.add_title('Final Registration')

plotter.link_views()
plotter.show()

## Visualize Deformation Magnitude

In [None]:
# The transformed mesh has deformation magnitude stored as point data
if 'DeformationMagnitude' in moving_m2i_mesh.point_data:
    plotter = pv.Plotter()
    plotter.add_mesh(
        moving_m2i_mesh,
        scalars='DeformationMagnitude',
        cmap='jet',
        show_scalar_bar=True,
        scalar_bar_args={'title': 'Deformation (mm)'}
    )
    plotter.add_title('Deformation Magnitude')
    plotter.show()

    # Print statistics
    deformation = registered_mesh['DeformationMagnitude']
    print(f"Deformation statistics:")
    print(f"  Min: {deformation.min():.2f} mm")
    print(f"  Max: {deformation.max():.2f} mm")
    print(f"  Mean: {deformation.mean():.2f} mm")
    print(f"  Std: {deformation.std():.2f} mm")
else:
    print("DeformationMagnitude not found in mesh point data")