# Pre-processing

As it has been discussed in the paper and at `../2_no_finetuning/1_inference_example.ipynb` the images that the network is used to taking are images of 1024x1024x3. Therefore, we pre-process the images we have in batch to fine tune at a later stage. Therefore, for simplicity, translate each image into a 1024x1024x1024 image

In [2]:
import sys, os
dir1 = os.path.abspath(os.path.join(os.path.abspath(''), '..', '..'))
if not dir1 in sys.path: sys.path.append(dir1)

from utils.environment import setup_data_vars

setup_data_vars()

In [5]:
import numpy as np

# import nibabel as nib
import SimpleITK as sitk
import os

join = os.path.join
from skimage import transform
from tqdm import tqdm
import cc3d

def pre_process(
      input_path_nii: str
    , input_path_gt: str
    , anatomy: str
    , modality = 'CT'
    , image_prefix = 'zzAMLART_'
    , nii_postfix = '_0000.nii.gz'
    , gt_postfix = '.nii.gz'
    , output_postfix = '.nii.gz'
    , WINDOW_WIDTH = 400
    , WINDOW_LEVEL = 40
    , image_size = 1024
    , voxel_num_thre2d = 50
    , voxel_num_thre3d = 1000
    ):
    # << SETUP DESTINATION >>
    
    # Convince yourself that the anatomy parsed in is the same as the one mentioned in the input paths. Hacky workaround exists, but assume we're cooperating for now.
    assert all([anatomy in x for x in [input_path_nii, input_path_gt]])

    prefix = modality + "_" + anatomy + "_"

    nyp_path_imgs = os.path.join(os.environ.get('PROJECT_DIR'), 'data', 'MedSAM_preprocessed', 'imgs') # e.g. MedSAM_preprocessed/imgs/
    nyp_path_gts = os.path.join(os.environ.get('PROJECT_DIR'), 'data', 'MedSAM_preprocessed', 'gts', prefix[:-1]) # e.g. MedSAM_preprocessed/gts/CT_Bladder/

    os.makedirs(nyp_path_imgs, exist_ok=True) # Nii images
    os.makedirs(nyp_path_gts, exist_ok=True) # Ground truth images

    # << ITERATE OVER ALL IMAGES AND CONVERT >>

    # Get the list of images to process
    img_names = sorted([f for f in os.listdir(input_path_nii) if f.endswith(nii_postfix)])
    gt_names = sorted([f for f in os.listdir(input_path_gt) if f.endswith(gt_postfix)])

    # Get the processed images for images and ground truth
    img_processed_names = sorted([f for f in os.listdir(nyp_path_imgs) if f.endswith(output_postfix)])
    gt_processed_names = sorted([f for f in os.listdir(nyp_path_gts) if f.endswith(output_postfix)])

    # Get list of remaining images assuming each image id is unique.
    img_ids = [int(img.split(image_prefix)[1].split(nii_postfix)[0]) for img in img_names]
    gt_ids = [int(gt.split(image_prefix)[1].split(gt_postfix)[0]) for gt in gt_names]
    processed_img_ids = [int(img.split(prefix)[1].split('.npy')[0]) for img in img_processed_names]
    processed_gt_ids = [int(gt.split(prefix)[1].split('.npy')[0]) for gt in gt_processed_names]

    remaining_img_names = set(img_ids) - set(processed_img_ids)
    remaining_gt_names = set(gt_ids) - set(processed_gt_ids)

    # << PROCESS IMAGES >>
    # Make the sensible assumption that the dimensions of the images and the ground truth segmentations are the same. Therefore, we can process these separately and not be impacted

    for img_id in tqdm(remaining_img_names, desc='Processing images'):
        img_path = os.path.join(input_path_nii, image_prefix + str(img_id).zfill(3) + nii_postfix)
        image_data = sitk.GetArrayFromImage(sitk.ReadImage(img_path))

        if modality == "CT":
            lower_bound = WINDOW_LEVEL - WINDOW_WIDTH / 2
            upper_bound = WINDOW_LEVEL + WINDOW_WIDTH / 2
            image_data_pre = np.clip(image_data, lower_bound, upper_bound)
            image_data_pre = (
                (image_data_pre - np.min(image_data_pre))
                / (np.max(image_data_pre) - np.min(image_data_pre))
                * 255.0    
            )
        else:
            raise NotImplementedError("Modality not supported")

        image_data_pre = np.uint8(image_data_pre)

        resize_img_skimg = transform.resize(
            image_data_pre,
            (image_size, image_size, image_size),
            order=3,
            preserve_range=True,
            mode="constant",
            anti_aliasing=True,
        )

        print(resize_img_skimg.GetShape())


    # for gt_id in tqdm(remaining_gt_names, desc='Processing ground truths'):
    #     gt_path = os.path.join(input_path_gt, image_prefix + str(gt_id).zfill(3) + gt_postfix)
    #     gt_data_ori = sitk.GetArrayFromImage(sitk.ReadImage(gt_path))

    #     # exclude objects with less than 1000 pixels in 3D
    #     gt_data_ori = cc3d.dust(gt_data_ori, threshold=voxel_num_thre3d, connectivity=26, in_place=True)
    #     # remove small objects with less than voxel_num_thre2d pixels in 2D slices. For
    #     # such small objects, the main challenge is detection rather than segmentation
    #     for slice_i in range(gt_data_ori.shape[0]):
    #         gt_i = gt_data_ori[slice_i, :, :]
    #         gt_data_ori = cc3d.dust(gt_i, threshold=voxel_num_thre2d, connectivity=6, in_place=True)

    #     # find non-zero slices
    #     z_index, _, _ = np.where(gt_data_ori > 0)
    #     z_index = np.unique(z_index)
    
    return -1






anatomy = 'Bladder'

pre_process(
    input_path_nii = os.path.join(os.environ.get('nnUNet_raw'), os.environ.get(anatomy), os.environ.get('data_trainingImages')),
    input_path_gt = os.path.join(os.environ.get('nnUNet_raw'), os.environ.get(anatomy), os.environ.get('data_trainingLabels')),
    anatomy = anatomy
)

Processing images:   0%|          | 0/100 [00:00<?, ?it/s]