In [1]:

# Cell 1: Environment Setup and Dependencies
import os
import sys
import warnings
warnings.filterwarnings('ignore')

import pydicom
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
from scipy.spatial.transform import Rotation as R_scipy

try:
    import SimpleITK as sitk
    print(f"SimpleITK version: {sitk.Version.VersionString()}")
except ImportError as e:
    print(f"Error: SimpleITK not found. Install with: pip install SimpleITK")
    sys.exit(1)

print(f"Python: {sys.version.split()[0]}, NumPy: {np.__version__}")

SimpleITK version: 2.5.2
Python: 3.13.7, NumPy: 2.3.3


In [2]:
# Cell 2: Configuration and Directory Setup
BASE_DIR = r'c:\Users\zhaoanr\Desktop\_anon'
MR_DIR = os.path.join(BASE_DIR, 'MR')
CT_DIR = os.path.join(BASE_DIR, 'CT')

OUTPUT_DIR = r'c:\Users\zhaoanr\Desktop\ct2-mri-registration\structure_refined_output'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Expected subdirectories
MR_IMAGES_DIR = os.path.join(MR_DIR, 'MR')
MR_RTSTRUCT_DIR = os.path.join(MR_DIR, 'RTSTRUCT')
CT_IMAGES_DIR = os.path.join(CT_DIR, 'CT')
CT_RTSTRUCT_DIR = os.path.join(CT_DIR, 'RTSTRUCT')

print("Directory structure:")
for label, path in [("MR Images", MR_IMAGES_DIR), 
                     ("MR RTStruct", MR_RTSTRUCT_DIR),
                     ("CT Images", CT_IMAGES_DIR), 
                     ("CT RTStruct", CT_RTSTRUCT_DIR)]:
    exists = "✓" if os.path.exists(path) else "✗"
    print(f"{exists} {label}: {path}")

print(f"\nOutput: {OUTPUT_DIR}")

Directory structure:
✓ MR Images: c:\Users\zhaoanr\Desktop\_anon\MR\MR
✓ MR RTStruct: c:\Users\zhaoanr\Desktop\_anon\MR\RTSTRUCT
✓ CT Images: c:\Users\zhaoanr\Desktop\_anon\CT\CT
✓ CT RTStruct: c:\Users\zhaoanr\Desktop\_anon\CT\RTSTRUCT

Output: c:\Users\zhaoanr\Desktop\ct2-mri-registration\structure_refined_output


In [3]:
# Cell 3: Load Image Series
def load_dicom_series_sitk(directory, series_description="Series"):
    """Load DICOM series from directory using SimpleITK."""
    print(f"\nLoading {series_description} from {os.path.basename(directory)}...")
    
    # Get all DICOM files
    dicom_files = [os.path.join(directory, f) for f in os.listdir(directory) 
                   if f.endswith('.dcm') or not '.' in f]
    
    if not dicom_files:
        raise ValueError(f"No DICOM files found in {directory}")
    
    # Sort by Z position
    sorted_files = []
    for filepath in dicom_files:
        try:
            header = pydicom.dcmread(filepath, stop_before_pixels=True, force=True)
            if 'ImagePositionPatient' in header:
                z_pos = float(header.ImagePositionPatient[2])
                sorted_files.append((filepath, z_pos))
        except:
            continue
    
    sorted_files.sort(key=lambda x: x[1])
    file_paths = [f[0] for f in sorted_files]
    
    print(f"  Found {len(file_paths)} slices")
    
    # Read with SimpleITK
    reader = sitk.ImageSeriesReader()
    reader.SetFileNames(file_paths)
    reader.MetaDataDictionaryArrayUpdateOn()
    reader.LoadPrivateTagsOn()
    
    image = reader.Execute()
    
    # Print metadata
    size = image.GetSize()
    spacing = image.GetSpacing()
    origin = image.GetOrigin()
    
    print(f"  Size: {size}")
    print(f"  Spacing: {np.round(spacing, 2)} mm")
    print(f"  Origin: {np.round(origin, 1)} mm")
    
    return image

print("="*80)
print("LOADING IMAGE SERIES")
print("="*80)

mr_image_sitk = load_dicom_series_sitk(MR_IMAGES_DIR, "MR")
ct_image_sitk = load_dicom_series_sitk(CT_IMAGES_DIR, "CT")

# Convert to float32
mr_image_sitk = sitk.Cast(mr_image_sitk, sitk.sitkFloat32)
ct_image_sitk = sitk.Cast(ct_image_sitk, sitk.sitkFloat32)

print(f"\nMR image type: {mr_image_sitk.GetPixelIDTypeAsString()}")
print(f"CT image type: {ct_image_sitk.GetPixelIDTypeAsString()}")

LOADING IMAGE SERIES

Loading MR from MR...
  Found 176 slices
  Size: (768, 768, 176)
  Spacing: [0.49 0.49 1.5 ] mm
  Origin: [-180.4 -199.3  -32.9] mm

Loading CT from CT...
  Found 334 slices
  Size: (512, 512, 334)
  Spacing: [0.98 0.98 1.5 ] mm
  Origin: [-249.5 -411.5 -152.9] mm

MR image type: 32-bit float
CT image type: 32-bit float


In [4]:
# Cell 4: Stage 1 - Mutual Information Registration (CORRECTED DIRECTION)
print("\n" + "="*80)
print("STAGE 1: MUTUAL INFORMATION REGISTRATION")
print("="*80)

# Registration callback to monitor progress
iteration_data = {'iteration': 0, 'metric_values': [], 'multires_iterations': []}

def command_iteration(method):
    if method.GetOptimizerIteration() == 0:
        iteration_data['multires_iterations'].append(iteration_data['iteration'])
    iteration_data['iteration'] += 1
    iteration_data['metric_values'].append(method.GetMetricValue())
    
    if iteration_data['iteration'] % 20 == 0:
        print(f"  Iteration {iteration_data['iteration']:4d}: Metric = {method.GetMetricValue():.4f}")

# Setup registration
registration_method = sitk.ImageRegistrationMethod()

# Mutual Information metric
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.1)

# Interpolator
registration_method.SetInterpolator(sitk.sitkLinear)

# Optimizer - Regular Step Gradient Descent
registration_method.SetOptimizerAsRegularStepGradientDescent(
    learningRate=1.0,
    minStep=0.001,
    numberOfIterations=100,
    gradientMagnitudeTolerance=1e-6
)
registration_method.SetOptimizerScalesFromIndexShift()

# Multi-resolution framework
registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

# Initial transform - CenteredTransformInitializer with MOMENTS
# CORRECTED: CT is fixed, MR is moving (register MR -> CT)
initial_transform = sitk.CenteredTransformInitializer(
    ct_image_sitk,  # Fixed image (target space)
    mr_image_sitk,  # Moving image
    sitk.VersorRigid3DTransform(),
    sitk.CenteredTransformInitializerFilter.MOMENTS
)

registration_method.SetInitialTransform(initial_transform, inPlace=True)

# Connect callback
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(registration_method))

print("\nStarting MI registration (MR -> CT)...")
print("Initial transform (MOMENTS initialization):")
print(f"  Translation: {initial_transform.GetTranslation()}")

# Execute registration: register MR to CT space (directly, no inversion needed)
mi_transform = registration_method.Execute(ct_image_sitk, mr_image_sitk)

print(f"\n✓ MI Registration complete")
print(f"  Final metric value: {registration_method.GetMetricValue():.4f}")
print(f"  Optimizer stop condition: {registration_method.GetOptimizerStopConditionDescription()}")

# Get transformation parameters
mi_rotation_matrix = np.array(mi_transform.GetMatrix()).reshape(3, 3)
mi_translation = np.array(mi_transform.GetTranslation())

print(f"\nMI Transform Parameters (MR -> CT):")
print(f"  Translation: [{mi_translation[0]:.2f}, {mi_translation[1]:.2f}, {mi_translation[2]:.2f}] mm")

rotation_obj = R_scipy.from_matrix(mi_rotation_matrix)
euler_angles = rotation_obj.as_euler('xyz', degrees=True)
print(f"  Rotation: X={euler_angles[0]:.2f}°, Y={euler_angles[1]:.2f}°, Z={euler_angles[2]:.2f}°")

# Apply MI transform to MR image
mr_mi_registered = sitk.Resample(
    mr_image_sitk,
    ct_image_sitk,
    mi_transform,
    sitk.sitkLinear,
    0.0,
    mr_image_sitk.GetPixelID()
)

# Save intermediate result
mi_output = os.path.join(OUTPUT_DIR, 'mr_mi_registered.nii.gz')
sitk.WriteImage(mr_mi_registered, mi_output)
print(f"\n✓ Saved MI-registered MR: {mi_output}")

# Save MI transform
mi_transform_file = os.path.join(OUTPUT_DIR, 'mi_transform.tfm')
sitk.WriteTransform(mi_transform, mi_transform_file)
print(f"✓ Saved MI transform: {mi_transform_file}")


STAGE 1: MUTUAL INFORMATION REGISTRATION

Starting MI registration (MR -> CT)...
Initial transform (MOMENTS initialization):
  Translation: (2.9270443295896866, 155.4231429940117, 16.622463266975075)
  Iteration   20: Metric = -0.6374
  Iteration   40: Metric = -0.8116
  Iteration   60: Metric = -0.8134
  Iteration   80: Metric = -0.6206

✓ MI Registration complete
  Final metric value: -0.6207
  Optimizer stop condition: RegularStepGradientDescentOptimizerv4: Step too small after 13 iterations. Current step (0.000976562) is less than minimum step (0.001).

MI Transform Parameters (MR -> CT):
  Translation: [10.90, 172.54, 48.12] mm
  Rotation: X=0.06°, Y=-0.76°, Z=-0.65°

✓ Saved MI-registered MR: c:\Users\zhaoanr\Desktop\ct2-mri-registration\structure_refined_output\mr_mi_registered.nii.gz
✓ Saved MI transform: c:\Users\zhaoanr\Desktop\ct2-mri-registration\structure_refined_output\mi_transform.tfm


In [5]:
# Cell 5: Extract Prostate/Rectum Structures from RTStructs
def extract_prostate_structures(rtstruct_path, modality_name="Image"):
    """Extract prostate/rectum structures from RTStruct file."""
    print(f"\nExtracting {modality_name} prostate structures from:")
    print(f"  {os.path.basename(rtstruct_path)}")
    
    ds = pydicom.dcmread(rtstruct_path)
    
    if not hasattr(ds, 'StructureSetROISequence') or not hasattr(ds, 'ROIContourSequence'):
        print(f"  ✗ No structure data found")
        return {}
    
    # Build ROI info
    roi_info = {}
    for roi in ds.StructureSetROISequence:
        roi_info[roi.ROINumber] = roi.ROIName if roi.ROIName else f'ROI_{roi.ROINumber}'
    
    # Extract prostate/rectum structures
    prostate_contours = {}
    
    for contour_seq in ds.ROIContourSequence:
        roi_num = contour_seq.ReferencedROINumber
        roi_name = roi_info.get(roi_num, f'ROI_{roi_num}')
        
        if not hasattr(contour_seq, 'ContourSequence'):
            continue
        
        # Collect all contour points
        all_points = []
        for contour in contour_seq.ContourSequence:
            if hasattr(contour, 'ContourData'):
                points = np.array(contour.ContourData).reshape(-1, 3)
                all_points.append(points)
        
        if not all_points:
            continue
        
        all_points = np.vstack(all_points)
        
        # Compute properties to identify prostate/rectum
        centroid = np.mean(all_points, axis=0)
        z_extent = all_points[:, 2].max() - all_points[:, 2].min()
        x_extent = all_points[:, 0].max() - all_points[:, 0].min()
        y_extent = all_points[:, 1].max() - all_points[:, 1].min()
        volume = x_extent * y_extent * z_extent
        x_pos = abs(centroid[0])
        
        # Heuristic: prostate/rectum is near midline, moderate size, moderate Z extent
        if x_pos < 40 and 30 < z_extent < 100 and volume < 800000:
            prostate_contours[roi_name] = all_points
            print(f"  ✓ Found '{roi_name}': {len(all_points)} points, centroid=[{centroid[0]:.1f}, {centroid[1]:.1f}, {centroid[2]:.1f}]")
    
    return prostate_contours

print("\n" + "="*80)
print("EXTRACTING PROSTATE/RECTUM STRUCTURES")
print("="*80)

# Find RTStruct files
mr_rtstruct_files = [f for f in os.listdir(MR_RTSTRUCT_DIR) if f.endswith('.dcm') or not '.' in f]
ct_rtstruct_files = [f for f in os.listdir(CT_RTSTRUCT_DIR) if f.endswith('.dcm') or not '.' in f]

mr_prostate = {}
ct_prostate = {}

if mr_rtstruct_files:
    mr_rtstruct_path = os.path.join(MR_RTSTRUCT_DIR, mr_rtstruct_files[0])
    mr_prostate = extract_prostate_structures(mr_rtstruct_path, "MR")

if ct_rtstruct_files:
    ct_rtstruct_path = os.path.join(CT_RTSTRUCT_DIR, ct_rtstruct_files[0])
    ct_prostate = extract_prostate_structures(ct_rtstruct_path, "CT")

print(f"\nSummary:")
print(f"  MR prostate structures: {len(mr_prostate)}")
print(f"  CT prostate structures: {len(ct_prostate)}")


EXTRACTING PROSTATE/RECTUM STRUCTURES

Extracting MR prostate structures from:
  RTSTRUCT.1.3.277.1.7230011.4.1.5.3356101125.1705.1766167756.922527.dcm
  ✓ Found 'Rectum': 14160 points, centroid=[16.1, 42.9, 67.9]
  ✓ Found 'Anal_Canal': 4348 points, centroid=[14.5, 43.3, 3.2]
  ✓ Found 'Bladder': 24496 points, centroid=[11.7, -23.9, 88.6]
  ✓ Found 'Prostate': 6634 points, centroid=[14.0, 6.6, 42.8]

Extracting CT prostate structures from:
  RTSTRUCT.1.3.277.1.7230011.4.1.5.548144899.484.1754913885.567543.dcm
  ✓ Found 'Rectum': 7274 points, centroid=[1.0, -127.2, 12.9]
  ✓ Found 'Anal_Canal': 3136 points, centroid=[1.2, -119.4, -51.0]
  ✓ Found 'Bladder': 8742 points, centroid=[-0.2, -190.4, 30.9]
  ✓ Found 'Sigmoid_Colon': 8168 points, centroid=[7.4, -170.5, 70.0]
  ✓ Found 'Prostate': 4102 points, centroid=[0.5, -164.4, -8.4]
  ✓ Found 'Aorta': 4072 points, centroid=[33.4, -174.8, 320.5]

Summary:
  MR prostate structures: 4
  CT prostate structures: 6


In [6]:
# Cell 6: Stage 2 - Prostate Nearest Neighbor Refinement
print("\n" + "="*80)
print("STAGE 2: PROSTATE NEAREST NEIGHBOR ALIGNMENT")
print("="*80)

# Extract specific structures by name
def get_structure_by_name(contours_dict, structure_name):
    """Get structure by exact name match."""
    for name, points in contours_dict.items():
        if structure_name.lower() in name.lower():
            return name, points
    return None, None

# Get Prostate from both modalities
mr_prostate_name, mr_prostate_points = get_structure_by_name(mr_prostate, 'Prostate')
ct_prostate_name, ct_prostate_points = get_structure_by_name(ct_prostate, 'Prostate')

print(f"\nStructure matching:")
print(f"  MR Prostate: '{mr_prostate_name}' - {len(mr_prostate_points) if mr_prostate_points is not None else 0} points")
print(f"  CT Prostate: '{ct_prostate_name}' - {len(ct_prostate_points) if ct_prostate_points is not None else 0} points")

if mr_prostate_points is not None and ct_prostate_points is not None:
    from scipy.spatial import distance_matrix
    
    # Original prostate centroids
    mr_prostate_centroid_orig = np.mean(mr_prostate_points, axis=0)
    ct_prostate_centroid = np.mean(ct_prostate_points, axis=0)
    
    print(f"\nOriginal prostate centroids (before any transform):")
    print(f"  MR: [{mr_prostate_centroid_orig[0]:.2f}, {mr_prostate_centroid_orig[1]:.2f}, {mr_prostate_centroid_orig[2]:.2f}]")
    print(f"  CT: [{ct_prostate_centroid[0]:.2f}, {ct_prostate_centroid[1]:.2f}, {ct_prostate_centroid[2]:.2f}]")
    print(f"  Distance: {np.linalg.norm(ct_prostate_centroid - mr_prostate_centroid_orig):.2f} mm")
    
    # Get MI transform components
    mi_translation = np.array(mi_transform.GetTranslation())
    mi_rotation_matrix = np.array(mi_transform.GetMatrix()).reshape(3, 3)
    
    print(f"\nMI Translation: [{mi_translation[0]:.2f}, {mi_translation[1]:.2f}, {mi_translation[2]:.2f}] mm")
    
    # Apply MI transform to MR prostate points by SUBTRACTING
    print(f"\n" + "-"*80)
    print("APPLYING MI TRANSFORM TO MR PROSTATE")
    print("-"*80)
    
    # Transform MR points: SUBTRACT translation (moving MR into CT space)
    mr_prostate_mi = mr_prostate_points - mi_translation
    mr_centroid_mi = np.mean(mr_prostate_mi, axis=0)
    
    print(f"MR centroid after MI: [{mr_centroid_mi[0]:.2f}, {mr_centroid_mi[1]:.2f}, {mr_centroid_mi[2]:.2f}]")
    print(f"CT centroid:          [{ct_prostate_centroid[0]:.2f}, {ct_prostate_centroid[1]:.2f}, {ct_prostate_centroid[2]:.2f}]")
    
    centroid_error_after_mi = np.linalg.norm(mr_centroid_mi - ct_prostate_centroid)
    print(f"Distance after MI: {centroid_error_after_mi:.2f} mm")
    
    # Compute nearest neighbor refinement (using all prostate points)
    print(f"\n" + "-"*80)
    print("COMPUTING NEAREST NEIGHBOR REFINEMENT (All Points)")
    print("-"*80)
    
    # Find closest CT point for each MR point
    dist_matrix_temp = distance_matrix(mr_prostate_mi, ct_prostate_points)
    closest_ct_indices = np.argmin(dist_matrix_temp, axis=1)
    closest_ct_points = ct_prostate_points[closest_ct_indices]
    
    # Compute displacement vectors (from MR points to their nearest CT points)
    displacements = closest_ct_points - mr_prostate_mi
    
    # Refinement is the mean of all displacement vectors
    refinement_translation = np.mean(displacements, axis=0)
    
    print(f"  Points used: {len(mr_prostate_mi)} (100.0%)")
    print(f"\nRefinement: [{refinement_translation[0]:.2f}, {refinement_translation[1]:.2f}, {refinement_translation[2]:.2f}] mm")
    print(f"Magnitude: {np.linalg.norm(refinement_translation):.2f} mm")
    
    # Apply refinement to MR prostate points
    mr_prostate_refined = mr_prostate_mi + refinement_translation
    mr_centroid_refined = np.mean(mr_prostate_refined, axis=0)
    
    # Compute total translation: MI - Refinement (since we subtract MI and add refinement)
    total_translation = mi_translation - refinement_translation
    
    print(f"\n" + "="*80)
    print("TRANSFORMATION SUMMARY")
    print("="*80)
    print(f"MI Translation:        [{mi_translation[0]:.3f}, {mi_translation[1]:.3f}, {mi_translation[2]:.3f}] mm")
    print(f"Refinement (NN-All):   [{refinement_translation[0]:.3f}, {refinement_translation[1]:.3f}, {refinement_translation[2]:.3f}] mm")
    print(f"Total Translation:     [{total_translation[0]:.3f}, {total_translation[1]:.3f}, {total_translation[2]:.3f}] mm")
    
    # Compute alignment quality metrics
    final_centroid_error = np.linalg.norm(mr_centroid_refined - ct_prostate_centroid)
    
    dist_matrix_final = distance_matrix(mr_prostate_refined, ct_prostate_points)
    min_distances_prostate = np.min(dist_matrix_final, axis=1)
    
    print(f"\n" + "="*80)
    print("ALIGNMENT QUALITY")
    print("="*80)
    print(f"  Centroid error:       {final_centroid_error:.2f} mm")
    print(f"  Mean surface dist:    {np.mean(min_distances_prostate):.2f} mm")
    print(f"  Median surface dist:  {np.median(min_distances_prostate):.2f} mm")
    print(f"  95th percentile:      {np.percentile(min_distances_prostate, 95):.2f} mm")
    
    # Create composite transform: MI - Refinement (invert refinement)
    composite_transform = sitk.CompositeTransform(3)
    composite_transform.AddTransform(mi_transform)
    
    refinement_transform = sitk.TranslationTransform(3)
    refinement_transform.SetOffset((-refinement_translation).tolist())  # INVERTED
    composite_transform.AddTransform(refinement_transform)
    
    # Apply composite transform to full MR image
    mr_refined = sitk.Resample(
        mr_image_sitk,
        ct_image_sitk,
        composite_transform,
        sitk.sitkLinear,
        0.0,
        mr_image_sitk.GetPixelID()
    )
    
    # Build final transformation matrix
    print(f"\n" + "="*80)
    print("FINAL TRANSFORMATION MATRIX")
    print("="*80)
    
    final_matrix = np.eye(4)
    final_matrix[:3, :3] = mi_rotation_matrix
    final_matrix[:3, 3] = total_translation
    
    print(f"\nComposite Transform (4x4 Homogeneous Matrix):")
    for row in final_matrix:
        print(f"  [{row[0]:10.6f} {row[1]:10.6f} {row[2]:10.6f} {row[3]:10.6f}]")
    
    rotation_obj = R_scipy.from_matrix(mi_rotation_matrix)
    euler_angles = rotation_obj.as_euler('xyz', degrees=True)
    
    print(f"\nTranslation: [{total_translation[0]:.3f}, {total_translation[1]:.3f}, {total_translation[2]:.3f}] mm")
    print(f"Rotation: [{euler_angles[0]:.3f}, {euler_angles[1]:.3f}, {euler_angles[2]:.3f}] degrees")
    
    # Save outputs
    refined_output = os.path.join(OUTPUT_DIR, 'mr_prostate_refined.nii.gz')
    sitk.WriteImage(mr_refined, refined_output)
    print(f"\n✓ Saved refined MR: {refined_output}")
    
    composite_file = os.path.join(OUTPUT_DIR, 'prostate_composite_transform.tfm')
    sitk.WriteTransform(composite_transform, composite_file)
    print(f"✓ Saved composite transform: {composite_file}")
    
    # Set variables for Cell 8
    final_transform = composite_transform
    mr_final = mr_refined
    final_error = final_centroid_error
    min_distances = min_distances_prostate
    ct_centroid = ct_prostate_centroid
    strategy_name = "NearestNeighbor"
    
else:
    print("\n✗ Prostate structures not found in both modalities")
    print("   Using MI registration only")
    final_transform = mi_transform
    mr_final = mr_mi_registered
    mr_refined = None


STAGE 2: PROSTATE NEAREST NEIGHBOR ALIGNMENT

Structure matching:
  MR Prostate: 'Prostate' - 6634 points
  CT Prostate: 'Prostate' - 4102 points

Original prostate centroids (before any transform):
  MR: [13.95, 6.63, 42.83]
  CT: [0.53, -164.40, -8.38]
  Distance: 179.03 mm

MI Translation: [10.90, 172.54, 48.12] mm

--------------------------------------------------------------------------------
APPLYING MI TRANSFORM TO MR PROSTATE
--------------------------------------------------------------------------------
MR centroid after MI: [3.05, -165.91, -5.29]
CT centroid:          [0.53, -164.40, -8.38]
Distance after MI: 4.26 mm

--------------------------------------------------------------------------------
COMPUTING NEAREST NEIGHBOR REFINEMENT (All Points)
--------------------------------------------------------------------------------
  Points used: 6634 (100.0%)

Refinement: [-0.58, 0.28, -0.45] mm
Magnitude: 0.78 mm

TRANSFORMATION SUMMARY
MI Translation:        [10.900, 172.538

In [7]:
# Cell 8: Final Transformation Matrix Summary
print("\n" + "="*80)
print("TRANSFORMATION MATRIX SUMMARY")
print("="*80)

# MI Transform
mi_matrix = np.array(mi_transform.GetMatrix()).reshape(3, 3)
mi_trans = np.array(mi_transform.GetTranslation())

mi_affine = np.eye(4)
mi_affine[:3, :3] = mi_matrix
mi_affine[:3, 3] = mi_trans

print("\nSTAGE 1: MI Registration Transform (4x4 Homogeneous Matrix)")
print("-" * 80)
for row in mi_affine:
    print(f"[{row[0]:10.6f} {row[1]:10.6f} {row[2]:10.6f} {row[3]:10.6f}]")

rotation_obj = R_scipy.from_matrix(mi_matrix)
mi_euler = rotation_obj.as_euler('xyz', degrees=True)

print(f"\n  Translation: [{mi_trans[0]:.3f}, {mi_trans[1]:.3f}, {mi_trans[2]:.3f}] mm")
print(f"  Rotation: [{mi_euler[0]:.3f}, {mi_euler[1]:.3f}, {mi_euler[2]:.3f}] degrees")

if 'refinement_transform' in locals() and 'strategy_name' in locals():
    print(f"\n\nSTAGE 2: Prostate Refinement ({strategy_name})")
    print("-" * 80)
    refine_trans = np.array(refinement_transform.GetOffset())
    
    # Refinement translation (applied after MI transform)
    print(f"  Refinement translation: [{refine_trans[0]:.3f}, {refine_trans[1]:.3f}, {refine_trans[2]:.3f}] mm")
    print(f"  Magnitude: {np.linalg.norm(refine_trans):.3f} mm")
    
    # Total translation is MI + refinement (both applied in sequence)
    total_trans = mi_trans + refine_trans
    print(f"\n  Total translation (MI + Refinement):")
    print(f"    [{total_trans[0]:.3f}, {total_trans[1]:.3f}, {total_trans[2]:.3f}] mm")
    print(f"    Magnitude: {np.linalg.norm(total_trans):.3f} mm")
    
    # For 3D Slicer: Use the same rotation, but adjust translation for Slicer's coordinate system
    # Slicer expects RAS coordinates, ITK uses LPS
    # The transform from ITK to Slicer requires flipping X and Y
    slicer_trans = total_trans.copy()
    slicer_trans[0] = -slicer_trans[0]  # Flip X (L->R)
    slicer_trans[1] = -slicer_trans[1]  # Flip Y (P->A)
    # Z stays the same (S->S)
    
    print(f"\n  For 3D Slicer (RAS coordinates):")
    print(f"    Total translation: [{slicer_trans[0]:.3f}, {slicer_trans[1]:.3f}, {slicer_trans[2]:.3f}] mm")
    print(f"    Rotation: [{mi_euler[0]:.3f}, {mi_euler[1]:.3f}, {mi_euler[2]:.3f}] degrees")
    
    # Build final affine matrix for Slicer
    final_affine_slicer = np.eye(4)
    final_affine_slicer[:3, :3] = mi_matrix
    final_affine_slicer[:3, 3] = slicer_trans
    
    print("\n\nFINAL: Transform for 3D Slicer (4x4 Matrix)")
    print("-" * 80)
    for row in final_affine_slicer:
        print(f"[{row[0]:10.6f} {row[1]:10.6f} {row[2]:10.6f} {row[3]:10.6f}]")
    
    # Quality metrics
    if 'final_error' in locals() and 'min_distances' in locals():
        print(f"\n\nProstate Alignment Quality:")
        print(f"  Strategy used:        {strategy_name}")
        print(f"  Centroid error:       {final_error:.2f} mm")
        print(f"  Mean surface dist:    {np.mean(min_distances):.2f} mm")
        print(f"  Median surface dist:  {np.median(min_distances):.2f} mm")
        print(f"  95th percentile:      {np.percentile(min_distances, 95):.2f} mm")
    
    # Save to file
    summary_file = os.path.join(OUTPUT_DIR, 'transformation_summary.txt')
    with open(summary_file, 'w') as f:
        f.write("=" * 80 + "\n")
        f.write("MRI-CT PROSTATE REGISTRATION TRANSFORMATION SUMMARY\n")
        f.write("=" * 80 + "\n\n")
        
        f.write("STAGE 1: Mutual Information (MI) Registration\n")
        f.write("-" * 80 + "\n")
        f.write("4x4 Homogeneous Matrix:\n")
        for row in mi_affine:
            f.write(f"  [{row[0]:10.6f} {row[1]:10.6f} {row[2]:10.6f} {row[3]:10.6f}]\n")
        f.write(f"\nTranslation: [{mi_trans[0]:.3f}, {mi_trans[1]:.3f}, {mi_trans[2]:.3f}] mm\n")
        f.write(f"Rotation (XYZ Euler): [{mi_euler[0]:.3f}, {mi_euler[1]:.3f}, {mi_euler[2]:.3f}] degrees\n\n")
        
        f.write(f"\nSTAGE 2: Prostate Structure-Based Refinement\n")
        f.write("-" * 80 + "\n")
        f.write(f"Strategy: {strategy_name}\n")
        f.write(f"Refinement translation: [{refine_trans[0]:.3f}, {refine_trans[1]:.3f}, {refine_trans[2]:.3f}] mm\n")
        f.write(f"Total translation (MI + Refinement): [{total_trans[0]:.3f}, {total_trans[1]:.3f}, {total_trans[2]:.3f}] mm\n\n")
        
        f.write("\nFINAL: Composite Transform for 3D Slicer\n")
        f.write("-" * 80 + "\n")
        f.write(f"Translation (RAS): [{slicer_trans[0]:.3f}, {slicer_trans[1]:.3f}, {slicer_trans[2]:.3f}] mm\n")
        f.write(f"Rotation (XYZ Euler): [{mi_euler[0]:.3f}, {mi_euler[1]:.3f}, {mi_euler[2]:.3f}] degrees\n\n")
        f.write("4x4 Homogeneous Matrix:\n")
        for row in final_affine_slicer:
            f.write(f"  [{row[0]:10.6f} {row[1]:10.6f} {row[2]:10.6f} {row[3]:10.6f}]\n")
        
        if 'final_error' in locals() and 'min_distances' in locals():
            f.write("\n\nAlignment Quality Metrics:\n")
            f.write("-" * 80 + "\n")
            f.write(f"Centroid error:       {final_error:.2f} mm\n")
            f.write(f"Mean surface dist:    {np.mean(min_distances):.2f} mm\n")
            f.write(f"Median surface dist:  {np.median(min_distances):.2f} mm\n")
            f.write(f"95th percentile:      {np.percentile(min_distances, 95):.2f} mm\n")
            f.write(f"Max distance:         {np.max(min_distances):.2f} mm\n")
    
    print(f"\n✓ Summary saved to: {summary_file}")
    
else:
    print("\n(Only MI registration was performed - no refinement data available)")


TRANSFORMATION MATRIX SUMMARY

STAGE 1: MI Registration Transform (4x4 Homogeneous Matrix)
--------------------------------------------------------------------------------
[  0.999847   0.011309  -0.013322  10.900246]
[ -0.011322   0.999935  -0.000905 172.537551]
[  0.013311   0.001056   0.999911  48.119952]
[  0.000000   0.000000   0.000000   1.000000]

  Translation: [10.900, 172.538, 48.120] mm
  Rotation: [0.061, -0.763, -0.649] degrees


STAGE 2: Prostate Refinement (NearestNeighbor)
--------------------------------------------------------------------------------
  Refinement translation: [0.577, -0.276, 0.449] mm
  Magnitude: 0.782 mm

  Total translation (MI + Refinement):
    [11.477, 172.261, 48.569] mm
    Magnitude: 179.345 mm

  For 3D Slicer (RAS coordinates):
    Total translation: [-11.477, -172.261, 48.569] mm
    Rotation: [0.061, -0.763, -0.649] degrees


FINAL: Transform for 3D Slicer (4x4 Matrix)
---------------------------------------------------------------------