Skip to content

MolloiLab/MedicalImageRegistration.jl

Repository files navigation

MedicalImageRegistration.jl

A minimal Julia library for 2D and 3D medical image registration, inspired by torchreg.

CI MIT License

Demo

Registration Demo

The animation shows: static target → misaligned moving image → registration in progress → aligned result → checkerboard overlay comparison

Features

  • Affine Registration: Translation, rotation, zoom, and shear optimization
  • SyN Registration: Symmetric diffeomorphic (deformable) registration
  • Multiresolution: Coarse-to-fine optimization for speed and robustness
  • 2D and 3D: Full support for both image dimensions
  • GPU Acceleration: Transparent CPU/CUDA/Metal support via AcceleratedKernels.jl
  • Automatic Differentiation: Mooncake.jl for gradient computation

Installation

using Pkg
Pkg.add("MedicalImageRegistration")

Quick Start

Affine Registration

using MedicalImageRegistration
using Metal  # or CUDA for NVIDIA GPUs

# Load images as arrays (X, Y, Z, C, N) - Julia convention
moving = MtlArray(rand(Float32, 64, 64, 64, 1, 1))  # GPU array
static = MtlArray(rand(Float32, 64, 64, 64, 1, 1))

# Create registration object
reg = AffineRegistration{Float32}(
    is_3d=true,
    scales=(4, 2),
    iterations=(500, 100),
    array_type=MtlArray  # Use MtlArray for Metal GPU
)

# Run registration
moved = register(reg, moving, static)

# Access the affine matrix
affine = get_affine(reg)

# Apply transform to another image
another_moved = transform(reg, another_image)

SyN (Diffeomorphic) Registration

reg = SyNRegistration{Float32}(
    scales=(4, 2, 1),
    iterations=(30, 30, 10),
    array_type=MtlArray
)
moved = register(reg, moving, static)

Custom Loss Functions

reg = AffineRegistration{Float32}(
    is_3d=true,
    learning_rate=0.01f0
)

# Use dice_loss instead of default mse_loss
moved = register(reg, moving, static; loss_fn=dice_loss)

Running the Demo

To run the interactive demo with TestImages.jl:

cd examples
julia demo.jl

This will:

  1. Automatically detect and use Metal GPU (Apple Silicon) if available
  2. Load a test image (cameraman)
  3. Create a synthetically misaligned version
  4. Run affine registration to recover alignment
  5. Generate a GIF animation showing the process
  6. Save output images to examples/output/

GPU Acceleration: The demo automatically uses Metal GPU on macOS with Apple Silicon. CPU fallback is used when GPU is not available.

GPU Requirements

This package is designed with a GPU-first architecture. While it can run on CPU, optimal performance requires a GPU:

Supported GPU Backends

Backend GPU Type Package Array Type
Metal Apple Silicon (M1/M2/M3) Metal.jl MtlArray
CUDA NVIDIA GPUs CUDA.jl CuArray
ROCm AMD GPUs AMDGPU.jl ROCArray
CPU Any (built-in) Array

GPU Usage Example

# Metal (Apple Silicon)
using Metal
using MedicalImageRegistration

moving = MtlArray(load_image("moving.nii"))
static = MtlArray(load_image("static.nii"))
reg = AffineRegistration{Float32}(is_3d=true, array_type=MtlArray)
moved = register(reg, moving, static)

# CUDA (NVIDIA)
using CUDA
moving = CuArray(load_image("moving.nii"))
static = CuArray(load_image("static.nii"))
reg = AffineRegistration{Float32}(is_3d=true, array_type=CuArray)
moved = register(reg, moving, static)

Performance Notes

  • GPU acceleration provides 10-100x speedup over CPU for typical medical images
  • Memory requirements scale with image size: expect ~4x image size for forward pass
  • Multi-resolution pyramid reduces memory usage and improves convergence

Intensity Conservation (HU Values)

When registering CT images, intensity (Hounsfield Unit) conservation is not guaranteed by default interpolation.

Why HU Values Change

Image registration uses interpolation (bilinear/trilinear) to resample the moving image. Interpolation creates new pixel values by averaging nearby pixels, which can:

  1. Smooth edges: Sharp boundaries between tissues become blurred
  2. Shift mean values: Average HU in a region may change slightly
  3. Introduce new values: Interpolated values may not exist in the original image

HU-Preserving Mode (Nearest-Neighbor)

This package supports hybrid interpolation mode for HU preservation:

  • During optimization: Bilinear/trilinear interpolation for smooth gradients
  • Final output: Nearest-neighbor interpolation to preserve exact input values
# Register with HU preservation
reg = AffineRegistration{Float32}(is_3d=true)
moved = register(reg, moving_ct, static_ct; final_interpolation=:nearest)

# Output values are EXACT subset of input values (HU preserved)
@assert issubset(Set(moved), Set(moving_ct))  # True!

# Or use transform() with interpolation kwarg
moved_nearest = transform(reg, moving_ct; interpolation=:nearest)

Recommendations for CT Images

Use Case Interpolation Mode Code
Visual alignment :bilinear (default) register(reg, moving, static)
Quantitative analysis :nearest register(reg, moving, static; final_interpolation=:nearest)
Dose calculation :nearest transform(reg, ct; interpolation=:nearest)
Segmentation transfer :bilinear + threshold transform(reg, mask) .> 0.5

How Hybrid Mode Works

  1. Optimization phase: Uses smooth bilinear/trilinear interpolation for gradient-based optimization
  2. Final output: Applies the learned transformation with nearest-neighbor to preserve exact values

This ensures the registration converges properly (smooth gradients) while the final result preserves exact intensity values (no interpolation artifacts).

Shepp-Logan Phantom Demo: Standard vs HU-Preserving

The following demo compares standard bilinear interpolation with HU-preserving nearest-neighbor mode using the Shepp-Logan phantom:

Standard (Bilinear) HU-Preserving (Nearest)
Standard HU-Preserving

Quantitative Results:

Metric Standard (Bilinear) HU-Preserving (Nearest)
New values created 2559 0
All values from input No Yes
Suitable for quantitative analysis No Yes

Run the demo yourself:

cd examples
julia demo_hu_preservation.jl

This demo:

  1. Loads the Shepp-Logan phantom (256x256)
  2. Creates synthetic misalignment (translation, rotation, zoom)
  3. Runs registration with both interpolation modes
  4. Generates comparison GIFs and intensity histograms
  5. Prints quantitative analysis showing value preservation

Clinical CT Registration

For clinical CT imaging with resolution mismatches and contrast agents, use the register_clinical workflow.

The Clinical Scenario

A common clinical scenario is registering cardiac CT scans with different parameters:

Property Scan 1 (Static/Reference) Scan 2 (Moving)
Contrast Non-contrast With IV contrast
Slice Thickness 3.0 mm 0.5 mm
Blood HU ~40 HU ~300+ HU
Use Case Calcium scoring Coronary visualization

Challenges:

  • 6x resolution difference in z-direction (3mm vs 0.5mm slices)
  • Intensity mismatch from contrast agent (blood goes 40→300+ HU)
  • HU preservation required for quantitative analysis (calcium scoring threshold = 130 HU)

PhysicalImage Type

Wrap volumes with physical spacing metadata:

using MedicalImageRegistration

# Create PhysicalImage from volume + spacing
volume = load_nifti("cardiac_ct.nii")  # (X, Y, Z, C, N) array
spacing = (0.5f0, 0.5f0, 0.5f0)        # (x, y, z) in mm
origin = (0f0, 0f0, 0f0)               # Optional origin

img = PhysicalImage(volume; spacing=spacing, origin=origin)

# Access properties
spatial_size(img)     # (512, 512, 403)
spatial_spacing(img)  # (0.5, 0.5, 0.5)
img.data              # The underlying array

register_clinical Workflow

using MedicalImageRegistration
using Metal  # For GPU acceleration

# Load CT scans with different resolutions
non_contrast = PhysicalImage(volume1; spacing=(0.5f0, 0.5f0, 3.0f0))  # 3mm z-spacing
contrast = PhysicalImage(volume2; spacing=(0.5f0, 0.5f0, 0.5f0))      # 0.5mm z-spacing

# Register contrast (moving) → non-contrast (static)
result = register_clinical(
    contrast, non_contrast;
    registration_resolution=2.0f0,  # Resample to 2mm isotropic for optimization
    loss_fn=mi_loss,                # Mutual Information for contrast mismatch
    preserve_hu=true,               # Nearest-neighbor for final output
    registration_type=:affine,      # Or :syn for deformable
    verbose=true
)

# Access results
result.moved_image      # PhysicalImage with registered contrast CT
result.transform        # Displacement field (can apply to other images)
result.metrics          # Dict with :mi_before, :mi_after, :mi_improvement
result.metadata         # Registration parameters and image info

Why Mutual Information?

Standard loss functions fail with contrast mismatch:

Loss Function Assumption Problem with Contrast
MSE Same intensity = aligned Blood is 40 vs 300 HU → penalizes correct alignment
NCC Linear intensity relationship Nonlinear per-tissue contrast enhancement
MI Statistical dependence Learns 40 HU ↔ 300 HU correspondence ✓

Mutual Information measures statistical dependence, not intensity similarity. It learns that anatomically corresponding points have consistent intensity mappings.

Workflow Under the Hood

┌─────────────────────────────────────────────────┐
│ 1. Resample both to registration_resolution     │
│    (bilinear OK - just for optimization)        │
└───────────────────────────────────────────────→─┘
                        ↓
┌─────────────────────────────────────────────────┐
│ 2. Register with MI loss                        │
│    (handles contrast intensity difference)       │
└───────────────────────────────────────────────→─┘
                        ↓
┌─────────────────────────────────────────────────┐
│ 3. Upsample transform to original resolution    │
│    (transform is smooth, bilinear OK)           │
└───────────────────────────────────────────────→─┘
                        ↓
┌─────────────────────────────────────────────────┐
│ 4. Apply to ORIGINAL moving image               │
│    preserve_hu=true → nearest-neighbor          │
│    Output HU values = EXACT input values        │
└─────────────────────────────────────────────────┘

ClinicalRegistrationResult

struct ClinicalRegistrationResult{T, N, A}
    moved_image::PhysicalImage{T,N,A}  # Registered image
    transform::A                        # Displacement field
    inverse_transform::Union{A,Nothing} # Optional inverse
    metrics::Dict{Symbol, T}            # :mi_before, :mi_after, :mi_improvement
    metadata::Dict{Symbol, Any}         # Spacing, sizes, parameters
end

Apply Transform to Other Images

# Apply the same transform to a segmentation mask
mask = PhysicalImage(mask_volume; spacing=contrast.spacing)
mask_transformed = transform_clinical(result, mask; interpolation=:nearest)

# Apply inverse transform (if computed)
result = register_clinical(...; compute_inverse=true)
inverse_transformed = transform_clinical_inverse(result, some_image)

Loss Function Selection

Scenario Recommended Loss Example
Same modality, no contrast mse_loss T1 MRI to T1 MRI
Same modality with preprocessing ncc_loss Skull-stripped MRI
Different contrast agents mi_loss Contrast CT to non-contrast CT
Multi-modal mi_loss CT to MRI
Binary segmentation dice_loss Mask alignment

Interpolation Mode Selection

Use Case Interpolation Reason
Visual alignment :bilinear Smooth appearance
Calcium scoring :nearest Exact HU for 130 HU threshold
Dose calculation :nearest HU → electron density mapping
Tissue density measurement :nearest Quantitative accuracy
Segmentation transfer :bilinear + threshold Smooth probability maps

Complete Example: Cardiac CT Registration

See the full interactive example in examples/cardiac_ct.jl (Pluto notebook).

using MedicalImageRegistration
using DICOM

# Load DICOM series (simplified - see notebook for full loader)
non_contrast_vol, nc_spacing = load_dicom_series("path/to/non_contrast/")
contrast_vol, ccta_spacing = load_dicom_series("path/to/ccta/")

# Create PhysicalImages
nc = PhysicalImage(Float32.(non_contrast_vol); spacing=Float32.(nc_spacing))
ccta = PhysicalImage(Float32.(contrast_vol); spacing=Float32.(ccta_spacing))

# Register
result = register_clinical(
    ccta, nc;
    registration_resolution=2.0f0,
    loss_fn=mi_loss,
    preserve_hu=true,
    registration_type=:affine,
    affine_scales=(4, 2, 1),
    affine_iterations=(50, 25, 10),
    verbose=true
)

# Verify HU preservation
original_values = Set(vec(ccta.data))
output_values = Set(vec(result.moved_image.data))
@assert output_values  original_values  # True!

# Report metrics
println("MI improved: $(result.metrics[:mi_before])$(result.metrics[:mi_after])")

Array Conventions

Julia uses column-major order. This package follows Julia conventions:

Dimension Julia (this package) PyTorch (torchreg)
Spatial (X, Y) or (X, Y, Z) (Y, X) or (Z, Y, X)
Full 2D (X, Y, C, N) (N, C, Y, X)
Full 3D (X, Y, Z, C, N) (N, C, Z, Y, X)

API Reference

Core Types

# Affine Registration
AffineRegistration{T}(;
    is_3d::Bool=true,           # 3D (true) or 2D (false)
    scales::Tuple=(4, 2),       # Multi-resolution pyramid scales
    iterations::Tuple=(500, 100), # Iterations per scale
    learning_rate::T=0.01,      # Optimizer learning rate
    with_translation::Bool=true, # Enable translation
    with_rotation::Bool=true,    # Enable rotation
    with_zoom::Bool=true,        # Enable zoom/scale
    with_shear::Bool=false,      # Enable shear
    align_corners::Bool=true,    # Grid sampling mode
    padding_mode::Symbol=:border, # :zeros or :border
    array_type::Type=Array       # Array type for GPU
)

# SyN (Diffeomorphic) Registration
SyNRegistration{T}(;
    scales::Tuple=(4, 2, 1),    # Multi-resolution scales
    iterations::Tuple=(30, 30, 10), # Iterations per scale
    learning_rate::T=0.01,      # Optimizer learning rate
    sigma_flow::T=1.0,          # Flow smoothing sigma
    sigma_img::T=0.0,           # Image smoothing sigma
    lambda_::T=1.0,             # Regularization weight
    time_steps::Int=7,          # Scaling-and-squaring steps
    array_type::Type=Array      # Array type for GPU
)

Functions

# Registration
register(reg, moving, static; loss_fn=mse_loss, verbose=true, final_interpolation=:bilinear)
fit!(reg, moving, static; loss_fn=mse_loss, verbose=true)
transform(reg, image; direction=:forward, interpolation=:bilinear)  # Apply learned transform
reset!(reg)  # Reset parameters to identity

# Affine-specific
get_affine(reg)  # Get current affine matrix
affine_transform(image, theta; interpolation=:bilinear)  # Apply explicit affine matrix
compose_affine(translation, rotation, zoom, shear)  # Build affine matrix
affine_grid(theta, size)  # Generate sampling grid from affine

# Loss Functions
mse_loss(pred, target)  # Mean Squared Error
dice_loss(pred, target)  # 1 - Dice coefficient
dice_score(pred, target) # Dice coefficient
ncc_loss(pred, target; kernel_size=9)  # Normalized Cross Correlation

# Low-level Operations
grid_sample(input, grid; padding_mode=:zeros, align_corners=true, interpolation=:bilinear)
spatial_transform(image, displacement; interpolation=:bilinear)  # Warp with displacement field
diffeomorphic_transform(velocity; time_steps=7)  # Scaling-and-squaring

# Interpolation Modes
# :bilinear/:trilinear - Smooth gradients, creates new values (default)
# :nearest - HU-preserving, returns exact input values, zero gradients

# Clinical Registration (anisotropic voxels, contrast mismatch)
PhysicalImage(data; spacing=(1,1,1), origin=(0,0,0))  # Wrap array with physical metadata
register_clinical(moving, static; registration_resolution=2.0, loss_fn=mi_loss, preserve_hu=true)
transform_clinical(result, image; interpolation=:nearest)  # Apply learned transform
transform_clinical_inverse(result, image)  # Apply inverse transform

# Physical Image Operations
spatial_size(img)      # Returns (X, Y, Z) size
spatial_spacing(img)   # Returns (sx, sy, sz) spacing in mm
resample(img, target_spacing)  # Resample to new spacing

# Multi-modal Loss Functions
mi_loss(pred, target; bins=64)   # Mutual Information (for contrast mismatch)
nmi_loss(pred, target; bins=64)  # Normalized MI (more robust)

Dependencies

License

MIT

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages