In [1]:
import os
import re

import numpy as np
import nibabel as nib
from scipy.ndimage.interpolation import zoom
import h5py

In [2]:
data_root = '../RibFrac/'
image_root = os.path.join(data_root, 'images')
label_root = os.path.join(data_root, 'labels')

image_train_path = os.path.join(image_root, 'train')
label_train_path = os.path.join(label_root, 'train')

image_val_path = os.path.join(image_root, 'val')
label_val_path = os.path.join(label_root, 'val')

In [3]:
prep_root = os.path.join(data_root, 'preprocessed_3d')

prep_train_path = os.path.join(prep_root, 'train')
prep_val_path = os.path.join(prep_root, 'val')

In [4]:
def preprocessing(image_path, label_path, prep_path, order, low_th, window, gamma=None):
    image_names = sorted(os.listdir(image_path))
    label_names = sorted(os.listdir(label_path))

    for img_name, lbl_name in zip(image_names, label_names):
        img_raw = nib.load(os.path.join(image_path, img_name))
        lbl_raw = nib.load(os.path.join(label_path, lbl_name))
        img = img_raw.get_fdata()
        lbl = lbl_raw.get_fdata()
        
        img_zoomed = zoom(img, (img_raw.header).get_zooms(), order=order, mode='nearest')
        high_th = low_th + window
        pmax = min(img_zoomed.max(), high_th)
        img_th = np.where((img_zoomed >= low_th) & (img_zoomed <= high_th), 
                              (img_zoomed-low_th)/(pmax-low_th), 0.0)
        img_th[img_zoomed < low_th] = 0.0
        img_th[img_zoomed > high_th] = 1.0
        
        if gamma is not None:
            img_th = img_th**(gamma)

        img_th = np.moveaxis(img_th, -1, 0)
        lbl_th = np.moveaxis(lbl, -1, 0)
        img_scale = np.array([256,256,256]) / np.array(img_th.shape)
        lbl_scale = np.array([256,256,256]) / np.array(lbl_th.shape)
        img_resampled = (zoom(img_th, img_scale, order=0, mode='nearest')).astype(np.float32)
        lbl_resampled = zoom(lbl_th, lbl_scale, order=0, mode='nearest').astype(np.uint8)
        lbl_resampled = (np.where(lbl_resampled == 0, 0, 1)).astype(np.uint8)

        file_name = ''.join((re.search(r'[0-9]+', img_name).group(),'.h5'))
        file_path = os.path.join(prep_path, file_name)
        
        file = h5py.File(file_path,'w')
        file.create_dataset('raw', data=img_resampled)
        file.create_dataset('label', data=lbl_resampled)
        file.close()

In [5]:
preprocessing(image_train_path, label_train_path, prep_train_path, order=3, low_th=100, window=2048, gamma=0.4)

In [6]:
preprocessing(image_val_path, label_val_path, prep_val_path, order=3, low_th=100, window=2048, gamma=0.4)