In [3]:
import sys
import os
from pathlib import Path

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

GNDTRUTH = Path('../data/raw/COMMON_images_masks/')
RAWIMGS = Path('../data/raw/GROUP_images/')
DATA = (RAWIMGS.parent).parent
MASKS = DATA/'masks'
LINREGS = DATA/'linear_registrations'

import SimpleITK as sitk
from utils import *
from linear_registration import register_linear

In [4]:
mov_names = ['g1_54', 'g1_55', 'g1_56']
fix_names = ['common_40', 'common_41', 'common_42']
fix_imgs = [sitk.ReadImage(GNDTRUTH/(name+'_image.nii.gz')) for name in fix_names]
fix_msks = [sitk.ReadImage(MASKS/(name+'_regmask.nii.gz')) for name in fix_names]
mov_imgs = [sitk.ReadImage(RAWIMGS/(name+'_image.nii.gz')) for name in mov_names]
mov_msks = [sitk.ReadImage(MASKS/(name+'_mask.nii.gz')) for name in mov_names]

fix_imgs_arr = []
fix_msks_arr = []
mov_imgs_arr = []
mov_msks_arr = []
for idx, n in enumerate(fix_names):
    fix_imgs_arr.append(sitk.GetArrayFromImage(fix_imgs[idx]))
    fix_msks_arr.append(sitk.GetArrayFromImage(fix_msks[idx]))
    mov_imgs_arr.append(sitk.GetArrayFromImage(mov_imgs[idx]))
    mov_msks_arr.append(sitk.GetArrayFromImage(mov_msks[idx]))

In [5]:
def est_lin_transf_multires(im_ref, im_mov, mask_ref=None, mask_mov=None, verbose=False):
    """
    Estimate affine transform to align im_mov to im_ref.
    Uses multi-resolution, masks, and robust MI configuration.
    Returns a SimpleITK Transform.
    """

    init_transform = sitk.CenteredTransformInitializer(
        im_ref,
        im_mov,
        sitk.AffineTransform(3),
        sitk.CenteredTransformInitializerFilter.GEOMETRY
    )

    reg = sitk.ImageRegistrationMethod()
    reg.SetInitialTransform(init_transform, inPlace=False)

    # --- Metric ---
    reg.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    reg.SetMetricSamplingStrategy(reg.RANDOM)
    reg.SetMetricSamplingPercentage(0.2)  # 20% voxels

    if mask_ref is not None:
        reg.SetMetricFixedMask(mask_ref)
    if mask_mov is not None:
        reg.SetMetricMovingMask(mask_mov)

    # --- Multi-resolution pyramid ---
    reg.SetShrinkFactorsPerLevel([4, 2, 1])
    reg.SetSmoothingSigmasPerLevel([2, 1, 0])
    reg.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

    # --- Optimizer ---
    reg.SetOptimizerAsGradientDescent(
        learningRate=1.0,
        numberOfIterations=500,
        convergenceMinimumValue=1e-6,
        convergenceWindowSize=10
    )
    reg.SetOptimizerScalesFromPhysicalShift()

    # --- Interpolator ---
    reg.SetInterpolator(sitk.sitkLinear)

    # --- Execute ---
    final_transform = reg.Execute(
        sitk.Cast(im_ref, sitk.sitkFloat32),
        sitk.Cast(im_mov, sitk.sitkFloat32)
    )

    if verbose:
        print("Final transform:")
        print(final_transform)
        print("Optimizer stop condition:", reg.GetOptimizerStopConditionDescription())
        print("Iterations:", reg.GetOptimizerIteration())
        print("Final metric value:", reg.GetMetricValue())

    return final_transform

In [6]:
def register_linear_multires(fix_im, mov_im, fix_mask=None, mov_mask=None, verbose=False):
    fix_need_lateral_crop = False
    mov_need_lateral_crop = False
    # If no mask is provided for either of the image, create one via HU threshold and perform basic cleaning to remove
    # some irrelevant highly attenuating elements.
    if fix_mask is None:
        if verbose: print('Creating bone mask for fixed image...')
        fix_mask = mask_from_hu(fix_im, hu_min=185, hu_max=500, closing_kernel_size=10, verbose=verbose)
        fix_mask = clean_bone_mask(fix_mask)
        fix_need_lateral_crop = True
    if mov_mask is None:
        if verbose: print('Creating bone mask for moving image...')
        mov_mask = mask_from_hu(mov_im, hu_min=185, hu_max=500, closing_kernel_size=10, verbose=verbose)
        mov_mask = clean_bone_mask(mov_mask)
        mov_need_lateral_crop = True

    # Crop fixed and moving images to ROI around right pelvis and femur
    if verbose: print('Cropping fixed image to ROI...')
    fix_im_roi, fix_mask_roi = crop_roi_from_mask_multi(fix_im, fix_mask)
    if fix_need_lateral_crop:
        fix_im_roi = crop_lateral(fix_im_roi, 'r')
        fix_mask_roi = crop_lateral(fix_mask_roi, 'r')
    if verbose: print('Cropping moving image to ROI...')
    mov_im_roi, mov_mask_roi = crop_roi_from_mask_multi(mov_im, mov_mask)
    if mov_need_lateral_crop:
        mov_im_roi = crop_lateral(mov_im_roi, 'r')
        # mov_mask_roi = crop_lateral(mov_mask_roi, 'r')

    if verbose: print('Estimating affine transformation...')
    lin_tfm = est_lin_transf_multires(fix_im_roi, mov_im_roi, mask_ref=fix_mask_roi, verbose=verbose)
    if verbose: print('Applying affine transformation...')
    mov_im_reg = apply_lin_transf(mov_im, lin_tfm, fix_im)
    mov_mask_reg = apply_lin_transf(mov_mask, lin_tfm, fix_im)
    if verbose: print('Linear registration finished.')

    if fix_need_lateral_crop:
        return {
            'registered_image': mov_im_reg,
            'registered_mask': mov_mask_reg,
            'affine_tfm': lin_tfm,
            'fixed_mask': fix_mask,
        }
    return {
        'registered_image': mov_im_reg,
        'registered_mask': mov_mask_reg,
        'affine_tfm': lin_tfm,
    }

In [7]:
fix_idx = 1
registrations = []
for idx, (mov_im, mov_msk) in enumerate(zip(mov_imgs, mov_msks)):
    print(f'---- Image {idx} ----')
    registrations.append(register_linear_multires(fix_imgs[fix_idx], mov_im, fix_mask=fix_msks[fix_idx], mov_mask=mov_msk, verbose=True))

---- Image 0 ----
Cropping fixed image to ROI...
Cropping moving image to ROI...
Estimating affine transformation...
Final transform:
itk::simple::CompositeTransform
 CompositeTransform (0x281546e0)
   RTTI typeinfo:   itk::CompositeTransform<double, 3u>
   Reference Count: 1
   Modified Time: 492586
   Debug: Off
   Object Name: 
   Observers: 
     none
   TransformQueue: 
   >>>>>>>>>
   AffineTransform (0x29541b40)
     RTTI typeinfo:   itk::AffineTransform<double, 3u>
     Reference Count: 1
     Modified Time: 492432
     Debug: Off
     Object Name: 
     Observers: 
       none
     Matrix: 
       0.979047 0.036998 -0.0295563 
       0.0682675 0.943021 -0.00916353 
       0.0163164 0.0112084 0.987607 
     Offset: [-17.6611, 13.5086, -409.865]
     Center: [-90.0312, -145.062, -348.697]
     Translation: [-10.8355, 18.8232, -408.638]
     Inverse: 
       1.02372 -0.0405239 0.0302612 
       -0.0742659 1.06325 0.00764277 
       -0.0160703 -0.0113973 1.01196 
     Singular: 0


In [8]:
mov_imgs_reg = [reg['registered_image'] for reg in registrations]
mov_msks_reg = [reg['registered_mask'] for reg in registrations]
lin_tfms = [reg['affine_tfm'] for reg in registrations]

In [9]:
mov_imgs_reg_data = [sitk.GetArrayFromImage(im) for im in mov_imgs_reg]
mov_msks_reg_data = [sitk.GetArrayFromImage(im) for im in mov_msks_reg]

In [11]:
show_interactive_overlay(fix_imgs_arr[fix_idx], mov_imgs_reg_data[2], 'sagital')

interactive(children=(IntSlider(value=256, description='slc', max=511), Output()), _dom_classes=('widget-interâ€¦

<function utils.visualization.show_sagital_overlay(fix_arr, mov_arr, slc)>