In [1]:
import SimpleITK as sitk
import numpy as np
import os
import time
import Code.assessment as assess
import subprocess

def Perform_LCBM_Registration(input_path, output_path):

    os.makedirs(output_path, exist_ok=True)

    # step 0: pre-processing, bias field correction and histogram matching, and rigid registration
    print('Step 0: pre-processing, bias field correction and histogram matching, and rigid registration')
    post_image_filename = os.path.join(input_path, 'post_image.mha')
    post_liver_filename = os.path.join(input_path, 'post_liver.mha')
    post_ablation_filename = os.path.join(input_path, 'post_ablation.mha')
    post_tumor_filename = os.path.join(input_path, 'post_tumor.mha')
    post_needle_filename = os.path.join(input_path, 'post_needle.txt')
    
    pre_image_filename = os.path.join(input_path, 'pre_image.mha')
    pre_liver_filename = os.path.join(input_path, 'pre_liver.mha')
    pre_tumor_filename = os.path.join(input_path, 'pre_tumor.mha')
    
    pre_image_rigid_filename = os.path.join(output_path, 'pre_image_rigid.mha')
    pre_liver_rigid_filename = os.path.join(output_path, 'pre_liver_rigid.mha')
    pre_tumor_rigid_filename = os.path.join(output_path, 'pre_tumor_rigid.mha')

    post_image_prep_filename = os.path.join(output_path, 'post_image_prep.mha')

    # bias field correction (a bit slow)
    print('\rBias field correction...', end='')
    post_image_bias = sitk.N4BiasFieldCorrection(sitk.ReadImage(post_image_filename, sitk.sitkFloat64))
    pre_image_bias = sitk.N4BiasFieldCorrection(sitk.ReadImage(pre_image_filename, sitk.sitkFloat64))
    print('\rBias field correction done.')
    
    # histogram matching
    print('\rHistogram matching...', end='')
    pre_image_hist = sitk.HistogramMatching(pre_image_bias, post_image_bias, 1024, 7)
    print('\rHistogram matching done.')
    
    # output pre-processed images
    post_image_prep_filename = os.path.join(output_path, 'post_image_prep.mha')
    pre_image_prep_filename = os.path.join(output_path, 'pre_image_prep.mha')
    sitk.WriteImage(post_image_bias, post_image_prep_filename, True)
    sitk.WriteImage(pre_image_hist, pre_image_prep_filename, True)
    
    post_image = sitk.ReadImage(post_image_prep_filename)
    pre_image = sitk.ReadImage(pre_image_prep_filename)
    
    # rigid registration: using SimpleElastix
    print('\rRigid registration...', end='')
    elastixImageFilter = sitk.ElastixImageFilter()
    elastixImageFilter.SetFixedImage(post_image)
    elastixImageFilter.SetMovingImage(pre_image)
    elastixImageFilter.SetParameterMap(sitk.GetDefaultParameterMap("rigid"))
    elastixImageFilter.Execute()
    
    transformixImageFilter = sitk.TransformixImageFilter()
    transformixImageFilter.ComputeDeformationFieldOn()
    transformixImageFilter.SetMovingImage(pre_image) # Bug of SimpleElastix, need to set moving image again
    transformixImageFilter.SetTransformParameterMap(elastixImageFilter.GetTransformParameterMap())
    transformixImageFilter.Execute()

    rigid_field = sitk.Cast(transformixImageFilter.GetDeformationField(), sitk.sitkVectorFloat64)
    rigid_trans = sitk.DisplacementFieldTransform(rigid_field)
    
    # output rigid registration results
    pre_image_rigid = sitk.Resample(pre_image, post_image, rigid_trans, sitk.sitkLinear)
    pre_liver_rigid = sitk.Resample(sitk.ReadImage(pre_liver_filename), post_image, rigid_trans, sitk.sitkNearestNeighbor)
    pre_tumor_rigid = sitk.Resample(sitk.ReadImage(pre_tumor_filename), post_image, rigid_trans, sitk.sitkNearestNeighbor)
    
    sitk.WriteImage(pre_image_rigid, pre_image_rigid_filename, True)
    sitk.WriteImage(pre_liver_rigid, pre_liver_rigid_filename, True)
    sitk.WriteImage(pre_tumor_rigid, pre_tumor_rigid_filename, True)
    print('\rRigid registration done.')
    
    
    
    t_start = time.time()
    # step 1: LC-part, quantification of external shrinkage
    print('Step 1: LC-part, quantification of external shrinkage')
    
    liver = sitk.ReadImage(post_liver_filename)
    ablation = sitk.ReadImage(post_ablation_filename)
    
    dilate_mm = np.array([10,10,10])
    dilate_pixel = np.ceil(dilate_mm / np.array(liver.GetSpacing())).astype(np.uint)
    
    liver_dilated = sitk.BinaryDilate(liver, dilate_pixel.tolist())
    ablation_dilated = sitk.BinaryDilate(ablation, dilate_pixel.tolist())
    
    omega_R_mask = sitk.And(sitk.BinaryNot(ablation), liver_dilated)
    omega_EC_mask = sitk.And(sitk.Xor(ablation_dilated, ablation), liver)
    omega_IC_mask = ablation
    
    omega_R_mask_filename = os.path.join(output_path, 'omega_R_mask.mha')
    omega_EC_mask_filename = os.path.join(output_path, 'omega_EC_mask.mha')
    omega_IC_mask_filename = os.path.join(output_path, 'omega_IC_mask.mha')
    sitk.WriteImage(omega_R_mask, omega_R_mask_filename, True)
    sitk.WriteImage(omega_EC_mask, omega_EC_mask_filename, True)
    sitk.WriteImage(omega_IC_mask, omega_IC_mask_filename, True)
    
    # after LC-part, we can get the respiratory motion field (phi_rm) and the external shrinkage field (phi_es)
    print('\rLC-part...', end='')
    command = f'Code/LC-part.exe -f {post_image_prep_filename} -m {pre_image_rigid_filename} -omegaR {omega_R_mask_filename} -omegaEC {omega_EC_mask_filename} -i 100x100x100 -g 4 -s 3 -o {output_path}'
    subprocess.call(command)
    print('\rLC-part done.')
    
    # read the respiratory motion field (phi_rm)
    phi_rm_field_filename = os.path.join(output_path, 'phi_rm_field.mha')
    rm_field = sitk.ReadImage(phi_rm_field_filename, sitk.sitkVectorFloat64)        
    rm_trans = sitk.DisplacementFieldTransform(rm_field)
    
    
    # step 2: BM-part, compensation for internal shrinkage
    print('Step 2: BM-part, compensation for internal shrinkage')
    
    # after BM-part, we can get the total shrinkage field (phi_ts)
    phi_es_field_filename = os.path.join(output_path, 'phi_es_field.mha')
    print('\rBM-part...', end='')
    command = f'./Code/BM-part.exe -r {post_image_filename} -n {post_needle_filename} -ic {omega_IC_mask_filename} -ec {omega_EC_mask_filename} -es {phi_es_field_filename} -o {output_path}'
    subprocess.call(command)
    print('\rBM-part done.')
    
    # read the total shrinkage field (phi_ts)
    phi_ts_field_filename = os.path.join(output_path, 'phi_ts_field.mha')
    ts_field = sitk.ReadImage(phi_ts_field_filename, sitk.sitkVectorFloat64)        
    ts_trans = sitk.DisplacementFieldTransform(ts_field)
    
    print(f'Processing time: {time.time() - t_start} seconds')

    # step 3: assessment of the registration accuracy, prediction of the local tumor progression (LTP)
    print('Step 3: assessment of the registration accuracy, prediction of the local tumor progression (LTP)')
    
    # read the pre-image and label
    post_image = sitk.ReadImage(post_image_prep_filename)
    post_liver = sitk.ReadImage(post_liver_filename)
    post_tumor = sitk.ReadImage(post_tumor_filename)
    post_ablation = sitk.ReadImage(post_ablation_filename)
    
    pre_image = sitk.ReadImage(pre_image_rigid_filename)
    pre_liver = sitk.ReadImage(pre_liver_rigid_filename)
    pre_tumor = sitk.ReadImage(pre_tumor_rigid_filename)

    # assessment
    assessment = assess.Assessment(
        fixed_img = post_image, fixed_seg_list=[post_liver, post_tumor],
        moving_img = pre_image, moving_seg_list=[pre_liver, pre_tumor],
        fixed_ablation = post_ablation, follow_tumor = post_tumor, moving_tumor = pre_tumor,
        mask = post_liver
    )
    
    # assess the LC-part
    trans = sitk.CompositeTransform([rm_trans])
    print('Assessment of LC-part:', assessment.AssessRegistration(trans))

    # output the results
    warped_pre_image = assessment.warped_img
    warped_pre_liver = assessment.warped_seg_list[0]
    warped_pre_tumor = assessment.warped_seg_list[1]

    sitk.WriteImage(warped_pre_image, os.path.join(output_path, 'pre_image_rm.mha'), True)
    sitk.WriteImage(warped_pre_liver, os.path.join(output_path, 'pre_liver_rm.mha'), True)
    sitk.WriteImage(warped_pre_tumor, os.path.join(output_path, 'pre_tumor_rm.mha'), True)
    
    # assessment: BM-part
    trans = sitk.CompositeTransform([rm_trans, ts_trans])
    print('Assessment of BM-part:', assessment.AssessRegistration(trans))

    # output the results
    warped_pre_image = assessment.warped_img
    warped_pre_liver = assessment.warped_seg_list[0]
    warped_pre_tumor = assessment.warped_seg_list[1]

    sitk.WriteImage(warped_pre_image, os.path.join(output_path, 'pre_image_ts.mha'), True)
    sitk.WriteImage(warped_pre_liver, os.path.join(output_path, 'pre_liver_ts.mha'), True)
    sitk.WriteImage(warped_pre_tumor, os.path.join(output_path, 'pre_tumor_ts.mha'), True)

    
Perform_LCBM_Registration('D:/Work/Papers/lc-bm/Github/data', 'D:/Work/Papers/lc-bm/Github/results')

Step 0: pre-processing, bias field correction and histogram matching, and rigid registration
Bias field correction done.
Histogram matching done.
Rigid registration done.
Step 1: LC-part, quantification of external shrinkage
LC-part done.
Step 2: BM-part, compensation for internal shrinkage
BM-part done.
Processing time: 398.0189416408539 seconds
Step 3: assessment of the registration accuracy, prediction of the local tumor progression (LTP)
Assessment of LC-part: {'ncc': 0.6850111, 'dice': [0.9507592814297636, 0.6163522012578617], 'hdd': [10.645203142355134, 8.661687324045607], 'ltp_gt': False, 'ltp_result': False}
Assessment of BM-part: {'ncc': 0.68692976, 'dice': [0.9511625015923103, 0.7267605633802816], 'hdd': [10.645203142355134, 6.447043216520032], 'ltp_gt': False, 'ltp_result': False}
