In [1]:
import numpy as np
import pandas as pd
import os
from sklearn.model_selection import train_test_split
# import wandb

##### Tensorflow #####
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import metrics
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ReduceLROnPlateau, Callback
from wandb.keras import WandbCallback
import tensorflow.keras.backend as K

os.environ['HDF5_USE_FILE_LOCKING'] = 'FALSE'

# Set this environment variable to only use the first GPU on Maeda
# Change to '1' to use second GPU
# Comment out to use both GPUs
os.environ['CUDA_VISIBLE_DEVICES'] = '1'

# For tensorflow 2.4.1 allow memory growth on GPU
gpus = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

# Login to Weights and Biases for metric tracking
# wandb.login()

In [2]:
# L2 Dice loss
def dsc_l2(y_true, y_pred, smooth=0.0000001):
    if len(y_true.shape) == 5:
        num = K.sum(K.square(y_true - y_pred), axis=(1,2,3))
        den = K.sum(K.square(y_true), axis=(1,2,3)) + K.sum(K.square(y_pred), axis=(1,2,3)) + smooth
    elif len(y_true.shape) == 4:
        num = K.sum(K.square(y_true - y_pred), axis=(1,2))
        den = K.sum(K.square(y_true), axis=(1,2)) + K.sum(K.square(y_pred), axis=(1,2)) + smooth
    return K.mean(num/den, axis=-1)

In [3]:
def PocketNet(inputShape, 
              numClasses, 
              mode, 
              net, 
              pocket, 
              initFilters, 
              depth):
    
    '''
    PocketNet - Smaller CNN for medical image segmentation
    
    Inputs:
    inputShape   : Size of network input - (depth, height, width, channels) for 3D
                   (height, width, channels) for 2D
    numClasses   : Number of output classes
    mode         : 'seg' or 'class' for segmenation or classification network
    net          : 'unet', 'resnet', or 'densenet' for U-Net, ResNet or DenseNet blocks
    pocket       : True/False for pocket architectures
    initFilters  : Number of starting filters at input level
    depth        : Number of max-pooling layers
    
    Outputs:
    model        : Keras model for training/predicting
    
    Author: Adrian Celaya
    Last modified: 4.20.2021
    '''
    
    # 3D inputs are (depth, height, width, channels)
    if len(inputShape) == 4:
        dim = '3d'
    # 2D inputs are (height, width, channels)
    elif len(inputShape) == 3:
        dim = '2d'
    
    # Convolution block operator
    def Block(x, filters, params, net, dim):
        # DenseNet block
        if net == 'densenet':
            for _ in range(2):
                if dim == '3d':
                    y = layers.Conv3D(filters, **params[0])(x)
                elif dim == '2d':
                    y = layers.Conv2D(filters, **params[0])(x)
                x = layers.concatenate([x, y])
            x = layers.Conv3D(filters, **params[1])(x)
        
        # ResNet block
        elif net == 'resnet':
            if dim == '3d':
                y = layers.Conv3D(filters, **params[0])(x)
                y = layers.Conv3D(filters, **params[0])(y)
            elif dim == '2d':
                y = layers.Conv2D(filters, **params[0])(x)
                y = layers.Conv2D(filters, **params[0])(y)
                
            x = layers.concatenate([x, y])
            
            if dim == '3d':
                x = layers.Conv3D(filters, **params[1])(x)
            elif dim == '2d':
                x = layers.Conv2D(filters, **params[1])(x)
        
        # U-Net block
        elif net == 'unet':
            if dim == '3d':
                x = layers.Conv3D(filters, **params[0])(x)
                x = layers.Conv3D(filters, **params[0])(x)
            elif dim == '2d':
                x = layers.Conv2D(filters, **params[0])(x)
                x = layers.Conv2D(filters, **params[0])(x)
                
        return x

    # Downsampling block - Convolution + maxpooling
    def TransitionDown(x, filters, params, net, dim):
        skip = Block(x, filters, params, net, dim)
        
        if dim == '3d':
            x = layers.MaxPooling3D(pool_size = (1, 2, 2), strides = (1, 2, 2))(skip)
        elif dim == '2d':
            x = layers.MaxPooling2D(pool_size = (2, 2), strides = (2, 2))(skip)
            
        return skip, x

    # Upsampling block - Transposed convolution + concatenation + convolution
    def TransitionUp(x, skip, filters, params, net, dim):
        
        if dim == '3d':
            x = layers.Conv3DTranspose(filters, **params[2])(x)
        elif dim == '2d':
            x = layers.Conv2DTranspose(filters, **params[2])(x)
            
        x = layers.concatenate([x, skip])
        x = Block(x, filters, params, net, dim)
        return x
    
    # Parameters for each convolution operation
    params = list()
    if dim == '3d':
        params.append(dict(kernel_size = (3, 3, 3), activation = 'relu', padding = 'same'))
        params.append(dict(kernel_size = (1, 1, 1), activation = 'relu', padding = 'same'))
        params.append(dict(kernel_size = (1, 2, 2), strides = (1, 2, 2), padding = 'same'))
    elif dim == '2d':
        params.append(dict(kernel_size = (3, 3), activation = 'relu', padding = 'same'))
        params.append(dict(kernel_size = (1, 1), activation = 'relu', padding = 'same'))
        params.append(dict(kernel_size = (2, 2), strides = (2, 2), padding = 'same'))

        
    # Keep filters constant for PocketNet
    if pocket:
        filters = [initFilters for i in range(depth + 1)]
    else:
        filters = [initFilters * 2 ** (i) for i in range(depth + 1)]
    
    # Input to network
    inputs = layers.Input(inputShape)
 
    # Encoder path
    x = inputs
    skips = list()
    for i in range(depth):
        skip, x = TransitionDown(x, filters[i], params, net, dim)
        skips.append(skip)
        
    # Bottleneck
    x = Block(x, filters[-1], params, net, dim)

    # Apply global max-pooling to output of bottleneck if classification
    if mode == 'class':
        x = layers.GlobalMaxPooling2D()(x)
        output = layers.Dense(numClasses, activation = 'softmax')(x)

    
    # Continue with decoder path if segmentation
    elif mode == 'seg':
        
        for i in range(depth - 1, -1, -1):
            x = TransitionUp(x, skips[i], filters[i], params, net, dim)
            
        if dim == '3d':
            output = layers.Conv3D(numClasses, (1, 1, 1), activation = 'softmax')(x)
        elif dim == '2d':
            output = layers.Conv2D(numClasses, (1, 1), activation = 'softmax')(x)
            
    model = Model(inputs = [inputs], outputs = [output])
    return model 

In [4]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, dataframe, batch_size = 1, dim = (5, 240, 240), n_channels = 4, n_classes = 2, shuffle = True):
        self.dim = dim
        self.dataframe = dataframe
        self.batch_size = batch_size
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.dataframe) / self.batch_size))

    def __getitem__(self, index):
        X, y = self.__data_generation(index)
        return X, y

    def on_epoch_end(self):
        if self.shuffle:
            self.dataframe = self.dataframe.sample(frac = 1).reset_index(drop = True)
        
    def __data_generation(self, index):
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size, *self.dim, self.n_classes))

        for i in range(index, index + self.batch_size):
            X[i - index] = np.load(self.dataframe.iloc[i]['image'])
            y[i - index] = np.load(self.dataframe.iloc[i]['mask'])
        return X, y

In [5]:
def RunModel(batchSize, pocket):
    
#     experimentName = wandb.util.generate_id()
#     wandb.init(project = 'brats-pocketnet', 
#                group = experimentName)
    
    # Load main dataframe with images and targets
    train = pd.read_csv('brats_slices_paths.csv')
    pats = np.unique(train['id'])

    # Fix a test set and scale up the size of each training set
    trainPats, valPats, _, _ = train_test_split(pats, pats, test_size = 0.20, random_state = 0)
    
    val = train.loc[train['id'].isin(valPats)]
    val = val.reset_index(drop = True)
    numVal = len(val) # Need number of validation patients for keras fit_generator function
    
    train = train.loc[train['id'].isin(trainPats)]
    train = train.reset_index(drop = True)
    numTrain = len(train)

    if pocket:
        print('Running pocket U-Net')
    else:
        print('Running full U-Net')
            
    # Create training and validation data generators
    trainGenerator = DataGenerator(train, batchSize)
    validationGenerator = DataGenerator(val, batchSize)

    # Create model
    model = PocketNet(inputShape = (5, 240, 240, 4), 
                      numClasses = 2, 
                      mode = 'seg', 
                      net = 'unet', 
                      pocket = pocket, 
                      initFilters = 16, 
                      depth = 4)
    
    # Compile model with Dice loss
    model.compile(optimizer = 'adam', loss = [dsc_l2])

    # Reduce learning rate by 0.5 if validation dice coefficient does not improve after 5 epochs
    reduceLr = ReduceLROnPlateau(monitor = 'val_loss', 
                                 mode = 'min',
                                 factor = 0.5, 
                                 patience = 5, 
                                 min_lr = 0.000001, 
                                 verbose = 1)

    # Train model
    model.fit(trainGenerator, 
              epochs = 15, 
              steps_per_epoch = (numTrain // (10 * batchSize)), 
              validation_data = validationGenerator, 
              validation_steps = (numVal // (10 * batchSize)), 
              callbacks = [reduceLr],#, WandbCallback()], 
              verbose = 1, 
              use_multiprocessing = True, 
              workers = 8)
    
    #wandb.finish()
        
    ##### END OF FUNCTION #####

In [6]:
RunModel(batchSize = 16, pocket = True)

Running pocket U-Net
Epoch 1/15

Process Keras_worker_ForkPoolWorker-8:
Process Keras_worker_ForkPoolWorker-6:
Process Keras_worker_ForkPoolWorker-1:
Traceback (most recent call last):
Process Keras_worker_ForkPoolWorker-5:
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

Process Keras_worker_ForkPoolWorker-2:
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
Process Keras_worker_ForkPoolWorker-4:


Traceback (most recent call last):
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3437, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-6-65f4acec007c>", line 1, in <module>
    RunModel(batchSize = 16, pocket = True)
  File "<ipython-input-5-4c0edc69efc9>", line 52, in RunModel
    model.fit(trainGenerator,
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/site-packages/wandb/integration/keras/keras.py", line 124, in new_v2
    return old_v2(*args, **kwargs)
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 1100, in fit
    tmp_logs = self.train_function(iterator)
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 828, in __call__
    result = self._call(*args, **kwds)
  File "/opt/apps/miniconda

Traceback (most recent call last):
Process Keras_worker_ForkPoolWorker-7:


  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/site-packages/IPython/core/ultratb.py", line 248, in wrapped
    return f(*args, **kwargs)
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/site-packages/IPython/core/ultratb.py", line 281, in _fixed_getinnerframes
    records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/inspect.py", line 1515, in getinnerframes
    frameinfo = (tb.tb_frame,) + getframeinfo(tb, context)
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/inspect.py", line 1473, in getframeinfo
    filename = getsourcefile(frame) or getfile(frame)
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/inspect.py", line 708, in getsourcefile
    if getattr(getmodule(object, filename), '__loader__', None) is not None:
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/inspect.py", line 754, in getmodul

Traceback (most recent call last):
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/multiprocessing/pool.py", line 131, in worker
    put((job, i, result))
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Process Keras_worker_ForkPoolWorker-3:
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/multiprocessing/queues.py", line 367, in put
    with self._wlock:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceback (most recent call last):
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._

TypeError: object of type 'NoneType' has no len()

Process Keras_worker_ForkPoolWorker-11:
Process Keras_worker_ForkPoolWorker-13:
Process Keras_worker_ForkPoolWorker-17:
Process Keras_worker_ForkPoolWorker-16:
Process Keras_worker_ForkPoolWorker-14:
Process Keras_worker_ForkPoolWorker-15:
Process Keras_worker_ForkPoolWorker-12:
Process Keras_worker_ForkPoolWorker-10:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    sel

  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/site-packages/IPython/core/ultratb.py", line 248, in wrapped
    return f(*args, **kwargs)
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/site-packages/IPython/core/ultratb.py", line 281, in _fixed_getinnerframes
    records = fix_frame_records_filenames(inspect.getinnerframes(etb, context))
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/inspect.py", line 1515, in getinnerframes
    frameinfo = (tb.tb_frame,) + getframeinfo(tb, context)
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/inspect.py", line 1473, in getframeinfo
    filename = getsourcefile(frame) or getfile(frame)
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/inspect.py", line 708, in getsourcefile
    if getattr(getmodule(object, filename), '__loader__', None) is not None:
  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/inspect.py", line 754, in getmodul

  File "/opt/apps/miniconda/miniconda3/envs/tf_latest/lib/python3.8/multiprocessing/connection.py", line 379, in _recv
    chunk = read(handle, remaining)
KeyboardInterrupt
