In [1]:
import numpy as np
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers

In [2]:
def get_model_memory_usage(batch_size, model):
    import numpy as np
    try:
        from keras import backend as K
    except:
        from tensorflow.keras import backend as K

    shapes_mem_count = 0
    internal_model_mem_count = 0
    for l in model.layers:
        layer_type = l.__class__.__name__
        if layer_type == 'Model':
            internal_model_mem_count += get_model_memory_usage(batch_size, l)
        single_layer_mem = 1
        out_shape = l.output_shape
        if type(out_shape) is list:
            out_shape = out_shape[0]
        for s in out_shape:
            if s is None:
                continue
            single_layer_mem *= s
        shapes_mem_count += single_layer_mem

    trainable_count = np.sum([K.count_params(p) for p in model.trainable_weights])
    non_trainable_count = np.sum([K.count_params(p) for p in model.non_trainable_weights])

    number_size = 4.0
    if K.floatx() == 'float16':
        number_size = 2.0
    if K.floatx() == 'float64':
        number_size = 8.0

    total_memory = number_size * (batch_size * shapes_mem_count + trainable_count + non_trainable_count)
    gbytes = np.round(total_memory / (1024.0 ** 3), 3) + internal_model_mem_count
    return gbytes

In [3]:
CUBE_SIZE = 512
NUM_CHANNELS = 3

In [4]:
def get_model(width=128, height=128, depth=64, channel=1):
    """Build a 3D convolutional neural network model."""

    inputs = keras.Input((width, height, depth, channel))

    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(inputs)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=64, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=128, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv3D(filters=256, kernel_size=3, activation="relu")(x)
    x = layers.MaxPool3D(pool_size=2)(x)
    x = layers.BatchNormalization()(x)

    x = layers.GlobalAveragePooling3D()(x)
    x = layers.Dense(units=512, activation="relu")(x)
    x = layers.Dropout(0.3)(x)

    outputs = layers.Dense(units=10, activation="softmax")(x)

    # Define the model.
    model = keras.Model(inputs, outputs, name="3dcnn")
    return model


# Build model.
model = get_model(width=CUBE_SIZE, height=CUBE_SIZE, depth=CUBE_SIZE, channel=NUM_CHANNELS)
model.summary()

Model: "3dcnn"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 512, 512, 512, 3  0         
                             )]                                  
                                                                 
 conv3d (Conv3D)             (None, 510, 510, 510, 64  5248      
                             )                                   
                                                                 
 max_pooling3d (MaxPooling3D  (None, 255, 255, 255, 64  0        
 )                           )                                   
                                                                 
 batch_normalization (BatchN  (None, 255, 255, 255, 64  256      
 ormalization)               )                                   
                                                                 
 conv3d_1 (Conv3D)           (None, 253, 253, 253, 64  110656

In [5]:
for batch_size in [4, 8, 16, 32]:
  mem_size = get_model_memory_usage(batch_size, model)
  print("Batch Size:{} , Memory Usage: {}GB".format(batch_size, mem_size))

Batch Size:4 , Memory Usage: 188.973GB
Batch Size:8 , Memory Usage: 377.94GB
Batch Size:16 , Memory Usage: 755.875GB
Batch Size:32 , Memory Usage: 1511.745GB
