In [1]:
import sys
import os
from pathlib import Path

project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(project_root)

GNDTRUTH = Path('../data/raw/COMMON_images_masks/')
RAWIMGS = Path('../data/raw/GROUP_images/')
DATA = (RAWIMGS.parent).parent
MASKS = DATA/'masks'

import SimpleITK as sitk

In [2]:
from ipywidgets import interact, IntSlider, fixed
import matplotlib.pyplot as plt
import numpy as np

def show_coronal_slice(arr, slc, mask=None):
    plt.figure(figsize=(5,5))
    plt.imshow(arr[:, slc, :], cmap='gray')
    
    if mask is not None:
        plt.imshow(
            np.ma.masked_where(mask[:, slc, :] == 0, mask[:, slc, :]),
            cmap='hsv',
            vmin=0,
            vmax=5,
            alpha=0.4
        )
    
    plt.axis('off')
    plt.title(f'Slice {slc}')
    plt.show()

def show_interactive(arr, fn, mask=None):
    return interact(
        fn,
        slc=IntSlider(min=0, max=arr.shape[1]-1, step=1, value=arr.shape[1]//2),
        arr=fixed(arr),
        mask=fixed(mask)
    )

def show_coronal_overlay(fixed_img, moving_img, slc):
    plt.figure(figsize=(5,5))
    
    plt.imshow(fixed_img[:, slc, :], cmap='Blues')
    plt.imshow(moving_img[:, slc, :], cmap='Reds', alpha=0.3)
    
    plt.axis('off')
    plt.title(f'Coronal slice {slc}')
    plt.show()

def show_interactive_overlay(fixed_img, moving_img):
    return interact(
        show_coronal_overlay,
        slc=IntSlider(
            min=0,
            max=fixed_img.shape[1]-1,
            step=1,
            value=fixed_img.shape[1]//2
        ),
        fixed_img=fixed(fixed_img),
        moving_img=fixed(moving_img)
    )

In [3]:
def bone_mask_from_hu(img, hu_min=206):
    """
    img: SimpleITK image (CT en HU)
    hu_min: umbral mínimo para hueso
    returns: SimpleITK binary mask (UInt8)
    """

    # Threshold HU
    mask = img >= hu_min
    mask = sitk.Cast(mask, sitk.sitkUInt8)
    # Basic cleaning
    # mask = sitk.BinaryMorphologicalOpening(mask, [2,2,2])
    mask = sitk.BinaryMorphologicalClosing(mask, [4,4,4])
    mask = sitk.BinaryFillhole(mask)

    return mask

def est_lin_transf(im_ref, im_mov, mask_ref=None, mask_mov=None):
    """
    Estimate affine transform to align im_mov to im_ref.
    Optionally uses masks to focus registration on bone.
    Returns a SimpleITK Transform.
    """

    # --- Initial transform: center alignment ---
    init_transform = sitk.CenteredTransformInitializer(
        im_ref,
        im_mov,
        sitk.AffineTransform(3),
        sitk.CenteredTransformInitializerFilter.GEOMETRY
    )

    # --- Registration setup ---
    reg = sitk.ImageRegistrationMethod()
    reg.SetInitialTransform(init_transform, inPlace=False)

    # Metric
    reg.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    if mask_ref is not None:
        reg.SetMetricFixedMask(mask_ref)
    if mask_mov is not None:
        reg.SetMetricMovingMask(mask_mov)

    # Optimizer
    reg.SetOptimizerAsGradientDescent(
        learningRate=1.0,
        numberOfIterations=300,
        convergenceMinimumValue=1e-6,
        convergenceWindowSize=10
    )
    reg.SetOptimizerScalesFromPhysicalShift()

    # Interpolator
    reg.SetInterpolator(sitk.sitkLinear)

    # --- Execute ---
    final_transform = reg.Execute(sitk.Cast(im_ref, sitk.sitkFloat32), 
                                  sitk.Cast(im_mov, sitk.sitkFloat32))
    print(final_transform)
    print("--------")
    print("Optimizer stop condition: {0}".format(reg.GetOptimizerStopConditionDescription()))
    print("Number of iterations: {0}".format(reg.GetOptimizerIteration()))
    print("--------")

    return final_transform

def apply_lin_transf(im_mov, lin_xfm, im_ref, is_mask=False):
    """
    Apply linear transform to im_mov and resample into im_ref space.
    """
    interpolator = sitk.sitkNearestNeighbor if is_mask else sitk.sitkLinear
    return sitk.Resample(
        im_mov,
        im_ref,
        lin_xfm,
        interpolator,
        0,
        im_mov.GetPixelID()
    )


In [4]:
fix_im = sitk.ReadImage(GNDTRUTH/'common_40_image.nii.gz')
mov_im = sitk.ReadImage(RAWIMGS/'g1_54_image.nii.gz')
fix_msk = sitk.ReadImage(MASKS/'common_40_regmask.nii.gz')
mov_msk = sitk.ReadImage(MASKS/'g1_54_mask.nii.gz')

# Arrays from images
fix_im_data = sitk.GetArrayFromImage(fix_im)
fix_msk_data = sitk.GetArrayFromImage(fix_msk)
mov_im_data = sitk.GetArrayFromImage(mov_im)
mov_msk_data = sitk.GetArrayFromImage(mov_msk)


In [5]:
fix_bonemask = bone_mask_from_hu(fix_im,170)
fix_bonemask_data = sitk.GetArrayFromImage(fix_bonemask)
show_interactive(fix_im_data, show_coronal_slice, fix_bonemask_data)

interactive(children=(IntSlider(value=256, description='slc', max=511), Output()), _dom_classes=('widget-inter…

<function __main__.show_coronal_slice(arr, slc, mask=None)>

In [6]:
mov_bonemask = bone_mask_from_hu(mov_im,170)
mov_bonemask_data = sitk.GetArrayFromImage(mov_bonemask)
show_interactive(mov_im_data, show_coronal_slice, mov_bonemask_data)

interactive(children=(IntSlider(value=256, description='slc', max=511), Output()), _dom_classes=('widget-inter…

<function __main__.show_coronal_slice(arr, slc, mask=None)>

In [7]:
lin_tfm = est_lin_transf(fix_im, mov_im, fix_bonemask)

itk::simple::CompositeTransform
 CompositeTransform (0x2238e730)
   RTTI typeinfo:   itk::CompositeTransform<double, 3u>
   Reference Count: 1
   Modified Time: 672581
   Debug: Off
   Object Name: 
   Observers: 
     none
   TransformQueue: 
   >>>>>>>>>
   AffineTransform (0x222dcac0)
     RTTI typeinfo:   itk::AffineTransform<double, 3u>
     Reference Count: 1
     Modified Time: 672571
     Debug: Off
     Object Name: 
     Observers: 
       none
     Matrix: 
       1.07289 -0.00984917 0.0167226 
       0.0674793 1.04891 -0.05118 
       -0.0296269 0.0202447 1.00511 
     Offset: [-39.3569, 50.221, -1597.27]
     Center: [1, -129, 842.698]
     Translation: [-23.9214, 0.849444, -1595.61]
     Inverse: 
       0.931083 0.00903287 -0.015031 
       -0.0585024 0.951865 0.0494422 
       0.0286232 -0.018906 0.993479 
     Singular: 0
   TransformsToOptimizeFlags: 
           1 
   TransformsToOptimizeQueue: 
   PreviousTransformsToOptimizeUpdateTime: 0

--------
Optimizer stop con

In [8]:
mov_im_resampled = apply_lin_transf(mov_im, lin_tfm, fix_im)
mov_im_resampled_data = sitk.GetArrayFromImage(mov_im_resampled)

In [9]:
show_interactive_overlay(fix_im_data, mov_im_resampled_data)

interactive(children=(IntSlider(value=256, description='slc', max=511), Output()), _dom_classes=('widget-inter…

<function __main__.show_coronal_overlay(fixed_img, moving_img, slc)>

In [10]:
mov_im2 = sitk.ReadImage(RAWIMGS/'g1_55_image.nii.gz')
mov_im_data = sitk.GetArrayFromImage(mov_im2)

In [11]:
lin_tfm2 = est_lin_transf(fix_im, mov_im2, fix_bonemask)

itk::simple::CompositeTransform
 CompositeTransform (0x243b7960)
   RTTI typeinfo:   itk::CompositeTransform<double, 3u>
   Reference Count: 1
   Modified Time: 690614
   Debug: Off
   Object Name: 
   Observers: 
     none
   TransformQueue: 
   >>>>>>>>>
   AffineTransform (0x2231b430)
     RTTI typeinfo:   itk::AffineTransform<double, 3u>
     Reference Count: 1
     Modified Time: 690460
     Debug: Off
     Object Name: 
     Observers: 
       none
     Matrix: 
       1.0373 0.00880889 -0.0265391 
       0.0978759 1.02437 0.0276055 
       -0.0178563 -0.00768567 0.918477 
     Offset: [-4.62458, -16.2011, -1550.95]
     Center: [1, -129, 842.698]
     Translation: [-28.0881, 4.01599, -1618.67]
     Inverse: 
       0.965286 -0.00808972 0.0281348 
       -0.0927154 0.976766 -0.0320364 
       0.0179905 0.00801615 1.08904 
     Singular: 0
   TransformsToOptimizeFlags: 
           1 
   TransformsToOptimizeQueue: 
   PreviousTransformsToOptimizeUpdateTime: 0

--------
Optimizer st

In [12]:
mov_im2_resampled = apply_lin_transf(mov_im2, lin_tfm2, fix_im)
mov_im2_resampled_data = sitk.GetArrayFromImage(mov_im2_resampled)

In [13]:
show_interactive_overlay(fix_im_data, mov_im2_resampled_data)

interactive(children=(IntSlider(value=256, description='slc', max=511), Output()), _dom_classes=('widget-inter…

<function __main__.show_coronal_overlay(fixed_img, moving_img, slc)>

# Non-linear registration

In [51]:
def est_nl_transf(im_ref, im_mov, mask_ref=None, n_iter=50, smoothing=1.0):
    """
    Estimate non-linear transform (Demons) to align im_mov to im_ref.
    Returns a displacement field transform.
    """

    # --- Cast to float ---
    im_ref_f = sitk.Cast(im_ref, sitk.sitkFloat32)
    im_mov_f = sitk.Cast(im_mov, sitk.sitkFloat32)

    # --- Optional smoothing ---
    im_ref_f = sitk.SmoothingRecursiveGaussian(im_ref_f, smoothing)
    im_mov_f = sitk.SmoothingRecursiveGaussian(im_mov_f, smoothing)

    # --- Demons registration ---
    demons = sitk.DemonsRegistrationFilter()
    demons.SetNumberOfIterations(n_iter)
    demons.SetStandardDeviations(1.0)
    demons.SetSmoothDisplacementField(True)
    demons.SetMaximumRMSError(0.01)

    # --- Execute ---
    disp_field = demons.Execute(im_ref_f, im_mov_f)

    # --- Convert to transform ---
    nl_transform = sitk.DisplacementFieldTransform(disp_field)

    return nl_transform

def apply_nl_transf(im_mov, nl_xfm, ref_img, is_mask=False):
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(ref_img)
    resampler.SetTransform(nl_xfm)

    if is_mask:
        resampler.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        resampler.SetInterpolator(sitk.sitkLinear)

    return resampler.Execute(im_mov)

In [None]:
nl_tfm = est_nl_transf(fix_im, mov_im2_resampled, fix_bonemask)

## With B-spline

In [18]:
def est_nl_transf(im_ref, im_mov, mask_ref=None, grid_spacing=40):
    """
    Estimate non-linear BSpline transform to align im_mov to im_ref.
    Returns the BSpline transform.
    """

    # --- Cast to float ---
    im_ref = sitk.Cast(im_ref, sitk.sitkFloat32)
    im_mov = sitk.Cast(im_mov, sitk.sitkFloat32)

    # --- Initialize BSpline grid ---
    spacing = im_ref.GetSpacing()
    size = im_ref.GetSize()

    grid_size = [
        int(size[i] * spacing[i] / grid_spacing)
        for i in range(3)
    ]

    initial_transform = sitk.BSplineTransformInitializer(
        image1=im_ref,
        transformDomainMeshSize=grid_size,
        order=3
    )

    # --- Registration method ---
    reg = sitk.ImageRegistrationMethod()

    reg.SetInitialTransform(initial_transform, inPlace=False)

    # Metric
    reg.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)

    if mask_ref is not None:
        reg.SetMetricFixedMask(mask_ref)

    # Optimizer
    reg.SetOptimizerAsLBFGSB(
        gradientConvergenceTolerance=1e-5,
        numberOfIterations=50,
        maximumNumberOfCorrections=5,
        maximumNumberOfFunctionEvaluations=1000
    )

    # Interpolator
    reg.SetInterpolator(sitk.sitkLinear)

    # --- Execute ---
    final_transform = reg.Execute(im_ref, im_mov)

    print(final_transform)
    print("--------")
    print("Optimizer stop condition: {0}".format(reg.GetOptimizerStopConditionDescription()))
    print("Number of iterations: {0}".format(reg.GetOptimizerIteration()))
    print("--------")

    return final_transform

def apply_nl_transf(im_mov, nl_xfm, ref_img, is_mask=False):
    resampler = sitk.ResampleImageFilter()
    resampler.SetReferenceImage(ref_img)
    resampler.SetTransform(nl_xfm)

    if is_mask:
        resampler.SetInterpolator(sitk.sitkNearestNeighbor)
    else:
        resampler.SetInterpolator(sitk.sitkLinear)

    return resampler.Execute(im_mov)

In [19]:
nl_tfm = est_nl_transf(fix_im, mov_im2_resampled, fix_bonemask)

In [20]:
mov_im2_nlreg = apply_nl_transf(mov_im2_resampled, nl_tfm, fix_im)
mov_im2_nlreg_data = sitk.GetArrayFromImage(mov_im2_nlreg)

In [21]:
show_interactive_overlay(fix_im_data, mov_im2_nlreg_data)

interactive(children=(IntSlider(value=256, description='slc', max=511), Output()), _dom_classes=('widget-inter…

<function __main__.show_coronal_overlay(fixed_img, moving_img, slc)>