In [None]:
# importing necessary libraries

from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.layers import *

In [None]:
# From : https://transresunet.github.io/main_architecture

def res_block(inputs,filter_size):
    """
    res_block -- Residual block for building res path
    
    Arguments:
    inputs {<class 'tensorflow.python.framework.ops.Tensor'>} -- input for residual block
    filter_size {int} -- convolutional filter size 
    
    Returns:
    add {<class 'tensorflow.python.framework.ops.Tensor'>} -- addition of two convolutional filter output  
    """
    # First Conv2D layer
    cb1 = Conv2D(filter_size,(3,3),padding = 'same',activation="relu")(inputs)
    # Second Conv2D layer parallel to the first one
    cb2 = Conv2D(filter_size,(1,1),padding = 'same',activation="relu")(inputs)
    # Addition of cb1 and cb2
    add = Add()([cb1,cb2])
    
    return add

def res_path(inputs,filter_size,path_number):
    """
    res_path -- residual path / modified skip connection
    
    Arguments:
    inputs {<class 'tensorflow.python.framework.ops.Tensor'>} -- input for res path
    filter_size {int} -- convolutional filter size 
    path_number {int} -- path identifier 
    
    Returns:
    skip_connection {<class 'tensorflow.python.framework.ops.Tensor'>} -- final res path
    """
    # Minimum one residual block for every res path
    skip_connection = res_block(inputs, filter_size)

    # Two serial residual blocks for res path 2
    if path_number == 2:
        skip_connection = res_block(skip_connection,filter_size)
    
    # Three serial residual blocks for res path 1
    elif path_number == 1:
        skip_connection = res_block(skip_connection,filter_size)
        skip_connection = res_block(skip_connection,filter_size)
    
    return skip_connection

In [None]:
def decoder_block(inputs, res, out_channels, depth):
    conv_kwargs = dict(
        activation='relu',
        padding='same',
        kernel_initializer='he_normal',
        data_format='channels_last'  
    )
    
    # UpConvolutional layer
    db = UpSampling2D((2, 2), interpolation='bilinear')(inputs)
    db = concatenate([db, res], axis=3)
    # First conv2D layer 
    db = Conv2D(out_channels, 3, **conv_kwargs)(db)
    # Second conv2D layer
    db = Conv2D(out_channels, 3, **conv_kwargs)(db)

    if depth > 2:
        # Third conv2D layer
        db = Conv2D(out_channels, 3, **conv_kwargs)(db)

    return db

In [None]:
def TransResUNet(input_size=(512, 512, 1)):
    """
    TransResUNet -- main architecture of TransResUNet
    
    Arguments:
    input_size {tuple} -- size of input image
    
    Returns:
    model {<class 'tensorflow.python.keras.engine.training.Model'>} -- final model
    """
    
    # Input 
    inputs = Input(input_size)

    # VGG16 with imagenet weights
    encoder = VGG16(include_top=False, weights='imagenet', input_shape=input_size)
       
    # First encoder block
    enc1 = encoder.get_layer(name='block1_conv1')(inputs)
    enc1 = encoder.get_layer(name='block1_conv2')(enc1)
    enc2 = MaxPooling2D(pool_size=(2, 2))(enc1)
    
    # Second encoder block
    enc2 = encoder.get_layer(name='block2_conv1')(enc2)
    enc2 = encoder.get_layer(name='block2_conv2')(enc2)
    enc3 = MaxPooling2D(pool_size=(2, 2))(enc2)
    
    # Third encoder block
    enc3 = encoder.get_layer(name='block3_conv1')(enc3)
    enc3 = encoder.get_layer(name='block3_conv2')(enc3)
    enc3 = encoder.get_layer(name='block3_conv3')(enc3)
    center = MaxPooling2D(pool_size=(2, 2))(enc3)

    # Center block
    center = Conv2D(1024, (3, 3), activation='relu', padding='same')(center)
    center = Conv2D(1024, (3, 3), activation='relu', padding='same')(center)
    
    # classification branch
    cls = Conv2D(32, (3,3), activation='relu', padding='same')(center)
    cls = Conv2D(1, (1,1))(cls)
    cls = GlobalAveragePooling2D()(cls)
    cls = Activation('sigmoid', name='class')(cls)
    clsr = Reshape((1, 1, 1), name='reshape')(cls)
    
    # Decoder block corresponding to third encoder
    res_path3 = res_path(enc3,128,3)
    dec3 = decoder_block(center, enc3, 256, 3)
    
    # Decoder block corresponding to second encoder
    res_path2 = res_path(enc2,64,2)
    dec2 = decoder_block(dec3, enc2, 128, 2)
    
    # Final Block concatenation with first encoded feature 
    res_path1 = res_path(enc1,32,1)
    dec1 = decoder_block(dec2, enc1, 64, 1)

    # Output
    out = Conv2D(1, 1)(dec1)
    out = Activation('sigmoid')(out)
    out = multiply(inputs=[out,clsr], name='seg')
    
    # Final model
    model = Model(inputs=[inputs], outputs=[out, cls])
    
    return model