In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
from matplotlib import pyplot as plt
from utils import normalize_image,load_dcm_data,plot_coregistration_views,plot_difference_map
from scipy.ndimage import rotate, zoom, shift
from scipy.optimize import least_squares
import scipy


In [2]:
#Load the DCMS, sorted and stacked

dcm_path='RadCTTACEomics_1193-20250418T131346Z-001/RadCTTACEomics_1193/10_AP_Ax2.50mm'
dcms_target=load_dcm_data(dcm_path)
dcms_target.sort(key = (lambda x: float(x.SliceLocation)))
pixelarray_target = np.stack([x.pixel_array for x in dcms_target], axis=0)
pixelarray_target=normalize_image(pixelarray_target)

dcm_path='RadCTTACEomics_1193-20250418T131346Z-001/RadCTTACEomics_1193/20_PP_Ax2.50mm'
dcms_moving=load_dcm_data(dcm_path)
dcms_moving.sort(key = (lambda x: float(x.SliceLocation)))
pixelarray_moving = np.stack([x.pixel_array for x in dcms_moving], axis=0)
pixelarray_moving=normalize_image(pixelarray_moving)
pixelarray_target.shape, pixelarray_moving.shape

((188, 512, 512), (188, 512, 512))

For image coregistration, it is essential that both volumes have matching dimensions. One common solution is to crop the larger volume to match the size of the smaller one.

In [3]:
# Crop the target (reference) to match the moving shape in depth (Z-axis)
depth_moving = pixelarray_moving.shape[0]
pixelarray_target = pixelarray_target[:depth_moving, :, :]
pixelarray_target.shape, pixelarray_moving.shape

((188, 512, 512), (188, 512, 512))

## PLOT
From just plotting the image , is hard to see a difference between the moving and fixed targed.

In [None]:
plot_coregistration_views(pixelarray_target, pixelarray_moving)

A more effective approach involves analyzing the voxel-wise differences between the volumes. The results indicate that the most significant errors are concentrated along the borders, suggesting that the primary misalignments occur in these peripheral regions.

In [None]:
plot_difference_map(pixelarray_target,pixelarray_moving)


## Apply Corregistration
In this step, we apply a transformation that includes translation, rotation, and scaling to align the moving volume to the target. To estimate the optimal transformation parameters, we use the least squares optimization method, with the Root Mean Squared Error (RMSE) as the objective (error) function.


### Utils
Frist we start by defining a error function to minimize in which in this case it will be Root mean square deviation. In this case we select RMSE We use RMSE (Root Mean Squared Error) as the loss function because it provides a simple and effective measure of voxel-wise similarity between the fixed and moving volumes. It quantifies the average squared difference in intensity values, penalizing larger mismatches more heavily, which helps guide the optimization toward better alignment. Additionally, RMSE is fully differentiable, making it suitable for use in gradient-based optimization methods like least squares.

In [4]:

def rmse_loss(fixed, moving):
    """Compute Root Mean Squared Error between two volumes."""
    if fixed.shape != moving.shape:
        raise ValueError("Shape mismatch: fixed and moving volumes must have the same shape.")
    return np.sqrt(np.sum((fixed - moving) ** 2))

initial_error= rmse_loss(pixelarray_target,pixelarray_moving)
print("Initial error ", initial_error)

Initial error  411.3909542329376


In [5]:
def mse(ref_img, inp_img):
    """
    Calculates the MSE between two images.
    """
    return np.mean((ref_img-inp_img)**2)


We also define some util function to apply transformation and to get an initial parameter estimate

In [9]:

def find_volume_centroid(volume: np.ndarray):
    """Compute centroid of non-zero voxels in a 3D volume."""
    nonzero = np.argwhere(volume > 0)
    if nonzero.size == 0:
        raise ValueError("Volume is empty or has no non-zero voxels.")
    return np.mean(nonzero, axis=0)

def rotate_volume_all_axes(volume, angles, order=1):
    """Apply sequential 3D rotation around each principal axis."""
    axes = [(1, 2), (0, 2), (0, 1)]  # axial, coronal, sagittal
    for angle, axis in zip(angles, axes):
        volume = rotate(volume, angle=angle, axes=axis, reshape=False, order=order)
    return volume

def apply_transform(volume, translation, rotation_angles, zoom_factors):
    """Apply shift, rotation, and zoom to a volume."""
    transformed = shift(volume, translation, order=1)
    transformed = rotate_volume_all_axes(transformed, rotation_angles, order=1)
    # transformed = zoom(transformed, zoom_factors, order=1)
    return transformed
def apply_transform_rigid(volume, translation, rotation_angles):
    """Apply shift, rotation, and zoom to a volume."""
    transformed = shift(volume, translation, order=1)
    # transformed = np.roll(volume, np.array(translation).astype(int), axis=(0, 1, 2))
    transformed = rotate_volume_all_axes(transformed, rotation_angles, order=1)
    # transformed = zoom(transformed, zoom_factors, order=1)
    return transformed


We firts get the center of volume to get an estimate on the translation, rotation and scale will be empty at the beging

In [None]:
centroid_moving = find_volume_centroid(pixelarray_moving)
centroid_target = find_volume_centroid(pixelarray_target)
initial_translation = centroid_moving - centroid_target  # difference between centers

# Initial guess: [dx, dy, dz, rx, ry, rz, zx, zy, zz]
initial_params = np.concatenate([
    initial_translation,              # translation estimate
    [ np.pi, 0, 0   ]      # In radians                 # no initial rotation
    # [1.0, 1.0, 1.0]                   # no scaling
])
def function_to_minimize(params):
    translation = params[0:3]
    rotation_angles = params[3:6]*(180/np.pi) 
    # zoom_factors = params[6:9]
    transformed = apply_transform_rigid(pixelarray_moving, translation, rotation_angles)
    return mse(pixelarray_target, transformed)

result = least_squares(
    function_to_minimize,
    x0=initial_params,
    verbose=2
)
result.x

   Iteration     Total nfev        Cost      Cost reduction    Step norm     Optimality   
       0              1         1.7408e-03                                    4.09e+02    
       1              2         1.2183e-03      5.22e-04       4.22e+00       4.91e-04    
       2              3         8.0803e-04      4.10e-04       1.05e+00       5.49e-04    
       3              5         3.2252e-04      4.86e-04       5.27e-01       6.07e-04    
       4              7         1.0787e-04      2.15e-04       2.63e-01       4.85e-04    
       5              9         2.6830e-05      8.10e-05       1.32e-01       2.56e-04    
       6             11         1.6640e-05      1.02e-05       6.59e-02       1.50e-04    
       7             13         1.4332e-05      2.31e-06       1.65e-02       7.58e-05    
       8             16         1.4265e-05      6.72e-08       2.06e-03       3.06e-05    
       9             18         1.4252e-05      1.29e-08       5.15e-04       1.69e-05    

We define the fuction to minimize with least square error

In [None]:
def function_to_minimize(params):
    translation = params[0:3]
    rotation_angles = params[3:6]
    zoom_factors = params[6:9]
    transformed = apply_transform(pixelarray_moving, translation, rotation_angles, zoom_factors)
    return rmse_loss(pixelarray_target, transformed)

result = least_squares(function_to_minimize, x0=initial_params, verbose=1)


In [None]:
final_transform = apply_transform(pixelarray_moving, 
                                      translation=result.x[0:3],
                                      rotation_angles=result.x[3:6],
                                      zoom_factors=result.x[6:9])

In [None]:
plot_difference_map(pixelarray_target,final_transform)