# Vision transformer in Tensorflow (Unfinished)

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras



Based on https://github.com/lucidrains/vit-pytorch/blob/main/vit_pytorch/vit.py

In [2]:
from einops import rearrange, repeat
from einops.layers.tensorflow import Rearrange

In [3]:
from tensorflow.keras import layers as ly
import matplotlib.pyplot as plt

In [4]:
aug_image_size= 64
data_augmentation = keras.Sequential(
    [
        ly.Normalization(),
        ly.Resizing(aug_image_size, aug_image_size),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
# data_augmentation.layers[0].adapt(x_train)

Metal device set to: Apple M1


In [5]:
def feed_forward(dim, mlp_dim, dropout=0.1):
    model = keras.Sequential([
        ly.Dense(mlp_dim, activation=tf.nn.gelu),
        ly.Dropout(dropout),
        ly.Dense(dim),
        ly.Dropout(dropout)
    ])
    return model

In [6]:
class Identity(tf.keras.layers.Layer):
    def __init__(self, name):
        super(Identity, self).__init__(name=name)

    def call(self, x):
        return x

In [7]:
class TransformerBlock(ly.Layer):
    def __init__(self, dim, mlp_dim, heads = 8, dropout = 0.1):
        super(TransformerBlock, self).__init__()
        self.lnorm1 = ly.LayerNormalization(epsilon=1e-6)
        self.attn = ly.MultiHeadAttention(num_heads=heads, key_dim=dim, dropout=dropout)
        self.lnorm2 = ly.LayerNormalization(epsilon=1e-6)
        self.ff = feed_forward(dim, mlp_dim, dropout=dropout)
    def call(self, x):
        x1 = self.lnorm1(x)
        attn = self.attn(x1, x1, x1)
        x2 = x + attn
        x3 = self.lnorm2(x2)
        x3 = self.ff(x3)
        x4 = x3 + x2
        return x4

In [8]:
class PatchEncoder(ly.Layer):
    def __init__(self, patch_size, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.projection = ly.Dense(units=projection_dim)
        self.position_embedding = ly.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patches) + self.position_embedding(positions)
        return encoded

In [9]:
def VisionTransformer(image_size=224, patch_size=16, num_classes=1000,
                     dim=192, depth=12, heads=8, mlp_dim=768, channels = 3, 
                      dim_head = 64, dropout = 0.1, emb_dropout = 0.1):
    
    inputs = keras.Input(shape=(image_size, image_size, 3))
    augmented = data_augmentation(inputs)
    im_h, im_w = augmented.shape[1:3]
    patch_h, patch_w = patch_size, patch_size
    num_patches = (im_h // patch_h) * (im_w // patch_w)
    patch_dim = channels * patch_h * patch_w
    
    patch_encoded = PatchEncoder(patch_size, num_patches, patch_dim)(augmented)
    patch_encoded = ly.Dense(dim)(patch_encoded)
    patch_encoded = ly.Dropout(emb_dropout)(patch_encoded)

    for _ in range(depth):
        patch_encoded = TransformerBlock(dim, mlp_dim, dropout=dropout)(patch_encoded)
    
    representation = ly.LayerNormalization(epsilon=1e-6)(patch_encoded)
    representation = ly.Flatten()(representation)
    
    out = ly.Dense(num_classes)(representation)
    return keras.Model(inputs=inputs, outputs=out, name="vision_transformer")

In [11]:
image_size = 224 # standard image size
channels = 3 # RGB image
patch_size = 16
embed_dim = 192
mlp_dim = embed_dim * 4
depth = 12
n_heads = 3
dropout = 0.1
emb_dropout = 0.1

In [14]:
model = VisionTransformer(image_size=image_size, 
                          patch_size=patch_size, num_classes=1000, 
                          dim=embed_dim, mlp_dim=mlp_dim, 
                          dropout=dropout, emb_dropout=emb_dropout)

In [15]:
inp = np.random.rand(1, 224, 224, 3)
model(inp).shape

TensorShape([1, 1000])

In [16]:
model.summary()

Model: "vision_transformer"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 data_augmentation (Sequenti  (None, 64, 64, 3)        7         
 al)                                                             
                                                                 
 patch_encoder (PatchEncoder  (None, 16, 768)          602880    
 )                                                               
                                                                 
 dense_1 (Dense)             (None, 16, 192)           147648    
                                                                 
 dropout (Dropout)           (None, 16, 192)           0         
                                                                 
 transformer_block (Transfor  (None, 16, 192)   