In [2]:

import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"

from math import log2
import tensorflow as tf
import tensorflow.keras.layers as L
from tensorflow.keras.models import Model

In [3]:
def mlp (inputs,config_dict):
    
    mlp = L.Dense(config_dict['mlp_dim'],activation = 'gelu')(inputs)
    mlp = L.Dropout(config_dict['dropout'])(mlp)
    mlp = L.Dense(config_dict['hidden_dim'], activation = 'gelu')(mlp)
    mlp = L.Dropout(config_dict['dropout'])(mlp)

    return mlp

In [24]:
def conv(inputs,filters,kernel = 3):
    
    conv = L.Conv2D(filters,kernel_size = kernel,padding = 'same')(inputs)
    conv = L.BatchNormalization()(conv)
    conv = L.ReLU()(conv)

    return conv

In [5]:
def deconv (inputs,filters,strides = 2):
    deconv = L.Conv2DTranspose(filters,kernel_size = 2, padding = 'same',strides = strides)(inputs)

    return deconv

In [18]:
# making function of Transformer Encoder
def trans_encoder(inputs,config_dict):
    residue_1 = inputs
    
    encode = L.LayerNormalization()(inputs)
    encode = L.MultiHeadAttention(
        num_heads = config_dict['heads'], key_dim = config_dict['hidden_dim']
    )(encode,encode)

    encode = L.Add()([encode,residue_1])
    
    residue_2 = encode

    
    encode = L.LayerNormalization()(encode)
    encode = mlp(encode,config_dict)

    encode = L.Add()([encode,residue_2])

    return encode
    

In [35]:
def Trans_Unet_2D(config_dict):
    # Here we will define Inputs

    input_shape = (config_dict['num_patches'],config_dict['patch_size'] * config_dict['patch_size'] * config_dict['num_channels'])
    inputs = L.Input(input_shape) #
    #print('Inputs shape',inputs.shape)
    ### Patch + Position Embedding ####
    patch_embedding = L.Dense(config_dict['hidden_dim'])(inputs)
        #print('Patch',patch_embedding)
    positions = tf.range(start = 0,limit = config_dict['num_patches'],delta = 1)
        #print('Positions',positions)
    position_embedding = L.Embedding(input_dim = config_dict['num_patches'],output_dim = config_dict['hidden_dim'])(positions)
        #print('Position Embeddings',position_embedding)
    x = patch_embedding + position_embedding
    #print('patch + position',x)

    ### The Transformer Encoder ###
    skip_connection_index = [3,6,9,12]
    skip_connections = []
 # Now we extract the layers in index 3,6,9 and 12 as skip connections
    for i in range(1,config_dict['num_layers']+1,1):
        x = trans_encoder(x,config_dict)

        if i in skip_connection_index:
            skip_connections.append(x)

    ## Now Defining CNN Decoder

    z3,z6,z9,z12 = skip_connections

    ## Reshaping
    z0 = L.Reshape((config_dict['image_size'],config_dict['image_size'],config_dict['num_channels']))(inputs)

    shape = (
        config_dict['image_size'] // config_dict['patch_size'],
        config_dict['image_size'] // config_dict['patch_size'],
        config_dict['hidden_dim']
    )
   
    z3 = L.Reshape(shape)(z3)
    z6 = L.Reshape(shape)(z6)
    z9 = L.Reshape(shape)(z9)
    z12 = L.Reshape(shape)(z12)
    print('z3 shape', z3.shape)
    ## Additional Layers for managing different Patch Sizes

    total_upscale_factor = int(log2(config_dict['patch_size']))
    upscale = total_upscale_factor - 4
    print('upscale factor',upscale)
    if upscale >=2:  ## Patch Size 16 or greater
        z3 = deconv(z3,z3.shape[-1], strides = 2**upscale)
        z6 = deconv(z6,z6.shape[-1], strides = 2**upscale)
        z9 = deconv(z9,z9.shape[-1], strides = 2**upscale)
        z12 = deconv(z12,z12.shape[-1],strides = 2**upscale)

    if upscale < 0: ## Patch_size smaller than 16
        p = 2**abs(upscale)
        z3 = L.MaxPool2D((p,p))(z3)
        z6 = L.MaxPool2D((p,p))(z6)
        z9 = L.MaxPool2D((p,p))(z9)
        z12 = L.MaxPool2D((p,p))(z12)

    ## Decoder 1
    x = deconv(z12,128)

    s = deconv(z9,128)
    s = conv(s,128)

    x = L.Concatenate()([x,s])

    x = conv(x,128)
    x = conv(x,128)


    ## Decoder 2

    x = deconv(x,64)

    s = conv(z6,64)
    s = deconv(s,64)
    s = deconv(s,64)
    s = conv(s,64)

    x = L.Concatenate()([x,s])

    x = conv(x,64)
    x = conv(x,64)

    ## Decoder 3

    x = deconv(x,32)

    s = deconv(z3,32)
    s = conv(s,32)
    s = deconv(s,32)
    s = conv(s,32)
    s = deconv(s,32)
    s = conv(s,32)

    x = L.Concatenate()([x,s])

    x = conv(x,32)
    x = conv(x,32)

    ## Decoder 4
    
    x = deconv(x,16)

    s = conv(z0,16)
    s = conv(s,16)

    x = L.Concatenate()([x,s])
    x = conv(x,16)
    x = conv(x,16)

    ## Output
    outputs = L.Conv2D(1,kernel_size = 1, padding= 'same', activation = 'sigmoid')(x)

    return Model(inputs,outputs, name = 'UNETR_2D')

In [12]:
if __name__=="__main__":
    config_dict = {}
    
    config_dict["num_layers"] =  12
    config_dict['image_size'] = 512
    config_dict['hidden_dim'] = 64
    config_dict['mlp_dim'] = 128
    config_dict['heads'] = 6
    config_dict['dropout'] = 0.1
    config_dict['patch_size'] = 1
    config_dict['num_patches'] = (config_dict['image_size']**2)//(config_dict['patch_size'])
    config_dict['num_channels'] = 3
    

In [36]:
 Trans_Unet_2D(config_dict)

Inputs shape (None, 262144, 3)
Patch <KerasTensor shape=(None, 262144, 64), dtype=float32, sparse=False, name=keras_tensor_1307>
Positions tf.Tensor([     0      1      2 ... 262141 262142 262143], shape=(262144,), dtype=int32)
Position Embeddings tf.Tensor(
[[-0.04750292 -0.01878229 -0.04176264 ... -0.00545417 -0.0211094
  -0.01040163]
 [-0.03994137 -0.0084803  -0.01892308 ... -0.01790547  0.03327367
  -0.00419811]
 [-0.02413393 -0.00270544 -0.01596689 ...  0.04486689 -0.01343553
   0.00218304]
 ...
 [-0.03459259  0.00017107 -0.02176112 ...  0.00421695 -0.03048208
  -0.01478223]
 [ 0.04458486 -0.04883571  0.03633959 ... -0.01017683  0.04318893
  -0.02244467]
 [-0.00593229 -0.01920514 -0.01825304 ... -0.03068141  0.04130179
  -0.00633845]], shape=(262144, 64), dtype=float32)
patch + position <KerasTensor shape=(None, 262144, 64), dtype=float32, sparse=False, name=keras_tensor_1308>
z3 shape (None, 512, 512, 64)
upscale factor -4


<Functional name=UNETR_2D, built=True>

In [26]:
model.summary()