# Test U-Net

## U-Net architecture

In [1]:
import importlib
import numpy as np
import tensorflow as tf

from custom_architectures import unet_arch, totalsegmentator_arch

def show_memory(msg = ''):
    print('{}{}'.format(msg if not msg else msg + '\t: ', {
        k : '{:.3f}'.format(v / 1024 ** 3) for k, v in tf.config.experimental.get_memory_info('GPU:0').items()
    }))
    tf.config.experimental.reset_memory_stats('GPU:0')
    
gpus = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices([gpus[0]], 'GPU')

2023-04-13 10:34:13.647641: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-04-13 10:34:13.743638: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-04-13 10:34:13.767841: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until 

## Test U-Net 2D

In [2]:
importlib.reload(unet_arch)

model = unet_arch.UNet(
    input_shape = (512, 512, 1), output_dim  = 1,
    drop_rate = lambda i: 0. if i == 4 else 0.25
)
print(model.count_params())
model.compile(loss = 'binary_crossentropy', optimizer = 'adam')

show_memory()

2023-03-23 15:35:39.337311: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-23 15:35:39.707179: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14783 MB memory:  -> device: 0, name: Quadro RTX 5000, pci bus id: 0000:17:00.0, compute capability: 7.5


8611425
{'current': '0.032', 'peak': '0.035'}


In [3]:
batch_size = 24

show_memory('Before')

inp = tf.random.uniform((batch_size * 2, ) + tuple(model.input_shape[1:]), 0., 1.)
out = tf.ones_like(inp)

"""show_memory('Batch init')

model(inp[:batch_size], training = True)

show_memory('Simple call')

with tf.GradientTape() as tape:
    pred = model(inp[:batch_size], training = True)
    l = model.compiled_loss(out[:batch_size], pred)

show_memory('With gradient')

del tape, l, pred

show_memory('After deleting tape')
"""
_ = model.fit(inp, out, epochs = 5, batch_size = batch_size)

show_memory('After fit')
tf.config.experimental.reset_memory_stats('GPU:0')

Before	: {'current': '0.032', 'peak': '0.035'}
Epoch 1/5


2023-03-23 15:35:41.686235: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8204
2023-03-23 15:35:49.280779: W tensorflow/core/common_runtime/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.08GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2023-03-23 15:35:49.280808: W tensorflow/core/common_runtime/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 5.08GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2023-03-23 15:35:49.683526: W tensorflow/core/common_runtime/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 1.52GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains

Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
After fit	: {'current': '0.219', 'peak': '13.966'}


In [4]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_image (InputLayer)       [(None, 512, 512, 1  0           []                               
                                )]                                                                
                                                                                                  
 down_conv1 (Conv2D)            (None, 512, 512, 32  320         ['input_image[0][0]']            
                                )                                                                 
                                                                                                  
 activation (Activation)        (None, 512, 512, 32  0           ['down_conv1[0][0]']             
                                )                                                             

## Test AM-UNet 2D

In [6]:
importlib.reload(unet_arch)

model = unet_arch.UNet(
    input_shape = (512, 512, 1),
    output_dim  = 1,
    
    n_stages    = 5,
    n_conv_per_stage = lambda i: 1 if i <= 1 else 2,
    filters     = [16, 32, 64, 128, 128],
    bnorm       = 'after',
    activation  = 'relu',
    drop_rate   = lambda i: 0. if i == 0 else 0.25,
    
    n_middle_stages = 4,
    n_middle_conv   = 2,
    middle_filters  = 64,
    middle_bnorm    = 'after'
)
print(model.count_params())
model.compile(loss = 'binary_crossentropy', optimizer = 'adam')

show_memory()

2254065
{'current': '0.370', 'peak': '0.371'}


In [8]:
batch_size = 42

show_memory('Before')

inp = tf.random.uniform((batch_size * 2, ) + tuple(model.input_shape[1:]), 0., 1.)
out = tf.ones_like(inp)

"""show_memory('Batch init')

model(inp[:batch_size], training = True)

show_memory('Simple call')

with tf.GradientTape() as tape:
    pred = model(inp[:batch_size], training = True)
    l = model.compiled_loss(out[:batch_size], pred)

show_memory('With gradient')

del tape, l, pred

show_memory('After deleting tape')
"""
_ = model.fit(inp, out, epochs = 5, batch_size = batch_size)

show_memory('After fit')
tf.config.experimental.reset_memory_stats('GPU:0')

Before	: {'current': '0.298', 'peak': '0.298'}
Epoch 1/5


2023-03-23 16:02:17.752060: W tensorflow/core/kernels/gpu_utils.cc:50] Failed to allocate memory for convolution redzone checking; skipping this check. This is benign and only means that we won't check cudnn for out-of-bounds reads and writes. This message will only be printed once.


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
After fit	: {'current': '0.484', 'peak': '13.884'}


In [5]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_image (InputLayer)       [(None, 512, 512, 1  0           []                               
                                )]                                                                
                                                                                                  
 down_conv1 (Conv2D)            (None, 512, 512, 16  160         ['input_image[0][0]']            
                                )                                                                 
                                                                                                  
 down_bn1 (BatchNormalization)  (None, 512, 512, 16  64          ['down_conv1[0][0]']             
                                )                                                             

## Test AM-UNet 3D

In [2]:
importlib.reload(unet_arch)

model = unet_arch.UNet(
    input_shape = (None, 512, 512, 1),
    output_dim  = 1,
    
    n_stages    = 4,
    n_conv_per_stage = lambda i: 1 if i <= 2 else 2,
    filters     = [16, 32, 64, 128, 128],
    bnorm       = 'never',
    activation  = 'relu',
    drop_rate   = lambda i: 0. if i == 0 else 0.25,
    
    n_middle_stages = 2,
    n_middle_conv   = 2,
    middle_filters  = 64,
    middle_bnorm    = 'never',
    
    mixed_precision = True
)
print(model.count_params())
model.compile(loss = 'binary_crossentropy', optimizer = 'adam')

show_memory()

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPUs will likely run quickly with dtype policy mixed_float16 as they all have compute capability of at least 7.0
Mixed precision compatibility check (mixed_float16): OK
Your GPUs will likely run quickly with dtype policy mixed_float16 as they all have compute capability of at least 7.0


2023-03-24 10:26:09.319656: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-24 10:26:09.691052: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14783 MB memory:  -> device: 0, name: Quadro RTX 5000, pci bus id: 0000:17:00.0, compute capability: 7.5


3416209
{'current': '0.013', 'peak': '0.016'}


In [5]:
batch_size = 1
seq_len    = 64

show_memory('Before')

inp = tf.random.uniform((batch_size * 2, seq_len) + tuple(model.input_shape[2:]), 0., 1.)
out = tf.ones_like(inp)

_ = model.fit(inp, out, epochs = 5, batch_size = batch_size)

show_memory('After fit')
tf.config.experimental.reset_memory_stats('GPU:0')

Before	: {'current': '0.518', 'peak': '7.675'}
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
After fit	: {'current': '0.768', 'peak': '12.750'}


In [3]:
batch_size = 1
for i in range(2, 10):
    seq_len    = 32 * i
    print(seq_len)
    show_memory('Before')

    inp = tf.random.uniform((batch_size, seq_len) + tuple(model.input_shape[2:]), 0., 1.)
    _ = model(inp)

    show_memory('After fit')
    tf.config.experimental.reset_memory_stats('GPU:0')

64
Before	: {'current': '0.013', 'peak': '0.016'}


2023-03-24 10:26:16.412280: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8204


After fit	: {'current': '0.200', 'peak': '3.888'}
96
Before	: {'current': '0.200', 'peak': '0.200'}
After fit	: {'current': '0.294', 'peak': '5.935'}
128
Before	: {'current': '0.294', 'peak': '0.294'}
After fit	: {'current': '0.450', 'peak': '7.920'}
160
Before	: {'current': '0.450', 'peak': '0.450'}
After fit	: {'current': '0.575', 'peak': '9.967'}
192
Before	: {'current': '0.575', 'peak': '0.575'}
After fit	: {'current': '0.638', 'peak': '11.983'}
224
Before	: {'current': '0.638', 'peak': '0.638'}


2023-03-24 10:26:36.306704: W tensorflow/core/common_runtime/bfc_allocator.cc:290] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.25GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.


After fit	: {'current': '0.669', 'peak': '10.369'}
256
Before	: {'current': '0.669', 'peak': '0.669'}


KeyboardInterrupt: 

In [4]:
print(256 * 3 / 10)

76.8


In [None]:
model.summary()

## Test U-Net 3D

In [2]:
importlib.reload(unet_arch)

model = unet_arch.UNet(
    input_shape = (128, 512, 512, 1), output_dim  = 1, n_stages    = 6, n_conv_per_stage = [1, 1, 1, 2, 2, 2],
    drop_rate = lambda i: 0. if i == 5 else 0.25,
    concat_mode = 'add',
    pool_strides = [2, 2, 2, 2, 2],
    filters = [8, 16, 32, 64, 128, 256]
)
print(model.count_params())
model.compile(loss = 'binary_crossentropy', optimizer = 'adam')

show_memory()

2023-03-21 10:47:32.945584: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-21 10:47:33.323590: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14783 MB memory:  -> device: 0, name: Quadro RTX 5000, pci bus id: 0000:17:00.0, compute capability: 7.5


5823497
{'current': '0.022', 'peak': '0.026'}


In [3]:
batch_size = 1

show_memory('Before')

inp = tf.random.uniform((batch_size * 2, ) + tuple(model.input_shape[1:]), 0., 1.)
out = tf.ones_like(inp)

_ = model.fit(inp, out, epochs = 5, batch_size = batch_size)

show_memory('After fit')
tf.config.experimental.reset_memory_stats('GPU:0')

Before	: {'current': '0.022', 'peak': '0.026'}
Epoch 1/5


2023-03-21 10:47:35.758813: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8204


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
After fit	: {'current': '0.905', 'peak': '13.136'}


In [4]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_image (InputLayer)       [(None, 128, 512, 5  0           []                               
                                12, 1)]                                                           
                                                                                                  
 down_conv1 (Conv3D)            (None, 128, 512, 51  224         ['input_image[0][0]']            
                                2, 8)                                                             
                                                                                                  
 activation (Activation)        (None, 128, 512, 51  0           ['down_conv1[0][0]']             
                                2, 8)                                                         

## Test UNet 3D with strides

In [2]:
importlib.reload(unet_arch)

model = unet_arch.UNet(
    input_shape = (128, 512, 512, 1), output_dim  = 1, n_stages = 6, n_conv_per_stage = 2,
    drop_rate = lambda i: 0. if i == 5 else 0.25,
    concat_mode = 'add',
    pool_type = None,
    strides = 2,
    filters = [8, 16, 32, 64, 128, 256]
)
print(model.count_params())
model.compile(loss = 'binary_crossentropy', optimizer = 'adam')

show_memory()

2023-03-21 11:18:44.599095: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-21 11:18:44.979307: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14783 MB memory:  -> device: 0, name: Quadro RTX 5000, pci bus id: 0000:17:00.0, compute capability: 7.5


0 (None, 128, 512, 512, 8)
1 (None, 64, 256, 256, 16)
2 (None, 32, 128, 128, 32)
3 (None, 16, 64, 64, 64)
4 (None, 8, 32, 32, 128)
5 (None, 4, 16, 16, 256)
4 (None, 8, 32, 32, 128) (None, 8, 32, 32, 128) 128
3 (None, 16, 64, 64, 64) (None, 16, 64, 64, 64) 64
2 (None, 32, 128, 128, 32) (None, 32, 128, 128, 32) 32
1 (None, 64, 256, 256, 16) (None, 64, 256, 256, 16) 16
0 (None, 128, 512, 512, 8) (None, 128, 512, 512, 8) 8
5896185
{'current': '0.022', 'peak': '0.026'}


In [3]:
batch_size = 1

show_memory('Before')

inp = tf.random.uniform((batch_size * 2, ) + tuple(model.input_shape[1:]), 0., 1.)
out = tf.ones_like(inp)

_ = model.fit(inp, out, epochs = 5, batch_size = batch_size)

show_memory('After fit')
tf.config.experimental.reset_memory_stats('GPU:0')

Before	: {'current': '0.022', 'peak': '0.026'}
Epoch 1/5


2023-03-21 11:18:53.017038: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8204


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
After fit	: {'current': '0.910', 'peak': '13.031'}


In [13]:
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_image (InputLayer)       [(None, 128, 512, 5  0           []                               
                                12, 1)]                                                           
                                                                                                  
 down_conv1 (Conv3D)            (None, 128, 512, 51  224         ['input_image[0][0]']            
                                2, 8)                                                             
                                                                                                  
 activation_123 (Activation)    (None, 128, 512, 51  0           ['down_conv1[0][0]']             
                                2, 8)                                                       

## Test classifier AM-UNet 3D

In [2]:
importlib.reload(unet_arch)

model = unet_arch.UNet(
    input_shape = (512, 512, None, 1),
    output_dim  = 104,
    final_activation = 'softmax',
    
    n_stages    = 4,
    n_conv_per_stage = lambda i: 1,
    up_n_conv_per_stage = lambda i: min(i, 1),
    filters     = list(np.array([16, 32, 64, 128, 128])),
    bnorm       = 'never',
    activation  = 'relu',
    drop_rate   = lambda i: 0. if i == 0 else 0.25,
    
    n_middle_stages = 2,
    n_middle_conv   = 2,
    middle_filters  = 64,
    middle_bnorm    = 'never',
    
    concat_mode = lambda i: 'concat' if i > 0 else None,
    
    mixed_precision = True
)
print(model.count_params())
model.compile(loss = 'sparse_categorical_crossentropy', optimizer = 'adam')

show_memory()

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPUs will likely run quickly with dtype policy mixed_float16 as they all have compute capability of at least 7.0
Mixed precision compatibility check (mixed_float16): OK
Your GPUs will likely run quickly with dtype policy mixed_float16 as they all have compute capability of at least 7.0


2023-03-30 15:01:57.966552: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-03-30 15:01:58.353265: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 535 MB memory:  -> device: 0, name: Quadro RTX 5000, pci bus id: 0000:17:00.0, compute capability: 7.5


2519128
{'current': '0.009', 'peak': '0.014'}


In [3]:
batch_size = 1
seq_len    = 48

show_memory('Before')

inp = tf.random.uniform((batch_size * 2, seq_len) + tuple(model.input_shape[2:]), 0., 1.)
out = tf.random.uniform(inp.shape[:-1], 0, model.output_shape[-1], dtype = tf.int32)

print('Input shape : {} - Output shape : {}'.format(inp.shape, out.shape))

_ = model.fit(inp, out, epochs = 5, batch_size = batch_size)

show_memory('After fit')
tf.config.experimental.reset_memory_stats('GPU:0')

Before	: {'current': '0.009', 'peak': '0.014'}
Input shape : (2, 48, 512, 512, 1) - Output shape : (2, 48, 512, 512)
Epoch 1/5


2023-03-29 15:38:52.065510: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8204


Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
After fit	: {'current': '0.216', 'peak': '8.909'}


In [None]:
batch_size = 1
for i in range(1, 3):
    seq_len    = 16 * i
    show_memory('Before')

    inp = tf.random.uniform((batch_size, seq_len) + tuple(model.input_shape[2:]), 0., 1.)
    print('Input shape  : {}'.format(inp.shape))
    out = model(inp)
    print('Output shape : {}'.format(inp.shape))
    del out
    
    show_memory('After call')
    tf.config.experimental.reset_memory_stats('GPU:0')

In [3]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_image (InputLayer)       [(None, 512, 512, N  0           []                               
                                one, 1)]                                                          
                                                                                                  
 down_conv1 (Conv3D)            (None, 512, 512, No  448         ['input_image[0][0]']            
                                ne, 16)                                                           
                                                                                                  
 activation (Activation)        (None, 512, 512, No  0           ['down_conv1[0][0]']             
                                ne, 16)                                                       

## Test TotalSegmentator

In [2]:
importlib.reload(totalsegmentator_arch)

model = totalsegmentator_arch.TotalSegmentator(
    input_shape = (None, None, None, 1),
    output_dim  = 105,
    pretrained  = None,
    manual_padding = True,
    
    #n_conv_per_stage = lambda i: 1 if i <= 2 else 2,
    #up_n_conv_per_stage = 1,
    #filters = [16, 32, 64, 128, 128, 128],
    drop_rate   = lambda i: 0. if i == 0 else 0.25,

)
print(model.count_params())
model.compile(loss = 'sparse_categorical_crossentropy', optimizer = 'adam')

show_memory()

2023-04-13 10:34:23.532705: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-04-13 10:34:23.916325: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14783 MB memory:  -> device: 0, name: Quadro RTX 5000, pci bus id: 0000:17:00.0, compute capability: 7.5


30479744
{'current': '0.119', 'peak': '0.130'}


In [4]:
batch_size = 1
seq_len    = 48
img_size   = (256, 256, 1)

show_memory('Before')

inp = tf.random.uniform((batch_size * 2, seq_len) + img_size, 0., 1.)
out = tf.random.uniform(inp.shape[:-1], 0, model.output_shape[-1], dtype = tf.int32)

print('Input shape : {} - Output shape : {}'.format(inp.shape, out.shape))

_ = model.fit(inp, out, epochs = 5, batch_size = batch_size)

show_memory('After fit')
tf.config.experimental.reset_memory_stats('GPU:0')

Before	: {'current': '0.402', 'peak': '0.402'}
Input shape : (2, 48, 256, 256, 1) - Output shape : (2, 48, 256, 256)
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
After fit	: {'current': '0.449', 'peak': '12.968'}


In [9]:
batch_size = 1
img_size   = (512, 512, 1)
for i in range(1, 4):
    seq_len    = 16 * i
    show_memory('Before')

    inp = tf.random.uniform((batch_size, seq_len) + img_size, 0., 1.)
    print('Input shape  : {}'.format(inp.shape))
    out = model(inp)
    print('Output shape : {}'.format(out.shape))
    del out
    
    show_memory('After call')
    tf.config.experimental.reset_memory_stats('GPU:0')

Before	: {'current': '0.212', 'peak': '0.212'}
Input shape  : (1, 16, 512, 512, 1)
Output shape : (1, 16, 512, 512, 105)
After call	: {'current': '0.181', 'peak': '7.992'}
Before	: {'current': '0.181', 'peak': '0.181'}
Input shape  : (1, 32, 512, 512, 1)
Output shape : (1, 32, 512, 512, 105)
After call	: {'current': '0.196', 'peak': '8.671'}
Before	: {'current': '0.196', 'peak': '0.196'}
Input shape  : (1, 48, 512, 512, 1)
Output shape : (1, 48, 512, 512, 105)
After call	: {'current': '0.212', 'peak': '12.455'}


In [4]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_image (InputLayer)       [(None, None, None,  0           []                               
                                 None, 1)]                                                        
                                                                                                  
 zero_padding3d (ZeroPadding3D)  (None, None, None,   0          ['input_image[0][0]']            
                                None, 1)                                                          
                                                                                                  
 conv_blocks_context/0/blocks/0  (None, None, None,   896        ['zero_padding3d[0][0]']         
 /conv (Conv3D)                 None, 32)                                                     