# 4D CT Reconstruction Using RegisterTimeSeriesImages Class

This notebook demonstrates the use of the `RegisterTimeSeriesImages` class to register a time series of CT images to a common reference frame.

This is a refactored version of `reconstruct_4d_ct.ipynb` that uses the new class-based approach.


In [None]:
import os

import itk
import numpy as np

from physiomotion4d import RegisterTimeSeriesImages, TransformTools


## Load Data and Set Parameters

Set `quick_run = True` for a fast test with fewer images, or `quick_run = False` for full processing.


In [None]:
# Load image files
data_dir = os.path.join("..", "..", "data", "Slicer-Heart-CT")
files = [
    os.path.join(data_dir, f)
    for f in sorted(os.listdir(data_dir))
    if f.endswith(".mha") and f.startswith("slice_")
]

print(f"Found {len(files)} slice files")


In [None]:
# Configuration
quick_run = True  # Set to True for quick testing

# Select files and parameters based on mode
if quick_run:
    print("=== QUICK RUN MODE ===")
    total_num_files = len(files)
    target_num_files = 5
    file_step = total_num_files // target_num_files
    files = files[0:total_num_files:file_step]
    files_indx = list(range(0, total_num_files, file_step))
    num_files = len(files)
    reference_image_num = num_files // 2

    # Registration parameters - only ANTs for quick run
    registration_methods = ["ants", "icon", "ants_icon"]
    number_of_iterations_list = [[8, 4, 1], 5, [[8, 4, 1], 5]]  # For ANTs and ICON
else:
    print("=== FULL RUN MODE ===")
    num_files = len(files)
    files_indx = list(range(0, num_files))
    reference_image_num = 7

    # Registration parameters - both ANTs and ICON for full run
    registration_methods = ["ants", "icon", "ants_icon"]
    number_of_iterations_list = [
        [30, 15, 7, 3],  # For ANTs
        20,           # For ICON
        [[30, 15, 7, 3], 20]  # For ants_icon
    ]

# Common parameters
reference_image_file = os.path.join(data_dir, f"slice_{files_indx[reference_image_num]:03d}.mha")
register_start_to_reference = False
portion_of_prior_transform_to_init_next_transform = 0.0

print(f"Number of files: {num_files}")
print(f"Reference image: slice_{files_indx[reference_image_num]:03d}.mha")
print(f"Registration methods: {registration_methods}")
print(f"Number of iterations: {number_of_iterations_list}")


## Load Images


In [None]:
# Load fixed/reference image
fixed_image = itk.imread(reference_image_file, pixel_type=itk.F)
print(f"Fixed image size: {itk.size(fixed_image)}")
print(f"Fixed image spacing: {itk.spacing(fixed_image)}")

# Save fixed image for reference
os.makedirs("results", exist_ok=True)
out_file = os.path.join("results", f"slice_fixed.mha")
itk.imwrite(fixed_image, out_file)
print(f"Saved fixed image to: {out_file}")

images = []
for file in files:
    img = itk.imread(file, pixel_type=itk.F)
    images.append(img)


In [None]:
# This cell will be run for each registration method in the loop below
print(f"Registration methods to run: {registration_methods}")


## Perform Time Series Registration

Loop through each registration method and perform registration.


In [None]:
# Store results for each method
all_results = {}

# Loop through each registration method
for method_idx, registration_method in enumerate(registration_methods):
    number_of_iterations = number_of_iterations_list[method_idx]

    print("\n" + "="*70)
    print(f"Starting registration with {registration_method.upper()}")
    print("="*70)
    print(f"  Starting index: {reference_image_num}")
    print(f"  Register start to reference: {register_start_to_reference}")
    print(f"  Prior transform weight: {portion_of_prior_transform_to_init_next_transform}")
    print(f"  Number of iterations: {number_of_iterations}")

    # Create registrar for this method
    registrar = RegisterTimeSeriesImages(registration_method=registration_method)
    registrar.set_modality('ct')
    registrar.set_fixed_image(fixed_image)
    registrar.set_number_of_iterations(number_of_iterations)

    # Perform registration
    result = registrar.register_time_series(
        moving_images=images,
        reference_frame=reference_image_num,
        register_reference=register_start_to_reference,
        prior_weight=portion_of_prior_transform_to_init_next_transform,
    )

    # Store results
    all_results[registration_method] = result

    forward_transforms = result["forward_transforms"]
    inverse_transforms = result["inverse_transforms"]
    losses = result["losses"]

    print(f"\n{registration_method.upper()} registration complete!")
    print(f"  Average loss: {np.mean(losses):.6f}")
    print(f"  Min loss: {np.min(losses):.6f}")
    print(f"  Max loss: {np.max(losses):.6f}")

print("\n" + "="*70)
print("All registrations complete!")
print("="*70)


In [None]:
# Save registered images and transforms for each method
tfm_tools = TransformTools()

for registration_method in registration_methods:
    result = all_results[registration_method]
    forward_transforms = result["forward_transforms"]
    inverse_transforms = result["inverse_transforms"]

    print(f"Saving {registration_method.upper()} results...")
    for i, img_indx in enumerate(files_indx):
        print(f"  Saving slice {img_indx:03d}...")

        # Apply transform and save registered image (moving to fixed)
        reg_image = tfm_tools.transform_image(images[i], forward_transforms[i], fixed_image)
        out_file = os.path.join(
            "results", f"slice_{registration_method}_forward_{img_indx:03d}.mha"
        )
        itk.imwrite(reg_image, out_file, compression=True)

        # Apply inverse transform and save (fixed to moving)
        reg_image_inv = tfm_tools.transform_image(fixed_image, inverse_transforms[i], images[i])
        out_file = os.path.join(
            "results", f"slice_fixed_{registration_method}_inverse_{img_indx:03d}.mha"
        )
        itk.imwrite(reg_image_inv, out_file, compression=True)

        # Save transforms
        itk.transformwrite(
            forward_transforms[i],
            os.path.join(
                "results",
                f"slice_{registration_method}_forward_{img_indx:03d}.hdf"
            ),
            compression=True
        )
        itk.transformwrite(
            inverse_transforms[i],
            os.path.join(
                "results",
                f"slice_{registration_method}_inverse_{img_indx:03d}.hdf"
            ),
            compression=True
        )

print("✓ Results saved to results/ directory")


In [None]:
# Print registration losses for each method
for registration_method in registration_methods:
    result = all_results[registration_method]
    losses = result["losses"]

    print(f"{registration_method.upper()} Registration Losses:")
    print("="*50)
    for i, img_indx in enumerate(files_indx):
        status = "(reference)" if i == reference_image_num else ""
        print(f"  Slice {img_indx:03d}: {losses[i]:.6f} {status}")

    print(f"{registration_method.upper()} Statistics:")
    print(f"  Mean loss: {np.mean(losses):.6f}")
    print(f"  Std loss: {np.std(losses):.6f}")
    print(f"  Min loss: {np.min(losses):.6f}")
    print(f"  Max loss: {np.max(losses):.6f}")


## Visualize Registration Quality


In [None]:
# Generate grid image for visualization
grid_image = tfm_tools.generate_grid_image(fixed_image, 30, 1)

for registration_method in registration_methods:
    result = all_results[registration_method]
    inverse_transforms = result["inverse_transforms"]

    print(f"Generating {registration_method.upper()} grid visualizations...")
    for i, img_indx in enumerate(files_indx):
        print(f"  Generating grid for slice {img_indx:03d}...")

        # Transform grid with inverse transform (FM)
        inverse_grid_image = tfm_tools.transform_image(
            grid_image,
            inverse_transforms[i],
            fixed_image,
        )
        itk.imwrite(
            inverse_grid_image,
            os.path.join("results", f"slice_fixed_{registration_method}_inverse_grid_{img_indx:03d}.mha"),
            compression=True
        )

        # Save displacement field as image
        inverse_transform_image = tfm_tools.convert_transform_to_displacement_field(
            inverse_transforms[i],
            fixed_image,
            np_component_type=np.float32,
        )
        itk.imwrite(
            inverse_transform_image,
            os.path.join("results", f"slice_{registration_method}_inverse_{img_indx:03d}_field.mha"),
            compression=True
        )

print("✓ Grid visualizations saved")
