### PocketNet Performance Profiling
Use this notebook to perform profiling on a Pocket U-Net and full-sized U-Net with a subset of BraTS tumor segmentation data.

In [None]:
import os
import numpy as np
import pandas as pd
from datetime import datetime
from packaging import version

##### Tensorflow #####
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, metrics
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K

# Set this environment variable to only use the first GPU 
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

# For tensorflow 2.x.x allow memory growth on GPU
###################################
gpus = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)
###################################

# Use this to allow growth on TensorFlow v1.x.x
# ###################################
# config = tf.ConfigProto()
 
# # Don't pre-allocate memory; allocate as-needed
# config.gpu_options.allow_growth = True
 
# # Only allow a specified percent of the GPU memory to be allocated
# config.gpu_options.per_process_gpu_memory_fraction = 0.75
 
# # Create a session with the above options specified.
# K.tensorflow_backend.set_session(tf.Session(config = config))
# ##################################

print('======== GPU Information ========')
!nvidia-smi --query-gpu=gpu_name,memory.total --format=csv

In [None]:
# Uncomment to install tensorboard profile plugin
#!pip install -U tensorboard_plugin_profile

In [None]:
# L2 Dice loss
def DiceLossL2(y_true, y_pred):
    smooth = 0.0000001
    
    # (batch size, depth, height, width, channels)
    if len(y_true.shape) == 5:
        num = K.sum(K.square(y_true - y_pred), axis = (1,2,3))
        den = K.sum(K.square(y_true), axis = (1,2,3)) + K.sum(K.square(y_pred), axis = (1,2,3)) + smooth
        
    # (batch size, height, width, channels)
    elif len(y_true.shape) == 4:
        num = K.sum(K.square(y_true - y_pred), axis = (1,2))
        den = K.sum(K.square(y_true), axis = (1,2)) + K.sum(K.square(y_pred), axis = (1,2)) + smooth
        
    return K.mean(num/den, axis = -1)

In [None]:
def PocketNet(inputShape, 
              numClasses, 
              mode, 
              net, 
              pocket, 
              initFilters, 
              depth):
    
    '''
    PocketNet - Smaller CNN for medical image segmentation
    
    Inputs:
    inputShape   : Size of network input - (depth, height, width, channels) for 3D
                   (height, width, channels) for 2D
    numClasses   : Number of output classes
    mode         : 'seg' or 'class' for segmenation or classification network
    net          : 'unet', 'resnet', or 'densenet' for U-Net, ResNet or DenseNet blocks
    pocket       : True/False for pocket architectures
    initFilters  : Number of starting filters at input level
    depth        : Number of max-pooling layers
    
    Outputs:
    model        : Keras model for training/predicting
    
    Author: Adrian Celaya
    Last modified: 4.20.2021
    '''
    
    # 3D inputs are (depth, height, width, channels)
    if len(inputShape) == 4:
        dim = '3d'
    # 2D inputs are (height, width, channels)
    elif len(inputShape) == 3:
        dim = '2d'
    
    # Convolution block operator
    def Block(x, filters, params, net, dim):
        ### DenseNet block ###
        if net == 'densenet':
            for _ in range(2):
                if dim == '3d':
                    y = layers.Conv3D(filters, **params[0])(x)
                elif dim == '2d':
                    y = layers.Conv2D(filters, **params[0])(x)
                x = layers.concatenate([x, y])
                
            if dim == '3d':
                x = layers.Conv3D(filters, **params[1])(x)
            elif dim == '2d':
                x = layers.Conv2D(filters, **params[1])(x)
        
        ### ResNet block ###
        elif net == 'resnet':
            if dim == '3d':
                y = layers.Conv3D(filters, **params[0])(x)
                y = layers.Conv3D(filters, **params[0])(y)
            elif dim == '2d':
                y = layers.Conv2D(filters, **params[0])(x)
                y = layers.Conv2D(filters, **params[0])(y)
                
            x = layers.concatenate([x, y])
            
            if dim == '3d':
                x = layers.Conv3D(filters, **params[1])(x)
            elif dim == '2d':
                x = layers.Conv2D(filters, **params[1])(x)
        
        ### U-Net block ###
        elif net == 'unet':
            if dim == '3d':
                x = layers.Conv3D(filters, **params[0])(x)
                x = layers.Conv3D(filters, **params[0])(x)
            elif dim == '2d':
                x = layers.Conv2D(filters, **params[0])(x)
                x = layers.Conv2D(filters, **params[0])(x)
                
        return x

    # Downsampling block - Convolution + maxpooling
    def TransitionDown(x, filters, params, net, dim):
        skip = Block(x, filters, params, net, dim)
        
        if dim == '3d':
            x = layers.MaxPooling3D(pool_size = (1, 2, 2), strides = (1, 2, 2))(skip)
        elif dim == '2d':
            x = layers.MaxPooling2D(pool_size = (2, 2), strides = (2, 2))(skip)
            
        return skip, x

    # Upsampling block - Transposed convolution + concatenation + convolution
    def TransitionUp(x, skip, filters, params, net, dim):
        
        if dim == '3d':
            x = layers.Conv3DTranspose(filters, **params[2])(x)
        elif dim == '2d':
            x = layers.Conv2DTranspose(filters, **params[2])(x)
            
        x = layers.concatenate([x, skip])
        x = Block(x, filters, params, net, dim)
        return x
    
    # Parameters for each convolution operation
    params = list()
    if dim == '3d':
        params.append(dict(kernel_size = (3, 3, 3), activation = 'relu', padding = 'same'))
        params.append(dict(kernel_size = (1, 1, 1), activation = 'relu', padding = 'same'))
        params.append(dict(kernel_size = (1, 2, 2), strides = (1, 2, 2), padding = 'same'))
    elif dim == '2d':
        params.append(dict(kernel_size = (3, 3), activation = 'relu', padding = 'same'))
        params.append(dict(kernel_size = (1, 1), activation = 'relu', padding = 'same'))
        params.append(dict(kernel_size = (2, 2), strides = (2, 2), padding = 'same'))

        
    # Keep filters constant for PocketNet
    if pocket:
        filters = [initFilters for i in range(depth + 1)]
    else:
        filters = [initFilters * 2 ** (i) for i in range(depth + 1)]
    
    # Input to network
    inputs = layers.Input(inputShape)
 
    # Encoder path
    x = inputs
    skips = list()
    for i in range(depth):
        skip, x = TransitionDown(x, filters[i], params, net, dim)
        skips.append(skip)
        
    # Bottleneck
    x = Block(x, filters[-1], params, net, dim)

    # Apply global max-pooling to output of bottleneck if classification
    if mode == 'class':
        x = layers.GlobalMaxPooling2D()(x)
        output = layers.Dense(numClasses, activation = 'softmax')(x)

    
    # Continue with decoder path if segmentation
    elif mode == 'seg':
        
        for i in range(depth - 1, -1, -1):
            x = TransitionUp(x, skips[i], filters[i], params, net, dim)
            
        if dim == '3d':
            output = layers.Conv3D(numClasses, (1, 1, 1), activation = 'softmax')(x)
        elif dim == '2d':
            output = layers.Conv2D(numClasses, (1, 1), activation = 'softmax')(x)
            
    model = Model(inputs = [inputs], outputs = [output])
    return model

In [None]:
class DataGenerator(keras.utils.Sequence):
    def __init__(self, dataframe, batch_size = 4, dim = (5, 240, 240), n_channels = 4, n_classes = 2, shuffle = True):
        self.dim = dim
        self.dataframe = dataframe
        self.batch_size = batch_size
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        return int(np.floor(len(self.dataframe) / self.batch_size))

    def __getitem__(self, index):
        X, y = self.__data_generation(index)
        return X, y

    def on_epoch_end(self):
      if self.shuffle:
        self.dataframe = self.dataframe.sample(frac = 1).reset_index(drop = True)
        
    def __data_generation(self, index):
        X = np.empty((self.batch_size, *self.dim, self.n_channels))
        y = np.empty((self.batch_size, *self.dim, self.n_classes))

        for i in range(index, index + self.batch_size):
            X[i - index] = np.load(self.dataframe.iloc[i]['image'])
            y[i - index] = np.load(self.dataframe.iloc[i]['mask'])
        return X, y

In [None]:
def RunModel(batchSize, pocket):
    
    # Load dataframe with images and masks paths
    train = pd.read_csv('profile_data.csv')

    # Create logs directory based on architecture and batch size
    if pocket:
        print('Running pocket u-net with batch size ' + str(int(batchSize)))
        logs = 'logs/' + 'pocket_unet_batchsize_' + str(batchSize)
    else:
        print('Running full u-net with batch size ' + str(int(batchSize)))
        logs = 'logs/' + 'full_unet_batchsize_' + str(batchSize)
            
    # Create data generator
    # Change directory input to match your environment
    trainGenerator = DataGenerator(dataframe = train, batch_size = batchSize)

    # Create model
    model = PocketNet(inputShape = (5, 240, 240, 4), 
                      numClasses = 2, 
                      mode = 'seg', 
                      net = 'unet', 
                      pocket = pocket, 
                      initFilters = 16, 
                      depth = 4)
    
    # Compile model with Dice loss
    model.compile(optimizer = 'adam', loss = [DiceLossL2])

    # Tensorboard callback set up
    tboard = tf.keras.callbacks.TensorBoard(log_dir = logs,
                                            histogram_freq = 1,
                                            profile_batch = (1, 10))
    # Train model
    model.fit(trainGenerator, 
              epochs = 5, 
              steps_per_epoch = len(train) // batchSize,
              callbacks = [tboard])
    
    ##### END OF FUNCTION#####

In [None]:
for pocket in [True, False]:
  for batch in [1, 4, 8, 16, 32]:
    RunModel(batchSize = batch, pocket = pocket)

Use the TensorBoard Profiler to collect training memory usage and average time per step for each model and batch size.

In [None]:
# Load the TensorBoard notebook extension.
%load_ext tensorboard

# Launch TensorBoard and navigate to the Profile tab to view performance profile
%tensorboard --logdir=logs