## 'Tiny' Resnet 3D Model

## 1. Libraries

In [1]:
#########################################################################
# 01. Libraries

import tensorflow as tf
from tensorflow.keras import layers, models, optimizers, regularizers, constraints

#########################################################################

## 2. Model

In [2]:
#########################################################################
# 02. Model Functions


def convBlock(x, filters, pool, drop_rate, name):
    
    x_skip = x

    x = layers.Conv3D(filters, kernel_size=(3, 3, 3), 
                               kernel_regularizer=regularizers.l2(1e-4),
                               kernel_constraint=constraints.MaxNorm(1, axis=[0, 1, 2, 3]),
                               padding='same', name=name + '_' + 'Conv3D')(x)
    x = layers.BatchNormalization(name=name + '_' + 'Batch_Norm')(x)
    x = layers.LeakyReLU(0.1, name=name + '_' + 'Leaky_Ralu')(x)
    if drop_rate > 0:
        x = layers.Dropout(drop_rate, name=name + '_' + 'Dropout')(x)
    if pool:
        x = layers.MaxPool3D((2, 2, 2), name=name + '_' + 'Max_Pool')(x)

    return x


def resnet3D(input_shape=(32, 200, 200, 1)):

    x_in = layers.Input(shape=input_shape)
    
    x1 = convBlock(x_in, filters=16, pool=True, drop_rate=0.2, name='Block1')
    x1_ident = layers.AvgPool3D((2, 2, 2), name='Block1_Skip_Avg_Pool')(x_in)
    x1_concat = layers.concatenate([x1, x1_ident], name='Block1_Skip_Concat')
    
    x2 = convBlock(x1_concat, filters=32, pool=True, drop_rate=0.3, name='Block2')
    x2_ident = layers.AvgPool3D((2, 2, 2), name='Block2_Skip_Avg_Pool')(x1_ident)
    x2_concat = layers.concatenate([x2, x2_ident], name='Block2_Skip_Concat')
    
    x3 = convBlock(x2_concat, filters=64, pool=True, drop_rate=0.3, name='Block3')
    x3_ident = layers.AvgPool3D((2, 2, 2), name='Block3_Skip_Avg_Pool')(x2_ident)
    x3_concat = layers.concatenate([x3, x3_ident], name='Block3_Skip_Concat')
    
    x4 = convBlock(x3_concat, filters=72, pool=True, drop_rate=0.4, name='Block4')
    x4_ident = layers.AvgPool3D((2, 2, 2), name='Block4_Skip_Avg_Pool')(x3_ident)
    x4_concat = layers.concatenate([x4, x4_ident], name='Block4_Skip_Concat')
    
    x5 = convBlock(x4_concat, filters=128, pool=True, drop_rate=0.4, name='Block5')
    x5_ident = layers.AvgPool3D((2, 2, 2), name='Block5_Skip_Avg_Pool')(x4_ident)
    x5_concat = layers.concatenate([x5, x5_ident], name='Block5_Skip_Concat')
    
    x6 = convBlock(x5_concat, filters=256, pool=False, drop_rate=0., name='Block6')
    x_out = layers.GlobalMaxPool3D(name='Block5_Global_Max_Pool')(x6)

    model = models.Model(inputs=x_in, outputs=x_out, name='Resnet3D')

    return model
#########################################################################

In [3]:
#########################################################################

model = resnet3D(input_shape=(None, None, None, 1))
print(model.summary(line_length=130))

tf.keras.models.save_model(model=model, filepath='resnet3D.h5', include_optimizer=False)

#########################################################################

Model: "Resnet3D"
__________________________________________________________________________________________________________________________________
Layer (type)                              Output Shape                 Param #         Connected to                               
input_1 (InputLayer)                      [(None, None, None, None, 1) 0                                                          
__________________________________________________________________________________________________________________________________
Block1_Conv3D (Conv3D)                    (None, None, None, None, 16) 448             input_1[0][0]                              
__________________________________________________________________________________________________________________________________
Block1_Batch_Norm (BatchNormalization)    (None, None, None, None, 16) 64              Block1_Conv3D[0][0]                        
_________________________________________________________________