
# Image Registration and Deformation Gradient Analysis with SimpleElastix

This notebook demonstrates how to use **SimpleElastix** (via SimpleITK) for:
- Creating synthetic 3D image data
- Applying known deformations (translation, affine)
- Performing image registration to recover transformations
- Computing deformation gradient and tensor quantities
- Visualizing results with PyVista

We will go step by step to understand how registration and deformation analysis works.


In [1]:
# Install required packages
!pip install SimpleITK-SimpleElastix pyvista xmltodict



In [2]:
import numpy as np
import pyvista as pv
import SimpleITK as sitk
from pathlib import Path


## Helper Functions

We define utility functions for property lookup and visualization.  
The visualization functions use **PyVista** to plot images and vector fields in 3D.


In [None]:
def PlotImage(Image, Spacing, Sampling=1):
    """Visualize a 3D volume using PyVista.

    Parameters
    ----------
    Image : np.ndarray or SimpleITK.Image
    Spacing : sequence(float)
    Sampling : int
        Subsampling factor to speed rendering.
    """
    if type(Image) == sitk.SimpleITK.Image:
        Array = sitk.GetArrayFromImage(Image).T
    else:
        Array = Image

    Shape = np.round(np.array(Array.shape) * Spacing * 1e3,1)
    Array = Array[::Sampling,::Sampling,::Sampling]

    Args = dict(font_family='times', 
                font_size=30,
                location='outer',
                show_xlabels=False,
                show_ylabels=False,
                show_zlabels=False,
                all_edges=True,
                fmt='%i',
                xtitle=f'{Shape[0]} mm',
                ytitle=f'{Shape[1]} mm',
                ztitle=f'{Shape[2]} mm',
                use_3d_text=False
                )

    # Plot using pyvista
    Plot = pv.Plotter(off_screen=True)
    Actors = Plot.add_volume(Array,cmap='bone',show_scalar_bar=False, opacity='sigmoid_7')
    Actors.prop.interpolation_type = 'linear'
    Plot.camera_position = 'xz'
    Plot.camera.roll = 0
    Plot.camera.elevation = 30
    Plot.camera.azimuth = -60
    Plot.camera.zoom(1)
    Plot.show_bounds(**Args)
    Plot.add_axes()
    Plot.show()


    return

def PlotVectors(Vectors, Spacing, Sampling=8):
    """Visualize a 3D vector field as arrows using PyVista.

    Parameters
    ----------
    Vectors : np.ndarray or SimpleITK.Image
        Expected shape (X, Y, Z, 3) or transposed variants handled by the function.
    Spacing : sequence(float)
    Sampling : int
    """
    if type(Vectors) == sitk.SimpleITK.Image:
        Array = sitk.GetArrayFromImage(Vectors).T
    else:
        Array = Vectors

    Shape = np.round(np.array(Array.shape[:-1]) * Spacing * 1e3,1)
    Array = Array[::Sampling,::Sampling,::Sampling]

    nX, nY, nZ = Array.shape[:-1]
    X, Y, Z = np.meshgrid(np.arange(nX), np.arange(nY), np.arange(nZ), indexing="ij")

    Points = np.column_stack((X.ravel(), Y.ravel(), Z.ravel()))
    Array = Array.reshape(-1, 3)

    PointCloud = pv.PolyData(Points, force_float=False)
    PointCloud['Vectors'] = Array
    Arrows = PointCloud.glyph(orient='Vectors', scale=True, factor=1.0)


    Args = dict(font_family='times', 
                font_size=30,
                location='outer',
                show_xlabels=False,
                show_ylabels=False,
                show_zlabels=False,
                all_edges=True,
                fmt='%i',
                xtitle=f'{Shape[0]} mm',
                ytitle=f'{Shape[1]} mm',
                ztitle=f'{Shape[2]} mm',
                use_3d_text=False
                )
    
    sArgs = dict(font_family='times', 
                 width=0.05,
                 height=0.75,
                 vertical=True,
                 position_x=0.9,
                 position_y=0.125,
                 title_font_size=30,
                 label_font_size=20,
                 title='Displacement')

    # Plot using pyvista
    Plot = pv.Plotter(off_screen=True)
    Actors = Plot.add_mesh(Arrows,cmap='jet',show_scalar_bar=True, scalar_bar_args=sArgs)
    Plot.camera_position = 'xz'
    Plot.camera.roll = 0
    Plot.camera.elevation = 30
    Plot.camera.azimuth = -60
    Plot.camera.zoom(1)
    Plot.show_bounds(**Args)
    Plot.add_axes()
    Plot.show()


    return

def PlotTensor(Image, Spacing, Title='Value', Sampling=1):
    """Visualize a 3D volume using PyVista.

    Parameters
    ----------
    Image : np.ndarray or SimpleITK.Image
    Spacing : sequence(float)
    Sampling : int
        Subsampling factor to speed rendering.
    """
    if type(Image) == sitk.SimpleITK.Image:
        Array = sitk.GetArrayFromImage(Image).T
    else:
        Array = Image

    Shape = np.round(np.array(Array.shape) * Spacing * 1e3,1)
    Array = Array[::Sampling,::Sampling,::Sampling]

    Args = dict(font_family='times', 
                font_size=30,
                location='outer',
                show_xlabels=False,
                show_ylabels=False,
                show_zlabels=False,
                all_edges=True,
                fmt='%i',
                xtitle=f'{Shape[0]} mm',
                ytitle=f'{Shape[1]} mm',
                ztitle=f'{Shape[2]} mm',
                use_3d_text=False
                )
    
    sArgs = dict(font_family='times', 
                 width=0.05,
                 height=0.75,
                 vertical=True,
                 position_x=0.9,
                 position_y=0.125,
                 title_font_size=30,
                 label_font_size=20,
                 title=Title)

    # Plot using pyvista
    Plot = pv.Plotter(off_screen=True)
    Actors = Plot.add_volume(Array,cmap='jet',show_scalar_bar=True, scalar_bar_args=sArgs, opacity=0.05)
    Actors.prop.interpolation_type = 'linear'
    Plot.camera_position = 'xz'
    Plot.camera.roll = 0
    Plot.camera.elevation = 30
    Plot.camera.azimuth = -60
    Plot.camera.zoom(1)
    Plot.show_bounds(**Args)
    Plot.add_axes()
    Plot.show()


    return

## Step 1: Create Synthetic 3D Data

We generate a 3D cube with random noise, then add a structured block of voxels.  

In [None]:
# Define image parameters
Size = (64, 64, 64)
Spacing = (1.0e-3, 1.0e-3, 1.0e-3)
NoiseLevel = 0.3
Cube = np.random.rand(*Size) * NoiseLevel

# Define cube coordinates
Start, Stop, Step = 20, 44, 0.2
Coords = np.arange(Start, Stop+Step, Step)
X, Y, Z = np.meshgrid(Coords, Coords, Coords, indexing='ij')
Coords = np.stack([X.ravel(), Y.ravel(), Z.ravel()], axis=0)
Center = np.mean(Coords, axis=1, keepdims=True)
VoxelValues = np.random.rand(Coords.shape[1])

# Place cube in image
rX, rY, rZ = np.round(Coords).astype(int)
Cube[rX, rY, rZ] = VoxelValues
PlotImage(Cube, Spacing)

## Create Synthetic Deformed Cube

Define and apply a custom deformation  

In [None]:
# Define deformation gradient
F = np.array([[1.0, 0.0, 0.0],
              [0.0, 1.0, 0.0],
              [0.0, 0.0, 1.0]])

# Define translation in voxels
t = np.array([10, 0, 0])

# Apply transform
tCoords = F @ (Coords-Center) + t[:, None]
tCoords += Center
tCoords = np.clip(tCoords, 0, np.array(Size)[:, None]-1)
tX, tY, tZ = np.rint(tCoords).astype(int)
dCube = np.random.rand(*Size)*NoiseLevel
dCube[tX, tY, tZ] = VoxelValues

PlotImage(dCube, Spacing)

## Step 2: Registration

We register the fixed and moving images using a translation transform only.  
This allows us to recover the translation vector applied earlier.


In [None]:
# Convert to numpy array to SimpleITK image
Fixed = sitk.GetImageFromArray(Cube.T)
Moving = sitk.GetImageFromArray(dCube.T)
Fixed.SetSpacing(Spacing)
Moving.SetSpacing(Spacing)

# Define parameter map
ParameterMap = sitk.GetDefaultParameterMap('translation')
ParameterMap['SP_alpha'] = ['0.6']
ParameterMap['SP_A'] = ['100']
ParameterMap['MaximumNumberOfIterations'] = ['256']
ParameterMap['NewSamplesEveryIteration'] = ['true']
Schedule = np.repeat([4, 2, 1], Fixed.GetDimension())
ParameterMap['FixedImagePyramidSchedule'] = [str(S) for S in Schedule]
ParameterMap['MovingImagePyramidSchedule'] = [str(S) for S in Schedule]
sitk.PrintParameterMap(ParameterMap)

# Perform registration and plot results
Elastix = sitk.ElastixImageFilter()
Elastix.SetFixedImage(Fixed)
Elastix.SetMovingImage(Moving)
Elastix.SetParameterMap(ParameterMap)
Result = Elastix.Execute()
PlotImage(Result, Spacing)

## Step 3: Deformation Field Computation

We use Transformix to compute the deformation field resulting from the registration and its gradient

In [None]:
# Compute deformation field
Transformix = sitk.TransformixImageFilter()
TransformParameterMap = Elastix.GetTransformParameterMap()
Transformix.SetTransformParameterMap(TransformParameterMap)
Transformix.ComputeDeformationFieldOn()
Transformix.ComputeSpatialJacobianOn()
Transformix.SetMovingImage(Moving)
Transformix.SetOutputDirectory('TransformixOuputs')
Transformed = Transformix.Execute()
PlotImage(Transformed, Spacing)

# Extract translation parameters
Parameters = TransformParameterMap[0]
t_reg = np.array([float(Parameters['TransformParameters'][0]),
                  float(Parameters['TransformParameters'][1]),
                  float(Parameters['TransformParameters'][2])])
print("Recovered translation (voxels):", np.round(t_reg/np.array(Spacing),1))


## Combined Registration

Sometimes, it is necessary to perform the registration in multiple steps (like rigid + bspline) to recover the deformation field accurately. This can be performed using parameter map vector as shown below.

In [None]:
# Parameter maps
ParameterMap1 = sitk.GetDefaultParameterMap('translation')
ParameterMap1['MaximumNumberOfIterations'] = ['256']
ParameterMap1['NewSamplesEveryIteration'] = ['true']

ParameterMap2 = sitk.GetDefaultParameterMap('affine')
ParameterMap2['MaximumNumberOfIterations'] = ['256']
ParameterMap2['NewSamplesEveryIteration'] = ['true']

ParameterMapVector = sitk.VectorOfParameterMap()
ParameterMapVector.append(ParameterMap1)
ParameterMapVector.append(ParameterMap2)

# Registration
Elastix = sitk.ElastixImageFilter()
Elastix.SetFixedImage(Fixed)
Elastix.SetMovingImage(Moving)
Elastix.SetParameterMap(ParameterMapVector)
Result = Elastix.Execute()
PlotImage(Result, Spacing)

In [None]:
# Again, use Transformix to compute deformation field
Transformix = sitk.TransformixImageFilter()
TransformParameterMap = Elastix.GetTransformParameterMap()
Transformix.SetTransformParameterMap(TransformParameterMap)
Transformix.ComputeDeformationFieldOn()
Transformix.ComputeSpatialJacobianOn()
Transformix.SetMovingImage(Moving)
Transformix.SetOutputDirectory('TransformixOuputs')
Transformed = Transformix.Execute()
PlotImage(Transformed, Spacing)

## Extract Parameters

Note that now, the variable "TransformParameterMap" contains 2 elements. The first one is the result of the translation and the second is the affine deformation.

In [None]:
# Extract translation transform parameters
Parameters1 = TransformParameterMap[0]
t_reg1 = np.array([float(Parameters1['TransformParameters'][0]),
                   float(Parameters1['TransformParameters'][1]),
                   float(Parameters1['TransformParameters'][2])])
print("Recovered translation (voxels):", np.round(t_reg1/np.array(Spacing),1))

# Extract affine transform parameters
Parameters2 = TransformParameterMap[1]
F_reg = np.zeros((3,3))
for i in range(3):
    for j in range(3):
        F_reg[i,j] = float(Parameters2['TransformParameters'][i*3+j])
t_reg2 = np.array([float(Parameters2['TransformParameters'][9]),
                   float(Parameters2['TransformParameters'][10]),
                   float(Parameters2['TransformParameters'][11])])

print("Recovered affine F:\n", F_reg.round(3))
print("Recovered translation (voxels):", np.round(t_reg2/np.array(Spacing),1))

## Deformation Field and Deformation Gradient

Transformix will produce 2 outputs: the deformation field and the deformation gradient field.
From this, we compute:
- Determinant of F (Jacobian determinant)
- Isochoric part of F
- Right Cauchy–Green tensor

In [None]:
# Load resuting deformation field
File = Path('TransformixOuputs/deformationField.nii')
DefField = sitk.ReadImage(File)
DefField = sitk.GetArrayFromImage(DefField)
DefField = DefField.transpose((2,1,0,3))
PlotVectors(DefField, Spacing, Sampling=8)

In [None]:
# Read deformation gradient
File = Path('TransformixOuputs/fullSpatialJacobian.nii')
Jacobian = sitk.ReadImage(File)
Jacobian = sitk.GetArrayFromImage(Jacobian)
Jacobian = Jacobian.transpose((2,1,0,3))

# Build deformation gradient field
F_Field = np.zeros(Jacobian.shape[:-1] + (3,3))
for i in range(3):
    for j in range(3):
        F_Field[...,i,j] = Jacobian[...,-(i*3+j)-1]

# Determinant of F
J = np.linalg.det(F_Field)
PlotTensor(J, Spacing, Title='Volume Change (-)')

# Isochoric deformation gradient
F_tilde = np.reshape(J, J.shape + (1,1)) ** (-1 / 3) * F_Field
NormF_tilde = np.linalg.norm(F_tilde, axis=(-1, -2))
# PlotImage(NormF_tilde, Spacing)

# Right Cauchy–Green tensor
C = np.einsum('...ki,...kj->...ij', F_Field, F_Field)
print(f'Tensor C and image center:\n {C[32,32,32].round(3)}')