## CT and RT Structure data Pre-processing for anatomy change prediction

In [None]:
from pathlib import Path
from multiprocessing import Pool
import logging

import click
import pandas as pd
import numpy as np
import SimpleITK as sitk

import os, sys, glob

In [None]:
reg_nii_Path = "./Howard_Reg_Data/Registrated_ImgGTVpGTVnDose/"
patients = os.listdir(reg_nii_Path)
print(len(patients))

savePath = './Howard_ResampleNormCrop_Data_WithDose/'
if not os.path.exists(savePath):
   os.makedirs(savePath)

In [None]:
def SitkInfo(SitkData):
    
    print('PixelIDValue', SitkData.GetPixelIDValue())
    print('PixelIDTypeAsString', SitkData.GetPixelIDTypeAsString())
    print('NumberOfComponentsPerPixel',SitkData.GetNumberOfComponentsPerPixel())

    print('GetDimension',SitkData.GetDimension())
    print('GetWidth',SitkData.GetWidth())
    print('GetHeight',SitkData.GetHeight())
    print('GetDepth',SitkData.GetDepth())

### 3. Resampling

In [None]:
resample_path = savePath
input_path = os.path.join(resample_path, 'input')
label_path= os.path.join(resample_path, 'label')

In [None]:
if not os.path.exists(resample_path):
   os.makedirs(resample_path)
if not os.path.exists(input_path):
   os.makedirs(input_path)
if not os.path.exists(label_path):
   os.makedirs(label_path)

In [None]:
resampler = sitk.ResampleImageFilter()
resampler.SetOutputDirection([1, 0, 0, 0, 1, 0, 0, 0, 1])
resampling = [2,2,4]
resampler.SetOutputSpacing(resampling)

In [None]:
def get_bouding_boxes(ct, pt):
    """
    Get the bounding boxes of the CT and PT images.
    This works since all images have the same direction
    """

    ct_origin = np.array(ct.GetOrigin())
    pt_origin = np.array(pt.GetOrigin())

    ct_position_max = ct_origin + np.array(ct.GetSize()) * np.array(
        ct.GetSpacing())
    pt_position_max = pt_origin + np.array(pt.GetSize()) * np.array(
        pt.GetSpacing())
    return np.concatenate(
        [
            np.maximum(ct_origin, pt_origin),
            np.minimum(ct_position_max, pt_position_max),
        ],
        axis=0,
    )

In [None]:
import matplotlib.pyplot as plt
def SaveSitkImg(SitkData1, SitkData2, SitkData3, SitkData4, Path, p):
    
    # Create the figure and axes
    fig, ax = plt.subplots(4, 3, figsize=(28, 20))

    idx = 0
    for SitkData in (SitkData1, SitkData2, SitkData3, SitkData4):
        image_data = sitk.GetArrayFromImage(SitkData)
            
        center = np.array([image_data.shape[0] / 2, image_data.shape[1] / 2, image_data.shape[2] / 2], dtype=int)

        # Create the sagittal profile
        sagittal = image_data[center[0], :, :]

        # Create the coronal profile
        coronal = image_data[:, center[1], :]

        # Create the axial profile
        axial = image_data[:, :, center[2]]

        if idx<3:
            # Show the sagittal profile
            ax[idx, 0].imshow(np.flipud(sagittal), cmap='gray', aspect='equal',vmin=-500,vmax=500)
            ax[idx, 0].set_title('Sagittal')
            ax[idx, 0].axis('off')

            # Show the coronal profile
            ax[idx, 1].imshow(np.flipud(coronal), cmap='gray', aspect='equal',vmin=-500,vmax=500)
            ax[idx, 1].set_title('Coronal')
            ax[idx, 1].axis('off')

            # Show the axial profile
            ax[idx, 2].imshow(np.flipud(axial), cmap='gray', aspect='equal',vmin=-500,vmax=500)
            ax[idx, 2].set_title('Axial')
            ax[idx, 2].axis('off')
        else:
            # Show the sagittal profile
            ax[idx, 0].imshow(np.flipud(sagittal), cmap='gray', aspect='equal')
            ax[idx, 0].set_title('Sagittal')
            ax[idx, 0].axis('off')

            # Show the coronal profile
            ax[idx, 1].imshow(np.flipud(coronal), cmap='gray', aspect='equal')
            ax[idx, 1].set_title('Coronal')
            ax[idx, 1].axis('off')

            # Show the axial profile
            ax[idx, 2].imshow(np.flipud(axial), cmap='gray', aspect='equal')
            ax[idx, 2].set_title('Axial')
            ax[idx, 2].axis('off')
        idx += 1

        # save the figure './test_{}.png'.format(idx)
    plt.savefig(os.path.join(Path, p+".png"))
    plt.clf()

In [None]:
def resample_one_patient(p):
    
    ct = sitk.ReadImage(os.path.join(reg_nii_Path, p, 'CT.nii.gz'))
    label_ct = sitk.ReadImage(os.path.join(reg_nii_Path, p, 'GTVp_CT.nii.gz'))
    label2_ct = sitk.ReadImage(os.path.join(reg_nii_Path, p, 'GTVn_CT.nii.gz'))
    cbct1 = sitk.ReadImage(os.path.join(reg_nii_Path, p, 'CBCT1.nii.gz'))
    label_cbct1 = sitk.ReadImage(os.path.join(reg_nii_Path, p, 'GTVp_CBCT1.nii.gz'))
    label2_cbct1 = sitk.ReadImage(os.path.join(reg_nii_Path, p, 'GTVn_CBCT1.nii.gz'))
    cbct2 = sitk.ReadImage(os.path.join(reg_nii_Path, p, 'CBCT2.nii.gz'))
    label_cbct2 = sitk.ReadImage(os.path.join(reg_nii_Path, p, 'GTVp_CBCT2.nii.gz'))
    label2_cbct2 = sitk.ReadImage(os.path.join(reg_nii_Path, p, 'GTVn_CBCT2.nii.gz'))

    dose = sitk.ReadImage(os.path.join(reg_nii_Path, p, 'Dose.nii.gz'))

    SaveSitkImg(ct, cbct1, cbct2, dose, resample_path, p+'_Ori')

    bb = get_bouding_boxes(ct, ct)
    size = np.round((bb[3:] - bb[:3]) / resampling).astype(int)
    resampler.SetOutputOrigin(bb[:3])
    resampler.SetSize([int(k) for k in size])  # sitk is so stupid
    resampler.SetInterpolator(sitk.sitkBSpline)
    resampler.SetDefaultPixelValue(-1000) 
    ct = resampler.Execute(ct)
    sitk.WriteImage(ct, os.path.join(resample_path, p+'_CT.nii.gz'))
    cbct1 = resampler.Execute(cbct1)
    sitk.WriteImage(cbct1, os.path.join(resample_path, p+'_CBCT1.nii.gz'))
    cbct2 = resampler.Execute(cbct2)
    sitk.WriteImage(cbct2, os.path.join(resample_path, p+'_CBCT2.nii.gz'))
    dose = resampler.Execute(dose)
    sitk.WriteImage(dose, os.path.join(resample_path, p+'_Dose.nii.gz'))

    resampler.SetInterpolator(sitk.sitkNearestNeighbor)
    label_ct = resampler.Execute(label_ct)
    label_cbct1 = resampler.Execute(label_cbct1)
    label_cbct2 = resampler.Execute(label_cbct2)
    label2_ct = resampler.Execute(label2_ct)
    label2_cbct1 = resampler.Execute(label2_cbct1)
    label2_cbct2 = resampler.Execute(label2_cbct2)

    SaveSitkImg(ct, cbct1, cbct2, dose, resample_path, p+'_Resample')
    sitk.WriteImage(label_ct, os.path.join(resample_path, p+'_GTVp_CT.nii.gz'))
    sitk.WriteImage(label_cbct1, os.path.join(resample_path, p+'_GTVp_CBCT1.nii.gz'))
    sitk.WriteImage(label_cbct2, os.path.join(resample_path, p+'_GTVp_CBCT2.nii.gz'))
    sitk.WriteImage(label2_ct, os.path.join(resample_path, p+'_GTVn_CT.nii.gz'))
    sitk.WriteImage(label2_cbct1, os.path.join(resample_path, p+'_GTVn_CBCT1.nii.gz'))
    sitk.WriteImage(label2_cbct2, os.path.join(resample_path, p+'_GTVn_CBCT2.nii.gz'))

In [None]:
processed = []
for p in patients:
    if os.path.exists(os.path.join(reg_nii_Path, p, 'Dose.nii.gz')):
        resample_one_patient(p)
        print(p)
        processed.append(p)

### 4. Cropping

In [None]:
def find_centroid(mask):

    stats = sitk.LabelShapeStatisticsImageFilter()
    stats.Execute(mask)
    try:
        centroid_coords = stats.GetCentroid(255)
    except:
        print('Something wrong')
    centroid_idx = mask.TransformPhysicalPointToIndex(centroid_coords)

    return np.asarray(centroid_idx, dtype=np.float64)

In [None]:
exclude_patients = [] # no clinic data

def tune_range(min_d, max_d, d, size_d, p):
    if min_d<0:
        min_d = 0
        max_d = min_d + size_d
        assert (max_d<d), f"Cannot extract the patch with the shape {size_d} from the image with the shape {d} for patient {p}."
    
    if max_d>d:
        max_d = d
        min_d = max_d - size_d
        assert (min_d>0), f"Cannot extract the patch with the shape {size_d} from the image with the shape {d} for patient {p}."

    return min_d, max_d

for p in patients:
    if os.path.exists(os.path.join(reg_nii_Path, p, 'Dose.nii.gz')):

        dose = sitk.ReadImage(os.path.join(resample_path, p+'_Dose.nii.gz'))

        image = sitk.ReadImage(os.path.join(resample_path, p+'_CT.nii.gz'))
        mask = sitk.ReadImage(os.path.join(resample_path, p+'_GTVp_CT.nii.gz'))
        mask_n = sitk.ReadImage(os.path.join(resample_path, p+'_GTVn_CT.nii.gz'))
        
        image1 = sitk.ReadImage(os.path.join(resample_path, p+'_CBCT1.nii.gz'))
        mask1 = sitk.ReadImage(os.path.join(resample_path, p+'_GTVp_CBCT1.nii.gz'))
        mask1_n = sitk.ReadImage(os.path.join(resample_path, p+'_GTVn_CBCT1.nii.gz'))

        
        image2 = sitk.ReadImage(os.path.join(resample_path, p+'_CBCT2.nii.gz'))
        mask2 = sitk.ReadImage(os.path.join(resample_path, p+'_GTVp_CBCT2.nii.gz'))
        mask2_n = sitk.ReadImage(os.path.join(resample_path, p+'_GTVn_CBCT2.nii.gz'))

        if not p in exclude_patients:
            patch_size = np.array([128,128,32])

            #crop the image to patch_size around the tumor center
            tumour_center = find_centroid(mask) # center of GTV
            size = patch_size
            min_coords = np.floor(tumour_center - size / 2).astype(np.int64)
            max_coords = np.floor(tumour_center + size / 2).astype(np.int64)
            min_x, min_y, min_z = min_coords
            max_x, max_y, max_z = max_coords

            (img_x, img_y, img_z)=image.GetSize()

            min_x, max_x = tune_range(min_x, max_x, img_x, size[0], p) 
            min_y, max_y = tune_range(min_y, max_y, img_y, size[1], p) 
            min_z, max_z = tune_range(min_z, max_z, img_z, size[2], p) 
            
            dose = dose[min_x:max_x, min_y:max_y, min_z:max_z]
            image = image[min_x:max_x, min_y:max_y, min_z:max_z]
            image1 = image1[min_x:max_x, min_y:max_y, min_z:max_z]
            image2 = image2[min_x:max_x, min_y:max_y, min_z:max_z]

            # window image intensities to [-500, 1000] HU range
            image = sitk.Clamp(image, sitk.sitkFloat32, -500, 500)
            image1 = sitk.Clamp(image1, sitk.sitkFloat32, -500, 500)
            image2 = sitk.Clamp(image2, sitk.sitkFloat32, -500, 500)

            mask = mask[min_x:max_x, min_y:max_y, min_z:max_z]
            mask1 = mask1[min_x:max_x, min_y:max_y, min_z:max_z]
            mask2 = mask2[min_x:max_x, min_y:max_y, min_z:max_z]
            mask_n = mask_n[min_x:max_x, min_y:max_y, min_z:max_z]
            mask1_n = mask1_n[min_x:max_x, min_y:max_y, min_z:max_z]
            mask2_n = mask2_n[min_x:max_x, min_y:max_y, min_z:max_z]

            sitk.WriteImage(dose, os.path.join(input_path, p+'_Dose.nii.gz'))

            sitk.WriteImage(image, os.path.join(input_path, p+'_CT.nii.gz'))
            sitk.WriteImage(mask, os.path.join(input_path, p+'_GTVp_CT.nii.gz'))
            sitk.WriteImage(mask_n, os.path.join(input_path, p+'_GTVn_CT.nii.gz'))

            sitk.WriteImage(image1, os.path.join(input_path, p+'_CBCT1.nii.gz'))
            sitk.WriteImage(mask1, os.path.join(input_path, p+'_GTVp_CBCT1.nii.gz'))
            sitk.WriteImage(mask1_n, os.path.join(input_path, p+'_GTVn_CBCT1.nii.gz'))
            
            sitk.WriteImage(image2, os.path.join(label_path, p+'_CBCT2.nii.gz'))
            sitk.WriteImage(mask2, os.path.join(label_path, p+'_GTVp_CBCT2.nii.gz'))
            sitk.WriteImage(mask2_n, os.path.join(label_path, p+'_GTVn_CBCT2.nii.gz'))

            
            SaveSitkImg(image, image1, image2, dose, resample_path, p+'_Crop')
        else:
            print('skip ', p)
            continue
