In [8]:
import os
import nibabel as nib
import numpy as np
from scipy.ndimage.interpolation import zoom
import scipy as sp
from tqdm import tqdm, trange
from tqdm.notebook import tqdm_notebook

import keras
from keras.models import Model, Sequential
from keras.layers import Input, Dense, Reshape, Flatten, LeakyReLU, Dropout, Embedding, Concatenate
from keras.layers.core import Activation
from keras.layers.convolutional import Conv3D, Deconv3D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.normalization import BatchNormalization
from keras.utils.vis_utils import plot_model
from keras.optimizers import Adam
import tensorflow as tf
from keras.utils import multi_gpu_model
import skimage.transform as skt

from keras.utils import generic_utils as keras_generic_utils
import keras.backend as K
K.set_image_data_format('channels_first')

In [2]:
def load_nifti(file_path, mask=None, z_factor=None, remove_nan=False):
    """Load a 3D array from a NIFTI file."""
    img = nib.load(file_path)
    struct_arr = np.array(img.get_fdata())

    if remove_nan:
        struct_arr = np.nan_to_num(struct_arr)
    if mask is not None:
        struct_arr *= mask
    if z_factor is not None:
        struct_arr = np.around(zoom(struct_arr, z_factor), 0)

    return struct_arr


def save_nifti(file_path, struct_arr):
    """Save a 3D array to a NIFTI file."""
    img = nib.Nifti1Image(struct_arr, np.eye(4))
    nib.save(img, file_path)

In [5]:

def prepare_data(use_smooth = False, running_on_server = False):
    root_dir = 'C:/Users/Eshan/Google Drive UALBERTA/Data/' if not running_on_server else '/mnt/hdd1/lxc-hdd1/tahjid/PD Data/'
    patient_list, patient_numbers, dataset = [], [], []
    label_map = dict(Control=0, PD=1)
    type_map = dict(FullScan=0, GrayMatter=1, WhiteMatter=2)
    full_scan_path = root_dir + 'FinalData/'
    wmgmpath = root_dir + 'FinalDataWMGM/' if not use_smooth else root_dir + 'FinalDataWMGMSmooth/'
    prefix = 'mwp' if not use_smooth else 'smwp'
    ext = '.nii'
    for i in tqdm(['Control', 'PD']):
        path = full_scan_path + i + '/'
        listOfFiles = [f for f in os.listdir(path) if f.endswith(ext)]
        for file in tqdm(listOfFiles):
            patient_numbers.append(file[:4])
        path = wmgmpath + i + '/'
        listOfFiles = [f for f in os.listdir(path) if f.endswith(ext)]
        for file in tqdm(listOfFiles):
            filename = file[4:8] if not use_smooth else file[5:9]
            if filename not in patient_numbers:
                continue
            if not use_smooth:
                patient_list.append([i, file[4:8]])
            else:
                patient_list.append([i, file[5:9]])

    for i in tqdm(patient_list):
        path = full_scan_path + i[0] + '/'
        patientIdVal = i[1]
        fullScanvalue = os.path.join(path + i[1] + ext)
        path = wmgmpath + i[0] + '/'
        gmval = os.path.join(path + prefix + str(type_map['GrayMatter']) + i[1] + ext)
        wmval = os.path.join(path + prefix + str(type_map['WhiteMatter']) + i[1] + ext)
        labelval = label_map[i[0]]
        dataset.append([patientIdVal,fullScanvalue, gmval, wmval, labelval])
    return np.array(dataset)
dataset = prepare_data()


  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 299/299 [00:00<?, ?it/s]A

100%|██████████| 626/626 [00:00<00:00, 208814.56it/s]

100%|██████████| 299/299 [00:00<?, ?it/s]A

100%|██████████| 714/714 [00:00<00:00, 119074.87it/s]
100%|██████████| 2/2 [00:00<00:00, 45.45it/s]
100%|██████████| 1170/1170 [00:00<00:00, 5651.80it/s]


In [6]:

def normalize(input):
    """Normalize inputs between -1 and +1"""
    normd = 2*(input-input.min())/(input.max()-input.min())-1
    return normd

In [38]:

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, data, labels, batch_size = 2, dim1 = (91,109,91), dim2 = (242, 145, 121) , n_channels=1,
                 n_classes = 2, shuffle = True,
                 target_size = (256/4, 256/2, 256/2), resize = True, normalize = True):
        'Initialization'
        # self.dim1 = dim1
        self.dim = dim2
        self.batch_size = batch_size
        self.labels = list(labels)
        self.data = data
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.list_IDs = list(data[:,:1])
        self.target_size = tuple(int(a) for a in target_size)
        self.resize = resize
        self.normalize = normalize
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __numbatches__(self):
        return int(np.floor(len(self.list_IDs) / self.__len__()))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # Generate data
        X, y = self.__data_generation(indexes)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, indexes):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        gmwm = np.empty((self.batch_size, self.n_channels,  self.target_size[1], self.target_size[1], self.target_size[2]))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, count in enumerate(indexes):
            val = self.data[count]
            graymatter = load_nifti(val[2])
            whitematter = load_nifti(val[3])

            graymatter = graymatter.astype(np.float64)
            whitematter = whitematter.astype(np.float64)

            if self.normalize:
                graymatter = normalize(graymatter)
                whitematter = normalize(whitematter)

            if self.resize:
                graymatter = skt.resize(graymatter, self.target_size, mode = 'constant')
                whitematter = skt.resize(whitematter, self.target_size, mode = 'constant')

            print(graymatter.shape)
            print(whitematter.shape)
            gmwm[i,] = np.concatenate((graymatter, whitematter))[np.newaxis, ...]

            # Store class
            y[i] = self.labels[i]

        return gmwm, y

In [39]:
X = dataset[:,:4]
y = dataset[:,4:]
d = DataGenerator(X, y)

In [40]:
a = d.__getitem__(0)

(64, 128, 128)
(64, 128, 128)
(64, 128, 128)
(64, 128, 128)


In [42]:
a[0].shape

(2, 1, 128, 128, 128)

(64, 128, 128)