In [52]:
from keras.layers import InputLayer, Conv2D, Dropout, Input, Lambda, Concatenate, Conv2DTranspose, MaxPooling2D
from keras.models import Model
from typing import Tuple 
from models.layers import ConvBlock, DeconvBlock
from keras import Sequential

import tensorflow as tf 

IMG_SIZE = (128, 128, 3)
        
class SkipAutoEncoder(Model):
    def __init__(self, img_size: Tuple[int, int, int], **kwargs):
        super(SkipAutoEncoder, self).__init__(**kwargs)
        # encoder 
        self.conv1 = ConvBlock(64, 4, 2, bias=False, batch_norm=True,  name='conv1')
        self.conv2 = ConvBlock(128, 4, 2, bias=False, batch_norm=True,  name='conv2')
        self.conv3 = ConvBlock(256, 4, 2, bias=False, batch_norm=True,  name='conv3')
        self.conv4 = ConvBlock(512, 4, 2, bias=False, batch_norm=True,  name='conv4')
        self.conv5 = ConvBlock(512, 4, 2, bias=False, batch_norm=True,  name='conv5')
        
        # decoder 
        self.deconv5 = DeconvBlock(512, 4, 2, bias=False, name='deconv5')
        self.deconv4 = Sequential([Concatenate(-1), DeconvBlock(512, 4, 2, bias=False)], name='deconv4')
        self.deconv3 = Sequential([Concatenate(-1), DeconvBlock(512, 4, 2, bias=False)], name='deconv3')
        self.deconv2 = Sequential([Concatenate(-1), DeconvBlock(256, 4, 2, bias=False)], name='deconv2')
        self.deconv1 = Sequential([Concatenate(-1), DeconvBlock(128, 4, 2, bias=False)], name='deconv1')
        self.out = Sequential([
            DeconvBlock(512, 5, 2, bias=False),
            Conv2D(3, 1, activation='tanh')
        ], name='output')
        self.build((None, *img_size))
        
        
    def call(self, real: tf.Tensor):
        x1 = self.conv1(real)
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x4 = self.conv4(x3)
        x5 = self.conv5(x4)
        y4 = self.deconv5(x5)
        y3 = self.deconv4([y4, x4])
        y2 = self.deconv3([y3, x3])
        y1 = self.deconv1([y2, x2])
        return self.out(y1)
        


TensorShape([10, 128, 128, 3])

In [None]:
class SkipAutoEncoder(Model):
    def __init__(self, img_size: Tuple[int,] **kwargs):
        

In [19]:
tf.keras.backend.clear_session()
nb_filter = [32,64,128,256,512]
# Build U-Net++ model
inputs = Input((128, 128, 3))
s = Lambda(lambda x: x / 255) (inputs)


c1 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (s)
c1 = Dropout(0.5) (c1)
c1 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c1)
c1 = Dropout(0.5) (c1)
p1 = MaxPooling2D((2, 2), strides=(2, 2)) (c1)

c2 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (p1)
c2 = Dropout(0.5) (c2)
c2 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c2)
c2 = Dropout(0.5) (c2)
p2 = MaxPooling2D((2, 2), strides=(2, 2)) (c2)

up1_2 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up12', padding='same')(c2)
conv1_2 = Concatenate(-1, name='merge12')([up1_2, c1])
c3 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv1_2)
c3 = Dropout(0.5) (c3)
c3 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (c3)
c3 = Dropout(0.5) (c3)

conv3_1 = Conv2D(128, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (p2)
conv3_1 = Dropout(0.5) (conv3_1)
conv3_1 = Conv2D(128, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv3_1)
conv3_1 = Dropout(0.5) (conv3_1)
pool3 = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(conv3_1)

up2_2 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up22', padding='same')(conv3_1)
conv2_2 = Concatenate(-1, name='ax')([up2_2, c2])
conv2_2 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv2_2)
conv2_2 = Dropout(0.5) (conv2_2)
conv2_2 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv2_2)
conv2_2 = Dropout(0.5) (conv2_2)

up1_3 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up13', padding='same')(conv2_2)
conv1_3 = Concatenate(-1, name='merge13')([up1_3, c1, c3])
conv1_3 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv1_3)
conv1_3 = Dropout(0.5) (conv1_3)
conv1_3 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv1_3)
conv1_3 = Dropout(0.5) (conv1_3)

conv4_1 = Conv2D(256, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (pool3)
conv4_1 = Dropout(0.5) (conv4_1)
conv4_1 = Conv2D(256, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv4_1)
conv4_1 = Dropout(0.5) (conv4_1)
pool4 = MaxPooling2D((2, 2), strides=(2, 2), name='pool4')(conv4_1)

up3_2 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up32', padding='same')(conv4_1)
conv3_2 = Concatenate(-1, name='up')([up3_2, conv3_1])
conv3_2 = Conv2D(128, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv3_2)
conv3_2 = Dropout(0.5) (conv3_2)
conv3_2 = Conv2D(128, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv3_2)
conv3_2 = Dropout(0.5) (conv3_2)

up2_3 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up23', padding='same')(conv3_2)
conv2_3 = Concatenate(-1, name='merge23')([up2_3, c2, conv2_2])
conv2_3 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv2_3)
conv2_3 = Dropout(0.5) (conv2_3)
conv2_3 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv2_3)
conv2_3 = Dropout(0.5) (conv2_3)

up1_4 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up14', padding='same')(conv2_3)
conv1_4 = Concatenate(-1, name='merge14')([up1_4, c1, c3, conv1_3])
conv1_4 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv1_4)
conv1_4 = Dropout(0.5) (conv1_4)
conv1_4 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv1_4)
conv1_4 = Dropout(0.5) (conv1_4)

conv5_1 = Conv2D(512, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (pool4)
conv5_1 = Dropout(0.5) (conv5_1)
conv5_1 = Conv2D(512, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv5_1)
conv5_1 = Dropout(0.5) (conv5_1)

up4_2 = Conv2DTranspose(nb_filter[3], (2, 2), strides=(2, 2), name='up42', padding='same')(conv5_1)
conv4_2 = Concatenate(-1, name='adfsx')([up4_2, conv4_1])
conv4_2 = Conv2D(256, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv4_2)
conv4_2 = Dropout(0.5) (conv4_2)
conv4_2 = Conv2D(256, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv4_2)
conv4_2 = Dropout(0.5) (conv4_2)

up3_3 = Conv2DTranspose(nb_filter[2], (2, 2), strides=(2, 2), name='up33', padding='same')(conv4_2)
conv3_3 = Concatenate(-1, name='merge33')([up3_3, conv3_1, conv3_2])
conv3_3 = Conv2D(128, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv3_3)
conv3_3 = Dropout(0.5) (conv3_3)
conv3_3 = Conv2D(128, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv3_3)
conv3_3 = Dropout(0.5) (conv3_3)

up2_4 = Conv2DTranspose(nb_filter[1], (2, 2), strides=(2, 2), name='up24', padding='same')(conv3_3)
conv2_4 = Concatenate(-1, name='merge24')([up2_4, c2, conv2_2, conv2_3])
conv2_4 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv2_4)
conv2_4 = Dropout(0.5) (conv2_4)
conv2_4 = Conv2D(64, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv2_4)
conv2_4 = Dropout(0.5) (conv2_4)

up1_5 = Conv2DTranspose(nb_filter[0], (2, 2), strides=(2, 2), name='up15', padding='same')(conv2_4)
conv1_5 = Concatenate(-1, name='merge15')([up1_5, c1, c3, conv1_3, conv1_4])
conv1_5 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv1_5)
conv1_5 = Dropout(0.5) (conv1_5)
conv1_5 = Conv2D(32, (3, 3), activation='elu', kernel_initializer='he_normal', padding='same') (conv1_5)
conv1_5 = Dropout(0.5) (conv1_5)

nestnet_output_4 = Conv2D(1, (1, 1), activation='sigmoid', kernel_initializer = 'he_normal',  name='output_4', padding='same')(conv1_5)

model = Model([inputs], [nestnet_output_4])
# model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), loss=bce_dice_loss)

# tf.keras.utils.plot_model(
#     model,
#     show_shapes=False,
#     show_dtype=False,
#     show_layer_names=True,
#     rankdir='TB',
#     expand_nested=False,
#     dpi=46,
#     layer_range=None
# )
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_1 (InputLayer)        [(None, 128, 128, 3)]        0         []                            
                                                                                                  
 lambda (Lambda)             (None, 128, 128, 3)          0         ['input_1[0][0]']             
                                                                                                  
 conv2d (Conv2D)             (None, 128, 128, 32)         896       ['lambda[0][0]']              
                                                                                                  
 dropout (Dropout)           (None, 128, 128, 32)         0         ['conv2d[0][0]']              
                                                                                              