In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
import glob
from skimage import transform
import nibabel as nib
import tqdm
from PIL import Image


def normalize_img(img):
    norm_img = img/np.max(img)
    return norm_img


def crop_pad_resize(image, nx, ny):
    x, y = image.shape

    # difference in nr of pixels (divide by 2 since we have 2 sides)
    x_s = (x - nx) // 2
    y_s = (y - ny) // 2
    x_c = (nx - x) // 2
    y_c = (ny - y) // 2

    if x > nx and y > ny:
        # if image is larger in both dimensions cut a slice
        slice_cropped = image[x_s:x_s + nx, y_s:y_s + ny]

    else:
        # if one dim is smaller fill that side up with 0's
        slice_cropped = np.zeros((nx, ny))

        if x <= nx and y > ny:
            # fill up x direction with 0's, cut in x direction
            slice_cropped[x_c:x_c + x, :] = image[:, y_s:y_s + ny]
        elif x > nx and y <= ny:
            # fill up y direction with 0's, cut in y direction
            slice_cropped[:, y_c:y_c + y] = image[x_s:x_s + nx, :]
        else:
            # if dimensions are as desired, keep the original slice
            slice_cropped[x_c:x_c + x, y_c:y_c + y] = image[:, :]

    return slice_cropped


def preprocess(input_folder, target_resolution, target_size, train_test_val):

    nx, ny = target_size
    len_inp = len(input_folder)+1

    for folder in os.listdir(input_folder):

        if folder != '.ipynb_checkpoints':

            folder_path = os.path.join(input_folder, folder)

            if not os.path.exists(os.path.join('preprocessed/'+train_test_val, folder_path[len_inp:])):
                os.mkdir(os.path.join('preprocessed/'+train_test_val, folder_path[len_inp:]))

            for file in glob.glob(os.path.join(folder_path, 'patient???_frame??.nii.gz')):
                file_base = file.split('.nii.gz')[0]
                file_mask = file_base + '_gt.nii.gz'

                img_nii = nib.load(file)
                img_dat = img_nii.get_fdata()

                mask_nii = nib.load(file_mask)
                mask_dat = mask_nii.get_fdata()

                img = img_nii.get_fdata()
                mask = mask_nii.get_fdata()

                pixel_size = img_nii.header.get_zooms()

                scale_vector = [pixel_size[0] / target_resolution[0], pixel_size[1] / target_resolution[1]]

                for zz in tqdm.tqdm(range(img.shape[2])):
                    slice_img = np.squeeze(img[:, :, zz])
                    slice_img = normalize_img(np.squeeze(img[:, :, zz]))
                    img_rescaled = transform.rescale(slice_img,
                                                     scale_vector,
                                                     order=1,
                                                     preserve_range=True,
                                                     mode='constant')

                    slice_mask = np.squeeze(mask[:, :, zz])
                    slice_mask = normalize_img(np.squeeze(mask[:, :, zz]))
                    mask_rescaled = transform.rescale(slice_mask,
                                                      scale_vector,
                                                      order=0,
                                                      preserve_range=True,
                                                      mode='constant')

                    img_cropped = crop_pad_resize(img_rescaled, nx, ny)
                    mask_cropped = crop_pad_resize(mask_rescaled, nx, ny)

                    img_list.append(img_rescaled)
                    mask_list.append(mask_rescaled)

                    img_loc = os.path.join('preprocessed/'+train_test_val, file[len_inp:-7]+'_slice{:01}'.format(zz)+'.png')
                    img_fin = Image.fromarray(img_cropped)
                    img_fin = img_fin.convert("L")
                    img_fin.save(img_loc)
                    
                    mask_loc = os.path.join('preprocessed/'+train_test_val, file[len_inp:-7]+'_gt_slice{:01}'.format(zz)+'.png')
                    mask_fin = Image.fromarray(mask_cropped)
                    mask_fin = mask_fin.convert("L")
                    mask_fin.save(mask_loc)


target_resolution = (1.36719, 1.36719)
target_size = (212, 212)
img_list = []
mask_list = []

preprocess('data/train', target_resolution, target_size, train_test_val='train')
preprocess('data/test', target_resolution, target_size, train_test_val='test')

  norm_img = img/np.max(img)
  min_val = min_func(input_image)
  max_val = max_func(input_image)
  and min_func(output_image) <= cval <= max_func(output_image))
100%|██████████| 10/10 [00:00<00:00, 74.28it/s]
100%|██████████| 10/10 [00:00<00:00, 103.28it/s]
