In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [None]:
num_classes = 1
input_shape = (130, 130, 3)

In [None]:
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 256
num_epochs = 100  # For real training, use num_epochs=100. 10 is a test value
image_size = 130  # We'll resize input images to this size
patch_size = 6  # Size of the patches to be extract from the input images
num_patches = (image_size // patch_size) ** 2
projection_dim = 64
num_heads = 4
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [
    2048,
    1024,
]  # Size of the dense layers of the final classifier

In [None]:
data_augmentation = tf.keras.Sequential([ 
  tf.keras.layers.RandomFlip('horizontal_and_vertical'),
  tf.keras.layers.RandomRotation(0.2),
  tf.keras.layers.RandomBrightness(factor=0.2),
  tf.keras.layers.RandomContrast(factor=0.2),
  tf.keras.layers.Rescaling(scale=1./255),
])

In [None]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = tf.keras.layers.Dense(units, activation=keras.activations.gelu)(x)
        x = tf.keras.layers.Dropout(dropout_rate)(x)
    return x

In [None]:
class Patches(tf.keras.layers.Layer):
    def __init__(self,patch_size):
        super().__init__()
        self.patch_size = patch_size


    def call(self,images):
        input_shape = tf.shape(images)
        batch_size = input_shape[0]
        height = input_shape[1]
        width = input_shape[2]
        channels = input_shape[3]
        # Floor division
        num_vertical_patches = height // self.patch_size
        num_horizontal_patches = width // self.pathc_size
        total_patches = num_vertical_patches * num_horizontal_patches
        # Creating the patches
        patches = tf.image.extract_patches(images,size=[1,self.patch_size,self.patch_size,1])
        patches = tf.reshape(patches,[batch_size,total_patches,self.patch_size*self.patch_size*channels])
        return patches

    def get_config(self):
        config = super().get_config()
        config.update({"patch_size":self.patch_size})
        return config
        
        

In [None]:
class PatchEncoder(tf.keras.layers.Layer):
    def __init__(self,num_patches,projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim = num_patches, output_dim=projection_dim
        )


    def call(self,patch):
        positions = tf.expand_dims(
            tf.keras.ops.arange(0,self.num_patches,1), axis=0
        )
        projected_patches = self.projection(patch)
        encoded = projected_patches + self.position_embedding(positions)
        return encoded

    def get_config(self):
        config = super().get_config()
        config.update({"num_patches": self.num_patches})
        return config

    
    

In [None]:
def create_vit():
    inputs = tf.keras.Input((130,130,3))
    augmentation = data_augmentation(inputs)
    patches = Patches(patch_size)(augmentation)
    encoded_patches = PatchEncoder(num_patches,projection_dim)(patches)
    for _ in range(transformer_layers):
        x1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = tf.keras.layers.MultiHeadAttention(num_heads=num_heads,key_dim=projection_dim,dropout=0.1)(x1,x1)
        x2 = tf.keras.layers.Add()([attention_output,encoded_patches])
        x3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x2)
        # MLP.
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        # Skip connection 2.
        encoded_patches = tf.keras.layers.Add()([x3, x2])
    
    representation = tf.keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = tf.keras.layers.Flatten()(representation)
    representation = tf.keras.layers.Dropout(0.5)(representation)
    # Add MLP.
    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    # Classify outputs.
    logits = tf.keras.layers.Dense(num_classes)(features)
    # Create the Keras model.
    model = tf.keras.Model(inputs=inputs, outputs=logits)
    return model

In [None]:
model = create_vit()