## LuNet

as described in : https://arxiv.org/pdf/1703.07737.pdf

In [5]:
from tensorflow.keras.layers import Input, Dense, BatchNormalization, GlobalAveragePooling2D, Flatten
from tensorflow.keras.layers import ReLU, Conv2D, MaxPool2D, LeakyReLU, Add, Conv3D
from tensorflow.keras import Model

def ResBlock(tensor, n1, n2, n3):
    x = Conv2D(filters=n1, kernel_size=(1,1), strides=1, padding='same', kernel_initializer="he_normal")(tensor)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.3)(x)
    x = Conv2D(filters=n2, kernel_size=(3,3), strides=1, padding='same', kernel_initializer="he_normal")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.3)(x)
    x = Conv2D(filters=n3, kernel_size=(1,1), strides=1, padding='same', kernel_initializer="he_normal")(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.3)(x)
    if n3!=n1:
        shortcut = Conv2D(filters=n3, kernel_size=(1,1), strides=1, padding='same', kernel_initializer='he_normal')(tensor)
        x = Add()([shortcut, x])
    else:
        x = Add()([tensor, x])
    return x
    
def ResBlock2(tensor, n1, n2):
    x = Conv2D(filters=n1, kernel_size=(3,3), padding='same', kernel_initializer='he_normal')(tensor)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.3)(x)
    x = Conv2D(filters=n2, kernel_size=(3,3), strides = 1, padding='same', kernel_initializer='he_normal')(x)
    x = BatchNormalization()(x)
    x = LeakyReLU(alpha=0.3)(x)
    shortcut = Conv2D(filters=n2, kernel_size=(1,1), strides=1, padding='same', kernel_initializer='he_normal')(tensor)
    x = Add()([shortcut, x])
    return x

    
    
input_layer = Input(shape=(128, 64, 3))
x = Conv2D(filters=128, kernel_size=(7,7))(input_layer)
x = ResBlock(x, 128, 32, 128)
x = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='same')(x)
x = ResBlock(x, 128, 32, 128)
x = ResBlock(x, 128, 32, 128)
x = ResBlock(x, 128, 64, 256)
x = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='same')(x)
x = ResBlock(x, 256, 64, 256)
x = ResBlock(x, 256, 64, 256)
x = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='same')(x)
x = ResBlock(x, 256, 64, 256)
x = ResBlock(x, 256, 64, 256)
x = ResBlock(x, 256, 128, 512)
x = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='same')(x)
x = ResBlock(x, 512, 128, 512)
x = ResBlock(x, 512, 128, 512)
x = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='same')(x)
x = ResBlock2(x, 512, 128)
x = Flatten()(x)
x = Dense(512)(x)
x = BatchNormalization()(x)
x = ReLU()(x)
output = Dense(128)(x)
              
model = Model(inputs=input_layer, outputs=output)
model.summary()


Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_4 (InputLayer)            [(None, 128, 64, 3)] 0                                            
__________________________________________________________________________________________________
conv2d_88 (Conv2D)              (None, 122, 58, 128) 18944       input_4[0][0]                    
__________________________________________________________________________________________________
conv2d_89 (Conv2D)              (None, 122, 58, 128) 16512       conv2d_88[0][0]                  
__________________________________________________________________________________________________
batch_normalization_72 (BatchNo (None, 122, 58, 128) 512         conv2d_89[0][0]                  
____________________________________________________________________________________________