# Vision Transformer (ViT)

![ViT](media/vision_transformer/vit.png "ViT")

In [None]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
from utils.transformer import TransformerEncoder, PatchClassEmbedding, Patches
from utils.visualize import plotPatches, plotHistory
from utils.tools import CustomSchedule
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import cv2
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split

In [None]:
# set some paths
model_dir = Path('bin')

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

# 1.0 Import the Dataset

## 1.1 Download a dataset (Cats_vs_Dogs)

In [None]:
ds_train, ds_info = tfds.load(
    'cats_vs_dogs',
    shuffle_files=True,
    as_supervised=True,
    with_info=True)

In [None]:
ds_info

In [None]:
label_names = ds_info.features['label'].names
print(label_names)

In [None]:
n_images = ds_info.splits['train'].num_examples
print(n_images)

# 2.0 Prepare the Dataset

In [None]:
# dataset configurations
input_size = (224, 224, 3)
patch_size = 16
test_size = 0.2
num_patches = (input_size[0] // patch_size) ** 2

In [None]:
def build_dataset(ds, n_images):
    """Create a numpy array resizing all images"""
    X = np.empty((n_images, input_size[0], input_size[1], input_size[2]), dtype="float32")
    y = np.empty((n_images), dtype="float32")
    for i, data in tqdm(enumerate(ds['train'])):
        img = cv2.resize(data[0].numpy(), (input_size[1],input_size[0]))
        X[i] = img
        y[i] = data[1]
    return X, y

In [None]:
X, y = build_dataset(ds_train, n_images)

## 2.1 Visualize patch creation

In [None]:
plotPatches(X, n_images=2, patch_size=patch_size)

## 2.2 Split the dataset

In [None]:
# split with a stratified sampling
(X_train, X_test, y_train, y_test) = train_test_split(X, y,
    test_size=test_size, stratify=y, random_state=42)

## 2.3 Build a pre-process pipeline with keras pre

In [None]:
pre_process_pipeline = tf.keras.Sequential([
        tf.keras.layers.experimental.preprocessing.Normalization(),
        tf.keras.layers.experimental.preprocessing.RandomFlip("horizontal_and_vertical"),
        tf.keras.layers.experimental.preprocessing.RandomContrast(0.2),
        tf.keras.layers.experimental.preprocessing.RandomRotation(factor=0.03),
        tf.keras.layers.experimental.preprocessing.RandomZoom(height_factor=0.3, width_factor=0.3),], name="pre_process_pipeline")

pre_process_pipeline.layers[0].adapt(X_train)

# 3.0 Build the Vision Transformer (ViT)

In [None]:
# model configurations
d_model = 128
d_ff = d_model * 2
n_heads = 4
mlp_head_size = 512
dropout = 0.1
activation = tf.nn.gelu
n_layers = 6

In [None]:
def build_vit(transformer):
    # Input
    inputs = tf.keras.layers.Input(shape=input_size)
    
    # Data pre_processing Pipeline
    x = pre_process_pipeline(inputs)
    
    # Patch Creation
    x = Patches(patch_size)(x)
    
    # Linear Projection of Flattened Patches
    x = tf.keras.layers.Dense(d_model)(x)
    
    # Position Embedding + Extra learnable class embedding
    x = PatchClassEmbedding(d_model, num_patches)(x)
    
    # Transformer Model
    x = transformer(x)
    
    # Take only the Extra Learnable Class
    x = tf.keras.layers.Lambda(lambda x: x[:,0,:])(x)
    
    # MLP Head
    x = tf.keras.layers.Dense(mlp_head_size)(x)
    outputs = tf.keras.layers.Dense(len(label_names))(x)
    
    return tf.keras.models.Model(inputs, outputs)

In [None]:
transformer = TransformerEncoder(d_model, n_heads, d_ff, dropout, 
                                 activation, n_layers)

In [None]:
vit_model = build_vit(transformer)
vit_model.summary()

# 4.0 Train the Network

In [None]:
# some training configurations
lr = 3e-4
batch_size = 128
epochs = 100

In [None]:
lr = CustomSchedule(d_model, warmup_steps=20000.0)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
vit_model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")])

In [None]:
name_model = 'vision_transformer.h5'
checkpointer = tf.keras.callbacks.ModelCheckpoint(
        model_dir.joinpath(name_model),
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True)

In [None]:
history = vit_model.fit(x=X_train,
    y=y_train,
    batch_size=batch_size,
    epochs=epochs,
    validation_data=(X_test, y_test),
    callbacks=[checkpointer])

# 5.0 Test the Model

In [None]:
# load best weights
vit_model.load_weights(model_dir.joinpath(name_model))

In [None]:
# Evaluate the model 
vit_model.evaluate(X_test, y_test)