In [None]:
import os
import SimpleITK as sitk
from tqdm import tqdm
import skimage
from skimage.transform import resize
import matplotlib.pyplot as plt
import numpy as np

In [None]:
target_shape = (512, 512)

In [None]:
path_to_data_dir = '/data/chest_radiograph/nifti_files'
path_to_resized_data_dir = '/data/chest_radiograph/resized_nifti_files'

In [None]:
img_list = os.listdir(path_to_data_dir)

In [None]:
def pad_image_to_target_shape(img, small_edge_idx, pad_size):
    pad_width = ((pad_size//2 + pad_size % 2, pad_size//2), (0, 0)) if small_edge_idx == 0 else ((0, 0), (pad_size//2  + pad_size % 2, pad_size//2))
    padded_img = np.pad(img, pad_width=pad_width, mode='constant', constant_values=0)
    return padded_img

In [None]:
def reshape(np_img):
    target_shape = [512, 512]
    org_shape = np_img.shape
    small_edge_idx = np.argmin(org_shape) # Only change the small dimension
    large_edge_idx = np.argmax(org_shape)
    reduction_factor = target_shape[large_edge_idx]/org_shape[large_edge_idx]
    _target_shape_before_padding = target_shape.copy()
    _target_shape_before_padding[small_edge_idx] = round( reduction_factor * org_shape[small_edge_idx])
    _resized_np_img = resize(np_img, _target_shape_before_padding, mode='constant', cval=0)
    pad_size = target_shape[small_edge_idx] - _resized_np_img.shape[small_edge_idx]
    resized_np_img = pad_image_to_target_shape(_resized_np_img, small_edge_idx, pad_size)
    return resized_np_img

In [None]:
def save_img_as_sitk_to_path(np_img, org_sitk_img, path):
    new_sitk_img = sitk.GetImageFromArray(np_img)
    new_sitk_img.SetSpacing(org_sitk_img.GetSpacing())
    for key in org_sitk_img.GetMetaDataKeys():
        if not new_sitk_img.HasMetaDataKey(key):
            new_sitk_img.SetMetaData(key, org_sitk_img.GetMetaData(key))
    sitk.WriteImage(new_sitk_img, path)



In [None]:
img_dimensions = []
for img in tqdm(img_list):
    try:
        img_path = os.path.join(path_to_data_dir, img)
        save_path = os.path.join(path_to_resized_data_dir, img)
        sitk_img = sitk.ReadImage(img_path)
        np_img = sitk.GetArrayFromImage(sitk_img)
        resized_np_img = reshape(np_img)
        assert resized_np_img.shape == target_shape
        save_img_as_sitk_to_path(resized_np_img, sitk_img,  save_path)
    except:
        print("Something wrong occured with image: ", img)