In [11]:
# Cell 1: Environment Setup and Dependencies
import os
import sys
import warnings
warnings.filterwarnings('ignore')

import pydicom
import numpy as np
import matplotlib.pyplot as plt
from scipy.spatial import distance_matrix
from scipy.spatial.transform import Rotation as R_scipy

try:
    import SimpleITK as sitk
    print(f"SimpleITK version: {sitk.Version.VersionString()}")
except ImportError as e:
    print(f"Error: SimpleITK not found. Install with: pip install SimpleITK")
    sys.exit(1)

print(f"Python: {sys.version.split()[0]}, NumPy: {np.__version__}")

SimpleITK version: 2.5.2
Python: 3.13.7, NumPy: 2.3.3


In [12]:
# Cell 2: Configuration and Directory Setup
BASE_DIR = r'c:\Users\zhaoanr\Desktop\_anon3'
MR_DIR = os.path.join(BASE_DIR, 'MR')
CT_DIR = os.path.join(BASE_DIR, 'CT')

OUTPUT_DIR = r'c:\Users\zhaoanr\Desktop\ct2-mri-registration\structure_only_output'
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Expected subdirectories
MR_IMAGES_DIR = os.path.join(MR_DIR, 'MR')
MR_RTSTRUCT_DIR = os.path.join(MR_DIR, 'RTSTRUCT')
CT_IMAGES_DIR = os.path.join(CT_DIR, 'CT')
CT_RTSTRUCT_DIR = os.path.join(CT_DIR, 'RTSTRUCT')

print("Directory structure:")
for label, path in [("MR Images", MR_IMAGES_DIR), 
                     ("MR RTStruct", MR_RTSTRUCT_DIR),
                     ("CT Images", CT_IMAGES_DIR), 
                     ("CT RTStruct", CT_RTSTRUCT_DIR)]:
    exists = "✓" if os.path.exists(path) else "✗"
    print(f"{exists} {label}: {path}")

print(f"\nOutput: {OUTPUT_DIR}")

Directory structure:
✓ MR Images: c:\Users\zhaoanr\Desktop\_anon3\MR\MR
✓ MR RTStruct: c:\Users\zhaoanr\Desktop\_anon3\MR\RTSTRUCT
✓ CT Images: c:\Users\zhaoanr\Desktop\_anon3\CT\CT
✓ CT RTStruct: c:\Users\zhaoanr\Desktop\_anon3\CT\RTSTRUCT

Output: c:\Users\zhaoanr\Desktop\ct2-mri-registration\structure_only_output


In [13]:
# Cell 3: Load Image Series
def load_dicom_series_sitk(directory, series_description="Series"):
    """Load DICOM series from directory using SimpleITK."""
    print(f"\nLoading {series_description} from {os.path.basename(directory)}...")
    
    # Get all DICOM files
    dicom_files = [os.path.join(directory, f) for f in os.listdir(directory) 
                   if f.endswith('.dcm') or not '.' in f]
    
    if not dicom_files:
        raise ValueError(f"No DICOM files found in {directory}")
    
    # Sort by Z position
    sorted_files = []
    for filepath in dicom_files:
        try:
            header = pydicom.dcmread(filepath, stop_before_pixels=True, force=True)
            if 'ImagePositionPatient' in header:
                z_pos = float(header.ImagePositionPatient[2])
                sorted_files.append((filepath, z_pos))
        except:
            continue
    
    sorted_files.sort(key=lambda x: x[1])
    file_paths = [f[0] for f in sorted_files]
    
    print(f"  Found {len(file_paths)} slices")
    
    # Read with SimpleITK
    reader = sitk.ImageSeriesReader()
    reader.SetFileNames(file_paths)
    reader.MetaDataDictionaryArrayUpdateOn()
    reader.LoadPrivateTagsOn()
    
    image = reader.Execute()
    
    # Print metadata
    size = image.GetSize()
    spacing = image.GetSpacing()
    origin = image.GetOrigin()
    
    print(f"  Size: {size}")
    print(f"  Spacing: {np.round(spacing, 2)} mm")
    print(f"  Origin: {np.round(origin, 1)} mm")
    
    return image

print("="*80)
print("LOADING IMAGE SERIES")
print("="*80)

mr_image_sitk = load_dicom_series_sitk(MR_IMAGES_DIR, "MR")
ct_image_sitk = load_dicom_series_sitk(CT_IMAGES_DIR, "CT")

# Convert to float32
mr_image_sitk = sitk.Cast(mr_image_sitk, sitk.sitkFloat32)
ct_image_sitk = sitk.Cast(ct_image_sitk, sitk.sitkFloat32)

print(f"\nMR image type: {mr_image_sitk.GetPixelIDTypeAsString()}")
print(f"CT image type: {ct_image_sitk.GetPixelIDTypeAsString()}")

LOADING IMAGE SERIES

Loading MR from MR...
  Found 38 slices
  Size: (672, 672, 38)
  Spacing: [0.3 0.3 3. ] mm
  Origin: [ -98.5 -110.7  -95.3] mm

Loading CT from CT...
  Found 252 slices
  Size: (512, 512, 252)
  Spacing: [0.98 0.98 1.5 ] mm
  Origin: [-249.5 -482.5 -129. ] mm

MR image type: 32-bit float
CT image type: 32-bit float


In [14]:
# Cell 4: Extract Prostate Structures from RTStructs
def extract_prostate_structures(rtstruct_path, modality_name="Image"):
    """Extract prostate structures from RTStruct file."""
    print(f"\nExtracting {modality_name} prostate structures from:")
    print(f"  {os.path.basename(rtstruct_path)}")
    
    ds = pydicom.dcmread(rtstruct_path)
    
    if not hasattr(ds, 'StructureSetROISequence') or not hasattr(ds, 'ROIContourSequence'):
        print(f"  ✗ No structure data found")
        return {}
    
    # Build ROI info
    roi_info = {}
    for roi in ds.StructureSetROISequence:
        roi_info[roi.ROINumber] = roi.ROIName if roi.ROIName else f'ROI_{roi.ROINumber}'
    
    # Extract prostate structures
    prostate_contours = {}
    
    for contour_seq in ds.ROIContourSequence:
        roi_num = contour_seq.ReferencedROINumber
        roi_name = roi_info.get(roi_num, f'ROI_{roi_num}')
        
        if not hasattr(contour_seq, 'ContourSequence'):
            continue
        
        # Collect all contour points
        all_points = []
        for contour in contour_seq.ContourSequence:
            if hasattr(contour, 'ContourData'):
                points = np.array(contour.ContourData).reshape(-1, 3)
                all_points.append(points)
        
        if not all_points:
            continue
        
        all_points = np.vstack(all_points)
        
        # Compute properties to identify prostate
        centroid = np.mean(all_points, axis=0)
        z_extent = all_points[:, 2].max() - all_points[:, 2].min()
        x_extent = all_points[:, 0].max() - all_points[:, 0].min()
        y_extent = all_points[:, 1].max() - all_points[:, 1].min()
        volume = x_extent * y_extent * z_extent
        x_pos = abs(centroid[0])
        
        # Heuristic: prostate is near midline, moderate size, moderate Z extent
        if x_pos < 40 and 30 < z_extent < 100 and volume < 800000:
            prostate_contours[roi_name] = all_points
            print(f"  ✓ Found '{roi_name}': {len(all_points)} points, centroid=[{centroid[0]:.1f}, {centroid[1]:.1f}, {centroid[2]:.1f}]")
    
    return prostate_contours

print("\n" + "="*80)
print("EXTRACTING PROSTATE STRUCTURES")
print("="*80)

# Find RTStruct files
mr_rtstruct_files = [f for f in os.listdir(MR_RTSTRUCT_DIR) if f.endswith('.dcm') or not '.' in f]
ct_rtstruct_files = [f for f in os.listdir(CT_RTSTRUCT_DIR) if f.endswith('.dcm') or not '.' in f]

mr_prostate = {}
ct_prostate = {}

if mr_rtstruct_files:
    mr_rtstruct_path = os.path.join(MR_RTSTRUCT_DIR, mr_rtstruct_files[0])
    mr_prostate = extract_prostate_structures(mr_rtstruct_path, "MR")

if ct_rtstruct_files:
    ct_rtstruct_path = os.path.join(CT_RTSTRUCT_DIR, ct_rtstruct_files[0])
    ct_prostate = extract_prostate_structures(ct_rtstruct_path, "CT")

print(f"\nSummary:")
print(f"  MR prostate structures: {len(mr_prostate)}")
print(f"  CT prostate structures: {len(ct_prostate)}")


EXTRACTING PROSTATE STRUCTURES

Extracting MR prostate structures from:
  RTSTRUCT.1.3.277.1.7230011.4.1.5.3356100869.1239.1766424042.372684.dcm
  ✓ Found 'Rectum': 13930 points, centroid=[-11.9, 29.1, -28.1]
  ✓ Found 'Prostate': 7482 points, centroid=[-8.2, -8.2, -45.2]
  ✓ Found 'Bladder': 9630 points, centroid=[-5.1, -36.1, -17.2]

Extracting CT prostate structures from:
  RTSTRUCT.1.3.277.1.7230011.4.1.5.3356100357.1648.1764617987.981073.dcm
  ✓ Found 'Anal_Canal': 1944 points, centroid=[-0.9, -189.3, -57.5]
  ✓ Found 'Bladder': 6872 points, centroid=[1.8, -253.5, 1.7]
  ✓ Found 'Rectum': 5796 points, centroid=[-0.6, -181.4, -7.9]
  ✓ Found 'ThecalSac': 1444 points, centroid=[7.1, -162.3, 223.9]
  ✓ Found 'Prostate': 5200 points, centroid=[1.7, -225.6, -20.0]
  ✓ Found 'Vena_Cava_Inf': 4678 points, centroid=[-6.4, -243.3, 200.3]

Summary:
  MR prostate structures: 3
  CT prostate structures: 6


In [15]:
# Cell 5: Prostate Core Alignment (Structure-Only, No MI)
print("\n" + "="*80)
print("PROSTATE CORE ALIGNMENT (STRUCTURE-ONLY, INNER 50%)")
print("="*80)

def get_structure_by_name(contours_dict, structure_name):
    """Get structure by exact name match."""
    for name, points in contours_dict.items():
        if structure_name.lower() in name.lower():
            return name, points
    return None, None

def extract_inner_core(points, percentile=50):
    """Extract inner core of points by computing distances to centroid."""
    centroid = np.mean(points, axis=0)
    distances = np.linalg.norm(points - centroid, axis=1)
    threshold_dist = np.percentile(distances, percentile)
    core_mask = distances <= threshold_dist
    core_indices = np.where(core_mask)[0]
    inner_points = points[core_mask]
    return inner_points, core_indices, threshold_dist

# Get Prostate from both modalities
mr_prostate_name, mr_prostate_points = get_structure_by_name(mr_prostate, 'Prostate')
ct_prostate_name, ct_prostate_points = get_structure_by_name(ct_prostate, 'Prostate')

print(f"\nStructure matching:")
print(f"  MR Prostate: '{mr_prostate_name}' - {len(mr_prostate_points) if mr_prostate_points is not None else 0} points")
print(f"  CT Prostate: '{ct_prostate_name}' - {len(ct_prostate_points) if ct_prostate_points is not None else 0} points")

if mr_prostate_points is not None and ct_prostate_points is not None:
    # Original prostate centroids
    mr_prostate_centroid = np.mean(mr_prostate_points, axis=0)
    ct_prostate_centroid = np.mean(ct_prostate_points, axis=0)
    
    print(f"\nOriginal prostate centroids:")
    print(f"  MR: [{mr_prostate_centroid[0]:.2f}, {mr_prostate_centroid[1]:.2f}, {mr_prostate_centroid[2]:.2f}]")
    print(f"  CT: [{ct_prostate_centroid[0]:.2f}, {ct_prostate_centroid[1]:.2f}, {ct_prostate_centroid[2]:.2f}]")
    print(f"  Initial distance: {np.linalg.norm(ct_prostate_centroid - mr_prostate_centroid):.2f} mm")
    
    # Extract inner 50% core from both prostates
    print(f"\n" + "-"*80)
    print("EXTRACTING PROSTATE CORES (INNER 50%)")
    print("-"*80)
    
    mr_core, mr_core_indices, mr_core_threshold = extract_inner_core(mr_prostate_points, percentile=50)
    ct_core, ct_core_indices, ct_core_threshold = extract_inner_core(ct_prostate_points, percentile=50)
    
    print(f"\nMR prostate core:")
    print(f"  Total points:     {len(mr_prostate_points)}")
    print(f"  Core points:      {len(mr_core)} ({100*len(mr_core)/len(mr_prostate_points):.1f}%)")
    print(f"  Core radius:      {mr_core_threshold:.2f} mm")
    
    print(f"\nCT prostate core:")
    print(f"  Total points:     {len(ct_prostate_points)}")
    print(f"  Core points:      {len(ct_core)} ({100*len(ct_core)/len(ct_prostate_points):.1f}%)")
    print(f"  Core radius:      {ct_core_threshold:.2f} mm")
    
    # Iterative refinement
    print(f"\n" + "-"*80)
    print("ITERATIVE CORE-BASED NEAREST NEIGHBOR ALIGNMENT")
    print("-"*80)
    
    mr_core_current = mr_core.copy()
    cumulative_translation = np.zeros(3)
    
    max_iterations = 100
    convergence_threshold = 0.1  # mm
    
    refinement_history = []
    
    for iteration in range(max_iterations):
        # Find closest CT core point for each MR core point
        dist_matrix_core = distance_matrix(mr_core_current, ct_core)
        closest_ct_indices = np.argmin(dist_matrix_core, axis=1)
        closest_ct_points = ct_core[closest_ct_indices]
        
        # Compute displacement vectors
        displacements = closest_ct_points - mr_core_current
        
        # Current iteration refinement (use mean)
        current_refinement = np.mean(displacements, axis=0)
        refinement_magnitude = np.linalg.norm(current_refinement)
        
        # Accumulate
        cumulative_translation += current_refinement
        
        # Apply to current positions
        mr_core_current = mr_core_current + current_refinement
        
        # Compute mean distance
        mean_distance = np.mean(np.min(dist_matrix_core, axis=1))
        
        refinement_history.append({
            'iteration': iteration + 1,
            'refinement': current_refinement,
            'magnitude': refinement_magnitude,
            'cumulative': cumulative_translation.copy(),
            'mean_dist': mean_distance
        })
        
        print(f"  Iter {iteration+1}: Refinement = [{current_refinement[0]:6.2f}, {current_refinement[1]:6.2f}, {current_refinement[2]:6.2f}] mm, "
              f"Mag = {refinement_magnitude:5.2f} mm, Mean dist = {mean_distance:5.2f} mm")
        
        # Check convergence
        if refinement_magnitude < convergence_threshold:
            print(f"\n  ✓ Converged after {iteration+1} iterations (refinement < {convergence_threshold} mm)")
            break
    else:
        print(f"\n  → Stopped after {max_iterations} iterations (max reached)")
    
    print(f"\n" + "-"*80)
    print("ALIGNMENT SUMMARY")
    print("-"*80)
    print(f"  Total iterations:     {len(refinement_history)}")
    print(f"  Total translation:    [{cumulative_translation[0]:.2f}, {cumulative_translation[1]:.2f}, {cumulative_translation[2]:.2f}] mm")
    print(f"  Translation magnitude: {np.linalg.norm(cumulative_translation):.2f} mm")
    
    # Apply translation to ALL MR prostate points
    mr_prostate_aligned = mr_prostate_points + cumulative_translation
    mr_centroid_aligned = np.mean(mr_prostate_aligned, axis=0)
    
    # Compute alignment quality metrics
    print(f"\n" + "="*80)
    print("ALIGNMENT QUALITY (Full Prostate)")
    print("="*80)
    
    final_centroid_error = np.linalg.norm(mr_centroid_aligned - ct_prostate_centroid)
    print(f"  Centroid error:       {final_centroid_error:.2f} mm")
    
    # Surface distances
    dist_matrix_full = distance_matrix(mr_prostate_aligned, ct_prostate_points)
    min_distances_prostate = np.min(dist_matrix_full, axis=1)
    
    print(f"  Mean surface dist:    {np.mean(min_distances_prostate):.2f} mm")
    print(f"  Median surface dist:  {np.median(min_distances_prostate):.2f} mm")
    print(f"  95th percentile:      {np.percentile(min_distances_prostate, 95):.2f} mm")
    print(f"  Max distance:         {np.max(min_distances_prostate):.2f} mm")
    
    # Create translation-only transform
    translation_transform = sitk.TranslationTransform(3)
    translation_transform.SetOffset(cumulative_translation.tolist())
    
    # Apply transform to full MR image
    print(f"\n" + "-"*80)
    print("RESAMPLING FULL MR IMAGE")
    print("-"*80)
    
    mr_aligned = sitk.Resample(
        mr_image_sitk,
        ct_image_sitk,
        translation_transform,
        sitk.sitkLinear,
        0.0,
        mr_image_sitk.GetPixelID()
    )
    
    # Build transformation matrix (translation-only)
    print(f"\n" + "="*80)
    print("FINAL TRANSFORMATION MATRIX")
    print("="*80)
    
    final_matrix = np.eye(4)
    final_matrix[:3, 3] = cumulative_translation
    
    print(f"\nTransformation (4x4 Homogeneous Matrix):")
    for row in final_matrix:
        print(f"  [{row[0]:10.6f} {row[1]:10.6f} {row[2]:10.6f} {row[3]:10.6f}]")
    
    print(f"\nTranslation: [{cumulative_translation[0]:.3f}, {cumulative_translation[1]:.3f}, {cumulative_translation[2]:.3f}] mm")
    print(f"Rotation: [0.000, 0.000, 0.000] degrees (translation-only)")
    
    # Save outputs
    aligned_output = os.path.join(OUTPUT_DIR, 'mr_prostate_aligned.nii.gz')
    sitk.WriteImage(mr_aligned, aligned_output)
    print(f"\n✓ Saved aligned MR: {aligned_output}")
    
    transform_file = os.path.join(OUTPUT_DIR, 'prostate_translation_transform.tfm')
    sitk.WriteTransform(translation_transform, transform_file)
    print(f"✓ Saved transform: {transform_file}")
    
    # Save point clouds
    np.save(os.path.join(OUTPUT_DIR, 'mr_core_points_aligned.npy'), mr_core_current)
    np.save(os.path.join(OUTPUT_DIR, 'ct_core_points.npy'), ct_core)
    np.save(os.path.join(OUTPUT_DIR, 'mr_full_points_aligned.npy'), mr_prostate_aligned)
    np.save(os.path.join(OUTPUT_DIR, 'ct_full_points.npy'), ct_prostate_points)
    print(f"✓ Saved point clouds for visualization")
    
    # Save refinement history
    import json
    history_file = os.path.join(OUTPUT_DIR, 'alignment_history.json')
    with open(history_file, 'w') as f:
        json.dump([{k: v.tolist() if isinstance(v, np.ndarray) else v 
                    for k, v in item.items()} for item in refinement_history], f, indent=2)
    print(f"✓ Saved alignment history: {history_file}")
    
    # Summary file
    summary_file = os.path.join(OUTPUT_DIR, 'alignment_summary.txt')
    with open(summary_file, 'w') as f:
        f.write("=" * 80 + "\n")
        f.write("MRI-CT PROSTATE ALIGNMENT (STRUCTURE-ONLY)\n")
        f.write("=" * 80 + "\n\n")
        
        f.write("Method: Iterative Core-Based Nearest Neighbor (Inner 50%)\n")
        f.write(f"Iterations: {len(refinement_history)}\n")
        f.write(f"Translation: [{cumulative_translation[0]:.3f}, {cumulative_translation[1]:.3f}, {cumulative_translation[2]:.3f}] mm\n")
        f.write(f"Magnitude: {np.linalg.norm(cumulative_translation):.2f} mm\n\n")
        
        f.write("Alignment Quality:\n")
        f.write(f"  Centroid error:       {final_centroid_error:.2f} mm\n")
        f.write(f"  Mean surface dist:    {np.mean(min_distances_prostate):.2f} mm\n")
        f.write(f"  Median surface dist:  {np.median(min_distances_prostate):.2f} mm\n")
        f.write(f"  95th percentile:      {np.percentile(min_distances_prostate, 95):.2f} mm\n")
        f.write(f"  Max distance:         {np.max(min_distances_prostate):.2f} mm\n")
    
    print(f"✓ Saved summary: {summary_file}")
    
else:
    print("\n✗ Prostate structures not found in both modalities")
    print("   Cannot perform alignment")


PROSTATE CORE ALIGNMENT (STRUCTURE-ONLY, INNER 50%)

Structure matching:
  MR Prostate: 'Prostate' - 7482 points
  CT Prostate: 'Prostate' - 5200 points

Original prostate centroids:
  MR: [-8.16, -8.20, -45.22]
  CT: [1.69, -225.65, -19.95]
  Initial distance: 219.13 mm

--------------------------------------------------------------------------------
EXTRACTING PROSTATE CORES (INNER 50%)
--------------------------------------------------------------------------------

MR prostate core:
  Total points:     7482
  Core points:      3741 (50.0%)
  Core radius:      21.94 mm

CT prostate core:
  Total points:     5200
  Core points:      2600 (50.0%)
  Core radius:      24.16 mm

--------------------------------------------------------------------------------
ITERATIVE CORE-BASED NEAREST NEIGHBOR ALIGNMENT
--------------------------------------------------------------------------------
  Iter 1: Refinement = [ 13.06, -199.99,  31.60] mm, Mag = 202.89 mm, Mean dist = 203.22 mm
  Iter 2: R