# Moment-based alignment of two stacks
This notebook performs a direct moments-based rigid alignment between a fixed and a moving stack using only centroid and second-order statistics, then saves the aligned volume and transform without any extra preprocessing steps.

## 1. Set Up Environment and Parameters
Import required libraries and define the file paths for the fixed and moving stacks along with the output directory.

In [1]:
from pathlib import Path
import json
import numpy as np
import SimpleITK as sitk
from scipy import ndimage
from skimage.metrics import structural_similarity as ssim

# Update these paths to point to the desired input stacks and output directory
FIXED_PATH = Path('/mnt/nas_jlarsch/johannes/testIn/average_2p_noRot_flip_8b.nrrd')
MOVING_PATH = Path('/mnt/nas_jlarsch/johannes/testIn/L395_f10_anatomy_00002_8b.nrrd')
OUTPUT_DIR = Path('/mnt/f/Johannes/testOutput/fireants/momentsMinimal').resolve()
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

print(f'Fixed stack:  {FIXED_PATH}')
print(f'Moving stack: {MOVING_PATH}')
print(f'Outputs will be stored in: {OUTPUT_DIR}')

Fixed stack:  /mnt/nas_jlarsch/johannes/testIn/average_2p_noRot_flip_8b.nrrd
Moving stack: /mnt/nas_jlarsch/johannes/testIn/L395_f10_anatomy_00002_8b.nrrd
Outputs will be stored in: /mnt/f/Johannes/testOutput/fireants/momentsMinimal


## 2. Load Input Stacks
Read both stacks into SimpleITK images and NumPy arrays while confirming orientation, spacing, and value ranges.

In [2]:
def load_stack(path: Path):
    if not path.exists():
        raise FileNotFoundError(f'Missing stack: {path}')
    image = sitk.ReadImage(str(path))
    array = sitk.GetArrayFromImage(image).astype(np.float32)
    return image, array

fixed_sitk, fixed_np = load_stack(FIXED_PATH)
moving_sitk, moving_np = load_stack(MOVING_PATH)

print('Fixed stack metadata:')
print('  size:', fixed_sitk.GetSize())
print('  spacing:', fixed_sitk.GetSpacing())
print('  direction:', fixed_sitk.GetDirection())
print('  origin:', fixed_sitk.GetOrigin())
print('  dtype:', fixed_np.dtype, 'value range:', (float(fixed_np.min()), float(fixed_np.max())))
print()
print('Moving stack metadata:')
print('  size:', moving_sitk.GetSize())
print('  spacing:', moving_sitk.GetSpacing())
print('  direction:', moving_sitk.GetDirection())
print('  origin:', moving_sitk.GetOrigin())
print('  dtype:', moving_np.dtype, 'value range:', (float(moving_np.min()), float(moving_np.max())))

if fixed_np.shape != moving_np.shape:
    print('Warning: stacks have different shapes:', fixed_np.shape, 'vs', moving_np.shape)

Fixed stack metadata:
  size: (512, 512, 183)
  spacing: (1.0484, 1.0484, 1.0)
  direction: (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
  origin: (0.0, 0.0, 0.0)
  dtype: float32 value range: (0.0, 255.0)

Moving stack metadata:
  size: (512, 512, 206)
  spacing: (1.048361882567406, 1.048361882567406, 1.0)
  direction: (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)
  origin: (0.0, 0.0, 0.0)
  dtype: float32 value range: (0.0, 255.0)


## 3. Compute Image Moments
Compute spatial moments $m_{pq}$ and central moments $\mu_{pq}$ to extract centroids and covariance structures for each stack.

In [None]:
def compute_moments(image: sitk.Image):
    moments = sitk.ImageMoments(image)
    centroid = np.array(moments.GetCenterOfGravity(), dtype=np.float64)
    principal_axes = np.array(moments.GetPrincipalAxes(), dtype=np.float64).reshape(3, 3)
    principal_moments = np.array(moments.GetPrincipalMoments(), dtype=np.float64)
    if np.linalg.det(principal_axes) < 0:
        principal_axes[:, -1] *= -1
    return {
        'centroid_physical': centroid,
        'principal_axes': principal_axes,
        'principal_moments': principal_moments,
    }

fixed_moments = compute_moments(fixed_sitk)
moving_moments = compute_moments(moving_sitk)

print('Fixed centroid (mm):', fixed_moments['centroid_physical'])
print('Moving centroid (mm):', moving_moments['centroid_physical'])
print('Fixed principal moments:', fixed_moments['principal_moments'])
print('Moving principal moments:', moving_moments['principal_moments'])

AttributeError: module 'SimpleITK' has no attribute 'MomentsImageFilter'

## 4. Determine Alignment Transformations
Derive rotation and translation parameters using the first- and second-order moments to build a rigid transform that maps the moving stack to the fixed stack.

In [None]:
fixed_axes = fixed_moments['principal_axes']
moving_axes = moving_moments['principal_axes']
rotation = fixed_axes @ moving_axes.T
if np.linalg.det(rotation) < 0:
    rotation[:, -1] *= -1
translation = fixed_moments['centroid_physical'] - rotation @ moving_moments['centroid_physical']

alignment_parameters = {
    'rotation_matrix': rotation.tolist(),
    'translation_mm': translation.tolist(),
    'fixed_principal_moments': fixed_moments['principal_moments'].tolist(),
    'moving_principal_moments': moving_moments['principal_moments'].tolist(),
}

print('Rotation matrix (moving -> fixed):\n', rotation)
print('Translation (mm):', translation)
print('Determinant of rotation:', np.linalg.det(rotation))

## 5. Apply Moment-Based Alignment
Convert the derived parameters into a rigid affine transform and resample the moving stack to the fixed stack grid.

In [None]:
affine_transform = sitk.AffineTransform(3)
affine_transform.SetMatrix(tuple(rotation.astype(np.float64).ravel()))
affine_transform.SetTranslation(tuple(translation.astype(np.float64)))

resampler = sitk.ResampleImageFilter()
resampler.SetReferenceImage(fixed_sitk)
resampler.SetInterpolator(sitk.sitkLinear)
resampler.SetDefaultPixelValue(float(moving_np.min()))
resampler.SetTransform(affine_transform)
aligned_sitk = resampler.Execute(moving_sitk)
aligned_np = sitk.GetArrayFromImage(aligned_sitk).astype(np.float32)

print('Aligned stack shape:', aligned_np.shape)

## 6. Validate Alignment Quality
Compare the fixed and aligned stacks using mean squared error (MSE) and structural similarity (SSIM) on representative slices.

In [None]:
if fixed_np.shape != aligned_np.shape:
    raise ValueError('Aligned stack shape mismatch; cannot compute metrics.')

mse_value = float(np.mean((fixed_np - aligned_np) ** 2))
mid_index = fixed_np.shape[0] // 2
fixed_slice = fixed_np[mid_index]
aligned_slice = aligned_np[mid_index]
slice_range = float(max(fixed_slice.max(), aligned_slice.max()) - min(fixed_slice.min(), aligned_slice.min()))
if slice_range == 0:
    slice_range = 1.0
ssim_value = float(ssim(fixed_slice, aligned_slice, data_range=slice_range))

print(f'MSE (full volume): {mse_value:.6f}')
print(f'SSIM (mid slice): {ssim_value:.4f}')

## 7. Persist Aligned Stacks
Write the aligned stack and transform to disk and log the parameters used for reproducibility.

In [None]:
aligned_volume_path = OUTPUT_DIR / 'moving_moments_aligned.nrrd'
transform_path = OUTPUT_DIR / 'moments_alignment_affine.tfm'
summary_path = OUTPUT_DIR / 'moments_alignment_summary.json'

sitk.WriteImage(aligned_sitk, str(aligned_volume_path))
sitk.WriteTransform(affine_transform, str(transform_path))

summary_payload = {
    'fixed_path': str(FIXED_PATH),
    'moving_path': str(MOVING_PATH),
    'aligned_volume_path': str(aligned_volume_path),
    'transform_path': str(transform_path),
    'rotation_matrix': alignment_parameters['rotation_matrix'],
    'translation_mm': alignment_parameters['translation_mm'],
    'fixed_centroid_mm': fixed_moments['centroid_physical'].tolist(),
    'moving_centroid_mm': moving_moments['centroid_physical'].tolist(),
    'fixed_principal_moments': alignment_parameters['fixed_principal_moments'],
    'moving_principal_moments': alignment_parameters['moving_principal_moments'],
    'mse_full_volume': mse_value,
    'ssim_mid_slice': ssim_value,
}
with summary_path.open('w', encoding='utf-8') as f:
    json.dump(summary_payload, f, indent=2)

print(f'Saved aligned volume to {aligned_volume_path}')
print(f'Saved affine transform to {transform_path}')
print(f'Saved alignment summary to {summary_path}')