In [20]:
import os
import nibabel as nib
import numpy as np
from scipy.ndimage.interpolation import zoom
import scipy as sp
from tqdm.notebook import tqdm_notebook

import keras
from keras.models import Model
from keras.layers import Input
from keras.layers.core import Activation
from keras.layers.convolutional import Conv3D, Deconv3D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.normalization import BatchNormalization
from keras.utils.vis_utils import plot_model
def load_nifti(file_path, mask=None, z_factor=None, remove_nan=False):
    """Load a 3D array from a NIFTI file."""
    img = nib.load(file_path)
    struct_arr = np.array(img.get_data())

    if remove_nan:
        struct_arr = np.nan_to_num(struct_arr)
    if mask is not None:
        struct_arr *= mask
    if z_factor is not None:
        struct_arr = np.around(zoom(struct_arr, z_factor), 0)

    return struct_arr


def save_nifti(file_path, struct_arr):
    """Save a 3D array to a NIFTI file."""
    img = nib.Nifti1Image(struct_arr, np.eye(4))
    nib.save(img, file_path)
    




In [2]:
def prepareData(use_smooth = False, runningOnServer = False):
    rootDir = 'C:/Users/Eshan/Google Drive UALBERTA/Data/' if not runningOnServer else '/mnt/hdd1/lxc-hdd1/tahjid/PD Data/'
    patientList, patientNumbers, dataset = [], [], []
    labelMap = dict(Control=0, PD=1)
    typeMap = dict(FullScan=0, GrayMatter=1, WhiteMatter=2)
    fullScanPath = rootDir + 'FinalData/'
    wmgmpath = rootDir + 'FinalDataWMGM/' if not use_smooth else rootDir + 'FinalDataWMGMSmooth/'
    prefix = 'mwp' if not use_smooth else 'smwp'
    ext = '.nii'
    for i in ['Control', 'PD']:
        path = fullScanPath + i + '/'
        listOfFiles = [f for f in os.listdir(path) if f.endswith(ext)]
        for file in listOfFiles:
            patientNumbers.append(file[:4])
        path = wmgmpath + i + '/'
        listOfFiles = [f for f in os.listdir(path) if f.endswith(ext)]
        for file in listOfFiles:
            filename = file[4:8] if not use_smooth else file[5:9]
            if filename not in patientNumbers:
                continue
            if not use_smooth:
                patientList.append([i, file[4:8]])
            else:
                patientList.append([i, file[5:9]])

    for i in patientList:
        path = fullScanPath + i[0] + '/'
        patientIdVal = i[1]
        fullScanvalue = os.path.join(path + i[1] + ext)
        path = wmgmpath + i[0] + '/'
        gmval = os.path.join(path + prefix + str(typeMap['GrayMatter']) + i[1] + ext)
        wmval = os.path.join(path + prefix + str(typeMap['WhiteMatter']) + i[1] + ext)
        labelval = labelMap[i[0]]
        dataset.append([patientIdVal,fullScanvalue, gmval, wmval, labelval])
    return np.array(dataset)
dataset = prepareData()

In [10]:
from sklearn.model_selection import train_test_split
X = dataset[:,:4]
y = dataset[:,4:]

In [11]:
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=.20, random_state=42)


In [17]:
dataset[:,:1]

array([['3004'],
       ['3011'],
       ['3013'],
       ...,
       ['4117'],
       ['4135'],
       ['4136']], dtype='<U76')

In [58]:
class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, data, labels, batch_size=2, dim1=(91,109,91), dim2=(121,145,121) , n_channels=1,
                 n_classes=2, shuffle=True):
        'Initialization'
        self.dim1 = dim1
        self.dim2 = dim2
        self.batch_size = batch_size
        self.labels = list(labels)
        self.data = data
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.list_IDs = list(data[:,:1])
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # print(indexes)
        # Find list of IDs
        # list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        fs, gm, wm, y = self.__data_generation(indexes)

        return fs, gm, wm, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, indexes):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        fs = np.empty((self.batch_size, *self.dim1, self.n_channels))
        gm = np.empty((self.batch_size, *self.dim2, self.n_channels))
        wm = np.empty((self.batch_size, *self.dim2, self.n_channels))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, count in enumerate(indexes):
            val = self.data[count]
            fullscan = load_nifti(val[1])
            graymatter = load_nifti(val[2])
            whitematter = load_nifti(val[3])
            # print(fullscan.shape, graymatter.shape, whitematter.shape)
            # break
            # Store sample
            fs[i,] = fullscan[..., np.newaxis]
            gm[i,] = graymatter[..., np.newaxis]
            wm[i,] = whitematter[..., np.newaxis]


            # Store class
            y[i] = self.labels[i]

        return fs, gm, wm, keras.utils.to_categorical(y, num_classes=self.n_classes)
    

In [59]:
d = DataGenerator(x_train, y_train)


In [60]:
d.__getitem__(0)

[783 131]



* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  

* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  

* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  

* deprecated from version: 3.0
* Will raise <class 'nibabel.deprecator.ExpiredDeprecationError'> as of version: 5.0
  


(array([[[[[0.],
           [0.],
           [0.],
           ...,
           [0.],
           [0.],
           [0.]],
 
          [[0.],
           [0.],
           [0.],
           ...,
           [0.],
           [0.],
           [0.]],
 
          [[0.],
           [0.],
           [0.],
           ...,
           [0.],
           [0.],
           [0.]],
 
          ...,
 
          [[0.],
           [0.],
           [0.],
           ...,
           [0.],
           [0.],
           [0.]],
 
          [[0.],
           [0.],
           [0.],
           ...,
           [0.],
           [0.],
           [0.]],
 
          [[0.],
           [0.],
           [0.],
           ...,
           [0.],
           [0.],
           [0.]]],
 
 
         [[[0.],
           [0.],
           [0.],
           ...,
           [0.],
           [0.],
           [0.]],
 
          [[0.],
           [0.],
           [0.],
           ...,
           [0.],
           [0.],
           [0.]],
 
          [[

In [16]:
def generator(phase_train=True, params={'z_size':200, 'strides':(2,2,2), 'kernel_size':(4,4,4)}):
    """
    Returns a Generator Model with input params and phase_train
    Args:
        phase_train (boolean): training phase or not
        params (dict): Dictionary with model parameters
    Returns:
        model (keras.Model): Keras Generator model
    """

    z_size = params['z_size']
    strides = params['strides']
    kernel_size = params['kernel_size']

    inputs = Input(shape=(1, 1, 1, z_size))

    g1 = Deconv3D(filters=512, kernel_size=kernel_size,
                  strides=(1, 1, 1), kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='valid')(inputs)
    g1 = BatchNormalization()(g1, training=phase_train)
    g1 = Activation(activation='relu')(g1)

    g2 = Deconv3D(filters=256, kernel_size=kernel_size,
                  strides=strides, kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='same')(g1)
    g2 = BatchNormalization()(g2, training=phase_train)
    g2 = Activation(activation='relu')(g2)

    g3 = Deconv3D(filters=128, kernel_size=kernel_size,
                  strides=strides, kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='same')(g2)
    g3 = BatchNormalization()(g3, training=phase_train)
    g3 = Activation(activation='relu')(g3)

    g4 = Deconv3D(filters=64, kernel_size=kernel_size,
                  strides=strides, kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='same')(g3)
    g4 = BatchNormalization()(g4, training=phase_train)
    g4 = Activation(activation='relu')(g4)

    g5 = Deconv3D(filters=1, kernel_size=kernel_size,
                  strides=strides, kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='same')(g4)
    g5 = BatchNormalization()(g5, training=phase_train)
    g5 = Activation(activation='sigmoid')(g5)

    model = Model(inputs=inputs, outputs=g5)
    model.summary()

    return model

In [22]:
model = generator()
model.summary()
plot_model(model, to_file = "generator_plot.png", show_shapes = True, show_layer_names = True)

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_4 (InputLayer)         (None, 1, 1, 1, 200)      0         
_________________________________________________________________
conv3d_transpose_6 (Conv3DTr (None, 512, 4, 4, 203)    33280     
_________________________________________________________________
batch_normalization_6 (Batch (None, 512, 4, 4, 203)    812       
_________________________________________________________________
activation_6 (Activation)    (None, 512, 4, 4, 203)    0         
_________________________________________________________________
conv3d_transpose_7 (Conv3DTr (None, 256, 8, 8, 406)    8388864   
_________________________________________________________________
batch_normalization_7 (Batch (None, 256, 8, 8, 406)    1624      
_________________________________________________________________
activation_7 (Activation)    (None, 256, 8, 8, 406)    0   

ImportError: Failed to import `pydot`. Please install `pydot`. For example with `pip install pydot`.