<a href="https://colab.research.google.com/github/VishalMaurya/MLP_Mixer_Model/blob/main/MLP_Mixer_Architecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf

In [2]:
# MLP function
def mlp(x, hidden_dims ):
    y = tf.keras.layers.Dense( hidden_dims )( x )
    y = tf.nn.gelu( y )
    y = tf.keras.layers.Dense( x.shape[-1] )( y )

    # extra experiment
    y = tf.keras.layers.Dropout(0.3)( y )

    return y 

In [3]:
# token mixing function
def token_mixing( x, token_mixing_mlp_dims):
    x = tf.keras.layers.LayerNormalization( epsilon=1e-6 )( x )
    x = tf.keras.layers.Permute( dims=[2,1])( x )

    x = mlp(x, token_mixing_mlp_dims)

    return x

def channel_mixing( x, channel_mixing_mlp_dims):
    x = tf.keras.layers.LayerNormalization( epsilon=1e-6 )( x )
    x = mlp( x, channel_mixing_mlp_dims)
    return x

In [4]:
# Mixer function, it is consist of token-mixing and channel-mixing MLPS

def mixer( x, token_mixing_mlp_dims, channel_mixing_mlp_dims ):
    """
    shape of x: ( batch_size, num_patches, channels)
    """

    # token_mixing MLP
    token_mixing_out = token_mixing( x, token_mixing_mlp_dims)
    # Transpose
    token_mixing_out = tf.keras.layers.Permute(dims = [ 2,1] )( token_mixing_out)
    # Skip connection
    token_mixing_out = tf.keras.layers.Add()( [ x, token_mixing_out] ) 
    # channel_mixing MLP
    channel_mixing_out = channel_mixing( token_mixing_out, channel_mixing_mlp_dims)
    # Skip connection
    channel_mixing_out = tf.keras.layers.Add()( [ channel_mixing_out, token_mixing_out] )

    return channel_mixing_out



In [6]:
# Creating input constants and model

input_image_shape = ( 32, 32, 3)
hidden_dims = 128
num_classes = 10
patch_size = 9
num_mixer_layers = 4
token_mixing_mlp_dims = 64
channel_mixing_mlp_dims = 128

# Input layer
inputs = tf.keras.layers.Input( shape=input_image_shape)
# Conv2D to extract patches
patches = tf.keras.layers.Conv2D( hidden_dims, kernel_size=patch_size, strides=patch_size)( inputs)
# Resizing the patches
patches_reshape = tf.keras.layers.Reshape( (patches.shape[1]*patches.shape[2], patches.shape[3]) )( patches)

x = patches_reshape
# create Mixer layers
for _ in range(num_mixer_layers):
    x = mixer(x, token_mixing_mlp_dims, channel_mixing_mlp_dims)

# classifier head
x = tf.keras.layers.LayerNormalization( epsilon=1e-6)( x )
x = tf.keras.layers.GlobalAveragePooling1D()( x )
outputs = tf.keras.layers.Dense( num_classes, activation='softmax')(x)

# create model object
model = tf.keras.models.Model( inputs, outputs)
model.summary()

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 3, 3, 128)    31232       input_2[0][0]                    
__________________________________________________________________________________________________
reshape (Reshape)               (None, 9, 128)       0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
layer_normalization (LayerNorma (None, 9, 128)       256         reshape[0][0]                    
______________________________________________________________________________________________