In [1]:
import os
import nibabel as nib
import numpy as np
from scipy.ndimage.interpolation import zoom
import scipy as sp
from tqdm import tqdm, trange
from tqdm.notebook import tqdm_notebook

import keras
from keras.models import Model, Sequential
from keras.layers import Input, Dense, Reshape, Flatten, LeakyReLU, Dropout, Embedding, Concatenate
from keras.layers.core import Activation
from keras.layers.convolutional import Conv3D, Deconv3D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.normalization import BatchNormalization
from keras.utils.vis_utils import plot_model
from keras.optimizers import Adam
import tensorflow as tf
from keras.utils import multi_gpu_model
from keras.utils import generic_utils as keras_generic_utils
import keras.backend as K

import skimage.transform as skt

def load_nifti(file_path, mask=None, z_factor=None, remove_nan=False):
    """Load a 3D array from a NIFTI file."""
    img = nib.load(file_path)
    struct_arr = np.array(img.get_fdata())

    if remove_nan:
        struct_arr = np.nan_to_num(struct_arr)
    if mask is not None:
        struct_arr *= mask
    if z_factor is not None:
        struct_arr = np.around(zoom(struct_arr, z_factor), 0)

    return struct_arr


def save_nifti(file_path, struct_arr):
    """Save a 3D array to a NIFTI file."""
    img = nib.Nifti1Image(struct_arr, np.eye(4))
    nib.save(img, file_path)

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
def prepareData(use_smooth = False, runningOnServer = True):
    rootDir = 'C:/Users/Eshan/Google Drive UALBERTA/Data/' if not runningOnServer else '/mnt/hdd1/lxc-hdd1/tahjid/PD Data/'
    patientList, patientNumbers, dataset = [], [], []
    labelMap = dict(Control=0, PD=1)
    typeMap = dict(FullScan=0, GrayMatter=1, WhiteMatter=2)
    fullScanPath = rootDir + 'FinalData/'
    wmgmpath = rootDir + 'FinalDataWMGM/' if not use_smooth else rootDir + 'FinalDataWMGMSmooth/'
    prefix = 'mwp' if not use_smooth else 'smwp'
    ext = '.nii'
    for i in tqdm(['Control', 'PD']):
        path = fullScanPath + i + '/'
        listOfFiles = [f for f in os.listdir(path) if f.endswith(ext)]
        for file in tqdm(listOfFiles):
            patientNumbers.append(file[:4])
        path = wmgmpath + i + '/'
        listOfFiles = [f for f in os.listdir(path) if f.endswith(ext)]
        for file in tqdm(listOfFiles):
            filename = file[4:8] if not use_smooth else file[5:9]
            if filename not in patientNumbers:
                continue
            if not use_smooth:
                patientList.append([i, file[4:8]])
            else:
                patientList.append([i, file[5:9]])

    for i in tqdm(patientList):
        path = fullScanPath + i[0] + '/'
        patientIdVal = i[1]
        fullScanvalue = os.path.join(path + i[1] + ext)
        path = wmgmpath + i[0] + '/'
        gmval = os.path.join(path + prefix + str(typeMap['GrayMatter']) + i[1] + ext)
        wmval = os.path.join(path + prefix + str(typeMap['WhiteMatter']) + i[1] + ext)
        labelval = labelMap[i[0]]
        dataset.append([patientIdVal,fullScanvalue, gmval, wmval, labelval])
    return np.array(dataset)
dataset = prepareData()

  0%|          | 0/2 [00:00<?, ?it/s]
100%|██████████| 299/299 [00:00<00:00, 2263712.81it/s]

100%|██████████| 626/626 [00:00<00:00, 317642.67it/s]
 50%|█████     | 1/2 [00:00<00:00,  6.01it/s]
100%|██████████| 299/299 [00:00<00:00, 886914.35it/s]

100%|██████████| 714/714 [00:00<00:00, 23364.41it/s]
100%|██████████| 2/2 [00:00<00:00,  5.77it/s]
100%|██████████| 1170/1170 [00:00<00:00, 99320.68it/s]


In [24]:
from sklearn.model_selection import train_test_split


In [27]:
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=.20, random_state=42)


In [3]:
def normalize(input):
    return 2*(input-input.min())/(input.max()-input.min())-1

class DataGenerator(keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, data, labels, batch_size=2, dim1=(91,109,91), dim2=(242,145,121) , n_channels=1,
                 n_classes=2, shuffle=True, target_size = 256, resize = True, normalize = True):
        'Initialization'
        self.dim = dim2
#         self.dim2 = dim2
        self.batch_size = batch_size
        self.labels = list(labels)
        self.data = data
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.list_IDs = list(data[:,:1])
        self.resize = resize
        self.normalize = normalize
        self.target_size = tuple(int(a) for a in (target_size/4, target_size/2, target_size/2))
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))
    
    def __numbatches__(self):
        return int(np.floor(len(self.list_IDs) / self.__len__()))
    
    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]
        # print(indexes)
        # Find list of IDs
        # list_IDs_temp = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(indexes)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle:
            np.random.shuffle(self.indexes)

    def __data_generation(self, indexes):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
#         fs = np.empty((self.batch_size, *self.dim1, self.n_channels))
#         gm = np.empty((self.batch_size, *self.dim2, self.n_channels))
#         wm = np.empty((self.batch_size, *self.dim2, self.n_channels))
#         gmwm = np.empty((self.batch_size, *self.dim, self.n_channels))
        gmwm = np.empty((self.batch_size, self.target_size[1], self.target_size[1], self.target_size[2], self.n_channels))
        y = np.empty((self.batch_size), dtype=int)

        # Generate data
        for i, count in enumerate(indexes):
            val = self.data[count]
#             fullscan = load_nifti(val[1])
            graymatter = load_nifti(val[2])
            whitematter = load_nifti(val[3])
        
            graymatter = graymatter.astype(np.float64)
            whitematter = whitematter.astype(np.float64)
            
            if self.normalize:
                graymatter = normalize(graymatter)
                whitematter = normalize(whitematter)
            
            if self.resize:
                graymatter = skt.resize(graymatter, self.target_size, mode = 'constant')
                whitematter = skt.resize(whitematter, self.target_size, mode = 'constant')
            
            graymatter = graymatter.astype(np.float32)
            whitematter = whitematter.astype(np.float32)
            
            # print(fullscan.shape, graymatter.shape, whitematter.shape)
            # break
            # Store sample
#             fs[i,] = fullscan[..., np.newaxis]
#             gm[i,] = graymatter[..., np.newaxis]
#             wm[i,] = whitematter[..., np.newaxis]
#             print(graymatter.shape)
#             print(whitematter.shape)
            gmwm[i,] = np.concatenate((graymatter, whitematter))[..., np.newaxis]


            # Store class
            y[i] = self.labels[i]

        return gmwm, y #keras.utils.to_categorical(y, num_classes=self.n_classes)
    

In [13]:
X = dataset[:,:4]
y = dataset[:,4:]
d = DataGenerator(X, y)


In [28]:
d.__numbatches__()

2

In [14]:
a = d.__getitem__(0)
# b = np.concatenate((a[1][0].reshape(121,145,121), a[1][0].reshape(121,145,121)), axis=0)

In [15]:
a[0][0].shape

(128, 128, 128, 1)

In [4]:
def gen(phase_train=True, n_classes = 2 ,params={'z_size':500, 'strides':(2,2,2), 'kernel_size':(4,4,4)}):
    """
    Returns a Generator Model with input params and phase_train
    Args:
        phase_train (boolean): training phase or not
        params (dict): Dictionary with model parameters
    Returns:
        model (keras.Model): Keras Generator model
    """

    z_size = params['z_size']
    strides = params['strides']
    kernel_size = params['kernel_size']
    
    in_label = Input(shape=(1,))
    li = Embedding(n_classes, 50)(in_label)
    li = Dense(z_size)(li)
    li = Reshape((1,1,1,z_size))(li)
    
    
    inputs = Input(shape=(1, 1, 1, z_size))
    
    
    merge_inputs = Concatenate()([inputs, li])
    
    g1 = Deconv3D(filters=128, kernel_size=kernel_size,
                  strides=(1, 1, 1), kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='valid')(merge_inputs)
    g1 = BatchNormalization()(g1, training=phase_train)
    g1 = Activation(activation='relu')(g1)

    g2 = Deconv3D(filters=64, kernel_size=kernel_size,
                  strides=strides, kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='same')(g1)
    g2 = BatchNormalization()(g2, training=phase_train)
    g2 = Activation(activation='relu')(g2)

    g3 = Deconv3D(filters=32, kernel_size=kernel_size,
                  strides=strides, kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='same')(g2)
    g3 = BatchNormalization()(g3, training=phase_train)
    g3 = Activation(activation='relu')(g3)

    g4 = Deconv3D(filters=16, kernel_size=kernel_size,
                  strides=strides, kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='same')(g3)
    g4 = BatchNormalization()(g4, training=phase_train)
    g4 = Activation(activation='relu')(g4)

    g5 = Deconv3D(filters=1, kernel_size=kernel_size,
                  strides=strides, kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='same')(g4)
    g5 = BatchNormalization()(g5, training=phase_train)
    g5 = Activation(activation='sigmoid')(g5)

    model = Model(inputs=[inputs, in_label], outputs=g5)
#     model.summary()

    return model

In [22]:
model = gen()
model.summary()
# plt.show()
plot_model(model, to_file = "generator_plot.png", show_shapes = True, show_layer_names = True)
# 

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_7 (InputLayer)            (None, 1)            0                                            
__________________________________________________________________________________________________
embedding_4 (Embedding)         (None, 1, 50)        100         input_7[0][0]                    
__________________________________________________________________________________________________
dense_4 (Dense)                 (None, 1, 500)       25500       embedding_4[0][0]                
__________________________________________________________________________________________________
input_8 (InputLayer)            (None, 1, 1, 1, 500) 0                                            
__________________________________________________________________________________________________
reshape_4 

In [5]:
def dis(phase_train = True, n_classes = 2, params={'cube_len':64, 'strides':(2,2,2), 'kernel_size':(4,4,4), 'leak_value':0.2}):
    """
    Returns a Discriminator Model with input params and phase_train 
    Args:
        phase_train (boolean): training phase or not
        params (dict): Dictionary with model parameters    
    Returns:
        model (keras.Model): Keras Discriminator model
    """
    cube_len = params['cube_len']
    strides = params['strides']
    kernel_size = params['kernel_size'] 
    leak_value = params['leak_value']
    
    in_label = Input(shape=(1,))
    li = Embedding(n_classes, 5)(in_label)
    n_nodes = cube_len * cube_len * cube_len
    li = Dense(n_nodes)(li)
    li = Reshape((cube_len, cube_len, cube_len, 1))(li)
    
    inputs = Input(shape=(cube_len, cube_len, cube_len, 1))
    
    merge_inputs = Concatenate()([inputs, li])
    
    d1 = Conv3D(filters=16, kernel_size=kernel_size,
                  strides=strides, kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='same')(merge_inputs)
    d1 = BatchNormalization()(d1, training=phase_train)
    d1 = LeakyReLU(leak_value)(d1)

    d2 = Conv3D(filters=32, kernel_size=kernel_size,
                  strides=strides, kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='same')(d1)
    d2 = BatchNormalization()(d2, training=phase_train)
    d2 = LeakyReLU(leak_value)(d2)

    d3 = Conv3D(filters=64, kernel_size=kernel_size,
                  strides=strides, kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='same')(d2)
    d3 = BatchNormalization()(d3, training=phase_train)
    d3 = LeakyReLU(leak_value)(d3)

    d4 = Conv3D(filters=128, kernel_size=kernel_size,
                  strides=strides, kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='same')(d3)
    d4 = BatchNormalization()(d4, training=phase_train)
    d4 = LeakyReLU(leak_value)(d4)

    d5 = Conv3D(filters=1, kernel_size=(1,1,1),
                  strides=(1, 1, 1), kernel_initializer='glorot_normal',
                  bias_initializer='zeros', padding='valid')(d4)
    d5 = BatchNormalization()(d5, training=phase_train)
    d5 = Activation(activation='sigmoid')(d5) 
    
    d6 = Flatten()(d5)
    d6 = Dropout(0.4)(d6)
    d6 = Dense(1, activation='sigmoid')(d6)
    model = Model(inputs=[inputs, in_label], outputs=d6)
#     model.summary()

    return model

In [15]:
model = dis()
model.summary()
plot_model(model, to_file = "discriminator_plot.png", show_shapes = True, show_layer_names = True)

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_11 (InputLayer)           (None, 1)            0                                            
__________________________________________________________________________________________________
embedding_6 (Embedding)         (None, 1, 5)         10          input_11[0][0]                   
__________________________________________________________________________________________________
dense_9 (Dense)                 (None, 1, 262144)    1572864     embedding_6[0][0]                
__________________________________________________________________________________________________
input_12 (InputLayer)           (None, 64, 64, 64, 1 0                                            
__________________________________________________________________________________________________
reshape_6 

In [6]:
def GAN(generator, discriminator):
#     model = Sequential()
#     model.add(generator)
#     discriminator.trainable = False
#     model.add(discriminator)

    discriminator.trainable = False
    g_image, g_label = generator.input
    g_output = generator.output
    gan_output = discriminator([g_output, g_label])
    model = Model([g_image, g_label], gan_output)
    model.summary()
    return model

In [7]:
def train():
    n_epochs   = 10000
    batch_size = 1
    g_lr       = 0.008
    d_lr       = 0.000001
    beta       = 0.5
    z_size     = 500
    obj_ratio  = 0.5
    discriminator = dis()
    generator = gen()
    try:
        discriminator = multi_gpu_model(discriminator, gpus = 4)
        generator = multi_gpu_model(generator, gpus = 4)
    except: 
        pass
    discriminator_on_generator = GAN(generator, discriminator)
    try:
        discriminator_on_generator = multi_gpu_model(discriminator_on_generator, gpus = 4)
#         generator = multi_gpu_model(generator)
    except: 
        pass
    
    g_optim = Adam(lr=g_lr, beta_1=beta)
    generator.compile(loss='binary_crossentropy', optimizer="SGD")

    d_optim = Adam(lr=d_lr, beta_1=0.9)
    discriminator_on_generator.compile(loss='binary_crossentropy', optimizer=g_optim)
    discriminator.trainable = True
    discriminator.compile(loss='binary_crossentropy', optimizer=d_optim)
  
     
    z_sample = np.random.normal(0, 0.33, size=[batch_size, 1, 1, 1, z_size]).astype(np.float32)
#     volumes = d.getAll(obj=obj, train=True, is_local=is_local, obj_ratio=obj_ratio)
    X = dataset[:,:4]
    y = dataset[:,4:]
    dataloader = DataGenerator(X, y, batch_size = batch_size, target_size = 128)
    print('Data loaded .......')
#     volumes = volu/mes[...,np.newaxis].astype(np.float) 

#     if not os.path.exists(train_sample_directory):
#         os.makedirs(train_sample_directory)
#     if not os.path.exists(model_directory):
#         os.makedirs(model_directory)         

    for epoch in trange(1):
        for batch in trange(dataloader.__numbatches__()):
            batchdata = dataloader.__getitem__(batch)
#             print(batchdata[0].shape)
#             idx = np.random.randint(len(volumes), size=batch_size)
#             x = volumes[idx]
            z = np.random.normal(0, 0.33, size=[batch_size, 1, 1, 1, z_size]).astype(np.float32)
            scans = batchdata[0]
            labels = batchdata[1]

#             generated_volumes = generator.predict(z, verbose=0)

#             X = np.concatenate((x, generated_volumes))
#             Y = np.reshape(labels + [0]*batch_size, (batch_size,))
            
            d_loss = discriminator.train_on_batch([scans, labels], np.ones((batch_size, 1)))       
            print("d_loss : %f" % (d_loss))

#             z = np.random.normal(0, 0.33, size=[batch_size, 1, 1, 1, z_size]).astype(np.float32)            
#             discriminator.trainable = False
#             g_loss = discriminator_on_generator.train_on_batch(z, np.reshape([1]*batch_size, (-1,1,1,1,1)))
#             discriminator.trainable = True

#             print("g_loss : %f" % (g_loss))

#             if epoch % 1000 == 10:
#                 generator.save_weights(model_directory +'generator_' + str(epoch), True)
#                 discriminator.save_weights(model_directory +'discriminator_' + str(epoch), True)

#             if epoch % 500 == 10:
#                 generated_volumes = generator.predict(z_sample, verbose=0)
#                 generated_volumes.dump(train_sample_directory+'/'+str(epoch))  
# train()        

In [9]:
strategy = tf.contrib.distribute.MirroredStrategy()

In [None]:
# with strategy.scope():
#     train()
train()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            (None, 1, 1, 1, 500) 0                                            
__________________________________________________________________________________________________
input_3 (InputLayer)            (None, 1)            0                                            
__________________________________________________________________________________________________
lambda_9 (Lambda)               (None, 1, 1, 1, 500) 0           input_4[0][0]                    
__________________________________________________________________________________________________
lambda_10 (Lambda)              (None, 1)            0           input_3[0][0]                    
__________________________________________________________________________________________________
lambda_11 

  0%|          | 0/1 [00:00<?, ?it/s]
  0%|          | 0/1 [00:00<?, ?it/s][A

MirroredVariable({'/replica:0/task:0/device:GPU:0': <tf.Variable 'Variable:0' shape=() dtype=int32>, '/replica:0/task:0/device:GPU:1': <tf.Variable 'Variable/replica_1:0' shape=() dtype=int32>, '/replica:0/task:0/device:GPU:2': <tf.Variable 'Variable/replica_2:0' shape=() dtype=int32>, '/replica:0/task:0/device:GPU:3': <tf.Variable 'Variable/replica_3:0' shape=() dtype=int32>})

In [None]:
dataloader = DataGenerator(X, y)

In [None]:
x, y = dataloader.__getitem__(0)

In [None]:
x.shape