# Optimized Cone Beam FDK Reconstruction
## Enhanced with Parallel Processing using Joblib

This optimized version includes:
- Parallel processing of projections using joblib
- Vectorized operations where possible
- Memory-efficient implementations
- Progress tracking with tqdm
- Better code structure and documentation

In [1]:
import math
import os
from pydicom import dcmread, dcmwrite
import pydicom
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from scipy.fft import fft, ifft
from joblib import Parallel, delayed
import multiprocessing
from functools import partial
import time

PI = math.pi
print(f"Available CPU cores: {multiprocessing.cpu_count()}")

Available CPU cores: 24


## Data Loading and Preprocessing

In [2]:
def load_dicom_data(path_list):
    """
    Optimized DICOM data loading with parallel processing
    """
    def read_single_dicom(fname):
        try:
            raw = dcmread(fname)
            return raw.pixel_array, raw.GantryAngle
        except Exception as e:
            print(f"Error reading {fname}: {e}")
            return None, None
    
    # Use parallel processing for DICOM reading
    n_jobs = min(8, multiprocessing.cpu_count())  # Limit to avoid I/O bottleneck
    results = Parallel(n_jobs=n_jobs, backend='threading')(
        delayed(read_single_dicom)(fname) for fname in tqdm(path_list, desc="Loading DICOM files")
    )
    
    # Filter out failed reads
    valid_results = [(img, angle) for img, angle in results if img is not None]
    
    if not valid_results:
        raise ValueError("No valid DICOM files found")
    
    images = np.array([img for img, _ in valid_results])
    angles = np.array([angle for _, angle in valid_results])
    
    return images, angles

In [3]:
def process_differential_images(images, angles, threshold=10000):
    """
    Optimized differential image processing
    """
    n_images = len(images)
    processed_images = np.zeros_like(images)
    processed_angles = []
    
    prev_img = np.zeros_like(images[0])
    
    for idx in tqdm(range(n_images), desc="Processing differential images"):
        curr_img = images[idx]
        diff_img = curr_img - prev_img
        
        if idx > 0 and np.max(diff_img) > threshold:
            # Use previous valid image and angle
            processed_images[idx] = processed_images[idx-1]
            processed_angles.append(processed_angles[idx-1])
        else:
            processed_images[idx] = diff_img
            processed_angles.append(angles[idx])
            
        prev_img = curr_img
    
    return processed_images, np.array(processed_angles)

In [4]:
# Data paths
_arc1_pth = r"E:\CMC\pyprojects\radio_therapy\dose-3d\dataset\epid-1-arc-vmat"
_arc2_pth = r"E:\CMC\pyprojects\radio_therapy\dose-3d\dataset\2"

# Get file lists
_files_1 = [f for f in os.listdir(_arc1_pth) if f.endswith('.dcm')]
_files_2 = [f for f in os.listdir(_arc2_pth) if f.endswith('.dcm')]

# Create full paths
_full_paths_1 = [os.path.join(_arc1_pth, f) for f in _files_1]
_full_paths_2 = [os.path.join(_arc2_pth, f) for f in _files_2]
_pth = _full_paths_1 + _full_paths_2

print(f"Total DICOM files found: {len(_pth)}")

Total DICOM files found: 1135


In [5]:
# Load DICOM data
start_time = time.time()
raw_images, raw_angles = load_dicom_data(_pth)
print(f"DICOM loading completed in {time.time() - start_time:.2f} seconds")

# Get image properties from first DICOM
dcm = dcmread(_pth[0])
shape = dcm.Rows, dcm.Columns
print(f"Image shape: {shape}")
print(f"Raw images shape: {raw_images.shape}")

Loading DICOM files: 100%|██████████| 1135/1135 [00:02<00:00, 380.11it/s]


DICOM loading completed in 6.09 seconds
Image shape: (1190, 1190)
Raw images shape: (1135, 1190, 1190)


In [6]:
# Process differential images
start_time = time.time()
processed_images, processed_angles = process_differential_images(raw_images, raw_angles)
print(f"Differential processing completed in {time.time() - start_time:.2f} seconds")

# Sort images by gantry angle
sorted_indices = np.argsort(processed_angles)
sorted_images = processed_images[sorted_indices]
sorted_angles = processed_angles[sorted_indices]

print(f"Final sorted images shape: {sorted_images.shape}")
print(f"Angle range: {sorted_angles.min():.1f}° to {sorted_angles.max():.1f}°")

# Free memory
del raw_images, processed_images

Processing differential images: 100%|██████████| 1135/1135 [00:02<00:00, 495.01it/s]


Differential processing completed in 3.71 seconds
Final sorted images shape: (1135, 1190, 1190)
Angle range: 0.1° to 359.9°


## Optimized Filter Functions

In [7]:
def filter_shepp_logan(N, d):
    """
    Optimized Shepp-Logan filter using vectorized operations
    """
    k = np.arange(N) - N/2.0
    # Avoid division by zero
    with np.errstate(divide='ignore', invalid='ignore'):
        fh_SL = -2.0 / (PI * PI * d * d * (4 * k**2 - 1))
        fh_SL[np.isinf(fh_SL)] = 0  # Handle inf values
        fh_SL[np.isnan(fh_SL)] = 0  # Handle NaN values
    return fh_SL

def nearest_power_of_2(N):
    """
    Find the nearest power of 2 greater than or equal to N
    """
    a = int(math.log2(N))
    return N if 2**a == N else 2**(a + 1)

## Parallel Reconstruction Functions

In [8]:
def weight_projection_vectorized(projection_beta, SOD, delta_dd):
    """
    Vectorized weighting of projection for cone beam geometry
    """
    Nrows, Ncolumns = projection_beta.shape
    
    # Create coordinate grids
    dd_column = delta_dd * np.arange(-Ncolumns/2 + 0.5, Ncolumns/2 + 0.5)
    dd_row = delta_dd * np.arange(-Nrows/2 + 0.5, Nrows/2 + 0.5)
    dd_row2D, dd_column2D = np.meshgrid(dd_row, dd_column, indexing='ij')
    
    # Vectorized weight calculation
    weight = SOD / np.sqrt(SOD**2 + dd_row2D**2 + dd_column2D**2)
    
    return projection_beta * weight

def filter_projection_fft(weighted_projection, fh_filter):
    """
    Optimized filtering using FFT with proper padding
    """
    Nrows, Ncolumns = weighted_projection.shape
    Nfft = nearest_power_of_2(2 * Ncolumns - 1)
    
    # Prepare filter in frequency domain
    fh_padded = np.zeros(Nfft)
    fh_padded[:len(fh_filter)] = fh_filter / 2.0
    fh_fft = fft(fh_padded)
    
    # Pad projection data
    projection_padded = np.zeros((Nrows, Nfft))
    projection_padded[:, :Ncolumns] = weighted_projection
    
    # Perform filtering in frequency domain
    projection_fft = fft(projection_padded, axis=1)
    filtered_fft = projection_fft * fh_fft[np.newaxis, :]
    filtered_projection = ifft(filtered_fft, axis=1).real
    
    return filtered_projection[:, :Ncolumns]

In [9]:
def backproject_single_angle(args):
    """
    Process a single projection angle for backprojection
    This function is designed to be called in parallel
    """
    projection_beta, beta_rad, SOD, delta_dd, Nimage, fh_filter, beta_num = args
    
    # Weight and filter the projection
    weighted_projection = weight_projection_vectorized(projection_beta, SOD, delta_dd)
    filtered_projection = filter_projection_fft(weighted_projection, fh_filter)
    
    # Perform backprojection
    Nrows, Ncolumns = filtered_projection.shape
    MX, MZ = Nimage, int(Nimage * Nrows / Ncolumns)
    
    # Define reconstruction volume
    roi = delta_dd * np.array([-Ncolumns/2.0 + 0.5, Ncolumns/2.0 - 0.5, 
                               -Nrows/2.0 + 0.5, Nrows/2.0 - 0.5])
    
    hx = (roi[1] - roi[0]) / (MX - 1)
    hy = (roi[3] - roi[2]) / (MZ - 1)
    
    xrange = roi[0] + hx * np.arange(MX)
    yrange = roi[2] + hy * np.arange(MZ)
    
    XX, YY, ZZ = np.meshgrid(xrange, xrange, yrange, indexing='ij')
    
    # Backprojection geometry
    cos_beta, sin_beta = np.cos(beta_rad), np.sin(beta_rad)
    
    U = (SOD + XX * sin_beta - YY * cos_beta) / SOD
    a = (XX * cos_beta + YY * sin_beta) / U
    b = ZZ / U
    
    # Convert to detector coordinates
    xx = np.floor(a / delta_dd).astype(np.int32)
    yy = np.floor(b / delta_dd).astype(np.int32)
    
    u1 = a / delta_dd - xx
    u2 = b / delta_dd - yy
    
    # Adjust indices to start from 0
    xx += Ncolumns // 2
    yy += Nrows // 2
    
    # Create mask for valid indices
    mask = (xx >= 0) & (xx < Ncolumns - 1) & (yy >= 0) & (yy < Nrows - 1)
    
    # Initialize backprojection result
    temp_rec = np.zeros((MX, MX, MZ))
    
    if np.any(mask):
        # Extract valid coordinates
        xx_valid = xx[mask]
        yy_valid = yy[mask]
        u1_valid = u1[mask]
        u2_valid = u2[mask]
        U_valid = U[mask]
        
        # Bilinear interpolation
        temp_val = ((1 - u1_valid) * (1 - u2_valid) * filtered_projection[yy_valid, xx_valid] +
                   (1 - u1_valid) * u2_valid * filtered_projection[yy_valid + 1, xx_valid] +
                   u1_valid * (1 - u2_valid) * filtered_projection[yy_valid, xx_valid + 1] +
                   u1_valid * u2_valid * filtered_projection[yy_valid + 1, xx_valid + 1])
        
        # Apply backprojection weight
        temp_val = temp_val / (U_valid**2) * 2 * PI / beta_num
        
        # Add to reconstruction volume
        temp_rec[mask] = temp_val
    
    return temp_rec

In [10]:
def parallel_fdk_reconstruction(projections, angles, SOD, Nimage, delta_dd, n_jobs=-1):
    """
    Parallel FDK reconstruction using joblib
    """
    Ncolumns = projections.shape[2]
    Nrows = projections.shape[1]
    beta_num = len(angles)
    
    # Convert angles to radians
    beta_rad = angles * PI / 180.0
    
    # Prepare filter
    Nfft = nearest_power_of_2(2 * Ncolumns - 1)
    fh_filter = filter_shepp_logan(Nfft, delta_dd)
    
    # Prepare arguments for parallel processing
    args_list = []
    for m in range(beta_num):
        args = (projections[m, :, :], beta_rad[m], SOD, delta_dd, 
                Nimage, fh_filter, beta_num)
        args_list.append(args)
    
    # Set number of jobs
    if n_jobs == -1:
        n_jobs = min(multiprocessing.cpu_count(), beta_num)
    
    print(f"Starting parallel reconstruction with {n_jobs} processes...")
    
    # Parallel processing with progress bar
    with tqdm(total=beta_num, desc="Reconstructing") as pbar:
        def update_progress(result):
            pbar.update(1)
            return result
        
        # Use joblib for parallel processing
        backprojections = Parallel(n_jobs=n_jobs, backend='multiprocessing')(
            delayed(backproject_single_angle)(args) for args in args_list
        )
    
    # Sum all backprojections
    print("Summing backprojections...")
    rec_image = np.sum(backprojections, axis=0)
    
    return rec_image

## Run Optimized Reconstruction

In [11]:
# Reconstruction parameters
Nimage = 100  # size of reconstructed image
SID = dcm.RTImageSID
SAD = dcm.RadiationMachineSAD
SOD = SAD  # source to origin distance, in unit mm
SDD = SID  # source to center of detector, in unit mm
width = 0.172  # size of detector cell, in unit mm
delta_dd = width * SOD / SDD  # interval of the virtual detector cell

print(f"Reconstruction parameters:")
print(f"  Image size: {Nimage}x{Nimage}")
print(f"  SOD: {SOD} mm")
print(f"  SDD: {SDD} mm")
print(f"  Detector pixel size: {width} mm")
print(f"  Virtual detector interval: {delta_dd:.4f} mm")
print(f"  Number of projections: {len(sorted_angles)}")

Reconstruction parameters:
  Image size: 100x100
  SOD: 1000 mm
  SDD: 1499.98304881781 mm
  Detector pixel size: 0.172 mm
  Virtual detector interval: 0.1147 mm
  Number of projections: 1135


In [None]:
# Run parallel reconstruction
start_time = time.time()

# Use fewer jobs to avoid memory issues with large datasets
n_jobs = min(4, multiprocessing.cpu_count())  # Limit to avoid memory issues

rec_image = parallel_fdk_reconstruction(
    sorted_images, sorted_angles, SOD, Nimage, delta_dd, n_jobs=n_jobs
)

reconstruction_time = time.time() - start_time
print(f"\nParallel reconstruction completed in {reconstruction_time:.2f} seconds")
print(f"Reconstructed image shape: {rec_image.shape}")
print(f"Reconstruction value range: {rec_image.min():.2f} to {rec_image.max():.2f}")

Starting parallel reconstruction with 4 processes...


Reconstructing:   0%|          | 0/1135 [00:00<?, ?it/s]

## Visualization

In [None]:
# Display reconstruction results
NimageZ = Nimage * sorted_images.shape[0] // sorted_images.shape[0]
Z_c = int(NimageZ // 2)
X_c = int(Nimage // 2)
Y_c = int(Nimage // 2)

figure, axis = plt.subplots(1, 3, figsize=(15, 5))
figure.suptitle('Optimized Parallel FDK Reconstruction Results', fontsize=16)

# Sagittal view (X plane)
axis[0].imshow(rec_image[X_c, :, :].T, cmap='CMRmap_r')
axis[0].set_title(f'Sagittal (X={X_c})')
axis[0].axis('off')

# Coronal view (Y plane)
axis[1].imshow(rec_image[:, Y_c, :].T, cmap='CMRmap_r')
axis[1].set_title(f'Coronal (Y={Y_c})')
axis[1].axis('off')

# Axial view (Z plane)
axis[2].imshow(rec_image[:, :, Z_c].T, cmap='CMRmap_r')
axis[2].set_title(f'Axial (Z={Z_c})')
axis[2].axis('off')

plt.tight_layout()
plt.show()

print(f"Center voxel value: {rec_image[X_c, Y_c, Z_c]:.2f}")

## Comparison with TPS (if available)

In [None]:
# Load TPS reference dose for comparison (update path as needed)
try:
    _TPS_pth = r"E:\CMC\pyprojects\radio_therapy\dose-3d\dataset\3DDose\RD.23022024.12 x 12.dcm"
    
    if os.path.exists(_TPS_pth):
        tps_dcm = dcmread(_TPS_pth)
        tps_image = tps_dcm.pixel_array
        
        # Reorient reconstruction for comparison
        rec_image_oriented = np.transpose(rec_image, (2, 1, 0))
        
        # Create comparison plot
        fig, axs = plt.subplots(2, 3, figsize=(15, 10))
        fig.suptitle('Optimized EPID vs TPS Dose Comparison', fontsize=16)
        
        # EPID reconstruction (top row)
        axs[0, 0].imshow(rec_image_oriented[X_c, :, :].T, cmap='CMRmap_r')
        axs[0, 0].set_title('EPID - Sagittal')
        axs[0, 0].axis('off')
        
        axs[0, 1].imshow(rec_image_oriented[:, Y_c, :].T, cmap='CMRmap_r')
        axs[0, 1].set_title('EPID - Coronal')
        axs[0, 1].axis('off')
        
        axs[0, 2].imshow(rec_image_oriented[:, :, Z_c].T, cmap='CMRmap_r')
        axs[0, 2].set_title('EPID - Axial')
        axs[0, 2].axis('off')
        
        # TPS dose (bottom row)
        axs[1, 0].imshow(tps_image[X_c, :, :].T, cmap='CMRmap_r')
        axs[1, 0].set_title('TPS - Sagittal')
        axs[1, 0].axis('off')
        
        axs[1, 1].imshow(tps_image[:, Y_c, :].T, cmap='CMRmap_r')
        axs[1, 1].set_title('TPS - Coronal')
        axs[1, 1].axis('off')
        
        axs[1, 2].imshow(tps_image[:, :, Z_c].T, cmap='CMRmap_r')
        axs[1, 2].set_title('TPS - Axial')
        axs[1, 2].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        # Calculate scaling factor
        scaler = np.max(tps_image) / np.max(rec_image_oriented)
        print(f"\nScaling factor (TPS/EPID): {scaler:.2f}")
        print(f"TPS center value: {tps_image[tps_image.shape[0]//2, tps_image.shape[1]//2, tps_image.shape[2]//2]}")
        print(f"EPID center value: {rec_image_oriented[50, 50, 50]:.2f}")
        
    else:
        print(f"TPS file not found at: {_TPS_pth}")
        
except Exception as e:
    print(f"Error loading TPS data: {e}")

## Performance Summary

In [None]:
print("\n" + "="*50)
print("OPTIMIZATION PERFORMANCE SUMMARY")
print("="*50)
print(f"Total reconstruction time: {reconstruction_time:.2f} seconds")
print(f"Number of projections processed: {len(sorted_angles)}")
print(f"Time per projection: {reconstruction_time/len(sorted_angles):.3f} seconds")
print(f"CPU cores used: {n_jobs}")
print(f"Reconstructed volume size: {rec_image.shape}")
print(f"Total voxels: {np.prod(rec_image.shape):,}")
print("\nOptimizations applied:")
print("✓ Parallel DICOM loading")
print("✓ Vectorized mathematical operations")
print("✓ FFT-based filtering")
print("✓ Parallel backprojection processing")
print("✓ Memory-efficient data handling")
print("✓ Progress tracking")

## Save Results (Optional)

In [None]:
# Uncomment to save the reconstructed volume as DICOM
# try:
#     if 'tps_dcm' in locals():
#         # Scale and prepare for DICOM export
#         scaled_image = np.int32(rec_image_oriented * scaler)
#         
#         # Create new DICOM based on TPS template
#         write_dicom = tps_dcm.copy()
#         write_dicom.NumberOfFrames = str(rec_image.shape[2])
#         write_dicom.Rows = rec_image.shape[0]
#         write_dicom.Columns = rec_image.shape[1]
#         write_dicom.PixelData = scaled_image.tobytes()
#         
#         # Save with timestamp
#         import datetime
#         timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
#         output_path = f"E:\\CMC\\pyprojects\\radio_therapy\\dose-3d\\dataset\\3DDose\\EPID_OPTIMIZED_{timestamp}.dcm"
#         dcmwrite(output_path, write_dicom)
#         print(f"\nOptimized reconstruction saved to: {output_path}")
#     else:
#         print("TPS data not available - cannot save DICOM")
# except Exception as e:
#     print(f"Error saving DICOM: {e}")

print("\nOptimized reconstruction complete!")