In [1]:
import os

import numpy as np
import pydicom

from tqdm import tqdm
import glob
from ants import registration, from_numpy, pad_image, create_jacobian_determinant_image
from ants import create_warped_grid, plot
from PIL import Image


def read_dicom_files(dicom_dir):
    try:
        dicom_files = glob.glob(os.path.join(dicom_dir, "*.*"))
        sorted_dicom_files = sorted(dicom_files)
        stacked_dicom = [pydicom.dcmread(dicom_file) for dicom_file in sorted_dicom_files]
        return stacked_dicom
    except IndexError as e:
        print(f"{e}, at path {dicom_dir}")
        return None


def affine_registration(moving, fixed, verbose=False, reg_type="AffineFast"):
    """
    Affine registration - rigid (rotation and translation) + scale - registration of a moving image to a fixed image

    """
    registered_dict = registration(fixed=fixed, moving=moving, type_of_transform=reg_type, verbose=verbose)
    return registered_dict


# Base path
root_path = "/media/monib/External Disk/work2022/Base_Dataset/KTL_HCC_Dicom"
save_path_P = "/media/monib/External Disk/work2022/Base_Dataset/KTL_AffineFast_P2A/P"
save_path_A = "/media/monib/External Disk/work2022/Base_Dataset/KTL_AffineFast_P2A/A"

sub_paths = [root_path + f"/{sub_dir}" for sub_dir in os.listdir(root_path)]

all_portal_paths = []
# all_portal_paths = []
all_arterial_paths = []

for idx, sub_path in enumerate(sub_paths):
    # pre_phase_path = sub_path + "/Pre/"
    artery_phase_path = sub_path + "/A/"
    # portal_phase_path = sub_path + "/P/"
    portal_phase_path = sub_path + "/P/"

    for phases_dir in (os.listdir(sub_path)):
        if "V" in phases_dir:
            portal_phase_path = portal_phase_path.split("/")
            portal_phase_path.remove("P")
            portal_phase_path = "/" + os.path.join(*portal_phase_path) + "V/"

    all_portal_paths.append(portal_phase_path)
    # all_portal_paths.append(portal_phase_path)
    all_arterial_paths.append(artery_phase_path)

patient_id_p = [p.split("/")[7] for p in all_portal_paths]
# patient_id_p = [p.split("/")[7] for p in artery_phase_path]
patient_id_a = [a.split("/")[7] for a in all_arterial_paths]
set_p = set(patient_id_p)
# set_p = set(patient_id_p)
set_a = set(patient_id_a)
print(set_p.symmetric_difference(set_a))
# print(patient_id_d)

patient_id_p_sorted = sorted(patient_id_p)
# patient_id_p_sorted = sorted(patient_id_p)
patient_id_a_sorted = sorted(patient_id_a)
all_portal_paths_sorted = sorted(all_portal_paths)
# all_portal_paths_sorted = sorted(all_portal_paths)
all_arterial_paths_sorted = sorted(all_arterial_paths)

# for d, p in zip(all_delayed_paths_sorted, all_portal_paths_sorted):
#     print(f"p: {p}")

empty_paths = []
just_val = 0
for p_path, a_path, dir_name in tqdm(zip(all_portal_paths_sorted, all_arterial_paths_sorted, patient_id_p_sorted),
                                     desc="Progress: "):
    # img_A_3d, affine_A, header_A = read_nii_file(wh_example_A_path)
    # img_B_3d, affine_B, header_B = read_nii_file(wh_example_B_path)
    dicom_A_stacks = read_dicom_files(p_path)
    dicom_B_stacks = read_dicom_files(a_path)

    if len(dicom_A_stacks) == 0 or len(dicom_B_stacks) == 0:
        empty_paths.append(p_path)
        continue
    moving_shape = list(dicom_A_stacks[0].pixel_array.shape)
    moving_shape.append(len(dicom_A_stacks))

    moving_3d = np.zeros(moving_shape)

    for s in range(len(dicom_A_stacks)):
        moving_3d[:, :, s] = dicom_A_stacks[s].pixel_array

    fixed_shape = list(dicom_B_stacks[0].pixel_array.shape)
    fixed_shape.append(len(dicom_B_stacks))

    fixed_3d = np.zeros(fixed_shape)

    for j in range(len(dicom_B_stacks)):
        fixed_3d[:, :, j] = dicom_B_stacks[j].pixel_array

    moving_ants_img = from_numpy(moving_3d)
    fixed_ants_img = from_numpy(fixed_3d)

    # Align from greater to smaller slice sizes
    if moving_ants_img.shape[2] > fixed_ants_img.shape[2]:

        moving_ants_img = pad_image(moving_ants_img, pad_width=(60, 60, 60), value=0.0)
        moved_dict = affine_registration(moving_ants_img, fixed_ants_img, reg_type="AffineFast")

    else:

        fixed_ants_img = pad_image(fixed_ants_img, pad_width=(60, 60, 60), value=0.0)
        moved_dict = affine_registration(fixed_ants_img, moving_ants_img, reg_type="AffineFast")

    moved_ants_img = moved_dict['warpedmovout']
    save_to_path_p = os.path.join(save_path_P, dir_name + "_p")
    save_to_path_a = os.path.join(save_path_A, dir_name + "_a")

    os.makedirs(save_to_path_p, exist_ok=True)
    os.makedirs(save_to_path_a, exist_ok=True)

    if moving_ants_img.shape[2] > fixed_ants_img.shape[2]:
        for r in range(moved_ants_img.shape[2]):
            dicom_slice = dicom_A_stacks[r]
            moved_slice = moved_ants_img[:, :, r].astype(str(dicom_slice.pixel_array.dtype))
            dicom_slice.PixelData = moved_slice.tobytes()

            save_slices_path = os.path.join(save_to_path_p, f"Aff_0000{r:03}.dcm")

            dicom_slice.SeriesDescription = "[Research & Science] - Generated Data"
            dicom_slice.save_as(save_slices_path)
        # save fixed as it is
        for b in range(len(dicom_B_stacks)):
            dicom_slice = dicom_B_stacks[b]
            save_slices_path = os.path.join(save_to_path_a, f"Orig_000{b:03}.dcm")
            dicom_slice.save_as(save_slices_path)
    else:
        for r in range(moved_ants_img.shape[2]):
            dicom_slice = dicom_B_stacks[r]
            moved_slice = moved_ants_img[:, :, r].astype(str(dicom_slice.pixel_array.dtype))
            dicom_slice.PixelData = moved_slice.tobytes()

            save_slices_path = os.path.join(save_to_path_a, f"Aff_0000{r:03}.dcm")

            dicom_slice.SeriesDescription = "[Research & Science] - Generated Data"
            dicom_slice.save_as(save_slices_path)
        # save fixed as it is
        for b in range(len(dicom_A_stacks)):
            dicom_slice = dicom_A_stacks[b]
            save_slices_path = os.path.join(save_to_path_p, f"Orig_000{b:03}.dcm")
            dicom_slice.save_as(save_slices_path)

    # just_val +=1
    # if just_val == 15:
    #     break
print("save completed!")

set()



Unknown encoding 'ISO IR 149' - using default encoding instead


Unknown encoding 'ISO_IR 149' - using default encoding instead

Progress: : 712it [4:07:30, 20.86s/it]

save completed!



