In [16]:
from typing_extensions import Concatenate
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]="2"

import tensorflow as tf
from tensorflow.keras import layers as L

def inverted_residual_block(inputs, num_of_filters,strides=1,expansion_ratio=1):
  ##point-wise convulation
  x=L.Conv2D(
      filters= expansion_ratio*inputs.shape[-1],
      kernel_size=1,
      padding="same",
      use_bias=False
      )(inputs)

  x=L.BatchNormalization()(x)
  x=L.Activation("swish")(x)

 ##depth-wise convulation

  x=L.DepthwiseConv2D(
      kernel_size=3,
      strides=strides,
      padding="same",
      use_bias=False
  )(x)
  x=L.BatchNormalization()(x)
  x=L.Activation("swish")(x)


 ##point-wise convulation
  x=L.Conv2D(
      filters= num_of_filters,
      kernel_size=1,
      padding="same",
      use_bias=False
  )(x)
  x=L.BatchNormalization()(x)

 ##residual connection

  if strides==1 and (inputs.shape==x.shape):
      return L.Add()([inputs,x])
  return x

def mlp(x,mlp_d,d,dropout_rate=0.1):
  x=L.Dense(mlp_d,activation="swish")(x)
  x=L.Dropout(dropout_rate)(x)
  x=L.Dense(d)(x)
  x=L.Dropout(dropout_rate)(x)

  return x


def transformer_encoder(x,num_heads,d,mlp_d):
  skip_1=x
  x=L.LayerNormalization()(x)
  x=L.MultiHeadAttention(
      num_heads=num_heads,key_dim=d
  )(x,x)
  x=L.Add()([x,skip_1])

  skip_2=x
  x=L.LayerNormalization()(x)
  x=mlp(x,mlp_d,d)
  x=L.Add()([x,skip_2])

  return x


def mobile_vit_block(inputs,num_filters,dimension,patch_size=2,num_layers=1):

  ## B=Batch, H=height, W=Weight, C=Number of channels. B,H,W,C denotes the shape of input tensor

  B,H,W,C=inputs.shape

  ## 3x3 convolution
  x=L.Conv2D(
      filters=C,
      kernel_size=3,
      padding="same",
      use_bias=False
  )(inputs)
  x=L.BatchNormalization()(x)
  x=L.Activation("swish")(x)

  ## 1x1 convolution
  x=L.Conv2D(
      filters=dimension,
      kernel_size=1,
      padding="same",
      use_bias=False
  )(x)
  x=L.BatchNormalization()(x)
  x=L.Activation("swish")(x)


  ## Reshape x to flattened patches
  P=patch_size*patch_size
  N=int(H*W//P)
  x=L.Reshape((P,N,dimension))(x)

  ## transformer encoder
  for _ in range(num_layers):
    x=transformer_encoder(x,1,dimension,dimension*2)

  ## reshape it back
  x=L.Reshape((H,W,dimension))(x)


  ## 1x1 convolution
  x=L.Conv2D(
      filters=C,
      kernel_size=1,
      padding="same",
      use_bias=False
  )(x)
  x=L.BatchNormalization()(x)
  x=L.Activation("swish")(x)


  ## Concatenate
  x=L.Concatenate()([x,inputs])


  ## 3x3 convolution
  x=L.Conv2D(
      filters=num_filters,
      kernel_size=3,
      padding="same",
      use_bias=False
  )(x)
  x=L.BatchNormalization()(x)
  x=L.Activation("swish")(x)
  return x

def MobileViT(input_shape,num_channels,d,expansion_ratio,num_layers=[2,4,3],num_classes=1000):
    ## Input layer
    inputs=L.Input(input_shape)

    ## Stem
    x=L.Conv2D(
        filters=num_channels[0],
        kernel_size=3,
        strides=2,
        padding="same",
        use_bias=False
    )(inputs)

    x=L.BatchNormalization()(x)
    x=L.Activation("swish")(x)
    x=inverted_residual_block(x,num_channels[1],strides=1,expansion_ratio=expansion_ratio)

    ## Stage 1
    x=inverted_residual_block(x,num_channels[2],strides=2,expansion_ratio=expansion_ratio)
    x=inverted_residual_block(x,num_channels[3],strides=1,expansion_ratio=expansion_ratio)
    x=inverted_residual_block(x,num_channels[4],strides=1,expansion_ratio=expansion_ratio)



    ## Stage 2
    x=inverted_residual_block(x,num_channels[5],strides=2,expansion_ratio=expansion_ratio)
    x=mobile_vit_block(x,num_channels[6],d[0],num_layers=num_layers[0])


    ## Stage 3
    x=inverted_residual_block(x,num_channels[7],strides=2,expansion_ratio=expansion_ratio)
    x=mobile_vit_block(x,num_channels[8],d[1],num_layers=num_layers[1])


    ## Stage 4
    x=inverted_residual_block(x,num_channels[9],strides=2,expansion_ratio=expansion_ratio)
    x=mobile_vit_block(x,num_channels[10],d[2],num_layers=num_layers[2])
    x=L.Conv2D(
        filters=num_channels[11],
        kernel_size=1,
        padding="same",
        use_bias=False
        )(x)
    x=L.BatchNormalization()(x)
    x=L.Activation("swish")(x)

    ## Classifier
    x=L.GlobalAveragePooling2D()(x)
    outputs=L.Dense(num_classes,activation="softmax")(x)

    model=tf.keras.models.Model(inputs,outputs)
    return model
def MobileViT_s(input_shape,num_classes):
  num_channels=[16,32,64,64,64,96,144,128,192,160,240,640]
  d=[144,192,240]
  expansion_ratio=4
  return MobileViT(
      input_shape,
      num_channels,
      d,
      expansion_ratio,
      num_classes=num_classes
  )

def MobileViT_xs(input_shape,num_classes):
  num_channels=[16,32,48,48,48,64,96,80,120,96,144,384]
  d=[96,120,144]
  expansion_ratio=4
  return MobileViT(
      input_shape,
      num_channels,
      d,
      expansion_ratio,
      num_classes=num_classes

  )

def MobileViT_xxs(input_shape,num_classes):
  num_channels=[16,32,24,24,24,48,64,64,80,80,96,320]
  d=[64,80,96]
  expansion_ratio=2
  return MobileViT(
      input_shape,
      num_channels,
      d,
      expansion_ratio,
      num_classes=num_classes

  )

if __name__ == "__main__":
  model=MobileViT_s((256,256,3),1000)
  model.summary()



Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_14 (InputLayer)       [(None, 256, 256, 3)]        0         []                            
                                                                                                  
 conv2d_162 (Conv2D)         (None, 128, 128, 16)         432       ['input_14[0][0]']            
                                                                                                  
 batch_normalization_206 (B  (None, 128, 128, 16)         64        ['conv2d_162[0][0]']          
 atchNormalization)                                                                               
                                                                                                  
 activation_159 (Activation  (None, 128, 128, 16)         0         ['batch_normalization_20