In [1]:
DATADIR = '/archive/bioinformatics/DLLab/KevinNguyen/data/BraTS2020'

In [2]:
# Limit usage to GPU 1
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('XLA_GPU')
print(gpus)

[PhysicalDevice(name='/physical_device:XLA_GPU:0', device_type='XLA_GPU')]


In [3]:
import sys
sys.path.append('../../')
from medl import unet
import numpy as np
import tensorflow.keras.backend as K

In [4]:
# Dice coefficient metric
def dice(yTrue, yPred):
    yTrueMask = tf.cast(yTrue > 0.5, dtype=tf.float32)
    yPredMask = tf.cast(yPred > 0.5, dtype=tf.float32)
    intersection = tf.reduce_sum(tf.multiply(yTrueMask, yPredMask), axis=(1, 2))
    union = tf.reduce_sum(yTrueMask, axis=(1, 2)) + tf.reduce_sum(yPredMask, axis=(1, 2))
    dice = tf.reduce_mean((2.0 * intersection + 1e-6) / (union + 1e-6))
    return dice

In [5]:
# Image preprocessing function, normalizes values to [0, 1]
def preprocess(arrImage):
    arrOut = arrImage.copy()
    arrOut /= arrOut.max()
    return arrOut

In [6]:
# Load cross-validation data
iFold = 0
strDataPath = os.path.join(DATADIR, f'fold{iFold}.npz')
dictData = np.load(strDataPath)
tupImgShape = dictData['t1ce_train'].shape[1:]
print('Images are size', tupImgShape)

# Create data generators
train_data = tf.keras.preprocessing.image.ImageDataGenerator(preprocessing_function=preprocess)
val_data = tf.keras.preprocessing.image.ImageDataGenerator(preprocessing_function=preprocess)

Images are size (240, 240, 1)


In [7]:
tf.keras.backend.clear_session()
model = unet.unet(input_size=tupImgShape)
model.compile(optimizer=tf.keras.optimizers.Adam(lr = 1e-4), 
              loss='binary_crossentropy', 
              metrics = ['accuracy', dice])
lsCallbacks = [tf.keras.callbacks.EarlyStopping(monitor='val_dice', patience=2, restore_best_weights=True, mode='max')]
model.fit(train_data.flow(dictData['t1ce_train'], dictData['mask_train']),
          validation_data=val_data.flow(dictData['t1ce_val'], dictData['mask_val']),
          epochs=3,
          batch_size=32,
          callbacks=lsCallbacks,
          verbose=1)

Epoch 1/3
  3/314 [..............................] - ETA: 1:13:37 - loss: 0.8268 - accuracy: 0.4055 - dice: 0.0120

KeyboardInterrupt: 

In [8]:
model.summary()

Model: "functional_3"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 240, 240, 1) 0                                            
__________________________________________________________________________________________________
conv2d_24 (Conv2D)              (None, 240, 240, 64) 640         input_2[0][0]                    
__________________________________________________________________________________________________
conv2d_25 (Conv2D)              (None, 240, 240, 64) 36928       conv2d_24[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_4 (MaxPooling2D)  (None, 120, 120, 64) 0           conv2d_25[0][0]                  
_______________________________________________________________________________________