In [23]:

# Cell 1: Environment Setup and Dependencies
import os
import sys
import warnings
warnings.filterwarnings('ignore')

import pydicom
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage
from scipy.spatial.transform import Rotation as R_scipy

try:
    import SimpleITK as sitk
    print(f"SimpleITK version: {sitk.Version.VersionString()}")
except ImportError as e:
    print(f"Error: SimpleITK not found. Install with: pip install SimpleITK")
    sys.exit(1)

print(f"Python: {sys.version.split()[0]}, NumPy: {np.__version__}")

SimpleITK version: 2.5.2
Python: 3.13.7, NumPy: 2.3.3


In [24]:
# Cell 2: Configuration and Directory Setup
BASE_DIR = r'c:\Users\zhaoanr\Desktop\_anon4'
MR_DIR = os.path.join(BASE_DIR, 'MR')
CT_DIR = os.path.join(BASE_DIR, 'CT')

OUTPUT_DIR = r'c:\Users\zhaoanr\Desktop\ct2-mri-registration\structure_refined_output'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Expected subdirectories
MR_IMAGES_DIR = os.path.join(MR_DIR, 'MR')
MR_RTSTRUCT_DIR = os.path.join(MR_DIR, 'RTSTRUCT')
CT_IMAGES_DIR = os.path.join(CT_DIR, 'CT')
CT_RTSTRUCT_DIR = os.path.join(CT_DIR, 'RTSTRUCT')

print("Directory structure:")
for label, path in [("MR Images", MR_IMAGES_DIR), 
                     ("MR RTStruct", MR_RTSTRUCT_DIR),
                     ("CT Images", CT_IMAGES_DIR), 
                     ("CT RTStruct", CT_RTSTRUCT_DIR)]:
    exists = "✓" if os.path.exists(path) else "✗"
    print(f"{exists} {label}: {path}")

print(f"\nOutput: {OUTPUT_DIR}")

Directory structure:
✓ MR Images: c:\Users\zhaoanr\Desktop\_anon4\MR\MR
✓ MR RTStruct: c:\Users\zhaoanr\Desktop\_anon4\MR\RTSTRUCT
✓ CT Images: c:\Users\zhaoanr\Desktop\_anon4\CT\CT
✓ CT RTStruct: c:\Users\zhaoanr\Desktop\_anon4\CT\RTSTRUCT

Output: c:\Users\zhaoanr\Desktop\ct2-mri-registration\structure_refined_output


In [25]:
# Cell 3: Load Image Series
def load_dicom_series_sitk(directory, series_description="Series"):
    """Load DICOM series from directory using SimpleITK."""
    print(f"\nLoading {series_description} from {os.path.basename(directory)}...")
    
    # Get all DICOM files
    dicom_files = [os.path.join(directory, f) for f in os.listdir(directory) 
                   if f.endswith('.dcm') or not '.' in f]
    
    if not dicom_files:
        raise ValueError(f"No DICOM files found in {directory}")
    
    # Sort by Z position
    sorted_files = []
    for filepath in dicom_files:
        try:
            header = pydicom.dcmread(filepath, stop_before_pixels=True, force=True)
            if 'ImagePositionPatient' in header:
                z_pos = float(header.ImagePositionPatient[2])
                sorted_files.append((filepath, z_pos))
        except:
            continue
    
    sorted_files.sort(key=lambda x: x[1])
    file_paths = [f[0] for f in sorted_files]
    
    print(f"  Found {len(file_paths)} slices")
    
    # Read with SimpleITK
    reader = sitk.ImageSeriesReader()
    reader.SetFileNames(file_paths)
    reader.MetaDataDictionaryArrayUpdateOn()
    reader.LoadPrivateTagsOn()
    
    image = reader.Execute()
    
    # Print metadata
    size = image.GetSize()
    spacing = image.GetSpacing()
    origin = image.GetOrigin()
    
    print(f"  Size: {size}")
    print(f"  Spacing: {np.round(spacing, 2)} mm")
    print(f"  Origin: {np.round(origin, 1)} mm")
    
    return image

print("="*80)
print("LOADING IMAGE SERIES")
print("="*80)

mr_image_sitk = load_dicom_series_sitk(MR_IMAGES_DIR, "MR")
ct_image_sitk = load_dicom_series_sitk(CT_IMAGES_DIR, "CT")

# Convert to float32
mr_image_sitk = sitk.Cast(mr_image_sitk, sitk.sitkFloat32)
ct_image_sitk = sitk.Cast(ct_image_sitk, sitk.sitkFloat32)

print(f"\nMR image type: {mr_image_sitk.GetPixelIDTypeAsString()}")
print(f"CT image type: {ct_image_sitk.GetPixelIDTypeAsString()}")

LOADING IMAGE SERIES

Loading MR from MR...
  Found 176 slices
  Size: (768, 768, 176)
  Spacing: [0.49 0.49 1.5 ] mm
  Origin: [-194.6 -250.2  -80. ] mm

Loading CT from CT...
  Found 268 slices
  Size: (512, 512, 268)
  Spacing: [0.98 0.98 1.5 ] mm
  Origin: [-249.5 -490.5 -223. ] mm

MR image type: 32-bit float
CT image type: 32-bit float


In [26]:
# Cell 4: Stage 1 - Mutual Information Registration (with Best Metric Tracking)
print("\n" + "="*80)
print("STAGE 1: MUTUAL INFORMATION REGISTRATION")
print("="*80)

# Registration callback to monitor progress and track best metric
iteration_data = {
    'iteration': 0, 
    'metric_values': [], 
    'multires_iterations': [],
    'best_metric': float('inf'),
    'best_params': None,
    'worse_count': 0,
    'patience': 15  # Stop if metric gets worse for this many consecutive iterations
}

def command_iteration(method):
    if method.GetOptimizerIteration() == 0:
        iteration_data['multires_iterations'].append(iteration_data['iteration'])
    
    iteration_data['iteration'] += 1
    current_metric = method.GetMetricValue()
    iteration_data['metric_values'].append(current_metric)
    
    # Track best metric (more negative is better for MI)
    if current_metric < iteration_data['best_metric']:
        iteration_data['best_metric'] = current_metric
        iteration_data['best_params'] = method.GetOptimizerPosition()
        iteration_data['worse_count'] = 0
    else:
        iteration_data['worse_count'] += 1
    
    if iteration_data['iteration'] % 20 == 0:
        print(f"  Iteration {iteration_data['iteration']:4d}: Metric = {current_metric:.4f} "
              f"(Best: {iteration_data['best_metric']:.4f} at iter {iteration_data['iteration'] - iteration_data['worse_count']})")
    
    # Early stopping if metric consistently degrades
    if iteration_data['worse_count'] >= iteration_data['patience']:
        print(f"\n  ⚠ Early stopping: Metric degraded for {iteration_data['patience']} consecutive iterations")
        print(f"    Current: {current_metric:.4f}, Best: {iteration_data['best_metric']:.4f}")
        method.StopRegistration()

# Setup registration
registration_method = sitk.ImageRegistrationMethod()

# Mutual Information metric
registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
registration_method.SetMetricSamplingPercentage(0.1)

# Interpolator
registration_method.SetInterpolator(sitk.sitkLinear)

# Optimizer - Regular Step Gradient Descent with adjusted parameters
registration_method.SetOptimizerAsRegularStepGradientDescent(
    learningRate=1.0,
    minStep=0.0001,  # Reduced from 0.001 to allow finer steps
    numberOfIterations=200,  # Increased from 100
    gradientMagnitudeTolerance=1e-6
)
registration_method.SetOptimizerScalesFromIndexShift()

# Multi-resolution framework
registration_method.SetShrinkFactorsPerLevel(shrinkFactors=[4, 2, 1])
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=[2, 1, 0])
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()

# Initial transform - CenteredTransformInitializer with MOMENTS
initial_transform = sitk.CenteredTransformInitializer(
    ct_image_sitk,  # Fixed image (target space)
    mr_image_sitk,  # Moving image
    sitk.VersorRigid3DTransform(),
    sitk.CenteredTransformInitializerFilter.MOMENTS
)

registration_method.SetInitialTransform(initial_transform, inPlace=True)

# Connect callback
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: command_iteration(registration_method))

print("\nStarting MI registration (MR -> CT)...")
print("Initial transform (MOMENTS initialization):")
print(f"  Translation: {initial_transform.GetTranslation()}")
print(f"\nEarly stopping enabled: will stop if metric degrades for {iteration_data['patience']} iterations")

# Execute registration
mi_transform = registration_method.Execute(ct_image_sitk, mr_image_sitk)

final_metric = registration_method.GetMetricValue()
print(f"\n✓ MI Registration complete")
print(f"  Final metric value: {final_metric:.4f}")
print(f"  Best metric value:  {iteration_data['best_metric']:.4f}")
print(f"  Optimizer stop condition: {registration_method.GetOptimizerStopConditionDescription()}")

# Check if we should use the best parameters instead of final
metric_diff = abs(final_metric - iteration_data['best_metric'])
if metric_diff > 0.05 and iteration_data['best_params'] is not None:
    print(f"\n  ⚠ Final metric ({final_metric:.4f}) is worse than best ({iteration_data['best_metric']:.4f})")
    print(f"    Restoring best parameters from iteration {iteration_data['iteration'] - iteration_data['worse_count']}")
    
    # Create new transform with best parameters
    mi_transform_best = sitk.VersorRigid3DTransform()
    mi_transform_best.SetParameters(iteration_data['best_params'])
    mi_transform_best.SetFixedParameters(mi_transform.GetFixedParameters())
    mi_transform = mi_transform_best
    
    print(f"    ✓ Using best transform (metric: {iteration_data['best_metric']:.4f})")

# Get transformation parameters
mi_rotation_matrix = np.array(mi_transform.GetMatrix()).reshape(3, 3)
mi_translation = np.array(mi_transform.GetTranslation())

print(f"\nMI Transform Parameters (MR -> CT):")
print(f"  Translation: [{mi_translation[0]:.2f}, {mi_translation[1]:.2f}, {mi_translation[2]:.2f}] mm")

rotation_obj = R_scipy.from_matrix(mi_rotation_matrix)
euler_angles = rotation_obj.as_euler('xyz', degrees=True)
print(f"  Rotation: X={euler_angles[0]:.2f}°, Y={euler_angles[1]:.2f}°, Z={euler_angles[2]:.2f}°")

# Apply MI transform to MR image
mr_mi_registered = sitk.Resample(
    mr_image_sitk,
    ct_image_sitk,
    mi_transform,
    sitk.sitkLinear,
    0.0,
    mr_image_sitk.GetPixelID()
)

# Save intermediate result
mi_output = os.path.join(OUTPUT_DIR, 'mr_mi_registered.nii.gz')
sitk.WriteImage(mr_mi_registered, mi_output)
print(f"\n✓ Saved MI-registered MR: {mi_output}")

# Save MI transform
mi_transform_file = os.path.join(OUTPUT_DIR, 'mi_transform.tfm')
sitk.WriteTransform(mi_transform, mi_transform_file)
print(f"✓ Saved MI transform: {mi_transform_file}")

# Save metric evolution plot
if len(iteration_data['metric_values']) > 0:
    plt.figure(figsize=(10, 5))
    plt.plot(iteration_data['metric_values'], linewidth=1.5)
    
    # Mark multi-resolution transitions
    for mr_iter in iteration_data['multires_iterations'][1:]:
        plt.axvline(x=mr_iter, color='red', linestyle='--', alpha=0.5, label='Resolution change')
    
    # Mark best metric
    best_idx = iteration_data['metric_values'].index(iteration_data['best_metric'])
    plt.plot(best_idx, iteration_data['best_metric'], 'go', markersize=10, label=f'Best metric: {iteration_data["best_metric"]:.4f}')
    
    plt.xlabel('Iteration')
    plt.ylabel('Mutual Information Metric')
    plt.title('MI Registration Convergence (more negative = better)')
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    plot_file = os.path.join(OUTPUT_DIR, 'mi_convergence.png')
    plt.savefig(plot_file, dpi=150, bbox_inches='tight')
    plt.close()
    print(f"✓ Saved convergence plot: {plot_file}")


STAGE 1: MUTUAL INFORMATION REGISTRATION

Starting MI registration (MR -> CT)...
Initial transform (MOMENTS initialization):
  Translation: (-12.022542943517777, 192.42303869155546, 84.96163815264883)

Early stopping enabled: will stop if metric degrades for 15 iterations
  Iteration   20: Metric = -0.8047 (Best: -0.8047 at iter 19)

  ⚠ Early stopping: Metric degraded for 15 consecutive iterations
    Current: -0.8044, Best: -0.8047

  ⚠ Early stopping: Metric degraded for 15 consecutive iterations
    Current: -0.6893, Best: -0.8047

  ⚠ Early stopping: Metric degraded for 15 consecutive iterations
    Current: -0.5631, Best: -0.8047

✓ MI Registration complete
  Final metric value: -0.5631
  Best metric value:  -0.8047
  Optimizer stop condition: RegularStepGradientDescentOptimizerv4: 

  ⚠ Final metric (-0.5631) is worse than best (-0.8047)
    Restoring best parameters from iteration 19
    ✓ Using best transform (metric: -0.8047)

MI Transform Parameters (MR -> CT):
  Translatio

In [27]:
# Cell 5: Extract Prostate/Rectum Structures from RTStructs
def extract_prostate_structures(rtstruct_path, modality_name="Image"):
    """Extract prostate/rectum structures from RTStruct file."""
    print(f"\nExtracting {modality_name} prostate structures from:")
    print(f"  {os.path.basename(rtstruct_path)}")
    
    ds = pydicom.dcmread(rtstruct_path)
    
    if not hasattr(ds, 'StructureSetROISequence') or not hasattr(ds, 'ROIContourSequence'):
        print(f"  ✗ No structure data found")
        return {}
    
    # Build ROI info
    roi_info = {}
    for roi in ds.StructureSetROISequence:
        roi_info[roi.ROINumber] = roi.ROIName if roi.ROIName else f'ROI_{roi.ROINumber}'
    
    # Extract prostate/rectum structures
    prostate_contours = {}
    
    for contour_seq in ds.ROIContourSequence:
        roi_num = contour_seq.ReferencedROINumber
        roi_name = roi_info.get(roi_num, f'ROI_{roi_num}')
        
        if not hasattr(contour_seq, 'ContourSequence'):
            continue
        
        # Collect all contour points
        all_points = []
        for contour in contour_seq.ContourSequence:
            if hasattr(contour, 'ContourData'):
                points = np.array(contour.ContourData).reshape(-1, 3)
                all_points.append(points)
        
        if not all_points:
            continue
        
        all_points = np.vstack(all_points)
        
        # Compute properties to identify prostate/rectum
        centroid = np.mean(all_points, axis=0)
        z_extent = all_points[:, 2].max() - all_points[:, 2].min()
        x_extent = all_points[:, 0].max() - all_points[:, 0].min()
        y_extent = all_points[:, 1].max() - all_points[:, 1].min()
        volume = x_extent * y_extent * z_extent
        x_pos = abs(centroid[0])
        
        # Heuristic: prostate/rectum is near midline, moderate size, moderate Z extent
        if x_pos < 40 and 30 < z_extent < 100 and volume < 800000:
            prostate_contours[roi_name] = all_points
            print(f"  ✓ Found '{roi_name}': {len(all_points)} points, centroid=[{centroid[0]:.1f}, {centroid[1]:.1f}, {centroid[2]:.1f}]")
    
    return prostate_contours

print("\n" + "="*80)
print("EXTRACTING PROSTATE/RECTUM STRUCTURES")
print("="*80)

# Find RTStruct files
mr_rtstruct_files = [f for f in os.listdir(MR_RTSTRUCT_DIR) if f.endswith('.dcm') or not '.' in f]
ct_rtstruct_files = [f for f in os.listdir(CT_RTSTRUCT_DIR) if f.endswith('.dcm') or not '.' in f]

mr_prostate = {}
ct_prostate = {}

if mr_rtstruct_files:
    mr_rtstruct_path = os.path.join(MR_RTSTRUCT_DIR, mr_rtstruct_files[0])
    mr_prostate = extract_prostate_structures(mr_rtstruct_path, "MR")

if ct_rtstruct_files:
    ct_rtstruct_path = os.path.join(CT_RTSTRUCT_DIR, ct_rtstruct_files[0])
    ct_prostate = extract_prostate_structures(ct_rtstruct_path, "CT")

print(f"\nSummary:")
print(f"  MR prostate structures: {len(mr_prostate)}")
print(f"  CT prostate structures: {len(ct_prostate)}")


EXTRACTING PROSTATE/RECTUM STRUCTURES

Extracting MR prostate structures from:
  RTSTRUCT.1.3.277.1.7230011.4.1.5.3356100869.1309.1766424158.323854.dcm
  ✓ Found 'Anal_Canal': 3456 points, centroid=[0.2, -11.4, -10.7]
  ✓ Found 'Rectum': 12382 points, centroid=[-2.2, 3.8, 42.2]
  ✓ Found 'Prostate': 8350 points, centroid=[-1.9, -44.1, 19.6]

Extracting CT prostate structures from:
  RTSTRUCT.1.3.277.1.7230011.4.1.5.3356100357.1181.1764613163.280539.dcm
  ✓ Found 'Sigmoid_Colon': 7734 points, centroid=[-6.3, -224.7, -2.4]
  ✓ Found 'Duodenum': 4630 points, centroid=[-18.5, -258.9, 148.9]
  ✓ Found 'Bladder': 10154 points, centroid=[-0.3, -263.4, -37.3]
  ✓ Found 'Rectum': 7304 points, centroid=[5.6, -178.0, -49.4]
  ✓ Found 'Prostate': 4978 points, centroid=[4.3, -225.5, -69.2]
  ✓ Found 'Vena_Cava_Inf': 5292 points, centroid=[-12.7, -244.7, 123.9]

Summary:
  MR prostate structures: 3
  CT prostate structures: 6


In [30]:
# Cell 6: Stage 2 - Prostate Core Alignment (Inner 50%, Iterative Refinement)
print("\n" + "="*80)
print("STAGE 2: PROSTATE CORE ALIGNMENT (INNER 50%, ITERATIVE)")
print("="*80)

# Extract specific structures by name
def get_structure_by_name(contours_dict, structure_name):
    """Get structure by exact name match."""
    for name, points in contours_dict.items():
        if structure_name.lower() in name.lower():
            return name, points
    return None, None

def extract_inner_core(points, percentile=50):
    """Extract inner core of points by computing distances to centroid.
    
    Args:
        points: Nx3 array of points
        percentile: percentage of points to keep (closest to centroid)
    
    Returns:
        inner_points: Points in the inner core
        core_indices: Indices of core points in original array
        threshold_dist: Radius of the core
    """
    centroid = np.mean(points, axis=0)
    
    # Compute distances from each point to centroid
    distances = np.linalg.norm(points - centroid, axis=1)
    
    # Find threshold distance for inner percentile
    threshold_dist = np.percentile(distances, percentile)
    
    # Select points within threshold
    core_mask = distances <= threshold_dist
    core_indices = np.where(core_mask)[0]
    inner_points = points[core_mask]
    
    return inner_points, core_indices, threshold_dist

# Get Prostate from both modalities
mr_prostate_name, mr_prostate_points = get_structure_by_name(mr_prostate, 'Prostate')
ct_prostate_name, ct_prostate_points = get_structure_by_name(ct_prostate, 'Prostate')

print(f"\nStructure matching:")
print(f"  MR Prostate: '{mr_prostate_name}' - {len(mr_prostate_points) if mr_prostate_points is not None else 0} points")
print(f"  CT Prostate: '{ct_prostate_name}' - {len(ct_prostate_points) if ct_prostate_points is not None else 0} points")

if mr_prostate_points is not None and ct_prostate_points is not None:
    from scipy.spatial import distance_matrix
    
    # Original prostate centroids
    mr_prostate_centroid_orig = np.mean(mr_prostate_points, axis=0)
    ct_prostate_centroid = np.mean(ct_prostate_points, axis=0)
    
    print(f"\nOriginal prostate centroids (before any transform):")
    print(f"  MR: [{mr_prostate_centroid_orig[0]:.2f}, {mr_prostate_centroid_orig[1]:.2f}, {mr_prostate_centroid_orig[2]:.2f}]")
    print(f"  CT: [{ct_prostate_centroid[0]:.2f}, {ct_prostate_centroid[1]:.2f}, {ct_prostate_centroid[2]:.2f}]")
    print(f"  Distance: {np.linalg.norm(ct_prostate_centroid - mr_prostate_centroid_orig):.2f} mm")
    
    # Get MI transform components
    mi_translation = np.array(mi_transform.GetTranslation())
    mi_rotation_matrix = np.array(mi_transform.GetMatrix()).reshape(3, 3)
    
    print(f"\nMI Translation: [{mi_translation[0]:.2f}, {mi_translation[1]:.2f}, {mi_translation[2]:.2f}] mm")
    
    # Apply MI transform to MR prostate points
    print(f"\n" + "-"*80)
    print("APPLYING MI TRANSFORM TO MR PROSTATE")
    print("-"*80)
    
    # Transform MR points: SUBTRACT translation (moving MR into CT space)
    mr_prostate_mi = mr_prostate_points - mi_translation
    mr_centroid_mi = np.mean(mr_prostate_mi, axis=0)
    
    print(f"MR centroid after MI: [{mr_centroid_mi[0]:.2f}, {mr_centroid_mi[1]:.2f}, {mr_centroid_mi[2]:.2f}]")
    print(f"CT centroid:          [{ct_prostate_centroid[0]:.2f}, {ct_prostate_centroid[1]:.2f}, {ct_prostate_centroid[2]:.2f}]")
    
    centroid_error_after_mi = np.linalg.norm(mr_centroid_mi - ct_prostate_centroid)
    print(f"Distance after MI: {centroid_error_after_mi:.2f} mm")
    
    # Extract inner 50% core from both prostates
    print(f"\n" + "-"*80)
    print("EXTRACTING PROSTATE CORES (INNER 50%)")
    print("-"*80)
    
    mr_core, mr_core_indices, mr_core_threshold = extract_inner_core(mr_prostate_mi, percentile=50)
    ct_core, ct_core_indices, ct_core_threshold = extract_inner_core(ct_prostate_points, percentile=50)
    
    print(f"\nMR prostate core:")
    print(f"  Total points:     {len(mr_prostate_mi)}")
    print(f"  Core points:      {len(mr_core)} ({100*len(mr_core)/len(mr_prostate_mi):.1f}%)")
    print(f"  Core radius:      {mr_core_threshold:.2f} mm")
    print(f"  Core centroid:    [{np.mean(mr_core, axis=0)[0]:.2f}, {np.mean(mr_core, axis=0)[1]:.2f}, {np.mean(mr_core, axis=0)[2]:.2f}]")
    
    print(f"\nCT prostate core:")
    print(f"  Total points:     {len(ct_prostate_points)}")
    print(f"  Core points:      {len(ct_core)} ({100*len(ct_core)/len(ct_prostate_points):.1f}%)")
    print(f"  Core radius:      {ct_core_threshold:.2f} mm")
    print(f"  Core centroid:    [{np.mean(ct_core, axis=0)[0]:.2f}, {np.mean(ct_core, axis=0)[1]:.2f}, {np.mean(ct_core, axis=0)[2]:.2f}]")
    
    # Iterative refinement to allow for larger corrections
    print(f"\n" + "-"*80)
    print("ITERATIVE CORE-BASED NEAREST NEIGHBOR REFINEMENT")
    print("-"*80)
    
    # Start with core after MI transform
    mr_core_current = mr_core.copy()
    cumulative_refinement = np.zeros(3)
    
    max_iterations = 100
    convergence_threshold = 0.1  # mm - stop if refinement becomes very small
    
    refinement_history = []
    
    for iteration in range(max_iterations):
        # Find closest CT core point for each MR core point
        dist_matrix_core = distance_matrix(mr_core_current, ct_core)
        closest_ct_indices = np.argmin(dist_matrix_core, axis=1)
        closest_ct_points = ct_core[closest_ct_indices]
        
        # Compute displacement vectors
        displacements = closest_ct_points - mr_core_current
        
        # Current iteration refinement (use mean)
        current_refinement = np.mean(displacements, axis=0)
        refinement_magnitude = np.linalg.norm(current_refinement)
        
        # Accumulate
        cumulative_refinement += current_refinement
        
        # Apply to current positions
        mr_core_current = mr_core_current + current_refinement
        
        # Compute mean distance for convergence check
        mean_distance = np.mean(np.min(dist_matrix_core, axis=1))
        
        refinement_history.append({
            'iteration': iteration + 1,
            'refinement': current_refinement,
            'magnitude': refinement_magnitude,
            'cumulative': cumulative_refinement.copy(),
            'mean_dist': mean_distance
        })
        
        print(f"  Iter {iteration+1}: Refinement = [{current_refinement[0]:6.2f}, {current_refinement[1]:6.2f}, {current_refinement[2]:6.2f}] mm, "
              f"Mag = {refinement_magnitude:5.2f} mm, Mean dist = {mean_distance:5.2f} mm")
        
        # Check convergence
        if refinement_magnitude < convergence_threshold:
            print(f"\n  ✓ Converged after {iteration+1} iterations (refinement < {convergence_threshold} mm)")
            break
    else:
        print(f"\n  → Stopped after {max_iterations} iterations (max reached)")
    
    # Final cumulative refinement
    refinement_translation = cumulative_refinement
    
    print(f"\n" + "-"*80)
    print("REFINEMENT SUMMARY")
    print("-"*80)
    print(f"  Total iterations:     {len(refinement_history)}")
    print(f"  Cumulative refinement: [{refinement_translation[0]:.2f}, {refinement_translation[1]:.2f}, {refinement_translation[2]:.2f}] mm")
    print(f"  Total magnitude:      {np.linalg.norm(refinement_translation):.2f} mm")
    
    # Apply refinement to ALL MR prostate points (not just core)
    mr_prostate_refined = mr_prostate_mi + refinement_translation
    mr_centroid_refined = np.mean(mr_prostate_refined, axis=0)
    
    # Compute total translation: MI - Refinement (since we subtract MI and add refinement)
    total_translation = mi_translation - refinement_translation
    
    print(f"\n" + "="*80)
    print("TRANSFORMATION SUMMARY")
    print("="*80)
    print(f"MI Translation:        [{mi_translation[0]:.3f}, {mi_translation[1]:.3f}, {mi_translation[2]:.3f}] mm (magnitude: {np.linalg.norm(mi_translation):.2f} mm)")
    print(f"Refinement (Core-NN):  [{refinement_translation[0]:.3f}, {refinement_translation[1]:.3f}, {refinement_translation[2]:.3f}] mm (magnitude: {np.linalg.norm(refinement_translation):.2f} mm)")
    print(f"Total Translation:     [{total_translation[0]:.3f}, {total_translation[1]:.3f}, {total_translation[2]:.3f}] mm (magnitude: {np.linalg.norm(total_translation):.2f} mm)")
    
    # Compute alignment quality metrics using ALL points
    print(f"\n" + "="*80)
    print("ALIGNMENT QUALITY (Full Prostate)")
    print("="*80)
    
    final_centroid_error = np.linalg.norm(mr_centroid_refined - ct_prostate_centroid)
    print(f"  Centroid error:       {final_centroid_error:.2f} mm")
    
    # Surface distances for full prostate
    dist_matrix_full = distance_matrix(mr_prostate_refined, ct_prostate_points)
    min_distances_prostate = np.min(dist_matrix_full, axis=1)
    
    print(f"  Mean surface dist:    {np.mean(min_distances_prostate):.2f} mm")
    print(f"  Median surface dist:  {np.median(min_distances_prostate):.2f} mm")
    print(f"  95th percentile:      {np.percentile(min_distances_prostate, 95):.2f} mm")
    print(f"  Max distance:         {np.max(min_distances_prostate):.2f} mm")
    
    # Also report core-specific metrics
    print(f"\n" + "="*80)
    print("ALIGNMENT QUALITY (Inner 50% Core)")
    print("="*80)
    
    mr_core_refined = mr_core_current  # Already iteratively refined
    mr_core_centroid_refined = np.mean(mr_core_refined, axis=0)
    ct_core_centroid = np.mean(ct_core, axis=0)
    
    core_centroid_error = np.linalg.norm(mr_core_centroid_refined - ct_core_centroid)
    print(f"  Core centroid error:  {core_centroid_error:.2f} mm")
    
    dist_matrix_core_final = distance_matrix(mr_core_refined, ct_core)
    min_distances_core = np.min(dist_matrix_core_final, axis=1)
    
    print(f"  Core mean dist:       {np.mean(min_distances_core):.2f} mm")
    print(f"  Core median dist:     {np.median(min_distances_core):.2f} mm")
    print(f"  Core 95th percentile: {np.percentile(min_distances_core, 95):.2f} mm")
    print(f"  Core max distance:    {np.max(min_distances_core):.2f} mm")
    
    # Create composite transform: MI - Refinement (invert refinement)
    composite_transform = sitk.CompositeTransform(3)
    composite_transform.AddTransform(mi_transform)
    
    refinement_transform = sitk.TranslationTransform(3)
    refinement_transform.SetOffset((-refinement_translation).tolist())  # INVERTED
    composite_transform.AddTransform(refinement_transform)
    
    # Apply composite transform to full MR image
    print(f"\n" + "-"*80)
    print("RESAMPLING FULL MR IMAGE")
    print("-"*80)
    
    mr_refined = sitk.Resample(
        mr_image_sitk,
        ct_image_sitk,
        composite_transform,
        sitk.sitkLinear,
        0.0,
        mr_image_sitk.GetPixelID()
    )
    
    # Build final transformation matrix
    print(f"\n" + "="*80)
    print("FINAL TRANSFORMATION MATRIX")
    print("="*80)
    
    final_matrix = np.eye(4)
    final_matrix[:3, :3] = mi_rotation_matrix
    final_matrix[:3, 3] = total_translation
    
    print(f"\nComposite Transform (4x4 Homogeneous Matrix):")
    for row in final_matrix:
        print(f"  [{row[0]:10.6f} {row[1]:10.6f} {row[2]:10.6f} {row[3]:10.6f}]")
    
    rotation_obj = R_scipy.from_matrix(mi_rotation_matrix)
    euler_angles = rotation_obj.as_euler('xyz', degrees=True)
    
    print(f"\nTranslation: [{total_translation[0]:.3f}, {total_translation[1]:.3f}, {total_translation[2]:.3f}] mm")
    print(f"Rotation: [{euler_angles[0]:.3f}, {euler_angles[1]:.3f}, {euler_angles[2]:.3f}] degrees")
    
    # Save outputs
    refined_output = os.path.join(OUTPUT_DIR, 'mr_prostate_refined.nii.gz')
    sitk.WriteImage(mr_refined, refined_output)
    print(f"\n✓ Saved refined MR: {refined_output}")
    
    composite_file = os.path.join(OUTPUT_DIR, 'prostate_composite_transform.tfm')
    sitk.WriteTransform(composite_transform, composite_file)
    print(f"✓ Saved composite transform: {composite_file}")
    
    # Visualization: save core point clouds for verification
    np.save(os.path.join(OUTPUT_DIR, 'mr_core_points_refined.npy'), mr_core_refined)
    np.save(os.path.join(OUTPUT_DIR, 'ct_core_points.npy'), ct_core)
    np.save(os.path.join(OUTPUT_DIR, 'mr_full_points_refined.npy'), mr_prostate_refined)
    np.save(os.path.join(OUTPUT_DIR, 'ct_full_points.npy'), ct_prostate_points)
    print(f"✓ Saved point clouds for visualization")
    
    # Save refinement history
    import json
    history_file = os.path.join(OUTPUT_DIR, 'refinement_history.json')
    with open(history_file, 'w') as f:
        json.dump([{k: v.tolist() if isinstance(v, np.ndarray) else v 
                    for k, v in item.items()} for item in refinement_history], f, indent=2)
    print(f"✓ Saved refinement history: {history_file}")
    
    # Set variables for Cell 8
    final_transform = composite_transform
    mr_final = mr_refined
    final_error = final_centroid_error
    min_distances = min_distances_prostate
    ct_centroid = ct_prostate_centroid
    strategy_name = f"Core-NearestNeighbor-Iterative (Inner 50%, {len(refinement_history)} iters)"
    
else:
    print("\n✗ Prostate structures not found in both modalities")
    print("   Using MI registration only")
    final_transform = mi_transform
    mr_final = mr_mi_registered
    mr_refined = None


STAGE 2: PROSTATE CORE ALIGNMENT (INNER 50%, ITERATIVE)

Structure matching:
  MR Prostate: 'Prostate' - 8350 points
  CT Prostate: 'Prostate' - 4978 points

Original prostate centroids (before any transform):
  MR: [-1.86, -44.14, 19.61]
  CT: [4.35, -225.51, -69.23]
  Distance: 202.05 mm

MI Translation: [-6.78, 177.48, 89.24] mm

--------------------------------------------------------------------------------
APPLYING MI TRANSFORM TO MR PROSTATE
--------------------------------------------------------------------------------
MR centroid after MI: [4.92, -221.63, -69.63]
CT centroid:          [4.35, -225.51, -69.23]
Distance after MI: 3.95 mm

--------------------------------------------------------------------------------
EXTRACTING PROSTATE CORES (INNER 50%)
--------------------------------------------------------------------------------

MR prostate core:
  Total points:     8350
  Core points:      4175 (50.0%)
  Core radius:      21.27 mm
  Core centroid:    [4.94, -222.43, -72

In [29]:
# Cell 8: Final Transformation Matrix Summary
print("\n" + "="*80)
print("TRANSFORMATION MATRIX SUMMARY")
print("="*80)

# MI Transform
mi_matrix = np.array(mi_transform.GetMatrix()).reshape(3, 3)
mi_trans = np.array(mi_transform.GetTranslation())

mi_affine = np.eye(4)
mi_affine[:3, :3] = mi_matrix
mi_affine[:3, 3] = mi_trans

print("\nSTAGE 1: MI Registration Transform (4x4 Homogeneous Matrix)")
print("-" * 80)
for row in mi_affine:
    print(f"[{row[0]:10.6f} {row[1]:10.6f} {row[2]:10.6f} {row[3]:10.6f}]")

rotation_obj = R_scipy.from_matrix(mi_matrix)
mi_euler = rotation_obj.as_euler('xyz', degrees=True)

print(f"\n  Translation: [{mi_trans[0]:.3f}, {mi_trans[1]:.3f}, {mi_trans[2]:.3f}] mm")
print(f"  Rotation: [{mi_euler[0]:.3f}, {mi_euler[1]:.3f}, {mi_euler[2]:.3f}] degrees")

if 'refinement_transform' in locals() and 'strategy_name' in locals():
    print(f"\n\nSTAGE 2: Prostate Refinement ({strategy_name})")
    print("-" * 80)
    refine_trans = np.array(refinement_transform.GetOffset())
    
    # Refinement translation (applied after MI transform)
    print(f"  Refinement translation: [{refine_trans[0]:.3f}, {refine_trans[1]:.3f}, {refine_trans[2]:.3f}] mm")
    print(f"  Magnitude: {np.linalg.norm(refine_trans):.3f} mm")
    
    # Total translation is MI + refinement (both applied in sequence)
    total_trans = mi_trans + refine_trans
    print(f"\n  Total translation (MI + Refinement):")
    print(f"    [{total_trans[0]:.3f}, {total_trans[1]:.3f}, {total_trans[2]:.3f}] mm")
    print(f"    Magnitude: {np.linalg.norm(total_trans):.3f} mm")
    
    # For 3D Slicer: Use the same rotation, but adjust translation for Slicer's coordinate system
    # Slicer expects RAS coordinates, ITK uses LPS
    # The transform from ITK to Slicer requires flipping X and Y
    slicer_trans = total_trans.copy()
    slicer_trans[0] = -slicer_trans[0]  # Flip X (L->R)
    slicer_trans[1] = -slicer_trans[1]  # Flip Y (P->A)
    # Z stays the same (S->S)
    
    print(f"\n  For 3D Slicer (RAS coordinates):")
    print(f"    Total translation: [{slicer_trans[0]:.3f}, {slicer_trans[1]:.3f}, {slicer_trans[2]:.3f}] mm")
    print(f"    Rotation: [{mi_euler[0]:.3f}, {mi_euler[1]:.3f}, {mi_euler[2]:.3f}] degrees")
    
    # Build final affine matrix for Slicer
    final_affine_slicer = np.eye(4)
    final_affine_slicer[:3, :3] = mi_matrix
    final_affine_slicer[:3, 3] = slicer_trans
    
    print("\n\nFINAL: Transform for 3D Slicer (4x4 Matrix)")
    print("-" * 80)
    for row in final_affine_slicer:
        print(f"[{row[0]:10.6f} {row[1]:10.6f} {row[2]:10.6f} {row[3]:10.6f}]")
    
    # Quality metrics
    if 'final_error' in locals() and 'min_distances' in locals():
        print(f"\n\nProstate Alignment Quality:")
        print(f"  Strategy used:        {strategy_name}")
        print(f"  Centroid error:       {final_error:.2f} mm")
        print(f"  Mean surface dist:    {np.mean(min_distances):.2f} mm")
        print(f"  Median surface dist:  {np.median(min_distances):.2f} mm")
        print(f"  95th percentile:      {np.percentile(min_distances, 95):.2f} mm")
    
    # Save to file
    summary_file = os.path.join(OUTPUT_DIR, 'transformation_summary.txt')
    with open(summary_file, 'w') as f:
        f.write("=" * 80 + "\n")
        f.write("MRI-CT PROSTATE REGISTRATION TRANSFORMATION SUMMARY\n")
        f.write("=" * 80 + "\n\n")
        
        f.write("STAGE 1: Mutual Information (MI) Registration\n")
        f.write("-" * 80 + "\n")
        f.write("4x4 Homogeneous Matrix:\n")
        for row in mi_affine:
            f.write(f"  [{row[0]:10.6f} {row[1]:10.6f} {row[2]:10.6f} {row[3]:10.6f}]\n")
        f.write(f"\nTranslation: [{mi_trans[0]:.3f}, {mi_trans[1]:.3f}, {mi_trans[2]:.3f}] mm\n")
        f.write(f"Rotation (XYZ Euler): [{mi_euler[0]:.3f}, {mi_euler[1]:.3f}, {mi_euler[2]:.3f}] degrees\n\n")
        
        f.write(f"\nSTAGE 2: Prostate Structure-Based Refinement\n")
        f.write("-" * 80 + "\n")
        f.write(f"Strategy: {strategy_name}\n")
        f.write(f"Refinement translation: [{refine_trans[0]:.3f}, {refine_trans[1]:.3f}, {refine_trans[2]:.3f}] mm\n")
        f.write(f"Total translation (MI + Refinement): [{total_trans[0]:.3f}, {total_trans[1]:.3f}, {total_trans[2]:.3f}] mm\n\n")
        
        f.write("\nFINAL: Composite Transform for 3D Slicer\n")
        f.write("-" * 80 + "\n")
        f.write(f"Translation (RAS): [{slicer_trans[0]:.3f}, {slicer_trans[1]:.3f}, {slicer_trans[2]:.3f}] mm\n")
        f.write(f"Rotation (XYZ Euler): [{mi_euler[0]:.3f}, {mi_euler[1]:.3f}, {mi_euler[2]:.3f}] degrees\n\n")
        f.write("4x4 Homogeneous Matrix:\n")
        for row in final_affine_slicer:
            f.write(f"  [{row[0]:10.6f} {row[1]:10.6f} {row[2]:10.6f} {row[3]:10.6f}]\n")
        
        if 'final_error' in locals() and 'min_distances' in locals():
            f.write("\n\nAlignment Quality Metrics:\n")
            f.write("-" * 80 + "\n")
            f.write(f"Centroid error:       {final_error:.2f} mm\n")
            f.write(f"Mean surface dist:    {np.mean(min_distances):.2f} mm\n")
            f.write(f"Median surface dist:  {np.median(min_distances):.2f} mm\n")
            f.write(f"95th percentile:      {np.percentile(min_distances, 95):.2f} mm\n")
            f.write(f"Max distance:         {np.max(min_distances):.2f} mm\n")
    
    print(f"\n✓ Summary saved to: {summary_file}")
    
else:
    print("\n(Only MI registration was performed - no refinement data available)")


TRANSFORMATION MATRIX SUMMARY

STAGE 1: MI Registration Transform (4x4 Homogeneous Matrix)
--------------------------------------------------------------------------------
[  0.999938   0.003795  -0.010443  -6.781345]
[ -0.003773   0.999991   0.002072 177.483603]
[  0.010451  -0.002032   0.999943  89.239807]
[  0.000000   0.000000   0.000000   1.000000]

  Translation: [-6.781, 177.484, 89.240] mm
  Rotation: [-0.116, -0.599, -0.216] degrees


STAGE 2: Prostate Refinement (Core-NearestNeighbor-Iterative (Inner 50%, 8 iters))
--------------------------------------------------------------------------------
  Refinement translation: [0.675, 4.283, 2.797] mm
  Magnitude: 5.159 mm

  Total translation (MI + Refinement):
    [-6.107, 181.766, 92.037] mm
    Magnitude: 203.831 mm

  For 3D Slicer (RAS coordinates):
    Total translation: [6.107, -181.766, 92.037] mm
    Rotation: [-0.116, -0.599, -0.216] degrees


FINAL: Transform for 3D Slicer (4x4 Matrix)
----------------------------------