# Resample and Overlay PET/CT -> CT Sim

This notebook reads DICOM series, prints image metadata (origin/spacing/direction),
and contains helper functions to resample a moving volume (PET/CT or CT_2) into the fixed
grid of a target CT (CT_1 or CT_Sim) using SimpleITK. It also includes an overlay viewer
for quick slice-by-slice checks.

In [1]:
# Basic imports and a friendly SimpleITK import check
import os
import sys
import numpy as np
import matplotlib.pyplot as plt

# Try to import SimpleITK; if missing, user can install using pip as shown below.
try:
    import SimpleITK as sitk
except Exception as e:
    sitk = None
    print("SimpleITK not available. Install with: pip install SimpleITK")
    print("Import error:", e)

In [2]:
# Helper functions: read series, build affine, resample, and overlay viewer
def read_dicom_series(folder):
    """Read a DICOM series folder and return a SimpleITK image.
    If SimpleITK is not installed this will raise.
    """
    if sitk is None:
        raise RuntimeError('SimpleITK is required for read_dicom_series')
    reader = sitk.ImageSeriesReader()
    series_IDs = reader.GetGDCMSeriesIDs(folder)
    if not series_IDs:
        # try reading single-file image (nifti/mha/etc)
        return sitk.ReadImage(folder)
    series_file_names = reader.GetGDCMSeriesFileNames(folder, series_IDs[0])
    reader.SetFileNames(series_file_names)
    image = reader.Execute()
    return image

def affine_from_4x4(mat4):
    """Convert a 4x4 numpy homogeneous transform to a SimpleITK AffineTransform.
    mat4: numpy array shape (4,4) mapping moving->fixed in world coords.
    """
    assert mat4.shape == (4,4)
    linear = mat4[:3,:3].astype(float)
    translation = mat4[:3,3].astype(float)
    t = sitk.AffineTransform(3)
    t.SetMatrix(linear.reshape((9,)).tolist())
    t.SetTranslation(translation.tolist())
    return t

def resample_to_target(moving, fixed, transform=None, interpolator='linear', default_value=0.0):
    if sitk is None:
        raise RuntimeError('SimpleITK required for resampling')
    if transform is None:
        transform = sitk.Transform()
    interp_map = { 'linear': sitk.sitkLinear, 'nearest': sitk.sitkNearestNeighbor, 'bspline': sitk.sitkBSpline }
    sitk_interp = interp_map.get(interpolator, sitk.sitkLinear)
    resampled = sitk.Resample(moving, fixed, transform, sitk_interp, default_value, moving.GetPixelID())
    return resampled

def world_z_to_index(image, z_world):
    origin = np.array(image.GetOrigin())
    spacing = np.array(image.GetSpacing())
    direction = np.array(image.GetDirection()).reshape(3,3)
    dz = np.array([0.0, 0.0, z_world]) - origin
    dir_k = direction[:,2]
    k = np.dot(dz, dir_k) / spacing[2]
    return k

def show_overlay_slice(fixed, moving_resampled, slice_index=None, z_world=None, cmap_moving='hot', alpha=0.5):
    fixed_np = sitk.GetArrayFromImage(fixed)
    mov_np = sitk.GetArrayFromImage(moving_resampled)
    size = fixed.GetSize()
    if slice_index is None and z_world is not None:
        k_float = world_z_to_index(fixed, z_world)
        slice_index = int(round(k_float))
    if slice_index is None:
        slice_index = int(size[2]//2)
    arr_k = slice_index
    fig, ax = plt.subplots(1,1, figsize=(8,8))
    fixed_slice = fixed_np[arr_k,:,:]
    mov_slice = mov_np[arr_k,:,:]
    fixed_display = (fixed_slice - fixed_slice.min()) / max(1e-6, (fixed_slice.max() - fixed_slice.min()))
    ax.imshow(fixed_display, cmap='gray', interpolation='nearest')
    mov_display = (mov_slice - mov_slice.min()) / max(1e-6, (mov_slice.max() - mov_slice.min()))
    ax.imshow(mov_display, cmap=cmap_moving, alpha=alpha, interpolation='nearest')
    ax.set_title(f"Slice index (array order) {arr_k}")
    plt.axis('off')
    plt.show()

def save_image(img, out_path):
    sitk.WriteImage(img, out_path)
    print(f"Saved: {out_path}")

def example_workflow(moving_folder, fixed_folder, out_path_resampled='resampled.mha', transform_4x4=None):
    print('Reading fixed image...')
    fixed = read_dicom_series(fixed_folder)
    print('Fixed origin, spacing, size:', fixed.GetOrigin(), fixed.GetSpacing(), fixed.GetSize())
    print('Reading moving image...')
    moving = read_dicom_series(moving_folder)
    print('Moving origin, spacing, size:', moving.GetOrigin(), moving.GetSpacing(), moving.GetSize())
    transform = None
    if transform_4x4 is not None:
        transform = affine_from_4x4(transform_4x4)
        print('Using provided 4x4 transform.')
    print('Resampling moving -> fixed grid ...')
    moving_resampled = resample_to_target(moving, fixed, transform=transform, interpolator='linear', default_value=0.0)
    save_image(moving_resampled, out_path_resampled)
    try:
        origin = np.array(fixed.GetOrigin())
        spacing = np.array(fixed.GetSpacing())
        size = np.array(fixed.GetSize())
        center_k = int(round(size[2] / 2.0))
        show_overlay_slice(fixed, moving_resampled, slice_index=center_k)
    except Exception as e:
        print('Could not display overlay:', e)

In [3]:
# Quick test: list files in ../WholePelvis to confirm path and available DICOMs
from pathlib import Path
data_dir = Path('..') / 'WholePelvis'
print('Data dir resolved to:', data_dir.resolve())
if data_dir.exists():
    files = list(data_dir.glob('*'))
    print('Number of files in WholePelvis:', len(files))
    for f in files[:20]:
        print('-', f.name)
else:
    print('Directory not found, check path')

Data dir resolved to: C:\Users\zhaoanr\Desktop\WholePelvis
Number of files in WholePelvis: 2992
- CT.1.2.246.352.221.4612103617401399903.11679518133698580356.dcm
- CT.1.2.246.352.221.4612281645113694925.4445080184088332458.dcm
- CT.1.2.246.352.221.4613357340562627367.11069548827518099366.dcm
- CT.1.2.246.352.221.4613742148749638854.17444518703121455550.dcm
- CT.1.2.246.352.221.4614246201180082441.12216565514302185904.dcm
- CT.1.2.246.352.221.4614853653591645137.11640949316199802015.dcm
- CT.1.2.246.352.221.4615023050527622172.910398443307771028.dcm
- CT.1.2.246.352.221.4616338419106580712.2364937694379555226.dcm
- CT.1.2.246.352.221.4616748542900588316.4007241067963100804.dcm
- CT.1.2.246.352.221.4616800360097223970.6889008078367591605.dcm
- CT.1.2.246.352.221.4617271686679834815.3537910081571082377.dcm
- CT.1.2.246.352.221.4617305787553047537.16724059461988762812.dcm
- CT.1.2.246.352.221.4620276485284484005.1726138375659829149.dcm
- CT.1.2.246.352.221.4620768376903297969.1769569090718

In [4]:
# Series summary: list SeriesInstanceUIDs and file counts using SimpleITK or pydicom fallback
from collections import Counter
from pathlib import Path
data_dir = Path('..') / 'WholePelvis'
if not data_dir.exists():
    print('Data directory not found:', data_dir)
else:
    try:
        # Prefer SimpleITK's fast series discovery when available
        if 'sitk' in globals() and sitk is not None:
            print('Using SimpleITK to find series IDs...')
            series_ids = sitk.ImageSeriesReader.GetGDCMSeriesIDs(str(data_dir)) or []
            print(f'Found {len(series_ids)} series via SimpleITK')
            for sid in series_ids:
                files = sitk.ImageSeriesReader.GetGDCMSeriesFileNames(str(data_dir), sid)
                print(f'Series: {sid} -> {len(files)} files, example: {Path(files[0]).name if files else ''}')
        else:
            # Fallback to pydicom to inspect SeriesInstanceUID per file (slower)
            import pydicom
            print('SimpleITK not available; falling back to pydicom (may be slow)')
            uids = Counter()
            files = list(data_dir.glob('*.dcm'))
            for i, f in enumerate(files):
                try:
                    ds = pydicom.dcmread(str(f), stop_before_pixels=True, force=True)
                    sid = getattr(ds, 'SeriesInstanceUID', None) or 'UNKNOWN'
                    uids[sid] += 1
                except Exception as e:
                    uids['READ_ERROR'] += 1
                if i and i % 500 == 0:
                    print(f'Inspected {i} files...')
            print('Series counts (top 20):')
            for sid, cnt in uids.most_common(20):
                print(f'{sid} -> {cnt} files')
    except Exception as e:
        print('Error while enumerating series:', e)

Using SimpleITK to find series IDs...
Found 40 series via SimpleITK
Found 40 series via SimpleITK
Series: 1.2.246.352.221.4621740649674887645.17443183878707904400 -> 2 files, example: RI.1.2.246.352.221.5712045233132034502.15136406492454696350.dcm
Series: 1.2.246.352.221.4621740649674887645.17443183878707904400 -> 2 files, example: RI.1.2.246.352.221.5712045233132034502.15136406492454696350.dcm
Series: 1.2.246.352.221.4629816576916826922.5651701502939223462 -> 88 files, example: CT.1.2.246.352.221.5057563884189011384.3311819419043651519.dcm
Series: 1.2.246.352.221.4629816576916826922.5651701502939223462 -> 88 files, example: CT.1.2.246.352.221.5057563884189011384.3311819419043651519.dcm
Series: 1.2.246.352.221.4682099163491121666.13810091271748640934 -> 2 files, example: RI.1.2.246.352.221.4908038445967358700.6011497693230694805.dcm
Series: 1.2.246.352.221.4682099163491121666.13810091271748640934 -> 2 files, example: RI.1.2.246.352.221.4908038445967358700.6011497693230694805.dcm
Series

## How to run

1. Install requirements if you plan to run the resampling cells:
   - `pip install SimpleITK numpy matplotlib pydicom`
2. Adjust the `example_workflow()` call with your moving/fixed folders (relative to this notebook).
3. Run the cells in order.

In [5]:
# Series selection helpers and sensible defaults
from pathlib import Path
data_dir = Path('..') / 'WholePelvis'
# Gather series ids and file counts (if SimpleITK is available)
series_info = []
if 'sitk' in globals() and sitk is not None and data_dir.exists():
    sids = sitk.ImageSeriesReader.GetGDCMSeriesIDs(str(data_dir)) or []
    for sid in sids:
        files = sitk.ImageSeriesReader.GetGDCMSeriesFileNames(str(data_dir), sid)
        series_info.append((sid, len(files), files[0] if files else None))
# Try to find common candidates by file count (318, 366)
default_pet_sid = None
default_ctsim_sid = None
for sid, cnt, ex in series_info:
    if cnt == 366 and default_pet_sid is None:
        default_pet_sid = sid
    if cnt == 318 and default_ctsim_sid is None:
        default_ctsim_sid = sid
# Print discovered candidates
print('Discovered series counts (top 20):')
for sid, cnt, ex in sorted(series_info, key=lambda x: -x[1])[:20]:
    print(f'{cnt:4d} files: {sid} -> example file {Path(ex).name if ex else ''}')
print()
print('Default PET candidate (366 slices):', default_pet_sid)
print('Default CT_Sim candidate (318 slices):', default_ctsim_sid)

# Allow user override here: set MOVING_SERIES_ID to the CT_2 series you want to resample,
# and set FIXED_SERIES_IDS to a list of series to resample into (CT_1 and PET).
# By default we assume CT_2 should be resampled into the 318-series and 366-series if found.
MOVING_SERIES_ID = default_ctsim_sid or (series_info[0][0] if series_info else None)
FIXED_SERIES_IDS = [default_pet_sid] if default_pet_sid else []
# also include a 318/other large CT as fixed (choose first large series not equal to moving)
for sid, cnt, ex in series_info:
    if sid != MOVING_SERIES_ID and cnt >= 300 and sid not in FIXED_SERIES_IDS:
        FIXED_SERIES_IDS.append(sid)
        break
print('MOVING_SERIES_ID (CT_2 candidate):', MOVING_SERIES_ID)
print('FIXED_SERIES_IDS (targets, e.g. PET and CT_1):', FIXED_SERIES_IDS)

def read_series_by_id(data_dir, series_id):
    "Read a DICOM series specified by SeriesInstanceUID and return a SimpleITK image and file list."
    if 'sitk' not in globals() or sitk is None:
        raise RuntimeError('SimpleITK required to read and resample series')
    files = sitk.ImageSeriesReader.GetGDCMSeriesFileNames(str(data_dir), series_id)
    reader = sitk.ImageSeriesReader()
    reader.SetFileNames(files)
    img = reader.Execute()
    return img, files

Discovered series counts (top 20):
 366 files: 1.2.246.352.221.5652494895570817685.12284654837981657745 -> example file CT.1.2.246.352.221.4705750904517158974.11652307503497441166.dcm
 366 files: 1.2.246.352.221.5761707951501486900.17648698438667590801 -> example file PE.1.2.246.352.221.5213764312901985085.12878585370538429571.dcm
 318 files: 1.2.246.352.221.5205878071515498524.15068257275121431175 -> example file CT.1.2.246.352.221.5152916959242403634.12029315609761544371.dcm
 318 files: 1.2.246.352.221.5678833251329575635.11111956103594119858 -> example file CT.1.2.246.352.221.5203536764609419051.13230986444841661321.dcm
 176 files: 1.2.246.352.221.5175030514826366843.13068040234059044505 -> example file MR.1.2.246.352.221.4734979789753053407.1008250676121845694.dcm
  88 files: 1.2.246.352.221.4629816576916826922.5651701502939223462 -> example file CT.1.2.246.352.221.5057563884189011384.3311819419043651519.dcm
  88 files: 1.2.246.352.221.4704100350675794783.2863429542978936198 -> exa

In [None]:
# Diagnostic: compute Z-ranges, example voxel mappings, and display overlays (inline)
from pathlib import Path
import numpy as np

def image_z_range(img):
    "Return (z_min, z_max) in world coordinates using TransformIndexToPhysicalPoint."
    size = img.GetSize()
    p0 = img.TransformIndexToPhysicalPoint((0,0,0))
    pN = img.TransformIndexToPhysicalPoint((0,0,size[2]-1))
    z_vals = [p0[2], pN[2]]
    return (min(z_vals), max(z_vals))

def print_diagnostics_for_series(series_id, label):
    try:
        img, files = read_series_by_id(data_dir, series_id)
    except Exception as e:
        print('Could not read series', series_id, e)
        return None
    size = img.GetSize()
    spacing = img.GetSpacing()
    origin = img.GetOrigin()
    direction = img.GetDirection()
    zmin, zmax = image_z_range(img)
    print(f'{label}: id={series_id} files={len(files)} size={size} spacing={spacing} origin={origin}')
    print(f'  z-range: {zmin:.3f} to {zmax:.3f} (mm)')
    return img

# Read images for diagnostics: moving (CT_2) and fixed targets (PET, CT_1)
moving_img = None
fixed_imgs = {}
if MOVING_SERIES_ID is not None:
    moving_img = print_diagnostics_for_series(MOVING_SERIES_ID, 'MOVING (CT_2)')
for i, sid in enumerate(FIXED_SERIES_IDS):
    fixed_imgs[sid] = print_diagnostics_for_series(sid, f'FIXED_{i}')

# Compute overlap region (z) between moving and first fixed (if available)
if moving_img is not None and FIXED_SERIES_IDS:
    first_fixed = FIXED_SERIES_IDS[0]
    fixed_img = fixed_imgs.get(first_fixed, None)
    if fixed_img is not None:
        mz0, mz1 = image_z_range(moving_img)
        fz0, fz1 = image_z_range(fixed_img)
        overlap_min = max(mz0, fz0)
        overlap_max = min(mz1, fz1)
        print(f'Overlap in Z (moving vs fixed[{first_fixed[:8]}]): {overlap_min:.3f} to {overlap_max:.3f}')
        if overlap_max <= overlap_min:
            print('  WARNING: no overlap in Z between moving and fixed')

# Helper to map world Z to voxel index in an image (robust using TransformPhysicalPointToIndex)
def world_z_to_voxel_k(img, z_world):
    # find x,y world at image origin XY (use origin XY) and given z_world, then map to index
    ox, oy, oz = img.GetOrigin()
    world_pt = (ox, oy, z_world)
    try:
        idx = img.TransformPhysicalPointToIndex(world_pt)
        return idx[2]
    except Exception:
        # fallback to approximate formula along direction third column
        return int(round(world_z_to_index(img, z_world)))

# Show example mappings for a few Zs (origin z of each image)
example_zs = []
if moving_img is not None:
    example_zs.append(moving_img.GetOrigin()[2])
for sid, img in fixed_imgs.items():
    if img is not None:
        example_zs.append(img.GetOrigin()[2])
example_zs = list(dict.fromkeys(example_zs))[:10]  # unique, up to 10
for z in example_zs:
    print(f'Example world Z={z:.3f} -> moving k=', '' if moving_img is None else world_z_to_voxel_k(moving_img, z), ' fixed k=', '' if not FIXED_SERIES_IDS else world_z_to_voxel_k(fixed_imgs[FIXED_SERIES_IDS[0]], z))

# Now create inline overlays: resample moving and all fixed images into a common grid (choose first fixed as target)
def make_overlay_grid(target_img, imgs_to_overlay, labels=None, slice_indices=None, cols=5, rows=1, cmap_list=None, alpha_list=None, save_png=False, out_path=None):
    # imgs_to_overlay: list of SimpleITK images aligned to target grid (same size/spacing)
    target_np = sitk.GetArrayFromImage(target_img)
    size = target_img.GetSize()
    nz = size[2]
    if slice_indices is None:
        # pick equispaced slices within target range (limit to cols*rows)
        n = min(cols*rows, 5)
        slice_indices = [int(round(i*(nz-1)/(n-1))) for i in range(n)] if n>1 else [nz//2]
    if cmap_list is None:
        cmap_list = ['hot', 'winter', 'viridis'][:len(imgs_to_overlay)]
    if alpha_list is None:
        alpha_list = [0.5]*len(imgs_to_overlay)
    fig, axes = plt.subplots(rows, cols, figsize=(cols*3, rows*3))
    axes = np.array(axes).reshape(-1)
    for ax in axes: ax.axis('off')
    for i, k in enumerate(slice_indices):
        ax = axes[i]
        base = target_np[k,:,:].astype(float)
        base_display = (base - base.min()) / max(1e-6, base.max()-base.min())
        ax.imshow(base_display, cmap='gray')
        for j, img in enumerate(imgs_to_overlay):
            arr = sitk.GetArrayFromImage(img)[k,:,:].astype(float)
            disp = (arr - arr.min()) / max(1e-6, arr.max()-arr.min())
            ax.imshow(disp, cmap=cmap_list[j%len(cmap_list)], alpha=alpha_list[j%len(alpha_list)])
        ax.set_title(f'slice k={k}')
    plt.tight_layout()
    if save_png and out_path is not None:
        fig.savefig(out_path, bbox_inches='tight')
        print('Saved grid to', out_path)
    plt.show()
    plt.close(fig)

# Perform resampling and overlays if we have target and moving images
if MOVING_SERIES_ID is not None and FIXED_SERIES_IDS:
    target_sid = FIXED_SERIES_IDS[0]
    target_img = fixed_imgs[target_sid]
    # resample moving into target grid
    moving_res = resample_to_target(moving_img, target_img, transform=None, interpolator='linear', default_value=0.0)
    # also resample any other fixed (e.g., PET) into target grid so we can overlay all three
    other_imgs = []
    for sid in FIXED_SERIES_IDS:
        if sid == target_sid:
            continue
        img = fixed_imgs.get(sid, None)
        if img is not None:
            other_imgs.append(resample_to_target(img, target_img, transform=None, interpolator='linear', default_value=0.0))
    # Build list: first overlay CT_2 (moving_res), then other_imgs (likely PET or CT_1)
    overlays = [moving_res] + other_imgs
    cmap_list = ['winter', 'hot', 'viridis'][:len(overlays)]
    alpha_list = [0.5]*len(overlays)
    # Display single combined overlay for the center slice
    print('









    print('Not enough series selected to display overlays')else:    make_overlay_grid(target_img, overlays, cmap_list=cmap_list, alpha_list=alpha_list, cols=5, rows=1, save_png=True, out_path=str(out_grid))    out_grid = out_dir / f'overlay_grid_{target_sid[:12]}.png'    # Optionally save a 5x1 PNG showing 5 slices (set save_png=True to save)    make_overlay_grid(target_img, overlays, cmap_list=cmap_list, alpha_list=alpha_list, cols=5, rows=1)    print('Displaying grid of 5 slices across target (inline)')    # Display a grid of 5 slices inline (cols=5, rows=1)    make_overlay_grid(target_img, overlays, cmap_list=cmap_list, alpha_list=alpha_list, cols=1, rows=1)Displaying center-slice combined overlay (inline)')

Moving image read: origin, spacing, size = (-389.23828125, -577.23828125, -189.3) (1.5234375, 1.5234375, 1.5) (512, 512, 318)

Processing target series 1.2.246.352.221.5652494895570817685.12284654837981657745
Fixed image origin, spacing, size = (-249.51171875, -473.01171875, -1698.5) (0.9765625, 0.9765625, 3.0) (512, 512, 366)

Processing target series 1.2.246.352.221.5652494895570817685.12284654837981657745
Fixed image origin, spacing, size = (-249.51171875, -473.01171875, -1698.5) (0.9765625, 0.9765625, 3.0) (512, 512, 366)
Saved: resample_outputs\resampled_moving_to_1.2.246.352.221..mha
Saved: resample_outputs\resampled_moving_to_1.2.246.352.221..mha
Saved overlay to resample_outputs\overlay_1.2.246.352.221..png
Saved overlay to resample_outputs\overlay_1.2.246.352.221..png

Processing target series 1.2.246.352.221.5678833251329575635.11111956103594119858
Fixed image origin, spacing, size = (-249.51171875, -437.51171875, -189.3) (0.9765625, 0.9765625, 1.5) (512, 512, 318)

Processin